Class: Rumale::Clustering::MeanShift

Inherits:
Base::Estimator show all
Includes:
Base::ClusterAnalyzer
Defined in:
rumale-clustering/lib/rumale/clustering/mean_shift.rb

Overview

MeanShift is a class that implements mean-shift clustering with flat kernel.

Reference

  • Carreira-Perpinan, M A., “A review of mean-shift algorithms for clustering,” arXiv:1503.00687v1.

  • Sheikh, Y A., Khan, E A., and Kanade, T., “Mode-seeking by Medoidshifts,” Proc. ICCV’07, pp. 1–8, 2007.

  • Vedaldi, A., and Soatto, S., “Quick Shift and Kernel Methods for Mode Seeking,” Proc. ECCV’08, pp. 705–718, 2008.

Examples:

require 'rumale/clustering/mean_shift'

analyzer = Rumale::Clustering::MeanShift.new(bandwidth: 1.5)
cluster_labels = analyzer.fit_predict(samples)

Instance Attribute Summary collapse

Attributes inherited from Base::Estimator

#params

Instance Method Summary collapse

Methods included from Base::ClusterAnalyzer

#score

Constructor Details

#initialize(bandwidth: 1.0, max_iter: 500, tol: 1e-4) ⇒ MeanShift

Create a new cluster analyzer with mean-shift algorithm.

Parameters:

  • bandwidth (Float) (defaults to: 1.0)

    The bandwidth parameter of flat kernel.

  • max_iter (Integer) (defaults to: 500)

    The maximum number of iterations.

  • tol (Float) (defaults to: 1e-4)

    The tolerance of termination criterion



34
35
36
37
38
39
40
41
# File 'rumale-clustering/lib/rumale/clustering/mean_shift.rb', line 34

def initialize(bandwidth: 1.0, max_iter: 500, tol: 1e-4)
  super()
  @params = {
    bandwidth: bandwidth,
    max_iter: max_iter,
    tol: tol
  }
end

Instance Attribute Details

#cluster_centersNumo::DFloat (readonly)

Return the centroids.

Returns:

  • (Numo::DFloat)

    (shape: [n_clusters, n_features])



27
28
29
# File 'rumale-clustering/lib/rumale/clustering/mean_shift.rb', line 27

def cluster_centers
  @cluster_centers
end

Instance Method Details

#fit(x) ⇒ MeanShift

Analysis clusters with given training data.

Returns The learned cluster analyzer itself.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The training data to be used for cluster analysis.

Returns:

  • (MeanShift)

    The learned cluster analyzer itself.



48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# File 'rumale-clustering/lib/rumale/clustering/mean_shift.rb', line 48

def fit(x, _y = nil)
  x = Rumale::Validation.check_convert_sample_array(x)

  z = x.dup
  @params[:max_iter].times do
    distance_mat = Rumale::PairwiseMetric.euclidean_distance(x, z)
    kernel_mat = Numo::DFloat.cast(distance_mat.le(@params[:bandwidth]))
    sum_kernel = kernel_mat.sum(axis: 0)
    weight_mat = kernel_mat.dot((1 / sum_kernel).diag)
    updated = weight_mat.transpose.dot(x)
    break if (z - updated).abs.sum(axis: 1).max <= @params[:tol]

    z = updated
  end

  @cluster_centers = connect_components(z)

  self
end

#fit_predict(x) ⇒ Numo::Int32

Analysis clusters and assign samples to clusters.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The training data to be used for cluster analysis.

Returns:

  • (Numo::Int32)

    (shape: [n_samples]) Predicted cluster label per sample.



82
83
84
85
86
# File 'rumale-clustering/lib/rumale/clustering/mean_shift.rb', line 82

def fit_predict(x)
  x = Rumale::Validation.check_convert_sample_array(x)

  fit(x).predict(x)
end

#predict(x) ⇒ Numo::Int32

Predict cluster labels for samples.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The samples to predict the cluster label.

Returns:

  • (Numo::Int32)

    (shape: [n_samples]) Predicted cluster label per sample.



72
73
74
75
76
# File 'rumale-clustering/lib/rumale/clustering/mean_shift.rb', line 72

def predict(x)
  x = Rumale::Validation.check_convert_sample_array(x)

  assign_cluster(x)
end