Class: Rumale::Clustering::MeanShift
- Inherits:
-
Base::Estimator
- Object
- Base::Estimator
- Rumale::Clustering::MeanShift
- Includes:
- Base::ClusterAnalyzer
- Defined in:
- rumale-clustering/lib/rumale/clustering/mean_shift.rb
Overview
MeanShift is a class that implements mean-shift clustering with flat kernel.
Reference
-
Carreira-Perpinan, M A., “A review of mean-shift algorithms for clustering,” arXiv:1503.00687v1.
-
Sheikh, Y A., Khan, E A., and Kanade, T., “Mode-seeking by Medoidshifts,” Proc. ICCV’07, pp. 1–8, 2007.
-
Vedaldi, A., and Soatto, S., “Quick Shift and Kernel Methods for Mode Seeking,” Proc. ECCV’08, pp. 705–718, 2008.
Instance Attribute Summary collapse
-
#cluster_centers ⇒ Numo::DFloat
readonly
Return the centroids.
Attributes inherited from Base::Estimator
Instance Method Summary collapse
-
#fit(x) ⇒ MeanShift
Analysis clusters with given training data.
-
#fit_predict(x) ⇒ Numo::Int32
Analysis clusters and assign samples to clusters.
-
#initialize(bandwidth: 1.0, max_iter: 500, tol: 1e-4) ⇒ MeanShift
constructor
Create a new cluster analyzer with mean-shift algorithm.
-
#predict(x) ⇒ Numo::Int32
Predict cluster labels for samples.
Methods included from Base::ClusterAnalyzer
Constructor Details
#initialize(bandwidth: 1.0, max_iter: 500, tol: 1e-4) ⇒ MeanShift
Create a new cluster analyzer with mean-shift algorithm.
34 35 36 37 38 39 40 41 |
# File 'rumale-clustering/lib/rumale/clustering/mean_shift.rb', line 34 def initialize(bandwidth: 1.0, max_iter: 500, tol: 1e-4) super() @params = { bandwidth: bandwidth, max_iter: max_iter, tol: tol } end |
Instance Attribute Details
#cluster_centers ⇒ Numo::DFloat (readonly)
Return the centroids.
27 28 29 |
# File 'rumale-clustering/lib/rumale/clustering/mean_shift.rb', line 27 def cluster_centers @cluster_centers end |
Instance Method Details
#fit(x) ⇒ MeanShift
Analysis clusters with given training data.
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
# File 'rumale-clustering/lib/rumale/clustering/mean_shift.rb', line 48 def fit(x, _y = nil) x = Rumale::Validation.check_convert_sample_array(x) z = x.dup @params[:max_iter].times do distance_mat = Rumale::PairwiseMetric.euclidean_distance(x, z) kernel_mat = Numo::DFloat.cast(distance_mat.le(@params[:bandwidth])) sum_kernel = kernel_mat.sum(axis: 0) weight_mat = kernel_mat.dot((1 / sum_kernel).diag) updated = weight_mat.transpose.dot(x) break if (z - updated).abs.sum(axis: 1).max <= @params[:tol] z = updated end @cluster_centers = connect_components(z) self end |
#fit_predict(x) ⇒ Numo::Int32
Analysis clusters and assign samples to clusters.
82 83 84 85 86 |
# File 'rumale-clustering/lib/rumale/clustering/mean_shift.rb', line 82 def fit_predict(x) x = Rumale::Validation.check_convert_sample_array(x) fit(x).predict(x) end |
#predict(x) ⇒ Numo::Int32
Predict cluster labels for samples.
72 73 74 75 76 |
# File 'rumale-clustering/lib/rumale/clustering/mean_shift.rb', line 72 def predict(x) x = Rumale::Validation.check_convert_sample_array(x) assign_cluster(x) end |