 /*
  * Khoros: $Id: llrftrain.c,v 1.1 1991/05/10 15:41:54 khoros Exp $
  */

#if !defined(lint) && !defined(SABER)
static char rcsid[] = "Khoros: $Id: llrftrain.c,v 1.1 1991/05/10 15:41:54 khoros Exp $";
#endif

 /*
  * $Log: llrftrain.c,v $
 * Revision 1.1  1991/05/10  15:41:54  khoros
 * Initial revision
 *
  */ 

/*
 *----------------------------------------------------------------------
 *
 * Copyright 1991, University of New Mexico.  All rights reserved.
 * Permission to copy and modify this software and its documen-
 * tation only for internal use in your organization is hereby
 * granted, provided that this notice is retained thereon and
 * on all copies.  UNM makes no representations as too the sui-
 * tability and operability of this software for any purpose.
 * It is provided "as is" without express or implied warranty.
 * 
 * UNM DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE,
 * INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FIT-
 * NESS.  IN NO EVENT SHALL UNM BE LIABLE FOR ANY SPECIAL,
 * INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY OTHER DAMAGES WHAT-
 * SOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER
 * IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
 * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PER-
 * FORMANCE OF THIS SOFTWARE.
 * 
 * No other rights, including for example, the right to redis-
 * tribute this software and its documentation or the right to
 * prepare derivative works, are granted unless specifically
 * provided in a separate license agreement.
 *---------------------------------------------------------------------
 */

#include "unmcopyright.h"        /* Copyright 1991 by UNM */

/*>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
 >>>>
 >>>>         File Name: llrftrain.c
 >>>>
 >>>>      Program Name: lrftrain
 >>>>
 >>>> Date Last Updated: Tue Apr  9 08:34:35 1991 
 >>>>
 >>>>          Routines: llrftrain - the library call for lrftrain
 >>>>
 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>   <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<*/


#include "vinclude.h"


/* -library_includes */
#define ALPHA_MAX 0.995
/* -library_includes_end */


/****************************************************************
*
* Routine Name: llrftrain - library call for lrftrain
*
* Purpose:
*    
*    Trains on an image for the weights  used  in  the  Localized
*    Receptive Field classifier.
*    
*    
* Input:
*    
*    image          the input image used for training.
*    
*    cc_img         the cluster center image  specifying  the  cluster
*                   centers.
*    
*    var_img        the cluster variance image  specifying  the  vari-
*                   ances.
*    
*    cn_img         the  cluster   number   image   specifying   which
*                   vector/pixel belongs to which cluster.
*    
*    converge       the convergence parameter.
*    
*    meu            the weight update parameter.
*    
*    border         the border width in pixels of the input image.
*    
*    max_iter       the maximum number of  iterations  until  termina-
*                   tion.
*    
*    prt_mse        the iteration interval for printing the MSE to the
*                   stats file.
*    
*    delta_mse      the minimum change in the MSE between  iterations,
*                   for termination.
*    
*    
* Output:
*    
*    wt_img         the resulting weight image after training.
*    
*    printdev       the file containing the training statistics.
*    
*    This routine was written with the help of and ideas from Dr.  Don
*    Hush, University of New Mexico, Dept. of EECE.
*    
*    
*
* Written By: Tom Sauer and Charlie Gage
*    
*    
****************************************************************/


/* -library_def */
int
llrftrain (image, cc_img, var_img, cn_img, wt_img, converge, meu, border, max_iter, prt_mse, delta_mse, printdev)
struct  xvimage  *image,         /* input image */
                 *cc_img,        /* cluster center image */
                 *var_img,       /* cluster variance image */
                 *cn_img,        /* cluster number image */
                 **wt_img;       /* output LRF weight image */

float   converge,                /* convergence parameter */
        meu,                     /* weight update parameter */
        delta_mse;               /* MIN delta MSE value */

int     border,                  /* input image border width */
        max_iter,                /* MAX iterations to perform */
        prt_mse;                 /* interval to print MSE to stats file */

FILE   *printdev;

/* -library_def_end */

