Class: Rumale::Tree::DecisionTreeClassifier
- Inherits:
- 
      BaseDecisionTree
      
        - Object
- Base::Estimator
- BaseDecisionTree
- Rumale::Tree::DecisionTreeClassifier
 
- Includes:
- Base::Classifier, ExtDecisionTreeClassifier
- Defined in:
- rumale-tree/lib/rumale/tree/decision_tree_classifier.rb
Overview
DecisionTreeClassifier is a class that implements decision tree for classification.
Direct Known Subclasses
Instance Attribute Summary collapse
- 
  
    
      #classes  ⇒ Numo::Int32 
    
    
  
  
  
  
    
      readonly
    
    
  
  
  
  
  
  
    Return the class labels. 
- 
  
    
      #feature_importances  ⇒ Numo::DFloat 
    
    
  
  
  
  
    
      readonly
    
    
  
  
  
  
  
  
    Return the importance for each feature. 
- 
  
    
      #leaf_labels  ⇒ Numo::Int32 
    
    
  
  
  
  
    
      readonly
    
    
  
  
  
  
  
  
    Return the labels assigned each leaf. 
- 
  
    
      #rng  ⇒ Random 
    
    
  
  
  
  
    
      readonly
    
    
  
  
  
  
  
  
    Return the random generator for random selection of feature index. 
- 
  
    
      #tree  ⇒ Node 
    
    
  
  
  
  
    
      readonly
    
    
  
  
  
  
  
  
    Return the learned tree. 
Attributes inherited from Base::Estimator
Instance Method Summary collapse
- 
  
    
      #fit(x, y)  ⇒ DecisionTreeClassifier 
    
    
  
  
  
  
  
  
  
  
  
    Fit the model with given training data. 
- 
  
    
      #initialize(criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil)  ⇒ DecisionTreeClassifier 
    
    
  
  
  
    constructor
  
  
  
  
  
  
  
    Create a new classifier with decision tree algorithm. 
- 
  
    
      #predict(x)  ⇒ Numo::Int32 
    
    
  
  
  
  
  
  
  
  
  
    Predict class labels for samples. 
- 
  
    
      #predict_proba(x)  ⇒ Numo::DFloat 
    
    
  
  
  
  
  
  
  
  
  
    Predict probability for samples. 
Methods included from Base::Classifier
Methods inherited from BaseDecisionTree
Constructor Details
#initialize(criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) ⇒ DecisionTreeClassifier
Create a new classifier with decision tree algorithm.
| 55 56 57 58 | # File 'rumale-tree/lib/rumale/tree/decision_tree_classifier.rb', line 55 def initialize(criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) super end | 
Instance Attribute Details
#classes ⇒ Numo::Int32 (readonly)
Return the class labels.
| 25 26 27 | # File 'rumale-tree/lib/rumale/tree/decision_tree_classifier.rb', line 25 def classes @classes end | 
#feature_importances ⇒ Numo::DFloat (readonly)
Return the importance for each feature.
| 29 30 31 | # File 'rumale-tree/lib/rumale/tree/decision_tree_classifier.rb', line 29 def feature_importances @feature_importances end | 
#leaf_labels ⇒ Numo::Int32 (readonly)
Return the labels assigned each leaf.
| 41 42 43 | # File 'rumale-tree/lib/rumale/tree/decision_tree_classifier.rb', line 41 def leaf_labels @leaf_labels end | 
#rng ⇒ Random (readonly)
Return the random generator for random selection of feature index.
| 37 38 39 | # File 'rumale-tree/lib/rumale/tree/decision_tree_classifier.rb', line 37 def rng @rng end | 
#tree ⇒ Node (readonly)
Return the learned tree.
| 33 34 35 | # File 'rumale-tree/lib/rumale/tree/decision_tree_classifier.rb', line 33 def tree @tree end | 
Instance Method Details
#fit(x, y) ⇒ DecisionTreeClassifier
Fit the model with given training data.
| 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | # File 'rumale-tree/lib/rumale/tree/decision_tree_classifier.rb', line 65 def fit(x, y) x = ::Rumale::Validation.check_convert_sample_array(x) y = ::Rumale::Validation.check_convert_label_array(y) ::Rumale::Validation.check_sample_size(x, y) n_samples, n_features = x.shape @params[:max_features] = n_features if @params[:max_features].nil? @params[:max_features] = [@params[:max_features], n_features].min y = Numo::Int32.cast(y) unless y.is_a?(Numo::Int32) uniq_y = y.to_a.uniq.sort @classes = Numo::Int32.asarray(uniq_y) @n_leaves = 0 @leaf_labels = [] @feature_ids = Array.new(n_features) { |v| v } @sub_rng = @rng.dup build_tree(x, y.map { |v| uniq_y.index(v) }) eval_importance(n_samples, n_features) @leaf_labels = Numo::Int32[*@leaf_labels] self end | 
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
| 90 91 92 93 94 | # File 'rumale-tree/lib/rumale/tree/decision_tree_classifier.rb', line 90 def predict(x) x = ::Rumale::Validation.check_convert_sample_array(x) @leaf_labels[apply(x)].dup end | 
#predict_proba(x) ⇒ Numo::DFloat
Predict probability for samples.
| 100 101 102 103 104 | # File 'rumale-tree/lib/rumale/tree/decision_tree_classifier.rb', line 100 def predict_proba(x) x = ::Rumale::Validation.check_convert_sample_array(x) Numo::DFloat[*Array.new(x.shape[0]) { |n| partial_predict_proba(@tree, x[n, true]) }] end |