%Set the output directory, where the model will be saved
work_path = fileparts(mfilename('fullpath')); %path to the current file
output_dir = [ work_path '\Output\']; %path to the output folder

%Load CIFA-10 dataset
%cifar_dir = [ work_path '\Input\cifar-10-batches-mat\']; %path to the CIFAR-10 directory
cifar_dir = [ work_path '\matconvnet-1.0-beta23\data\cifar\cifar-10-batches-mat\'];
%Set how many samples will be loaded from the train set 
%The value should be an integer be betwwen 1 and 5 (meaning 10 to 50 thousand samples)
nr_datafile_to_load = 1; 
%use the cifar loader function to load the training, validation and test data
cifar_db = cifar_loader(cifar_dir, nr_datafile_to_load);

%Initialize network 
%This function sets the network architecture and initializes the weights
net = init_network_ex1();
% solving the exercises, please create the following files:
%net = init_network_ex1();
%net = init_network_ex2();
%net = init_network_ex3(); %optional


%Train the network:
%this function trains the network we initialized above with the loaded
%training samples. It will use only the training images for the training.
%The validation images are used to check if there is overfitting. The test
%samples are not used by this function.
%The get_batch function generates the training batches during the training:
%it returns the training images and their labels for the given batch.
%The training parameters set above are also inputs to the function.
st = tic;
[net, stats] = cnn_train(net, cifar_db, @(x,y) get_batch(x,y), ...
                        'expDir', output_dir, net.meta.trainOpts) ;
toc(st);

%Evaluation on the test set
accuracy = evaluation(net, cifar_db);