function imdb = cifar_loader(cifar_dir, nr_datafile_to_load, tr_val_ratio)
%This function loads the CIFAR-10 dataset. The dataset structur is
%compatible  with the cnn training function.


%%%% Initialization %%%%
%make sure that the nr_datafile_to_load is between 1 and 5. This number
%affects the number of datafiles to load. There are 5 datafiles, each
%contains 10000 training samples. 
nr_datafile_to_load = min(max(nr_datafile_to_load,1), 5);

%if tr_val_ratio is not set, use 0.9 as default value. This means that 90%
%of the loaded training data will be used for training and the rest 10% for
%validation.
if ~exist('tr_val_ratio', 'var')
    tr_val_ratio = 0.9;
end

%initialize cell array for data, labels and set
data = cell(1, nr_datafile_to_load+1); 
labels = cell(1, nr_datafile_to_load+1);
set = cell(1, nr_datafile_to_load+1);

%%%% Load the Training Set %%%%
%load the data batch files one by one
for b=1:nr_datafile_to_load
    %get the file name for the current batch
    c_filename = [cifar_dir '/data_batch_' num2str(b) '.mat'];
    %load the content of the mat file
    lobj = load(c_filename);
    %reshape the data from vector format to image format: 
    %[nr_images, width*height*nr_channels]->[width,height,nr_channels,nr_images]
    data{b} = permute(reshape(lobj.data',32,32,3,[]),[2 1 3 4]);
    labels{b} = lobj.labels' + 1; % Index from 1
    
    %set the training and validation set for the current part
    nr_sampels = length(lobj.labels);
    set_b = ones(1, nr_sampels); %initialize set variable as all samples are training samples 
    set_b(floor(nr_sampels*tr_val_ratio):end) = 2; %change 1-tr_val_ratio number of samples to validation samples
    %shuffle the order of elements in the set array (so it is random which sample 
    %is used for training and which sample is used for validation)
    set{b} = set_b(randperm(length(set_b)));
end

%%%% Load the Test Set %%%%
lobj = load([cifar_dir '/test_batch.mat']);
data{nr_datafile_to_load+1} = permute(reshape(lobj.data',32,32,3,[]),[2 1 3 4]);
labels{nr_datafile_to_load+1} = lobj.labels' + 1; % Index from 1
set{nr_datafile_to_load+1} = ones(1, length(labels{b}), 'uint8')*3;


%concatenate the separate parts of the data into one tensor
data = single(cat(4, data{:}))/255;
labels = single(cat(2, labels{:}));
set = cat(2, set{:});

%remove the mean of the training images from each image
dataMean = mean(data(:,:,:,set < 3), 4);
data = bsxfun(@minus, data, dataMean);


%put the data samples, the labels and the set ids into one struct
imdb.images.data = data;    %data contains a 4D tensor of the training images (dimensions: Width, Height, Depth, ImageIndex)
imdb.images.labels = labels; %labels contains a vector with the ground truth labels 
imdb.images.set = set; % set is a vector, it defines which samples is for training, validation or test
imdb.meta.sets = {'train', 'val', 'test'};

%add the name of the clusters as meta information to the database structure
lobj = load([cifar_dir 'batches.meta.mat']);
imdb.meta.classes = lobj.label_names;


