00001
00025 #include "inanna/trainer.h"
00026 #include "inanna/termination.h"
00027 #include "inanna/patternset.h"
00028
00030
00031
00032
00033
00034
00036
00037 Trainer::Trainer () {
00038 mTerminatorName = "GL2";
00039 mGeneralizationLoss = 0.0;
00040 mTrained = 0;
00041 mTotalTrained = 0;
00042 pTrainingObserver = NULL;
00043 }
00044
00045 void Trainer::init (const StringMap& params) {
00046 INITPARAMS(params,
00047 );
00048 }
00049
00050 void Trainer::initTrain (ANNetwork& network) const
00051 {
00052
00053 network.init (0.5);
00054 }
00055
00056 double Trainer::train (ANNetwork& network,
00057 const PatternSet& trainset,
00058 int cycles,
00059 const PatternSet* validationSet,
00060 int validationInterval)
00061 {
00062 ASSERT (trainset.patterns>0);
00063 ASSERT (cycles>0);
00064
00065
00066 initTrain (network);
00067
00068
00069 mTrainingProfile.make (cycles);
00070 for (int i=0; i<mTrainingProfile.size(); i++)
00071 mTrainingProfile[i] = 0.0;
00072
00073
00074 Terminator* arnold=NULL;
00075 bool ensureValidGTTrain = true;
00076 if (validationSet) {
00077 mValidationProfile.make (cycles/validationInterval+1);
00078 for (int i=0; i<mValidationProfile.size(); i++)
00079 mValidationProfile[i]=0.0;
00080
00082
00083
00084 String terminator = mTerminatorName;
00085
00086
00087 if (terminator.left(1) == "-") {
00088 ensureValidGTTrain = false;
00089 terminator = terminator.mid (1);
00090 }
00091
00092
00093 if (validationSet->patterns==0)
00094 terminator = "none";
00095
00096
00097 arnold = buildTerminator (terminator, *validationSet, validationInterval);
00098 }
00099
00101
00102
00103 double trainMSE;
00104 double GL = 0;
00105 int validations = 0;
00106 bool terminate = false;
00107 for (mTotalTrained=0; mTotalTrained<cycles;) {
00108
00109 trainMSE = mTrainingProfile[mTotalTrained] = trainOnce (network, trainset);
00110 mTotalTrained++;
00111
00112
00113 if (false)
00114 printf ("Cycle %d MSE=%f\n", mTotalTrained, trainMSE);
00115
00116
00117 if (arnold && mTotalTrained>0 && !(mTotalTrained%validationInterval)) {
00118
00119
00120
00121 terminate = arnold->validate (network, *this, mTotalTrained);
00122 mValidationProfile[validations++] = arnold->validationError ();
00123
00124
00125
00126 GL = arnold->generalizationLoss ();
00127
00128
00129
00130 if (ensureValidGTTrain && arnold->validationError() < trainMSE)
00131 break;
00132 }
00133
00134
00135 if (pTrainingObserver) {
00136 pTrainingObserver->cycleTrained (*this, mTotalTrained);
00137
00138
00139
00140
00141 if (pTrainingObserver->wantsToStop())
00142 break;
00143 }
00144
00145 }
00146
00147
00148
00149 if (arnold && (!ensureValidGTTrain || arnold->minimumError() > trainMSE)) {
00150 arnold->restore (network);
00151 mTrained = arnold->howManyTrained ();
00152 } else
00153
00154 mTrained = mTotalTrained;
00155
00157
00158
00159 mGeneralizationLoss = GL;
00160 if (mTrained>0)
00161 mTrainingProfile.resize (mTrained);
00162 else
00163 mTrainingProfile.make (0);
00164 if (arnold)
00165 mValidationProfile.resize (validations);
00166 else
00167 mValidationProfile.make (0);
00168
00169
00170
00171 delete arnold;
00172
00173 return trainMSE;
00174 }
00175
00176
00177