00001
00025 #include "fstream"
00026
00027 #include <magic/mobject.h>
00028 #include <magic/mregexp.h>
00029
00030 #include "inanna/annfilef.h"
00031 #include "inanna/equalization.h"
00032
00033
00034
00036
00037
00038
00039
00040
00042
00043 void ANNFileFormatLib::load (const String& filename,
00044 ANNetwork& set)
00045 throw (file_not_found, invalid_format, assertion_failed, open_failure)
00046 {
00047 ASSERTWITH (!isempty(filename), "Filename required (was empty)");
00048
00049
00050 TextIStream* in = &stin;
00051 if (filename != "-") {
00052 in = new TextIStream (new File (filename, IO_Readable));
00053 if (!*in)
00054 throw file_not_found (format ("Network file '%s' not found", (CONSTR) filename));
00055 }
00056
00057
00058 load (*in, set, "SNNS");
00059
00060 if (in != &stin)
00061 delete in;
00062 }
00063
00064 void ANNFileFormatLib::load (TextIStream& in,
00065 ANNetwork& net,
00066 const String& filetype)
00067 throw (invalid_format, assertion_failed)
00068 {
00069 ANNFileFormat* handler = create (filetype);
00070 try {
00071 handler->load (in, net);
00072 } catch (...) {
00073 delete handler;
00074 throw;
00075 }
00076 delete handler;
00077 }
00078
00083 void ANNFileFormatLib::save (
00084 const String& filename,
00085 const ANNetwork& net,
00086 const char* fileformat)
00087 throw (invalid_filename, invalid_format, assertion_failed, stream_failure, open_failure)
00088 {
00089
00090 TextOStream* out = &sout;
00091
00092
00093 if (filename != "-")
00094 out = new TextOStream (new File (filename, IO_Writable));
00095
00096
00097 if (! out)
00098 throw open_failure (strformat (i18n("Network file '%s' couldn't be opened for writing"),
00099 (CONSTR) filename));
00100
00101 try {
00102
00103 save (*out, net, fileformat);
00104 } catch (...) {
00105
00106 if (out != &sout)
00107 delete out;
00108
00109 throw;
00110 }
00111
00112
00113
00114 if (out != &sout)
00115 delete out;
00116 }
00117
00121 void ANNFileFormatLib::save (
00122 TextOStream& out,
00123 const ANNetwork& net,
00124 const String& filetype)
00125 throw (assertion_failed, invalid_filename, stream_failure)
00126 {
00127
00128 ANNFileFormat* handler = create (filetype);
00129
00130 try {
00131 handler->save (out, net);
00132 } catch (must_overload e) {
00133 delete handler;
00134 throw invalid_filename (format (i18n("Save-to-file operation not supported for file type\n%s"),
00135 (CONSTR) e.what()));
00136 } catch (...) {
00137 delete handler;
00138 throw;
00139 }
00140 delete handler;
00141 }
00142
00143
00144 ANNFileFormat* ANNFileFormatLib::create (const String& fileformat) throw (invalid_format) {
00145
00146
00147
00148
00149 if (fileformat == "SNNS")
00150 return new SNNS_ANNFormat ();
00151 else
00152 throw invalid_format (i18n("Only SNNS (.net) file format supported currently"));
00153 }
00154
00155
00156
00158
00159
00160
00161
00162
00164
00165 void SNNS_ANNFormat::load (TextIStream& in, ANNetwork& net) const {
00166 RegExp re_units ("units : (.+)");
00167 RegExp re_lfunc ("learning function : (.+)");
00168 RegExp re_ufunc ("update function : (.+)");
00169 RegExp re_unitdefault ("unit default section");
00170 RegExp re_unitdefin ("unit definition section");
00171 RegExp re_conndefin ("connection definition section");
00172 RegExp re_unitline ("^ +[0-9]+");
00173 RegExp re_equalization ("Equalization:");
00174 RegExp re_EON ("# end-of-network");
00175
00176 String linebuf;
00177 Array<String> linesubs;
00178 int state = 0;
00179 int inputs=0, hiddens=0, outputs=0;
00180 while (in.readLine (linebuf)) {
00181
00182
00183 if (re_EON.match(linebuf))
00184 break;
00185
00186 switch (state) {
00187 case 0:
00188
00189 if (re_units.match (linebuf, linesubs))
00190 net.make (linesubs[1]);
00191
00192
00193 if (re_lfunc.match (linebuf, linesubs))
00194 if (linesubs[1] != "Rprop")
00195 throw invalid_format (strformat("SNNS learning function '%s' not supported",
00196 (CONSTR) linesubs[1]));
00197
00198 if (re_unitdefault.match (linebuf))
00199 state = 1;
00200 break;
00201
00202 case 1:
00203 if (re_unitdefin.match (linebuf))
00204 state = 2;
00205 break;
00206
00207 case 2:
00208 if (re_unitline.match (linebuf)) {
00209
00210 Array<String> unitfields;
00211 linebuf.split (unitfields, '|');
00212 for (int i=0; i<unitfields.size(); i++)
00213 unitfields[i] = unitfields[i].stripWhiteSpace();
00214 int unitid = unitfields[0].toInt()-1;
00215
00216
00217 net[unitid].setActivation (unitfields[3].toFloat());
00218 net[unitid].setBias (unitfields[4].toFloat());
00219
00220
00221 if (unitfields[5].stripWhiteSpace()=="i")
00222 inputs++;
00223 else if (unitfields[5].stripWhiteSpace()=="h")
00224 hiddens++;
00225 else if (unitfields[5].stripWhiteSpace()=="o")
00226 outputs++;
00227
00228
00229 Array<String> coords;
00230 unitfields[6].split (coords, ',');
00231 ASSERTWITH (coords.size()==3, "ANNetwork SNNS file must have 3-dimensional coordinates for units");
00232 net[unitid].moveTo (coords[0].stripWhiteSpace().toFloat(),
00233 coords[1].stripWhiteSpace().toFloat(),
00234 coords[2].stripWhiteSpace().toFloat());
00235 }
00236
00237 if (re_conndefin.match (linebuf))
00238 state = 3;
00239 break;
00240
00241 case 3:
00242 if (re_unitline.match (linebuf)) {
00243
00244 Array<String> unitfields;
00245 linebuf.split (unitfields, '|');
00246 int targetid = unitfields[0].stripWhiteSpace().toInt()-1;
00247
00248
00249 Array<String> conns;
00250 unitfields[2].split (conns, ',');
00251 for (int i=0; i<conns.size(); i++) {
00252 Array<String> connpair;
00253 conns[i].split (connpair, ':');
00254
00255
00256 Connection* newconn = net.connect (connpair[0].stripWhiteSpace().toInt()-1, targetid);
00257
00258
00259 newconn->setWeight (connpair[1].stripWhiteSpace().toFloat());
00260 }
00261 }
00262
00263 if (re_equalization.match (linebuf) && !net.getEqualizer()) {
00264
00265 in >> linebuf;
00266
00267
00268 net.setEqualizer (readEqualizer(in));
00269 }
00270 break;
00271 };
00272 }
00273
00274
00275 const_cast<ANNLayering&>(dynamic_cast<const ANNLayering&>(net.getTopology())).make (format ("%d-%d-%d", inputs, hiddens, outputs));
00276 }
00277
00278 void SNNS_ANNFormat::save (TextOStream& out,
00279 const ANNetwork& net) const
00280 throw (stream_failure)
00281 {
00282
00283 out << "SNNS network definition file V1.4-3D\n";
00284 out << "generated at <time>\n\n"
00285 << "network name : Network\n"
00286 << "source files :\n"
00287 << "no. of units : " << net.size() << "\n"
00288 << "no. of connections : 0\n"
00289 << "no. of unit types : 0\n"
00290 << "no. of site types : 0\n\n\n"
00291 << "learning function : Rprop\n"
00292 << "update function : Topological_Order\n\n\n"
00293 << "unit default section :\n\n"
00294 << "act | bias | st | subnet | layer | act func | out func\n"
00295 << "---------|----------|----|--------|-------|--------------|-------------\n"
00296 << " 0.00000 | 0.00000 | h | 0 | 1 | Act_Logistic | Out_Identity\n"
00297 << "---------|----------|----|--------|-------|--------------|-------------\n\n\n";
00298
00299
00300 out << "unit definition section :\n\n"
00301 << "no. | typeName | unitName | act | bias | st | position | act func | out func | sites\n"
00302 << "----|----------|----------|----------|----------|----|----------|----------------------|----------|-------\n";
00303
00304
00305 for (int i=0; i<net.size(); i++) {
00306
00307 char st='h';
00308
00309
00310 if (const ANNLayering* layering = dynamic_cast<const ANNLayering*> (&net.getTopology()))
00311 if (layering->layers() > 1)
00312 if (i < (*layering)[0])
00313 st = 'i';
00314 else if (i >= net.size()-(*layering)[-1])
00315 st = 'o';
00316
00317
00318 out << strformat ("%3d | | unit | % 3.5f | % 3.5f | %c | %2g,%2g,%2g |||\n",
00319 i+1, net[i].activation(), net[i].bias(), st,
00320 net[i].getPlace().x, net[i].getPlace().y, net[i].getPlace().z);
00321 }
00322 out << "----|----------|----------|----------|----------|----|----------|----------------------|----------|-------\n";
00323
00324
00325 out << "\n\nconnection definition section :\n\n"
00326 << "target | site | source:weight\n"
00327 << "-------|------|---------------------------------------------------------------------------------------------------------------------\n";
00328 for (int i=0; i<net.size(); i++)
00329 if (net[i].incomings()>0) {
00330 out << strformat ("%6d | |", i+1);
00331 for (int j=0; j<net[i].incomings(); j++) {
00332 if (j>0)
00333 out << ',';
00334 out << strformat ("%3d:% .5f",
00335 net[i].incoming(j).source().id()+1,
00336 net[i].incoming(j).weight());
00337 }
00338 out << '\n';
00339 }
00340 out << "-------|------|---------------------------------------------------------------------------------------------------------------------\n";
00341
00342
00343 if (const Equalizer* eq = net.getEqualizer()) {
00344 out << "\n# Equalization:\n"
00345 << "# " << (*eq) << "\n";
00346 }
00347
00348 out << "# end-of-network\n";
00349 }
00350