Main Page   Class Hierarchy   Compound List   File List   Compound Members   File Members  

learningenv.cc

Go to the documentation of this file.
00001 
00025 #include <magic/mmap.h>
00026 #include <magic/mclass.h>
00027 #include <magic/mtextstream.h>
00028 
00029 #include <nhp/individual.h>
00030 #include <inanna/annetwork.h>
00031 #include <inanna/annfilef.h>
00032 #include <inanna/patternset.h>
00033 #include <inanna/rprop.h>
00034 
00035 #include "annalee/learningenv.h"
00036 #include "annalee/anngenes.h"
00037 #include "annalee/layered.h"
00038 #include "annalee/miller.h"
00039 #include "annalee/cangelosi.h"
00040 #include "annalee/kitano.h"
00041 //#include "chaosenc.h"
00042 
00043 impl_dynamic (LearningEAEnv, {EAEnvironment});
00044 
00045 
00046 
00048 //                                                                           //
00049 //  |                           o              ----   _   -----              //
00050 //  |      ___   ___        _       _         |      / \  |       _          //
00051 //  |     /   )  ___| |/\ |/ \  | |/ \   ___  | --- /   \ |---  |/ \  |   |  //
00052 //  |     |---  (   | |   |   | | |   | (   \ |   \ |---| |     |   |  \ /   //
00053 //  |____  \__   \__| |   |   | | |   |  ---/ |___/ |   | |____ |   |   V    //
00054 //                                       __/                                 //
00056 
00062 LearningEAEnv::LearningEAEnv () :
00063         mTrainData ((PatternSet&) *new PatternSet()),
00064         mTrainSet ((PatternSet&) *new PatternSet()),
00065         mEvaluationSet ((PatternSet&) *new PatternSet()),
00066         mReportSet ((PatternSet&) *new PatternSet()),
00067         mParams((StringMap&) *new StringMap())
00068 {
00069     FORBIDDEN; // But IT'S NEVER CALLED! (whew...)
00070 }
00071 
00095 LearningEAEnv::LearningEAEnv (const PatternSet& trainSet,
00096                               const PatternSet& evalSet,
00097                               const PatternSet& testSet,
00098                               StringMap& params)
00099         : mTrainSet      (dynamic_cast<const PatternSet&>(trainSet)),
00100           mEvaluationSet (dynamic_cast<const PatternSet&>(evalSet)),
00101           mReportSet     (dynamic_cast<const PatternSet&>(testSet)),
00102           mParams        (params)
00103 {
00104     ASSERTWITH (trainSet.patterns>0, "Must have train patterns");
00105     ASSERTWITH (evalSet.patterns>0, "Must have evaluation patterns");
00106     
00107     // Set some default values if not given explicitly
00108     mNEvals         = getOrDefault (mParams, "LearningEAEnv.evals", String(1)).toInt ();
00109     mNoise          = getOrDefault (mParams, "LearningEAEnv.noise", String(0.0)).toDouble ();
00110     mPermutate      = getOrDefault (mParams, "LearningEAEnv.permutate", String(0)).toInt ();
00111     mMaxTrainCycles = getOrDefault (mParams, "LearningEAEnv.maxTrainCycles", String(3000)).toInt ();
00112     mReportCycles   = mMaxTrainCycles;
00113     mValidInterval  = getOrDefault (mParams, "LearningEAEnv.stripLen", String(10)).toInt ();
00114     mTermMethod     = getOrDefault (mParams, "LearningEAEnv.terminator", String("GL5"));
00115     mTermPart       = getOrDefault (mParams, "LearningEAEnv.termPart", String(0.25)).toInt ();
00116     logDir (getOrDefault (params, "logdir", String("log")));
00117 
00118     mEvalPart       = evalSet.patterns/double(evalSet.patterns+trainSet.patterns);
00119     mProblemType    = (trainSet.outputs>1)? CLASSIFICATION : CLASSIFICATION2;
00120     
00121     if (mTermMethod=="none")
00122         mTermPart = 0;
00123 
00124     // Join the given training and evaluation sets
00125     mTrainData.join (trainSet, evalSet);
00126 
00127     // Then split them again, just to be sure that the splitting is
00128     // done exactly the same way if the training data is permutated or
00129     // something
00130     splitTrainData ();
00131 
00132     ASSERTWITH (trainSet.patterns == mTrainSet.patterns,
00133             "Bug in algorithm: Uneven splitting of training data");
00134 
00135     ASSERT (mTrainSet.patterns>0);
00136     ASSERT (mEvaluationSet.patterns>0);
00137     ASSERT (mReportSet.patterns>0);
00138 }
00139 
00142 void LearningEAEnv::addFeaturesTo (Genome& genome) const
00143 {
00144     // TODO: bool optimize_parameters = mParams["LearningEAEnv.optimizeParams"].toInt ();
00145     
00146     mParams.set ("inputs", String(mTrainData.inputs));
00147     mParams.set ("outputs", String(mTrainData.outputs));
00148     
00149     // Brain, according to the selected encoding scheme
00150     String encoding = mParams["LearningEAEnv.encoding"];
00151     if (encoding=="layered")
00152         genome.add (new LayeredEncoding ("brainplan", mParams));
00153     else if (encoding == "miller")
00154         genome.add (new MillerEncoding ("brainplan", mParams));
00155     else if (encoding == "nolfi")
00156         genome.add (new NolfiEncoding ("brainplan", mParams));
00157     else if (encoding == "cangelosi")
00158         genome.add (new CangelosiEncoding ("brainplan", mParams));
00159     else if (encoding == "kitano")
00160         genome.add (new KitanoEncoding ("brainplan", mParams));
00161     //  else if (encoding == "chaos")
00162     //      genome.add (new ChaosEncoding ("brainplan", mParams));
00163     else
00164         ASSERTWITH (false, format ("Unknown EANN encoding '%s'", (CONSTR) encoding));
00165 
00166     // At initialization of an individual, invoke the brainplan
00167     genome.add (new InterGene ("init", "brainplan"));
00168 }
00169 
00170 void LearningEAEnv::permutate () {
00171     FORBIDDEN; // Temporarily, we don't want anyone to use this
00172     
00173     // First permutate the whole dataset
00174     if (mProblemType==CLASSIFICATION2)
00175         mTrainData.recombine2 ();
00176     else
00177         mTrainData.recombine ();
00178 
00179     splitTrainData ();
00180 }
00181 
00187 void LearningEAEnv::splitTrainData ()
00188 {
00189     mTrainSet.copy (mTrainData, 0, int(mTrainData.patterns*(1-mEvalPart))-1);
00190     mEvaluationSet.copy (mTrainData, int(mTrainData.patterns*(1-mEvalPart)),
00191                          mTrainData.patterns-1);
00192 }
00193 
00210 double LearningEAEnv::evaluateg (const Individual& ind)
00211 {
00212     // Shuffle patterns if such is enabled
00213     if (mPermutate)
00214         permutate ();
00215     
00216     // Create a separate training set and GA evaluation set
00217     PatternSet trainSet, terminSet;
00218     trainSet.copy (mTrainSet, 0, int(mTrainSet.patterns*(1-mTermPart))-1);
00219     terminSet.copy (mTrainSet, int(mTrainSet.patterns*(1-mTermPart)), mTrainSet.patterns-1);
00220 
00221     // Get the I/O interface of the individual and set the parameters
00222     // which it doesn't know yet
00223     const ANNetwork* brainplan = dynamic_cast<const ANNetwork*> (ind.getFeature ("brainplan"));
00224     ASSERT (brainplan);
00225     //io.logDir (mLogDir);
00226 
00227     ANNetwork* brain = new ANNetwork (*brainplan);
00228 
00229     Trainer* pTrainer = createTrainer ();
00230 
00231     // Train the individual for a while
00232     double trainMSE = pTrainer->train (*brain,
00233                                       trainSet,
00234                                       mMaxTrainCycles,
00235                                       &terminSet,
00236                                       mValidInterval);
00237     delete pTrainer;
00238     trainMSE = 0.0; // Dispose. WARNING
00239 
00240     // Measure the fitness of the network with several criteria
00241     
00242     // Test with evaluation set
00243     double fitn_MSE = brain->test (mEvaluationSet);
00244     delete brain;
00245 
00246     double fitn_conns   = 0;
00247     double fitn_hiddens = 0;
00248     double fitn_inputs  = 0;
00249 
00250     // Collect the factors together
00251     double fitness = fitn_MSE*1.0 + fitn_conns*0.0 + fitn_hiddens*0.0 + fitn_inputs*0.0;
00252 
00253     // Print some stats
00254     const Object& stats = ind["stats"];
00255     if (!isnull(stats))
00256         sout.printf (", stats=%s", (CONSTR) dynamic_cast<const String&>(stats));
00257     else
00258         sout.printf (", stats=0 0");
00259     const Object& pConn = ind["pConn"];
00260     if (!isnull(pConn))
00261         sout.printf (", pConn=%s", (CONSTR) dynamic_cast<const String&>(pConn));
00262     
00263     return fitness;
00264 }
00265 
00266 /*******************************************************************************
00267  *
00268 **/
00269 Trainer* LearningEAEnv::createTrainer () const
00270 {
00271     StringMap trainParams;
00272     trainParams.set ("RPropTrainer.delta0", "1.0");
00273     RPropTrainer* pTrainer = new RPropTrainer();
00274     pTrainer->init (trainParams);
00275     pTrainer->setTerminator (mTermMethod);
00276     return pTrainer;
00277 }
00278 
00279 /*******************************************************************************
00280  *
00281 **/
00282 void LearningEAEnv::cycle_report (OStream& log, OStream& out) {
00283     
00284     // Get the I/O interface of the individual
00285     // best->execute (GeneticMsg ("IO", *best));
00286     ASSERT (mpBest);
00287     /*
00288     LearningIO& io = static_cast<LearningIO&> ((*best)["IO"]);
00289 
00290     // Determine where to log the picture of the King
00291     String cycleLogDir = mLogDir; // No generational logging
00292     if (0) { // TODO: Make this configurable if we want more logging
00293         cycleLogDir += format("/gens/gen%04d", mCycles);
00294         system (String("mkdir -p ")+cycleLogDir); // Create the directory
00295     }
00296     io.logDir (cycleLogDir);
00297     */
00298     
00299     // Create a separate training set and GA evaluation set
00300     PatternSet trainSet, terminSet;
00301     trainSet.copy (mTrainSet, 0, int(mTrainSet.patterns*(1-mTermPart))-1);
00302     terminSet.copy (mTrainSet, int(mTrainSet.patterns*(1-mTermPart)),
00303                     mTrainSet.patterns-1);
00304     
00305     ANNetwork& brain = dynamic_cast <ANNetwork&> ((*mpBest)["brainplan"]);
00306     
00307     Trainer* pTrainer = createTrainer ();
00308     
00309     // Train the individual for a while
00310     pTrainer->train (brain, trainSet, mMaxTrainCycles, &terminSet, mValidInterval);
00311     
00312     // Save this to a file
00313     ANNFileFormatLib::save (mLogDir + "/einstein.net", brain);
00314     //brain.saveBrain (mLogDir + "/einstein.net", "Best brain found by Annalee");
00315     
00316     // Save the pattern sets _only_ for the first cycle
00317     /*
00318       if (mCycles==1) {
00319       trainSet.save (mLogDir+"/train.pat");
00320       terminSet.save (mLogDir+"/termin.pat");
00321       mEvaluationSet.save (mLogDir+"/eval.pat");
00322       ((PatternSet&)mReportSet).save (mLogDir+"/report.pat");
00323       }
00324     */
00325     
00326     // Excuse me, I'm looking for the Einstein Brain, can you tell me
00327     // where I can find the Einstein Brain, please?
00328     
00329     // Test with each evaluation pattern while counting the correct
00330     // predictions
00331     switch (mProblemType) {
00332       case CLASSIFICATION:
00333       case CLASSIFICATION2: {
00334           //const ANNetwork& net = io.getNet ();
00335           if (!isnull (brain)) {
00336               ClassifResults* clsresults = brain.testClassify (mReportSet);
00337               double perc = double(clsresults->failures) / double(mReportSet.patterns);
00338               out.printf ("Number of incorrect predictions: "
00339                           "%4d out of %4d (%0.2f%%), mse=%f\n",
00340                           clsresults->failures, mReportSet.patterns, perc*100,
00341                           clsresults->mse);
00342               log.printf ("%f %f", clsresults->mse, perc);
00343               delete clsresults;
00344           } else {
00345               out.printf ("Einstein is out of his mind. "
00346                           "Propably his brain doesn't exist at all\n");
00347               log.printf ("666.0 666.0");
00348           }
00349       } break;
00350       
00351       case APPROXIMATION: {
00352           double mse = brain.test (mReportSet);
00353           out.printf ("MSE=%f", mse);
00354       } break;
00355       
00356       default:
00357           throw generic_exception (format ("Unknown problem type %d for LearningEAEnv",
00358                                            mProblemType));
00359     }
00360     log.flush ();
00361     
00362     // Print any network pictures to corresponding log files
00363     
00364     // Invoke the brainplan to make the network pictures
00365     mpBest->execute (TakeBrainPicsMsg ("brainplan", *mpBest));
00366     
00367     // Check if any of these exists in the individual's properties
00368     char pnames[][20]={"brainpic1","brainpic2","brainpic3","braindesc1"};
00369     char fnames[][20]={"/einstein-pic1.eps","/einstein-pic2.eps",
00370                        "/einstein-pic3.eps","/einstein-desc1.txt"};
00371     for (int p=0; p<4; p++) {
00372         const String& pic = static_cast<const String&> ((*mpBest)[pnames[p]]);
00373         if (!isnull(pic)) {
00374             // A property exists -> save it
00375             FILE* fout = fopen (mLogDir + fnames[p], "w");
00376             ASSERT (fout);
00377             fprintf (fout, "%s", (CONSTR) pic);
00378             fclose (fout);
00379         }
00380     }
00381 }
00382 
00383 /*******************************************************************************
00384  *
00385 **/
00386 DataOStream& LearningEAEnv::operator>> (DataOStream& out) const {
00387     out.name("mProblemType") << mProblemType;
00388     out.name("mEvalPart") << mEvalPart;
00389     out.name("mTermPart") << mTermPart;
00390     out.name("mMaxTrainCycles") << mMaxTrainCycles;
00391     out.name("mReportCycles") << mReportCycles;
00392     out.name("mValidInterval") << mValidInterval;
00393     out.name("mTermMethod") << mTermMethod;
00394     out.name("mPermutate") << mPermutate;
00395     return out;
00396 }
00397 
00398 /*******************************************************************************
00399  *
00400 **/
00401 void LearningEAEnv::check () const {
00402     EAEnvironment::check ();
00403     mTrainData.check ();
00404     mTrainSet.check ();
00405     mEvaluationSet.check ();
00406     mReportSet.check ();
00407     mParams.check ();
00408 
00409     ASSERT (mProblemType>=0 && mProblemType<=2);
00410     ASSERT (mEvalPart>=0 && mEvalPart<1);
00411     ASSERT (mTermPart>=0 && mTermPart<1);
00412     ASSERT (mMaxTrainCycles>=0 && mMaxTrainCycles<100000);
00413     ASSERT (mReportCycles>=0 && mReportCycles<100000);
00414     ASSERT (mValidInterval>=0 && mValidInterval<100000);
00415 }
00416 

Generated on Thu Feb 10 20:21:26 2005 for Annalee by doxygen1.2.18