Module: Rumale::Base::ClusterAnalyzer

Overview

Module for all clustering algorithms in Rumale.

Instance Method Summary collapse

Instance Method Details

#fit_predictObject

An abstract method for analyzing clusters and predicting cluster indices.

Raises:

  • (NotImplementedError)


10
11
12
# File 'rumale-core/lib/rumale/base/cluster_analyzer.rb', line 10

def fit_predict
  raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
end

#score(x, y) ⇒ Float

Calculate purity of clustering result.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) Testing data.

  • y (Numo::Int32)

    (shape: [n_samples]) True labels for testing data.

Returns:

  • (Float)

    Purity



19
20
21
22
23
24
25
26
27
28
29
30
31
# File 'rumale-core/lib/rumale/base/cluster_analyzer.rb', line 19

def score(x, y)
  x = ::Rumale::Validation.check_convert_sample_array(x)
  y = ::Rumale::Validation.check_convert_label_array(y)
  ::Rumale::Validation.check_sample_size(x, y)

  predicted = fit_predict(x)
  cluster_ids = predicted.to_a.uniq
  class_ids = y.to_a.uniq
  cluster_ids.sum do |k|
    pr_sample_ids = predicted.eq(k).where.to_a
    class_ids.map { |j| (pr_sample_ids & y.eq(j).where.to_a).size }.max
  end.fdiv(y.size)
end