Main Page   Class Hierarchy   Compound List   File List   Compound Members   File Members  

trainer.cc

Go to the documentation of this file.
00001 
00025 #include "inanna/trainer.h"
00026 #include "inanna/termination.h"
00027 #include "inanna/patternset.h"
00028 
00030 //                     -----           o                                     //
00031 //                       |        ___      _    ___                          //
00032 //                       |   |/\  ___| | |/ \  /   ) |/\                     //
00033 //                       |   |   (   | | |   | |---  |                       //
00034 //                       |   |    \__| | |   |  \__  |                       //
00036 
00037 Trainer::Trainer () {
00038     mTerminatorName     = "GL2";
00039     mGeneralizationLoss = 0.0;
00040     mTrained            = 0;
00041     mTotalTrained       = 0;
00042     pTrainingObserver   = NULL;
00043 }
00044 
00045 /*virtual*/ void Trainer::init (const StringMap& params) {
00046     INITPARAMS(params, 
00047                );
00048 }
00049 
00050 /*virtual*/ void Trainer::initTrain (ANNetwork& network) const
00051 {
00052     // Initialize weights
00053     network.init (0.5);
00054 }
00055 
00056 double Trainer::train (ANNetwork&        network,
00057                        const PatternSet& trainset,
00058                        int               cycles,
00059                        const PatternSet* validationSet,
00060                        int               validationInterval)
00061 {
00062     ASSERT (trainset.patterns>0);
00063     ASSERT (cycles>0);
00064 
00065     // Initialize training method
00066     initTrain (network);
00067 
00068     // Initialize recording
00069     mTrainingProfile.make (cycles);
00070     for (int i=0; i<mTrainingProfile.size(); i++)
00071         mTrainingProfile[i] = 0.0;
00072 
00073     // If validation set is present, build a terminator that monitors it.
00074     Terminator* arnold=NULL;
00075     bool ensureValidGTTrain = true;
00076     if (validationSet) {
00077         mValidationProfile.make (cycles/validationInterval+1);
00078         for (int i=0; i<mValidationProfile.size(); i++)
00079             mValidationProfile[i]=0.0;
00080 
00082         // Determine termination method
00083         
00084         String terminator = mTerminatorName;
00085         
00086         // Check if we should _not_ check for that validError>trainError
00087         if (terminator.left(1) == "-") {
00088             ensureValidGTTrain = false;
00089             terminator = terminator.mid (1);
00090         }
00091         
00092         // Can't use termination if there are no validation patterns...
00093         if (validationSet->patterns==0)
00094             terminator = "none";
00095         
00096         // Order a terminator from the factory
00097         arnold = buildTerminator (terminator, *validationSet, validationInterval);
00098     }
00099 
00101     // Train and validate
00102     
00103     double  trainMSE;
00104     double  GL          = 0;
00105     int     validations = 0;
00106     bool    terminate   = false;
00107     for (mTotalTrained=0; mTotalTrained<cycles;) {
00108         // Train all patterns once
00109         trainMSE = mTrainingProfile[mTotalTrained] = trainOnce (network, trainset);
00110         mTotalTrained++;
00111 
00112         // Streamed output
00113         if (false)
00114             printf ("Cycle %d MSE=%f\n", mTotalTrained, trainMSE);
00115         
00116         // Validate for early stopping
00117         if (arnold && mTotalTrained>0 && !(mTotalTrained%validationInterval)) {
00118 
00119             // Calculate the validation error for the current network
00120             // state
00121             terminate = arnold->validate (network, *this, mTotalTrained);
00122             mValidationProfile[validations++] = arnold->validationError ();
00123 
00124             // Let the terminator calculate the GL value. Some
00125             // terminators use this value to determine termination.
00126             GL = arnold->generalizationLoss ();
00127 
00128             // Do not terminate if the validation error is lower than
00129             // the training error
00130             if (ensureValidGTTrain && arnold->validationError() < trainMSE)
00131                 break;
00132         }
00133 
00134         // Report the cycle to the training observer, if present
00135         if (pTrainingObserver) {
00136             pTrainingObserver->cycleTrained (*this, mTotalTrained);
00137 
00138             // The observer has the power to stop training. This is
00139             // typically a cancel command given interactively by a
00140             // user.
00141             if (pTrainingObserver->wantsToStop())
00142                 break;
00143         }
00144 
00145     }
00146 
00147     // Restore the state with the lowest error on validation set. Do
00148     // not restore if the validation error is smaller than the training error
00149     if (arnold && (!ensureValidGTTrain || arnold->minimumError() > trainMSE)) {
00150         arnold->restore (network);
00151         mTrained = arnold->howManyTrained ();
00152     } else
00153         // No validation, keep the latest state
00154         mTrained = mTotalTrained;
00155 
00157     // Record and clean up some statistics (truncate vectors)
00158     
00159     mGeneralizationLoss = GL;
00160     if (mTrained>0)
00161         mTrainingProfile.resize (mTrained);
00162     else
00163         mTrainingProfile.make (0);
00164     if (arnold)
00165         mValidationProfile.resize (validations);
00166     else
00167         mValidationProfile.make (0);
00168 
00169     
00170     // You will now be terminated
00171     delete arnold;
00172     
00173     return trainMSE; // Return final training MSE
00174 }
00175 
00176 
00177 

Generated on Thu Feb 10 20:06:45 2005 for Inanna by doxygen1.2.18