00001 00025 #ifndef __INANNA_TRAINER_H__ 00026 #define __INANNA_TRAINER_H__ 00027 00028 #include <magic/mparameter.h> 00029 #include "inanna/annetwork.h" 00030 00031 // Local predeclarations 00032 class Trainer; 00033 class TrainingObserver; 00034 00035 00037 // ----- o // 00038 // | ___ _ ___ // 00039 // | |/\ ___| | |/ \ / ) |/\ // 00040 // | | ( | | | | |--- | // 00041 // | | \__| | | | \__ | // 00043 00051 class Trainer : public Object, public IParameterized { 00052 public: 00053 Trainer (); 00054 00057 virtual void init (const StringMap& params); 00058 00067 virtual double train (ANNetwork& network, 00068 const PatternSet& trainset, 00069 int cycles, 00070 const PatternSet* pValidationSet=NULL, 00071 int validationInterval=0); 00072 00078 void setTerminator (const String& name) {mTerminatorName=name;} 00079 00080 // Informative methods 00081 00089 int cyclesTrained () const {return mTrained;} 00090 00096 int totalCycles () const {return mTotalTrained;} 00097 00101 double generalizLoss () const {return mGeneralizationLoss;} 00102 00106 void setGeneralizLoss(double gl) {mGeneralizationLoss = gl;} 00107 00112 const Vector& trainingRecord () const {return mTrainingProfile;} 00113 00118 const Vector& validationRecord() const {return mValidationProfile;} 00119 00122 void setObserver (TrainingObserver* observer) {pTrainingObserver=observer;} 00123 00124 protected: 00125 00127 virtual void initTrain (ANNetwork& network) const; 00128 00130 virtual double trainOnce (ANNetwork& network, const PatternSet& set) const {MUST_OVERLOAD; return 0.0;} 00131 00132 protected: 00133 00135 String mTerminatorName; 00136 00142 int mTrained; 00143 00149 int mTotalTrained; 00150 00160 double mGeneralizationLoss; 00161 00166 Vector mTrainingProfile; 00167 00172 Vector mValidationProfile; 00173 00176 TrainingObserver* pTrainingObserver; 00177 00178 friend class Terminator; 00179 }; 00180 00181 00182 00184 // ----- o o ___ // 00185 // | ___ _ _ | | | ____ ___ ___ // 00186 // | |/\ ___| | |/ \ | |/ \ ___ | | |--- ( / ) |/\ | | / ) |/\ // 00187 // | | ( | | | | | | | ( \ | | | ) \__ |--- | \ / |--- | // 00188 // | | \__| | | | | | | ---/ `___´ |__/ ____) \__ | V \__ | // 00189 // __/ // 00191 00196 class TrainingObserver : public Object { 00197 public: 00198 TrainingObserver () {mStop = false;} 00199 00201 void initTraining () {mStop = false;} 00202 00209 bool wantsToStop () const {return mStop;} 00210 00220 virtual void cycleTrained (const Trainer& trainer, int totalCycles)=0; 00221 00222 protected: 00229 void stopTraining () {mStop=true;} 00230 00231 private: 00232 bool mStop; 00233 }; 00234 00235 #endif 00236