00001
00025 #include <stdio.h>
00026 #include <ctype.h>
00027 #include <magic/mobject.h>
00028 #include <magic/mmath.h>
00029 #include <magic/mmap.h>
00030 #include <magic/mclass.h>
00031 #include "inanna/patternset.h"
00032 #include "inanna/dataformat.h"
00033
00034
00035
00036
00037 impl_abstract (PatternSource, {CObject});
00038
00039 PatternSource::PatternSource () {
00040 patterns = 0;
00041 inputs = 0;
00042 outputs = 0;
00043 }
00044
00045 PatternSource::PatternSource (const PatternSource& orig) {
00046 patterns = orig.patterns;
00047 inputs = orig.inputs;
00048 outputs = orig.outputs;
00049 mName = orig.mName;
00050 }
00051
00052 void PatternSource::make1 (int patts, int ins, int outs) {
00053 patterns = patts;
00054 inputs = ins;
00055 outputs = outs;
00056 }
00057
00058 int PatternSource::getClass (int p) const {
00059 if (outputs>1) {
00060
00061 double max=-666;
00062 int maxI=0;
00063 for (int i=0; i<outputs; i++)
00064 if (output(p,i) > max) {
00065 max = output(p,i);
00066 maxI = i;
00067 }
00068 return maxI;
00069 } else {
00070
00071 return (output(p,0)>0.5)? 1:0;
00072 }
00073 }
00074
00075 void PatternSource::copy (const PatternSource& other, int startp, int endp) {
00076 ASSERTWITH (startp<other.patterns && endp<other.patterns || endp==startp-1,
00077 format ("Copy range (%d-%d) out of source set limits (%d patterns)",
00078 startp, endp, other.patterns));
00079 ASSERTWITH (startp>=-1, format ("Invalid copy range %d", startp));
00080
00081
00082 if (endp==startp-1)
00083 return;
00084
00085 ASSERTWITH (startp<=endp, "Start pattern must be greater than end pattern. "
00086 "Note: end point can be start-1, in which case no copying is done.");
00087
00088
00089 if (startp==-1)
00090 startp = 0;
00091 if (endp==-1)
00092 endp = other.patterns-1;
00093
00094
00095 make (endp-startp+1, other.inputs, other.outputs);
00096
00097
00098 for (int p=0; p<=endp-startp; p++) {
00099 for (int i=0; i<inputs; i++)
00100 set_input (p, i, other.input (p+startp, i));
00101 for (int i=0; i<outputs; i++)
00102 set_output (p, i, other.output (p+startp, i));
00103 }
00104 }
00105
00106 void PatternSource::join (const PatternSource& a, const PatternSource& b) {
00107
00108 if (a.patterns==0 || b.patterns==0) {
00109
00110 make (0,0,0);
00111 if (a.patterns>0)
00112 copy (a);
00113 if (b.patterns>0)
00114 copy (b);
00115 return;
00116 }
00117
00118
00119
00120 ASSERTWITH (a.inputs==b.inputs && a.outputs==b.outputs,
00121 format ("Tsets to be joined must be of equal dimensions (was: %d->%d . %d->%d)",
00122 a.inputs, a.outputs, b.inputs, b.outputs));
00123 ASSERTWITH (a.patterns==0 || b.patterns==0 ||
00124 a.inputs>0 && a.outputs>0 && b.inputs>0 && b.outputs>0,
00125 "Tsets to be joined may not have null dimension");
00126 make (a.patterns+b.patterns, a.inputs, a.outputs);
00127
00128
00129 for (int p=0; p<a.patterns; p++) {
00130 for (int i=0; i<inputs; i++)
00131 set_input (p, i, a.input (p, i));
00132 for (int i=0; i<outputs; i++)
00133 set_output (p, i, a.output (p, i));
00134 }
00135
00136
00137 for (int p=0; p<b.patterns; p++) {
00138 for (int i=0; i<inputs; i++)
00139 set_input (p+a.patterns, i, b.input (p, i));
00140 for (int i=0; i<outputs; i++)
00141 set_output (p+a.patterns, i, b.output (p, i));
00142 }
00143 }
00144
00145 void PatternSource::split (PatternSource& a, PatternSource& b, double ratio) const {
00146 ASSERT (ratio>=0 && ratio<=1);
00147 ASSERT (patterns>0);
00148 ASSERT (inputs>0);
00149
00150 a.make (int(patterns*ratio), inputs, outputs);
00151 b.make (patterns-int(patterns*ratio), inputs, outputs);
00152
00153 a.copy (*this, 0, int(patterns*ratio)-1);
00154 b.copy (*this, int(patterns*ratio), patterns-1);
00155
00156
00157
00158
00159
00160
00161 }
00162
00163 void PatternSource::filter (const PatternSource& src, const String& bits) {
00164 ASSERTWITH (src.patterns>0 && src.inputs+src.outputs>0,
00165 "Filter source TSet must have dimensions");
00166 ASSERTWITH (uint(src.inputs) == bits.length() || bits.length()==0,
00167 format ("Filter mask must have length (was %d) "
00168 "equal to source TSet input vector dimension (was %d)",
00169 bits.length(), src.inputs));
00170
00171
00172 int features=0;
00173 if (bits.length()==0)
00174 features = src.inputs;
00175 else
00176 for (uint i=0; i<bits.length(); i++)
00177 if (bits[i] == '1')
00178 features++;
00179
00180 ASSERTWITH (features>0, "Training set filter may not be empty");
00181
00182
00183 make (src.patterns, features, src.outputs);
00184
00185
00186 for (int p=0; p<patterns; p++) {
00187
00188
00189 int feature=0;
00190 for (int i=0; i<src.inputs; i++)
00191 if (bits.length()==0 || bits[i]=='1')
00192 set_input (p, feature++, src.input (p, i));
00193
00194
00195 for (int j=0; j<src.outputs; j++)
00196 set_output (p, j, src.output (p, j));
00197 }
00198 }
00199
00200
00201 Array<int> PatternSource::countClasses () const {
00202 Array<int> result;
00203 result.make (classes());
00204
00205 for (int c=0; c<result.size(); c++)
00206 result[c] = 0;
00207 for (int p=0; p<patterns; p++)
00208 result[getClass (p)]++;
00209 return result;
00210 }
00211
00212 void PatternSource::check () const {
00213 ASSERT (inputs>=0);
00214 ASSERT (inputs<1000);
00215 ASSERT (outputs>=0);
00216 ASSERT (outputs<1000);
00217 ASSERT (patterns>=0);
00218 ASSERT (patterns<100000);
00219 }
00220
00221
00222
00224
00225
00226
00227
00228
00230
00231 PatternSet::PatternSet (int patts, int ins, int outs) {
00232 patterns = patts;
00233 inputs = ins;
00234 outputs = outs;
00235
00236
00237 if (patts && ins && outs)
00238 make (patts, ins, outs);
00239 }
00240
00248 PatternSet::PatternSet (const String& filename,
00249 int ins,
00250 int outs )
00251 {
00252 patterns = 0;
00253 inputs = ins;
00254 outputs = outs;
00255
00256 load (filename);
00257 }
00258
00259 PatternSet::PatternSet (const Matrix& m, int mins, int mouts)
00260 {
00261 ASSERT (mins+mouts == m.cols || mins==-1);
00262 if (mins==-1)
00263 mins = m.cols;
00264 make (m.rows, mins, mouts);
00265 for (int r=0; r<m.rows; r++) {
00266 for (int c=0; c<mins; c++)
00267 set_input (r,c,m.get(r,c));
00268 for (int c=0; c<mouts; c++)
00269 set_output (r,c+mins,m.get(r,c+mins));
00270 }
00271 }
00272
00273 PatternSet::PatternSet (const PatternSet& orig) : PatternSource (orig), mInps(orig.mInps), mOutps(orig.mOutps) {
00274 check ();
00275 }
00276
00277 void PatternSet::make (int patts, int ins, int outs) {
00278
00279 PatternSource::make1 (patts, ins, outs);
00280 mInps.make (patts, ins);
00281 mOutps.make (patts, outs);
00282 }
00283
00289 void PatternSet::load (const char* filename)
00290 {
00291 DataFormatLib::load (filename, *this);
00292 }
00293
00298 void PatternSet::load (TextIStream& in, const String& extension)
00299 {
00300 DataFormatLib::load (in, *this, extension);
00301 }
00302
00307 void PatternSet::save (const String& filename,
00308 int filetype )
00309 {
00310 DataFormatLib::save (filename, *this);
00311 }
00312
00313 void PatternSet::print (FILE* out) const
00314 {
00315 if (!out)
00316 out=stdout;
00317
00318 for (int p=0; p<mInps.rows; p++) {
00319 fprintf (out, "# Input pattern %d:\n", p);
00320 for (int i=0; i<mInps.cols; i++)
00321 fprintf (out, "%f ", mInps.get (p,i));
00322 fprintf (out, "\n");
00323 fprintf (out, "# Output pattern %d:\n", p);
00324 for (int i=0; i<mOutps.cols; i++)
00325 fprintf (out, "%f ", mOutps.get (p,i));
00326 fprintf (out, "\n");
00327 }
00328 }
00329
00330 void PatternSet::mutate (int errcnt) {
00331 int mutated[errcnt];
00332 int r,
00333 clear;
00334
00335 for (int p=0; p<mInps.rows; p++)
00336 for (int i=0; i<errcnt; i++) {
00337 clear = 0;
00338 while (!clear) {
00339 r = rnd (mInps.cols)+1;
00340 clear=1;
00341 for (int j=0; j<i-1; j++)
00342 if (mutated[j]==r)
00343 clear=0;
00344 }
00345 mutated[i] = r;
00346 mInps.get (p,r) = 1-mInps.get (p,r);
00347 }
00348 }
00349
00350 void PatternSet::recombine (int startp, int endp) {
00351 if (startp==-1)
00352 startp=0;
00353 if (endp==-1)
00354 endp=patterns-1;
00355
00356 ASSERTWITH (startp<=endp, "Range error in recombination");
00357
00358
00359 for (int p=0; p<patterns; p++) {
00360
00361
00362
00363 int o=rnd (patterns);
00364
00365
00366 for (int i=0; i<inputs; i++)
00367 swap (mInps.get (p, i), mInps.get (o, i));
00368
00369
00370 for (int i=0; i<outputs; i++)
00371 swap (mOutps.get (p, i), mOutps.get (o, i));
00372 }
00373 }
00374
00375 void PatternSet::recombine2 (int startp, int endp) {
00376 if (startp==-1)
00377 startp=0;
00378 if (endp==-1)
00379 endp=patterns-1;
00380
00381 ASSERTWITH (startp<=endp, "Range error in recombination");
00382
00383
00384 for (int p=0; p<patterns; p++) {
00385
00386
00387
00388 int odd = p%2;
00389 int o=0;
00390 do {
00391 o=rnd (patterns/2)*2+odd;
00392 } while (o>=patterns);
00393
00394
00395 for (int i=0; i<inputs; i++)
00396 swap (mInps.get (p, i), mInps.get (o, i));
00397
00398
00399 for (int i=0; i<outputs; i++)
00400 swap (mOutps.get (p, i), mOutps.get (o, i));
00401 }
00402 }
00403
00404 void PatternSet::check () const {
00405 PatternSource::check ();
00406
00407 }
00408
00409 void PatternSet::copy (const PatternSet& orig, int start, int end) {
00410 PatternSource::copy (orig, start, end);
00411 }
00412
00413 Ref<Matrix> PatternSet::getMatrix () const {
00414 Ref<Matrix> result = new Matrix (patterns, inputs+outputs);
00415 for (int p=0; p<result->rows; p++)
00416 for (int c=0; c<result->cols; c++)
00417 if (c<inputs)
00418 result->get(p,c) = input(p,c);
00419 else
00420 result->get(p,c) = output(p,c-inputs);
00421 return result;
00422 }
00423
00424
00426
00427
00428
00429
00430
00431
00432
00434
00435 ArrayTrainSet::ArrayTrainSet (int ins, int outs, char* filename) {
00436 PatternSource::make1 (0, ins, outs);
00437
00438 FILE* fin = fopen (filename, "r");
00439 if (!fin)
00440 throw new runtime_error ("file not found");
00441
00442 while (!feof (fin)) {
00443 double x;
00444 fscanf (fin, "%lf", &x);
00445 data.add (new double (x));
00446 }
00447 fclose (fin);
00448
00449 patterns = data.size()-inputs-outputs;
00450 }
00451
00452 void ArrayTrainSet::print (FILE* out) const {
00453
00454
00455
00456
00457
00458
00459 for (int i=0; i<data.size(); i++) {
00460 fprintf (out, "%f ", data[i]);
00461 }
00462 fprintf (out, "\n");
00463 }
00464
00465 void ArrayTrainSet::multiply (double k) {
00466 for (int i=0; i<data.size(); i++)
00467 data[i] *= k;
00468 }
00469
00470 void ArrayTrainSet::truncate (int newsize) {
00471 PRE (newsize<data.size() && newsize>=(inputs+outputs));
00472 data.resize (newsize);
00473 patterns = data.size() - inputs - outputs;
00474 }
00475