00001
00025 #include <magic/mmath.h>
00026 #include "inanna/patternset.h"
00027 #include "inanna/termination.h"
00028 #include "inanna/annetwork.h"
00029
00030
00032
00033
00034
00035
00036
00037
00038
00040
00041 Learner::Learner ()
00042 {
00043 }
00044
00064 double Learner::train (const PatternSet& set, int cycles, int cycint, int vsize)
00065 {
00066 MUST_OVERLOAD;
00067 return 0.0;
00068 }
00069
00086 double Learner::train (const PatternSet& trainSet,
00087 const PatternSet& validationSet,
00088 int cycles,
00089 int cycint)
00090 {
00091 NOT_IMPLEMENTED;
00092 return 0.0;
00093 }
00094
00099 double Learner::trainOnce (const PatternSet& trainset)
00100 {
00101 MUST_OVERLOAD;
00102 return 0.0;
00103 }
00104
00110 Vector Learner::testPattern (const PatternSource& set,
00111 int pattern ) const
00112 {
00113 MUST_OVERLOAD;
00114 return Vector(1);
00115 }
00116
00121 double Learner::test (const PatternSource& set) const
00122 {
00123 ASSERT (set.patterns>0);
00124
00125 double errorSum = 0.0;
00126 for (int p=0; p<set.patterns; p++) {
00127 Vector res = testPattern (set, p);
00128 ASSERT (res.size() == set.outputs);
00129
00130 for (int j=0; j<res.size(); j++)
00131 errorSum += sqr (res[j] - set.output (p, j));
00132 }
00133
00134 return errorSum / (set.patterns * set.outputs);
00135 }
00136
00146 ClassifResults* Learner::testClassify (const PatternSource& set) const
00147 {
00148 ASSERT (set.patterns>0);
00149
00150 ClassifResults* result = new ClassifResults;
00151
00152
00153 int classes = (set.outputs>1)? set.outputs : 2;
00154
00155
00156 result->classcnts.make (classes);
00157 for (int i=0; i<classes; i++)
00158 result->classcnts[i] = 0;
00159
00160
00161 result->classSizes.make (classes);
00162 for (int i=0; i<classes; i++)
00163 result->classSizes[i] = 0;
00164
00165
00166 int failures=0;
00167 double errorSum = 0.0;
00168 for (int p=0; p<set.patterns; p++) {
00169
00170 int correctClass = set.getClass (p);
00171
00172
00173 bool success=false;
00174
00175 Vector res = testPattern (set, p);
00176
00177
00178 for (int j=0; j<res.size(); j++)
00179 errorSum += sqr (res[j] - set.output (p, j));
00180
00181 if (set.outputs==1) {
00182 if (correctClass == int(res[0]+0.5))
00183 success = true;
00184 } else {
00185 int highestClass = maxIndex (res);
00186 if (highestClass == correctClass)
00187 success = true;
00188 }
00189
00190
00191 if (!success) {
00192 failures++;
00193
00194
00195 result->classcnts[correctClass]++;
00196 }
00197
00198
00199 result->classSizes[correctClass]++;
00200 }
00201
00202
00203 result->mse = errorSum/(set.patterns*set.outputs);
00204 result->failures = failures;
00205
00206 return result;
00207 }
00208
00219 void Learner::copyFreeNet (const ANNetwork& fnet,
00220 bool onlyWeights)
00221 {
00222 copy (static_cast<const Learner&> (fnet));
00223 }
00224
00235 void Learner::copy (const Learner& fnet,
00236 bool onlyWeights)
00237 {
00238 }
00239
00243 ANNetwork* Learner::toANNetwork () const
00244 {
00245 MUST_OVERLOAD;
00246 return NULL;
00247 }
00248
00249