nnfw/include/backpropagationalgo.h
Go to the documentation of this file.
00001 /******************************************************************************** 00002 * Neural Network Framework. * 00003 * Copyright (C) 2005-2011 Gianluca Massera <emmegian@yahoo.it> * 00004 * * 00005 * This program is free software; you can redistribute it and/or modify * 00006 * it under the terms of the GNU General Public License as published by * 00007 * the Free Software Foundation; either version 2 of the License, or * 00008 * (at your option) any later version. * 00009 * * 00010 * This program is distributed in the hope that it will be useful, * 00011 * but WITHOUT ANY WARRANTY; without even the implied warranty of * 00012 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * 00013 * GNU General Public License for more details. * 00014 * * 00015 * You should have received a copy of the GNU General Public License * 00016 * along with this program; if not, write to the Free Software * 00017 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA * 00018 ********************************************************************************/ 00019 00020 #ifndef BACKPROPAGATIONALGO_H 00021 #define BACKPROPAGATIONALGO_H 00022 00026 #include "nnfwconfig.h" 00027 #include "learningalgorithm.h" 00028 #include "biasedcluster.h" 00029 #include "matrixlinker.h" 00030 #include <QMap> 00031 #include <QVector> 00032 00033 namespace farsa { 00034 00038 class FARSA_NNFW_API BackPropagationAlgo : public LearningAlgorithm { 00039 public: 00046 BackPropagationAlgo( NeuralNet *n_n, UpdatableList update_order, double l_r = 0.1 ); 00048 BackPropagationAlgo(); 00049 00051 ~BackPropagationAlgo( ); 00052 00057 void setUpdateOrder( const UpdatableList& update_order ); 00058 00060 UpdatableList updateOrder() const { 00061 return update_order; 00062 }; 00066 void setTeachingInput( Cluster* output, const DoubleVector& ti ); 00067 00068 virtual void learn(); 00069 00071 virtual void learn( const Pattern& ); 00072 00074 virtual double calculateMSE( const Pattern& ); 00075 00077 void setRate( double newrate ) { 00078 learn_rate = newrate; 00079 }; 00080 00082 double rate() const { 00083 return learn_rate; 00084 }; 00085 00087 void setMomentum( double newmom ) { 00088 momentumv = newmom; 00089 }; 00090 00092 double momentum() const { 00093 return momentumv; 00094 }; 00095 00097 void enableMomentum(); 00098 00100 void disableMomentum() { 00101 useMomentum = false; 00102 }; 00103 00125 DoubleVector getError( Cluster* ); 00169 virtual void configure(ConfigurationParameters& params, QString prefix); 00177 virtual void save(ConfigurationParameters& params, QString prefix); 00179 static void describe( QString type ); 00180 protected: 00182 virtual void neuralNetChanged(); 00183 private: 00185 double learn_rate; 00187 double momentumv; 00189 double useMomentum; 00191 UpdatableList update_order; 00192 00194 class FARSA_NNFW_API cluster_deltas { 00195 public: 00196 BiasedCluster* cluster; 00197 bool isOutput; 00198 DoubleVector deltas_outputs; 00199 DoubleVector deltas_inputs; 00200 DoubleVector last_deltas_inputs; 00201 QList<MatrixLinker*> incoming_linkers_vec; 00202 QVector<DoubleVector> incoming_last_outputs; 00203 }; 00205 QMap<Cluster*, int> mapIndex; 00207 QVector<cluster_deltas> cluster_deltas_vec; 00208 // --- propagate delta through the net 00209 void propagDeltas(); 00210 // --- add a Cluster into the structures above 00211 void addCluster( Cluster*, bool ); 00212 // --- add a Linker into the structures above 00213 void addLinker( Linker* ); 00214 00215 }; 00216 00217 } 00218 00219 #endif 00220