lwpr.cpp

00001 //
00002 //  Copyright (c) 2005-2007
00003 //  James N Knight
00004 //
00005 //  Permission to use, copy, modify, distribute and sell this software
00006 //  and its documentation for any purpose is hereby granted without fee,
00007 //  provided that the above copyright notice appear in all copies and
00008 //  that both that copyright notice and this permission notice appear
00009 //  in supporting documentation.  The authors make no representations
00010 //  about the suitability of this software for any purpose.
00011 //  It is provided "as is" without express or implied warranty.
00012 //
00013 //  
00014 // See http://homepages.inf.ed.ac.uk/svijayak/publications/vijayakumar-NeuCom2005.pdf
00015 // for the original publication of this algorithm.
00016 
00017 
00018 
00019 #include <boost/numeric/ublas/io.hpp>
00020 #include <boost/program_options.hpp>
00021 
00022 #include "lwpr.hpp"
00023 
00024 #include "rf_diag.hpp"
00025 #include "rf_full.hpp"
00026 
00027 LWPR::LWPR() : vm(boost::program_options::variables_map())
00028 {
00029 
00030     // Default parameter vaules taken from paper and other implementations
00031     w_gen = 0.2;
00032     w_update = 0.001;
00033     w_predict = 0.001;
00034     diag = true;
00035     
00036 }
00037 
00038 LWPR::LWPR(const boost::program_options::variables_map& v) : vm(v){
00039     w_gen = vm["w_gen"].as<Types::RParam>();
00040     w_update = vm["w_update"].as<Types::RParam>();
00041     w_predict = vm["w_predict"].as<Types::RParam>();
00042     diag = vm["diag"].as<bool>();
00043 }
00044 
00045 LWPR::~LWPR()
00046 {
00047     for(RFList::iterator it = receptiveFields.begin(); it != receptiveFields.end(); ++it)
00048         delete *it;
00049 }
00050 
00051 
00052 void LWPR::learn(Types::Input x, Types::Output y)
00053 {
00054     // Taken From Table 5 in PAPER
00055 
00056     // For Each Receptive field
00057     //         1) Compute the activation of x at each receptive field
00058     //         2) Update the Distance Matrix
00059     //         3) Update the Local Model (projections and regression)
00060     //         4) Check if more projections are needed
00061     //
00062     // If no RF was activated by more than w_gen create a new RF with
00063     // R=2, c=x, and D=d_def*I;
00064     
00065     Types::RValue maxActivation = 0.0;
00066     Types::RValue tempActivaction = 0.0;
00067     RFList::iterator maxActivationIndex = receptiveFields.begin();
00068 
00069     for(RFList::iterator it = receptiveFields.begin(); it != receptiveFields.end(); ++it){
00070         tempActivaction = (*it)->getActivation(x);
00071 
00072         if(tempActivaction > w_update)
00073             (*it)->learn(x,y,tempActivaction);
00074 
00075         if(tempActivaction > maxActivation){
00076             maxActivation = tempActivaction;
00077             maxActivationIndex = it;
00078         }
00079     }
00080 
00081     if(maxActivation < w_gen){
00082         // make the new RF have similar shape to maximally actiavted rf if it is trustworthy.
00083         if((maxActivation > (0.1 * w_gen)) && (*maxActivationIndex)->trustworthy()){
00084             if(diag){
00085                 receptiveFields.push_back(new RFDiag(*(RFDiag*)(*maxActivationIndex),x,vm));
00086             }
00087             else{
00088                 receptiveFields.push_back(new RFFull(*(RFFull*)(*maxActivationIndex),x,vm));
00089             }
00090         }
00091         else{
00092             if(diag){
00093                 receptiveFields.push_back(new RFDiag(x,vm));
00094             }
00095             else{
00096                 receptiveFields.push_back(new RFFull(x,vm));
00097             }
00098         }
00099     }
00100 }
00101 
00102 Types::OutputT LWPR::predict(Types::Input x) const
00103 {
00104     // Taken from Equation 3.2 in PAPER
00105     
00106     Types::RValue sumActivations = 0.0;
00107     Types::RValue tempActivaction = 0.0;
00108     Types::RValue sumPredictions = 0.0;
00109 
00110     for(RFList::const_iterator it = receptiveFields.begin(); it != receptiveFields.end(); ++it){
00111         tempActivaction = (*it)->getActivation(x);
00112 
00113         if(tempActivaction > w_predict && (*it)->trustworthy()){
00114             sumPredictions += tempActivaction * (*it)->predict(x);
00115             sumActivations += tempActivaction;
00116         }
00117     }
00118 
00119     // really shouldn't compare to 0 should do < 2*eps 
00120     return (sumActivations==0)?0:sumPredictions/sumActivations;
00121 }
00122 
00123 
00124 std::ostream& operator<<(std::ostream& out, const LWPR& lwpr)
00125 {
00126     int i=0;
00127 
00128     // print out the receptive field centers and distance matrices
00129     for(LWPR::RFList::const_iterator it = lwpr.receptiveFields.begin(); it != lwpr.receptiveFields.end(); ++it){
00130 
00131         // this is terrible, got to be a better way
00132         if(lwpr.diag)
00133             out << *((RFDiag*)*it) << std::endl;
00134         else
00135             out << *((RFFull*)*it) << std::endl;
00136 
00137     }
00138 }
00139 
00140 //std::istream& operator>>(std::istream& in, LWPR& lwpr)
00141 //{
00142 //
00143 //}

Generated on Fri Jul 27 00:24:01 2007 for LWPR by  doxygen 1.5.1