00001
00025 #include "inanna/backprop.h"
00026 #include "inanna/patternset.h"
00027
00028
00030
00031
00032
00033
00034
00036
00037 void BackpropTrainer::init (const StringMap& params) {
00038 Trainer::init (params);
00039 INITPARAMS(params,
00040 mEta = params["BackpropTrainer.eta"].toDouble();
00041 mMomentum = params["BackpropTrainer.momentum"].toDouble();
00042 mDecay = params["BackpropTrainer.decay"].toDouble();
00043 mBatchLearning = params["BackpropTrainer.batchLearning"].toInt();
00044 );
00045 }
00046
00047 Array<DynParameter>* BackpropTrainer::parameters () const {
00048 Array<DynParameter>* result = new Array<DynParameter>;
00049 result->add (new DoubleParameter ("eta", i18n("Learning rate"), 15, 0.0, 1.0, 0.25));
00050 result->add (new DoubleParameter ("momentum", i18n("Weight momentum"), 15, 0.0, 1.0, 0.9));
00051 result->add (new DoubleParameter ("decay", i18n("Weight decay multiplier"), 15, 0.5, 1.0, 1.0));
00052 result->add (new BoolParameter ("batchLearning", i18n("Update weights in batch")));
00053
00054 return result;
00055 }
00056
00057 void BackpropTrainer::initTrain (ANNetwork& network) const {
00058 Trainer::initTrain (network);
00059
00060
00061 int connections=0;
00062 for (int i=0; i<network.size(); i++)
00063 connections += network[i].incomings();
00064
00065
00066 mWeightDeltas.make (connections + network.size());
00067 for (int i=0; i<mWeightDeltas.size(); i++)
00068 mWeightDeltas[i] = 0.0;
00069 }
00070
00071 double BackpropTrainer::trainOnce (ANNetwork& network, const PatternSet& set) const {
00072 #ifdef CMP_WARNINGS
00073 #warning "TODO: Batch learning is disabled right now."
00074 #endif
00075
00076
00077
00078
00079
00080
00081
00082 double sse=0.0;
00083 for (int p=0; p<set.patterns; p++)
00084 sse += trainPattern (network, set, p);
00085
00086 if (true || mBatchLearning)
00087 updateWeights (network);
00088
00089
00090
00091
00092
00093
00094
00095
00096 return sse/set.patterns;
00097 }
00098
00099 double BackpropTrainer::trainPattern (ANNetwork& network, const PatternSet& set, int p) const {
00100
00101 for (int inp=0; inp<set.inputs; inp++)
00102 network[inp].setActivation (set.input (p, inp));
00103
00104
00105 network.update ();
00106
00107
00108 backpropagate (network, set, p);
00109
00110
00111 double sse=0.0;
00112 for (int outp=0; outp<set.outputs; outp++)
00113 sse += sqr(set.output(p,outp) - network[network.size() - set.outputs + outp].activation());
00114
00115 return sse / set.outputs;
00116 }
00117
00118 void BackpropTrainer::backpropagate (register ANNetwork& network, register const PatternSet& set, int p) const {
00119 mError.make (network.size());
00120 int outLayerBase = network.size() - set.outputs;
00121
00122 register double sum_k;
00123 register Neuron* neuron_j;
00124 register double delta_j;
00125 register int j;
00126
00127
00128
00129 for (j=network.size()-1; j>=0; j--) {
00130 neuron_j = &network[j];
00131
00132 if (j >= outLayerBase) {
00133 delta_j = (set.output(p,j-outLayerBase) - neuron_j->activation())
00134 * neuron_j->activation() * (1.0 - neuron_j->activation());
00135 }
00136 else {
00137 sum_k=0.0;
00138 for (int k=0; k<neuron_j->outgoings(); k++)
00139 sum_k += mError [neuron_j->outgoing(k).target().id()] * neuron_j->outgoing(k).weight();
00140
00141 delta_j = neuron_j->activation()*(1.0-neuron_j->activation()) * sum_k;
00142 }
00143 mError[j] = delta_j;
00144 }
00145 }
00146
00147 void BackpropTrainer::updateWeights (register ANNetwork& network) const {
00148 register int j, i, ji;
00149 register double deltaw_ji;
00150
00151 for (j=network.size()-1, ji=0; j>=0; j--) {
00152
00153
00154 deltaw_ji = mEta * mError[j];
00155 network[j].setBias (network[j].bias() + deltaw_ji + mMomentum*mWeightDeltas[ji]);
00156 mWeightDeltas[ji++] = deltaw_ji;
00157
00158
00159 for (i=network[j].incomings()-1; i>=0; i--, ji++) {
00160 register Connection& conn = network[j].incoming(i);
00161 deltaw_ji = mEta * mError[j] * conn.source().activation();
00162 conn.setWeight (conn.weight() + deltaw_ji + mMomentum*mWeightDeltas[ji]);
00163 mWeightDeltas[ji] = deltaw_ji;
00164 }
00165 }
00166 }
00167