Class: Rumale::NeuralNetwork::RBFClassifier
- Inherits:
-
BaseRBF
- Object
- Base::Estimator
- BaseRBF
- Rumale::NeuralNetwork::RBFClassifier
- Includes:
- Base::Classifier
- Defined in:
- rumale-neural_network/lib/rumale/neural_network/rbf_classifier.rb
Overview
RBFClassifier is a class that implements classifier based on (k-means) radial basis function (RBF) networks.
Reference
-
Bugmann, G., “Normalized Gaussian Radial Basis Function networks,” Neural Computation, vol. 20, pp. 97–110, 1998.
-
Que, Q., and Belkin, M., “Back to the Future: Radial Basis Function Networks Revisited,” Proc. of AISTATS’16, pp. 1375–1383, 2016.
Instance Attribute Summary collapse
-
#centers ⇒ Numo::DFloat
readonly
Return the centers in the hidden layer of RBF network.
-
#classes ⇒ Numo::Int32
readonly
Return the class labels.
-
#rng ⇒ Random
readonly
Return the random generator.
-
#weight_vec ⇒ Numo::DFloat
readonly
Return the weight vector.
Attributes inherited from Base::Estimator
Instance Method Summary collapse
-
#decision_function(x) ⇒ Numo::DFloat
Calculate confidence scores for samples.
-
#fit(x, y) ⇒ RBFClassifier
Fit the model with given training data.
-
#initialize(hidden_units: 128, gamma: nil, reg_param: 100.0, normalize: false, max_iter: 50, tol: 1e-4, random_seed: nil) ⇒ RBFClassifier
constructor
Create a new classifier with (k-means) RBF networks.
-
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
Methods included from Base::Classifier
Constructor Details
#initialize(hidden_units: 128, gamma: nil, reg_param: 100.0, normalize: false, max_iter: 50, tol: 1e-4, random_seed: nil) ⇒ RBFClassifier
Create a new classifier with (k-means) RBF networks.
53 54 55 56 |
# File 'rumale-neural_network/lib/rumale/neural_network/rbf_classifier.rb', line 53 def initialize(hidden_units: 128, gamma: nil, reg_param: 100.0, normalize: false, max_iter: 50, tol: 1e-4, random_seed: nil) super end |
Instance Attribute Details
#centers ⇒ Numo::DFloat (readonly)
Return the centers in the hidden layer of RBF network.
34 35 36 |
# File 'rumale-neural_network/lib/rumale/neural_network/rbf_classifier.rb', line 34 def centers @centers end |
#classes ⇒ Numo::Int32 (readonly)
Return the class labels.
30 31 32 |
# File 'rumale-neural_network/lib/rumale/neural_network/rbf_classifier.rb', line 30 def classes @classes end |
#rng ⇒ Random (readonly)
Return the random generator.
42 43 44 |
# File 'rumale-neural_network/lib/rumale/neural_network/rbf_classifier.rb', line 42 def rng @rng end |
#weight_vec ⇒ Numo::DFloat (readonly)
Return the weight vector.
38 39 40 |
# File 'rumale-neural_network/lib/rumale/neural_network/rbf_classifier.rb', line 38 def weight_vec @weight_vec end |
Instance Method Details
#decision_function(x) ⇒ Numo::DFloat
Calculate confidence scores for samples.
80 81 82 83 84 85 |
# File 'rumale-neural_network/lib/rumale/neural_network/rbf_classifier.rb', line 80 def decision_function(x) x = ::Rumale::Validation.check_convert_sample_array(x) h = hidden_output(x) h.dot(@weight_vec) end |
#fit(x, y) ⇒ RBFClassifier
Fit the model with given training data.
63 64 65 66 67 68 69 70 71 72 73 74 |
# File 'rumale-neural_network/lib/rumale/neural_network/rbf_classifier.rb', line 63 def fit(x, y) x = ::Rumale::Validation.check_convert_sample_array(x) y = ::Rumale::Validation.check_convert_label_array(y) ::Rumale::Validation.check_sample_size(x, y) raise 'RBFClassifier#fit requires Numo::Linalg but that is not loaded.' unless enable_linalg?(warning: false) @classes = Numo::NArray[*y.to_a.uniq.sort] partial_fit(x, one_hot_encode(y)) self end |
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
91 92 93 94 95 96 97 98 |
# File 'rumale-neural_network/lib/rumale/neural_network/rbf_classifier.rb', line 91 def predict(x) x = ::Rumale::Validation.check_convert_sample_array(x) scores = decision_function(x) n_samples, n_classes = scores.shape label_ids = scores.max_index(axis: 1) - Numo::Int32.new(n_samples).seq * n_classes @classes[label_ids].dup end |