Main Page   Class Hierarchy   Compound List   File List   Compound Members   File Members  

termination.cc

Go to the documentation of this file.
00001 
00025 #include "inanna/termination.h"
00026 #include "inanna/annetwork.h"
00027 #include "inanna/patternset.h"
00028 #include "inanna/trainer.h"
00029 
00030 Terminator* buildTerminator (const String& modelName, const PatternSet& validationset, int interval) {
00031     if (modelName == "none")
00032         return NULL;
00033     if (modelName == "dummy")
00034         return new DummyTerminator (validationset, interval);
00035         
00036     ASSERT (modelName.length()>=3);
00037     String  prefix = modelName.left(2);
00038     int     intParam = modelName.mid (2).toInt();
00039     double  doubleParam = modelName.mid (2).toDouble();
00040     ASSERT (intParam>0);
00041     ASSERT (doubleParam>0.0);
00042     
00043     if (prefix=="FT")
00044         return new TerminatorT800 (validationset, intParam, interval);
00045     if (prefix=="GL")
00046         return new GLTerminator (validationset, doubleParam, interval);
00047     if (prefix=="PQ")
00048         return new PQTerminator (validationset, doubleParam, interval);
00049     if (prefix=="UP")
00050         return new UPTerminator (validationset, intParam, interval);
00051     if (prefix=="PR")
00052         return new PRTerminator (validationset, doubleParam, interval, modelName.mid(2));
00053     
00054     throw generic_exception (format ("Invalid terminator model name '%s'",
00055                                      (CONSTR) modelName));
00056 }
00057 
00058 
00060 //                                                                           //
00061 //      -----                 o         |   |           |             |      //
00062 //        |    ___                _     |\ /|  ___   |  | _           |      //
00063 //        |   /   ) |/\ |/|/| | |/ \    | V | /   ) -+- |/ |  __   ---|      //
00064 //        |   |---  |   | | | | |   |   | | | |---   |  |  | /  \ (   |      //
00065 //        |    \__  |   | | | | |   | O |   |  \__    \ |  | \__/  ---|      //
00066 //                                                                           //
00068 
00069 Terminator::Terminator (const PatternSet& vset, int striplen) : mValidationSet (vset), mStripLength (striplen) {
00070     mMinValidError = 666;
00071     mLastValidError = 666;
00072     mMinCycle = -1;
00073 }
00074 
00075 double Terminator::generalizationLoss (double last, double opt) const {
00076     if (last==-666 || opt==-666)
00077         return 100*((mLastValidError/mMinValidError) - 1);
00078     else
00079         return 100*((last/opt)-1);
00080 }
00081 
00082 bool Terminator::validate (const ANNetwork& net, Trainer& trainer, int cyclesTrained) {
00083     mLastValidError = net.test (mValidationSet);
00084     trainer.setGeneralizLoss (generalizationLoss());
00085     return check (net, cyclesTrained);
00086 }
00087 
00088 
00089 
00091 //                                                                           //
00092 //           -----                 -----                 o                   //
00093 //           |      ___   ____  |    |    ___                _               //
00094 //           |---   ___| (     -+-   |   /   ) |/\ |/|/| | |/ \              //
00095 //           |     (   |  \__   |    |   |---  |   | | | | |   |             //
00096 //           |      \__| ____)   \   |    \__  |   | | | | |   | O           //
00097 //                                                                           //
00099 
00100 TerminatorT800::TerminatorT800 (const PatternSet& vset, int hits, int striplen) : Terminator (vset, striplen), mMaxRaises (hits) {
00101 }
00102     
00103 bool TerminatorT800::check (const ANNetwork& net, int cyclesTrained) {
00104     if (mLastValidError<=mMinValidError) {
00105         mMinValidError = mLastValidError;
00106         mRaises = 0;
00107     } else
00108         mRaises++;
00109     
00110     return (mRaises>=mMaxRaises);
00111 }
00112 
00113 
00114 
00116 //                                                                           //
00117 //       ----             o             -----                 o              //
00118 //      (      ___            _           |    ___                _          //
00119 //       ---   ___| |   | | |/ \   ___    |   /   ) |/\ |/|/| | |/ \         //
00120 //          ) (   |  \ /  | |   | (   \   |   |---  |   | | | | |   |        //
00121 //      ___/   \__|   V   | |   |  ---/   |    \__  |   | | | | |   | O      //
00122 //                                 __/                                       //
00124 
00125 SavingTerminator::SavingTerminator (const PatternSet& validationset, int striplen) : Terminator (validationset, striplen) {
00126 }
00127 
00128 SavingTerminator::~SavingTerminator () {
00129 }
00130 
00131 void SavingTerminator::save (const ANNetwork& network, int cyclesTrained) {
00132     // On the first call, create storage for weights and biases
00133     if (mBestWeights.size()==0) {
00134         // Count the total number of connections in the entire network
00135         int connections=0;
00136         for (int i=0; i<network.size(); i++)
00137             connections += network[i].incomings();
00138         
00139         mBestWeights.make (connections + network.size());
00140     }
00141 
00142     // Copy weights and biases to the storate
00143     for (register int j=network.size()-1, ji=0; j>=0; j--)
00144         for (register int i=-1; i<network[j].incomings(); i++, ji++)
00145             if (i==-1) // Bias
00146                 mBestWeights[ji] = network[j].bias ();
00147             else // Weight
00148                 mBestWeights[ji] = network[j].incoming(i).weight();
00149     
00150     mMinCycle = cyclesTrained;
00151 }
00152 
00153 bool SavingTerminator::restore (ANNetwork& network) {
00154     // Restore weights and biases from the storage
00155     if (mBestWeights.size()>0)
00156         for (register int j=network.size()-1, ji=0; j>=0; j--)
00157             for (register int i=-1; i<network[j].incomings(); i++, ji++)
00158                 if (i==-1) // Bias
00159                     network[j].setBias (mBestWeights[ji]);
00160                 else // Weight
00161                     network[j].incoming(i).setWeight (mBestWeights[ji]);
00162 
00163     return true;
00164 }
00165 
00166 
00167 
00169 //                                                                          //
00170 //      ---- |     -----                 o                 o                //
00171 //     |     |       |    ___                _    ___   |           _       //
00172 //     | --- |       |   /   ) |/\ |/|/| | |/ \   ___| -+- |  __  |/ \      //
00173 //     |   \ |       |   |---  |   | | | | |   | (   |  |  | /  \ |   |     //
00174 //     |___/ |____   |    \__  |   | | | | |   |  \__|   \ | \__/ |   |     //
00175 //                                                                          //
00177 
00178 /*******************************************************************************
00179 *
00180 **/
00181 GLTerminator::GLTerminator (
00182     const PatternSet& validationset,
00183     double            threshold,
00184     int               striplen)
00185         : Terminator (validationset, striplen),
00186           SavingTerminator (validationset, striplen),
00187           mThreshold (threshold) {
00188 }
00189 
00190 bool GLTerminator::check (
00191     const ANNetwork& net,
00192     int              cyclesTrained)
00193 {
00194     if (mLastValidError<=mMinValidError) {
00195         mMinValidError = mLastValidError;
00196         save (net, cyclesTrained);
00197     }
00198     
00199     //TRACE3 ("Valid-MSE=%f, GL=%f, threshold=%f",
00200     //      mLastValidError, generalizationLoss(), mThreshold);
00201     return generalizationLoss() >= mThreshold;
00202 }
00203 
00204 /*******************************************************************************
00205 *
00206 **/
00207 PRTerminator::PRTerminator (
00208     const PatternSet& validationset,
00209     double            threshold,
00210     int               striplen,
00211     const String&     desc)
00212         : Terminator (validationset, striplen),
00213           GLTerminator (validationset, threshold, striplen)
00214 {
00215     mThreshold       = threshold;
00216     mGLperP          = 3.0;
00217     mK               = 5;
00218     mStripLength     = 5;
00219     mMaxRaises       = 8;
00220 
00221     // States
00222     mGLFulfilled     = false;
00223     mUPFulfilled     = false;
00224     mGLperPFulfilled = false;
00225     mRaises          = 0;
00226 }
00227 
00228 double PRTerminator::progress (const ANNetwork& net) const
00229 {
00230 #ifdef CMP_WARNINGS
00231 #warning "TODO: Convert to Trainer"
00232 #endif
00233     /*
00234     int endpos=net.cyclesTrained;
00235     double sum=0, min=1E30;
00236     const Vector& trainMSEs=net.trainingRecord();
00237     for (int i=endpos-mStripLength; i<endpos; i++) {
00238         sum += trainMSEs[i];
00239         if (trainMSEs[i]<min)
00240             min = trainMSEs[i];
00241     }
00242     return 1000*(sum/(mK*min)-1);
00243     */
00244     return 0.0;
00245 }
00246 
00247 bool PRTerminator::check (const ANNetwork& net, int cyclesTrained)
00248 {
00249     bool terminate = false;
00250     
00251     if (mLastValidError<=mMinValidError) {
00252         mMinValidError = mLastValidError;
00253         mRaises = 0;
00254         save (net, cyclesTrained);
00255     } else
00256         mRaises++;
00257 
00258     if (progress (net) < 0.1)
00259         terminate = true;
00260 
00261     // GL-criterion
00262     if (generalizationLoss() >= mThreshold)
00263         mGLFulfilled = true;
00264 
00265     // UP-criterion
00266     if (mRaises>=mMaxRaises)
00267         mUPFulfilled = true;
00268 
00269     // Progress ratio criterion
00270     if (generalizationLoss()/progress(net) > mGLperP)
00271         mGLperPFulfilled = true;
00272 
00273     if (mGLFulfilled && mUPFulfilled && mGLperPFulfilled)
00274         terminate = true;
00275 
00276     return terminate;
00277 }
00278 
00279 
00280 
00282 //                                                                          //
00283 //     ----   ___  -----                 o                 o                //
00284 //     |   ) |   |   |    ___                _    ___   |           _       //
00285 //     |---  |   |   |   /   ) |/\ |/|/| | |/ \   ___| -+- |  __  |/ \      //
00286 //     |     | \ |   |   |---  |   | | | | |   | (   |  |  | /  \ |   |     //
00287 //     |     `__X´   |    \__  |   | | | | |   |  \__|   \ | \__/ |   |     //
00288 //               \                                                          //
00290 
00291 PQTerminator::PQTerminator (
00292     const PatternSet& validationset,
00293     double            threshold,
00294     int               striplen)
00295         : Terminator (validationset, striplen),
00296           SavingTerminator (validationset, striplen),
00297           mThreshold (threshold)
00298 {
00299 }
00300 
00301 bool PQTerminator::check (const ANNetwork& net, int cyclesTrained)
00302 {
00303     double GL = generalizationLoss ();
00304     return GL>=mThreshold;
00305 }
00306 
00307 
00308 
00310 //                                                                          //
00311 //     |   | ----  -----                 o                 o                //
00312 //     |   | |   )   |    ___                _    ___   |           _       //
00313 //     |   | |---    |   /   ) |/\ |/|/| | |/ \   ___| -+- |  __  |/ \      //
00314 //     |   | |       |   |---  |   | | | | |   | (   |  |  | /  \ |   |     //
00315 //     `___´ |       |    \__  |   | | | | |   |  \__|   \ | \__/ |   |     //
00316 //                                                                          //
00318 
00319 UPTerminator::UPTerminator (
00320     const PatternSet& validationset,
00321     int               maxraises,
00322     int               striplen)
00323         : Terminator (validationset, striplen),
00324           SavingTerminator (validationset, striplen),
00325           TerminatorT800 (validationset, maxraises, striplen)
00326 {
00327 }
00328 
00329 bool UPTerminator::check (const ANNetwork& net, int cyclesTrained)
00330 {
00331     if (mLastValidError<=mMinValidError) {
00332         mMinValidError = mLastValidError;
00333         mRaises = 0;
00334         save (net, cyclesTrained);
00335     } else
00336         mRaises++;
00337     
00338     return (mRaises>=mMaxRaises);
00339 }
00340 

Generated on Thu Feb 10 20:06:45 2005 for Inanna by doxygen1.2.18