/** file:        hikmeans.c
 ** description: MEX hierarchical ikmeans.
 ** author:      Brian Fulkerson
 **/

/*
Copyright (C) 2014 Andrea Vedaldi.
Copyright (C) 2007-12 Andrea Vedaldi and Brian Fulkerson.
All rights reserved.

This file is part of the VLFeat library and is made available under
the terms of the BSD license (see the COPYING file).
*/

#include<mexutils.h>

#include<stdio.h>
#include<stdlib.h>
#include<math.h>
#include<string.h>
#include<assert.h>

#include <vl/hikmeans.h>
#include <vl/generic.h>

#define NFIELDS(field_names) (sizeof(field_names)/sizeof(*field_names))

enum {
  opt_max_niters,
  opt_method,
  opt_verbose
} ;

vlmxOption  options [] = {
  {"MaxIters",     1,   opt_max_niters  },
  {"Method",       1,   opt_method      },
  {"Verbose",      0,   opt_verbose     },
  {0,              0,   0               }
} ;

/** ------------------------------------------------------------------
 ** @internal
 ** @brief Copy HIKM tree node to a MATLAB structure
 **/

static void
xcreate (mxArray *mnode, int i, VlHIKMNode *node)
{
  int node_K = vl_ikm_get_K (node->filter) ;
  int M = vl_ikm_get_ndims (node->filter) ;
  vl_ikmacc_t const *centers = vl_ikm_get_centers (node->filter) ;

  mxArray *mcenters ;

  mcenters = mxCreateNumericMatrix (M, node_K, mxINT32_CLASS, mxREAL);
  memcpy (mxGetPr(mcenters), centers, sizeof(*centers) * M * node_K) ;
  mxSetField (mnode, i, "centers", mcenters) ;

  if (node->children) {
    mxArray * msub ;
    const char * field_names[] = {"centers", "sub" } ;
    mwSize dims [2] ;
    int k ;

    dims[0] = 1 ;
    dims[1] = node_K ;

    msub = mxCreateStructArray (2, dims, 2, field_names) ;

    for (k = 0 ; k < node_K ; ++k) {
      xcreate (msub, k, node -> children [k]) ;
    }

    mxSetField (mnode, i, "sub", msub) ;
  }
}

/** ------------------------------------------------------------------
 ** @internal
 ** @brief Copy HIKM tree to a MATLAB structure
 **/

mxArray *
hikm_to_matlab (VlHIKMTree * tree)
{
  vl_size K = vl_hikm_get_K (tree) ;
  vl_size depth = vl_hikm_get_depth (tree) ;
  mwSize  dims [2] = {1, 1} ;
  mxArray *mtree ;
  const char *field_names[] = {"K", "depth", "centers", "sub"} ;

  /* Create the main struct array */
  mtree = mxCreateStructArray
    (2, dims, NFIELDS(field_names), field_names) ;
  mxSetField (mtree, 0, "K", mxCreateDoubleScalar (K)) ;
  mxSetField (mtree, 0, "depth", mxCreateDoubleScalar (depth)) ;
  if (tree->root) xcreate (mtree, 0, tree->root) ;
  return mtree;
}

/** ------------------------------------------------------------------
 ** @internal
 ** @brief MEX driver
 **/
void mexFunction (int nout, mxArray * out[], int nin, const mxArray * in[])
{
  enum {IN_DATA = 0, IN_K, IN_NLEAVES, IN_END} ;
  enum {OUT_TREE = 0, OUT_ASGN} ;
  VlHIKMTree* tree ;
  int nleaves = 1 ;
  int method_type = VL_IKM_LLOYD ;
  int max_niters = 200 ;
  int verb = 0 ;
  vl_uint8 *data ;
  vl_size M, N, K = 2 ;
  vl_size depth = 0 ;

  int opt ;
  int next = IN_END ;
  mxArray const *optarg ;

  VL_USE_MATLAB_ENV ;

  /* -----------------------------------------------------------------
   *                                               Check the arguments
   * -------------------------------------------------------------- */

  if (nin < 3) {
      mexErrMsgTxt ("At least three arguments required.");
  } else if (nout > 2) {
    mexErrMsgTxt ("Too many output arguments.");
  }

  if (mxGetClassID (in[IN_DATA]) != mxUINT8_CLASS) {
    mexErrMsgTxt ("DATA must be of class UINT8.");
  }

  if (! vlmxIsPlainScalar (in[IN_NLEAVES])          ||
      (nleaves = (int) *mxGetPr (in[IN_NLEAVES])) < 1) {
    mexErrMsgTxt ("NLEAVES must be a scalar not smaller than 2.") ;
  }

  M = mxGetM (in[IN_DATA]);   /* n of components */
  N = mxGetN (in[IN_DATA]);   /* n of elements */

  if (! vlmxIsPlainScalar (in[IN_K])    ||
      (K = (int) *mxGetPr (in[IN_K])) > N) {
    mexErrMsgTxt ("Cannot have more clusters than data.") ;
  }

  data = (vl_uint8 *) mxGetPr (in[IN_DATA]) ;

  while ((opt = vlmxNextOption (in, nin, options, &next, &optarg)) >= 0) {
    char buf [1024] ;

    switch (opt) {

    case opt_verbose :
      ++ verb ;
      break ;

    case opt_max_niters :
      if (!vlmxIsPlainScalar(optarg) ||
          (max_niters = (int) *mxGetPr(optarg)) < 1) {
        mexErrMsgTxt("MaxNiters must be not smaller than 1.") ;
      }
      break ;

    case opt_method :
      if (!vlmxIsString (optarg, -1)) {
        mexErrMsgTxt("'Method' must be a string.") ;
      }
      if (mxGetString (optarg, buf, sizeof(buf))) {
        mexErrMsgTxt("Option argument too long.") ;
      }
      if (strcmp("lloyd", buf) == 0) {
        method_type = VL_IKM_LLOYD ;
      } else if (strcmp("elkan", buf) == 0) {
        method_type = VL_IKM_ELKAN ;
      } else {
        mexErrMsgTxt("Unknown cost type.") ;
      }
      break ;

    default :
      abort() ; break ;
    }
  }

  /* -----------------------------------------------------------------
   *                                                        Do the job
   * -------------------------------------------------------------- */

  depth = VL_MAX(1, ceil (log (nleaves) / log(K))) ;
  tree  = vl_hikm_new  (method_type) ;

  if (verb) {
    mexPrintf("hikmeans: # dims: %d\n", M) ;
    mexPrintf("hikmeans: # data: %d\n", N) ;
    mexPrintf("hikmeans: K: %d\n", K) ;
    mexPrintf("hikmeans: depth: %d\n", depth) ;
  }

  vl_hikm_set_verbosity (tree, verb) ;
  vl_hikm_init (tree, M, K, depth) ;
  vl_hikm_train (tree, data, N) ;

  out[OUT_TREE] = hikm_to_matlab (tree) ;

  if (nout > 1) {
    vl_uint *asgn ;
    vl_uindex j ;
    out [OUT_ASGN] = mxCreateNumericMatrix
      (vl_hikm_get_depth (tree), N, mxUINT32_CLASS, mxREAL) ;
    asgn = mxGetData(out[OUT_ASGN]) ;
    vl_hikm_push (tree, asgn, data, N) ;
    for (j = 0 ; j < N*depth ; ++ j) asgn [j] ++ ;
  }

  if (verb) {
    mexPrintf("hikmeans: done.\n") ;
  }

  /* vl_hikm_delete (tree) ; */
}
