Class: Rumale::MetricLearning::FisherDiscriminantAnalysis
- Inherits:
-
Base::Estimator
- Object
- Base::Estimator
- Rumale::MetricLearning::FisherDiscriminantAnalysis
- Includes:
- Base::Transformer
- Defined in:
- rumale-metric_learning/lib/rumale/metric_learning/fisher_discriminant_analysis.rb
Overview
FisherDiscriminantAnalysis is a class that implements Fisher Discriminant Analysis.
Reference
-
Fisher, R. A., “The use of multiple measurements in taxonomic problems,” Annals of Eugenics, vol. 7, pp. 179–188, 1936.
-
Sugiyama, M., “Local Fisher Discriminant Analysis for Supervised Dimensionality Reduction,” Proc. ICML’06, pp. 905–912, 2006.
Instance Attribute Summary collapse
-
#class_means ⇒ Numo::DFloat
readonly
Returns the class mean vectors.
-
#classes ⇒ Numo::Int32
readonly
Return the class labels.
-
#components ⇒ Numo::DFloat
readonly
Returns the transform matrix.
-
#mean ⇒ Numo::DFloat
readonly
Returns the mean vector.
Attributes inherited from Base::Estimator
Instance Method Summary collapse
-
#fit(x, y) ⇒ FisherDiscriminantAnalysis
Fit the model with given training data.
-
#fit_transform(x, y) ⇒ Numo::DFloat
Fit the model with training data, and then transform them with the learned model.
-
#initialize(n_components: nil) ⇒ FisherDiscriminantAnalysis
constructor
Create a new transformer with FisherDiscriminantAnalysis.
-
#transform(x) ⇒ Numo::DFloat
Transform the given data with the learned model.
Constructor Details
#initialize(n_components: nil) ⇒ FisherDiscriminantAnalysis
Create a new transformer with FisherDiscriminantAnalysis.
44 45 46 47 |
# File 'rumale-metric_learning/lib/rumale/metric_learning/fisher_discriminant_analysis.rb', line 44 def initialize(n_components: nil) super() @params = { n_components: n_components } end |
Instance Attribute Details
#class_means ⇒ Numo::DFloat (readonly)
Returns the class mean vectors.
34 35 36 |
# File 'rumale-metric_learning/lib/rumale/metric_learning/fisher_discriminant_analysis.rb', line 34 def class_means @class_means end |
#classes ⇒ Numo::Int32 (readonly)
Return the class labels.
38 39 40 |
# File 'rumale-metric_learning/lib/rumale/metric_learning/fisher_discriminant_analysis.rb', line 38 def classes @classes end |
#components ⇒ Numo::DFloat (readonly)
Returns the transform matrix.
26 27 28 |
# File 'rumale-metric_learning/lib/rumale/metric_learning/fisher_discriminant_analysis.rb', line 26 def components @components end |
#mean ⇒ Numo::DFloat (readonly)
Returns the mean vector.
30 31 32 |
# File 'rumale-metric_learning/lib/rumale/metric_learning/fisher_discriminant_analysis.rb', line 30 def mean @mean end |
Instance Method Details
#fit(x, y) ⇒ FisherDiscriminantAnalysis
Fit the model with given training data.
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
# File 'rumale-metric_learning/lib/rumale/metric_learning/fisher_discriminant_analysis.rb', line 54 def fit(x, y) x = ::Rumale::Validation.check_convert_sample_array(x) y = ::Rumale::Validation.check_convert_label_array(y) ::Rumale::Validation.check_sample_size(x, y) unless enable_linalg?(warning: false) raise 'FisherDiscriminatAnalysis#fit requires Numo::Linalg but that is not loaded.' end # initialize some variables. n_features = x.shape[1] @classes = Numo::Int32[*y.to_a.uniq.sort] n_classes = @classes.size n_components = if @params[:n_components].nil? [n_features, n_classes - 1].min else [n_features, @params[:n_components]].min end # calculate within and between scatter matricies. within_mat = Numo::DFloat.zeros(n_features, n_features) between_mat = Numo::DFloat.zeros(n_features, n_features) @class_means = Numo::DFloat.zeros(n_classes, n_features) @mean = x.mean(0) @classes.each_with_index do |label, i| mask_vec = y.eq(label) sz_class = mask_vec.count class_samples = x[mask_vec, true] class_mean = class_samples.mean(0) within_mat += (class_samples - class_mean).transpose.dot(class_samples - class_mean) between_mat += sz_class * (class_mean - @mean).(1) * (class_mean - @mean) @class_means[i, true] = class_mean end # calculate components. _, evecs = Numo::Linalg.eigh(between_mat, within_mat, vals_range: (n_features - n_components)...n_features) comps = evecs.reverse(1).transpose.dup @components = n_components == 1 ? comps[0, true].dup : comps.dup self end |
#fit_transform(x, y) ⇒ Numo::DFloat
Fit the model with training data, and then transform them with the learned model.
99 100 101 102 103 104 105 |
# File 'rumale-metric_learning/lib/rumale/metric_learning/fisher_discriminant_analysis.rb', line 99 def fit_transform(x, y) x = ::Rumale::Validation.check_convert_sample_array(x) y = ::Rumale::Validation.check_convert_label_array(y) ::Rumale::Validation.check_sample_size(x, y) fit(x, y).transform(x) end |
#transform(x) ⇒ Numo::DFloat
Transform the given data with the learned model.
111 112 113 114 115 |
# File 'rumale-metric_learning/lib/rumale/metric_learning/fisher_discriminant_analysis.rb', line 111 def transform(x) x = ::Rumale::Validation.check_convert_sample_array(x) x.dot(@components.transpose) end |