Class: Rumale::NearestNeighbors::KNeighborsClassifier
- Inherits:
-
Base::Estimator
- Object
- Base::Estimator
- Rumale::NearestNeighbors::KNeighborsClassifier
- Includes:
- Base::Classifier
- Defined in:
- rumale-nearest_neighbors/lib/rumale/nearest_neighbors/k_neighbors_classifier.rb
Overview
KNeighborsClassifier is a class that implements the classifier with the k-nearest neighbors rule. The current implementation uses the Euclidean distance for finding the neighbors.
Instance Attribute Summary collapse
-
#classes ⇒ Numo::Int32
readonly
Return the class labels.
-
#labels ⇒ Numo::Int32
readonly
Return the labels of the prototypes.
-
#prototypes ⇒ Numo::DFloat
readonly
Return the prototypes for the nearest neighbor classifier.
Attributes inherited from Base::Estimator
Instance Method Summary collapse
-
#decision_function(x) ⇒ Numo::DFloat
Calculate confidence scores for samples.
-
#fit(x, y) ⇒ KNeighborsClassifier
Fit the model with given training data.
-
#initialize(n_neighbors: 5, metric: 'euclidean') ⇒ KNeighborsClassifier
constructor
Create a new classifier with the nearest neighbor rule.
-
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
Methods included from Base::Classifier
Constructor Details
#initialize(n_neighbors: 5, metric: 'euclidean') ⇒ KNeighborsClassifier
Create a new classifier with the nearest neighbor rule.
44 45 46 47 48 49 50 |
# File 'rumale-nearest_neighbors/lib/rumale/nearest_neighbors/k_neighbors_classifier.rb', line 44 def initialize(n_neighbors: 5, metric: 'euclidean') super() @params = { n_neighbors: n_neighbors, metric: (metric == 'precomputed' ? 'precomputed' : 'euclidean') } end |
Instance Attribute Details
#classes ⇒ Numo::Int32 (readonly)
Return the class labels.
36 37 38 |
# File 'rumale-nearest_neighbors/lib/rumale/nearest_neighbors/k_neighbors_classifier.rb', line 36 def classes @classes end |
#labels ⇒ Numo::Int32 (readonly)
Return the labels of the prototypes
32 33 34 |
# File 'rumale-nearest_neighbors/lib/rumale/nearest_neighbors/k_neighbors_classifier.rb', line 32 def labels @labels end |
#prototypes ⇒ Numo::DFloat (readonly)
Return the prototypes for the nearest neighbor classifier. If the metric is ‘precomputed’, that returns nil.
28 29 30 |
# File 'rumale-nearest_neighbors/lib/rumale/nearest_neighbors/k_neighbors_classifier.rb', line 28 def prototypes @prototypes end |
Instance Method Details
#decision_function(x) ⇒ Numo::DFloat
Calculate confidence scores for samples.
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
# File 'rumale-nearest_neighbors/lib/rumale/nearest_neighbors/k_neighbors_classifier.rb', line 77 def decision_function(x) x = ::Rumale::Validation.check_convert_sample_array(x) if @params[:metric] == 'precomputed' && x.shape[1] != @labels.size raise ArgumentError, 'Expect the size input matrix to be n_testing_samples-by-n_training_samples.' end n_prototypes = @labels.size n_neighbors = [@params[:n_neighbors], n_prototypes].min n_samples = x.shape[0] n_classes = @classes.size scores = Numo::DFloat.zeros(n_samples, n_classes) distance_matrix = @params[:metric] == 'precomputed' ? x : ::Rumale::PairwiseMetric.euclidean_distance(x, @prototypes) n_samples.times do |m| neighbor_ids = distance_matrix[m, true].to_a.each_with_index.sort.map(&:last)[0...n_neighbors] neighbor_ids.each { |n| scores[m, @classes.to_a.index(@labels[n])] += 1.0 } end scores end |
#fit(x, y) ⇒ KNeighborsClassifier
Fit the model with given training data.
58 59 60 61 62 63 64 65 66 67 68 69 70 |
# File 'rumale-nearest_neighbors/lib/rumale/nearest_neighbors/k_neighbors_classifier.rb', line 58 def fit(x, y) x = ::Rumale::Validation.check_convert_sample_array(x) y = ::Rumale::Validation.check_convert_label_array(y) ::Rumale::Validation.check_sample_size(x, y) if @params[:metric] == 'precomputed' && x.shape[0] != x.shape[1] raise ArgumentError, 'Expect the input distance matrix to be square.' end @prototypes = x.dup if @params[:metric] == 'euclidean' @labels = Numo::Int32.asarray(y.to_a) @classes = Numo::Int32.asarray(y.to_a.uniq.sort) self end |
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
103 104 105 106 107 108 109 110 111 112 |
# File 'rumale-nearest_neighbors/lib/rumale/nearest_neighbors/k_neighbors_classifier.rb', line 103 def predict(x) x = ::Rumale::Validation.check_convert_sample_array(x) if @params[:metric] == 'precomputed' && x.shape[1] != @labels.size raise ArgumentError, 'Expect the size input matrix to be n_samples-by-n_training_samples.' end decision_values = decision_function(x) n_samples = x.shape[0] Numo::Int32.asarray(Array.new(n_samples) { |n| @classes[decision_values[n, true].max_index] }) end |