Class: Rumale::ModelSelection::GridSearchCV
- Inherits:
-
Base::Estimator
- Object
- Base::Estimator
- Rumale::ModelSelection::GridSearchCV
- Defined in:
- rumale-model_selection/lib/rumale/model_selection/grid_search_cv.rb
Overview
GridSearchCV is a class that performs hyperparameter optimization with grid search method.
Instance Attribute Summary collapse
-
#best_estimator ⇒ Estimator
readonly
Return the estimator learned with the best parameter.
-
#best_index ⇒ Integer
readonly
Return the index of the best parameter.
-
#best_params ⇒ Hash
readonly
Return the best parameter set.
-
#best_score ⇒ Float
readonly
Return the score of the estimator learned with the best parameter.
-
#cv_results ⇒ Hash
readonly
Return the result of cross validation for each parameter.
Attributes inherited from Base::Estimator
Instance Method Summary collapse
-
#decision_function(x) ⇒ Numo::DFloat
Call the decision_function method of learned estimator with the best parameter.
-
#fit(x, y) ⇒ GridSearchCV
Fit the model with given training data and all sets of parameters.
-
#initialize(estimator: nil, param_grid: nil, splitter: nil, evaluator: nil, greater_is_better: true) ⇒ GridSearchCV
constructor
Create a new grid search method.
-
#predict(x) ⇒ Numo::NArray
Call the predict method of learned estimator with the best parameter.
-
#predict_log_proba(x) ⇒ Numo::DFloat
Call the predict_log_proba method of learned estimator with the best parameter.
-
#predict_proba(x) ⇒ Numo::DFloat
Call the predict_proba method of learned estimator with the best parameter.
-
#score(x, y) ⇒ Float
Call the score method of learned estimator with the best parameter.
Constructor Details
#initialize(estimator: nil, param_grid: nil, splitter: nil, evaluator: nil, greater_is_better: true) ⇒ GridSearchCV
Create a new grid search method.
65 66 67 68 69 70 71 72 73 74 |
# File 'rumale-model_selection/lib/rumale/model_selection/grid_search_cv.rb', line 65 def initialize(estimator: nil, param_grid: nil, splitter: nil, evaluator: nil, greater_is_better: true) super() @params = { param_grid: valid_param_grid(param_grid), estimator: Marshal.load(Marshal.dump(estimator)), splitter: Marshal.load(Marshal.dump(splitter)), evaluator: Marshal.load(Marshal.dump(evaluator)), greater_is_better: greater_is_better } end |
Instance Attribute Details
#best_estimator ⇒ Estimator (readonly)
Return the estimator learned with the best parameter.
53 54 55 |
# File 'rumale-model_selection/lib/rumale/model_selection/grid_search_cv.rb', line 53 def best_estimator @best_estimator end |
#best_index ⇒ Integer (readonly)
Return the index of the best parameter.
49 50 51 |
# File 'rumale-model_selection/lib/rumale/model_selection/grid_search_cv.rb', line 49 def best_index @best_index end |
#best_params ⇒ Hash (readonly)
Return the best parameter set.
45 46 47 |
# File 'rumale-model_selection/lib/rumale/model_selection/grid_search_cv.rb', line 45 def best_params @best_params end |
#best_score ⇒ Float (readonly)
Return the score of the estimator learned with the best parameter.
41 42 43 |
# File 'rumale-model_selection/lib/rumale/model_selection/grid_search_cv.rb', line 41 def best_score @best_score end |
#cv_results ⇒ Hash (readonly)
Return the result of cross validation for each parameter.
37 38 39 |
# File 'rumale-model_selection/lib/rumale/model_selection/grid_search_cv.rb', line 37 def cv_results @cv_results end |
Instance Method Details
#decision_function(x) ⇒ Numo::DFloat
Call the decision_function method of learned estimator with the best parameter.
102 103 104 |
# File 'rumale-model_selection/lib/rumale/model_selection/grid_search_cv.rb', line 102 def decision_function(x) @best_estimator.decision_function(x) end |
#fit(x, y) ⇒ GridSearchCV
Fit the model with given training data and all sets of parameters.
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
# File 'rumale-model_selection/lib/rumale/model_selection/grid_search_cv.rb', line 81 def fit(x, y) init_attrs param_combinations.each do |prm_set| prm_set.each do |prms| report = perform_cross_validation(x, y, prms) store_cv_result(prms, report) end end find_best_params @best_estimator = configurated_estimator(@best_params) @best_estimator.fit(x, y) self end |
#predict(x) ⇒ Numo::NArray
Call the predict method of learned estimator with the best parameter.
110 111 112 |
# File 'rumale-model_selection/lib/rumale/model_selection/grid_search_cv.rb', line 110 def predict(x) @best_estimator.predict(x) end |
#predict_log_proba(x) ⇒ Numo::DFloat
Call the predict_log_proba method of learned estimator with the best parameter.
118 119 120 |
# File 'rumale-model_selection/lib/rumale/model_selection/grid_search_cv.rb', line 118 def predict_log_proba(x) @best_estimator.predict_log_proba(x) end |
#predict_proba(x) ⇒ Numo::DFloat
Call the predict_proba method of learned estimator with the best parameter.
126 127 128 |
# File 'rumale-model_selection/lib/rumale/model_selection/grid_search_cv.rb', line 126 def predict_proba(x) @best_estimator.predict_proba(x) end |
#score(x, y) ⇒ Float
Call the score method of learned estimator with the best parameter.
135 136 137 |
# File 'rumale-model_selection/lib/rumale/model_selection/grid_search_cv.rb', line 135 def score(x, y) @best_estimator.score(x, y) end |