Module: Rumale::Base::Classifier

Overview

Module for all classifiers in Rumale.

Instance Method Summary collapse

Instance Method Details

#fitObject

An abstract method for fitting a model.

Raises:

  • (NotImplementedError)


12
13
14
# File 'rumale-core/lib/rumale/base/classifier.rb', line 12

def fit
  raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
end

#predictObject

An abstract method for predicting labels.

Raises:

  • (NotImplementedError)


17
18
19
# File 'rumale-core/lib/rumale/base/classifier.rb', line 17

def predict
  raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
end

#score(x, y) ⇒ Float

Calculate the mean accuracy of the given testing data.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) Testing data.

  • y (Numo::Int32)

    (shape: [n_samples]) True labels for testing data.

Returns:

  • (Float)

    Mean accuracy



26
27
28
29
30
31
32
33
# File 'rumale-core/lib/rumale/base/classifier.rb', line 26

def score(x, y)
  x = ::Rumale::Validation.check_convert_sample_array(x)
  y = ::Rumale::Validation.check_convert_label_array(y)
  ::Rumale::Validation.check_sample_size(x, y)

  predicted = predict(x)
  (y.to_a.map.with_index { |label, n| label == predicted[n] ? 1 : 0 }).sum.fdiv(y.size)
end