backpropagationalgo.h
Go to the documentation of this file.
1 /********************************************************************************
2  * Neural Network Framework. *
3  * Copyright (C) 2005-2011 Gianluca Massera <emmegian@yahoo.it> *
4  * *
5  * This program is free software; you can redistribute it and/or modify *
6  * it under the terms of the GNU General Public License as published by *
7  * the Free Software Foundation; either version 2 of the License, or *
8  * (at your option) any later version. *
9  * *
10  * This program is distributed in the hope that it will be useful, *
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of *
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
13  * GNU General Public License for more details. *
14  * *
15  * You should have received a copy of the GNU General Public License *
16  * along with this program; if not, write to the Free Software *
17  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA *
18  ********************************************************************************/
19 
20 #ifndef BACKPROPAGATIONALGO_H
21 #define BACKPROPAGATIONALGO_H
22 
26 #include "nnfwconfig.h"
27 #include "learningalgorithm.h"
28 #include "biasedcluster.h"
29 #include "matrixlinker.h"
30 #include <QMap>
31 #include <QVector>
32 
33 namespace farsa {
34 
38 class FARSA_NNFW_API BackPropagationAlgo : public LearningAlgorithm {
39 public:
46  BackPropagationAlgo( NeuralNet *n_n, UpdatableList update_order, double l_r = 0.1 );
49 
52 
57  void setUpdateOrder( const UpdatableList& update_order );
58 
60  UpdatableList updateOrder() const {
61  return update_order;
62  };
66  void setTeachingInput( Cluster* output, const DoubleVector& ti );
67 
68  virtual void learn();
69 
71  virtual void learn( const Pattern& );
72 
74  virtual double calculateMSE( const Pattern& );
75 
77  void setRate( double newrate ) {
78  learn_rate = newrate;
79  };
80 
82  double rate() const {
83  return learn_rate;
84  };
85 
87  void setMomentum( double newmom ) {
88  momentumv = newmom;
89  };
90 
92  double momentum() const {
93  return momentumv;
94  };
95 
97  void enableMomentum();
98 
101  useMomentum = false;
102  };
103 
125  DoubleVector getError( Cluster* );
169  virtual void configure(ConfigurationParameters& params, QString prefix);
177  virtual void save(ConfigurationParameters& params, QString prefix);
179  static void describe( QString type );
180 protected:
182  virtual void neuralNetChanged();
183 private:
185  double learn_rate;
187  double momentumv;
189  double useMomentum;
191  UpdatableList update_order;
192 
194  class FARSA_NNFW_API cluster_deltas {
195  public:
196  BiasedCluster* cluster;
197  bool isOutput;
198  DoubleVector deltas_outputs;
199  DoubleVector deltas_inputs;
200  DoubleVector last_deltas_inputs;
201  QList<MatrixLinker*> incoming_linkers_vec;
202  QVector<DoubleVector> incoming_last_outputs;
203  };
205  QMap<Cluster*, int> mapIndex;
207  QVector<cluster_deltas> cluster_deltas_vec;
208  // --- propagate delta through the net
209  void propagDeltas();
210  // --- add a Cluster into the structures above
211  void addCluster( Cluster*, bool );
212  // --- add a Linker into the structures above
213  void addLinker( Linker* );
214 
215 };
216 
217 }
218 
219 #endif
220