/*
    In order to adapt this script according to your needs please edit the file 
    'GlobalDefinitions.as' that is located in the \Includes directory for this
    script.
*/

#include "Includes\\WaveRecorder.as"
#include "Includes\\WaveFile.as" 
#include "Includes\\GlobalDefinitions.as"
#include "Includes\\ErrorCheck.as"
#include "Includes\\NetEditor.as"

// static variables
string sWorkDir;                    // The working directory of the script
string sWavePath;                   // The full path to the current wave file
uint sParticipantCount = 0;         // Number of participants for the experiment
uint sCurParticipantIdx;            // The array index of the currently processed participant
bool sRandomized = false;           // Has the net already been randomized?

// static objects created from script classes
WaveRecorder sRecorder;             // The wave recorder object
NetEditor sNetEditor;               // The net editor object


// Length of a single sample snippet [sample points]
const uint SAMPLE_COUNT_PER_SNIPPET = (SNIPPET_LEN_MS * SAMPLE_FREQ) / 1000; 

/*
    Main program (Script entry point)
*/
void main()
{
    // Do some initialization
    Init();
    
    // Handle collection of new training data
    if (MessageBox("Do you want to record new training data first?", MB_YESNO) == IDYES)
    {
        CollectTrainData();
    }
    
    // Handle creation of a new net or loading of an existing net
    if (MessageBox("Do you want to create a new voice detection\nnet for the collected data now ?", MB_YESNO) == IDYES)
    {
        CreateNet();
    }
    else
    {
        LoadNet();
    }
    
    // Handling the training process
    if (MessageBox("Do you want to train the net now ?", MB_YESNO) == IDYES)
    {
        Training();
    }
    
    // Handle the actual voice detection process
    if (MessageBox("Do you want to start the online voice recognition ?", MB_YESNO) == IDYES)
    {
        VoiceRecognition();
    }
    
    // Do som cleaning up when the script exits normally
    CleanUp();
}


/*
    Initialization after the script has been started.
*/
void Init()
{
    // Clear any data currently available in the Lesson Editor
    SetLessonCount(1);
    ClearLesson();
    
    // Hide all currently not required windows
    ShowLessonEditor(false);
    ShowErrorViewer(false);
    ShowPatternErrorViewer(false);
    
    // Initialize the wave recorder. This starts "Hard Disk Ogg" (freeware application)
    // and minimizes it to the windows toolbar.
    sRecorder.Init();
    
    // Retriev the scripts working directory (this is where this file is located)
    GetWorkDir(sWorkDir);
    
    // Obtain the number of participants for the experiment
    sParticipantCount = PARTICIPANTS.length();
    
    // Create a sub directory for every participant (if not already existing)
    for (sCurParticipantIdx = 0; sCurParticipantIdx < sParticipantCount; sCurParticipantIdx++)
    {
        string participantName = PARTICIPANTS[sCurParticipantIdx];
        EnsureTrue(MakeDir(PARTICIPANT_DIR), "Unable to create participant database directory");
        EnsureTrue(MakeDir(PARTICIPANT_DIR + "\\" + participantName), "Unable to create directory: " + PARTICIPANT_DIR + "\\"+ participantName);
    }
}


/*
    Some cleaning up when the script exits normally
*/
void CleanUp()
{
    // Shut down the wave recorder (Hard Disk Ogg)
    sRecorder.Shutdown();
    // Enable display of links for sure since this might have been disabled during the script run
    ViewSetting(SHOW_LINKS, true);
}


/*
    Create a new net that fits to the experiment's data layout.
*/
void CreateNet()
{
    // Create the net's input and output neurons and name the output neurons (with the participants' names)
    sNetEditor.CreateNet(INPUT_DIMENSION, sParticipantCount, HID_LAYERS.length());
    for (uint i = 0; i < sParticipantCount; i++)
    {
        SelectOutput(i + 1, false);
        SetSelectedNeuronName(PARTICIPANTS[i]);
    }
    ClearSelection();
    // Create the desired hidden layers, randomize and save the net.
    sNetEditor.CreateHiddenLayers(HID_LAYERS);
    RandomizeNet();
    sRandomized = true;
    SaveNet(NET_NAME);
}


