Class: Rumale::Clustering::MiniBatchKMeans

Inherits:
Base::Estimator show all
Includes:
Base::ClusterAnalyzer
Defined in:
rumale-clustering/lib/rumale/clustering/mini_batch_k_means.rb

Overview

MniBatchKMeans is a class that implements K-Means cluster analysis with mini-batch stochastic gradient descent (SGD).

Reference

  • Sculley, D., “Web-scale k-means clustering,” Proc. WWW’10, pp. 1177–1178, 2010.

Examples:

require 'rumale/clustering/mini_batch_k_means'

analyzer = Rumale::Clustering::MiniBatchKMeans.new(n_clusters: 10, max_iter: 50, batch_size: 50, random_seed: 1)
cluster_labels = analyzer.fit_predict(samples)

Instance Attribute Summary collapse

Attributes inherited from Base::Estimator

#params

Instance Method Summary collapse

Methods included from Base::ClusterAnalyzer

#score

Constructor Details

#initialize(n_clusters: 8, init: 'k-means++', max_iter: 100, batch_size: 100, tol: 1.0e-4, random_seed: nil) ⇒ MiniBatchKMeans

Create a new cluster analyzer with K-Means method with mini-batch SGD.

Parameters:

  • n_clusters (Integer) (defaults to: 8)

    The number of clusters.

  • init (String) (defaults to: 'k-means++')

    The initialization method for centroids (‘random’ or ‘k-means++’).

  • max_iter (Integer) (defaults to: 100)

    The maximum number of iterations.

  • batch_size (Integer) (defaults to: 100)

    The size of the mini batches.

  • tol (Float) (defaults to: 1.0e-4)

    The tolerance of termination criterion.

  • random_seed (Integer) (defaults to: nil)

    The seed value using to initialize the random generator.



40
41
42
43
44
45
46
47
48
49
50
51
# File 'rumale-clustering/lib/rumale/clustering/mini_batch_k_means.rb', line 40

def initialize(n_clusters: 8, init: 'k-means++', max_iter: 100, batch_size: 100, tol: 1.0e-4, random_seed: nil)
  super()
  @params = {
    n_clusters: n_clusters,
    init: (init == 'random' ? 'random' : 'k-means++'),
    max_iter: max_iter,
    batch_size: batch_size,
    tol: tol,
    random_seed: random_seed || srand
  }
  @rng = Random.new(@params[:random_seed])
end

Instance Attribute Details

#cluster_centersNumo::DFloat (readonly)

Return the centroids.

Returns:

  • (Numo::DFloat)

    (shape: [n_clusters, n_features])



26
27
28
# File 'rumale-clustering/lib/rumale/clustering/mini_batch_k_means.rb', line 26

def cluster_centers
  @cluster_centers
end

#rngRandom (readonly)

Return the random generator.

Returns:

  • (Random)


30
31
32
# File 'rumale-clustering/lib/rumale/clustering/mini_batch_k_means.rb', line 30

def rng
  @rng
end

Instance Method Details

#fit(x) ⇒ KMeans

Analysis clusters with given training data.

Returns The learned cluster analyzer itself.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The training data to be used for cluster analysis.

Returns:

  • (KMeans)

    The learned cluster analyzer itself.



58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# File 'rumale-clustering/lib/rumale/clustering/mini_batch_k_means.rb', line 58

def fit(x, _y = nil)
  x = ::Rumale::Validation.check_convert_sample_array(x)

  # initialization.
  n_samples = x.shape[0]
  update_counter = Numo::Int32.zeros(@params[:n_clusters])
  sub_rng = @rng.dup
  init_cluster_centers(x, sub_rng)
  # optimization with mini-batch sgd.
  @params[:max_iter].times do |_t|
    sample_ids = Array(0...n_samples).shuffle(random: sub_rng)
    old_centers = @cluster_centers.dup
    until (subset_ids = sample_ids.shift(@params[:batch_size])).empty?
      # sub sampling
      sub_x = x[subset_ids, true]
      # assign nearest centroids
      cluster_labels = assign_cluster(sub_x)
      # update centroids
      @params[:n_clusters].times do |c|
        assigned_bits = cluster_labels.eq(c)
        next unless assigned_bits.count.positive?

        update_counter[c] += 1
        learning_rate = 1.fdiv(update_counter[c])
        update = sub_x[assigned_bits.where, true].mean(axis: 0)
        @cluster_centers[c, true] = (1 - learning_rate) * @cluster_centers[c, true] + learning_rate * update
      end
    end
    error = Numo::NMath.sqrt(((old_centers - @cluster_centers)**2).sum(axis: 1)).mean
    break if error <= @params[:tol]
  end
  self
end

#fit_predict(x) ⇒ Numo::Int32

Analysis clusters and assign samples to clusters.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The training data to be used for cluster analysis.

Returns:

  • (Numo::Int32)

    (shape: [n_samples]) Predicted cluster label per sample.



106
107
108
109
110
# File 'rumale-clustering/lib/rumale/clustering/mini_batch_k_means.rb', line 106

def fit_predict(x)
  x = ::Rumale::Validation.check_convert_sample_array(x)

  fit(x).predict(x)
end

#predict(x) ⇒ Numo::Int32

Predict cluster labels for samples.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The samples to predict the cluster label.

Returns:

  • (Numo::Int32)

    (shape: [n_samples]) Predicted cluster label per sample.



96
97
98
99
100
# File 'rumale-clustering/lib/rumale/clustering/mini_batch_k_means.rb', line 96

def predict(x)
  x = ::Rumale::Validation.check_convert_sample_array(x)

  assign_cluster(x)
end