00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019 #ifndef MODEL_HPP
00020 #define MODEL_HPP
00021
00022 #include <boost/program_options.hpp>
00023
00024 #include "types.hpp"
00025
00031 class Model {
00032
00034 int R;
00035
00037 int N;
00038
00047
00052 Types::RParam lambda;
00053
00058 Types::RParam phi;
00059
00061
00062
00063
00072
00074 Types::vector x_0;
00075
00077 Types::RValue beta_0;
00078
00080 Types::vector MSE;
00081
00082 Types::vector a_zz, a_zres;
00083
00084 Types::Matrix a_xz;
00085
00086
00088 Types::vector data_count;
00089
00091 mutable Types::vector derivatives_ok;
00092
00094
00096 Types::Matrix u, p, u_temp;
00097
00099 Types::RValue e;
00100
00102 Types::vector e_cv;
00103
00105 Types::vector y_target;
00106
00107
00108 Types::Matrix x_res;
00109 Types::vector beta, z;
00110 Types::RValue y_hat, res;
00111
00112 Types::vector x_temp;
00113
00114
00115 public:
00119 Model();
00120
00124 Model(int input_dim);
00125
00129 Model(int input_dim, const boost::program_options::variables_map& vm);
00130
00131
00141 void learn(Types::Input x, Types::Output y, Types::RValue w, Types::RValue W, Types::RValue W_old);
00142
00146 bool trustworthy() const;
00147
00149 Types::Input check_derivatives() const;
00150
00154 bool updateNumProjections();
00155
00164 void project(Types::Input x, Types::vector& z) const;
00165
00173 Types::OutputT predict(Types::Input x) const;
00174
00176 Types::RValue get_e_cv() const { return e_cv(e_cv.size()-1); };
00177
00179 Types::RValue get_e() const { return e; };
00180
00182 Types::RValue get_MSE() const { return MSE(R-1); };
00183
00185 Types::Input get_a_zz() const { return a_zz; };
00186
00188 Types::Input get_z() const { return z; };
00189
00193 friend std::ostream& operator<<(std::ostream& out, const Model& m);
00194
00195
00196 };
00197
00198 #endif