Main Page   Class Hierarchy   Compound List   File List   Compound Members   File Members  

learning.cc

Go to the documentation of this file.
00001 
00025 #include <magic/mmath.h> 
00026 #include "inanna/patternset.h"
00027 #include "inanna/termination.h"
00028 #include "inanna/annetwork.h"
00029 
00030 
00032 //                                                                           //
00033 //                   |                                                       //
00034 //                   |      ___   ___        _    ___                        //
00035 //                   |     /   )  ___| |/\ |/ \  /   ) |/\                   //
00036 //                   |     |---  (   | |   |   | |---  |                     //
00037 //                   |____  \__   \__| |   |   |  \__  |                     //
00038 //                                                                           //
00040 
00041 Learner::Learner ()
00042 {
00043 }
00044 
00064 double Learner::train (const PatternSet& set, int cycles, int cycint, int vsize)
00065 {
00066     MUST_OVERLOAD;
00067     return 0.0;
00068 }
00069 
00086 double Learner::train (const PatternSet& trainSet,
00087                        const PatternSet& validationSet,
00088                        int cycles,
00089                        int cycint)
00090 {
00091     NOT_IMPLEMENTED;
00092     return 0.0;
00093 }
00094 
00099 double Learner::trainOnce (const PatternSet& trainset)
00100 {
00101     MUST_OVERLOAD;
00102     return 0.0;
00103 }
00104 
00110 Vector Learner::testPattern (const PatternSource& set, 
00111                              int pattern               ) const
00112 {
00113     MUST_OVERLOAD;
00114     return Vector(1);
00115 }
00116 
00121 double Learner::test (const PatternSource& set) const
00122 {
00123     ASSERT (set.patterns>0);
00124     
00125     double errorSum = 0.0; // Sum of squared errors (SSE)
00126     for (int p=0; p<set.patterns; p++) {
00127         Vector res = testPattern (set, p);
00128         ASSERT (res.size() == set.outputs);
00129         
00130         for (int j=0; j<res.size(); j++)
00131             errorSum += sqr (res[j] - set.output (p, j));
00132     }
00133     
00134     return errorSum / (set.patterns * set.outputs); // Mean of squared errors (MSE)
00135 }
00136 
00146 ClassifResults* Learner::testClassify (const PatternSource& set) const
00147 {
00148     ASSERT (set.patterns>0);
00149 
00150     ClassifResults* result = new ClassifResults;
00151 
00152     // Determine the number of classes (not so trivial)
00153     int classes = (set.outputs>1)? set.outputs : 2;
00154                    
00155     // Initialize the class hit counter
00156     result->classcnts.make (classes);
00157     for (int i=0; i<classes; i++)
00158         result->classcnts[i] = 0;
00159 
00160     // Initialize the class instance counter
00161     result->classSizes.make (classes);
00162     for (int i=0; i<classes; i++)
00163         result->classSizes[i] = 0;
00164     
00165     // Classify each pattern in the set
00166     int failures=0;
00167     double errorSum = 0.0; // Sum of squared errors (SSE)
00168     for (int p=0; p<set.patterns; p++) {
00169         // Find the correct class 
00170         int correctClass = set.getClass (p);
00171         
00172         // Determine success
00173         bool success=false;
00174 
00175         Vector res = testPattern (set, p);
00176 
00177         // Record the SSE
00178         for (int j=0; j<res.size(); j++)
00179             errorSum += sqr (res[j] - set.output (p, j));
00180 
00181         if (set.outputs==1) {
00182             if (correctClass == int(res[0]+0.5))
00183                 success = true;
00184         } else {
00185             int highestClass = maxIndex (res);
00186             if (highestClass == correctClass)
00187                 success = true;
00188         }
00189         
00190         // Record the success
00191         if (!success) {
00192             failures++;
00193             
00194             // For the particular class
00195             result->classcnts[correctClass]++;
00196         }
00197 
00198         // And increment the number of instances for this particular class
00199         result->classSizes[correctClass]++;
00200     }
00201     
00202     // Return the mean
00203     result->mse = errorSum/(set.patterns*set.outputs); // Mean of squared errors (MSE)
00204     result->failures = failures;
00205 
00206     return result;
00207 }
00208 
00219 void Learner::copyFreeNet (const ANNetwork& fnet,
00220                            bool             onlyWeights)
00221 {
00222     copy (static_cast<const Learner&> (fnet));
00223 }
00224 
00235 void Learner::copy (const Learner& fnet,
00236                     bool           onlyWeights)
00237 {
00238 }
00239 
00243 ANNetwork* Learner::toANNetwork () const
00244 {
00245     MUST_OVERLOAD;
00246     return NULL;
00247 }
00248 
00249 

Generated on Thu Feb 10 20:06:44 2005 for Inanna by doxygen1.2.18