Class: Rumale::SVM::LogisticRegression
- Inherits:
-
Base::Estimator
- Object
- Base::Estimator
- Rumale::SVM::LogisticRegression
- Includes:
- Base::Classifier
- Defined in:
- lib/rumale/svm/logistic_regression.rb
Overview
LogisticRegression is a class that provides Logistic Regression in LIBLINEAR with Rumale interface
Instance Attribute Summary collapse
-
#bias_term ⇒ Numo::DFloat
readonly
Return the bias term (a.k.a. intercept) for LogisticRegression.
-
#weight_vec ⇒ Numo::DFloat
readonly
Return the weight vector for LogisticRegression.
Instance Method Summary collapse
-
#decision_function(x) ⇒ Numo::DFloat
Calculate confidence scores for samples.
-
#fit(x, y) ⇒ LogisticRegression
Fit the model with given training data.
-
#initialize(penalty: 'l2', dual: true, reg_param: 1.0, fit_bias: true, bias_scale: 1.0, tol: 1e-3, verbose: false, random_seed: nil) ⇒ LogisticRegression
constructor
Create a new classifier with Logistic Regression.
-
#marshal_dump ⇒ Hash
Dump marshal data.
-
#marshal_load(obj) ⇒ nil
Load marshal data.
-
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
-
#predict_proba(x) ⇒ Numo::DFloat
Predict class probability for samples.
Constructor Details
#initialize(penalty: 'l2', dual: true, reg_param: 1.0, fit_bias: true, bias_scale: 1.0, tol: 1e-3, verbose: false, random_seed: nil) ⇒ LogisticRegression
Create a new classifier with Logistic Regression.
39 40 41 42 43 44 45 46 47 48 49 50 51 |
# File 'lib/rumale/svm/logistic_regression.rb', line 39 def initialize(penalty: 'l2', dual: true, reg_param: 1.0, fit_bias: true, bias_scale: 1.0, tol: 1e-3, verbose: false, random_seed: nil) super() @params = {} @params[:penalty] = penalty == 'l1' ? 'l1' : 'l2' @params[:dual] = dual @params[:reg_param] = reg_param.to_f @params[:fit_bias] = fit_bias @params[:bias_scale] = bias_scale.to_f @params[:tol] = tol.to_f @params[:verbose] = verbose @params[:random_seed] = random_seed.nil? ? nil : random_seed.to_i end |
Instance Attribute Details
#bias_term ⇒ Numo::DFloat (readonly)
Return the bias term (a.k.a. intercept) for LogisticRegression.
24 25 26 |
# File 'lib/rumale/svm/logistic_regression.rb', line 24 def bias_term @bias_term end |
#weight_vec ⇒ Numo::DFloat (readonly)
Return the weight vector for LogisticRegression.
20 21 22 |
# File 'lib/rumale/svm/logistic_regression.rb', line 20 def weight_vec @weight_vec end |
Instance Method Details
#decision_function(x) ⇒ Numo::DFloat
Calculate confidence scores for samples.
72 73 74 75 76 77 |
# File 'lib/rumale/svm/logistic_regression.rb', line 72 def decision_function(x) raise "#{self.class.name}##{__method__} expects to be called after training the model with the fit method." unless trained? x = Rumale::Validation.check_convert_sample_array(x) xx = fit_bias? ? (x) : x Numo::Liblinear.decision_function(xx, liblinear_params, @model) end |
#fit(x, y) ⇒ LogisticRegression
Fit the model with given training data.
58 59 60 61 62 63 64 65 66 |
# File 'lib/rumale/svm/logistic_regression.rb', line 58 def fit(x, y) x = Rumale::Validation.check_convert_sample_array(x) y = Rumale::Validation.check_convert_label_array(y) Rumale::Validation.check_sample_size(x, y) xx = fit_bias? ? (x) : x @model = Numo::Liblinear.train(xx, y, liblinear_params) @weight_vec, @bias_term = weight_and_bias(@model[:w]) self end |
#marshal_dump ⇒ Hash
Dump marshal data.
104 105 106 107 108 109 |
# File 'lib/rumale/svm/logistic_regression.rb', line 104 def marshal_dump { params: @params, model: @model, weight_vec: @weight_vec, bias_term: @bias_term } end |
#marshal_load(obj) ⇒ nil
Load marshal data.
113 114 115 116 117 118 119 |
# File 'lib/rumale/svm/logistic_regression.rb', line 113 def marshal_load(obj) @params = obj[:params] @model = obj[:model] @weight_vec = obj[:weight_vec] @bias_term = obj[:bias_term] nil end |
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
83 84 85 86 87 88 |
# File 'lib/rumale/svm/logistic_regression.rb', line 83 def predict(x) raise "#{self.class.name}##{__method__} expects to be called after training the model with the fit method." unless trained? x = Rumale::Validation.check_convert_sample_array(x) xx = fit_bias? ? (x) : x Numo::Int32.cast(Numo::Liblinear.predict(xx, liblinear_params, @model)) end |
#predict_proba(x) ⇒ Numo::DFloat
Predict class probability for samples. This method works correctly only if the probability parameter is true.
95 96 97 98 99 100 |
# File 'lib/rumale/svm/logistic_regression.rb', line 95 def predict_proba(x) raise "#{self.class.name}##{__method__} expects to be called after training the model with the fit method." unless trained? x = Rumale::Validation.check_convert_sample_array(x) xx = fit_bias? ? (x) : x Numo::Liblinear.predict_proba(xx, liblinear_params, @model) end |