Class: Rumale::MetricLearning::FisherDiscriminantAnalysis

Inherits:
Base::Estimator
  • Object
show all
Includes:
Base::Transformer
Defined in:
rumale-metric_learning/lib/rumale/metric_learning/fisher_discriminant_analysis.rb

Overview

FisherDiscriminantAnalysis is a class that implements Fisher Discriminant Analysis.

Reference

  • Fisher, R. A., “The use of multiple measurements in taxonomic problems,” Annals of Eugenics, vol. 7, pp. 179–188, 1936.

  • Sugiyama, M., “Local Fisher Discriminant Analysis for Supervised Dimensionality Reduction,” Proc. ICML’06, pp. 905–912, 2006.

Examples:

require 'rumale/metric_learning/fisher_discriminant_analysis'

transformer = Rumale::MetricLearning::FisherDiscriminantAnalysis.new
transformer.fit(training_samples, traininig_labels)
low_samples = transformer.transform(testing_samples)

Instance Attribute Summary collapse

Attributes inherited from Base::Estimator

#params

Instance Method Summary collapse

Constructor Details

#initialize(n_components: nil) ⇒ FisherDiscriminantAnalysis

Create a new transformer with FisherDiscriminantAnalysis.

Parameters:

  • n_components (Integer) (defaults to: nil)

    The number of components. If nil is given, the number of components will be set to [n_features, n_classes - 1].min



44
45
46
47
# File 'rumale-metric_learning/lib/rumale/metric_learning/fisher_discriminant_analysis.rb', line 44

def initialize(n_components: nil)
  super()
  @params = { n_components: n_components }
end

Instance Attribute Details

#class_meansNumo::DFloat (readonly)

Returns the class mean vectors.

Returns:

  • (Numo::DFloat)

    (shape: [n_classes, n_features])



34
35
36
# File 'rumale-metric_learning/lib/rumale/metric_learning/fisher_discriminant_analysis.rb', line 34

def class_means
  @class_means
end

#classesNumo::Int32 (readonly)

Return the class labels.

Returns:

  • (Numo::Int32)

    (shape: [n_classes])



38
39
40
# File 'rumale-metric_learning/lib/rumale/metric_learning/fisher_discriminant_analysis.rb', line 38

def classes
  @classes
end

#componentsNumo::DFloat (readonly)

Returns the transform matrix.

Returns:

  • (Numo::DFloat)

    (shape: [n_components, n_features])



26
27
28
# File 'rumale-metric_learning/lib/rumale/metric_learning/fisher_discriminant_analysis.rb', line 26

def components
  @components
end

#meanNumo::DFloat (readonly)

Returns the mean vector.

Returns:

  • (Numo::DFloat)

    (shape: [n_features])



30
31
32
# File 'rumale-metric_learning/lib/rumale/metric_learning/fisher_discriminant_analysis.rb', line 30

def mean
  @mean
end

Instance Method Details

#fit(x, y) ⇒ FisherDiscriminantAnalysis

Fit the model with given training data.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The training data to be used for fitting the model.

  • y (Numo::Int32)

    (shape: [n_samples]) The labels to be used for fitting the model.

Returns:



54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# File 'rumale-metric_learning/lib/rumale/metric_learning/fisher_discriminant_analysis.rb', line 54

def fit(x, y)
  x = ::Rumale::Validation.check_convert_sample_array(x)
  y = ::Rumale::Validation.check_convert_label_array(y)
  ::Rumale::Validation.check_sample_size(x, y)
  unless enable_linalg?(warning: false)
    raise 'FisherDiscriminatAnalysis#fit requires Numo::Linalg but that is not loaded.'
  end

  # initialize some variables.
  n_features = x.shape[1]
  @classes = Numo::Int32[*y.to_a.uniq.sort]
  n_classes = @classes.size
  n_components = if @params[:n_components].nil?
                   [n_features, n_classes - 1].min
                 else
                   [n_features, @params[:n_components]].min
                 end

  # calculate within and between scatter matricies.
  within_mat = Numo::DFloat.zeros(n_features, n_features)
  between_mat = Numo::DFloat.zeros(n_features, n_features)
  @class_means = Numo::DFloat.zeros(n_classes, n_features)
  @mean = x.mean(0)
  @classes.each_with_index do |label, i|
    mask_vec = y.eq(label)
    sz_class = mask_vec.count
    class_samples = x[mask_vec, true]
    class_mean = class_samples.mean(0)
    within_mat += (class_samples - class_mean).transpose.dot(class_samples - class_mean)
    between_mat += sz_class * (class_mean - @mean).expand_dims(1) * (class_mean - @mean)
    @class_means[i, true] = class_mean
  end

  # calculate components.
  _, evecs = Numo::Linalg.eigh(between_mat, within_mat, vals_range: (n_features - n_components)...n_features)
  comps = evecs.reverse(1).transpose.dup
  @components = n_components == 1 ? comps[0, true].dup : comps.dup
  self
end

#fit_transform(x, y) ⇒ Numo::DFloat

Fit the model with training data, and then transform them with the learned model.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The training data to be used for fitting the model.

  • y (Numo::Int32)

    (shape: [n_samples]) The labels to be used for fitting the model.

Returns:

  • (Numo::DFloat)

    (shape: [n_samples, n_components]) The transformed data



99
100
101
102
103
104
105
# File 'rumale-metric_learning/lib/rumale/metric_learning/fisher_discriminant_analysis.rb', line 99

def fit_transform(x, y)
  x = ::Rumale::Validation.check_convert_sample_array(x)
  y = ::Rumale::Validation.check_convert_label_array(y)
  ::Rumale::Validation.check_sample_size(x, y)

  fit(x, y).transform(x)
end

#transform(x) ⇒ Numo::DFloat

Transform the given data with the learned model.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The data to be transformed with the learned model.

Returns:

  • (Numo::DFloat)

    (shape: [n_samples, n_components]) The transformed data.



111
112
113
114
115
# File 'rumale-metric_learning/lib/rumale/metric_learning/fisher_discriminant_analysis.rb', line 111

def transform(x)
  x = ::Rumale::Validation.check_convert_sample_array(x)

  x.dot(@components.transpose)
end