Main Page   Class Hierarchy   Compound List   File List   Compound Members   File Members  

patternset.cc

Go to the documentation of this file.
00001 
00025 #include <stdio.h>
00026 #include <ctype.h>
00027 #include <magic/mobject.h>
00028 #include <magic/mmath.h>
00029 #include <magic/mmap.h>
00030 #include <magic/mclass.h>
00031 #include "inanna/patternset.h"
00032 #include "inanna/dataformat.h"
00033 
00034 
00035 // #define rnd(range) (rand()%range)
00036 
00037 impl_abstract (PatternSource, {CObject});
00038 
00039 PatternSource::PatternSource () {
00040     patterns    = 0;
00041     inputs      = 0;
00042     outputs     = 0;
00043 }
00044 
00045 PatternSource::PatternSource (const PatternSource& orig) {
00046     patterns    = orig.patterns;
00047     inputs      = orig.inputs;
00048     outputs     = orig.outputs;
00049     mName       = orig.mName;
00050 }
00051 
00052 void PatternSource::make1 (int patts, int ins, int outs) {
00053     patterns    = patts;
00054     inputs      = ins;
00055     outputs     = outs;
00056 }
00057 
00058 int PatternSource::getClass (int p) const {
00059     if (outputs>1) {
00060         // Multiple class indicators -> scan trough them to find the highest
00061         double max=-666;
00062         int maxI=0;
00063         for (int i=0; i<outputs; i++)
00064             if (output(p,i) > max) {
00065                 max = output(p,i);
00066                 maxI = i;
00067             }
00068         return maxI;
00069     } else {
00070         // Single, binary class indicator
00071         return (output(p,0)>0.5)? 1:0;
00072     }
00073 }
00074 
00075 void PatternSource::copy (const PatternSource& other, int startp, int endp) {
00076     ASSERTWITH (startp<other.patterns && endp<other.patterns || endp==startp-1,
00077                 format ("Copy range (%d-%d) out of source set limits (%d patterns)",
00078                         startp, endp, other.patterns));
00079     ASSERTWITH (startp>=-1, format ("Invalid copy range %d", startp));
00080 
00081     // If the range is empty, do nothing
00082     if (endp==startp-1)
00083         return;
00084     
00085     ASSERTWITH (startp<=endp, "Start pattern must be greater than end pattern. "
00086                 "Note: end point can be start-1, in which case no copying is done.");
00087 
00088     // Set parameters to default of not given
00089     if (startp==-1)
00090         startp = 0;
00091     if (endp==-1)
00092         endp = other.patterns-1;
00093 
00094     // Resize self
00095     make (endp-startp+1, other.inputs, other.outputs);
00096 
00097     // Copy
00098     for (int p=0; p<=endp-startp; p++) {
00099         for (int i=0; i<inputs; i++)
00100             set_input (p, i, other.input (p+startp, i));
00101         for (int i=0; i<outputs; i++)
00102             set_output (p, i, other.output (p+startp, i));
00103     }
00104 }
00105 
00106 void PatternSource::join (const PatternSource& a, const PatternSource& b) {
00107     // First check the possibility that either or both of the sets might be empty
00108     if (a.patterns==0 || b.patterns==0) {
00109         // These cases we handle just by copying. For some reason.
00110         make (0,0,0);
00111         if (a.patterns>0)
00112             copy (a);
00113         if (b.patterns>0)
00114             copy (b);
00115         return;
00116     }
00117 
00118     // Not empty
00119     
00120     ASSERTWITH (a.inputs==b.inputs && a.outputs==b.outputs,
00121                 format ("Tsets to be joined must be of equal dimensions (was: %d->%d . %d->%d)",
00122                         a.inputs, a.outputs, b.inputs, b.outputs));
00123     ASSERTWITH (a.patterns==0 || b.patterns==0 ||
00124                 a.inputs>0 && a.outputs>0 && b.inputs>0 && b.outputs>0,
00125                 "Tsets to be joined may not have null dimension");
00126     make (a.patterns+b.patterns, a.inputs, a.outputs);
00127 
00128     // Copy a
00129     for (int p=0; p<a.patterns; p++) {
00130         for (int i=0; i<inputs; i++)
00131             set_input (p, i, a.input (p, i));
00132         for (int i=0; i<outputs; i++)
00133             set_output (p, i, a.output (p, i));
00134     }
00135 
00136     // Copy b
00137     for (int p=0; p<b.patterns; p++) {
00138         for (int i=0; i<inputs; i++)
00139             set_input (p+a.patterns, i, b.input (p, i));
00140         for (int i=0; i<outputs; i++)
00141             set_output (p+a.patterns, i, b.output (p, i));
00142     }
00143 }
00144 
00145 void PatternSource::split (PatternSource& a, PatternSource& b, double ratio) const {
00146     ASSERT (ratio>=0 && ratio<=1);
00147     ASSERT (patterns>0);
00148     ASSERT (inputs>0);
00149 
00150     a.make (int(patterns*ratio), inputs, outputs);
00151     b.make (patterns-int(patterns*ratio), inputs, outputs);
00152 
00153     a.copy (*this, 0, int(patterns*ratio)-1);
00154     b.copy (*this, int(patterns*ratio), patterns-1);
00155 
00156     /*
00157     ASSERTWITH (a.patterns>0 && b.patterns>0,
00158                 "Resulting pattern set was empty (ratio was maybe too near 0.0 or 1.0,\n"
00159                 "    or the source set was too small?)");
00160                 */
00161 }
00162 
00163 void PatternSource::filter (const PatternSource& src, const String& bits) {
00164     ASSERTWITH (src.patterns>0 && src.inputs+src.outputs>0,
00165                 "Filter source TSet must have dimensions");
00166     ASSERTWITH (uint(src.inputs) == bits.length() || bits.length()==0,
00167                 format ("Filter mask must have length (was %d) "
00168                         "equal to source TSet input vector dimension (was %d)",
00169                         bits.length(), src.inputs));
00170     
00171     // Count features
00172     int features=0;
00173     if (bits.length()==0)
00174         features = src.inputs;
00175     else
00176         for (uint i=0; i<bits.length(); i++)
00177             if (bits[i] == '1')
00178                 features++;
00179     
00180     ASSERTWITH (features>0, "Training set filter may not be empty");
00181 
00182     // Re-make self to this size
00183     make (src.patterns, features, src.outputs);
00184 
00185     // For each pattern
00186     for (int p=0; p<patterns; p++) {
00187         
00188         // Copy the input features by filtering
00189         int feature=0;
00190         for (int i=0; i<src.inputs; i++)
00191             if (bits.length()==0 || bits[i]=='1')
00192                 set_input (p, feature++, src.input (p, i));
00193         
00194         // Copy the outputs
00195         for (int j=0; j<src.outputs; j++)
00196             set_output (p, j, src.output (p, j));
00197     }
00198 }
00199 
00200 
00201 Array<int> PatternSource::countClasses () const {
00202     Array<int> result;
00203     result.make (classes());
00204     
00205     for (int c=0; c<result.size(); c++)
00206         result[c] = 0;
00207     for (int p=0; p<patterns; p++)
00208         result[getClass (p)]++;
00209     return result;
00210 }
00211 
00212 void PatternSource::check () const {
00213     ASSERT (inputs>=0);
00214     ASSERT (inputs<1000);
00215     ASSERT (outputs>=0);
00216     ASSERT (outputs<1000);
00217     ASSERT (patterns>=0);
00218     ASSERT (patterns<100000);
00219 }
00220 
00221 
00222 
00224 //            ----                                 ----                      //
00225 //            |   )  ___   |   |   ___        _   (      ___   |             //
00226 //            |---   ___| -+- -+- /   ) |/\ |/ \   ---  /   ) -+-            //
00227 //            |     (   |  |   |  |---  |   |   |     ) |---   |             //
00228 //            |      \__|   \   \  \__  |   |   | ___/   \__    \            //
00230 
00231 PatternSet::PatternSet (int patts, int ins, int outs) {
00232     patterns = patts;
00233     inputs   = ins;
00234     outputs  = outs;
00235 
00236     // If we got the size parameters, build the structure here
00237     if (patts && ins && outs)
00238         make (patts, ins, outs);
00239 }
00240 
00248 PatternSet::PatternSet (const String& filename, 
00249                         int           ins,      
00250                         int           outs      )
00251 {
00252     patterns = 0;
00253     inputs   = ins;
00254     outputs  = outs;
00255 
00256     load (filename);
00257 }
00258 
00259 PatternSet::PatternSet (const Matrix& m, int mins, int mouts)
00260 {
00261     ASSERT (mins+mouts == m.cols || mins==-1);
00262     if (mins==-1)
00263         mins = m.cols;
00264     make (m.rows, mins, mouts);
00265     for (int r=0; r<m.rows; r++) {
00266         for (int c=0; c<mins; c++) 
00267             set_input (r,c,m.get(r,c));
00268         for (int c=0; c<mouts; c++)
00269             set_output (r,c+mins,m.get(r,c+mins));
00270     }
00271 }
00272 
00273 PatternSet::PatternSet (const PatternSet& orig) : PatternSource (orig), mInps(orig.mInps), mOutps(orig.mOutps) {
00274     check ();
00275 }
00276 
00277 void PatternSet::make (int patts, int ins, int outs) {
00278     // ASSERTWITH (patts>0 && ins>0 && outs>0, "Zero argument not allowed");
00279     PatternSource::make1 (patts, ins, outs);
00280     mInps.make (patts, ins);
00281     mOutps.make (patts, outs);
00282 }
00283 
00289 void PatternSet::load (const char* filename)
00290 {
00291     DataFormatLib::load (filename, *this);
00292 }
00293 
00298 void PatternSet::load (TextIStream& in, const String& extension)
00299 {
00300     DataFormatLib::load (in, *this, extension);
00301 }
00302 
00307 void PatternSet::save (const String& filename, 
00308                        int           filetype  )
00309 {
00310     DataFormatLib::save (filename, *this);
00311 }
00312 
00313 void PatternSet::print (FILE* out) const
00314 {
00315     if (!out)
00316         out=stdout;
00317 
00318     for (int p=0; p<mInps.rows; p++) {
00319         fprintf (out, "# Input pattern %d:\n", p);
00320         for (int i=0; i<mInps.cols; i++)
00321             fprintf (out, "%f ", mInps.get (p,i));
00322         fprintf (out, "\n");
00323         fprintf (out, "# Output pattern %d:\n", p);
00324         for (int i=0; i<mOutps.cols; i++)
00325             fprintf (out, "%f ", mOutps.get (p,i));
00326         fprintf (out, "\n");
00327     }
00328 }
00329 
00330 void PatternSet::mutate (int errcnt) {
00331     int mutated[errcnt];
00332     int r,
00333         clear;
00334 
00335     for (int p=0; p<mInps.rows; p++)
00336         for (int i=0; i<errcnt; i++) {
00337             clear = 0;
00338             while (!clear) {
00339                 r = rnd (mInps.cols)+1;
00340                 clear=1;
00341                 for (int j=0; j<i-1; j++)
00342                     if (mutated[j]==r)
00343                         clear=0;
00344             }
00345             mutated[i]      = r;
00346             mInps.get (p,r) = 1-mInps.get (p,r);
00347         }
00348 }
00349 
00350 void PatternSet::recombine (int startp, int endp) {
00351     if (startp==-1)
00352         startp=0;
00353     if (endp==-1)
00354         endp=patterns-1;
00355     
00356     ASSERTWITH (startp<=endp, "Range error in recombination");
00357 
00358     // For each pattern
00359     for (int p=0; p<patterns; p++) {
00360 
00361         // Swap with another, random pattern
00362 
00363         int o=rnd (patterns);
00364 
00365         // Swap inputs
00366         for (int i=0; i<inputs; i++)
00367             swap (mInps.get (p, i), mInps.get (o, i));
00368 
00369         // Swap outputs
00370         for (int i=0; i<outputs; i++)
00371             swap (mOutps.get (p, i), mOutps.get (o, i));
00372     }
00373 }
00374 
00375 void PatternSet::recombine2 (int startp, int endp) {
00376     if (startp==-1)
00377         startp=0;
00378     if (endp==-1)
00379         endp=patterns-1;
00380     
00381     ASSERTWITH (startp<=endp, "Range error in recombination");
00382 
00383     // For each pattern
00384     for (int p=0; p<patterns; p++) {
00385 
00386         // Swap with another, random pattern
00387 
00388         int odd = p%2;
00389         int o=0;
00390         do {
00391             o=rnd (patterns/2)*2+odd;
00392         } while (o>=patterns);
00393 
00394         // Swap inputs
00395         for (int i=0; i<inputs; i++)
00396             swap (mInps.get (p, i), mInps.get (o, i));
00397 
00398         // Swap outputs
00399         for (int i=0; i<outputs; i++)
00400             swap (mOutps.get (p, i), mOutps.get (o, i));
00401     }
00402 }
00403 
00404 void PatternSet::check () const {
00405     PatternSource::check ();
00406     
00407 }
00408 
00409 void PatternSet::copy (const PatternSet& orig, int start, int end) {
00410     PatternSource::copy (orig, start, end);
00411 }
00412 
00413 Ref<Matrix> PatternSet::getMatrix () const {
00414     Ref<Matrix> result = new Matrix (patterns, inputs+outputs);
00415     for (int p=0; p<result->rows; p++)
00416         for (int c=0; c<result->cols; c++)
00417             if (c<inputs)
00418                 result->get(p,c) = input(p,c);
00419             else
00420                 result->get(p,c) = output(p,c-inputs);
00421     return result;
00422 }
00423 
00424 
00426 //                                                                           //
00427 //       _                       -----           o        ----               //
00428 //      / \           ___          |        ___      _   (      ___   |      //
00429 //     /   \ |/\ |/\  ___| \   |   |   |/\  ___| | |/ \   ---  /   ) -+-     //
00430 //     |---| |   |   (   |  \  |   |   |   (   | | |   |     ) |---   |      //
00431 //     |   | |   |    \__|   \_/   |   |    \__| | |   | ___/   \__    \     //
00432 //                          \_/                                              //
00434 
00435 ArrayTrainSet::ArrayTrainSet (int ins, int outs, char* filename) {
00436     PatternSource::make1 (0, ins, outs);
00437 
00438     FILE* fin = fopen (filename, "r");
00439     if (!fin)
00440         throw new runtime_error ("file not found");
00441 
00442     while (!feof (fin)) {
00443         double  x;
00444         fscanf (fin, "%lf", &x);
00445         data.add (new double (x));
00446     }
00447     fclose (fin);
00448 
00449     patterns = data.size()-inputs-outputs;
00450 }
00451 
00452 void ArrayTrainSet::print (FILE* out) const {
00453 /*  out << data.upper << " data items\n"
00454         << patterns << " patterns\n"
00455         << inputs   << " inputs\n"
00456         << outputs  << " outputs\n";
00457     out << "t = {";
00458 */
00459     for (int i=0; i<data.size(); i++) {
00460         fprintf (out, "%f ", data[i]);
00461     }
00462     fprintf (out, "\n");
00463 }
00464 
00465 void ArrayTrainSet::multiply (double k) {
00466     for (int i=0; i<data.size(); i++)
00467         data[i] *= k;
00468 }
00469 
00470 void ArrayTrainSet::truncate (int newsize) {
00471     PRE (newsize<data.size() && newsize>=(inputs+outputs));
00472     data.resize (newsize);
00473     patterns = data.size() - inputs - outputs;
00474 }
00475 

Generated on Thu Feb 10 20:06:45 2005 for Inanna by doxygen1.2.18