00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019 #include <boost/numeric/ublas/io.hpp>
00020 #include <boost/program_options.hpp>
00021
00022 #include "lwpr.hpp"
00023
00024 #include "rf_diag.hpp"
00025 #include "rf_full.hpp"
00026
00027 LWPR::LWPR() : vm(boost::program_options::variables_map())
00028 {
00029
00030
00031 w_gen = 0.2;
00032 w_update = 0.001;
00033 w_predict = 0.001;
00034 diag = true;
00035
00036 }
00037
00038 LWPR::LWPR(const boost::program_options::variables_map& v) : vm(v){
00039 w_gen = vm["w_gen"].as<Types::RParam>();
00040 w_update = vm["w_update"].as<Types::RParam>();
00041 w_predict = vm["w_predict"].as<Types::RParam>();
00042 diag = vm["diag"].as<bool>();
00043 }
00044
00045 LWPR::~LWPR()
00046 {
00047 for(RFList::iterator it = receptiveFields.begin(); it != receptiveFields.end(); ++it)
00048 delete *it;
00049 }
00050
00051
00052 void LWPR::learn(Types::Input x, Types::Output y)
00053 {
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065 Types::RValue maxActivation = 0.0;
00066 Types::RValue tempActivaction = 0.0;
00067 RFList::iterator maxActivationIndex = receptiveFields.begin();
00068
00069 for(RFList::iterator it = receptiveFields.begin(); it != receptiveFields.end(); ++it){
00070 tempActivaction = (*it)->getActivation(x);
00071
00072 if(tempActivaction > w_update)
00073 (*it)->learn(x,y,tempActivaction);
00074
00075 if(tempActivaction > maxActivation){
00076 maxActivation = tempActivaction;
00077 maxActivationIndex = it;
00078 }
00079 }
00080
00081 if(maxActivation < w_gen){
00082
00083 if((maxActivation > (0.1 * w_gen)) && (*maxActivationIndex)->trustworthy()){
00084 if(diag){
00085 receptiveFields.push_back(new RFDiag(*(RFDiag*)(*maxActivationIndex),x,vm));
00086 }
00087 else{
00088 receptiveFields.push_back(new RFFull(*(RFFull*)(*maxActivationIndex),x,vm));
00089 }
00090 }
00091 else{
00092 if(diag){
00093 receptiveFields.push_back(new RFDiag(x,vm));
00094 }
00095 else{
00096 receptiveFields.push_back(new RFFull(x,vm));
00097 }
00098 }
00099 }
00100 }
00101
00102 Types::OutputT LWPR::predict(Types::Input x) const
00103 {
00104
00105
00106 Types::RValue sumActivations = 0.0;
00107 Types::RValue tempActivaction = 0.0;
00108 Types::RValue sumPredictions = 0.0;
00109
00110 for(RFList::const_iterator it = receptiveFields.begin(); it != receptiveFields.end(); ++it){
00111 tempActivaction = (*it)->getActivation(x);
00112
00113 if(tempActivaction > w_predict && (*it)->trustworthy()){
00114 sumPredictions += tempActivaction * (*it)->predict(x);
00115 sumActivations += tempActivaction;
00116 }
00117 }
00118
00119
00120 return (sumActivations==0)?0:sumPredictions/sumActivations;
00121 }
00122
00123
00124 std::ostream& operator<<(std::ostream& out, const LWPR& lwpr)
00125 {
00126 int i=0;
00127
00128
00129 for(LWPR::RFList::const_iterator it = lwpr.receptiveFields.begin(); it != lwpr.receptiveFields.end(); ++it){
00130
00131
00132 if(lwpr.diag)
00133 out << *((RFDiag*)*it) << std::endl;
00134 else
00135 out << *((RFFull*)*it) << std::endl;
00136
00137 }
00138 }
00139
00140
00141
00142
00143