/*
    Load an existing net or abort the script if the user selects 'No'.
*/
void LoadNet()
{
    string question = "Do you want to load an existing voice detection\nnet for the collected data now ?\n\n";
    
    question += "Selecting No will abort the script.";
    
    if (MessageBox(question, MB_YESNO) == IDYES)
    {
        string path;
        if (FileOpenDlg("Open Neural Net File", "mbn", "NET_NAME", path) == IDOK)
        {
            OpenNet(path);
        }
    }
    else
    {
        AbortScript();
    }
}


/*
    Handle the training process
*/
void Training()
{
    // Now all collected training lessons for all participants are merged together to build one
    // training lesson. 
    MergeToTrainLesson();
    
    EnsureTrue(GetLessonSize() > 0, "No data to train from");
        
    // The net takes over all names from the active lesson (if possible).
    NamesToNet();
    ShowLessonEditor(false);
    
    // Normalize the net using the full training lesson and optionally randomize the net
    sNetEditor.NormalizeNetWithActiveLesson();
    if (!sRandomized && MessageBox("Randomize before training ?", MB_YESNO) == IDYES)
    {
        RandomizeNet();
    }
    
    // Start training 
    ShowPatternErrorViewer(true);
    ShowErrorViewer(true);
    TeacherSetting(TARGET_ERR, TARGET_NET_ERROR);
    Trace("Training...");
    StartTeaching();
    
    // The pattern error viewer shall update every now and then to display the target/actual output 
    // activations for the next participant
    SecondsTimer updateTimer;
    
    updateTimer.Start(PATTERN_ERR_VIEWER_TOGGLE_SECONDS);
    uint patternViewerOutput = 1;
    
    while (IsTeaching())
    {
        if (sParticipantCount > 1)
        {
            if (updateTimer.IsElapsed())
            {
                updateTimer.Start(PATTERN_ERR_VIEWER_TOGGLE_SECONDS);
                patternViewerOutput++;
                if (patternViewerOutput > sParticipantCount)
                {
                    patternViewerOutput = 1;
                }
                EnsureTrue(SelectPatternErrorViewerData(patternViewerOutput), "Error in setting selection for Pattern Error Viewer");
            }
        }
        Sleep(100);
    }
    Trace(" Finished.\n");
}

/*
    Handle the complete collection of new training data
*/
void CollectTrainData()
{
    bool abort = false;
    
    // Loop through all participants one time. Optionally abort if the user wants to.
    for (sCurParticipantIdx = 0; sCurParticipantIdx < sParticipantCount && !abort; sCurParticipantIdx++)
    {
        string participantName = PARTICIPANTS[sCurParticipantIdx];
        string question = "Record new training sample for\n\n" + participantName + " ?\n\n";
        
        question += "Select Cancel to Abort the training data collection process";
        
        EDlgRet ret = MessageBox(question, MB_YESNOCANCEL);
    
        switch (ret)
        {
        case IDYES:
            // Build full path for target wave file and start recording. If the user accepted the sample then
            // convert the data and add it to the participant's training data lesson.
            sWavePath = sWorkDir + "\\" + PARTICIPANT_DIR + "\\"+ participantName + "\\" + WAVE_FILE_NAME;
            if (RecordParticipant())
            {
                BuildParticipantLesson();
            }
            break;
            
        case IDNO:
            // User does not want to record new data for this participant. Proceed to next participant.
            break;
            
        case IDCANCEL:
        default:
            // User selected to cancel the whole recording process (for all participants)
            abort = true;
            break;
        }  
    }
    
    // Merge all old and new training data into one large training lesson for the net.
    MergeToTrainLesson();
}


/*
    Record a new wave file for the currently active participant.
    
    The user has to confirm acceptance of the recorded wave file. Return false if the wave file 
    has not been accepted.
    The function also lets the user retry the recording.  
*/
bool RecordParticipant()
{
    bool accepted = false;
    bool abort = false;
    
    while (!accepted && !abort)
    {
        string participantName = PARTICIPANTS[sCurParticipantIdx];
        
        Trace("Recording " + participantName + "...\n");
        Record(OVERALL_REC_SECONDS_PER_PARTICIPANT);
        string question = "Accept recorded sample for\n\n" + participantName + " ?\n\n";
        
        question += "Select No to repeat the record.\n\n";
        question += "Select Cancel to skip recording " + participantName;
        
        EDlgRet ret = MessageBox(question, MB_YESNOCANCEL);
    
        switch (ret)
        {
        case IDYES:
            accepted = true;
            break;
            
        case IDNO:
            break;
            
        case IDCANCEL:
        default:
            abort = true;
            break;
        }
    }
    
    return accepted;
}


