nnfw/include/learningalgorithm.h

Go to the documentation of this file.
00001 /********************************************************************************
00002  *  Neural Network Framework.                                                   *
00003  *  Copyright (C) 2005-2011 Gianluca Massera <emmegian@yahoo.it>                *
00004  *                                                                              *
00005  *  This program is free software; you can redistribute it and/or modify        *
00006  *  it under the terms of the GNU General Public License as published by        *
00007  *  the Free Software Foundation; either version 2 of the License, or           *
00008  *  (at your option) any later version.                                         *
00009  *                                                                              *
00010  *  This program is distributed in the hope that it will be useful,             *
00011  *  but WITHOUT ANY WARRANTY; without even the implied warranty of              *
00012  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the               *
00013  *  GNU General Public License for more details.                                *
00014  *                                                                              *
00015  *  You should have received a copy of the GNU General Public License           *
00016  *  along with this program; if not, write to the Free Software                 *
00017  *  Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA  *
00018  ********************************************************************************/
00019 
00020 #ifndef LEARNINGALGORITHM_H
00021 #define LEARNINGALGORITHM_H
00022 
00026 #include "nnfwconfig.h"
00027 #include "neuralnet.h"
00028 #include <QMap>
00029 #include <QVector>
00030 #include <cmath>
00031 #include <parametersettable.h>
00032 #include <configurationparameters.h>
00033 
00034 namespace farsa {
00035 
00036 class NeuralNet;
00037 
00066 class FARSA_NNFW_API Pattern : public ParameterSettableWithConfigureFunction {
00067 public:
00068     class PatternInfo {
00069     public:
00070         DoubleVector inputs;
00071         DoubleVector outputs;
00072     };
00074     Pattern() : ParameterSettableWithConfigureFunction(), pinfo() { /*nothing to do*/ };
00076     ~Pattern() { /*nothing to do*/ };
00078     void setInputsOf( Cluster*, const DoubleVector& );
00080     void setOutputsOf( Cluster*, const DoubleVector& );
00082     void setInputsOutputsOf( Cluster*, const DoubleVector& inputs, const DoubleVector& outputs );
00084     DoubleVector inputsOf( Cluster* ) const;
00086     DoubleVector outputsOf( Cluster* ) const;
00089     PatternInfo& operator[]( Cluster* );
00115     virtual void configure(ConfigurationParameters& params, QString prefix);
00123     virtual void save(ConfigurationParameters& params, QString prefix);
00125     static void describe( QString type );
00126 private:
00127     mutable QMap<Cluster*, PatternInfo> pinfo;
00128 };
00129 
00137 typedef QVector<Pattern> PatternSet;
00138 
00143 class FARSA_NNFW_API LearningAlgorithm : public ParameterSettableWithConfigureFunction {
00144 public:
00146     LearningAlgorithm( NeuralNet* net );
00148     LearningAlgorithm();
00150     virtual ~LearningAlgorithm();
00152     void setNeuralNet( NeuralNet* net ) {
00153         netp = net;
00154         this->neuralNetChanged();
00155     };
00157     NeuralNet* neuralNet() {
00158         return netp;
00159     };
00161     virtual void learn() = 0;
00163     virtual void learn( const Pattern& ) = 0;
00165     virtual void learnOnSet( const PatternSet& set ) {
00166         for( int i=0; i<(int)set.size(); i++ ) {
00167             learn( set[i] );
00168         }
00169     };
00171     virtual double calculateMSE( const Pattern& ) = 0;
00173     virtual double calculateMSEOnSet( const PatternSet& set ) {
00174         double mseacc = 0.0;
00175         int dim = (int)set.size();
00176         for( int i=0; i<dim; i++ ) {
00177             mseacc += calculateMSE( set[i] );
00178         }
00179         return mseacc/dim;
00180     };
00182     double calculateRMSD( const Pattern& p ) {
00183         return sqrt( calculateMSE( p ) );
00184     };
00186     double calculateRMSDOnSet( const PatternSet& p ) {
00187         return sqrt( calculateMSEOnSet( p ) );
00188     };
00190     PatternSet loadPatternSet( ConfigurationParameters& params, QString path, QString prefix );
00192     void savePatternSet( PatternSet& set, ConfigurationParameters& params, QString prefix );
00193 protected:
00195     virtual void neuralNetChanged() = 0;
00196 private:
00197     NeuralNet* netp;
00198 };
00199 
00200 }
00201 
00202 #endif
00203