Class: Rumale::MetricLearning::LocalFisherDiscriminantAnalysis
- Inherits:
-
Base::Estimator
- Object
- Base::Estimator
- Rumale::MetricLearning::LocalFisherDiscriminantAnalysis
- Includes:
- Base::Transformer
- Defined in:
- rumale-metric_learning/lib/rumale/metric_learning/local_fisher_discriminant_analysis.rb
Overview
LocalFisherDiscriminantAnalysis is a class that implements Local Fisher Discriminant Analysis.
Reference
-
Sugiyama, M., “Local Fisher Discriminant Analysis for Supervised Dimensionality Reduction,” Proc. ICML’06, pp. 905–912, 2006.
Instance Attribute Summary collapse
-
#classes ⇒ Numo::Int32
readonly
Return the class labels.
-
#components ⇒ Numo::DFloat
readonly
Returns the transform matrix.
Attributes inherited from Base::Estimator
Instance Method Summary collapse
-
#fit(x, y) ⇒ LocalFisherDiscriminantAnalysis
Fit the model with given training data.
-
#fit_transform(x, y) ⇒ Numo::DFloat
Fit the model with training data, and then transform them with the learned model.
-
#initialize(n_components: nil, gamma: nil) ⇒ LocalFisherDiscriminantAnalysis
constructor
Create a new transformer with LocalFisherDiscriminantAnalysis.
-
#transform(x) ⇒ Numo::DFloat
Transform the given data with the learned model.
Constructor Details
#initialize(n_components: nil, gamma: nil) ⇒ LocalFisherDiscriminantAnalysis
Create a new transformer with LocalFisherDiscriminantAnalysis.
35 36 37 38 39 40 41 |
# File 'rumale-metric_learning/lib/rumale/metric_learning/local_fisher_discriminant_analysis.rb', line 35 def initialize(n_components: nil, gamma: nil) super() @params = { n_components: n_components, gamma: gamma } end |
Instance Attribute Details
#classes ⇒ Numo::Int32 (readonly)
Return the class labels.
29 30 31 |
# File 'rumale-metric_learning/lib/rumale/metric_learning/local_fisher_discriminant_analysis.rb', line 29 def classes @classes end |
#components ⇒ Numo::DFloat (readonly)
Returns the transform matrix.
25 26 27 |
# File 'rumale-metric_learning/lib/rumale/metric_learning/local_fisher_discriminant_analysis.rb', line 25 def components @components end |
Instance Method Details
#fit(x, y) ⇒ LocalFisherDiscriminantAnalysis
Fit the model with given training data.
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
# File 'rumale-metric_learning/lib/rumale/metric_learning/local_fisher_discriminant_analysis.rb', line 48 def fit(x, y) unless enable_linalg?(warning: false) raise 'LocalFisherDiscriminatAnalysis#fit requires Numo::Linalg but that is not loaded.' end x = Rumale::Validation.check_convert_sample_array(x) y = Rumale::Validation.check_convert_label_array(y) Rumale::Validation.check_sample_size(x, y) # initialize some variables. n_samples, n_features = x.shape @classes = Numo::Int32[*y.to_a.uniq.sort] n_components = @params[:n_components] || n_features @params[:gamma] ||= 1.fdiv(n_features) affinity_mat = Rumale::PairwiseMetric.rbf_kernel(x, nil, @params[:gamma]) affinity_mat[affinity_mat.diag_indices] = 1.0 # calculate within and mixture scatter matricies. class_mat = Numo::DFloat.zeros(n_samples, n_samples) within_weight_mat = Numo::DFloat.zeros(n_samples, n_samples) @classes.each do |label| pos = y.eq(label) n_class_samples = pos.count pos_vec = Numo::DFloat.cast(pos) pos_mat = pos_vec.outer(pos_vec) class_mat += pos_mat within_weight_mat += pos_mat * 1.fdiv(n_class_samples) end mixture_weight_mat = ((affinity_mat - 1) / n_samples) * class_mat + 1.fdiv(n_samples) within_weight_mat *= affinity_mat mixture_weight_mat = mixture_weight_mat.sum(axis: 1).diag - mixture_weight_mat within_weight_mat = within_weight_mat.sum(axis: 1).diag - within_weight_mat # calculate components. mixture_mat = x.transpose.dot(mixture_weight_mat.dot(x)) within_mat = x.transpose.dot(within_weight_mat.dot(x)) _, evecs = Numo::Linalg.eigh(mixture_mat, within_mat, vals_range: (n_features - n_components)...n_features) comps = evecs.reverse(1).transpose.dup @components = n_components == 1 ? comps[0, true].dup : comps.dup self end |
#fit_transform(x, y) ⇒ Numo::DFloat
Fit the model with training data, and then transform them with the learned model.
96 97 98 99 100 101 102 |
# File 'rumale-metric_learning/lib/rumale/metric_learning/local_fisher_discriminant_analysis.rb', line 96 def fit_transform(x, y) x = Rumale::Validation.check_convert_sample_array(x) y = Rumale::Validation.check_convert_label_array(y) Rumale::Validation.check_sample_size(x, y) fit(x, y).transform(x) end |
#transform(x) ⇒ Numo::DFloat
Transform the given data with the learned model.
108 109 110 111 112 |
# File 'rumale-metric_learning/lib/rumale/metric_learning/local_fisher_discriminant_analysis.rb', line 108 def transform(x) x = Rumale::Validation.check_convert_sample_array(x) x.dot(@components.transpose) end |