Class: Rumale::LinearModel::LinearRegression
- Inherits:
-
BaseEstimator
- Object
- Base::Estimator
- BaseEstimator
- Rumale::LinearModel::LinearRegression
- Includes:
- Base::Regressor
- Defined in:
- rumale-linear_model/lib/rumale/linear_model/linear_regression.rb
Overview
LinearRegression is a class that implements ordinary least square linear regression with singular value decomposition (SVD) or L-BFGS optimization.
Instance Attribute Summary
Attributes inherited from BaseEstimator
Attributes inherited from Base::Estimator
Instance Method Summary collapse
-
#fit(x, y) ⇒ LinearRegression
Fit the model with given training data.
-
#initialize(fit_bias: true, bias_scale: 1.0, max_iter: 1000, tol: 1e-4, solver: 'auto', verbose: false) ⇒ LinearRegression
constructor
Create a new ordinary least square linear regressor.
-
#predict(x) ⇒ Numo::DFloat
Predict values for samples.
Methods included from Base::Regressor
Constructor Details
#initialize(fit_bias: true, bias_scale: 1.0, max_iter: 1000, tol: 1e-4, solver: 'auto', verbose: false) ⇒ LinearRegression
Create a new ordinary least square linear regressor.
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
# File 'rumale-linear_model/lib/rumale/linear_model/linear_regression.rb', line 48 def initialize(fit_bias: true, bias_scale: 1.0, max_iter: 1000, tol: 1e-4, solver: 'auto', verbose: false) super() @params = { fit_bias: fit_bias, bias_scale: bias_scale, max_iter: max_iter, tol: tol, verbose: verbose } @params[:solver] = if solver == 'auto' enable_linalg?(warning: false) ? 'svd' : 'lbfgs' else solver.match?(/^svd$|^lbfgs$/) ? solver : 'lbfgs' end end |
Instance Method Details
#fit(x, y) ⇒ LinearRegression
Fit the model with given training data.
69 70 71 72 73 74 75 76 77 78 79 80 81 |
# File 'rumale-linear_model/lib/rumale/linear_model/linear_regression.rb', line 69 def fit(x, y) x = Rumale::Validation.check_convert_sample_array(x) y = Rumale::Validation.check_convert_target_value_array(y) Rumale::Validation.check_sample_size(x, y) @weight_vec, @bias_term = if @params[:solver] == 'svd' && enable_linalg?(warning: false) partial_fit_svd(x, y) else partial_fit_lbfgs(x, y) end self end |
#predict(x) ⇒ Numo::DFloat
Predict values for samples.
87 88 89 90 91 |
# File 'rumale-linear_model/lib/rumale/linear_model/linear_regression.rb', line 87 def predict(x) x = Rumale::Validation.check_convert_sample_array(x) x.dot(@weight_vec.transpose) + @bias_term end |