Class: Rumale::Clustering::GaussianMixture
- Inherits:
-
Base::Estimator
- Object
- Base::Estimator
- Rumale::Clustering::GaussianMixture
- Includes:
- Base::ClusterAnalyzer
- Defined in:
- rumale-clustering/lib/rumale/clustering/gaussian_mixture.rb
Overview
GaussianMixture is a class that implements cluster analysis with gaussian mixture model.
Instance Attribute Summary collapse
-
#covariances ⇒ Numo::DFloat
readonly
Return the diagonal elements of covariance matrix of each cluster.
-
#means ⇒ Numo::DFloat
readonly
Return the mean of each cluster.
-
#n_iter ⇒ Integer
readonly
Return the number of iterations to covergence.
-
#weights ⇒ Numo::DFloat
readonly
Return the weight of each cluster.
Attributes inherited from Base::Estimator
Instance Method Summary collapse
-
#fit(x) ⇒ GaussianMixture
Analysis clusters with given training data.
-
#fit_predict(x) ⇒ Numo::Int32
Analysis clusters and assign samples to clusters.
-
#initialize(n_clusters: 8, init: 'k-means++', covariance_type: 'diag', max_iter: 50, tol: 1.0e-4, reg_covar: 1.0e-6, random_seed: nil) ⇒ GaussianMixture
constructor
Create a new cluster analyzer with gaussian mixture model.
-
#predict(x) ⇒ Numo::Int32
Predict cluster labels for samples.
Methods included from Base::ClusterAnalyzer
Constructor Details
#initialize(n_clusters: 8, init: 'k-means++', covariance_type: 'diag', max_iter: 50, tol: 1.0e-4, reg_covar: 1.0e-6, random_seed: nil) ⇒ GaussianMixture
Create a new cluster analyzer with gaussian mixture model.
54 55 56 57 58 59 60 61 62 63 64 65 66 |
# File 'rumale-clustering/lib/rumale/clustering/gaussian_mixture.rb', line 54 def initialize(n_clusters: 8, init: 'k-means++', covariance_type: 'diag', max_iter: 50, tol: 1.0e-4, reg_covar: 1.0e-6, random_seed: nil) super() @params = { n_clusters: n_clusters, init: (init == 'random' ? 'random' : 'k-means++'), covariance_type: (covariance_type == 'full' ? 'full' : 'diag'), max_iter: max_iter, tol: tol, reg_covar: reg_covar, random_seed: random_seed || srand } end |
Instance Attribute Details
#covariances ⇒ Numo::DFloat (readonly)
Return the diagonal elements of covariance matrix of each cluster.
43 44 45 |
# File 'rumale-clustering/lib/rumale/clustering/gaussian_mixture.rb', line 43 def covariances @covariances end |
#means ⇒ Numo::DFloat (readonly)
Return the mean of each cluster.
39 40 41 |
# File 'rumale-clustering/lib/rumale/clustering/gaussian_mixture.rb', line 39 def means @means end |
#n_iter ⇒ Integer (readonly)
Return the number of iterations to covergence.
31 32 33 |
# File 'rumale-clustering/lib/rumale/clustering/gaussian_mixture.rb', line 31 def n_iter @n_iter end |
#weights ⇒ Numo::DFloat (readonly)
Return the weight of each cluster.
35 36 37 |
# File 'rumale-clustering/lib/rumale/clustering/gaussian_mixture.rb', line 35 def weights @weights end |
Instance Method Details
#fit(x) ⇒ GaussianMixture
Analysis clusters with given training data.
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
# File 'rumale-clustering/lib/rumale/clustering/gaussian_mixture.rb', line 73 def fit(x, _y = nil) check_enable_linalg('fit') x = ::Rumale::Validation.check_convert_sample_array(x) n_samples = x.shape[0] memberships = init_memberships(x) @params[:max_iter].times do |t| @n_iter = t @weights = calc_weights(n_samples, memberships) @means = calc_means(x, memberships) @covariances = calc_covariances(x, @means, memberships, @params[:reg_covar], @params[:covariance_type]) new_memberships = calc_memberships(x, @weights, @means, @covariances, @params[:covariance_type]) error = (memberships - new_memberships).abs.max break if error <= @params[:tol] memberships = new_memberships.dup end self end |
#fit_predict(x) ⇒ Numo::Int32
Analysis clusters and assign samples to clusters.
109 110 111 112 113 114 |
# File 'rumale-clustering/lib/rumale/clustering/gaussian_mixture.rb', line 109 def fit_predict(x) check_enable_linalg('fit_predict') x = ::Rumale::Validation.check_convert_sample_array(x) fit(x).predict(x) end |
#predict(x) ⇒ Numo::Int32
Predict cluster labels for samples.
97 98 99 100 101 102 103 |
# File 'rumale-clustering/lib/rumale/clustering/gaussian_mixture.rb', line 97 def predict(x) check_enable_linalg('predict') x = ::Rumale::Validation.check_convert_sample_array(x) memberships = calc_memberships(x, @weights, @means, @covariances, @params[:covariance_type]) assign_cluster(memberships) end |