Main Page   Class Hierarchy   Compound List   File List   Compound Members   File Members  

Trainer Class Reference

Abstract baseclass for neural network training algorithms (strategies). More...

#include <trainer.h>

Inheritance diagram for Trainer:

BackpropTrainer RPropTrainer List of all members.

Public Methods

 Trainer ()
virtual void init (const StringMap &params)
virtual double train (ANNetwork &network, const PatternSet &trainset, int cycles, const PatternSet *pValidationSet=NULL, int validationInterval=0)
void setTerminator (const String &name)
int cyclesTrained () const
int totalCycles () const
double generalizLoss () const
void setGeneralizLoss (double gl)
const Vector & trainingRecord () const
const Vector & validationRecord () const
void setObserver (TrainingObserver *observer)

Protected Methods

virtual void initTrain (ANNetwork &network) const
virtual double trainOnce (ANNetwork &network, const PatternSet &set) const

Protected Attributes

String mTerminatorName
int mTrained
int mTotalTrained
double mGeneralizationLoss
Vector mTrainingProfile
Vector mValidationProfile
TrainingObserverpTrainingObserver

Detailed Description

Abstract baseclass for neural network training algorithms (strategies).

A trainer has three functions - to train network with data, to store temporary data, and to keep record of the training process.

Design Patterns: Strategy.

Definition at line 51 of file trainer.h.


Constructor & Destructor Documentation

Trainer  
 

This file is part of the Inanna library.

* * Copyright (C) 1997-2002 Marko Grönroos <magi@iki.fi> * *

* This library is free software; you can redistribute it and/or * modify it under the terms of the GNU Library General Public * License as published by the Free Software Foundation; either * version 2 of the License, or (at your option) any later version. * * This library is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Library General Public License for more details. * * You should have received a copy of the GNU Library General Public * License along with this library; see the file COPYING.LIB. If * not, write to the Free Software Foundation, Inc., 59 Temple Place *

  • Suite 330, Boston, MA 02111-1307, USA. * *

Definition at line 37 of file trainer.cc.

References mGeneralizationLoss, mTerminatorName, mTotalTrained, mTrained, and pTrainingObserver.


Member Function Documentation

int cyclesTrained   const [inline]
 

Returns the number of training cycles the network has been trained so far.

The network may actually have been trained more than this, but an earlier weight state may have been restored by a Terminator (early stopping method).

Definition at line 89 of file trainer.h.

References mTrained.

double generalizLoss   const [inline]
 

Returns the percentual error growth between lowest validation error and the validation error at the end of the training.

Definition at line 101 of file trainer.h.

References mGeneralizationLoss.

void init const StringMap &    params [virtual]
 

Initialize the algorithm with the given parameters.

Reimplemented in BackpropTrainer, and RPropTrainer.

Definition at line 45 of file trainer.cc.

Referenced by RPropTrainer::init(), and BackpropTrainer::init().

void initTrain ANNetwork   network const [protected, virtual]
 

Initializes training.

Reimplemented in BackpropTrainer, and RPropTrainer.

Definition at line 50 of file trainer.cc.

References ANNetwork::init().

Referenced by BackpropTrainer::initTrain(), and train().

void setGeneralizLoss double    gl [inline]
 

This method is used by the Terminator method to store the generalization loss value here.

Definition at line 106 of file trainer.h.

References mGeneralizationLoss.

Referenced by Terminator::validate().

void setObserver TrainingObserver   observer [inline]
 

Sets the observer object for the trainer, to track the the training progress.

Definition at line 122 of file trainer.h.

References pTrainingObserver.

Referenced by AbsoluteNeuralPrediction::train().

void setTerminator const String &    name [inline]
 

Sets the termination method by name (UP2, GL5, etc).

See Terminator, it's inheritors and the global buildTerminator() method for more info about the parameters.

Definition at line 78 of file trainer.h.

References mTerminatorName.

Referenced by AbsoluteNeuralPrediction::train().

int totalCycles   const [inline]
 

Total number of training cycles, including any cycles cut out by a Terminator.

See also:
LearningMapping::cyclesTrained

Definition at line 96 of file trainer.h.

References mTotalTrained.

double train ANNetwork   network,
const PatternSet   trainset,
int    cycles,
const PatternSet   pValidationSet = NULL,
int    validationInterval = 0
[virtual]
 

Train the given network with the given training set.

Notice that the trainer should change just the weights of the network, not alter the network objects in any way, for example to store algorithm-specific data to neurons or connections. That data should be stored to the trainer object itself.

Definition at line 56 of file trainer.cc.

References buildTerminator(), TrainingObserver::cycleTrained(), Terminator::generalizationLoss(), Terminator::howManyTrained(), initTrain(), mGeneralizationLoss, Terminator::minimumError(), mTerminatorName, mTotalTrained, mTrained, mTrainingProfile, mValidationProfile, PatternSource::patterns, pTrainingObserver, Terminator::restore(), trainOnce(), Terminator::validate(), Terminator::validationError(), and TrainingObserver::wantsToStop().

Referenced by AbsoluteNeuralPrediction::train().

const Vector& trainingRecord   const [inline]
 

Returns a history record of training set error during training.

Useful for both statistical analysis of learning curves and possibly also some Terminator methods.

Definition at line 112 of file trainer.h.

References mTrainingProfile.

virtual double trainOnce ANNetwork   network,
const PatternSet   set
const [inline, protected, virtual]
 

Trains the pattern set once.

Reimplemented in BackpropTrainer.

Definition at line 130 of file trainer.h.

Referenced by train().

const Vector& validationRecord   const [inline]
 

Returns a history recording of validation set error during training.

Useful for both statistical analysis of learning curves and possibly also some Terminator methods.

Definition at line 118 of file trainer.h.

References mValidationProfile.


Member Data Documentation

double mGeneralizationLoss [protected]
 

Current loss of generalization ability of the network in respect to the state with the lowest validation error.

This value changes during the training if some Terminator which uses the GL measurement is used.

Generalization loss is calculated by dividing last fitness by the best validation error.

Definition at line 160 of file trainer.h.

Referenced by generalizLoss(), setGeneralizLoss(), train(), and Trainer().

String mTerminatorName [protected]
 

Name of the current termination method.

Definition at line 135 of file trainer.h.

Referenced by setTerminator(), train(), and Trainer().

int mTotalTrained [protected]
 

Total number of training cycles, including any cycles cut out by a Terminator.

See also:
LearningMapping::mTrained

Definition at line 149 of file trainer.h.

Referenced by totalCycles(), train(), and Trainer().

int mTrained [protected]
 

Number of relevant training cycles trained so far.

The network may have been trained more than this, but an earlier weight state may have been restored by a Terminator (early stopping method).

Definition at line 142 of file trainer.h.

Referenced by cyclesTrained(), train(), and Trainer().

Vector mTrainingProfile [protected]
 

History record of training set error during training.

Useful for both statistical analysis of learning curves and possibly also some Terminator methods.

Definition at line 166 of file trainer.h.

Referenced by train(), and trainingRecord().

Vector mValidationProfile [protected]
 

Recording of validation set error during training.

Useful for both statistical analysis of learning curves and possibly also some Terminator methods.

Definition at line 172 of file trainer.h.

Referenced by train(), and validationRecord().

TrainingObserver* pTrainingObserver [protected]
 

Observer object that gets called after every cycle.

Definition at line 176 of file trainer.h.

Referenced by setObserver(), train(), and Trainer().


The documentation for this class was generated from the following files:
Generated on Thu Feb 10 20:06:46 2005 for Inanna by doxygen1.2.18