nnfw/include/backpropagationalgo.h

Go to the documentation of this file.
00001 /********************************************************************************
00002  *  Neural Network Framework.                                                   *
00003  *  Copyright (C) 2005-2011 Gianluca Massera <emmegian@yahoo.it>                *
00004  *                                                                              *
00005  *  This program is free software; you can redistribute it and/or modify        *
00006  *  it under the terms of the GNU General Public License as published by        *
00007  *  the Free Software Foundation; either version 2 of the License, or           *
00008  *  (at your option) any later version.                                         *
00009  *                                                                              *
00010  *  This program is distributed in the hope that it will be useful,             *
00011  *  but WITHOUT ANY WARRANTY; without even the implied warranty of              *
00012  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the               *
00013  *  GNU General Public License for more details.                                *
00014  *                                                                              *
00015  *  You should have received a copy of the GNU General Public License           *
00016  *  along with this program; if not, write to the Free Software                 *
00017  *  Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA  *
00018  ********************************************************************************/
00019 
00020 #ifndef BACKPROPAGATIONALGO_H
00021 #define BACKPROPAGATIONALGO_H
00022 
00026 #include "nnfwconfig.h"
00027 #include "learningalgorithm.h"
00028 #include "biasedcluster.h"
00029 #include "matrixlinker.h"
00030 #include <QMap>
00031 #include <QVector>
00032 
00033 namespace farsa {
00034 
00038 class FARSA_NNFW_API BackPropagationAlgo : public LearningAlgorithm {
00039 public:
00046     BackPropagationAlgo( NeuralNet *n_n, UpdatableList update_order, double l_r = 0.1 );
00048     BackPropagationAlgo();
00049 
00051     ~BackPropagationAlgo( );
00052 
00057     void setUpdateOrder( const UpdatableList& update_order );
00058     
00060     UpdatableList updateOrder() const {
00061         return update_order;
00062     };
00066     void setTeachingInput( Cluster* output, const DoubleVector& ti );
00067 
00068     virtual void learn();
00069 
00071     virtual void learn( const Pattern& );
00072 
00074     virtual double calculateMSE( const Pattern& );
00075 
00077     void setRate( double newrate ) {
00078         learn_rate = newrate;
00079     };
00080 
00082     double rate() const {
00083         return learn_rate;
00084     };
00085 
00087     void setMomentum( double newmom ) {
00088         momentumv = newmom;
00089     };
00090 
00092     double momentum() const {
00093         return momentumv;
00094     };
00095 
00097     void enableMomentum();
00098 
00100     void disableMomentum() {
00101         useMomentum = false;
00102     };
00103 
00125     DoubleVector getError( Cluster* );
00169     virtual void configure(ConfigurationParameters& params, QString prefix);
00177     virtual void save(ConfigurationParameters& params, QString prefix);
00179     static void describe( QString type );
00180 protected:
00182     virtual void neuralNetChanged();
00183 private:
00185     double learn_rate;
00187     double momentumv;
00189     double useMomentum;
00191     UpdatableList update_order;
00192 
00194     class FARSA_NNFW_API cluster_deltas {
00195     public:
00196         BiasedCluster* cluster;
00197         bool isOutput;
00198         DoubleVector deltas_outputs;
00199         DoubleVector deltas_inputs;
00200         DoubleVector last_deltas_inputs;
00201         QList<MatrixLinker*> incoming_linkers_vec;
00202         QVector<DoubleVector> incoming_last_outputs;
00203     };
00205     QMap<Cluster*, int> mapIndex;
00207     QVector<cluster_deltas> cluster_deltas_vec;
00208     // --- propagate delta through the net
00209     void propagDeltas();
00210     // --- add a Cluster into the structures above
00211     void addCluster( Cluster*, bool );
00212     // --- add a Linker into the structures above
00213     void addLinker( Linker* );
00214 
00215 };
00216 
00217 }
00218 
00219 #endif
00220