00001
00025 #include <ctype.h>
00026 #include <magic/mobject.h>
00027 #include <magic/mlist.h>
00028 #include <magic/mmap.h>
00029 #include <magic/mregexp.h>
00030 #include <magic/mtextstream.h>
00031
00032 #include "inanna/dataformat.h"
00033 #include "inanna/dataformats.h"
00034
00035
00037
00038
00039
00040
00041
00043
00044 struct PatternRow {
00045 Vector values;
00046 String comment;
00047 PatternRow (int n) {values.make(n);}
00048 };
00049
00050 void SNNSDataFormat::load (TextIStream& in, PatternSet& set) const {
00051
00052
00053 RegExp no_patts ("patterns : (.+)");
00054 RegExp no_ins ("input units : (.+)");
00055 RegExp no_outs ("output units : (.+)");
00056 RegExp inpatt ("Input pattern (.+):");
00057 RegExp outpatt ("Output pattern (.+):");
00058
00059
00060 enum states {ST_NONE = 0,
00061 ST_INS = 1,
00062 ST_OUTS = 2,
00063 ST_NONE_INS = 3,
00064 ST_NONE_OUTS = 4
00065 };
00066 int state = ST_NONE;
00067 int readcnt = 0;
00068 String buff;
00069 Array<String> linesubs;
00070 int pattern = 0;
00071 int patts = 0;
00072 int ins = 0;
00073 int outs = 0;
00074
00075 while (in.readLine(buff)) {
00076
00077 if (isdigit (buff[0])) {
00078 if (!set.inputs || !set.outputs || !set.patterns)
00079 throw invalid_format ("Pattern file dimensions not given anywhere");
00080
00081
00082
00083
00084
00085
00086
00087 switch (state) {
00088 case ST_NONE_OUTS:
00089 if (pattern>0 && readcnt!=set.outputs)
00090 throw invalid_format (format ("Too short output vector #d in SNNS "
00091 "pattern set", pattern));
00092 pattern++;
00093 case ST_NONE:
00094 state = ST_INS;
00095 readcnt=0;
00096 case ST_INS:
00097 if (readcnt<set.inputs)
00098 set.set_input (pattern, readcnt++, double(String (buff).toDouble()));
00099 break;
00100 case ST_NONE_INS:
00101 if (pattern>0 && readcnt!=set.inputs)
00102 throw invalid_format (format ("Too short input vector #d in SNNS "
00103 "pattern set", pattern));
00104 state = ST_OUTS;
00105 readcnt = 0;
00106 case ST_OUTS:
00107 if (readcnt<set.outputs)
00108 set.set_output (pattern, readcnt++, double (String (buff).toDouble()));
00109 }
00110
00111 } else {
00112
00113 in.readLine (buff);
00114
00115
00116 if (no_patts.match (buff, linesubs)) {
00117 patts = linesubs[1].toInt();
00118 if (set.patterns && patts!=set.patterns)
00119 throw invalid_format ("Pattern file has wrong dimensions");
00120 }
00121
00122
00123 if (no_ins.match (buff, linesubs)) {
00124 ins = linesubs[1].toInt();
00125 if (set.inputs && ins!=set.inputs)
00126 throw invalid_format ("Pattern file has wrong dimensions");
00127 }
00128
00129
00130 if (no_outs.match (buff, linesubs)) {
00131 outs = linesubs[1].toInt();
00132 if (set.outputs && outs!=set.outputs)
00133 throw invalid_format ("Pattern file has wrong dimensions");
00134 }
00135
00136
00137
00138 if ((!set.inputs || !set.outputs) && ins && outs && set.patterns)
00139 set.make (patts, ins, outs);
00140
00141
00142 if (state==ST_INS)
00143 state = ST_NONE_INS;
00144 if (state==ST_OUTS)
00145 state = ST_NONE_OUTS;
00146 }
00147 }
00148
00149 if (pattern!=set.patterns-1)
00150 throw invalid_format (format ("Wrong number of patterns in SNNS pattern file "
00151 "(found %d of %d", pattern, set.patterns));
00152 }
00153
00154 void SNNSDataFormat::save (FILE* out, const PatternSet& set) const {
00155 fprintf (out, "SNNS pattern definition file V3.2\n"
00156 "generated at Xxxxx time\n\n\n"
00157 "No. of patterns : %d\n"
00158 "No. of input units : %d\n"
00159 "No. of output units : %d\n\n",
00160 set.patterns, set.inputs, set.outputs);
00161
00162 for (int p=0; p<set.patterns; p++) {
00163 fprintf (out, "# input %d\n", p);
00164 for (int i=0; i<set.inputs; i++) {
00165 if (i>0)
00166 fprintf (out, " ");
00167 if (is_undef (set.input (p,i)))
00168 fprintf (out, "0");
00169 else
00170 fprintf (out, "%g", set.input (p,i));
00171 }
00172 fprintf (out, "\n# target %d\n", p);
00173 for (int i=0; i<set.outputs; i++) {
00174 fprintf (out, " ");
00175 if (is_undef (set.output (p,i)))
00176 fprintf (out, "0");
00177 else
00178 fprintf (out, "%g", set.output (p,i));
00179 }
00180 fprintf (out, "\n");
00181 }
00182 }
00183
00184
00186
00187
00188
00189
00190
00192
00193 void Proben1DataFormat::load (TextIStream& in, PatternSet& set) const
00194 {
00195 String buff;
00196
00197
00198 String header;
00199 header.reserve (256);
00200 for (int i=0; i<7; i++) {
00201 in.readLine (buff);
00202 buff.chop();
00203 ASSERT (buff.length()>0);
00204 ASSERT (buff.find ("=")!=-1);
00205 header += buff;
00206 header += '&';
00207 }
00208
00209
00210 StringMap map;
00211 splitpairs (map, header, '=', '&');
00212
00213
00214 bool kludge=false;
00215 if (isnull(map["bool_in"])) kludge=true;
00216 if (isnull(map["real_in"])) kludge=true;
00217 if (isnull(map["bool_out"])) kludge=true;
00218 if (isnull(map["real_out"])) kludge=true;
00219 if (isnull(map["training_examples"])) kludge=true;
00220 if (isnull(map["validation_examples"])) kludge=true;
00221 if (isnull(map["test_examples"])) kludge=true;
00222 if (kludge)
00223 throw generic_exception ("Some required line(s) were missing from a "
00224 "Proben1 data file header");
00225
00226
00227 int ins = map["bool_in"].toInt() + map["real_in"].toInt();
00228 int outs = map["bool_out"].toInt() + map["real_out"].toInt();
00229 int pats = map["training_examples"].toInt() + map["validation_examples"].toInt()
00230 + map["test_examples"].toInt();
00231
00232
00233 set.make (pats, ins, outs);
00234
00235
00236 for (int p=0; p<pats; p++) {
00237
00238 in.readLine (buff);
00239 buff.chop ();
00240
00241
00242 Array<String> items;
00243 buff.split (items, ' ');
00244
00245
00246 int pos=0;
00247 Array<String> fields (items.size());
00248 for (int i=0; i<items.size(); i++)
00249 if (items[i].length()>0)
00250 fields[pos++] = items[i];
00251 fields.resize (pos);
00252
00253 ASSERTWITH (fields.size()==set.inputs+set.outputs,
00254 format ("Wrong number of fields (%d) in a Proben1 file "
00255 "(%d was expected).\n"
00256 "Propably due to bad formatting in line:\n"
00257 "%s\n"
00258 "Rules: number of fields must be equal to "
00259 "the sum of inputs and outputs and they \n"
00260 "must be separated by spaces (ASC32). No "
00261 "other characters are allowed\n",
00262 fields.size(), set.inputs+set.outputs, (CONSTR) buff));
00263
00264
00265 for (int i=0; i<ins; i++)
00266 set.set_input (p,i, fields[i].toDouble());
00267
00268
00269 for (int i=0; i<outs; i++)
00270 set.set_output (p,i, fields[i+ins].toDouble());
00271 }
00272
00273
00274 set.setAttribute ("training_examples", new Int (map["training_examples"].toInt()));
00275 set.setAttribute ("validation_examples", new Int (map["validation_examples"].toInt()));
00276 set.setAttribute ("test_examples", new Int (map["test_examples"].toInt()));
00277 }
00278
00279
00280
00282
00283
00284
00285
00286
00288
00289 void RawDataFormat::load (TextIStream& in, PatternSet& set) const
00290 {
00291 List<PatternRow*> patlist;
00292 int patNum=0;
00293 bool comments=false;
00294
00295
00296 String buffer;
00297 while (in.readLine (buffer)) {
00298 PatternRow* prow = new PatternRow (set.inputs? set.inputs+set.outputs : 1);
00299 String item;
00300 int itemIndex=0;
00301
00302
00303 for (uint i=0; i <= buffer.length(); i++) {
00304
00305 char c = (i==buffer.length())? '\n' : buffer[i];
00306
00307
00308 if (isspace (c)) {
00309 if (item.length()>0) {
00310
00311
00312 if (itemIndex >= set.inputs)
00313 prow->values.resize (itemIndex+1);
00314
00315
00316 if (set.inputs==0 || itemIndex<set.inputs+set.outputs)
00317 prow->values[itemIndex] = (item=="x")? UNDEFINED_FLOAT : item.toDouble();
00318 item="";
00319
00320 itemIndex++;
00321 }
00322
00323 } else if (item.length()==0 && !isdigit(c) && c!='x' && c!='-' && c!='+' && c!='.') {
00324
00325 prow->comment = buffer.mid (i);
00326 comments = true;
00327 break;
00328 } else
00329 item += c;
00330 }
00331
00332
00333 while (itemIndex>0 && isempty (String(prow->values[itemIndex-1])))
00334 --itemIndex;
00335
00336
00337 if (itemIndex>0) {
00338
00339
00340 if (set.inputs==0)
00341 set.inputs = itemIndex;
00342
00343
00344 if (itemIndex==set.inputs)
00345 patlist.add (prow);
00346 else {
00347 delete prow;
00348 throw invalid_format (strformat ("Pattern set row %d has invalid number of fields (%d out of %d)", patNum+1, itemIndex, set.inputs));
00349 }
00350
00351 patNum++;
00352 } else
00353
00354 delete prow;
00355 }
00356
00357
00358 set.make (patNum, set.inputs, set.outputs);
00359
00360
00361 int p=0;
00362 for (ListIter<PatternRow*> l (patlist); !l.exhausted(); l.next(), p++) {
00363 for (int i=0; i<set.inputs; i++)
00364 set.set_input (p, i, l.get()->values[i]);
00365 for (int i=0; i<set.outputs; i++)
00366 set.set_output (p, i, l.get()->values[i+set.inputs]);
00367 }
00368
00369
00370 if (comments) {
00371 Array<String>* comments = new Array<String> (patNum);
00372 int p=0;
00373 for (ListIter<PatternRow*> l (patlist); !l.exhausted(); l.next(), p++)
00374 (*comments)[p] = l.get()->comment;
00375 set.setAttribute ("comments", comments);
00376 }
00377 }
00378
00379 void RawDataFormat::save (FILE* out, const PatternSet& set) const
00380 {
00381 for (int p=0; p<set.patterns; p++) {
00382 for (int i=0; i<set.inputs; i++) {
00383 if (i>0)
00384 fprintf (out, " ");
00385 if (is_undef (set.input (p,i)))
00386 fprintf (out, "x");
00387 else
00388 fprintf (out, "%g", set.input (p,i));
00389 }
00390 for (int i=0; i<set.outputs; i++) {
00391 fprintf (out, " ");
00392 if (is_undef (set.output (p,i)))
00393 fprintf (out, "x");
00394 else
00395 fprintf (out, "%g", set.output (p,i));
00396 }
00397 if (!isnull(set.getAttribute("comments")))
00398 fprintf (out, " %s", (CONSTR) dynamic_cast<const Array<String>&> (set.getAttribute("comments"))[p]);
00399 fprintf (out, "\n");
00400 }
00401 }
00402
00403