00001
00025 #include "inanna/termination.h"
00026 #include "inanna/annetwork.h"
00027 #include "inanna/patternset.h"
00028 #include "inanna/trainer.h"
00029
00030 Terminator* buildTerminator (const String& modelName, const PatternSet& validationset, int interval) {
00031 if (modelName == "none")
00032 return NULL;
00033 if (modelName == "dummy")
00034 return new DummyTerminator (validationset, interval);
00035
00036 ASSERT (modelName.length()>=3);
00037 String prefix = modelName.left(2);
00038 int intParam = modelName.mid (2).toInt();
00039 double doubleParam = modelName.mid (2).toDouble();
00040 ASSERT (intParam>0);
00041 ASSERT (doubleParam>0.0);
00042
00043 if (prefix=="FT")
00044 return new TerminatorT800 (validationset, intParam, interval);
00045 if (prefix=="GL")
00046 return new GLTerminator (validationset, doubleParam, interval);
00047 if (prefix=="PQ")
00048 return new PQTerminator (validationset, doubleParam, interval);
00049 if (prefix=="UP")
00050 return new UPTerminator (validationset, intParam, interval);
00051 if (prefix=="PR")
00052 return new PRTerminator (validationset, doubleParam, interval, modelName.mid(2));
00053
00054 throw generic_exception (format ("Invalid terminator model name '%s'",
00055 (CONSTR) modelName));
00056 }
00057
00058
00060
00061
00062
00063
00064
00065
00066
00068
00069 Terminator::Terminator (const PatternSet& vset, int striplen) : mValidationSet (vset), mStripLength (striplen) {
00070 mMinValidError = 666;
00071 mLastValidError = 666;
00072 mMinCycle = -1;
00073 }
00074
00075 double Terminator::generalizationLoss (double last, double opt) const {
00076 if (last==-666 || opt==-666)
00077 return 100*((mLastValidError/mMinValidError) - 1);
00078 else
00079 return 100*((last/opt)-1);
00080 }
00081
00082 bool Terminator::validate (const ANNetwork& net, Trainer& trainer, int cyclesTrained) {
00083 mLastValidError = net.test (mValidationSet);
00084 trainer.setGeneralizLoss (generalizationLoss());
00085 return check (net, cyclesTrained);
00086 }
00087
00088
00089
00091
00092
00093
00094
00095
00096
00097
00099
00100 TerminatorT800::TerminatorT800 (const PatternSet& vset, int hits, int striplen) : Terminator (vset, striplen), mMaxRaises (hits) {
00101 }
00102
00103 bool TerminatorT800::check (const ANNetwork& net, int cyclesTrained) {
00104 if (mLastValidError<=mMinValidError) {
00105 mMinValidError = mLastValidError;
00106 mRaises = 0;
00107 } else
00108 mRaises++;
00109
00110 return (mRaises>=mMaxRaises);
00111 }
00112
00113
00114
00116
00117
00118
00119
00120
00121
00122
00124
00125 SavingTerminator::SavingTerminator (const PatternSet& validationset, int striplen) : Terminator (validationset, striplen) {
00126 }
00127
00128 SavingTerminator::~SavingTerminator () {
00129 }
00130
00131 void SavingTerminator::save (const ANNetwork& network, int cyclesTrained) {
00132
00133 if (mBestWeights.size()==0) {
00134
00135 int connections=0;
00136 for (int i=0; i<network.size(); i++)
00137 connections += network[i].incomings();
00138
00139 mBestWeights.make (connections + network.size());
00140 }
00141
00142
00143 for (register int j=network.size()-1, ji=0; j>=0; j--)
00144 for (register int i=-1; i<network[j].incomings(); i++, ji++)
00145 if (i==-1)
00146 mBestWeights[ji] = network[j].bias ();
00147 else
00148 mBestWeights[ji] = network[j].incoming(i).weight();
00149
00150 mMinCycle = cyclesTrained;
00151 }
00152
00153 bool SavingTerminator::restore (ANNetwork& network) {
00154
00155 if (mBestWeights.size()>0)
00156 for (register int j=network.size()-1, ji=0; j>=0; j--)
00157 for (register int i=-1; i<network[j].incomings(); i++, ji++)
00158 if (i==-1)
00159 network[j].setBias (mBestWeights[ji]);
00160 else
00161 network[j].incoming(i).setWeight (mBestWeights[ji]);
00162
00163 return true;
00164 }
00165
00166
00167
00169
00170
00171
00172
00173
00174
00175
00177
00178
00179
00180
00181 GLTerminator::GLTerminator (
00182 const PatternSet& validationset,
00183 double threshold,
00184 int striplen)
00185 : Terminator (validationset, striplen),
00186 SavingTerminator (validationset, striplen),
00187 mThreshold (threshold) {
00188 }
00189
00190 bool GLTerminator::check (
00191 const ANNetwork& net,
00192 int cyclesTrained)
00193 {
00194 if (mLastValidError<=mMinValidError) {
00195 mMinValidError = mLastValidError;
00196 save (net, cyclesTrained);
00197 }
00198
00199
00200
00201 return generalizationLoss() >= mThreshold;
00202 }
00203
00204
00205
00206
00207 PRTerminator::PRTerminator (
00208 const PatternSet& validationset,
00209 double threshold,
00210 int striplen,
00211 const String& desc)
00212 : Terminator (validationset, striplen),
00213 GLTerminator (validationset, threshold, striplen)
00214 {
00215 mThreshold = threshold;
00216 mGLperP = 3.0;
00217 mK = 5;
00218 mStripLength = 5;
00219 mMaxRaises = 8;
00220
00221
00222 mGLFulfilled = false;
00223 mUPFulfilled = false;
00224 mGLperPFulfilled = false;
00225 mRaises = 0;
00226 }
00227
00228 double PRTerminator::progress (const ANNetwork& net) const
00229 {
00230 #ifdef CMP_WARNINGS
00231 #warning "TODO: Convert to Trainer"
00232 #endif
00233
00234
00235
00236
00237
00238
00239
00240
00241
00242
00243
00244 return 0.0;
00245 }
00246
00247 bool PRTerminator::check (const ANNetwork& net, int cyclesTrained)
00248 {
00249 bool terminate = false;
00250
00251 if (mLastValidError<=mMinValidError) {
00252 mMinValidError = mLastValidError;
00253 mRaises = 0;
00254 save (net, cyclesTrained);
00255 } else
00256 mRaises++;
00257
00258 if (progress (net) < 0.1)
00259 terminate = true;
00260
00261
00262 if (generalizationLoss() >= mThreshold)
00263 mGLFulfilled = true;
00264
00265
00266 if (mRaises>=mMaxRaises)
00267 mUPFulfilled = true;
00268
00269
00270 if (generalizationLoss()/progress(net) > mGLperP)
00271 mGLperPFulfilled = true;
00272
00273 if (mGLFulfilled && mUPFulfilled && mGLperPFulfilled)
00274 terminate = true;
00275
00276 return terminate;
00277 }
00278
00279
00280
00282
00283
00284
00285
00286
00287
00288
00290
00291 PQTerminator::PQTerminator (
00292 const PatternSet& validationset,
00293 double threshold,
00294 int striplen)
00295 : Terminator (validationset, striplen),
00296 SavingTerminator (validationset, striplen),
00297 mThreshold (threshold)
00298 {
00299 }
00300
00301 bool PQTerminator::check (const ANNetwork& net, int cyclesTrained)
00302 {
00303 double GL = generalizationLoss ();
00304 return GL>=mThreshold;
00305 }
00306
00307
00308
00310
00311
00312
00313
00314
00315
00316
00318
00319 UPTerminator::UPTerminator (
00320 const PatternSet& validationset,
00321 int maxraises,
00322 int striplen)
00323 : Terminator (validationset, striplen),
00324 SavingTerminator (validationset, striplen),
00325 TerminatorT800 (validationset, maxraises, striplen)
00326 {
00327 }
00328
00329 bool UPTerminator::check (const ANNetwork& net, int cyclesTrained)
00330 {
00331 if (mLastValidError<=mMinValidError) {
00332 mMinValidError = mLastValidError;
00333 mRaises = 0;
00334 save (net, cyclesTrained);
00335 } else
00336 mRaises++;
00337
00338 return (mRaises>=mMaxRaises);
00339 }
00340