nnfw/include/cluster.h

Go to the documentation of this file.
00001 /********************************************************************************
00002  *  Neural Network Framework.                                                   *
00003  *  Copyright (C) 2005-2011 Gianluca Massera <emmegian@yahoo.it>                *
00004  *                                                                              *
00005  *  This program is free software; you can redistribute it and/or modify        *
00006  *  it under the terms of the GNU General Public License as published by        *
00007  *  the Free Software Foundation; either version 2 of the License, or           *
00008  *  (at your option) any later version.                                         *
00009  *                                                                              *
00010  *  This program is distributed in the hope that it will be useful,             *
00011  *  but WITHOUT ANY WARRANTY; without even the implied warranty of              *
00012  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the               *
00013  *  GNU General Public License for more details.                                *
00014  *                                                                              *
00015  *  You should have received a copy of the GNU General Public License           *
00016  *  along with this program; if not, write to the Free Software                 *
00017  *  Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA  *
00018  ********************************************************************************/
00019 
00020 #ifndef CLUSTER_H
00021 #define CLUSTER_H
00022 
00027 #include "nnfwconfig.h"
00028 #include "updatable.h"
00029 #include "outputfunction.h"
00030 #include <exception>
00031 #include <memory>
00032 
00033 namespace farsa {
00034 
00073 class FARSA_NNFW_API Cluster : public Updatable {
00074 public:
00076     Cluster( unsigned int numNeurons, QString name = "unnamed" );
00078     Cluster( ConfigurationParameters& params, QString prefix );
00080     virtual ~Cluster();
00082     unsigned int numNeurons() const {
00083         return numneurons;
00084     };
00086     bool needReset() {
00087         return needRst;
00088     };
00093     void setAccumulate( bool mode ) {
00094         accOff = !mode;
00095     };
00097     bool isAccumulate() const {
00098         return !accOff;
00099     };
00103     virtual void randomize( double min, double max ) = 0;
00107     void setInput( unsigned int neuron, double value );
00109     void setInputs( const DoubleVector& inputs );
00113     void setAllInputs( double value );
00117     void resetInputs();
00120     double getInput( unsigned int neuron ) const;
00122     DoubleVector& inputs() {
00123         return *inputdataptr;
00124     };
00126     DoubleVector inputs() const {
00127         return *inputdataptr;
00128     };
00130     void setOutput( unsigned int neuron, double value );
00132     void setOutputs( const DoubleVector& outputs );
00134     double getOutput( unsigned int neuron ) const;
00136     DoubleVector& outputs() {
00137         return *outputdataptr;
00138     };
00140     DoubleVector outputs() const {
00141         return *outputdataptr;
00142     };
00147     void setOutFunction( OutputFunction* up );
00149     OutputFunction* outFunction() const {
00150         return updater.get();
00151     };
00160     virtual void save(ConfigurationParameters& params, QString prefix);
00162     static void describe( QString type );
00175     typedef DoubleVector& (*getStateVectorFuncPtr)( Cluster* );
00182     getStateVectorFuncPtr getDelegateFor( QString stateVector ) {
00183         if ( stateDelegates.contains( stateVector ) ) {
00184             return stateDelegates[stateVector];
00185         }
00186         throw ClusterStateVectorNotPresent( (QString("The state vector named ") + stateVector + " is not part of this Cluster").toAscii().data() );
00187     };
00188 protected:
00203     template <class T, DoubleVector& (T::*TMethod)()>
00204     void setDelegateFor( QString vectorName ) {
00205         stateDelegates[vectorName] = &staticDelegateMethod<T, TMethod>;
00206     }
00210     void setNeedReset( bool b ) {
00211         needRst = accOff && b;
00212     };
00214     DoubleVector* inputdataptr;
00216     DoubleVector* outputdataptr;
00217 private:
00219     unsigned int numneurons;
00221     DoubleVector inputdata;
00223     DoubleVector outputdata;
00225     std::auto_ptr<OutputFunction> updater;
00227     bool needRst;
00231     bool accOff;
00232 
00234     QMap<QString, getStateVectorFuncPtr> stateDelegates;
00236     template <class T, DoubleVector& (T::*TMethod)()>
00237     static DoubleVector& staticDelegateMethod( Cluster* cluster_ptr ) {
00238         T* p = static_cast<T*>(cluster_ptr);
00239         //--- call the delegate method using pointer-to-member syntax
00240         return (p->*TMethod)();
00241     }
00242 };
00243 
00244 }
00245 
00246 #endif