Class: Rumale::Clustering::MiniBatchKMeans
- Inherits:
-
Base::Estimator
- Object
- Base::Estimator
- Rumale::Clustering::MiniBatchKMeans
- Includes:
- Base::ClusterAnalyzer
- Defined in:
- rumale-clustering/lib/rumale/clustering/mini_batch_k_means.rb
Overview
MniBatchKMeans is a class that implements K-Means cluster analysis with mini-batch stochastic gradient descent (SGD).
Reference
-
Sculley, D., “Web-scale k-means clustering,” Proc. WWW’10, pp. 1177–1178, 2010.
Instance Attribute Summary collapse
-
#cluster_centers ⇒ Numo::DFloat
readonly
Return the centroids.
-
#rng ⇒ Random
readonly
Return the random generator.
Attributes inherited from Base::Estimator
Instance Method Summary collapse
-
#fit(x) ⇒ KMeans
Analysis clusters with given training data.
-
#fit_predict(x) ⇒ Numo::Int32
Analysis clusters and assign samples to clusters.
-
#initialize(n_clusters: 8, init: 'k-means++', max_iter: 100, batch_size: 100, tol: 1.0e-4, random_seed: nil) ⇒ MiniBatchKMeans
constructor
Create a new cluster analyzer with K-Means method with mini-batch SGD.
-
#predict(x) ⇒ Numo::Int32
Predict cluster labels for samples.
Methods included from Base::ClusterAnalyzer
Constructor Details
#initialize(n_clusters: 8, init: 'k-means++', max_iter: 100, batch_size: 100, tol: 1.0e-4, random_seed: nil) ⇒ MiniBatchKMeans
Create a new cluster analyzer with K-Means method with mini-batch SGD.
40 41 42 43 44 45 46 47 48 49 50 51 |
# File 'rumale-clustering/lib/rumale/clustering/mini_batch_k_means.rb', line 40 def initialize(n_clusters: 8, init: 'k-means++', max_iter: 100, batch_size: 100, tol: 1.0e-4, random_seed: nil) super() @params = { n_clusters: n_clusters, init: (init == 'random' ? 'random' : 'k-means++'), max_iter: max_iter, batch_size: batch_size, tol: tol, random_seed: random_seed || srand } @rng = Random.new(@params[:random_seed]) end |
Instance Attribute Details
#cluster_centers ⇒ Numo::DFloat (readonly)
Return the centroids.
26 27 28 |
# File 'rumale-clustering/lib/rumale/clustering/mini_batch_k_means.rb', line 26 def cluster_centers @cluster_centers end |
#rng ⇒ Random (readonly)
Return the random generator.
30 31 32 |
# File 'rumale-clustering/lib/rumale/clustering/mini_batch_k_means.rb', line 30 def rng @rng end |
Instance Method Details
#fit(x) ⇒ KMeans
Analysis clusters with given training data.
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
# File 'rumale-clustering/lib/rumale/clustering/mini_batch_k_means.rb', line 58 def fit(x, _y = nil) x = ::Rumale::Validation.check_convert_sample_array(x) # initialization. n_samples = x.shape[0] update_counter = Numo::Int32.zeros(@params[:n_clusters]) sub_rng = @rng.dup init_cluster_centers(x, sub_rng) # optimization with mini-batch sgd. @params[:max_iter].times do |_t| sample_ids = Array(0...n_samples).shuffle(random: sub_rng) old_centers = @cluster_centers.dup until (subset_ids = sample_ids.shift(@params[:batch_size])).empty? # sub sampling sub_x = x[subset_ids, true] # assign nearest centroids cluster_labels = assign_cluster(sub_x) # update centroids @params[:n_clusters].times do |c| assigned_bits = cluster_labels.eq(c) next unless assigned_bits.count.positive? update_counter[c] += 1 learning_rate = 1.fdiv(update_counter[c]) update = sub_x[assigned_bits.where, true].mean(axis: 0) @cluster_centers[c, true] = (1 - learning_rate) * @cluster_centers[c, true] + learning_rate * update end end error = Numo::NMath.sqrt(((old_centers - @cluster_centers)**2).sum(axis: 1)).mean break if error <= @params[:tol] end self end |
#fit_predict(x) ⇒ Numo::Int32
Analysis clusters and assign samples to clusters.
106 107 108 109 110 |
# File 'rumale-clustering/lib/rumale/clustering/mini_batch_k_means.rb', line 106 def fit_predict(x) x = ::Rumale::Validation.check_convert_sample_array(x) fit(x).predict(x) end |
#predict(x) ⇒ Numo::Int32
Predict cluster labels for samples.
96 97 98 99 100 |
# File 'rumale-clustering/lib/rumale/clustering/mini_batch_k_means.rb', line 96 def predict(x) x = ::Rumale::Validation.check_convert_sample_array(x) assign_cluster(x) end |