% Adding the path of functions and loading data files
addpath('./libsvm-3.21/windows');
addpath(genpath('./vlfeat-0.9.20/toolbox/mex'));
addpath('./cifar10');
tr_meta  = load('./cifar10/batches.meta.mat');
tr_set   = load('./cifar10/data_batch_1.mat');
test_set = load('./cifar10/test_batch.mat');

% Prepare training set
tr_set_data   = double(rgb2gray( reshape( tr_set.data(1 : 50 : end, :), [200, 1024, 3] ) ));
tr_set_labels = double(tr_set.labels(1 : 50 : end, :));

% Prepare testing set
test_set_data   = double(rgb2gray( reshape( test_set.data(1 : 50 : end, :), [200, 1024, 3] ) ));
test_set_labels = double(test_set.labels(1 : 50 : end, :));

perm_vector = randperm(1024, 650); % to decrease the size of the training and testing set
[ C, g ] = gridsearch(tr_set_labels, tr_set_data(:, perm_vector));
% 2. trains an SVM model on the basis of images (pixel intensities)
svm_options = ['-s 0 -t 0 -q -h 1 -b 1 -c ', num2str( C ) , ' -g ', num2str( g ) ];
final_model = svmtrain(tr_set_labels, tr_set_data(:, perm_vector), svm_options);
% 3. tests the previous SVM model on the test images
[ ~ , acc, ~ ] = svmpredict(test_set_labels, test_set_data(:, perm_vector), final_model, '-b 1 -q');
fprintf('The accuracy of the basic svm model on the test set is %4.1f%%\n', acc(1));

% 4. extracts the HOG-features from the train images
tr_hog = zeros(size(tr_set_data, 1), (1+4+16)*31);
parfor i = 1:size(tr_set_data, 1)
    img = reshape(tr_set_data(i,:),32,32);
    window32 = vl_hog(im2single(img),32);
    window16 = vl_hog(im2single(img),16);
    window8 = vl_hog(im2single(img),8);
    tr_hog(i,:) = vertcat(window32(:), window16(:), window8(:));
end

test_hog = zeros(size(test_set_data, 1), (1+4+16)*31);
parfor i = 1:size(test_set_data, 1)
    img = reshape(test_set_data(i,:),32,32);
    window32 = vl_hog(im2single(img),32);
    window16 = vl_hog(im2single(img),16);
    window8 = vl_hog(im2single(img),8);
    test_hog(i,:) = vertcat(window32(:), window16(:), window8(:));
end

% 5. trains another SVM model on the basis of the prev. extracted HOG-features
svm_options = ['-s 0 -t 0 -q -h 1 -b 1 -c ', num2str( C ) , ' -g ', num2str( g ) ];
svm_model_hog = svmtrain( tr_set_labels, tr_hog, svm_options );
[ ~ , acc, ~ ] = svmpredict(test_set_labels, test_hog, svm_model_hog, '-b 1 -q');
fprintf('The accuracy of the svm model of hog features on the test set is %4.1f%%\n', acc(1));