backpropagationalgo.h
1 /********************************************************************************
2  * Neural Network Framework. *
3  * Copyright (C) 2005-2011 Gianluca Massera <emmegian@yahoo.it> *
4  * *
5  * This program is free software; you can redistribute it and/or modify *
6  * it under the terms of the GNU General Public License as published by *
7  * the Free Software Foundation; either version 2 of the License, or *
8  * (at your option) any later version. *
9  * *
10  * This program is distributed in the hope that it will be useful, *
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of *
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
13  * GNU General Public License for more details. *
14  * *
15  * You should have received a copy of the GNU General Public License *
16  * along with this program; if not, write to the Free Software *
17  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA *
18  ********************************************************************************/
19 
20 #ifndef BACKPROPAGATIONALGO_H
21 #define BACKPROPAGATIONALGO_H
22 
23 #include "nnfwconfig.h"
24 #include "learningalgorithm.h"
25 #include "biasedcluster.h"
26 #include "matrixlinker.h"
27 #include <QMap>
28 #include <QVector>
29 
30 namespace farsa {
31 
35 class FARSA_NNFW_API BackPropagationAlgo : public LearningAlgorithm {
36 public:
43  BackPropagationAlgo( NeuralNet *n_n, UpdatableList update_order, double l_r = 0.1 );
46 
49 
54  void setUpdateOrder( const UpdatableList& update_order );
55 
57  UpdatableList updateOrder() const {
58  return update_order;
59  };
63  void setTeachingInput( Cluster* output, const DoubleVector& ti );
64 
65  virtual void learn();
66 
68  virtual void learn( const Pattern& );
69 
71  virtual double calculateMSE( const Pattern& );
72 
74  void setRate( double newrate ) {
75  learn_rate = newrate;
76  };
77 
79  double rate() const {
80  return learn_rate;
81  };
82 
84  void setMomentum( double newmom ) {
85  momentumv = newmom;
86  };
87 
89  double momentum() const {
90  return momentumv;
91  };
92 
94  void enableMomentum();
95 
97  void disableMomentum() {
98  useMomentum = false;
99  };
100 
122  DoubleVector getError( Cluster* );
166  virtual void configure(ConfigurationParameters& params, QString prefix);
174  virtual void save(ConfigurationParameters& params, QString prefix);
176  static void describe( QString type );
177 protected:
179  virtual void neuralNetChanged();
180 private:
182  double learn_rate;
184  double momentumv;
186  double useMomentum;
188  UpdatableList update_order;
189 
191  class FARSA_NNFW_API cluster_deltas {
192  public:
193  BiasedCluster* cluster;
194  bool isOutput;
195  DoubleVector deltas_outputs;
196  DoubleVector deltas_inputs;
197  DoubleVector last_deltas_inputs;
198  QList<MatrixLinker*> incoming_linkers_vec;
199  QVector<DoubleVector> incoming_last_outputs;
200  };
202  QMap<Cluster*, int> mapIndex;
204  QVector<cluster_deltas> cluster_deltas_vec;
205  // --- propagate delta through the net
206  void propagDeltas();
207  // --- add a Cluster into the structures above
208  void addCluster( Cluster*, bool );
209  // --- add a Linker into the structures above
210  void addLinker( Linker* );
211 
212 };
213 
214 }
215 
216 #endif
217