Class: Rumale::SVM::ClusteredSVC
- Inherits:
-
Base::Estimator
- Object
- Base::Estimator
- Rumale::SVM::ClusteredSVC
- Includes:
- Base::Classifier
- Defined in:
- lib/rumale/svm/clustered_svc.rb
Overview
ClusteredSVC is a class that implements Clustered Support Vector Classifier.
Reference
-
Gu, Q., and Han, J., “Clustered Support Vector Machines,” In Proc. AISTATS’13, pp. 307–315, 2013.
Instance Attribute Summary collapse
-
#cluster_centers ⇒ Numo::DFloat
Return the centroids.
-
#model ⇒ LinearSVC
readonly
Return the classifier.
Instance Method Summary collapse
-
#decision_function(x) ⇒ Numo::DFloat
Calculate confidence scores for samples.
-
#fit(x, y) ⇒ ClusteredSVC
Fit the model with given training data.
-
#initialize(n_clusters: 8, reg_param_global: 1.0, max_iter_kmeans: 100, tol_kmeans: 1e-6, penalty: 'l2', loss: 'squared_hinge', dual: true, reg_param: 1.0, fit_bias: true, bias_scale: 1.0, tol: 1e-3, verbose: false, random_seed: nil) ⇒ ClusteredSVC
constructor
Create a new classifier with Random Recursive Support Vector Machine.
-
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
-
#transform(x) ⇒ Numo::DFloat
Transform the given data with the learned model.
Constructor Details
#initialize(n_clusters: 8, reg_param_global: 1.0, max_iter_kmeans: 100, tol_kmeans: 1e-6, penalty: 'l2', loss: 'squared_hinge', dual: true, reg_param: 1.0, fit_bias: true, bias_scale: 1.0, tol: 1e-3, verbose: false, random_seed: nil) ⇒ ClusteredSVC
Create a new classifier with Random Recursive Support Vector Machine.
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
# File 'lib/rumale/svm/clustered_svc.rb', line 52 def initialize(n_clusters: 8, reg_param_global: 1.0, max_iter_kmeans: 100, tol_kmeans: 1e-6, # rubocop:disable Metrics/ParameterLists penalty: 'l2', loss: 'squared_hinge', dual: true, reg_param: 1.0, fit_bias: true, bias_scale: 1.0, tol: 1e-3, verbose: false, random_seed: nil) super() @params = { n_clusters: n_clusters, reg_param_global: reg_param_global, max_iter_kmeans: max_iter_kmeans, tol_kmeans: tol_kmeans, penalty: penalty == 'l1' ? 'l1' : 'l2', loss: loss == 'hinge' ? 'hinge' : 'squared_hinge', dual: dual, reg_param: reg_param.to_f, fit_bias: fit_bias, bias_scale: bias_scale.to_f, tol: tol.to_f, verbose: verbose, random_seed: random_seed || Random.rand(4_294_967_295) } @rng = Random.new(@params[:random_seed]) @cluster_centers = nil end |
Instance Attribute Details
#cluster_centers ⇒ Numo::DFloat
Return the centroids.
31 32 33 |
# File 'lib/rumale/svm/clustered_svc.rb', line 31 def cluster_centers @cluster_centers end |
#model ⇒ LinearSVC (readonly)
Return the classifier.
27 28 29 |
# File 'lib/rumale/svm/clustered_svc.rb', line 27 def model @model end |
Instance Method Details
#decision_function(x) ⇒ Numo::DFloat
Calculate confidence scores for samples.
90 91 92 93 |
# File 'lib/rumale/svm/clustered_svc.rb', line 90 def decision_function(x) z = transform(x) @model.decision_function(z) end |
#fit(x, y) ⇒ ClusteredSVC
Fit the model with given training data.
80 81 82 83 84 |
# File 'lib/rumale/svm/clustered_svc.rb', line 80 def fit(x, y) z = transform(x) @model = LinearSVC.new(**linear_svc_params).fit(z, y) self end |
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
99 100 101 102 |
# File 'lib/rumale/svm/clustered_svc.rb', line 99 def predict(x) z = transform(x) @model.predict(z) end |
#transform(x) ⇒ Numo::DFloat
Transform the given data with the learned model.
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
# File 'lib/rumale/svm/clustered_svc.rb', line 108 def transform(x) clustering(x) if @cluster_centers.nil? cluster_ids = assign_cluster_id(x) x = (x) if fit_bias? n_samples, n_features = x.shape z = Numo::DFloat.zeros(n_samples, n_features * (1 + @params[:n_clusters])) z[true, 0...n_features] = 1.fdiv(Math.sqrt(@params[:reg_param_global])) * x @params[:n_clusters].times do |n| assigned_bits = cluster_ids.eq(n) z[assigned_bits.where, n_features * (n + 1)...n_features * (n + 2)] = x[assigned_bits.where, true] end z end |