Main Page   Class Hierarchy   Compound List   File List   Compound Members   File Members  

backprop.cc

Go to the documentation of this file.
00001 
00025 #include "inanna/backprop.h"
00026 #include "inanna/patternset.h"
00027 
00028 
00030 // ----              |                      -----           o                 //
00031 // |   )  ___   ___  |    --            --    |        ___      _    ___      //
00032 // |---   ___| |   \ | / |  ) |/\  __  |  )   |   |/\  ___| | |/ \  /   ) |/\ //
00033 // |   ) (   | |     |/  |--  |   /  \ |--    |   |   (   | | |   | |---  |   //
00034 // |___   \__|  \__/ | \ |    |   \__/ |      |   |    \__| | |   |  \__  |   //
00036 
00037 /*virtual*/ void BackpropTrainer::init (const StringMap& params) {
00038     Trainer::init (params);
00039     INITPARAMS(params, 
00040                mEta             = params["BackpropTrainer.eta"].toDouble();
00041                mMomentum        = params["BackpropTrainer.momentum"].toDouble();
00042                mDecay           = params["BackpropTrainer.decay"].toDouble();
00043                mBatchLearning   = params["BackpropTrainer.batchLearning"].toInt();
00044         );
00045 }
00046 
00047 /*virtual*/ Array<DynParameter>* BackpropTrainer::parameters () const {
00048     Array<DynParameter>* result = new Array<DynParameter>;
00049     result->add (new DoubleParameter    ("eta", i18n("Learning rate"), 15, 0.0, 1.0, 0.25));
00050     result->add (new DoubleParameter    ("momentum", i18n("Weight momentum"), 15, 0.0, 1.0, 0.9));
00051     result->add (new DoubleParameter    ("decay", i18n("Weight decay multiplier"), 15, 0.5, 1.0, 1.0));
00052     result->add (new BoolParameter      ("batchLearning", i18n("Update weights in batch")));
00053 
00054     return result;
00055 }
00056 
00057 /*virtual*/ void BackpropTrainer::initTrain (ANNetwork& network) const {
00058     Trainer::initTrain (network);
00059     
00060     // Count the total number of connections in the entire network
00061     int connections=0;
00062     for (int i=0; i<network.size(); i++)
00063         connections += network[i].incomings();
00064 
00065     // Create weight delta data for connections and biases
00066     mWeightDeltas.make (connections + network.size());
00067     for (int i=0; i<mWeightDeltas.size(); i++)
00068         mWeightDeltas[i] = 0.0;
00069 }
00070 
00071 /*virtual*/ double BackpropTrainer::trainOnce (ANNetwork& network, const PatternSet& set) const {
00072 #ifdef CMP_WARNINGS
00073 #warning "TODO: Batch learning is disabled right now."
00074 #endif
00075     /*    
00076     if (mBatchLearning)
00077         for (int i=0; i<mWeightDeltas.size(); i++)
00078             mWeightDeltas[i] = 0.0;
00079     */
00080     
00081     // Train each pattern once
00082     double sse=0.0;
00083     for (int p=0; p<set.patterns; p++)
00084         sse += trainPattern (network, set, p);
00085 
00086     if (true || mBatchLearning)
00087         updateWeights (network);
00088 
00089     // Actualize weight adjustments
00090     /*
00091     if (mBatchLearning)
00092         for (int n=0, oc=0; n<network.size(); n++)
00093             for (int c=0; c<network[n].incomings(); c++, oc++)
00094                 network[n].incoming(c).setWeight (mWeightDeltas[oc]);
00095     */
00096     return sse/set.patterns; // Return MSE
00097 }
00098 
00099 /*virtual*/ double BackpropTrainer::trainPattern (ANNetwork& network, const PatternSet& set, int p) const {
00100     // Feed the pattern to the network
00101     for (int inp=0; inp<set.inputs; inp++)
00102         network[inp].setActivation (set.input (p, inp));
00103 
00104     // Forward pass
00105     network.update ();
00106 
00107     // Backward pass
00108     backpropagate (network, set, p);
00109 
00110     // Calculate error
00111     double sse=0.0;
00112     for (int outp=0; outp<set.outputs; outp++)
00113         sse += sqr(set.output(p,outp) - network[network.size() - set.outputs + outp].activation());
00114 
00115     return sse / set.outputs; // Return MSE
00116 }
00117 
00118 void BackpropTrainer::backpropagate (register ANNetwork& network, register const PatternSet& set, int p) const {
00119     mError.make (network.size());
00120     int outLayerBase = network.size() - set.outputs;
00121     //register Connection* conn;
00122     register double sum_k;
00123     register Neuron* neuron_j;
00124     register double delta_j;
00125     register int j;
00126     //register int k;
00127     
00128     // Iterate backwards
00129     for (j=network.size()-1; j>=0; j--) {
00130         neuron_j = &network[j];
00131         // Calculate error at a neuron
00132         if (j >= outLayerBase) { // Output neuron
00133             delta_j = (set.output(p,j-outLayerBase) - neuron_j->activation())
00134                 * neuron_j->activation() * (1.0 - neuron_j->activation());
00135         }
00136         else { // A hidden or input neuron
00137             sum_k=0.0;
00138             for (int k=0; k<neuron_j->outgoings(); k++)
00139                 sum_k += mError [neuron_j->outgoing(k).target().id()] * neuron_j->outgoing(k).weight();
00140             
00141             delta_j = neuron_j->activation()*(1.0-neuron_j->activation()) * sum_k;
00142         }
00143         mError[j] = delta_j;
00144     }
00145 }
00146 
00147 void BackpropTrainer::updateWeights (register ANNetwork& network) const {
00148     register int j, i, ji;
00149     register double deltaw_ji;
00150     
00151     for (j=network.size()-1, ji=0; j>=0; j--) {
00152 
00153         // Update bias
00154         deltaw_ji = mEta * mError[j];
00155         network[j].setBias (network[j].bias() + deltaw_ji + mMomentum*mWeightDeltas[ji]);
00156         mWeightDeltas[ji++] = deltaw_ji;
00157 
00158         // Update weights
00159         for (i=network[j].incomings()-1; i>=0; i--, ji++) {
00160             register Connection& conn = network[j].incoming(i);
00161             deltaw_ji = mEta * mError[j] * conn.source().activation();
00162             conn.setWeight (conn.weight() + deltaw_ji + mMomentum*mWeightDeltas[ji]);
00163             mWeightDeltas[ji] = deltaw_ji;
00164         }
00165     }
00166 }
00167 

Generated on Thu Feb 10 20:06:44 2005 for Inanna by doxygen1.2.18