/*
Sensitivity analysis script for MemBrain

Usage: 
- Load a trained net in MemBrain and start the script

What the script does:

- Apply default value activations to all input neurons except for the analysed one
- For choosing the default values there are two options, adjustable via the constant 'USE_CURRENT_INPUT_ACTICATIONS_AS_DEFAULT'
  in the constant section of this script:
  		USE_CURRENT_INPUT_ACTICATIONS_AS_DEFAULT = true
  		--> The activations of the input neurons when the script is started are used as default values
  		USE_CURRENT_INPUT_ACTICATIONS_AS_DEFAULT = false
  		--> The mean values of the individual input neurons activation ranges are used as default values 	 
- For the analysed input neuron apply a range of activations from min to max of the neuron and record net output(s)
- Perform this analysis for all the input neurons in the net
- Export the results as csv file
- Optionally automatically open the csv file in an external application (default application registered for csv files with Windows)
- Optionally present the results in the Pattern Error Viewer of MemBrain

What the script can be used for:
- Get an idea of the influence of the different input neurons on the output(s) of the net

Note: The section 'Global constants' below allows to adjust some settings for the script
*/

// Global constants --> Adjustable settings for the script -------------------------------------

// Number of data points for each sensitivity series
uint TEST_POINT_COUNT = 100;									

// Name of the generated result csv file
string RESULT_FILE_NAME = "SensititivityAnalysis.csv";   	

// Wait time between two curve displays in the result animation [ms]
uint SLEEP_TIME_FOR_RESULT_ANIMATION = 2000; 				

// Use the current input activations of the loaded net as default values for all other neurons than the
// currently examined neuron. If set to false then the mean values of min/max for each input neuron are used. 
bool USE_CURRENT_INPUT_ACTICATIONS_AS_DEFAULT = true;		

// ---------------------------------------------------------------------------------------------


// Global variables
uint gInCount = 0;
uint gOutCount = 0;
uint gTestPatternCount = 0;
uint gCurrentPatternIdx = 0;


// The lesson with all test patterns
Lesson gTestLesson;
Lesson gRecordLesson;
Lesson gResultLesson;

// Minimum and maximum activation values
array<double> gActMin;
array<double> gActMax;
array<double> gActDefault;


// Program entry point
void main()
{
	Init();
	PrepareTestLesson();
	PrepareResultLesson();
	RecordNetReactionToTestLesson();
	FillResultLessonFromRecordedLesson();		
	ExportResultLesson();
	OpenResultLessonInExternalApp();
	PresentResultLessonInMemBrain();
}

// Initialize the script variables
void Init()
{
	if (TEST_POINT_COUNT < 2)
	{
		AbortWithMessage("TEST_POINT_COUNT must be 2 or higher!");
	}
	
	gInCount = GetInputCount();
	gOutCount = GetOutputCount();
	
	if (gInCount == 0)
	{
		AbortWithMessage("AT least one input neuron required to be in net!");
	}
	
	if (gOutCount == 0)
	{
		AbortWithMessage("AT least one output neuron required to be in net!");
	}
	
	// One test pattern for each input and each test point
	gTestPatternCount = TEST_POINT_COUNT * gInCount;
	
	// Collect min and max activations in arrays each for fast access later on
	gActMin.resize(gInCount);
	gActMax.resize(gInCount);
	gActDefault.resize(gInCount);
	
	double min;
	double max;
	double actDefault;
	
	for (uint i = 0; i < gInCount; i++)
	{
		GetInputActRange(i + 1, min, max);
		gActMin[i] = min;
		gActMax[i] = max;
		if (USE_CURRENT_INPUT_ACTICATIONS_AS_DEFAULT)
		{ // Get the default activations from the neurons as they are when the script is started
			GetInputAct(i + 1, actDefault);
		}
		else
		{ // Use the mean values of the activation range as defaults
			actDefault = (min + max) / 2;
		}
		gActDefault[i] = actDefault;
	}
}


// Prepare all test patterns in the test lesson
void PrepareTestLesson()
{
	// Prepare test lesson format in lesson editor @#1
	SetLessonCount(1);
	NamesFromNet();
	ClearLesson();
	
	// Init our test lesson object from lesson editor and already set target number of patterns
	gTestLesson.CloneFromLessonEditor();
	gTestLesson.SetSize(gTestPatternCount);
	
	gCurrentPatternIdx = 0;
	
	for (uint i = 0; i < gInCount; i++)
	{
		PrepareTestPatternsForInput(i);
	}
}

// Prepare the format and size of the result lesson
void PrepareResultLesson()
{
	SelectLesson(1);
	gResultLesson.CloneFromLessonEditor();
	// For each input we want a trace of all outputs. 
	gResultLesson.SetOutputCount(gInCount * gOutCount);
	gResultLesson.SetSize(TEST_POINT_COUNT);
	
	string inName;
	string outName;
	
	// Create an output column name for each output data series
	for (uint outIdx = 0; outIdx < gOutCount; outIdx++)
	{
		gTestLesson.GetOutputName(outIdx, outName);
		for (uint inIdx = 0; inIdx < gInCount; inIdx++)
		{
			gTestLesson.GetInputName(inIdx, inName);
			gResultLesson.SetOutputName(outIdx * gInCount + inIdx, outName + "[" + inName + "]");
		}
	} 
}


