Class: Rumale::KernelMachine::KernelRidge
- Inherits:
-
Base::Estimator
- Object
- Base::Estimator
- Rumale::KernelMachine::KernelRidge
- Includes:
- Base::Regressor
- Defined in:
- rumale-kernel_machine/lib/rumale/kernel_machine/kernel_ridge.rb
Overview
KernelRidge is a class that implements kernel ridge regression.
Instance Attribute Summary collapse
-
#weight_vec ⇒ Numo::DFloat
readonly
Return the weight vector.
Attributes inherited from Base::Estimator
Instance Method Summary collapse
-
#fit(x, y) ⇒ KernelRidge
Fit the model with given training data.
-
#initialize(reg_param: 1.0) ⇒ KernelRidge
constructor
Create a new regressor with kernel ridge regression.
-
#predict(x) ⇒ Numo::DFloat
Predict values for samples.
Methods included from Base::Regressor
Constructor Details
#initialize(reg_param: 1.0) ⇒ KernelRidge
Create a new regressor with kernel ridge regression.
32 33 34 35 36 37 |
# File 'rumale-kernel_machine/lib/rumale/kernel_machine/kernel_ridge.rb', line 32 def initialize(reg_param: 1.0) super() @params = { reg_param: reg_param } end |
Instance Attribute Details
#weight_vec ⇒ Numo::DFloat (readonly)
Return the weight vector.
27 28 29 |
# File 'rumale-kernel_machine/lib/rumale/kernel_machine/kernel_ridge.rb', line 27 def weight_vec @weight_vec end |
Instance Method Details
#fit(x, y) ⇒ KernelRidge
Fit the model with given training data.
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
# File 'rumale-kernel_machine/lib/rumale/kernel_machine/kernel_ridge.rb', line 45 def fit(x, y) x = ::Rumale::Validation.check_convert_sample_array(x) y = ::Rumale::Validation.check_convert_target_value_array(y) ::Rumale::Validation.check_sample_size(x, y) raise ArgumentError, 'Expect the kernel matrix of training data to be square.' unless x.shape[0] == x.shape[1] raise 'KernelRidge#fit requires Numo::Linalg but that is not loaded.' unless enable_linalg?(warning: false) n_samples = x.shape[0] if @params[:reg_param].is_a?(Float) reg_kernel_mat = x + Numo::DFloat.eye(n_samples) * @params[:reg_param] @weight_vec = Numo::Linalg.solve(reg_kernel_mat, y, driver: 'sym') else n_outputs = y.shape[1] @weight_vec = Numo::DFloat.zeros(n_samples, n_outputs) n_outputs.times do |n| reg_kernel_mat = x + Numo::DFloat.eye(n_samples) * @params[:reg_param][n] @weight_vec[true, n] = Numo::Linalg.solve(reg_kernel_mat, y[true, n], driver: 'sym') end end self end |
#predict(x) ⇒ Numo::DFloat
Predict values for samples.
74 75 76 77 78 |
# File 'rumale-kernel_machine/lib/rumale/kernel_machine/kernel_ridge.rb', line 74 def predict(x) x = ::Rumale::Validation.check_convert_sample_array(x) x.dot(@weight_vec) end |