00001 
00025 #include <ctype.h>
00026 #include <magic/mobject.h>
00027 #include <magic/mlist.h>
00028 #include <magic/mmap.h>
00029 #include <magic/mregexp.h>
00030 #include <magic/mtextstream.h>
00031 
00032 #include "inanna/dataformat.h"
00033 #include "inanna/dataformats.h"
00034 
00035 
00037 
00038 
00039 
00040 
00041 
00043 
00044 struct PatternRow {
00045     Vector  values;
00046     String  comment;
00047     PatternRow (int n) {values.make(n);}
00048 };
00049 
00050 void SNNSDataFormat::load (TextIStream& in, PatternSet& set) const {
00051 
00052     
00053     RegExp no_patts ("patterns : (.+)");
00054     RegExp no_ins   ("input units : (.+)");
00055     RegExp no_outs  ("output units : (.+)");
00056     RegExp inpatt   ("Input pattern (.+):");
00057     RegExp outpatt  ("Output pattern (.+):");
00058     
00059     
00060     enum states {ST_NONE        = 0,
00061                  ST_INS         = 1,
00062                  ST_OUTS        = 2,
00063                  ST_NONE_INS    = 3,
00064                  ST_NONE_OUTS   = 4
00065                 };
00066     int             state   = ST_NONE;  
00067     int             readcnt = 0;        
00068     String          buff;               
00069     Array<String>   linesubs;           
00070     int             pattern = 0;        
00071     int             patts   = 0;
00072     int             ins     = 0;
00073     int             outs    = 0;
00074     
00075     while (in.readLine(buff)) {
00076 
00077         if (isdigit (buff[0])) {
00078             if (!set.inputs || !set.outputs || !set.patterns)
00079                 throw invalid_format ("Pattern file dimensions not given anywhere");
00080 
00081             
00082             
00083             
00084             
00085             
00086             
00087             switch (state) { 
00088               case ST_NONE_OUTS:
00089                   if (pattern>0 && readcnt!=set.outputs)
00090                       throw invalid_format (format ("Too short output vector #d in SNNS "
00091                                                     "pattern set", pattern));
00092                   pattern++;
00093               case ST_NONE:
00094                   state = ST_INS;
00095                   readcnt=0;
00096               case ST_INS:
00097                   if (readcnt<set.inputs)
00098                       set.set_input (pattern, readcnt++, double(String (buff).toDouble()));
00099                   break;
00100               case ST_NONE_INS:
00101                   if (pattern>0 && readcnt!=set.inputs)
00102                       throw invalid_format (format ("Too short input vector #d in SNNS "
00103                                                     "pattern set", pattern));
00104                   state = ST_OUTS;
00105                   readcnt = 0;
00106               case ST_OUTS:
00107                   if (readcnt<set.outputs)
00108                       set.set_output (pattern, readcnt++, double (String (buff).toDouble()));
00109             }
00110             
00111         } else {
00112             
00113             in.readLine (buff);
00114 
00115             
00116             if (no_patts.match (buff, linesubs)) {
00117                 patts = linesubs[1].toInt();
00118                 if (set.patterns && patts!=set.patterns)
00119                     throw invalid_format ("Pattern file has wrong dimensions");
00120             }
00121 
00122             
00123             if (no_ins.match (buff, linesubs)) {
00124                 ins = linesubs[1].toInt();
00125                 if (set.inputs && ins!=set.inputs)
00126                     throw invalid_format ("Pattern file has wrong dimensions");
00127             }
00128 
00129             
00130             if (no_outs.match (buff, linesubs)) {
00131                 outs = linesubs[1].toInt();
00132                 if (set.outputs && outs!=set.outputs)
00133                     throw invalid_format ("Pattern file has wrong dimensions");
00134             }
00135 
00136             
00137             
00138             if ((!set.inputs || !set.outputs) && ins && outs && set.patterns)
00139                 set.make (patts, ins, outs);
00140 
00141             
00142             if (state==ST_INS)
00143                 state = ST_NONE_INS;
00144             if (state==ST_OUTS)
00145                 state = ST_NONE_OUTS;
00146         }
00147     }
00148 
00149     if (pattern!=set.patterns-1)
00150         throw invalid_format (format ("Wrong number of patterns in SNNS pattern file "
00151                                       "(found %d of %d", pattern, set.patterns));
00152 }
00153 
00154 void SNNSDataFormat::save (FILE* out, const PatternSet& set) const {
00155     fprintf (out, "SNNS pattern definition file V3.2\n"
00156              "generated at Xxxxx time\n\n\n"
00157              "No. of patterns : %d\n"
00158              "No. of input units : %d\n"
00159              "No. of output units : %d\n\n",
00160              set.patterns, set.inputs, set.outputs);
00161 
00162     for (int p=0; p<set.patterns; p++) {
00163         fprintf (out, "# input %d\n", p);
00164         for (int i=0; i<set.inputs; i++) {
00165             if (i>0)
00166                 fprintf (out, " ");
00167             if (is_undef (set.input (p,i)))
00168                 fprintf (out, "0");
00169             else 
00170                 fprintf (out, "%g", set.input (p,i));
00171         }
00172         fprintf (out, "\n# target %d\n", p);
00173         for (int i=0; i<set.outputs; i++) {
00174             fprintf (out, " ");
00175             if (is_undef (set.output (p,i)))
00176                 fprintf (out, "0");
00177             else 
00178                 fprintf (out, "%g", set.output (p,i));
00179         }
00180         fprintf (out, "\n");
00181     }
00182 }
00183 
00184 
00186 
00187 
00188 
00189 
00190 
00192 
00193 void Proben1DataFormat::load (TextIStream& in, PatternSet& set) const
00194 {
00195     String buff;    
00196     
00197     
00198     String header;
00199     header.reserve (256);
00200     for (int i=0; i<7; i++) {
00201         in.readLine (buff);
00202         buff.chop();
00203         ASSERT (buff.length()>0);
00204         ASSERT (buff.find ("=")!=-1);
00205         header += buff;
00206         header += '&';
00207     }
00208 
00209     
00210     StringMap map;
00211     splitpairs (map, header, '=', '&');
00212 
00213     
00214     bool kludge=false;
00215     if (isnull(map["bool_in"])) kludge=true;
00216     if (isnull(map["real_in"])) kludge=true;
00217     if (isnull(map["bool_out"])) kludge=true;
00218     if (isnull(map["real_out"])) kludge=true;
00219     if (isnull(map["training_examples"])) kludge=true;
00220     if (isnull(map["validation_examples"])) kludge=true;
00221     if (isnull(map["test_examples"])) kludge=true;
00222     if (kludge)
00223         throw generic_exception ("Some required line(s) were missing from a "
00224                                  "Proben1 data file header");
00225 
00226     
00227     int ins  = map["bool_in"].toInt() + map["real_in"].toInt();
00228     int outs = map["bool_out"].toInt() + map["real_out"].toInt();
00229     int pats = map["training_examples"].toInt() + map["validation_examples"].toInt()
00230         + map["test_examples"].toInt();
00231 
00232     
00233     set.make (pats, ins, outs);
00234 
00235     
00236     for (int p=0; p<pats; p++) {
00237         
00238         in.readLine (buff);
00239         buff.chop ();
00240 
00241         
00242         Array<String> items;
00243         buff.split (items, ' ');
00244 
00245         
00246         int pos=0;
00247         Array<String> fields (items.size());
00248         for (int i=0; i<items.size(); i++)
00249             if (items[i].length()>0)
00250                 fields[pos++] = items[i];
00251         fields.resize (pos);
00252         
00253         ASSERTWITH (fields.size()==set.inputs+set.outputs,
00254                     format ("Wrong number of fields (%d) in a Proben1 file "
00255                             "(%d was expected).\n"
00256                             "Propably due to bad formatting in line:\n"
00257                             "%s\n"
00258                             "Rules: number of fields must be equal to "
00259                             "the sum of inputs and outputs and they \n"
00260                             "must be separated by spaces (ASC32). No "
00261                             "other characters are allowed\n",
00262                             fields.size(), set.inputs+set.outputs, (CONSTR) buff));
00263 
00264         
00265         for (int i=0; i<ins; i++)
00266             set.set_input (p,i, fields[i].toDouble());
00267 
00268         
00269         for (int i=0; i<outs; i++)
00270             set.set_output (p,i, fields[i+ins].toDouble());
00271     }
00272 
00273     
00274     set.setAttribute ("training_examples", new Int (map["training_examples"].toInt()));
00275     set.setAttribute ("validation_examples", new Int (map["validation_examples"].toInt()));
00276     set.setAttribute ("test_examples", new Int (map["test_examples"].toInt()));
00277 }
00278 
00279 
00280 
00282 
00283 
00284 
00285 
00286 
00288 
00289 void RawDataFormat::load (TextIStream& in, PatternSet& set) const
00290 {
00291     List<PatternRow*> patlist;  
00292     int patNum=0;               
00293     bool comments=false;        
00294 
00295     
00296     String buffer;
00297     while (in.readLine (buffer)) { 
00298         PatternRow* prow = new PatternRow (set.inputs? set.inputs+set.outputs : 1);
00299         String item;        
00300         int itemIndex=0;    
00301 
00302         
00303         for (uint i=0; i <= buffer.length(); i++) {
00304             
00305             char c = (i==buffer.length())? '\n' : buffer[i];
00306 
00307             
00308             if (isspace (c)) {
00309                 if (item.length()>0) { 
00310                     
00311                     
00312                     if (itemIndex >= set.inputs)
00313                         prow->values.resize (itemIndex+1);
00314                     
00315                     
00316                     if (set.inputs==0 || itemIndex<set.inputs+set.outputs)
00317                         prow->values[itemIndex] = (item=="x")? UNDEFINED_FLOAT : item.toDouble();
00318                     item="";
00319                     
00320                     itemIndex++;
00321                 }
00322                 
00323             } else if (item.length()==0 && !isdigit(c) && c!='x' && c!='-' && c!='+' && c!='.') {
00324                 
00325                 prow->comment = buffer.mid (i);
00326                 comments = true;
00327                 break;
00328             } else 
00329                 item += c;
00330         }
00331 
00332         
00333         while (itemIndex>0 && isempty (String(prow->values[itemIndex-1])))
00334             --itemIndex;
00335 
00336         
00337         if (itemIndex>0) {
00338             
00339             
00340             if (set.inputs==0)
00341                 set.inputs = itemIndex;
00342 
00343             
00344             if (itemIndex==set.inputs)
00345                 patlist.add (prow);
00346             else {
00347                 delete prow;
00348                 throw invalid_format (strformat ("Pattern set row %d has invalid number of fields (%d out of %d)", patNum+1, itemIndex, set.inputs));
00349             }
00350             
00351             patNum++;
00352         } else
00353             
00354             delete prow;
00355     }
00356 
00357     
00358     set.make (patNum, set.inputs, set.outputs);
00359 
00360     
00361     int p=0;
00362     for (ListIter<PatternRow*> l (patlist); !l.exhausted(); l.next(), p++) {
00363         for (int i=0; i<set.inputs; i++)
00364             set.set_input (p, i, l.get()->values[i]);
00365         for (int i=0; i<set.outputs; i++)
00366             set.set_output (p, i, l.get()->values[i+set.inputs]);
00367     }
00368 
00369     
00370     if (comments) {
00371         Array<String>* comments = new Array<String> (patNum);
00372         int p=0;
00373         for (ListIter<PatternRow*> l (patlist); !l.exhausted(); l.next(), p++)
00374             (*comments)[p] = l.get()->comment;
00375         set.setAttribute ("comments", comments);
00376     }
00377 }
00378  
00379 void RawDataFormat::save (FILE* out, const PatternSet& set) const
00380 {
00381     for (int p=0; p<set.patterns; p++) {
00382         for (int i=0; i<set.inputs; i++) {
00383             if (i>0)
00384                 fprintf (out, " ");
00385             if (is_undef (set.input (p,i)))
00386                 fprintf (out, "x");
00387             else 
00388                 fprintf (out, "%g", set.input (p,i));
00389         }
00390         for (int i=0; i<set.outputs; i++) {
00391             fprintf (out, " ");
00392             if (is_undef (set.output (p,i)))
00393                 fprintf (out, "x");
00394             else 
00395                 fprintf (out, "%g", set.output (p,i));
00396         }
00397         if (!isnull(set.getAttribute("comments")))
00398             fprintf (out, " %s", (CONSTR) dynamic_cast<const Array<String>&> (set.getAttribute("comments"))[p]);
00399         fprintf (out, "\n");
00400     }
00401 }
00402 
00403