classdef EpisodicTopoARTRecallViewer < handle
    properties (Access = protected)
        path char

        % Colours
        lightBlue  = [240/255 240/255 1];
        darkBlue   = [112/255 112/255 191/255];
        lightGreen = [240/255 1 240/255];
        darkGreen  = [112/255 191/255 112/255];	
        lightRed   = [1 240/255 240/255];
        darkRed    = [191/255 112/255 112/255];
        black      = [0 0 0];

        % Large arrow parameters
        lArrowLength = 8;

        % Image parameters
        imageDist = 20;
        imageBox  = 5;

        % Small arrow parameters
        sArrowLength = 6;

        % Box parameters
        curvature = 0.05;
        refSize   = 100;

        % Text parameters
        stimulusTextYOff  = 19;
        recallTextMinSize = 120;
        fontSize          = 7;
        
        % ... parameters
        pointsXOff   = 30;
        pointsYOff   = 20;
        pointsString = '...';

        % Plot parameters 
        plotXMargin = 0.1;
        plotYMargin = 0.1;
    end

    methods (Access = protected)
        function [interNum, intraNumArray, maxIntraNum] = CountImages(rv)
            arguments
                rv EpisodicTopoARTRecallViewer
            end

            interNum = 0;
            intraNumArray = [];
            maxIntraNum = 0;

            filename = [rv.path 'inter_image_' num2str(interNum + 1, '%03i') '.jpg'];

            while isfile(filename)
                intraNum = 0;
                filename = [rv.path 'intra_image_' num2str(interNum + 1, '%03i') '_' num2str(intraNum + 1, '%03i') '.jpg'];

                while isfile(filename)
                    intraNum = intraNum + 1;
                    filename = [rv.path 'intra_image_' num2str(interNum + 1, '%03i') '_' num2str(intraNum + 1, '%03i') '.jpg'];
                end

                intraNumArray = [intraNumArray intraNum];
                maxIntraNum = max(intraNum, maxIntraNum);
                interNum = interNum + 1;
                filename = [rv.path 'inter_image_' num2str(interNum + 1, '%03i') '.jpg'];
            end
        end

        function ShowInterEpisodeRecallResults(rv, interNum, plotInterNum, xOff, yOff, lArrowOff, recallTextOff, ...
                stimulusSize, plotPYOff, plotYShift, plotWidth, plotHeight)
             arguments
                rv EpisodicTopoARTRecallViewer
                interNum {mustBeInteger, mustBeNonnegative}
                plotInterNum {mustBeInteger, mustBeNonnegative}
                xOff {mustBeInteger, mustBeNonnegative}
                yOff {mustBeInteger, mustBeNonnegative}
                lArrowOff {mustBeInteger, mustBeNonnegative}
                recallTextOff {mustBeInteger, mustBeNonnegative}
                stimulusSize {mustBeInteger, mustBePositive}
                plotPYOff {mustBeInteger, mustBeNonnegative}
                plotYShift {mustBeInteger, mustBeNonnegative}
                plotWidth {mustBeInteger, mustBePositive}
                plotHeight {mustBeInteger, mustBePositive}
             end

            if plotInterNum == 0
                return
            end

            interYPos1    = plotYShift - (plotInterNum + 1) * stimulusSize(1) - plotInterNum * rv.imageDist - yOff - rv.imageBox - plotPYOff;
            interYPos2    = plotYShift - stimulusSize(1) - rv.imageDist - yOff + rv.imageBox;
            interTextXPos = xOff - rv.imageBox - lArrowOff - recallTextOff / 2 + rv.fontSize / 2 - 1;
            interTextYPos = interYPos1 - (interYPos1 - interYPos2) / 2;

            if interYPos2 - interYPos1 >= rv.recallTextMinSize
                text(interTextXPos, interTextYPos, 'inter-episode recall', 'Color', rv.darkRed, 'FontSize', ...
                    rv.fontSize, 'Rotation', 90, 'HorizontalAlignment', 'Center');
            end

            if interYPos2 > interYPos1
                w = stimulusSize(2) + 2 * rv.imageBox;
                h = interYPos2 - interYPos1;
                curvatureW = rv.curvature * rv.refSize / w;
                curvatureH = rv.curvature * rv.refSize / h;
                rectangle('Position', [xOff - rv.imageBox, interYPos1, w, h], 'FaceColor', rv.lightRed, ...
                    'EdgeColor', rv.darkRed, 'Curvature', [curvatureW, curvatureH], 'LineWidth', 1);

                interArrowXPos = xOff - rv.imageBox - lArrowOff / 2;
                DrawArrow([interArrowXPos interArrowXPos] / plotWidth, [interYPos2 interYPos1] / plotHeight, rv.darkRed)
            end

            % Draw small arrows
            for i = 1 : ifelse(plotInterNum < interNum, plotInterNum + 1, plotInterNum)
                DrawArrow([xOff + stimulusSize(2) / 2, xOff + stimulusSize(2) / 2] / plotWidth, ...
                    [plotYShift - i * stimulusSize(1) - (i - 1) * rv.imageDist - yOff - 0.5, ...
                    plotYShift - i * stimulusSize(1) - (i - 1) * rv.imageDist - yOff - rv.imageDist + 1] / plotHeight, rv.black, 0.7)
            end

            % Show ...
            if plotInterNum < interNum
                text(xOff + stimulusSize(2) / 2, plotYShift - (plotInterNum + 1) * (stimulusSize(1) + rv.imageDist) ...
                    - yOff - rv.pointsYOff / 2, rv.pointsString, 'Color', rv.black, 'FontSize', ...
                    rv.fontSize, 'HorizontalAlignment', 'Center');
            end

            % Load and show images
            for i = 1 : plotInterNum
                interImage = imread([rv.path 'inter_image_' num2str(i, '%03i') '.jpg']);
                image([xOff, xOff + stimulusSize(2)], [plotYShift - (i + 1) * stimulusSize(1) - i * rv.imageDist - yOff, ...
                    plotYShift - i * stimulusSize(1) - i * rv.imageDist - yOff], flip(interImage, 1));
            end
            
        end

        function ShowIntraEpisodeRecallResults(rv, plotInterNum, plotIntraNum, intraNumArray, xOff, yOff, lArrowOff, recallTextOff, ...
                stimulusSize, plotPXOff, plotYShift, plotWidth, plotHeight)
             arguments
                rv EpisodicTopoARTRecallViewer
                plotInterNum {mustBeInteger, mustBeNonnegative}
                plotIntraNum {mustBeInteger, mustBeNonnegative}
                intraNumArray
                xOff {mustBeInteger, mustBeNonnegative}
                yOff {mustBeInteger, mustBeNonnegative}
                lArrowOff {mustBeInteger, mustBeNonnegative}
                recallTextOff {mustBeInteger, mustBeNonnegative}
                stimulusSize {mustBeInteger, mustBePositive}
                plotPXOff {mustBeInteger, mustBeNonnegative}
                plotYShift {mustBeInteger, mustBeNonnegative}
                plotWidth {mustBeInteger, mustBePositive}
                plotHeight {mustBeInteger, mustBePositive}
             end

            if plotInterNum == 0 || plotIntraNum == 0
                return
            end

            intraXPos1    = stimulusSize(2) + rv.imageDist + xOff - rv.imageBox;
            intraXPos2    = (plotIntraNum + 1) * stimulusSize(2) + plotIntraNum * rv.imageDist + xOff + rv.imageBox + plotPXOff;
            intraTextXPos = intraXPos1 + (intraXPos2 - intraXPos1) / 2;
            intraTextYPos = plotYShift - yOff - stimulusSize(1) - rv.imageDist + rv.imageBox + lArrowOff + ...
                                recallTextOff / 2 - rv.fontSize / 2 + 1;

            if intraXPos2 - intraXPos1 >= rv.recallTextMinSize
                text(intraTextXPos, intraTextYPos, 'intra-episode recall', 'Color', rv.darkBlue, 'FontSize', ...
                    rv.fontSize, 'HorizontalAlignment', 'Center');
            end

            if intraXPos2 > intraXPos1
                intraArrowYPos = plotYShift - yOff - stimulusSize(1) - rv.imageDist + rv.imageBox + lArrowOff / 2;
                DrawArrow([intraXPos1 intraXPos2] / plotWidth, [intraArrowYPos intraArrowYPos] / plotHeight, rv.darkBlue)
            end

            for i = 1 : plotInterNum
                stepIntraNum = min(intraNumArray(i), plotIntraNum);
                
                if stepIntraNum > 0
                    stepPlotPXoff = ifelse(stepIntraNum < intraNumArray(i), plotPXOff, 0);
                    intraXPos2    = (stepIntraNum + 1) * stimulusSize(2) + stepIntraNum * rv.imageDist + xOff + rv.imageBox + stepPlotPXoff;
                    intraYPos1    = plotYShift - yOff - (i + 1) * stimulusSize(1) - i * rv.imageDist - rv.imageBox;
                    intraYPos2    = plotYShift - yOff - i * (stimulusSize(1) + rv.imageDist) + rv.imageBox;

                    w = intraXPos2 - intraXPos1;
                    h = intraYPos2 - intraYPos1;
                    curvatureW = rv.curvature * rv.refSize / w;
                    curvatureH = rv.curvature * rv.refSize / h;
                    rectangle('Position', [intraXPos1, intraYPos1, w, h], 'FaceColor', rv.lightBlue, ...
                        'EdgeColor', rv.darkBlue, 'Curvature', [curvatureW, curvatureH], 'LineWidth', 1);
                end

                % Draw small arrows
                for j = 1 : ifelse(plotIntraNum < intraNumArray(i), stepIntraNum + 1, stepIntraNum)
                    DrawArrow([xOff + j * stimulusSize(2) + (j - 1) * rv.imageDist + 0.5, ...
                        xOff + j * (stimulusSize(2) + rv.imageDist) - 1] / plotWidth, ...
                        [plotYShift - i * (stimulusSize(1) + rv.imageDist) - stimulusSize(1) / 2 - yOff, ...
                        plotYShift - i * (stimulusSize(1) + rv.imageDist) - stimulusSize(1) / 2 - yOff] / plotHeight, rv.black, 0.7)
                end

                % Show ...
                if plotIntraNum < intraNumArray(i)
                    text(xOff + (plotIntraNum + 1) * (stimulusSize(2) + rv.imageDist) + rv.pointsXOff / 2, ...
                        plotYShift - i * (stimulusSize(1) + rv.imageDist) - stimulusSize(1) / 2 - yOff, ...
                        rv.pointsString, 'Color', rv.black, 'FontSize', rv.fontSize, 'HorizontalAlignment', 'Center');
                end

                % Load and show images
                for j = 1 : stepIntraNum
                    interImage = imread([rv.path 'intra_image_' num2str(i, '%03i') '_' num2str(j, '%03i') '.jpg']);
                    image([j * stimulusSize(2) + j * rv.imageDist + xOff, (j + 1) * stimulusSize(2) + j * rv.imageDist + xOff], ...
                        [plotYShift - (i + 1) * stimulusSize(1) - i * rv.imageDist - yOff, ...
                        plotYShift - i * stimulusSize(1) - i * rv.imageDist - yOff], flip(interImage, 1));
                end
            end
               
        end

        function ShowStimulus(rv, stimulus, xOff, yOff, stimulusSize, plotYShift)
            arguments
                rv EpisodicTopoARTRecallViewer
                stimulus
                xOff {mustBeInteger, mustBeNonnegative}
                yOff {mustBeInteger, mustBeNonnegative}
                stimulusSize {mustBeInteger, mustBePositive}
                plotYShift {mustBeInteger, mustBeNonnegative}
            end

            text(xOff - rv.imageBox, plotYShift - yOff + rv.stimulusTextYOff - rv.fontSize / 2, 'stimulus', ...
                'Color', rv.darkGreen, 'FontSize', rv.fontSize);

            w = stimulusSize(2) + 2 * rv.imageBox;
            h = stimulusSize(1) + 2 * rv.imageBox;
            curvatureW = rv.curvature * rv.refSize / w;
            curvatureH = rv.curvature * rv.refSize / h;
            rectangle('Position', [xOff - rv.imageBox, plotYShift - stimulusSize(1) - yOff - rv.imageBox, w, h], ...
                'FaceColor', rv.lightGreen, 'EdgeColor', rv.darkGreen, 'Curvature', [curvatureW, curvatureH], ...
                'LineWidth', 1);

            image([xOff, xOff + stimulusSize(2)], [plotYShift - stimulusSize(1) - yOff, plotYShift - yOff], flip(stimulus, 1));
        end
    end

    methods
        function rv = EpisodicTopoARTRecallViewer(path)
            arguments
                path char {mustBeFolder}
            end
            
            rv.path = path;
        end

        function PlotRecallResults(rv, maxInterSteps, maxIntraSteps)
            arguments
                rv EpisodicTopoARTRecallViewer
                maxInterSteps {mustBeInteger, mustBePositive}
                maxIntraSteps {mustBeInteger, mustBePositive}
            end

            % Count available images
            [interNum, intraNumArray, maxIntraNum] = rv.CountImages();

            % Incorporate maximum values
            plotInterNum = ifelse(maxInterSteps < 0, interNum, min(interNum, maxInterSteps));
            plotIntraNum = ifelse(maxInterSteps == 0, 0, ifelse(maxIntraSteps < 0, maxIntraNum, min(maxIntraNum, maxIntraSteps)));

            stimulus = imread([rv.path 'stimulus.jpg']);
            stimulusSize = size(stimulus);

            % Large arrow parameters
            lArrowOff = ifelse(interNum > 0, 20, 0);

            % Text parameters
            recallTextOff = ifelse(interNum > 0, rv.fontSize, 0);

            xOff = rv.imageBox + ifelse(maxInterSteps == 0, 0, lArrowOff + recallTextOff + 5);
            yOff = rv.imageBox + rv.stimulusTextYOff;

            % Plot parameters
            plotPXOff  = ifelse(maxInterSteps == 0, 0, ifelse(plotIntraNum > 0 && plotIntraNum < maxIntraNum, rv.imageDist + rv.pointsXOff, 0));
            plotPYOff  = ifelse(plotInterNum > 0 && plotInterNum < interNum, rv.imageDist + rv.pointsYOff, 0);
            plotXSize  = stimulusSize(2) * (plotIntraNum + 1) + rv.imageDist * plotIntraNum + rv.imageBox + xOff + plotPXOff;
            plotYSize  = stimulusSize(1) * (plotInterNum + 1) + rv.imageDist * plotInterNum + rv.imageBox + yOff + plotPYOff + 1;
            plotYShift = plotYSize;

            f = figure('Name', 'Episodic TopoART: recall results');
            f.Position(3:4) = [plotXSize / 2, plotYSize / 2];
            set (gca, 'Position', [0 0 1 1]);
            set(gca,'Visible','off')
            xlim([0, plotXSize])
            ylim([0, plotYSize])
            hold on;

            rv.ShowStimulus(stimulus, xOff, yOff, stimulusSize, plotYShift)
            rv.ShowInterEpisodeRecallResults(interNum, plotInterNum, xOff, yOff, lArrowOff, recallTextOff, ...
                stimulusSize, plotPYOff, plotYShift, plotXSize, plotYSize)
            rv.ShowIntraEpisodeRecallResults(plotInterNum, plotIntraNum, intraNumArray, xOff, yOff, lArrowOff, recallTextOff, ...
                stimulusSize, plotPXOff, plotYShift, plotXSize, plotYSize)

            basePath = fileparts(mfilename('fullpath'));
            f.Units = 'inches';
            set(f, 'PaperPositionMode', 'auto', 'PaperUnits', 'inches', 'PaperPosition', ...
               [rv.plotXMargin, rv.plotYMargin, f.Position(3:4)], 'PaperSize', ...
               [f.Position(3) + rv.plotXMargin * 2, f.Position(4) + rv.plotYMargin * 2])
            saveas(f, [basePath '/../images/ETA_recall.pdf']);
        end
    end
end