// Prepare all test patterns for the 0-based input index <inIdx> 
void PrepareTestPatternsForInput(uint inIdx)
{
	double act;
	
	for (uint step = 0; step < TEST_POINT_COUNT; step++)
	{
		act = gActMin[inIdx] + (gActMax[inIdx] - gActMin[inIdx]) / (TEST_POINT_COUNT - 1) * step;
		PrepareTestPattern(gCurrentPatternIdx, inIdx, act);
		gCurrentPatternIdx++;
	}
}


// Prepare the test pattern at 0-based index <patternIdx> for input at index <inIdx>. 
// Sets pattern input <inIdx> to the value <act> and all other inputs to their corresponding mean value.
void PrepareTestPattern(uint patternIdx, uint inIdx, double act)
{
	double value;
	
	for (uint i = 0; i < gInCount; i++)
	{
		if (i == inIdx)
		{
			value = act;
		}
		else
		{
			value = gActDefault[i];
		}
		gTestLesson.SetInput(patternIdx, i, value); 
	}
}

// Record the net's reaction to the test lesson
void RecordNetReactionToTestLesson()
{
	LoadLesson(gTestLesson);
	SetLessonCount(2);
	SetRecordingType(RT_ACT);
	StartRecording(2);
	ViewSetting(UPDATE_THINK, false);
	ThinkLesson();
	SleepExec();
	StopRecording();
	ViewSetting(UPDATE_THINK, true);
	SelectLesson(2);
	gRecordLesson.CloneFromLessonEditor();
	
	// Re-apply default values to all neurons to leave a cleaned-up state for the user in MemBrain 
	for (uint i = 0; i < gInCount; i++)
	{
		ApplyInputAct(i + 1, gActDefault[i]);
	}
}


// Fill the result lesson from the content of the recorded lesson
void FillResultLessonFromRecordedLesson()
{
	// Copy the data for each series to the corresponding result lesson columns
	uint testPatIdx = 0;
	
	for (uint inIdx = 0; inIdx < gInCount; inIdx++)
	{
		for (uint resultPatIdx = 0; resultPatIdx < TEST_POINT_COUNT; resultPatIdx++)
		{
			double value;
			gRecordLesson.GetInput(testPatIdx, inIdx, value);
			gResultLesson.SetInput(resultPatIdx, inIdx, value);
			for (uint outIdx = 0; outIdx < gOutCount; outIdx++)
			{
				gRecordLesson.GetOutput(testPatIdx, outIdx, value);
				gResultLesson.SetOutput(resultPatIdx, outIdx * gInCount + inIdx, value);
			}
			testPatIdx++;
		}
	}	
}


// Export the result lesson to csv
void ExportResultLesson()
{
	bool ok = gResultLesson.ExportRaw(RESULT_FILE_NAME);
	if (!ok)
	{
		AbortWithMessage("Unable to export result lesson to file: " + RESULT_FILE_NAME);
	}
}

// Open the generated result lesson in the application registered for csv files at Windows
void OpenResultLessonInExternalApp()
{
	if (MessageBox("Open generated csv file in external application?", MB_YESNO) == IDYES)
	{
		if (!ShellExecute("open", RESULT_FILE_NAME, ""))
		{
			AbortWithMessage("Unable to open result file with associated application!");	
		}
	}
}
	
// Present the generated result lesson in MemBrain using the Pattern Error Viewer	
void PresentResultLessonInMemBrain()
{
	string message = "Shall MemBrain visualize the results in the Pattern Error Viewer?";
	
	if (gInCount > 1)
	{
		message += "\r\n"
				   "If you select 'Yes' then abort the script later on manually in order to\r\n"
				   "stop the animation.";
	}
		 
	if (MessageBox(message, MB_YESNO) == IDYES)
	{
		SetLessonCount(3);
		SelectLesson(3);
		LoadLesson(gResultLesson);
		
		ShowErrorViewer(true);
		SelectLesson(2);
		SelectLesson(3);
		SelectPatternErrorViewerData(1);
		PatternErrorViewerFitYScale();
		SelectPatternErrorViewerLesson(TRAIN_LESSON);
		
		uint inputVariable = 1;
		
		if (gInCount > 1)
		{
			while (true)
			{		
				Sleep(SLEEP_TIME_FOR_RESULT_ANIMATION);	
				inputVariable++;
				if (inputVariable > gInCount * gOutCount)
				{
					inputVariable = 1;
				}
				SelectPatternErrorViewerData(inputVariable);
			}
		}
	}
	
}
	
	


// Abort the script with an error message
void AbortWithMessage(string message)
{
	MessageBox("Script aborts due to error!\r\n\r\n" + message);
	AbortScript();
}