/*
    Record a new wave file to the path specified by sWavePath.
    
    The wave file will be of a time length of at least 'durationSeconds'. Note 
    that usually the file will become a bit longer to the slow reaction of Hard 
    Disk Ogg to the issued commands.
    
    The function starts the wave recorder. It will light up red in the tool bar of Windows to indicate 
    that recording is active. The recorder will be stopped by this function after the desired 
    time has elapsed.
*/
void Record(uint durationSeconds)
{
    // An eventually already existinf wave file will be deleted.
    DeleteFile(sWavePath);
    // Start the recorder with a timeout value here to prevent Hard Disk Ogg from continuing 
    // recording forever in case the script gets aborted.
    sRecorder.Start(sWavePath, durationSeconds + 1);
    Sleep(durationSeconds * 1000);
    sRecorder.Stop();
    
    // We wait here until the file exists since Hard Disk Ogg needs some time to write and close the file.
    // The function also ensured that the file is not write-accessed anymore.
    while (!FileExists(sWavePath))
    {
        // Leave some computing power for other tasks while waiting!
        Sleep(100);
    }
}


/*
    The function converts the wave file stored at sWavePath into useful lesson data and 
    appends the data to the lesson already existing for the current participant.
*/
void BuildParticipantLesson()
{
    string participantName = PARTICIPANTS[sCurParticipantIdx];
    
    Trace("Creating/expanding training lesson for " + participantName + "...\n");
    
    // We create a temporary lesson here with only the new data. This is appended to the already 
    // existing lesson further down in this function.
    string lessonPathTemp = sWorkDir + "\\" + PARTICIPANT_DIR + "\\"+ participantName + "\\" + TEMP_LESSON_NAME;
    string lessonPathParticipant = sWorkDir + "\\" + PARTICIPANT_DIR + "\\"+ participantName + "\\" + TRAIN_LESSON_NAME;
    
    // Load the existing data into lesson 1 and prepare it for expansion by the new data.
    SetLessonCount(1);
    PrepareParticipantLessonForExpansion(lessonPathParticipant);
    
    // Create a new lesson from the data in the wave file and set this as the active lesson
    ConvertWaveToNewAveragedFftLesson();
    // Save the newly created lesson to a temporary file
    SaveLesson(lessonPathTemp);
    // Switch back to existing lesson data and delete newly created lesson from Lesson Editor. Append the 
    // new data stored in the temporary lesson file to the existing data and save the new, expanded lesson
    // for the participant.
    SetLessonCount(1);
    AppendLesson(lessonPathTemp);
    SaveLesson(lessonPathParticipant);
    EnableLessonEditorUpdate(true);
    ShowLessonEditor(false);
}


/*
    Convert the wave file (sWavePath) into a new lesson in the Lesson Editor and set the 
    new lesson to be the active one.
    
    The wave file is read in by use of an object of the script class 'WaveFile'.
    After reading in the file portions of length SAMPLE_COUNT_PER_SNIPPET are taken and transformed 
    to the frequency domain using an FFT transform. After the transformation the patterns are 
    averaged with respect to their inputs in order to fit into the specified input dimension 
    given by the constant INPUT_DIMENSION. 
*/
void ConvertWaveToNewAveragedFftLesson()
{
    // The wave file object is constructead and read in from file
    WaveFile sampleFile;
    
    EnsureTrue(sampleFile.Read(sWavePath), "Unable to read from wave file or file not found");
        
    uint overallSampleCount = sampleFile.GetSampleCountPerChannel();
    uint processedSampleCount = 0;
    uint channelCount = sampleFile.GetChannelCount();
    
    EnsureTrue((channelCount > 0) && (channelCount <= 2), "Wave file channel count is out of range");

    // Create new lesson to be filled with raw data from wave file and set this lesson as the active one.
    // Prepare the new lesson with the correct input dimension and output dimension = 0.
    SetLessonCount(GetLessonCount() + 1);
    SelectLesson(GetLessonCount());
    SetLessonInputCount(SAMPLE_COUNT_PER_SNIPPET);
    SetLessonOutputCount(0);
            
    // Left and right channel timepoint values from the wave file
    int sampleLeft;
    int sampleRight;
    // Suppress Lesson Editor update while filling the lesson. Increases processing speed.
    EnableLessonEditorUpdate(false);
    
    // Now fill the inputs of the lesson with SAMPLE_COUNT_PER_SNIPPET wave timepoints per pattern.
    // Note: For a MONO recording the channel 'right' will always be 0.
    while (processedSampleCount + SAMPLE_COUNT_PER_SNIPPET <= overallSampleCount)
    {
        AddPattern();
        for (uint i = 1; i <= SAMPLE_COUNT_PER_SNIPPET; i++)
        {
            bool ok = sampleFile.GetNextSample(sampleLeft, sampleRight);
            double value = (double(sampleLeft) + double(sampleRight)) / channelCount;
            
            EnsureTrue(ok, "Error in reading from recorded file");
            SetPatternInput(i, double(sampleLeft));
        }
        processedSampleCount += SAMPLE_COUNT_PER_SNIPPET;
    }
       
    
    // Create the FFT lesson for the Lesson
    EnsureTrue(CreateFftLesson(false, false), "Unable to create FFT Lesson");
    // Select the newly created FFT Lesson and create an averaged lesson with a 
    // reduced input dimension
    SelectLesson(GetLessonCount());
    EnsureTrue(CreateInputAverageLesson(INPUT_DIMENSION), "Unable to create averaged input Lesson from FFT Lesson");
    // Select averaged lesson and save it
    SelectLesson(GetLessonCount());

    // Re-enable the Lesson Editor update
    EnableLessonEditorUpdate(true);
}

