Class: Rumale::MetricLearning::NeighbourhoodComponentAnalysis
- Inherits:
-
Base::Estimator
- Object
- Base::Estimator
- Rumale::MetricLearning::NeighbourhoodComponentAnalysis
- Includes:
- Base::Transformer
- Defined in:
- rumale-metric_learning/lib/rumale/metric_learning/neighbourhood_component_analysis.rb
Overview
NeighbourhoodComponentAnalysis is a class that implements Neighbourhood Component Analysis.
Reference
-
Goldberger, J., Roweis, S., Hinton, G., and Salakhutdinov, R., “Neighbourhood Component Analysis,” Advances in NIPS’17, pp. 513–520, 2005.
Instance Attribute Summary collapse
-
#components ⇒ Numo::DFloat
readonly
Returns the neighbourhood components.
-
#n_iter ⇒ Integer
readonly
Return the number of iterations run for optimization.
-
#rng ⇒ Random
readonly
Return the random generator.
Attributes inherited from Base::Estimator
Instance Method Summary collapse
-
#fit(x, y) ⇒ NeighbourhoodComponentAnalysis
Fit the model with given training data.
-
#fit_transform(x, y) ⇒ Numo::DFloat
Fit the model with training data, and then transform them with the learned model.
-
#initialize(n_components: nil, init: 'random', max_iter: 100, tol: 1e-6, verbose: false, random_seed: nil) ⇒ NeighbourhoodComponentAnalysis
constructor
Create a new transformer with NeighbourhoodComponentAnalysis.
-
#transform(x) ⇒ Numo::DFloat
Transform the given data with the learned model.
Constructor Details
#initialize(n_components: nil, init: 'random', max_iter: 100, tol: 1e-6, verbose: false, random_seed: nil) ⇒ NeighbourhoodComponentAnalysis
Create a new transformer with NeighbourhoodComponentAnalysis.
49 50 51 52 53 54 55 56 57 58 59 60 |
# File 'rumale-metric_learning/lib/rumale/metric_learning/neighbourhood_component_analysis.rb', line 49 def initialize(n_components: nil, init: 'random', max_iter: 100, tol: 1e-6, verbose: false, random_seed: nil) super() @params = { n_components: n_components, init: init, max_iter: max_iter, tol: tol, verbose: verbose, random_seed: random_seed || srand } @rng = Random.new(@params[:random_seed]) end |
Instance Attribute Details
#components ⇒ Numo::DFloat (readonly)
Returns the neighbourhood components.
29 30 31 |
# File 'rumale-metric_learning/lib/rumale/metric_learning/neighbourhood_component_analysis.rb', line 29 def components @components end |
#n_iter ⇒ Integer (readonly)
Return the number of iterations run for optimization
33 34 35 |
# File 'rumale-metric_learning/lib/rumale/metric_learning/neighbourhood_component_analysis.rb', line 33 def n_iter @n_iter end |
#rng ⇒ Random (readonly)
Return the random generator.
37 38 39 |
# File 'rumale-metric_learning/lib/rumale/metric_learning/neighbourhood_component_analysis.rb', line 37 def rng @rng end |
Instance Method Details
#fit(x, y) ⇒ NeighbourhoodComponentAnalysis
Fit the model with given training data.
67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
# File 'rumale-metric_learning/lib/rumale/metric_learning/neighbourhood_component_analysis.rb', line 67 def fit(x, y) x = ::Rumale::Validation.check_convert_sample_array(x) y = ::Rumale::Validation.check_convert_label_array(y) ::Rumale::Validation.check_sample_size(x, y) n_features = x.shape[1] n_components = if @params[:n_components].nil? n_features else [n_features, @params[:n_components]].min end @components, @n_iter = optimize_components(x, y, n_features, n_components) self end |
#fit_transform(x, y) ⇒ Numo::DFloat
Fit the model with training data, and then transform them with the learned model.
87 88 89 90 91 92 93 |
# File 'rumale-metric_learning/lib/rumale/metric_learning/neighbourhood_component_analysis.rb', line 87 def fit_transform(x, y) x = ::Rumale::Validation.check_convert_sample_array(x) y = ::Rumale::Validation.check_convert_label_array(y) ::Rumale::Validation.check_sample_size(x, y) fit(x, y).transform(x) end |
#transform(x) ⇒ Numo::DFloat
Transform the given data with the learned model.
99 100 101 102 103 |
# File 'rumale-metric_learning/lib/rumale/metric_learning/neighbourhood_component_analysis.rb', line 99 def transform(x) x = ::Rumale::Validation.check_convert_sample_array(x) x.dot(@components.transpose) end |