(***************************************************************************************************
*              Extended F# sample for the usage of TopoART-AM (class Fast_TopoART_AM)              *
****************************************************************************************************
*                            Created by Marko Tscherepanow, 14 March 2020                          *
***************************************************************************************************)

// Compile and run from the console: dotnet run --project TopoART-AM_sample2.fsproj

/// <summary>
/// Learning of bidirectional associations between images. [F#]
/// <para>
/// Similar to Section 4.2 of "Marko Tscherepanow, Marco Kortkamp, and Marc Kammer (2011). A Hierarchical ART
/// Network for the Stable Incremental Learning of Topological Structures and Associations from Noisy Data. 
/// Neural Networks 24(8): 906-916. Elsevier.", a TopoART-AM network is trained with real-world image data.
/// There are two kinds of images grouped into owners and objects. TopoART-AM learns a bidirectional mapping
/// between images of these two groups. Each image has a size of about 34500 pixels. As each pixel comprises 3 
/// color channels (RGB) and the input vector encompasses a key from each group, the total length of the input 
/// vector is larger than 200,000. After finishing training, recall is performed for a single input stimulus of 
/// each group. The recall results are saved in the folder <c>results/recall/ObjectsOwners_dataset_recall_results</c>.
/// </para>
///</summary>
module LibTopoART_samples.TopoART_AM_sample2

open LibTopoART
open System
open SixLabors.ImageSharp
open SixLabors.ImageSharp.PixelFormats
open System.Globalization
open System.IO
open System.Reflection
open System.Runtime.InteropServices

// Dataset (containing training and test images)
let datasetPath = "../../../../../data/ObjectsOwners_dataset/";

// Destination directory for the recall results
let resultPath = "../../../../../results/recall/ObjectsOwners_dataset_recall_results/"

// Destination directory for trained networks
let networkPath = "../../../../../results/networks/"

// Key sizes
let key1Width = 214
let key1Height = 161
let key1Len = int64 <| key1Width * key1Height * 3; // Size of key 1 (object images)
let key2Width = 172
let key2Height = 201
let key2Len = int64 <| key2Width * key2Height * 3; // Size of key 1 (owner images)

// TopoART-AM parameters
let rho_a = 0.80m
let beta_sbm = 0.8m
let phi = 5L
let tau = 136L   // one fifth of the training dataset size (entries in ObjectMap * TrainImageNum)

// Application parameters
let objectNum = 20
let ownerNum = 6
let trainImageNum = 20
let testImageNum = 5
let maxTrainIterations = 25

// m to n mapping from owners to objects (some objects are shared by multiple owners)
let objectMap = [| [| 1; 2; 3; 9; 17; 19 |];             // objects of owner 1
                   [| 3; 4; 9; 15; 17 |];                // objects of owner 2
                   [| 1; 3; 7; 10; 11; 12; 16; 17 |];    // objects of owner 3
                   [| 3; 5; 17; 18; 20 |];               // objects of owner 4
                   [| 3; 8; 14; 19; 20 |];               // objects of owner 5
                   [| 3; 5; 10; 13; 16 |] |]             // objects of owner 6

(*--------------------------------------------------------------------------------------------------
-                                           Helper functions                                       -
--------------------------------------------------------------------------------------------------*)

let LoadSingleImage (path : string) =
    let image = Image.Load<Rgb24>(path);
    let width = image.Width
    let height = image.Height

    if (width = key1Width && height = key1Height) || (width = key2Width && height = key2Height) then
        let pixelNumber = width * height
        let data = Span<Rgb24>(Array.create pixelNumber <| Rgb24())
        image.CopyPixelDataTo(data);
        MemoryMarshal.AsBytes(data).ToArray()
    else
        printf "loading %s failed\n" path
        null

let LoadImages num subpath =
    let commonPath = datasetPath + subpath + "image_"
    let images = Array.init<byte[]> num (fun _ -> null)
    for i = 1 to num do
        let individualPath = commonPath + i.ToString("D2") + ".jpg"
        images[i - 1] <- LoadSingleImage individualPath
    images

let LoadImageArray sets subpath =
    let imageArray = Array.init<byte[][]> sets (fun _ -> null)
    for i = 1 to sets do
        let path = subpath + i.ToString("D2") + "/"
        imageArray[i - 1] <- LoadImages trainImageNum path
    imageArray

let LoadObjectImage (object : int) (image : int) =
    let path = datasetPath + "test/objects/object_" + object.ToString("D2") + "/" + "image_" + image.ToString("D2") + ".jpg"
    let image = LoadSingleImage path
    (image, path)

let LoadOwnerImage (owner : int) (image : int) =
    let path = datasetPath + "test/owners/owner_" + owner.ToString("D2") + "/" + "image_" + image.ToString("D2") + ".jpg"
    let image = LoadSingleImage path
    (image, path)

let Train (tam : IFast_TopoART_AM) (objects : byte[][][]) (owners : byte[][][]) =
    let mutable trainingSteps = 0
    let mutable stop = false
    let mutable i = 0
    while not stop && i < maxTrainIterations do
        tam.ResetAdaptationState()
        i <- i + 1
        for owner = 1 to objectMap.Length do
            for ObjectIndex = 1 to objectMap[owner - 1].Length do
                let object = objectMap[owner - 1].[ObjectIndex - 1]
                for imageIndex = 0 to trainImageNum - 1 do
                    tam.Learn(objects[object - 1][imageIndex], owners[owner - 1][imageIndex])
                    trainingSteps <- trainingSteps + 1
            printf "."
        stop <- (tam.GetAdaptationState(0.001m) &&& AdaptationState.ANY_PERMANENT_ADAPTATION_MASK) = AdaptationState.NO_ADAPTATION
    printf "\n"
    trainingSteps

