Class: Rumale::Tree::BaseDecisionTree
- Inherits:
-
Base::Estimator
- Object
- Base::Estimator
- Rumale::Tree::BaseDecisionTree
- Defined in:
- rumale-tree/lib/rumale/tree/base_decision_tree.rb
Overview
BaseDecisionTree is an abstract class for implementation of decision tree-based estimator. This class is used internally.
Direct Known Subclasses
Instance Attribute Summary
Attributes inherited from Base::Estimator
Instance Method Summary collapse
-
#apply(x) ⇒ Numo::Int32
Return the index of the leaf that each sample reached.
-
#initialize(criterion: nil, max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) ⇒ BaseDecisionTree
constructor
Initialize a decision tree-based estimator.
Constructor Details
#initialize(criterion: nil, max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) ⇒ BaseDecisionTree
Initialize a decision tree-based estimator.
25 26 27 28 29 30 31 32 33 34 35 36 37 |
# File 'rumale-tree/lib/rumale/tree/base_decision_tree.rb', line 25 def initialize(criterion: nil, max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) super() @params = { criterion: criterion, max_depth: max_depth, max_leaf_nodes: max_leaf_nodes, min_samples_leaf: min_samples_leaf, max_features: max_features, random_seed: random_seed || srand } @rng = Random.new(@params[:random_seed]) end |
Instance Method Details
#apply(x) ⇒ Numo::Int32
Return the index of the leaf that each sample reached.
43 44 45 46 47 |
# File 'rumale-tree/lib/rumale/tree/base_decision_tree.rb', line 43 def apply(x) x = ::Rumale::Validation.check_convert_sample_array(x) Numo::Int32[*(Array.new(x.shape[0]) { |n| partial_apply(@tree, x[n, true]) })] end |