Main Page   Class Hierarchy   Compound List   File List   Compound Members   File Members  

annfilef.cc

Go to the documentation of this file.
00001 
00025 #include "fstream"
00026 
00027 #include <magic/mobject.h>
00028 #include <magic/mregexp.h>
00029 
00030 #include "inanna/annfilef.h"
00031 #include "inanna/equalization.h"
00032 
00033 
00034 
00036 //   _   |   | |   | ----- o |       -----                          |     o       //
00037 //  / \  |\  | |\  | |       |  ___  |                     ___   |  |       |     //
00038 // /   \ | \ | | \ | |---  | | /   ) |---   __  |/\ |/|/|  ___| -+- |     | |---  //
00039 // |---| |  \| |  \| |     | | |---  |     /  \ |   | | | (   |  |  |     | |   ) //
00040 // |   | |   | |   | |     | |  \__  |     \__/ |   | | |  \__|   \ |____ | |__/  //
00042 
00043 void ANNFileFormatLib::load (const String& filename,
00044                              ANNetwork& set)
00045     throw (file_not_found, invalid_format, assertion_failed, open_failure)
00046 {
00047     ASSERTWITH (!isempty(filename), "Filename required (was empty)");
00048 
00049     // Open the file
00050     TextIStream* in = &stin;
00051     if (filename != "-") {
00052         in = new TextIStream (new File (filename, IO_Readable));
00053         if (!*in)
00054             throw file_not_found (format ("Network file '%s' not found", (CONSTR) filename));
00055     }
00056 
00057     // TODO: Support other filetypes than SNNS
00058     load (*in, set, "SNNS");
00059 
00060     if (in != &stin)
00061         delete in;
00062 }
00063 
00064 void ANNFileFormatLib::load (TextIStream& in,
00065                              ANNetwork& net,
00066                              const String& filetype)
00067     throw (invalid_format, assertion_failed)
00068 {
00069     ANNFileFormat* handler = create (filetype);
00070     try {
00071         handler->load (in, net);
00072     } catch (...) {
00073         delete handler; // Clean object
00074         throw; // Rethrow
00075     }
00076     delete handler;
00077 }
00078 
00083 void ANNFileFormatLib::save (
00084     const String&    filename,      
00085     const ANNetwork& net,           
00086     const char*      fileformat)    
00087     throw (invalid_filename, invalid_format, assertion_failed, stream_failure, open_failure)
00088 {
00089     // Default to standard output stream
00090     TextOStream* out = &sout;
00091 
00092     // If a filename is given, open it for writing.
00093     if (filename != "-")
00094         out = new TextOStream (new File (filename, IO_Writable));
00095 
00096     // TODO: This is a silly check.
00097     if (! out)
00098         throw open_failure (strformat (i18n("Network file '%s' couldn't be opened for writing"),
00099                                        (CONSTR) filename));
00100 
00101     try {
00102         // TODO: Make file format dynamic
00103         save (*out, net, fileformat);
00104     } catch (...) {
00105         // Delete the stream object, if created earlier.
00106         if (out != &sout)
00107             delete out;
00108 
00109         throw; // Re-throw
00110     }
00111         
00112 
00113     // Delete the stream object, if created earlier.
00114     if (out != &sout)
00115         delete out;
00116 }
00117 
00121 void ANNFileFormatLib::save (
00122     TextOStream&     out,       
00123     const ANNetwork& net,       
00124     const String&    filetype)  
00125     throw (assertion_failed, invalid_filename, stream_failure)
00126 {
00127     // Create the file format handler.
00128     ANNFileFormat* handler = create (filetype);
00129     
00130     try {
00131         handler->save (out, net);
00132     } catch (must_overload e) {
00133         delete handler;
00134         throw invalid_filename (format (i18n("Save-to-file operation not supported for file type\n%s"),
00135                                         (CONSTR) e.what()));
00136     } catch (...) {
00137         delete handler; // Clean object
00138         throw; // Rethrow
00139     }
00140     delete handler;
00141 }
00142 
00143 
00144 ANNFileFormat* ANNFileFormatLib::create (const String& fileformat) throw (invalid_format) {
00145     // Factory method, very trivial and rigid implementation.
00146     // TODO: Make more dynamic
00147     
00148     // Parse the contents according to the file name extension
00149     if (fileformat == "SNNS")
00150         return new SNNS_ANNFormat ();
00151     else
00152         throw invalid_format (i18n("Only SNNS (.net) file format supported currently"));
00153 }
00154 
00155 
00156 
00158 //   _   |   | |   | ----- o |       -----                           //
00159 //  / \  |\  | |\  | |       |  ___  |                     ___   |   //
00160 // /   \ | \ | | \ | |---  | | /   ) |---   __  |/\ |/|/|  ___| -+-  //
00161 // |---| |  \| |  \| |     | | |---  |     /  \ |   | | | (   |  |   //
00162 // |   | |   | |   | |     | |  \__  |     \__/ |   | | |  \__|   \  //
00164 
00165 void SNNS_ANNFormat::load (TextIStream& in, ANNetwork& net) const {
00166     RegExp re_units         ("units : (.+)");
00167     RegExp re_lfunc         ("learning function : (.+)");
00168     RegExp re_ufunc         ("update function   : (.+)");
00169     RegExp re_unitdefault   ("unit default section");
00170     RegExp re_unitdefin     ("unit definition section");
00171     RegExp re_conndefin     ("connection definition section");
00172     RegExp re_unitline      ("^ +[0-9]+");
00173     RegExp re_equalization  ("Equalization:");
00174     RegExp re_EON           ("# end-of-network");
00175 
00176     String linebuf;
00177     Array<String> linesubs;
00178     int state = 0;
00179     int inputs=0, hiddens=0, outputs=0;
00180     while (in.readLine (linebuf)) {
00181 
00182         // Stop reading on End-Of-Network
00183         if (re_EON.match(linebuf))
00184             break;
00185 
00186         switch (state) {
00187           case 0:
00188               // Network size
00189               if (re_units.match (linebuf, linesubs))
00190                   net.make (linesubs[1]);
00191               
00192               // Learning function
00193               if (re_lfunc.match (linebuf, linesubs))
00194                   if (linesubs[1] != "Rprop")
00195                       throw invalid_format (strformat("SNNS learning function '%s' not supported",
00196                                                       (CONSTR) linesubs[1]));
00197 
00198               if (re_unitdefault.match (linebuf))
00199                   state = 1;
00200               break;
00201               
00202           case 1: // Reading unit default section
00203               if (re_unitdefin.match (linebuf))
00204                   state = 2;
00205               break;
00206               
00207           case 2: // Reading unit definition section
00208               if (re_unitline.match (linebuf)) {
00209                   // Extract unit fields
00210                   Array<String> unitfields;
00211                   linebuf.split (unitfields, '|');
00212                   for (int i=0; i<unitfields.size(); i++)
00213                       unitfields[i] = unitfields[i].stripWhiteSpace();
00214                   int unitid = unitfields[0].toInt()-1;
00215 
00216                   // Read activation and bias
00217                   net[unitid].setActivation (unitfields[3].toFloat());
00218                   net[unitid].setBias (unitfields[4].toFloat());
00219 
00220                   // Read unit type
00221                   if (unitfields[5].stripWhiteSpace()=="i")
00222                       inputs++;
00223                   else if (unitfields[5].stripWhiteSpace()=="h")
00224                       hiddens++;
00225                   else if (unitfields[5].stripWhiteSpace()=="o")
00226                       outputs++;
00227 
00228                   // Extract unit coordinates
00229                   Array<String> coords;
00230                   unitfields[6].split (coords, ',');
00231                   ASSERTWITH (coords.size()==3, "ANNetwork SNNS file must have 3-dimensional coordinates for units");
00232                   net[unitid].moveTo (coords[0].stripWhiteSpace().toFloat(),
00233                                       coords[1].stripWhiteSpace().toFloat(),
00234                                       coords[2].stripWhiteSpace().toFloat());
00235               }
00236               
00237               if (re_conndefin.match (linebuf))
00238                   state = 3;
00239               break;
00240               
00241           case 3: // Reading connection definition section
00242               if (re_unitline.match (linebuf)) {
00243                   // Extract connection fields
00244                   Array<String> unitfields;
00245                   linebuf.split (unitfields, '|');
00246                   int targetid = unitfields[0].stripWhiteSpace().toInt()-1;
00247 
00248                   // Extract connections to this target unit
00249                   Array<String> conns;
00250                   unitfields[2].split (conns, ',');
00251                   for (int i=0; i<conns.size(); i++) {
00252                       Array<String> connpair;
00253                       conns[i].split (connpair, ':');
00254 
00255                       // Make connection
00256                       Connection* newconn = net.connect (connpair[0].stripWhiteSpace().toInt()-1, targetid);
00257 
00258                       // Set conenction weight
00259                       newconn->setWeight (connpair[1].stripWhiteSpace().toFloat());
00260                   }
00261               }
00262               
00263               if (re_equalization.match (linebuf) && !net.getEqualizer()) {
00264                   // Read the '# ' in the beginning of the next row
00265                   in >> linebuf;
00266 
00267                   // Read the equalization object
00268                   net.setEqualizer (readEqualizer(in));
00269               }
00270               break;
00271         };
00272     }
00273 
00274     // Make topology handler according to parameters acquired
00275     const_cast<ANNLayering&>(dynamic_cast<const ANNLayering&>(net.getTopology())).make (format ("%d-%d-%d", inputs, hiddens, outputs));
00276 }
00277 
00278 void SNNS_ANNFormat::save (TextOStream& out,
00279                            const ANNetwork& net) const
00280     throw (stream_failure)
00281 {
00282     // SNNS Version
00283     out << "SNNS network definition file V1.4-3D\n";
00284     out << "generated at <time>\n\n"
00285         << "network name : Network\n"
00286         << "source files :\n"
00287         << "no. of units : " << net.size() << "\n"
00288         << "no. of connections : 0\n"
00289         << "no. of unit types : 0\n"
00290         << "no. of site types : 0\n\n\n"
00291         << "learning function : Rprop\n"
00292         << "update function   : Topological_Order\n\n\n"
00293         << "unit default section :\n\n"
00294         << "act      | bias     | st | subnet | layer | act func     | out func\n"
00295         << "---------|----------|----|--------|-------|--------------|-------------\n"
00296         << " 0.00000 |  0.00000 | h  |      0 |     1 | Act_Logistic | Out_Identity\n"
00297         << "---------|----------|----|--------|-------|--------------|-------------\n\n\n";
00298 
00299     // List units
00300     out << "unit definition section :\n\n"
00301         << "no. | typeName | unitName | act      | bias     | st | position | act func             | out func | sites\n"
00302         << "----|----------|----------|----------|----------|----|----------|----------------------|----------|-------\n";
00303 
00304     // Go through all units
00305     for (int i=0; i<net.size(); i++) {
00306         // Determine unit status
00307         char st='h'; // Hidden
00308 
00309         // The layering object tells the unit type (input/output/hidden)
00310         if (const ANNLayering* layering = dynamic_cast<const ANNLayering*> (&net.getTopology()))
00311             if (layering->layers() > 1)
00312                 if (i < (*layering)[0])
00313                     st = 'i';
00314                 else if (i >= net.size()-(*layering)[-1])
00315                     st = 'o';
00316 
00317         // Print unit line
00318         out << strformat ("%3d |          | unit     | % 3.5f | % 3.5f | %c  | %2g,%2g,%2g |||\n",
00319                           i+1, net[i].activation(), net[i].bias(), st,
00320                           net[i].getPlace().x, net[i].getPlace().y, net[i].getPlace().z);
00321     }
00322     out << "----|----------|----------|----------|----------|----|----------|----------------------|----------|-------\n";
00323 
00324     // List connections
00325     out << "\n\nconnection definition section :\n\n"
00326         << "target | site | source:weight\n"
00327         << "-------|------|---------------------------------------------------------------------------------------------------------------------\n";
00328     for (int i=0; i<net.size(); i++)
00329         if (net[i].incomings()>0) {
00330             out << strformat ("%6d |      |", i+1); // Target neuron ID
00331             for (int j=0; j<net[i].incomings(); j++) {
00332                 if (j>0)
00333                     out << ',';
00334                 out << strformat ("%3d:% .5f",
00335                                   net[i].incoming(j).source().id()+1, // Source neuron ID
00336                                   net[i].incoming(j).weight()); // Connection weight
00337             }
00338             out << '\n';
00339         }
00340     out << "-------|------|---------------------------------------------------------------------------------------------------------------------\n";
00341 
00342     // Store equalization object
00343     if (const Equalizer* eq = net.getEqualizer()) {
00344         out << "\n# Equalization:\n"
00345             << "# " << (*eq) << "\n";
00346     }
00347 
00348     out << "# end-of-network\n";
00349 }
00350 

Generated on Thu Feb 10 20:06:44 2005 for Inanna by doxygen1.2.18