Class: Rumale::NaiveBayes::BaseNaiveBayes
- Inherits:
-
Base::Estimator
- Object
- Base::Estimator
- Rumale::NaiveBayes::BaseNaiveBayes
- Includes:
- Base::Classifier
- Defined in:
- rumale-naive_bayes/lib/rumale/naive_bayes/base_naive_bayes.rb
Overview
BaseNaiveBayes is a class that has methods for common processes of naive bayes classifier. This class is used internally.
Direct Known Subclasses
BernoulliNB, ComplementNB, GaussianNB, MultinomialNB, NegationNB
Instance Attribute Summary
Attributes inherited from Base::Estimator
Instance Method Summary collapse
-
#initialize ⇒ BaseNaiveBayes
constructor
rubocop:disable Lint/UselessMethodDefinition.
-
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
-
#predict_log_proba(x) ⇒ Numo::DFloat
Predict log-probability for samples.
-
#predict_proba(x) ⇒ Numo::DFloat
Predict probability for samples.
Methods included from Base::Classifier
Constructor Details
#initialize ⇒ BaseNaiveBayes
rubocop:disable Lint/UselessMethodDefinition
14 15 16 |
# File 'rumale-naive_bayes/lib/rumale/naive_bayes/base_naive_bayes.rb', line 14 def initialize # rubocop:disable Lint/UselessMethodDefinition super end |
Instance Method Details
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
22 23 24 25 26 27 28 |
# File 'rumale-naive_bayes/lib/rumale/naive_bayes/base_naive_bayes.rb', line 22 def predict(x) x = ::Rumale::Validation.check_convert_sample_array(x) n_samples = x.shape.first decision_values = decision_function(x) Numo::Int32.asarray(Array.new(n_samples) { |n| @classes[decision_values[n, true].max_index] }) end |
#predict_log_proba(x) ⇒ Numo::DFloat
Predict log-probability for samples.
34 35 36 37 38 39 40 |
# File 'rumale-naive_bayes/lib/rumale/naive_bayes/base_naive_bayes.rb', line 34 def predict_log_proba(x) x = ::Rumale::Validation.check_convert_sample_array(x) n_samples, = x.shape log_likelihoods = decision_function(x) log_likelihoods - Numo::NMath.log(Numo::NMath.exp(log_likelihoods).sum(axis: 1)).reshape(n_samples, 1) end |
#predict_proba(x) ⇒ Numo::DFloat
Predict probability for samples.
46 47 48 49 50 |
# File 'rumale-naive_bayes/lib/rumale/naive_bayes/base_naive_bayes.rb', line 46 def predict_proba(x) x = ::Rumale::Validation.check_convert_sample_array(x) Numo::NMath.exp(predict_log_proba(x)).abs end |