00001
00025 #ifndef __TERMINATION_H__
00026 #define __TERMINATION_H__
00027
00033 #include "annetwork.h"
00034
00035
00036 class ANNetwork;
00037 class PatternSet;
00038 class Trainer;
00039
00052 class Terminator : public Object {
00053 public:
00063 Terminator (const PatternSet& validationset, int striplen);
00064
00065
00073 bool validate (const ANNetwork& net, Trainer& trainer, int cyclesTrained);
00074
00084 virtual bool restore (ANNetwork& net) {return false;}
00085
00092 double generalizationLoss (double min=-666, double opt=-666) const;
00093
00095 double validationError () const {return mLastValidError;}
00096
00098 double minimumError () const {return mMinValidError;}
00099
00103 int howManyTrained () const {return mMinCycle;}
00104
00105 protected:
00107 double mMinValidError;
00108
00110 int mMinCycle;
00111
00113 double mLastValidError;
00114
00116 const PatternSet& mValidationSet;
00117
00119 int mStripLength;
00120
00122 virtual bool check (const ANNetwork& net, int cyclesTrained) = 0;
00123
00124 };
00125
00130 Terminator* buildTerminator (const String& modelName,
00131 const PatternSet& validationset,
00132 int validationInterval);
00133
00135
00143 class TerminatorT800 : virtual public Terminator {
00144 public:
00145
00151 TerminatorT800 (const PatternSet& validationset,
00152 int hits, int striplen);
00153
00154 bool check (const ANNetwork& net, int cyclesTrained);
00155
00156 protected:
00158 int mRaises;
00159
00165 const int mMaxRaises;
00166
00167 };
00168
00172 class DummyTerminator : virtual public Terminator {
00173 public:
00174
00175 DummyTerminator (const PatternSet& validationset, int striplen) : Terminator (validationset, striplen) {}
00176 bool check (const ANNetwork& net, int cyclesTrained) {mMinCycle = cyclesTrained; return false;}
00177 };
00178
00182 class SavingTerminator : virtual public Terminator {
00183 public:
00184 SavingTerminator (const PatternSet& validationset, int striplen);
00185 ~SavingTerminator ();
00186
00190 bool restore (ANNetwork& net);
00191
00192 protected:
00193
00196 void save (const ANNetwork& net, int cyclesTrained);
00197
00198 private:
00203 Vector mBestWeights;
00204 };
00205
00206
00207
00208
00209
00215 class GLTerminator : public SavingTerminator {
00216 public:
00217
00230 GLTerminator (const PatternSet& validationset,
00231 double threshold, int striplen);
00232
00233 bool check (const ANNetwork& net, int cyclesTrained);
00234
00235 protected:
00238 double mThreshold;
00239 };
00240
00248 class PRTerminator : public GLTerminator {
00249 public:
00250 PRTerminator (const PatternSet& validationset, double thrshld,
00251 int striplen, const String& pars);
00252
00253 bool check (const ANNetwork& net, int cyclesTrained);
00254
00255 private:
00256 int mStrips;
00257 double mGLperP;
00258 int mRaises, mMaxRaises;
00259 bool mGLFulfilled, mUPFulfilled, mGLperPFulfilled;
00260 double mK;
00261
00262 double progress (const ANNetwork& net) const;
00263 };
00264
00269 class PQTerminator : public SavingTerminator {
00270 double mThreshold;
00271 public:
00272 PQTerminator (const PatternSet& validationset,
00273 double threshold, int striplen);
00274
00275 bool check (const ANNetwork& net, int cyclesTrained);
00276 };
00277
00283 class UPTerminator : public SavingTerminator, public TerminatorT800 {
00284 public:
00285 UPTerminator (const PatternSet& validationset,
00286 int maxraises, int striplen);
00287
00288 bool check (const ANNetwork& net, int cyclesTrained);
00289 };
00290
00291
00292 #endif
00293