Class: Rumale::NeuralNetwork::RBFClassifier
- Inherits:
-
BaseRBF
- Object
- Base::Estimator
- BaseRBF
- Rumale::NeuralNetwork::RBFClassifier
- Includes:
- Base::Classifier
- Defined in:
- rumale-neural_network/lib/rumale/neural_network/rbf_classifier.rb
Overview
RBFClassifier is a class that implements classifier based on (k-means) radial basis function (RBF) networks.
Reference
-
Bugmann, G., “Normalized Gaussian Radial Basis Function networks,” Neural Computation, vol. 20, pp. 97–110, 1998.
-
Que, Q., and Belkin, M., “Back to the Future: Radial Basis Function Networks Revisited,” Proc. of AISTATS’16, pp. 1375–1383, 2016.
Instance Attribute Summary collapse
-
#centers ⇒ Numo::DFloat
readonly
Return the centers in the hidden layer of RBF network.
-
#classes ⇒ Numo::Int32
readonly
Return the class labels.
-
#rng ⇒ Random
readonly
Return the random generator.
-
#weight_vec ⇒ Numo::DFloat
readonly
Return the weight vector.
Attributes inherited from Base::Estimator
Instance Method Summary collapse
-
#decision_function(x) ⇒ Numo::DFloat
Calculate confidence scores for samples.
-
#fit(x, y) ⇒ RBFClassifier
Fit the model with given training data.
-
#initialize(hidden_units: 128, gamma: nil, reg_param: 100.0, normalize: false, max_iter: 50, tol: 1e-4, random_seed: nil) ⇒ RBFClassifier
constructor
Create a new classifier with (k-means) RBF networks.
-
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
Methods included from Base::Classifier
Constructor Details
#initialize(hidden_units: 128, gamma: nil, reg_param: 100.0, normalize: false, max_iter: 50, tol: 1e-4, random_seed: nil) ⇒ RBFClassifier
Create a new classifier with (k-means) RBF networks.
51 52 53 54 |
# File 'rumale-neural_network/lib/rumale/neural_network/rbf_classifier.rb', line 51 def initialize(hidden_units: 128, gamma: nil, reg_param: 100.0, normalize: false, max_iter: 50, tol: 1e-4, random_seed: nil) super end |
Instance Attribute Details
#centers ⇒ Numo::DFloat (readonly)
Return the centers in the hidden layer of RBF network.
32 33 34 |
# File 'rumale-neural_network/lib/rumale/neural_network/rbf_classifier.rb', line 32 def centers @centers end |
#classes ⇒ Numo::Int32 (readonly)
Return the class labels.
28 29 30 |
# File 'rumale-neural_network/lib/rumale/neural_network/rbf_classifier.rb', line 28 def classes @classes end |
#rng ⇒ Random (readonly)
Return the random generator.
40 41 42 |
# File 'rumale-neural_network/lib/rumale/neural_network/rbf_classifier.rb', line 40 def rng @rng end |
#weight_vec ⇒ Numo::DFloat (readonly)
Return the weight vector.
36 37 38 |
# File 'rumale-neural_network/lib/rumale/neural_network/rbf_classifier.rb', line 36 def weight_vec @weight_vec end |
Instance Method Details
#decision_function(x) ⇒ Numo::DFloat
Calculate confidence scores for samples.
78 79 80 81 82 83 |
# File 'rumale-neural_network/lib/rumale/neural_network/rbf_classifier.rb', line 78 def decision_function(x) x = ::Rumale::Validation.check_convert_sample_array(x) h = hidden_output(x) h.dot(@weight_vec) end |
#fit(x, y) ⇒ RBFClassifier
Fit the model with given training data.
61 62 63 64 65 66 67 68 69 70 71 72 |
# File 'rumale-neural_network/lib/rumale/neural_network/rbf_classifier.rb', line 61 def fit(x, y) x = ::Rumale::Validation.check_convert_sample_array(x) y = ::Rumale::Validation.check_convert_label_array(y) ::Rumale::Validation.check_sample_size(x, y) raise 'RBFClassifier#fit requires Numo::Linalg but that is not loaded.' unless enable_linalg?(warning: false) @classes = Numo::NArray[*y.to_a.uniq.sort] partial_fit(x, one_hot_encode(y)) self end |
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
89 90 91 92 93 94 95 96 |
# File 'rumale-neural_network/lib/rumale/neural_network/rbf_classifier.rb', line 89 def predict(x) x = ::Rumale::Validation.check_convert_sample_array(x) scores = decision_function(x) n_samples, n_classes = scores.shape label_ids = scores.max_index(axis: 1) - Numo::Int32.new(n_samples).seq * n_classes @classes[label_ids].dup end |