/*
    Prepare an eventually existing lesson given by the provided path for being extended by new patterns.
    If the lesson exists it is opened and the number of inputs is adjusted if required. In this case a 
    warning is printed in the script trace window and a backup copy of the lesson will be created before 
    the data in the existing lesson is erased.
    If the lesson does not exists then the currently active lesson is cleared and adjusted in order to 
    receive the new data.
*/
void PrepareParticipantLessonForExpansion(string path)
{
    if (FileExists(path))
    {
        LoadLesson(path);
        
        bool compatible = (GetLessonInputCount() == INPUT_DIMENSION);
    
        if (!compatible)
        {
            string message = "Existing participant lesson has different input count than required.\n";
            
            message += "A backup copy will be created (extension '.bak') and the data in the original\n";
            message += "lesson will be erased.\n";
            Trace(message);
            SaveLesson(path + ".bak");
            ClearLesson();
        }
    }
    else
    {
        ClearLesson();
    }
    
    SetLessonInputCount(INPUT_DIMENSION);
    
    if (GetLessonOutputCount() > 0)
    {
        SetLessonOutputCount(0);
    }
}


/*
    This function merges all existing training lessons from all participants into one single training lesson that 
    ends up as lesson #1 in the lesson Editor.
    
    The number of outputs and the name of the outputs for the training lesson is adjusted according to the 
    participants list. Also the correct output values are set: The output that belongs to the participant related 
    to a certain pattern gets assigned the value '1' while all other outputs get assigned the value '0'.
*/
void MergeToTrainLesson()
{
    Trace("Merging training lessons...\n");
    
    // Increase performance by disabling the lesson Editor update.
    EnableLessonEditorUpdate(false);
    
    // Array with number of patterns appended for every participant
    uint[] patternCount(sParticipantCount);
    uint lastPatternCount = 0;
    
    SetLessonCount(1);
    ClearLesson();
    SetLessonInputCount(INPUT_DIMENSION);
    SetLessonOutputCount(0);
        
    // Append all participant lessons to train lesson. Remember the number of appended patterns for every participant. 
    for (sCurParticipantIdx = 0; sCurParticipantIdx < sParticipantCount; sCurParticipantIdx++)
    {
        string participantName = PARTICIPANTS[sCurParticipantIdx];
        string participantLessonPath = sWorkDir + "\\" + PARTICIPANT_DIR + "\\"+ participantName + "\\" + TRAIN_LESSON_NAME;
        
        lastPatternCount = GetLessonSize();
        
        if (FileExists(participantLessonPath))
        {
            AppendLesson(participantLessonPath);
            patternCount[sCurParticipantIdx] = GetLessonSize() - lastPatternCount;
        }   
        else
        {
            patternCount[sCurParticipantIdx] = 0;
        }
    }
    
    // Now set the correct number of outputs and assign the correct names to the lesson output columns
    SetLessonOutputCount(sParticipantCount);
    for (sCurParticipantIdx = 0; sCurParticipantIdx < sParticipantCount; sCurParticipantIdx++)
    {
        SetLessonOutputName(sCurParticipantIdx + 1, PARTICIPANTS[sCurParticipantIdx]);
    }
    
    // Patterns are created with all 0's on the outputs. Now we have to assign a '1' to every output 
    // that represents the correct participant for every pattern in the lesson.
    uint pattern = 1;
    
    for (sCurParticipantIdx = 0; sCurParticipantIdx < sParticipantCount; sCurParticipantIdx++)
    {
        for (uint i = 1; i <= patternCount[sCurParticipantIdx]; i++)
        {
            SelectPattern(pattern);
            SetPatternOutput(sCurParticipantIdx + 1, 1);
            pattern++;
        }
    }
    
    // Save the lesson, enable the Lesson Editor update again and hide the Lesson Editor
    SaveLesson(TRAIN_LESSON_NAME);
    EnableLessonEditorUpdate(true);
    ShowLessonEditor(false);
}


