Class: Rumale::LinearModel::Ridge
- Inherits:
-
BaseEstimator
- Object
- Base::Estimator
- BaseEstimator
- Rumale::LinearModel::Ridge
- Includes:
- Base::Regressor
- Defined in:
- rumale-linear_model/lib/rumale/linear_model/ridge.rb
Overview
Ridge is a class that implements Ridge Regression with singular value decomposition (SVD) or L-BFGS optimization.
Instance Attribute Summary
Attributes inherited from BaseEstimator
Attributes inherited from Base::Estimator
Instance Method Summary collapse
-
#fit(x, y) ⇒ Ridge
Fit the model with given training data.
-
#initialize(reg_param: 1.0, fit_bias: true, bias_scale: 1.0, max_iter: 1000, tol: 1e-4, solver: 'auto', verbose: false) ⇒ Ridge
constructor
Create a new Ridge regressor.
-
#predict(x) ⇒ Numo::DFloat
Predict values for samples.
Methods included from Base::Regressor
Constructor Details
#initialize(reg_param: 1.0, fit_bias: true, bias_scale: 1.0, max_iter: 1000, tol: 1e-4, solver: 'auto', verbose: false) ⇒ Ridge
Create a new Ridge regressor.
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
# File 'rumale-linear_model/lib/rumale/linear_model/ridge.rb', line 48 def initialize(reg_param: 1.0, fit_bias: true, bias_scale: 1.0, max_iter: 1000, tol: 1e-4, solver: 'auto', verbose: false) super() @params = { reg_param: reg_param, fit_bias: fit_bias, bias_scale: bias_scale, max_iter: max_iter, tol: tol, verbose: verbose } @params[:solver] = if solver == 'auto' enable_linalg?(warning: false) ? 'svd' : 'lbfgs' else solver.match?(/^svd$|^lbfgs$/) ? solver : 'lbfgs' end end |
Instance Method Details
#fit(x, y) ⇒ Ridge
Fit the model with given training data.
70 71 72 73 74 75 76 77 78 79 80 81 82 |
# File 'rumale-linear_model/lib/rumale/linear_model/ridge.rb', line 70 def fit(x, y) x = Rumale::Validation.check_convert_sample_array(x) y = Rumale::Validation.check_convert_target_value_array(y) Rumale::Validation.check_sample_size(x, y) @weight_vec, @bias_term = if @params[:solver] == 'svd' && enable_linalg?(warning: false) partial_fit_svd(x, y) else partial_fit_lbfgs(x, y) end self end |
#predict(x) ⇒ Numo::DFloat
Predict values for samples.
88 89 90 91 92 |
# File 'rumale-linear_model/lib/rumale/linear_model/ridge.rb', line 88 def predict(x) x = Rumale::Validation.check_convert_sample_array(x) x.dot(@weight_vec.transpose) + @bias_term end |