Class: Rumale::Clustering::KMeans
- Inherits:
-
Base::Estimator
- Object
- Base::Estimator
- Rumale::Clustering::KMeans
- Includes:
- Base::ClusterAnalyzer
- Defined in:
- rumale-clustering/lib/rumale/clustering/k_means.rb
Overview
KMeans is a class that implements K-Means cluster analysis. The current implementation uses the Euclidean distance for analyzing the clusters.
Reference
-
Arthur, D., and Vassilvitskii, S., “k-means++: the advantages of careful seeding,” Proc. SODA’07, pp. 1027–1035, 2007.
Instance Attribute Summary collapse
-
#cluster_centers ⇒ Numo::DFloat
readonly
Return the centroids.
-
#rng ⇒ Random
readonly
Return the random generator.
Attributes inherited from Base::Estimator
Instance Method Summary collapse
-
#fit(x) ⇒ KMeans
Analysis clusters with given training data.
-
#fit_predict(x) ⇒ Numo::Int32
Analysis clusters and assign samples to clusters.
-
#initialize(n_clusters: 8, init: 'k-means++', max_iter: 50, tol: 1.0e-4, random_seed: nil) ⇒ KMeans
constructor
Create a new cluster analyzer with K-Means method.
-
#predict(x) ⇒ Numo::Int32
Predict cluster labels for samples.
Methods included from Base::ClusterAnalyzer
Constructor Details
#initialize(n_clusters: 8, init: 'k-means++', max_iter: 50, tol: 1.0e-4, random_seed: nil) ⇒ KMeans
Create a new cluster analyzer with K-Means method.
39 40 41 42 43 44 45 46 47 48 49 |
# File 'rumale-clustering/lib/rumale/clustering/k_means.rb', line 39 def initialize(n_clusters: 8, init: 'k-means++', max_iter: 50, tol: 1.0e-4, random_seed: nil) super() @params = { n_clusters: n_clusters, init: (init == 'random' ? 'random' : 'k-means++'), max_iter: max_iter, tol: tol, random_seed: random_seed || srand } @rng = Random.new(@params[:random_seed]) end |
Instance Attribute Details
#cluster_centers ⇒ Numo::DFloat (readonly)
Return the centroids.
26 27 28 |
# File 'rumale-clustering/lib/rumale/clustering/k_means.rb', line 26 def cluster_centers @cluster_centers end |
#rng ⇒ Random (readonly)
Return the random generator.
30 31 32 |
# File 'rumale-clustering/lib/rumale/clustering/k_means.rb', line 30 def rng @rng end |
Instance Method Details
#fit(x) ⇒ KMeans
Analysis clusters with given training data.
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
# File 'rumale-clustering/lib/rumale/clustering/k_means.rb', line 56 def fit(x, _y = nil) x = ::Rumale::Validation.check_convert_sample_array(x) init_cluster_centers(x) @params[:max_iter].times do |_t| cluster_labels = assign_cluster(x) old_centers = @cluster_centers.dup @params[:n_clusters].times do |n| assigned_bits = cluster_labels.eq(n) @cluster_centers[n, true] = x[assigned_bits.where, true].mean(axis: 0) if assigned_bits.count.positive? end error = Numo::NMath.sqrt(((old_centers - @cluster_centers)**2).sum(axis: 1)).mean break if error <= @params[:tol] end self end |
#fit_predict(x) ⇒ Numo::Int32
Analysis clusters and assign samples to clusters.
87 88 89 90 91 |
# File 'rumale-clustering/lib/rumale/clustering/k_means.rb', line 87 def fit_predict(x) x = ::Rumale::Validation.check_convert_sample_array(x) fit(x).predict(x) end |
#predict(x) ⇒ Numo::Int32
Predict cluster labels for samples.
77 78 79 80 81 |
# File 'rumale-clustering/lib/rumale/clustering/k_means.rb', line 77 def predict(x) x = ::Rumale::Validation.check_convert_sample_array(x) assign_cluster(x) end |