#include <model.hpp>
Public Member Functions | |
Model () | |
The default constructor. | |
Model (int input_dim) | |
The constructor. | |
Model (int input_dim, const boost::program_options::variables_map &vm) | |
A constructor taking a program_options::variables_map for getting paramter values. | |
void | learn (Types::Input x, Types::Output y, Types::RValue w, Types::RValue W, Types::RValue W_old) |
Update the model given a weighted data point and output. | |
bool | trustworthy () const |
Check whether enough data has been seen to trust this model. | |
Types::Input | check_derivatives () const |
Check whether derivatives with respect to projections are ok. | |
bool | updateNumProjections () |
Check if a new projection is needed. | |
void | project (Types::Input x, Types::vector &z) const |
Project an input to local low dimensional space. | |
Types::OutputT | predict (Types::Input x) const |
Predict a value from a given input. | |
Types::RValue | get_e_cv () const |
return last estimate of e_cv | |
Types::RValue | get_e () const |
return last estimate of e | |
Types::RValue | get_MSE () const |
return last estimate of MSE | |
Types::Input | get_a_zz () const |
get last estimate of a_zz | |
Types::Input | get_z () const |
get last projected x | |
Private Attributes | |
int | R |
the number of projections. | |
int | N |
the dimension of the input | |
Types::Matrix | u |
Projection Directions. | |
Types::Matrix | p |
Types::Matrix | u_temp |
Types::RValue | e |
error variables | |
Types::vector | e_cv |
cross validation error for each dimension | |
Types::vector | y_target |
store the errors for each dimension | |
Types::Matrix | x_res |
Types::vector | beta |
Types::vector | z |
Types::RValue | y_hat |
Types::RValue | res |
Types::vector | x_temp |
Algorithm Parameters | |
This group of variables control the behavior of the algorithm.
All paramters are defined in the paper (see PAPER). | |
Types::RParam | lambda |
A forgetting factor for sufficient statistics. | |
Types::RParam | phi |
A threshold for adding new projections. | |
Sufficient Statistics | |
Types::vector | x_0 |
Weighted Input History. | |
Types::RValue | beta_0 |
Weighted Output History. | |
Types::vector | MSE |
MSE memory. | |
Types::vector | a_zz |
Types::vector | a_zres |
Types::Matrix | a_xz |
Types::vector | data_count |
amount of data seen for projection dimension | |
Types::vector | derivatives_ok |
checks for derivatives being valid, uses data_count | |
Friends | |
std::ostream & | operator<< (std::ostream &out, const Model &m) |
Print the model to a stream in simple format. |
The weights come from a receptive field, RF. The models basic interface is learn(input, output) and predict(input).
Definition at line 31 of file model.hpp.
void Model::learn | ( | Types::Input | x, | |
Types::Output | y, | |||
Types::RValue | w, | |||
Types::RValue | W, | |||
Types::RValue | W_old | |||
) |
Update the model given a weighted data point and output.
x | the input vector | |
y | the true output | |
w | the current activation (from RF) | |
W | the activation history (from RF) | |
W_old | the old activation history (from RF) |
Definition at line 122 of file model.cpp.
References a_xz, a_zres, a_zz, beta, beta_0, data_count, e, e_cv, lambda, MSE, p, R, res, u, u_temp, x_0, x_res, y_hat, y_target, and z.
Referenced by RFFull::learn(), and RFDiag::learn().
bool Model::updateNumProjections | ( | ) |
Check if a new projection is needed.
If it is add it. This requires param .
Definition at line 197 of file model.cpp.
References a_xz, a_zres, a_zz, beta, data_count, derivatives_ok, e, e_cv, lambda, MSE, N, p, phi, R, u, u_temp, x_0, x_res, y_target, and z.
Referenced by RFFull::learn(), and RFDiag::learn().
void Model::project | ( | Types::Input | x, | |
Types::vector & | z | |||
) | const |
Types::OutputT Model::predict | ( | Types::Input | x | ) | const |
Types::RParam Model::lambda [private] |
A forgetting factor for sufficient statistics.
This parameter must have a value in [0,1]. It may be updated over time.
Definition at line 52 of file model.hpp.
Referenced by check_derivatives(), learn(), Model(), and updateNumProjections().
Types::RParam Model::phi [private] |
A threshold for adding new projections.
If then add a projection. Must take a value in [0,1]
Definition at line 58 of file model.hpp.
Referenced by Model(), and updateNumProjections().