let RecallImageName path (iteration : int) =
    path + "recall_image_" + iteration.ToString("D2") + ".jpg"

let RecallLoop min_F3_activation (tam : IFastAssociativeRecall) =
    let mutable stop = false
    let mutable iteration = 1
    let mutable results = []
    while not stop do 
        let (success, result : byte[], activation : decimal) = tam.RecallStep()
        stop <- not success || activation < min_F3_activation
        if not stop then results <- (iteration, result, activation) :: results
        iteration <- iteration + 1
    results

let SaveActivations (activations : decimal list) path =
    use file = new StreamWriter(path + "activations.txt")
    for activation in activations do
        file.WriteLine(activation.ToString(CultureInfo.InvariantCulture))

let SaveRecallImage (array : byte[]) width height iteration path =
    use image = new Image<Rgb24>(width, height)
    let mutable i = 0
    for j in 0..3..(array.Length - 1) do
        image[int (i % width), int(i / width)] <- Rgb24(array[j], array[j + 1], array[j + 2])
        i <- i + 1
    let file = RecallImageName path iteration
    image.SaveAsJpeg(file)

let SaveResults (results : (int * byte [] * decimal) list) width height path =
    let mutable activations = []
    for iteration, result, activation in results do
        SaveRecallImage result width height iteration path
        activations <- activation :: activations

    let mutable iteration = results.Length + 1
    let mutable file = RecallImageName path iteration
    while File.Exists(file) do
        File.Delete(file)
        iteration <- iteration + 1
        file <- RecallImageName path iteration

    SaveActivations activations path

let RecallObjects owner image min_F3_activation (tam : IFastAssociativeRecall) =
    let stimulus, spath = LoadOwnerImage owner image
    let dpath = resultPath + "key1/"

    // Copy stimulus
    File.Copy(spath, dpath + "stimulus.jpg", true);

    // Start recall time measuring
    let recall_start = DateTime.Now

    let _ = tam.BeginRecallKey1(stimulus)
    let results = RecallLoop min_F3_activation tam
    tam.EndRecall()

    // Stop recall time measuring
    let recall_end = DateTime.Now

    // Output the required time
    let recall_time = recall_end - recall_start
    printf "Time for recalling key 1: %s\n" <| recall_time.ToString()

    // Save results
    SaveResults results key1Width key1Height dpath

let RecallOwners object image min_F3_activation (tam : IFastAssociativeRecall) =
    let stimulus, spath = LoadObjectImage object image
    let dpath = resultPath + "key2/"

    // Copy stimulus
    File.Copy(spath, dpath + "stimulus.jpg", true);

    // Start recall time measuring
    let recall_start = DateTime.Now

    let _ = tam.BeginRecallKey2(stimulus)
    let results = RecallLoop min_F3_activation tam
    tam.EndRecall()

    // Stop recall time measuring
    let recall_end = DateTime.Now

    // Output the required time
    let recall_time = recall_end - recall_start
    printf "Time for recalling key 2: %s\n" <| recall_time.ToString()

    // Save results
    SaveResults results key2Width key2Height dpath

(*--------------------------------------------------------------------------------------------------
-                                             Main program                                         -
--------------------------------------------------------------------------------------------------*)

[<EntryPoint>]
let main args =

    // Set working directory to assembly directory
    let cb = Uri(Assembly.GetEntryAssembly().Location)
    Directory.SetCurrentDirectory(Path.GetDirectoryName(cb.LocalPath))

    printf "Load training images\n"
    let TrainingObjectImages = LoadImageArray objectNum "train/objects/object_"
    let TrainingOwnerImages = LoadImageArray ownerNum "train/owners/owner_"

    // Create TopoART-AM network (with only a single module in order to accelerate computations)
    let tam = new Fast_TopoART_AM(key1Len, key2Len, 1L, rho_a, Beta_sbm = beta_sbm, Tau = tau, Phi = phi)

    let filePrefix = "Fast_TopoART-AM"
    let fileSuffix = "ftam"

    // Start training time measuring
    let trainingStart = DateTime.Now

    // Train
    let _ = Train tam TrainingObjectImages TrainingOwnerImages

    // Stop training time measuring
    let trainingEnd = DateTime.Now

    // Output the required time
    let trainingTime = trainingEnd - trainingStart
    printf "Time for training: %s\n" <| trainingTime.ToString()

    // Bidirectional recall of objects and owners based on images from the test set

    // The number of images that are recalled is controlled by the parameter min_F3_activation;
    // the lower the activation of an F3 node the less accurate is the respectively recalled 
    // image. Due to the similarity of the images, only a small fraction of the theoretically 
    // possible interval of [0, 1] can be used here.

    // The indexes may be changed in order to use different test stimuli. There are five test 
    // images per object and per owner.

    // Recall key 1 (objects) by presenting a stimulus key 2 (owner)
    RecallObjects 3 3 0.975m tam   // key 2: owner 3; test image: 3; recall threshold: 0.975

    // Recall key 2 (owners) by presenting a stimulus key 1 (object)
    RecallOwners 20 4 0.970m tam   // key 1: object 20; test image: 4; recall threshold: 0.970

    printf "Save network\n"

    // Save network in human-readable form
    // tam.SaveText(networkPath + filePrefix + "_ObjectsOwners_dataset.txt");

    // Save network in binary form
    tam.Save(networkPath + filePrefix + "_ObjectsOwners_dataset." + fileSuffix);

    0
