/* Jim Susinno 2002
 * Multi-layer Neural Network
 *
 *
 *
 *
 */
#ifndef NEURALNET_H
#define NEURALNET_H

#include <math.h> /* log10 for graph, pow(double,double) for sigmoid */
#include "matrix.h" /* for matrix operations */
#include "iopair.h" /* stores list of training examples */

#define NEURALNET_SEED_RANDOM   1
#define NEURALNET_SEED_IDENTITY 2
#define NEURALNET_SEED_ZEROS    3
#define NEURALNET_SEED_ONES     4


/* v0.1.6: length of error graph (x) */
//#define ERROR_HISTORY 100




class NeuralNet {
public:
	NeuralNet();
	NeuralNet(int); /* takes # of layers */

	// not impl yet
	NeuralNet(int,int); /* for 1 layer perceptron case */
	NeuralNet(int,int,int); /* for 2 layer case */
	NeuralNet(int,int,int,int); /* for 3 layer case */
	NeuralNet(int,int,int,int,int); /* for 4 layer case */

	~NeuralNet();

	void initHistory(); /* convenience 0.1.6 */
	void changeErrorHistorySize(int); /* v0.1.6 */
	void changeErrorScaleMode(); /* v0.1.6 */
	
	void rename(char*); /* change the name value */
	void setLayers(int); /* set number of layers in network */
	void changeLayers(int); /* v0.2.7 - try to leep weights as much intact as possible */
	void setInputNodes(int); /* set number of inputs */
	void setOutputNodes(int); /* set number of outputs */
	void setNodesInLayer(int,int); /* set number of nodes in a layer */
	/* convenience */
	int Inputs(); /* return the number of input nodes */
	int Outputs(); /* return the number of output nodes */

	void seedWeights(int); /* use a certain method */
	void seedWeightsRandomly(float); /* v0.2.4: max val */
	void jitter(float); /* v0.1.9: shake the matrices up */

	float sigmoid(float x); /* squashing function */
	float d_sigmoid(float x); /* derivative of squashing function */
	float* calculateOutput(float*); /* run a vector through the net */


	
	/* functions to manipulate training examples */
	void addTrainingExample(float*,float*); /* add a pair to the list */
	void addTestExample(float*,float*); /* v0.2.1: add a pair to the validation list */
	void nextExample(); /* learning utility func */
	void gradientDescent(); /* learn on one training example(at a time) */
	void calculateError(); /* iterate over all examples and sum mse */
	void calculateTestError(); /* v0.2.1: iterate over all test examples and sum mse */
	void updateWeights(float,int); /* v0.2.2: momentum vs. learning */
	void clearDeltaWs();

	void updateErrorHistory(); /* v0.2.1: now a learning function */

	/* graphical visualizing/debugging methods */
	void draw3DWeightMatrices(); /* display in OpenGL */
	void draw3DWeightMatrices(int,int,int); /* v0.1.7: wrap - 3d display in OpenGL */
	void draw3DWeightMatrices(int,int,int,int,int,int,int,int); /* v0.2.8: wrap & more - 3d display in OpenGL */
	void drawRasterMatrices(int,int); /* display in glut */
	void drawTrainingExamples(int,int);
	void drawNodeValues(int,int,int); /* v0.1.6: takes window height last */
	void drawErrorHistoryGraph(); /* v0.1.6 */
	void drawSigmoidGraph(); /* v0.1.7 */


	/* v0.2.1: load/save functions
	 *
	 */
	void loadFromFile(char*);
	void saveToFile(char*);

	void dump();

	/* data members */
	char *name;
	unsigned int layers;
	Matrix **matrices; /* #layers of pointers to them */
	Matrix **deltaWs; /* #layers of pointers to them */
	Matrix **last_deltaWs; /* v0.2.1 - #layers of pointers to them */
	//Matrix **biases; /* #layers of pointers to them */
	float **node_layers; /* stores node values in evaluation */

	IOpair *training_examples;
	IOpair *test_examples; /* v0.2.1 - test validation set */
	IOpair *current_example;
	unsigned int training_count;
	unsigned int training_count_neg; /* v0.2.5 */
	unsigned int test_count;
	unsigned int test_count_neg; /* v0.2.5 */
	float mean_squared_error;
	float test_squared_error; /* v0.2.1 - error over test set */
	float learning_rate; /* eta */
	float momentum_rate; /* alpha */
	unsigned int batch_mode; /* v0.2.1- standard vs. stochastic */
	unsigned int use_bias; /* v0.2.6: switch */

	float steepness; /* control sigmoid threshold */
	unsigned int printout; /* debug on/off */

	float stop_condition; /* v0.2.7 */
	float annealing_rate; /* v0.2.7 */

	/* v0.1.6: error history graph
	 * These are for the error graph
	 *  - list of former values for graph
	 *  - max and min for scaling the graph
	 *  - count of values
	 */
	float *last_few_errors;
	float *last_few_test_errors;
	int error_history; /* how far back in the past we keep samples of error */
	int error_count;
	int iteration_count;
	float maximum_error;
	float minimum_error;
	int error_history_full;

	int error_scale_mode;
	int error_history2; // for just displaying recent ones??
};

#endif
