nnfw/include/learningalgorithm.h
Go to the documentation of this file.
00001 /******************************************************************************** 00002 * Neural Network Framework. * 00003 * Copyright (C) 2005-2011 Gianluca Massera <emmegian@yahoo.it> * 00004 * * 00005 * This program is free software; you can redistribute it and/or modify * 00006 * it under the terms of the GNU General Public License as published by * 00007 * the Free Software Foundation; either version 2 of the License, or * 00008 * (at your option) any later version. * 00009 * * 00010 * This program is distributed in the hope that it will be useful, * 00011 * but WITHOUT ANY WARRANTY; without even the implied warranty of * 00012 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * 00013 * GNU General Public License for more details. * 00014 * * 00015 * You should have received a copy of the GNU General Public License * 00016 * along with this program; if not, write to the Free Software * 00017 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA * 00018 ********************************************************************************/ 00019 00020 #ifndef LEARNINGALGORITHM_H 00021 #define LEARNINGALGORITHM_H 00022 00026 #include "nnfwconfig.h" 00027 #include "neuralnet.h" 00028 #include <QMap> 00029 #include <QVector> 00030 #include <cmath> 00031 #include <parametersettable.h> 00032 #include <configurationparameters.h> 00033 00034 namespace farsa { 00035 00036 class NeuralNet; 00037 00066 class FARSA_NNFW_API Pattern : public ParameterSettableWithConfigureFunction { 00067 public: 00068 class PatternInfo { 00069 public: 00070 DoubleVector inputs; 00071 DoubleVector outputs; 00072 }; 00074 Pattern() : ParameterSettableWithConfigureFunction(), pinfo() { /*nothing to do*/ }; 00076 ~Pattern() { /*nothing to do*/ }; 00078 void setInputsOf( Cluster*, const DoubleVector& ); 00080 void setOutputsOf( Cluster*, const DoubleVector& ); 00082 void setInputsOutputsOf( Cluster*, const DoubleVector& inputs, const DoubleVector& outputs ); 00084 DoubleVector inputsOf( Cluster* ) const; 00086 DoubleVector outputsOf( Cluster* ) const; 00089 PatternInfo& operator[]( Cluster* ); 00115 virtual void configure(ConfigurationParameters& params, QString prefix); 00123 virtual void save(ConfigurationParameters& params, QString prefix); 00125 static void describe( QString type ); 00126 private: 00127 mutable QMap<Cluster*, PatternInfo> pinfo; 00128 }; 00129 00137 typedef QVector<Pattern> PatternSet; 00138 00143 class FARSA_NNFW_API LearningAlgorithm : public ParameterSettableWithConfigureFunction { 00144 public: 00146 LearningAlgorithm( NeuralNet* net ); 00148 LearningAlgorithm(); 00150 virtual ~LearningAlgorithm(); 00152 void setNeuralNet( NeuralNet* net ) { 00153 netp = net; 00154 this->neuralNetChanged(); 00155 }; 00157 NeuralNet* neuralNet() { 00158 return netp; 00159 }; 00161 virtual void learn() = 0; 00163 virtual void learn( const Pattern& ) = 0; 00165 virtual void learnOnSet( const PatternSet& set ) { 00166 for( int i=0; i<(int)set.size(); i++ ) { 00167 learn( set[i] ); 00168 } 00169 }; 00171 virtual double calculateMSE( const Pattern& ) = 0; 00173 virtual double calculateMSEOnSet( const PatternSet& set ) { 00174 double mseacc = 0.0; 00175 int dim = (int)set.size(); 00176 for( int i=0; i<dim; i++ ) { 00177 mseacc += calculateMSE( set[i] ); 00178 } 00179 return mseacc/dim; 00180 }; 00182 double calculateRMSD( const Pattern& p ) { 00183 return sqrt( calculateMSE( p ) ); 00184 }; 00186 double calculateRMSDOnSet( const PatternSet& p ) { 00187 return sqrt( calculateMSEOnSet( p ) ); 00188 }; 00190 PatternSet loadPatternSet( ConfigurationParameters& params, QString path, QString prefix ); 00192 void savePatternSet( PatternSet& set, ConfigurationParameters& params, QString prefix ); 00193 protected: 00195 virtual void neuralNetChanged() = 0; 00196 private: 00197 NeuralNet* netp; 00198 }; 00199 00200 } 00201 00202 #endif 00203