Class: Rumale::MetricLearning::LocalFisherDiscriminantAnalysis

Inherits:
Base::Estimator
  • Object
show all
Includes:
Base::Transformer
Defined in:
rumale-metric_learning/lib/rumale/metric_learning/local_fisher_discriminant_analysis.rb

Overview

LocalFisherDiscriminantAnalysis is a class that implements Local Fisher Discriminant Analysis.

Reference

  • Sugiyama, M., “Local Fisher Discriminant Analysis for Supervised Dimensionality Reduction,” Proc. ICML’06, pp. 905–912, 2006.

Examples:

require 'rumale/metric_learning/local_fisher_discriminant_analysis'

transformer = Rumale::MetricLearning::LocalFisherDiscriminantAnalysis.new
transformer.fit(training_samples, traininig_labels)
low_samples = transformer.transform(testing_samples)

Instance Attribute Summary collapse

Attributes inherited from Base::Estimator

#params

Instance Method Summary collapse

Constructor Details

#initialize(n_components: nil, gamma: nil) ⇒ LocalFisherDiscriminantAnalysis

Create a new transformer with LocalFisherDiscriminantAnalysis.

Parameters:

  • n_components (Integer) (defaults to: nil)

    The number of components.

  • gamma (Float) (defaults to: nil)

    The parameter of rbf kernel, if nil it is 1 / n_features.



35
36
37
38
39
40
41
# File 'rumale-metric_learning/lib/rumale/metric_learning/local_fisher_discriminant_analysis.rb', line 35

def initialize(n_components: nil, gamma: nil)
  super()
  @params = {
    n_components: n_components,
    gamma: gamma
  }
end

Instance Attribute Details

#classesNumo::Int32 (readonly)

Return the class labels.

Returns:

  • (Numo::Int32)

    (shape: [n_classes])



29
30
31
# File 'rumale-metric_learning/lib/rumale/metric_learning/local_fisher_discriminant_analysis.rb', line 29

def classes
  @classes
end

#componentsNumo::DFloat (readonly)

Returns the transform matrix.

Returns:

  • (Numo::DFloat)

    (shape: [n_components, n_features])



25
26
27
# File 'rumale-metric_learning/lib/rumale/metric_learning/local_fisher_discriminant_analysis.rb', line 25

def components
  @components
end

Instance Method Details

#fit(x, y) ⇒ LocalFisherDiscriminantAnalysis

Fit the model with given training data.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The training data to be used for fitting the model.

  • y (Numo::Int32)

    (shape: [n_samples]) The labels to be used for fitting the model.

Returns:



48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# File 'rumale-metric_learning/lib/rumale/metric_learning/local_fisher_discriminant_analysis.rb', line 48

def fit(x, y)
  unless enable_linalg?(warning: false)
    raise 'LocalFisherDiscriminatAnalysis#fit requires Numo::Linalg but that is not loaded.'
  end

  x = Rumale::Validation.check_convert_sample_array(x)
  y = Rumale::Validation.check_convert_label_array(y)
  Rumale::Validation.check_sample_size(x, y)

  # initialize some variables.
  n_samples, n_features = x.shape
  @classes = Numo::Int32[*y.to_a.uniq.sort]
  n_components = @params[:n_components] || n_features
  @params[:gamma] ||= 1.fdiv(n_features)
  affinity_mat = Rumale::PairwiseMetric.rbf_kernel(x, nil, @params[:gamma])
  affinity_mat[affinity_mat.diag_indices] = 1.0

  # calculate within and mixture scatter matricies.
  class_mat = Numo::DFloat.zeros(n_samples, n_samples)
  within_weight_mat = Numo::DFloat.zeros(n_samples, n_samples)
  @classes.each do |label|
    pos = y.eq(label)
    n_class_samples = pos.count
    pos_vec = Numo::DFloat.cast(pos)
    pos_mat = pos_vec.outer(pos_vec)
    class_mat += pos_mat
    within_weight_mat += pos_mat * 1.fdiv(n_class_samples)
  end

  mixture_weight_mat = ((affinity_mat - 1) / n_samples) * class_mat + 1.fdiv(n_samples)
  within_weight_mat *= affinity_mat
  mixture_weight_mat = mixture_weight_mat.sum(axis: 1).diag - mixture_weight_mat
  within_weight_mat = within_weight_mat.sum(axis: 1).diag - within_weight_mat

  # calculate components.
  mixture_mat = x.transpose.dot(mixture_weight_mat.dot(x))
  within_mat = x.transpose.dot(within_weight_mat.dot(x))
  _, evecs = Numo::Linalg.eigh(mixture_mat, within_mat, vals_range: (n_features - n_components)...n_features)
  comps = evecs.reverse(1).transpose.dup
  @components = n_components == 1 ? comps[0, true].dup : comps.dup
  self
end

#fit_transform(x, y) ⇒ Numo::DFloat

Fit the model with training data, and then transform them with the learned model.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The training data to be used for fitting the model.

  • y (Numo::Int32)

    (shape: [n_samples]) The labels to be used for fitting the model.

Returns:

  • (Numo::DFloat)

    (shape: [n_samples, n_components]) The transformed data



96
97
98
99
100
101
102
# File 'rumale-metric_learning/lib/rumale/metric_learning/local_fisher_discriminant_analysis.rb', line 96

def fit_transform(x, y)
  x = Rumale::Validation.check_convert_sample_array(x)
  y = Rumale::Validation.check_convert_label_array(y)
  Rumale::Validation.check_sample_size(x, y)

  fit(x, y).transform(x)
end

#transform(x) ⇒ Numo::DFloat

Transform the given data with the learned model.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The data to be transformed with the learned model.

Returns:

  • (Numo::DFloat)

    (shape: [n_samples, n_components]) The transformed data.



108
109
110
111
112
# File 'rumale-metric_learning/lib/rumale/metric_learning/local_fisher_discriminant_analysis.rb', line 108

def transform(x)
  x = Rumale::Validation.check_convert_sample_array(x)

  x.dot(@components.transpose)
end