/*
    Start and control the voice recognition process.
    This is an endless loop of recording a short wave file, cutting it into little pieces of
    SAMPLE_COUNT_PER_SNIPPET samples, calculating the FFT transform, converting it into the 
    correct format for the net and evaluating the net output which is presented in the script 
    tracing window.
*/
void VoiceRecognition()
{
    SetLessonCount(1);
    ShowLessonEditor(false);
    ShowErrorViewer(false);
    ShowPatternErrorViewer(false);
    // We don't want to display the links and we want the net to update on every think step.
    ViewSetting(SHOW_LINKS, false);
    ViewSetting(UPDATE_THINK, true);
    
    Trace("Starting online voice recognition...\n");
    sWavePath = sWorkDir + "\\" + WAVE_FILE_NAME;

    while (true)
    {
        RecordVoiceForRecognition();
        RecognizeVoice();
    }
}

/*
    Record a wave file for detection
*/
void RecordVoiceForRecognition()
{
    Record(OVERALL_REC_SECONDS_FOR_DETECTION);
}


/*
    Evaluate a wave file with respect to detection:
    
    The wave file specified by sWavePath is converted into an averaged FFT lesson as already 
    happened during training.
    After that all patterns that result from the wave file are applied to the net inputs,
    the net output is calculated and the winner neuron (the one with the highest activation) is 
    determined.
    An array is used to count the winner events for every pattern separately for all participants
    so that after the whole lesson has been processed the participant with the most win events can 
    be determined as the winner.
    
    The winner neuron name (i.e. the winner participant) is finally printed in the script trace window
    together with its percentage wins.
*/
void RecognizeVoice()
{
    SetLessonCount(1);
    ClearLesson();
    ConvertWaveToNewAveragedFftLesson();
    EnableLessonOutData(false);
    
    uint patternCount = GetLessonSize();
    uint[] winnerCount(sParticipantCount);
    
    for (uint i = 0; i < sParticipantCount; i++)
    {
        winnerCount[i] = 0;
    }
    
    for (uint i = 0; i < patternCount; i++)
    {
        SelectPattern(i + 1);
        ApplyPattern();
        ThinkSteps(1);
        SleepExec();
        
        uint winner = GetOutputWinnerNeuron();
        
        EnsureTrue(winner != 0, "Unable to get winner neuron");
        winnerCount[winner - 1]++;
    }
    ShowLessonEditor(false);
    
    uint overallWinnerIdx = FindMaxIdx(winnerCount);
    
    double percent = 100 * double(winnerCount[overallWinnerIdx]) / patternCount; 
    
    Trace("Detected:\t\t" + PARTICIPANTS[overallWinnerIdx] + "\t(" + percent + " %)\n"); 
}

/*
    This function finds the index in the given array that represents the maximum value 
    in the array.
    If the array contains all equal values then 0 is returned as maximum index.
*/
uint FindMaxIdx(uint[] &in arr)
{
    uint maxIdx = 0;
    uint max = 0;
    uint len = arr.length();
    
    for (uint i = 0; i < len; i++)
    {
        if (arr[i] > max)
        {
            max = arr[i];
            maxIdx = i;
        }
    }
    
    return maxIdx;
}

/*
    Check if the specified file does exist and is ready to be opened.
    The function only returns true if the file is not currently write accessed.
*/
bool FileExists(const string &in fileName)
{
    file testFile;
    
    bool exists = testFile.Open(fileName, FILE_MODE_READ | FILE_SHARE_DENY_WRITE);
    
    if (exists)
    {
        testFile.Close();
    }
    
    return exists;
}





