/***************************************************************************************************
*             C# sample for the usage of Episodic TopoART (class Fast_Episodic_TopoART)            *
****************************************************************************************************
*                             Created by Marko Tscherepanow, 27 May 2016                           *
***************************************************************************************************/

// Compile and run from the console: dotnet run --project Episodic_TopoART_sample2.csproj

using System;
using System.Collections.Generic;
using System.IO;
using System.Reflection;
using SixLabors.ImageSharp;
using SixLabors.ImageSharp.PixelFormats;
using System.Runtime.InteropServices;
using LibTopoART;

namespace LibTopoART_samples
{
	/// <summary>
	/// Episodic clustering sample using real-world video data. [C#]
	/// <para>
	/// Like in Section 4.2 of "Marko Tscherepanow, Sina Kühnel, and Sören Riechers (2012). Episodic
	/// Clustering of Data Streams Using a Topology-Learning Neural Network. In Proceedings of the European 
	/// Conference on Artificial Intelligence (ECAI), Workshop on Active and Incremental Learning (AIL), 
	/// pp. 24-29. Montpellier, France.", an Episodic TopoART network is trained with real-world video data. 
	/// Each image has a size of 64x36 pixels. As each pixel comprises 3 color channels (RGB), the input 
	/// length equals 6912. After finishing training, recall is performed for a single input stimulus.
	/// </para>
	/// <para>The recall results can be visualised using the script <c>ShowEpisodicTopoARTRecallResults</c> 
	/// provided for R and MATLAB in the subfolder <c>visualisation</c>.
	/// </para>
	/// </summary>
	class Episodic_TopoART_sample2
	{
		private static void Main()
		{
			// Dataset (containing training and test images)
			const string datasetPath	=	"../../../../../data/AIL12-like_video_dataset/";
			// Destination directory for trained networks
			const string networkPath	=	"../../../../../results/networks/";
			const long sampleNumber		=	59523;

			// Recall stimulus from the test set (only odd images allowed)
			const string recallStimulus	=	"image_54971.jpg";
			// Destination directory for recall results
			const string recallPath		=	"../../../../../results/recall/AIL12-like_video_dataset_recall_results/";
			// Recall control parameters
			const long maxInterEpisodeRecallSteps			=	15;
			const long maxIntraEpisodeRecallSteps			=	50;
			const decimal minInterEpisodeRecallActivation	=	0.5m;

			// Set to true to save the trained network
			var saveNetwork = false;

			// Set to true to omit training and use a saved network file
			var useSavedNetwork = false;

			int i, j;
			int width, height;
			byte[] image;
			var images = new List<byte[]>();
			Fast_Episodic_TopoART feta = null;

			// Set working directory to assembly directory
			Directory.SetCurrentDirectory(Path.GetDirectoryName(new Uri(Assembly.GetEntryAssembly().Location).LocalPath));

			// Load dataset and train a new network
			if(!useSavedNetwork) {
				Console.WriteLine("Load training images");

				// Get image size and first image
				i = 0;
				Console.Write(".");
				LoadImage(datasetPath + "train/image_" + i.ToString("D5") + ".jpg", out width, out height, out image);
				images.Add(image);

				do {
					if(i != 0) {
						LoadImage(datasetPath + "train/image_" + i.ToString("D5") + ".jpg", out width, out height, out image);
						images.Add(image);
						if((i % 1000) == 0)
							Console.Write(".");
					}
					i += 2;
				} while(i < sampleNumber);
				Console.WriteLine("");

				// Start time measuring
				var trainingStart = DateTime.Now;

				feta = new Fast_Episodic_TopoART(images[0].LongLength, 2, 0.7m, 400);
				feta.Phi = 5;
				feta.Tau = 200;
				feta.Beta_sbm = 0.25m;

				for(i = 0; i < images.Count; ++i) {
					feta.Learn(images[i]);
					if((i % 500) == 0)
						Console.Write(".");
				}
				Console.WriteLine("");

				// Stop time measuring
				var trainingEnd = DateTime.Now;

				// Output the required time
				var trainingTime = trainingEnd - trainingStart;
				Console.WriteLine("Time for training: " + trainingTime);

				if(saveNetwork) {
					Console.WriteLine("Save network");

					// Save network in human-readable form
					// feta.SaveText(NetworkPath + "Fast_Episodic_TopoART_AIL12-like_video_dataset.txt");

					// Save network in binary form
					feta.Save(networkPath + "Fast_Episodic_TopoART_AIL12-like_video_dataset.feta");
				}
			}
			// Load an existing network
			else
				feta = new Fast_Episodic_TopoART(networkPath + "Fast_Episodic_TopoART_AIL12-like_video_dataset.eta");

			// Start time measuring
			var recallStart = DateTime.Now;

			// Perform recall
			if(feta != null) {
				// Copy stimulus
				File.Copy(datasetPath + "test/" + recallStimulus, recallPath + "stimulus.jpg", true);

				// Load recall stimulus
				LoadImage(datasetPath + "test/" + recallStimulus, out width, out height, out image);

				feta.BeginRecall(image);
				feta.InterEpisodeRecallStep(out byte[] recallResult, out decimal activation);

				string nextImage;

				// Inter-episode recall loop
				for(i = 1; (recallResult != null) && (activation > minInterEpisodeRecallActivation) && (i <= maxInterEpisodeRecallSteps); ++i) {
					SaveImage(recallPath + "inter_image_" + i.ToString("D3") + ".jpg", width, height, recallResult);

					// Intra-episode recall loop
					for(j = 1; (recallResult != null) && (j <= maxIntraEpisodeRecallSteps); ++j) {
						feta.IntraEpisodeRecallStep(out recallResult);
						if(recallResult == null) 
							break;
						SaveImage(recallPath + "intra_image_" + i.ToString("D3") + "_" + j.ToString("D3") + ".jpg", width, height, recallResult);
					}

					// Signify stopped intra-episode recall
					nextImage = recallPath + "intra_image_" + i.ToString("D3") + "_" + j.ToString("D3") + ".jpg";
					if(File.Exists(nextImage))
						File.Delete(nextImage);

					feta.InterEpisodeRecallStep(out recallResult, out activation);
				}

				// Signify stopped inter-episode recall
				nextImage = recallPath + "inter_image_" + i.ToString("D3") + ".jpg";
				if(File.Exists(nextImage))
					File.Delete(nextImage);

				feta.EndRecall();
			}

			// Stop time measuring
			var recallEnd = DateTime.Now;

			// Output the required time
			var recallTime = recallEnd - recallStart;
			Console.WriteLine("Time for recall: " + recallTime);
		}

		private static void LoadImage(string path, out int width, out int height, out byte[] imageArray)
		{
			using var image = Image.Load<Rgb24>(path);
			width = image.Width;
			height = image.Height;
			Span<Rgb24> data = new Span<Rgb24>(new Rgb24[image.Width * image.Height]);
			image.CopyPixelDataTo(data);
			imageArray = MemoryMarshal.AsBytes(data).ToArray();
		}

		private static void SaveImage(string path, int width, int height, byte[] imageArray)
		{
			using var image = new Image<Rgb24>(width, height);
			for(long i = 0, j = 0; j < imageArray.LongLength; ++i, j += 3)
				image[(int)(i % width), (int)(i / width)] = new Rgb24(imageArray[j], imageArray[j + 1], imageArray[j + 2]);
			image.SaveAsJpeg(path);
		}
	}
}