00001
00025 #ifndef __TRAINSET_H__
00026 #define __TRAINSET_H__
00027
00028 #include <magic/mobject.h>
00029 #include <magic/mmath.h>
00030 #include <magic/mmatrix.h>
00031 #include <magic/mattribute.h>
00032 #include <magic/mpararr.h>
00033
00034
00036
00037
00038
00039
00040
00042
00045 class PatternSource : public Object, public Attributed {
00046 decl_dynamic (PatternSource);
00047 public:
00048 PatternSource ();
00049
00051 PatternSource (const PatternSource& orig);
00052
00054 virtual void print (FILE* out = stdout) const {MUST_OVERLOAD}
00055
00057 virtual double input (int p, int i) const {MUST_OVERLOAD; return 0.0;}
00058
00060 virtual double output (int p, int j) const {MUST_OVERLOAD; return 0.0;}
00061 virtual void set_input (int p, int i, double value) {MUST_OVERLOAD;}
00062 virtual void set_output (int p, int j, double value) {MUST_OVERLOAD;}
00063
00067 virtual int getClass (int p) const;
00068
00074 virtual void recombine (int startp=-1, int endp=-1) {MUST_OVERLOAD;}
00075
00086 virtual void recombine2 (int startp=-1, int endp=-1) {MUST_OVERLOAD;}
00087
00088
00089
00095 virtual void copy (const PatternSource& other, int startp=-1, int endp=-1);
00096
00098 void operator= (const PatternSource& other) {copy (other);}
00099
00101 const String& name () const {return mName;}
00102
00104 void setName (const String& name) {mName=name;}
00105
00108 void join (const PatternSource& a, const PatternSource& b);
00109
00113 void split (PatternSource& a, PatternSource& b, double ratio) const;
00114
00119 void filter (const PatternSource& source, const String& bits);
00120
00127 int classes () const {return (outputs==1)? 2:outputs;}
00128
00138 Array<int> countClasses () const;
00139
00141 void check () const;
00142
00144 int inputs;
00145
00147 int outputs;
00148
00150 int patterns;
00151
00153 enum psetflags {INPUTS=1, OUTPUTS=2};
00154
00155 protected:
00157 String mName;
00158
00162 virtual void make (int patterns, int inputs, int outputs) {MUST_OVERLOAD}
00163 void make1 (int patterns, int inputs, int outputs);
00164
00165 };
00166
00167 #define TrainingSet PatternSource
00168
00169
00170
00172
00173
00174
00175
00176
00178
00181 class PatternSet : public PatternSource {
00182 public:
00183
00184 PatternSet () {}
00185
00188 PatternSet (int patts, int ins, int outs);
00189
00190 PatternSet (const String& fname, int ins=0, int outs=0);
00191
00193 PatternSet (const Matrix& m, int ins=-1, int outs=0);
00194
00196 PatternSet (const PatternSet& orig);
00197
00198 ~PatternSet () {}
00199
00206 void make (int patts=0, int ins=0, int outs=0);
00207
00208 enum filetypes {FT_SNNS=0, FT_RAW=1};
00209
00210 virtual void load (const char* filename);
00211 virtual void load (TextIStream& in, const String& extension=".raw");
00212 void save (const String& filename, int filetype=FT_SNNS);
00213
00214
00215
00217 virtual void print (FILE* out = stdout) const;
00218
00220 virtual double input (int p, int i) const {return mInps.get (p,i);}
00221
00223 virtual double output (int p, int j) const {return mOutps.get (p,j);}
00224
00226 virtual void set_input (int p, int i, double value) {mInps.get (p,i) = value;}
00227
00229 virtual void set_output (int p, int j, double value) {mOutps.get (p,j) = value;}
00230
00232 Ref<Matrix> getMatrix () const;
00233
00238 void mutate (int errcnt);
00239
00243 void recombine (int startp=-1, int endp=-1);
00244
00246 void recombine2 (int startp=-1, int endp=-1);
00247
00254 virtual void copy (const PatternSet& orig, int start=-1, int end=-1);
00255
00257 void check () const;
00258
00259 protected:
00261 Matrix mInps;
00262
00264 Matrix mOutps;
00265
00267
00268
00270 void calc_means ();
00271
00273 void calc_stddevs ();
00274
00275 private:
00276 void operator= (const PatternSet& orig) {FORBIDDEN}
00277 };
00278
00279
00280
00282
00283
00284
00285
00286
00287
00288
00290
00300 class ArrayTrainSet : public PatternSource {
00301 public:
00302
00309 ArrayTrainSet (int inputs, int outputs, char* filename);
00310 ~ArrayTrainSet () {;}
00311
00313 virtual void make (int patterns, int inputs, int outputs) {FORBIDDEN}
00314
00315
00316
00317 virtual void print (FILE* out = stdout) const;
00318 virtual double input (int p, int i) const {return data[p+i];}
00319 virtual double output (int p, int j) const {return data[p+inputs+j];}
00320
00321
00322
00325 void multiply (double k);
00326
00329 void truncate (int size);
00330
00331 public:
00336 Array<double> data;
00337
00338 };
00339
00340 #endif
00341