Class: Rumale::NaiveBayes::BaseNaiveBayes

Inherits:
Base::Estimator show all
Includes:
Base::Classifier
Defined in:
rumale-naive_bayes/lib/rumale/naive_bayes/base_naive_bayes.rb

Overview

BaseNaiveBayes is a class that has methods for common processes of naive bayes classifier. This class is used internally.

Instance Attribute Summary

Attributes inherited from Base::Estimator

#params

Instance Method Summary collapse

Methods included from Base::Classifier

#fit, #score

Constructor Details

#initializeBaseNaiveBayes

rubocop:disable Lint/UselessMethodDefinition



14
15
16
# File 'rumale-naive_bayes/lib/rumale/naive_bayes/base_naive_bayes.rb', line 14

def initialize # rubocop:disable Lint/UselessMethodDefinition
  super
end

Instance Method Details

#predict(x) ⇒ Numo::Int32

Predict class labels for samples.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The samples to predict the labels.

Returns:

  • (Numo::Int32)

    (shape: [n_samples]) Predicted class label per sample.



22
23
24
25
26
27
28
# File 'rumale-naive_bayes/lib/rumale/naive_bayes/base_naive_bayes.rb', line 22

def predict(x)
  x = ::Rumale::Validation.check_convert_sample_array(x)

  n_samples = x.shape.first
  decision_values = decision_function(x)
  Numo::Int32.asarray(Array.new(n_samples) { |n| @classes[decision_values[n, true].max_index] })
end

#predict_log_proba(x) ⇒ Numo::DFloat

Predict log-probability for samples.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The samples to predict the log-probailities.

Returns:

  • (Numo::DFloat)

    (shape: [n_samples, n_classes]) Predicted log-probability of each class per sample.



34
35
36
37
38
39
40
# File 'rumale-naive_bayes/lib/rumale/naive_bayes/base_naive_bayes.rb', line 34

def predict_log_proba(x)
  x = ::Rumale::Validation.check_convert_sample_array(x)

  n_samples, = x.shape
  log_likelihoods = decision_function(x)
  log_likelihoods - Numo::NMath.log(Numo::NMath.exp(log_likelihoods).sum(axis: 1)).reshape(n_samples, 1)
end

#predict_proba(x) ⇒ Numo::DFloat

Predict probability for samples.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The samples to predict the probailities.

Returns:

  • (Numo::DFloat)

    (shape: [n_samples, n_classes]) Predicted probability of each class per sample.



46
47
48
49
50
# File 'rumale-naive_bayes/lib/rumale/naive_bayes/base_naive_bayes.rb', line 46

def predict_proba(x)
  x = ::Rumale::Validation.check_convert_sample_array(x)

  Numo::NMath.exp(predict_log_proba(x)).abs
end