Class: Rumale::Ensemble::VRTreesClassifier
- Inherits:
-
RandomForestClassifier
- Object
- Base::Estimator
- RandomForestClassifier
- Rumale::Ensemble::VRTreesClassifier
- Defined in:
- rumale-ensemble/lib/rumale/ensemble/vr_trees_classifier.rb
Overview
VRTreesClassifier is a class that implements variable-random (VR) trees for classification.
Reference
-
Liu, F. T., Ting, K. M., Yu, Y., and Zhou, Z. H., “Spectrum of Variable-Random Trees,” Journal of Artificial Intelligence Research, vol. 32, pp. 355–384, 2008.
Instance Attribute Summary collapse
-
#classes ⇒ Numo::Int32
readonly
Return the class labels.
-
#estimators ⇒ Array<VRTreeClassifier>
readonly
Return the set of estimators.
-
#feature_importances ⇒ Numo::DFloat
readonly
Return the importance for each feature.
-
#rng ⇒ Random
readonly
Return the random generator for random selection of feature index.
Attributes inherited from Base::Estimator
Instance Method Summary collapse
-
#apply(x) ⇒ Numo::Int32
Return the index of the leaf that each sample reached.
-
#fit(x, y) ⇒ VRTreesClassifier
Fit the model with given training data.
-
#initialize(n_estimators: 10, criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, n_jobs: nil, random_seed: nil) ⇒ VRTreesClassifier
constructor
Create a new classifier with variable-random trees.
-
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
-
#predict_proba(x) ⇒ Numo::DFloat
Predict probability for samples.
Methods included from Base::Classifier
Constructor Details
#initialize(n_estimators: 10, criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, n_jobs: nil, random_seed: nil) ⇒ VRTreesClassifier
Create a new classifier with variable-random trees.
57 58 59 60 61 |
# File 'rumale-ensemble/lib/rumale/ensemble/vr_trees_classifier.rb', line 57 def initialize(n_estimators: 10, criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, n_jobs: nil, random_seed: nil) super end |
Instance Attribute Details
#classes ⇒ Numo::Int32 (readonly)
Return the class labels.
30 31 32 |
# File 'rumale-ensemble/lib/rumale/ensemble/vr_trees_classifier.rb', line 30 def classes @classes end |
#estimators ⇒ Array<VRTreeClassifier> (readonly)
Return the set of estimators.
26 27 28 |
# File 'rumale-ensemble/lib/rumale/ensemble/vr_trees_classifier.rb', line 26 def estimators @estimators end |
#feature_importances ⇒ Numo::DFloat (readonly)
Return the importance for each feature.
34 35 36 |
# File 'rumale-ensemble/lib/rumale/ensemble/vr_trees_classifier.rb', line 34 def feature_importances @feature_importances end |
#rng ⇒ Random (readonly)
Return the random generator for random selection of feature index.
38 39 40 |
# File 'rumale-ensemble/lib/rumale/ensemble/vr_trees_classifier.rb', line 38 def rng @rng end |
Instance Method Details
#apply(x) ⇒ Numo::Int32
Return the index of the leaf that each sample reached.
122 123 124 125 126 |
# File 'rumale-ensemble/lib/rumale/ensemble/vr_trees_classifier.rb', line 122 def apply(x) x = ::Rumale::Validation.check_convert_sample_array(x) super end |
#fit(x, y) ⇒ VRTreesClassifier
Fit the model with given training data.
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
# File 'rumale-ensemble/lib/rumale/ensemble/vr_trees_classifier.rb', line 68 def fit(x, y) x = ::Rumale::Validation.check_convert_sample_array(x) y = ::Rumale::Validation.check_convert_label_array(y) ::Rumale::Validation.check_sample_size(x, y) # Initialize some variables. n_features = x.shape[1] @params[:max_features] = n_features if @params[:max_features].nil? @params[:max_features] = @params[:max_features].clamp(1, n_features) @classes = Numo::Int32.asarray(y.to_a.uniq.sort) sub_rng = @rng.dup # Construct trees. rng_seeds = Array.new(@params[:n_estimators]) { sub_rng.rand(::Rumale::Ensemble::Value::SEED_BASE) } alpha_ratio = 0.5 / @params[:n_estimators] alphas = Array.new(@params[:n_estimators]) { |v| v * alpha_ratio } @estimators = if enable_parallel? parallel_map(@params[:n_estimators]) { |n| plant_tree(alphas[n], rng_seeds[n]).fit(x, y) } else Array.new(@params[:n_estimators]) { |n| plant_tree(alphas[n], rng_seeds[n]).fit(x, y) } end @feature_importances = if enable_parallel? parallel_map(@params[:n_estimators]) { |n| @estimators[n].feature_importances }.sum else @estimators.sum(&:feature_importances) end @feature_importances /= @feature_importances.sum self end |
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
102 103 104 105 106 |
# File 'rumale-ensemble/lib/rumale/ensemble/vr_trees_classifier.rb', line 102 def predict(x) x = ::Rumale::Validation.check_convert_sample_array(x) super end |
#predict_proba(x) ⇒ Numo::DFloat
Predict probability for samples.
112 113 114 115 116 |
# File 'rumale-ensemble/lib/rumale/ensemble/vr_trees_classifier.rb', line 112 def predict_proba(x) x = ::Rumale::Validation.check_convert_sample_array(x) super end |