/* -library_code */
{
   struct  xvimage  *weight_image, *createimage();
   int nodes,                    /* number of nodes in the LRF layer */
       rfc_nvects,               /* number of receptive field centers */
       rfc_dim,                  /* dimension of receptive field centers */
       var_nvects,               /* variance vectors */
       var_dim,                  /* dimension of variance vectors */
       input_nvects,             /* number of input vectors */
       input_dim,                /* dimension of input vectors */
       num_classes,              /* number of desired output classes */
       iterations,               /* number of iterations */
       patrn,                    /* number of input patterns */
       prt_flag,                 /* flag indicating print MSE info */
       nr, 
       nc, 
       i, j, k, x, z, class, node, d;
   int *cn_ptr;                  /* pointer to cluster number imagedata */
   float **rf_centers,           /* pointer to receptive field centers */ 
         **cc_var,               /* pointer to cluster center variances */
         **input_data,           /* pointer to input image data */
         *class_ptr,             /* pointer to the class that corresponds */
                                 /*  to the associated cluster center */
         **R,                    /* pointer to Receptive Field values */
         **W,                    /* pointer to the Weights */
         *Y,                     /* pointer to the output node equations */
         *error,                 /* pointer to the error values */
         *desired,               /* pointer to the desired class outputs */
         *image_ptr;
   float a, q, dis, total_error, urng();
   char  **load_vector(), *unload_vector();
   double alpha,                 /* MSE update parameters */
          mse, 
          prev_mse;
 
    if (prt_mse > 0)
       prt_flag = 1;
    else
       prt_flag = 0;

    nr = image->col_size;
    nc = image->row_size;

       /* initialize mse to some arbitrary large value */
    mse = 0.5;
    prev_mse = 0.0;

   /*--------------------------------------------------------------------*
    * INITIALIZE the NUMBER of NODES in the LRF, this is simply the number
    * of cluster centers previously determined 
    *--------------------------------------------------------------------*/ 
   nodes = cc_img->row_size;

   /*--------------------------------------------------------------------*
    * determine the NUMBER of DESIRED OUTPUT CLASSES
    *--------------------------------------------------------------------*/ 
   image_ptr = (float *) cc_img->imagedata;
   class_ptr = (float *) &image_ptr[cc_img->row_size * cc_img->col_size 
                                      * (cc_img->num_data_bands-1)];
   num_classes = *class_ptr; 
   for (i = 1; i < cc_img->row_size * cc_img->col_size; i++)
   {
      if (class_ptr[i] > num_classes) 
        num_classes = (int) class_ptr[i];
   }
   num_classes++;     /* add 1, since we start with zero */

   /*--------------------------------------------------------------------*
    * LOAD the RECEPTIVE FIELD CENTERS from the cluster center image,
    * but trick load_vector so that it does not see the class data
    * band attached to the end of the cluster center image
    *--------------------------------------------------------------------*/ 
   cc_img->num_data_bands -= 1;
   rf_centers = (float **) load_vector(cc_img, 0, &rfc_nvects, &rfc_dim);
   cc_img->num_data_bands += 1;

   /*--------------------------------------------------------------------*
    * LOAD the CLUSTER CENTER VARIANCES from the variance image.
    *--------------------------------------------------------------------*/ 
   cc_var = (float **) load_vector(var_img, 0, &var_nvects, &var_dim);

   /*--------------------------------------------------------------------*
    * LOAD the INPUT DATA into a vector format
    *--------------------------------------------------------------------*/ 
   input_data = (float **)load_vector(image, border, &input_nvects, &input_dim);

   /*--------------------------------------------------------------------*
    * ALLOCATE SPACE for the response functions.
    * There are N * K response functions, where N is the number of nodes
    * and K is the number of input patterns
    *--------------------------------------------------------------------*/ 
        
   R = (float **) malloc (sizeof (float *) * nodes);
   if (R == NULL)
   {
      (void) fprintf (stderr,
                "lrftrain: insufficient memory available\n");
        return (0);
   }
   for (i = 0; i < nodes; i++) 
   {
      R[i] = (float *) malloc (sizeof (float) * input_nvects);
      if (R[i] == NULL) 
      {
          (void) fprintf (stderr,
                  "lrftrain: insufficient memory available\n");
          return (0);
       }
    }


   /*--------------------------------------------------------------------*
    * COMPUTE the RESPONSE FUNCTIONS for each node and data input
    * pattern. This is based on using the ratio of the Euclidean distance 
    * of an input data vector to a receptive field center to each 
    * diagonal element of the covariance matrix.
    *--------------------------------------------------------------------*/ 
   for (i = 0; i < input_nvects; i++) 
   {
       for (j = 0; j < nodes; j++) 
       {
                  /* compute the distance along each axis */
   
           dis = 0.0;
           for (k = 0; k < input_dim; k++) {
               a = *(input_data[i] + k) - *(rf_centers[j] + k);
               q = a * a;
               dis += q/(*(cc_var[j] +k));
           }
           *(R[j] + i) = exp(-dis);
       }
   }

   /*--------------------------------------------------------------------*
    * ALLOCATE SPACE for the number of WEIGHTS
    * number of weights = number of desired output classes * # of nodes + 1
    * the +1 is for the bias weight.
    *--------------------------------------------------------------------*/ 
   W = (float **) malloc (sizeof (float *) * num_classes);
   if (W == NULL)
   {
      (void) fprintf (stderr,
                "lrftrain: insufficient memory available\n");
        return (0);
   }
   for (i = 0; i < num_classes; i++) 
   {
      W[i] = (float *) malloc (sizeof (float) * nodes + 1);
      if (W[i] == NULL) 
      {
          (void) fprintf (stderr,
                  "lrftrain: insufficient memory available\n");
          return (0);
       }
    }
    
   /*--------------------------------------------------------------------*
    * ALLOCATE SPACE for the OUPUT NODE EQUATIONS Y and for
    * the error for each node equation Y
    *--------------------------------------------------------------------*/ 

   Y = (float *) malloc (sizeof (float *) * num_classes);
   if (Y == NULL)
   {
      (void) fprintf (stderr,
                "lrftrain: insufficient memory available\n");
        return (0);
   }

   error = (float *) malloc (sizeof (float *) * num_classes);
   if (error == NULL)
   {
      (void) fprintf (stderr,
                "lrftrain: insufficient memory available\n");
        return (0);
   }

   desired = (float *) malloc (sizeof (float *) * num_classes);
   if (desired == NULL)
   {
      (void) fprintf (stderr,
                "lrftrain: insufficient memory available\n");
        return (0);
   }
        /* END OF DATA ALLOCATION */

   (void) fprintf (printdev, "lrftrain Statistics\n");
   (void) fprintf (printdev, "===================\n");

   /*--------------------------------------------------------------------*
    *   INITIALIZE the WEIGHTS with a random number between -1 and 1
    *--------------------------------------------------------------------*/ 
    for (class = 0; class < num_classes; class++)
    {
      for( node = 0; node < nodes + 1; node++)
      {
         W[class][node] = (urng() - 0.5) * 2.0;
      }
    }
        
   /*--------------------------------------------------------------------*
    * TRAIN the SINGLE LAYER PERCEPTRON for the weights using the
    *   LMS algorithm 
    *--------------------------------------------------------------------*/ 
    iterations = 0;
    do
    {
       iterations++;
       patrn = 0;
       for (x = border; x < nr - border; x++) 
       {
          for (z = border; z < nc - border; z++) 
          {
               /*
                *   compute the output layer node equation 
                */
             bzero(Y,sizeof(float) * num_classes);
             for (class = 0; class < num_classes; class++)
             {
                for (node = 1; node < nodes + 1; node++)
                {
                   Y[class] += R[node-1][patrn] * W[class][node];
                }
                Y[class] += W[class][0];
             }

                /* find the class that the current input vector belongs to
                 * and set the desired output for that class to 1 and set 
                 * all other desired outputs to -1
                 */
             cn_ptr = (int *) cn_img->imagedata;
             for (class = 0; class < num_classes; class++)
                desired[class] = -1.0;
             d = (int) class_ptr[cn_ptr[x*nc+z]];
             desired[d] = 1.0;

                 /* compute the error terms for each class
                  *
                  */
             total_error = 0.0;
             for (class = 0; class < num_classes; class++)
             {
                error[class] = desired[class] - Y[class];
                total_error += error[class] * error[class];
             }
                   
                /* update the weights for the output layer
                 *
                 */
             for (class = 0; class < num_classes; class++)
             {
                for (node = 1; node < nodes + 1; node++)
                {
                   W[class][node] = W[class][node] + meu 
                                * error[class] * R[node - 1][patrn];
                }
                W[class][0] = W[class][0] + meu * error[class];
             }
             patrn++;
          }
       }
   /*---------------------------------------------------------------------*
    * COMPUTE the MSE for the current iteration.
    *   This formula will put more weight on the previous mse 
    *   as the number of iterations increases, and less
    *   weight on the total_error for the current iteration.
    *---------------------------------------------------------------------*/
    alpha = MIN(ALPHA_MAX, (double) (iterations - 1) /(double) (iterations));
    mse = alpha*mse + ((1.0 - alpha)*(double) total_error*(double) total_error);

   /*---------------------------------------------------------------------*
    * TERMINATE program when minimum delta MSE value is reached 
    *---------------------------------------------------------------------*/
    if (fabs(mse - prev_mse) < (double)delta_mse)
    {
      fprintf (printdev,"\nTraining was terminated at the following point:\n");
      fprintf (printdev,"\n-----  iter = %d    mse = %.10f -----\n", iterations, mse);
      fprintf (printdev,"\nbecause the MSE reached a plateau.\n");
      break;
    }
    prev_mse = mse;

   /*---------------------------------------------------------------------*
    * PRINT STATS at selected interval
    *---------------------------------------------------------------------*/
    if ((iterations == 1) || ((prt_flag == 1) && ((iterations % prt_mse) == 0)))
       fprintf (printdev,"\n-----  iter = %d    mse = %.10f -----\n", iterations, mse);
   }
   while((mse > converge) && (iterations < max_iter));

   /*---------------------------------------------------------------------*
    * SAVE the WEIGHTS (W) 
    *---------------------------------------------------------------------*/
   weight_image = createimage((unsigned long) 1,       /* number of rows */
                        (unsigned long) num_classes,   /* number of columns */
                        (unsigned long) VFF_TYP_FLOAT,  /* data_storage_type */
                        (unsigned long) 1,                /* num_of_images */
                        (unsigned long) (nodes + 1),      /* num_data_bands */
                        "Weight image created by lrftrain", /* comment */
                        (unsigned long) 0,                /* map_row_size */
                        (unsigned long) 0,                /* map_col_size */
                        (unsigned long) VFF_MS_NONE,      /* map_scheme */
                        (unsigned long) VFF_MAPTYP_NONE, /* map_storage_type */
                        (unsigned long) VFF_LOC_IMPLICIT, /* location_type */
                        (unsigned long) 0);               /* location_dim */

    if (weight_image == NULL)
    {
       (void)fprintf(stderr,"lrftrain: Unable to allocate weight image!\n");
       *wt_img = NULL;
       return(0);
    }
    free(weight_image->imagedata);
    weight_image->imagedata = (char *) unload_vector(W, 0, VFF_TYP_FLOAT,
                                       num_classes, nodes + 1, 1, num_classes);
   *wt_img = weight_image;

   /*---------------------------------------------------------------------*
    * WRITE STATS to file
    *---------------------------------------------------------------------*/
   (void) fprintf (printdev,
                  "\n\tTotal Number of Iterations : %d\n\n",iterations);
   (void) fprintf (printdev,
                  "\tFinal Mean Squared Error: %f\n\n",mse);
   (void) fprintf (printdev,
                  "\tConvergence Parameter (cv): %f\n\n",converge);
   (void) fprintf (printdev,
                  "\tWeight Update Parameter (meu): %f\n\n",meu);
   (void) fprintf (printdev,
                  "\tMinimum DELTA MSE value (delta): %g\n\n", delta_mse);
   (void) fprintf (printdev,
                  "\tBorder Width: %d\n\n", border);
   (void) fprintf (printdev,
                  "\tNumber of Response Nodes: %d\n\n",nodes);
   (void) fprintf (printdev,
                  "\tNumber of Output Classes: %d\n\n",num_classes);
   return(TRUE);
}
/* -library_code_end */
