Class: Rumale::NeuralNetwork::RBFClassifier
- Inherits:
- 
      BaseRBF
      
        - Object
- Base::Estimator
- BaseRBF
- Rumale::NeuralNetwork::RBFClassifier
 
- Includes:
- Base::Classifier
- Defined in:
- rumale-neural_network/lib/rumale/neural_network/rbf_classifier.rb
Overview
RBFClassifier is a class that implements classifier based on (k-means) radial basis function (RBF) networks.
Reference
- 
Bugmann, G., “Normalized Gaussian Radial Basis Function networks,” Neural Computation, vol. 20, pp. 97–110, 1998. 
- 
Que, Q., and Belkin, M., “Back to the Future: Radial Basis Function Networks Revisited,” Proc. of AISTATS’16, pp. 1375–1383, 2016. 
Instance Attribute Summary collapse
- 
  
    
      #centers  ⇒ Numo::DFloat 
    
    
  
  
  
  
    
      readonly
    
    
  
  
  
  
  
  
    Return the centers in the hidden layer of RBF network. 
- 
  
    
      #classes  ⇒ Numo::Int32 
    
    
  
  
  
  
    
      readonly
    
    
  
  
  
  
  
  
    Return the class labels. 
- 
  
    
      #rng  ⇒ Random 
    
    
  
  
  
  
    
      readonly
    
    
  
  
  
  
  
  
    Return the random generator. 
- 
  
    
      #weight_vec  ⇒ Numo::DFloat 
    
    
  
  
  
  
    
      readonly
    
    
  
  
  
  
  
  
    Return the weight vector. 
Attributes inherited from Base::Estimator
Instance Method Summary collapse
- 
  
    
      #decision_function(x)  ⇒ Numo::DFloat 
    
    
  
  
  
  
  
  
  
  
  
    Calculate confidence scores for samples. 
- 
  
    
      #fit(x, y)  ⇒ RBFClassifier 
    
    
  
  
  
  
  
  
  
  
  
    Fit the model with given training data. 
- 
  
    
      #initialize(hidden_units: 128, gamma: nil, reg_param: 100.0, normalize: false, max_iter: 50, tol: 1e-4, random_seed: nil)  ⇒ RBFClassifier 
    
    
  
  
  
    constructor
  
  
  
  
  
  
  
    Create a new classifier with (k-means) RBF networks. 
- 
  
    
      #predict(x)  ⇒ Numo::Int32 
    
    
  
  
  
  
  
  
  
  
  
    Predict class labels for samples. 
Methods included from Base::Classifier
Constructor Details
#initialize(hidden_units: 128, gamma: nil, reg_param: 100.0, normalize: false, max_iter: 50, tol: 1e-4, random_seed: nil) ⇒ RBFClassifier
Create a new classifier with (k-means) RBF networks.
| 51 52 53 54 | # File 'rumale-neural_network/lib/rumale/neural_network/rbf_classifier.rb', line 51 def initialize(hidden_units: 128, gamma: nil, reg_param: 100.0, normalize: false, max_iter: 50, tol: 1e-4, random_seed: nil) super end | 
Instance Attribute Details
#centers ⇒ Numo::DFloat (readonly)
Return the centers in the hidden layer of RBF network.
| 32 33 34 | # File 'rumale-neural_network/lib/rumale/neural_network/rbf_classifier.rb', line 32 def centers @centers end | 
#classes ⇒ Numo::Int32 (readonly)
Return the class labels.
| 28 29 30 | # File 'rumale-neural_network/lib/rumale/neural_network/rbf_classifier.rb', line 28 def classes @classes end | 
#rng ⇒ Random (readonly)
Return the random generator.
| 40 41 42 | # File 'rumale-neural_network/lib/rumale/neural_network/rbf_classifier.rb', line 40 def rng @rng end | 
#weight_vec ⇒ Numo::DFloat (readonly)
Return the weight vector.
| 36 37 38 | # File 'rumale-neural_network/lib/rumale/neural_network/rbf_classifier.rb', line 36 def weight_vec @weight_vec end | 
Instance Method Details
#decision_function(x) ⇒ Numo::DFloat
Calculate confidence scores for samples.
| 78 79 80 81 82 83 | # File 'rumale-neural_network/lib/rumale/neural_network/rbf_classifier.rb', line 78 def decision_function(x) x = ::Rumale::Validation.check_convert_sample_array(x) h = hidden_output(x) h.dot(@weight_vec) end | 
#fit(x, y) ⇒ RBFClassifier
Fit the model with given training data.
| 61 62 63 64 65 66 67 68 69 70 71 72 | # File 'rumale-neural_network/lib/rumale/neural_network/rbf_classifier.rb', line 61 def fit(x, y) x = ::Rumale::Validation.check_convert_sample_array(x) y = ::Rumale::Validation.check_convert_label_array(y) ::Rumale::Validation.check_sample_size(x, y) raise 'RBFClassifier#fit requires Numo::Linalg but that is not loaded.' unless enable_linalg?(warning: false) @classes = Numo::NArray[*y.to_a.uniq.sort] partial_fit(x, one_hot_encode(y)) self end | 
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
| 89 90 91 92 93 94 95 96 | # File 'rumale-neural_network/lib/rumale/neural_network/rbf_classifier.rb', line 89 def predict(x) x = ::Rumale::Validation.check_convert_sample_array(x) scores = decision_function(x) n_samples, n_classes = scores.shape label_ids = scores.max_index(axis: 1) - Numo::Int32.new(n_samples).seq * n_classes @classes[label_ids].dup end |