Main Page   Class Hierarchy   Compound List   File List   Compound Members   File Members  

prediction.h

Go to the documentation of this file.
00001 
00025 #ifndef __PREDICTION_H__
00026 #define __PREDICTION_H__
00027 
00028 #include <magic/mmath.h>
00029 #include <magic/mmatrix.h>
00030 #include <magic/mmap.h>
00031 #include <magic/mclass.h>
00032 #include <magic/mpararr.h>
00033 #include <inanna/annetwork.h>
00034 
00035 class PatternSet;       // In patternset.h
00036 class MatrixEqualizer;  // In equalization.h
00037 class TrainingObserver; // In trainer.h
00038 
00039 
00040 
00042 //                                                                           //
00043 //  |   |                |    |       ___                    ----            //
00044 //  |\ /|        _    |  | _  |       |  \   ___   |   ___  (      ___   |   //
00045 //  | V |  __  |/ \  -+- |/ | | \   | |   |  ___| -+-  ___|  ---  /   ) -+-  //
00046 //  | | | /  \ |   |  |  |  | |  \  | |   | (   |  |  (   |     ) |---   |   //
00047 //  |   | \__/ |   |   \ |  | |   \_/ |__/   \__|   \  \__| ___/   \__    \  //
00048 //                               \_/                                         //
00050 
00053 class MonthlyDataSet : public Matrix {
00054   public:
00055 
00057             MonthlyDataSet  () {mFirstMonth0=mFirstYear=-1;}
00058 
00060             MonthlyDataSet  (const Matrix& matrix, int dateYYMM=0, int lag=0) : Matrix (matrix) {
00061                 if (const MonthlyDataSet* set = dynamic_cast<const MonthlyDataSet*>(&matrix)) {
00062                     mFirstMonth0 = set->mFirstMonth0;
00063                     mFirstYear = set->mFirstYear;
00064                     mLag = set->mLag;
00065                 } else {
00066                     mFirstMonth0 = (dateYYMM%100)-1;
00067                     mFirstYear = dateYYMM/100;
00068                     mLag = lag;
00069                 }
00070             }
00071 
00073     Ref<MonthlyDataSet> sub             (int row0, int row1, int col0, int col1) const;
00074 
00076     int                 lag             () const {return mLag;}
00077 
00079     void                setLag          (int lag) {mLag=lag;}
00080 
00082     int                 firstMonth0     () const {return mFirstMonth0;}
00083 
00085     void                setFirstMonth0  (int m0) {mFirstMonth0=m0;}
00086 
00087   protected:
00088     int     mFirstMonth0;
00089     int     mFirstYear;
00090     int     mLag;
00091 };
00092 
00093 
00095 //                                                                                                          //
00096 // ----                | o           o            -----                 ----                    |           //
00097 // |   )      ___      |    ___   |           _     |    ___   ____  |  |   )  ___   ____       |  |   ____ //
00098 // |---  |/\ /   )  ---| | |   \ -+- |  __  |/ \    |   /   ) (     -+- |---  /   ) (     |   | | -+- (     //
00099 // |     |   |---  (   | | |      |  | /  \ |   |   |   |---   \__   |  | \   |---   \__  |   | |  |   \__  //
00100 // |     |    \__   ---| |  \__/   \ | \__/ |   |   |    \__  ____)   \ |  \   \__  ____)  \__! |   \ ____) //
00101 //                                                                                                          //
00103 
00106 class PredictionTestResults : public Object {
00107   public:
00108                 PredictionTestResults () {}
00109 
00110     void        make        (int variables);
00111     void        addValue    (int variable, double value);
00112     void        calculate   ();
00113     
00114     double      averageError;   // For all variables
00115     double      MSE;            // For all variables
00116     double      RMSE;           // For all variables
00117     Vector      averageErrors;  // For each variable
00118     Vector      MSEs;           // For each variable
00119     Vector      RMSEs;          // For each variable
00120 
00121     OStream&    operator>>	(OStream& out) const;
00122 
00123   private:
00124     int     mValues; 
00126     void operator = (const PredictionTestResults& o) {FORBIDDEN}
00127     PredictionTestResults (const PredictionTestResults& o) {FORBIDDEN}
00128 };
00129 
00130 
00131 
00133 //                                                                                          //
00134 // ----                | o           o             ----                                     //
00135 // |   )      ___      |    ___   |           _   (      |       ___   |   ___              //
00136 // |---  |/\ /   )  ---| | |   \ -+- |  __  |/ \   ---  -+- |/\  ___| -+- /   )  ___  \   | //
00137 // |     |   |---  (   | | |      |  | /  \ |   |     )  |  |   (   |  |  |---  (   \  \  | //
00138 // |     |    \__   ---| |  \__/   \ | \__/ |   | ___/    \ |    \__|   \  \__   ---/   \_/ //
00139 //                                                                               __/   \_/  //
00141 
00146 class PredictionStrategy : public Object {
00147     decl_dynamic (PredictionStrategy);
00148   public:
00149                         PredictionStrategy  () {FORBIDDEN}
00150                         PredictionStrategy  (const String& name) : mName (name) {}
00151 
00152     virtual void        make                (const StringMap& params);
00153     virtual void        train               (const Matrix& traindata, int startmonth);
00154     virtual Ref<Matrix> predict             (const Matrix& testdata, int startmonth) const;
00155     virtual PredictionTestResults*  test    (const Matrix& testdata, int startmonth) const;
00156     virtual void        testCurve           (const Matrix& testdata, int startmonth, const String& filename) const;
00157 
00159     const String&       name                () const {return mName;}
00160     virtual int         inputMonths         () const {return mInputMonths;}
00161     
00162   protected:
00163     String  mName;
00164 
00169     int     mInputMonths;
00170 
00171     static Ref<Matrix>  rowDeltas           (const Matrix& matrix);
00172 };
00173 
00181 class PreviousYear : public PredictionStrategy {
00182     decl_dynamic (PreviousYear);
00183   public:
00184                         PreviousYear            () : PredictionStrategy ("PreviousYear") {mInputMonths=0;}
00185                         PreviousYear            (const StringMap& params);
00186     virtual void        train                   (const Matrix& traindata, int startmonth);
00187     virtual Ref<Matrix> predict                 (const Matrix& testdata, int startmonth) const;
00188     virtual int         inputMonths             () const {return 0;}
00189   protected:
00190     Matrix          mData;
00191 };
00192 
00200 class PreviousYearsAvg : public PredictionStrategy {
00201     decl_dynamic (PreviousYearsAvg);
00202   public:
00203                         PreviousYearsAvg        () : PredictionStrategy ("PreviousYearsAvg") {mInputMonths=0;}
00204                         PreviousYearsAvg        (const StringMap& params);
00205     virtual void        train                   (const Matrix& traindata, int startmonth);
00206     virtual Ref<Matrix> predict                 (const Matrix& testdata, int startmonth) const;
00207     virtual int         inputMonths             () const {return 0;}
00208 
00209   protected:
00210     Matrix          mMonthlyAvg;
00211 };
00212 
00221 class AverageDeltaPrediction : public PredictionStrategy {
00222     decl_dynamic (AverageDeltaPrediction);
00223   public:
00224                         AverageDeltaPrediction  () : PredictionStrategy ("AverageDelta") {mInputMonths=1;}
00225                         AverageDeltaPrediction  (const StringMap& params);
00226     virtual void        train                   (const Matrix& traindata, int startmonth);
00227     virtual Ref<Matrix> predict                 (const Matrix& testdata, int startmonth) const;
00228     const Matrix&       deltas                  () const {return mMonthAvg;}
00229     virtual int         inputMonths             () const {return 1;}
00230   protected:
00231     Matrix          mMonthAvg;
00232     PackArray<int>  mMonthCnt;
00233 };
00234 
00243 class CombinedPrediction : public PredictionStrategy {
00244     decl_dynamic (CombinedPrediction);
00245   public:
00246                         CombinedPrediction      () : PredictionStrategy ("CombinedPrediction") {}
00247                         CombinedPrediction      (const StringMap& params);
00248     virtual void        make                    (const StringMap& params);
00249     virtual void        train                   (const Matrix& traindata, int startmonth);
00250     virtual Ref<Matrix> predict                 (const Matrix& testdata, int startmonth) const;
00251 
00252   protected:
00253     Array<PredictionStrategy>   mPredictors;
00254     PackTable<int>          mPredictorChoises;
00255 
00256   private:
00257     int                 determineInputMonths    () const;
00258 };
00259 
00267 class ZeroDeltaPrediction : public PredictionStrategy {
00268     decl_dynamic (ZeroDeltaPrediction);
00269   public:
00270                         ZeroDeltaPrediction     () : PredictionStrategy ("ZeroDelta") {mInputMonths=1;}
00271                         ZeroDeltaPrediction     (const StringMap& params);
00272     virtual void        train                   (const Matrix& traindata, int startmonth);
00273     virtual Ref<Matrix> predict                 (const Matrix& testdata, int startmonth) const;
00274     virtual int         inputMonths             () const {return 1;}
00275   protected:
00276 };
00277 
00278 
00280 // |   |                       | ----                | o           o            //
00281 // |\  |  ___             ___  | |   )      ___      |    ___   |           _   //
00282 // | \ | /   ) |   | |/\  ___| | |---  |/\ /   )  ---| | |   \ -+- |  __  |/ \  //
00283 // |  \| |---  |   | |   (   | | |     |   |---  (   | | |      |  | /  \ |   | //
00284 // |   |  \__   \__! |    \__| | |     |    \__   ---| |  \__/   \ | \__/ |   | //
00286 
00292 class AbsoluteNeuralPrediction : public PredictionStrategy {
00293     decl_dynamic (AbsoluteNeuralPrediction);
00294   public:
00295                         AbsoluteNeuralPrediction    (const char* name=NULL) : PredictionStrategy (name? name:"AbsoluteNeural")  {mpNetwork=NULL; rpObserver=NULL;}
00296                         ~AbsoluteNeuralPrediction   ();
00297     virtual void        make                        (const StringMap& params);
00298     virtual void        train                       (const Matrix& traindata, int startmonth);
00299     virtual Ref<Matrix> predict                     (const Matrix& testdata, int startmonth) const;
00300 
00301     void                setObserver                 (TrainingObserver* observer) {rpObserver=observer;}
00302     virtual void        load                        (TextIStream& in);
00303     virtual void        save                        (TextOStream& out) const;
00304 
00305   protected:
00307     ANNetwork*          mpNetwork;
00308 
00312     bool                mUseAllOutputs;
00313 
00318     bool                mUseAllInputs;
00319 
00323     int                 mVariable;
00324 
00326     bool                mGlobalEqualization;
00327 
00331     String              mHiddenTopology;
00332 
00334     StringMap           mParams;
00335 
00339     TrainingObserver*   rpObserver;
00340 
00341   protected:
00343     PatternSet*         makeSet                 (const Matrix& data, int startmonth) const;
00344 
00345   private:
00346     int                 inputVariables          (int datacolumns) const {return 12+mInputMonths*(mUseAllInputs? datacolumns : 1);}
00347     int                 outputVariables         (int datacolumns) const {return mUseAllOutputs? datacolumns : 1;}
00348 };
00349 
00350 class SingleNeuralPrediction : public AbsoluteNeuralPrediction {
00351     decl_dynamic (SingleNeuralPrediction);
00352   public:
00353                         SingleNeuralPrediction      () : AbsoluteNeuralPrediction ("SingleNeural")  {mpNetwork=NULL;}
00354     virtual void        make                        (const StringMap& params);
00355     virtual void        train                       (const Matrix& traindata, int startmonth);
00356     virtual Ref<Matrix> predict                     (const Matrix& testdata, int startmonth) const;
00357 
00358     virtual void        load                        (TextIStream& in);
00359     virtual void        save                        (TextOStream& out) const;
00360 
00361   protected:
00362     Array<ANNetwork>    mNetworks;
00363 };
00364 
00368 class DeltaNeuralPrediction : public PredictionStrategy {
00369     decl_dynamic (DeltaNeuralPrediction);
00370   public:
00371                         DeltaNeuralPrediction   () : PredictionStrategy ("DeltaNeural") {}
00372                         DeltaNeuralPrediction   (const StringMap& params);
00373     virtual void        train                   (const Matrix& traindata, int startmonth);
00374     virtual Ref<Matrix> predict                 (const Matrix& testdata, int startmonth) const;
00375   protected:
00376 };
00377 
00379 class StochasticPrediction : public PredictionStrategy {
00380   public:
00381     virtual void        make                        (const StringMap& params);
00382     virtual void        train                       (const Matrix& traindata, int startmonth);
00383     virtual Ref<Matrix> predict                     (const Matrix& testdata, int startmonth) const;
00384 
00385   protected:
00386     virtual Matrix  predictRun                  ();
00387 };
00388 
00390 class NeuralPrediction {
00391   public:
00392   protected:
00393 };
00394 
00395 
00396 
00398 //   _             | o     o              ----                                     |     o       //
00399 //  / \            |    |      _         (      |       ___   |   ___              |       |     //
00400 // /   \ |   |  ---| | -+- | |/ \   ___   ---  -+- |/\  ___| -+- /   )  ___  \   | |     | |---  //
00401 // |---| |   | (   | |  |  | |   | (   \     )  |  |   (   |  |  |---  (   \  \  | |     | |   ) //
00402 // |   |  \__!  ---| |   \ | |   |  ---/ ___/    \ |    \__|   \  \__   ---/   \_/ |____ | |__/  //
00403 //                                  __/                                 __/   \_/                //
00405 
00408 class PredictionStrategyLib {
00409   public:
00410                                     PredictionStrategyLib ();
00411     PredictionStrategy*             create      (int i) const;
00412     int                             strategies  () const {return mClassNames.size();}
00413     void                            registerCls (const String& classname) {mClassNames.add(classname);}
00414     static  PredictionStrategyLib&  instance    () {return sInstance;}
00415   protected:
00416     static PredictionStrategyLib    sInstance;
00417     Array<String>                   mClassNames;
00418 };
00419 
00420 
00421 
00423 // -----           o       ----                                                  //
00424 //   |        ___      _   |   )  ___       ___         ___   |   ___       ____ //
00425 //   |   |/\  ___| | |/ \  |---   ___| |/\  ___| |/|/| /   ) -+- /   ) |/\ (     //
00426 //   |   |   (   | | |   | |     (   | |   (   | | | | |---   |  |---  |    \__  //
00427 //   |   |    \__| | |   | |      \__| |    \__| | | |  \__    \  \__  |   ____) //
00429 
00433 class TrainParameters {
00434   public:
00435                     TrainParameters () {defaults();}
00436 
00437     void            defaults        ();
00438     String          hiddenString    () const;
00439     Ref<StringMap>  getParams       () const;
00440     void            setParams       (const StringMap& map);
00441     void            write           (TextOStream& out) const;
00442     void            read            (TextIStream& in);
00443 
00444   public:
00445     int     trainCycles;
00446     float   delta0;
00447     float   deltaMax;
00448     float   weightDecay;
00449     bool    useWeightDecay;
00450     int     hidden1;
00451     int     hidden2;
00452     int     hidden3;
00453     int     runs;
00454     bool    allInputs;
00455     bool    allOutputs;
00456     bool    monthInputs;
00457     bool    equalizeGlobal;
00458     int     inputMonths;
00459 };
00460 
00461 #endif
00462 
00463 

Generated on Thu Feb 10 20:06:45 2005 for Inanna by doxygen1.2.18