00001 00025 #ifndef __BACKPROPTRAINER_H__ 00026 #define __BACKPROPTRAINER_H__ 00027 00028 #include "trainer.h" 00029 00030 00032 // ---- | ----- o // 00033 // | ) ___ ___ | -- -- | ___ _ ___ // 00034 // |--- ___| | \ | / | ) |/\ __ | ) | |/\ ___| | |/ \ / ) |/\ // 00035 // | ) ( | | |/ |-- | / \ |-- | | ( | | | | |--- | // 00036 // |___ \__| \__/ | \ | | \__/ | | | \__| | | | \__ | // 00038 00044 class BackpropTrainer : public Trainer { 00045 public: 00046 virtual Array<DynParameter>* parameters () const; 00047 virtual void init (const StringMap& params); 00048 00049 protected: 00051 virtual void initTrain (ANNetwork& network) const; 00052 00054 virtual double trainOnce (ANNetwork& network, const PatternSet& set) const; 00055 00057 virtual double trainPattern (ANNetwork& network, const PatternSet& set, int p) const; 00058 00063 virtual void backpropagate (ANNetwork& network, const PatternSet& set, int p) const; 00064 00066 virtual void updateWeights (ANNetwork& network) const; 00067 00068 protected: 00070 double mEta; 00071 00073 double mMomentum; 00074 00076 double mDecay; 00077 00079 bool mBatchLearning; 00080 00086 mutable Vector mWeightDeltas; 00087 00093 mutable Vector mError; 00094 }; 00095 00096 #endif 00097