Class: Rumale::LinearModel::LinearRegression

Inherits:
BaseEstimator show all
Includes:
Base::Regressor
Defined in:
rumale-linear_model/lib/rumale/linear_model/linear_regression.rb

Overview

LinearRegression is a class that implements ordinary least square linear regression with singular value decomposition (SVD) or L-BFGS optimization.

Examples:

require 'rumale/linear_model/linear_regression'

estimator = Rumale::LinearModel::LinearRegression.new
estimator.fit(training_samples, traininig_values)
results = estimator.predict(testing_samples)

# If Numo::Linalg is installed, you can specify 'svd' for the solver option.
require 'numo/linalg/autoloader'
require 'rumale/linear_model/linear_regression'

estimator = Rumale::LinearModel::LinearRegression.new(solver: 'svd')
estimator.fit(training_samples, traininig_values)
results = estimator.predict(testing_samples)

Instance Attribute Summary

Attributes inherited from BaseEstimator

#bias_term, #weight_vec

Attributes inherited from Base::Estimator

#params

Instance Method Summary collapse

Methods included from Base::Regressor

#score

Constructor Details

#initialize(fit_bias: true, bias_scale: 1.0, max_iter: 1000, tol: 1e-4, solver: 'auto', verbose: false) ⇒ LinearRegression

Create a new ordinary least square linear regressor.

Parameters:

  • fit_bias (Boolean) (defaults to: true)

    The flag indicating whether to fit the bias term.

  • bias_scale (Float) (defaults to: 1.0)

    The scale of the bias term.

  • max_iter (Integer) (defaults to: 1000)

    The maximum number of epochs that indicates how many times the whole data is given to the training process. If solver is ‘svd’, this parameter is ignored.

  • tol (Float) (defaults to: 1e-4)

    The tolerance of loss for terminating optimization. If solver is ‘svd’, this parameter is ignored.

  • solver (String) (defaults to: 'auto')

    The algorithm to calculate weights. (‘auto’, ‘svd’ or ‘lbfgs’). ‘auto’ chooses the ‘svd’ solver if Numo::Linalg is loaded. Otherwise, it chooses the ‘lbfgs’ solver. ‘svd’ performs singular value decomposition of samples. ‘lbfgs’ uses the L-BFGS method for optimization.

  • verbose (Boolean) (defaults to: false)

    The flag indicating whether to output loss during iteration. If solver is ‘svd’, this parameter is ignored.



48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# File 'rumale-linear_model/lib/rumale/linear_model/linear_regression.rb', line 48

def initialize(fit_bias: true, bias_scale: 1.0, max_iter: 1000, tol: 1e-4, solver: 'auto', verbose: false)
  super()
  @params = {
    fit_bias: fit_bias,
    bias_scale: bias_scale,
    max_iter: max_iter,
    tol: tol,
    verbose: verbose
  }
  @params[:solver] = if solver == 'auto'
                       enable_linalg?(warning: false) ? 'svd' : 'lbfgs'
                     else
                       solver.match?(/^svd$|^lbfgs$/) ? solver : 'lbfgs'
                     end
end

Instance Method Details

#fit(x, y) ⇒ LinearRegression

Fit the model with given training data.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The training data to be used for fitting the model.

  • y (Numo::DFloat)

    (shape: [n_samples, n_outputs]) The target values to be used for fitting the model.

Returns:



69
70
71
72
73
74
75
76
77
78
79
80
81
# File 'rumale-linear_model/lib/rumale/linear_model/linear_regression.rb', line 69

def fit(x, y)
  x = Rumale::Validation.check_convert_sample_array(x)
  y = Rumale::Validation.check_convert_target_value_array(y)
  Rumale::Validation.check_sample_size(x, y)

  @weight_vec, @bias_term = if @params[:solver] == 'svd' && enable_linalg?(warning: false)
                              partial_fit_svd(x, y)
                            else
                              partial_fit_lbfgs(x, y)
                            end

  self
end

#predict(x) ⇒ Numo::DFloat

Predict values for samples.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The samples to predict the values.

Returns:

  • (Numo::DFloat)

    (shape: [n_samples, n_outputs]) Predicted values per sample.



87
88
89
90
91
# File 'rumale-linear_model/lib/rumale/linear_model/linear_regression.rb', line 87

def predict(x)
  x = Rumale::Validation.check_convert_sample_array(x)

  x.dot(@weight_vec.transpose) + @bias_term
end