Class: Rumale::ModelSelection::GridSearchCV

Inherits:
Base::Estimator show all
Defined in:
rumale-model_selection/lib/rumale/model_selection/grid_search_cv.rb

Overview

GridSearchCV is a class that performs hyperparameter optimization with grid search method.

Examples:

require 'rumale/ensemble'
require 'rumale/model_selection/stratified_k_fold'
require 'rumale/model_selection/grid_search_cv'

rfc = Rumale::Ensemble::RandomForestClassifier.new(random_seed: 1)
pg = { n_estimators: [5, 10], max_depth: [3, 5], max_leaf_nodes: [15, 31] }
kf = Rumale::ModelSelection::StratifiedKFold.new(n_splits: 5)
gs = Rumale::ModelSelection::GridSearchCV.new(estimator: rfc, param_grid: pg, splitter: kf)
gs.fit(samples, labels)
p gs.cv_results
p gs.best_params
rbf = Rumale::KernelApproximation::RBF.new(random_seed: 1)
svc = Rumale::LinearModel::SVC.new
pipe = Rumale::Pipeline::Pipeline.new(steps: { rbf: rbf, svc: svc })
pg = { rbf__gamma: [32.0, 1.0], rbf__n_components: [4, 128], svc__reg_param: [16.0, 0.1] }
kf = Rumale::ModelSelection::StratifiedKFold.new(n_splits: 5)
gs = Rumale::ModelSelection::GridSearchCV.new(estimator: pipe, param_grid: pg, splitter: kf)
gs.fit(samples, labels)
p gs.cv_results
p gs.best_params

Instance Attribute Summary collapse

Attributes inherited from Base::Estimator

#params

Instance Method Summary collapse

Constructor Details

#initialize(estimator: nil, param_grid: nil, splitter: nil, evaluator: nil, greater_is_better: true) ⇒ GridSearchCV

Create a new grid search method.

Parameters:

  • estimator (Classifier/Regresor) (defaults to: nil)

    The estimator to be searched for optimal parameters with grid search method.

  • param_grid (Array<Hash>) (defaults to: nil)

    The parameter sets is represented with array of hash that consists of parameter names as keys and array of parameter values as values.

  • splitter (Splitter) (defaults to: nil)

    The splitter that divides dataset to training and testing dataset on cross validation.

  • evaluator (Evaluator) (defaults to: nil)

    The evaluator that calculates score of estimator results on cross validation. If nil is given, the score method of estimator is used to evaluation.

  • greater_is_better (Boolean) (defaults to: true)

    The flag that indicates whether the estimator is better as evaluation score is larger.



65
66
67
68
69
70
71
72
73
74
# File 'rumale-model_selection/lib/rumale/model_selection/grid_search_cv.rb', line 65

def initialize(estimator: nil, param_grid: nil, splitter: nil, evaluator: nil, greater_is_better: true)
  super()
  @params = {
    param_grid: valid_param_grid(param_grid),
    estimator: Marshal.load(Marshal.dump(estimator)),
    splitter: Marshal.load(Marshal.dump(splitter)),
    evaluator: Marshal.load(Marshal.dump(evaluator)),
    greater_is_better: greater_is_better
  }
end

Instance Attribute Details

#best_estimatorEstimator (readonly)

Return the estimator learned with the best parameter.

Returns:

  • (Estimator)


53
54
55
# File 'rumale-model_selection/lib/rumale/model_selection/grid_search_cv.rb', line 53

def best_estimator
  @best_estimator
end

#best_indexInteger (readonly)

Return the index of the best parameter.

Returns:

  • (Integer)


49
50
51
# File 'rumale-model_selection/lib/rumale/model_selection/grid_search_cv.rb', line 49

def best_index
  @best_index
end

#best_paramsHash (readonly)

Return the best parameter set.

Returns:

  • (Hash)


45
46
47
# File 'rumale-model_selection/lib/rumale/model_selection/grid_search_cv.rb', line 45

def best_params
  @best_params
end

#best_scoreFloat (readonly)

Return the score of the estimator learned with the best parameter.

Returns:

  • (Float)


41
42
43
# File 'rumale-model_selection/lib/rumale/model_selection/grid_search_cv.rb', line 41

def best_score
  @best_score
end

#cv_resultsHash (readonly)

Return the result of cross validation for each parameter.

Returns:

  • (Hash)


37
38
39
# File 'rumale-model_selection/lib/rumale/model_selection/grid_search_cv.rb', line 37

def cv_results
  @cv_results
end

Instance Method Details

#decision_function(x) ⇒ Numo::DFloat

Call the decision_function method of learned estimator with the best parameter.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The samples to compute the scores.

Returns:

  • (Numo::DFloat)

    (shape: [n_samples]) Confidence score per sample.



102
103
104
# File 'rumale-model_selection/lib/rumale/model_selection/grid_search_cv.rb', line 102

def decision_function(x)
  @best_estimator.decision_function(x)
end

#fit(x, y) ⇒ GridSearchCV

Fit the model with given training data and all sets of parameters.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The training data to be used for fitting the model.

  • y (Numo::NArray)

    (shape: [n_samples, n_outputs]) The target values or labels to be used for fitting the model.

Returns:

  • (GridSearchCV)

    The learned estimator with grid search.



81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# File 'rumale-model_selection/lib/rumale/model_selection/grid_search_cv.rb', line 81

def fit(x, y)
  init_attrs

  param_combinations.each do |prm_set|
    prm_set.each do |prms|
      report = perform_cross_validation(x, y, prms)
      store_cv_result(prms, report)
    end
  end

  find_best_params

  @best_estimator = configurated_estimator(@best_params)
  @best_estimator.fit(x, y)
  self
end

#predict(x) ⇒ Numo::NArray

Call the predict method of learned estimator with the best parameter.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The samples to obtain prediction result.

Returns:

  • (Numo::NArray)

    Predicted results.



110
111
112
# File 'rumale-model_selection/lib/rumale/model_selection/grid_search_cv.rb', line 110

def predict(x)
  @best_estimator.predict(x)
end

#predict_log_proba(x) ⇒ Numo::DFloat

Call the predict_log_proba method of learned estimator with the best parameter.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The samples to predict the log-probailities.

Returns:

  • (Numo::DFloat)

    (shape: [n_samples, n_classes]) Predicted log-probability of each class per sample.



118
119
120
# File 'rumale-model_selection/lib/rumale/model_selection/grid_search_cv.rb', line 118

def predict_log_proba(x)
  @best_estimator.predict_log_proba(x)
end

#predict_proba(x) ⇒ Numo::DFloat

Call the predict_proba method of learned estimator with the best parameter.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The samples to predict the probailities.

Returns:

  • (Numo::DFloat)

    (shape: [n_samples, n_classes]) Predicted probability of each class per sample.



126
127
128
# File 'rumale-model_selection/lib/rumale/model_selection/grid_search_cv.rb', line 126

def predict_proba(x)
  @best_estimator.predict_proba(x)
end

#score(x, y) ⇒ Float

Call the score method of learned estimator with the best parameter.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) Testing data.

  • y (Numo::NArray)

    (shape: [n_samples, n_outputs]) True target values or labels for testing data.

Returns:

  • (Float)

    The score of estimator.



135
136
137
# File 'rumale-model_selection/lib/rumale/model_selection/grid_search_cv.rb', line 135

def score(x, y)
  @best_estimator.score(x, y)
end