Class: Rumale::NeuralNetwork::BaseRBF

Inherits:
Base::Estimator show all
Defined in:
rumale-neural_network/lib/rumale/neural_network/base_rbf.rb

Overview

BaseRBF is an abstract class for implementation of radial basis function (RBF) network estimator. This class is used internally.

Reference

  • Bugmann, G., “Normalized Gaussian Radial Basis Function networks,” Neural Computation, vol. 20, pp. 97–110, 1998.

  • Que, Q., and Belkin, M., “Back to the Future: Radial Basis Function Networks Revisited,” Proc. of AISTATS’16, pp. 1375–1383, 2016.

Direct Known Subclasses

RBFClassifier, RBFRegressor

Instance Attribute Summary

Attributes inherited from Base::Estimator

#params

Instance Method Summary collapse

Constructor Details

#initialize(hidden_units: 128, gamma: nil, reg_param: 100.0, normalize: false, max_iter: 50, tol: 1e-4, random_seed: nil) ⇒ BaseRBF

Create a radial basis function network estimator.

Parameters:

  • hidden_units (Array) (defaults to: 128)

    The number of units in the hidden layer.

  • gamma (Float) (defaults to: nil)

    The parameter for the radial basis function, if nil it is 1 / n_features.

  • reg_param (Float) (defaults to: 100.0)

    The regularization parameter.

  • normalize (Boolean) (defaults to: false)

    The flag indicating whether to normalize the hidden layer output or not.

  • max_iter (Integer) (defaults to: 50)

    The maximum number of iterations for finding centers.

  • tol (Float) (defaults to: 1e-4)

    The tolerance of termination criterion for finding centers.

  • random_seed (Integer) (defaults to: nil)

    The seed value using to initialize the random generator.



24
25
26
27
28
29
30
31
32
33
34
35
36
37
# File 'rumale-neural_network/lib/rumale/neural_network/base_rbf.rb', line 24

def initialize(hidden_units: 128, gamma: nil, reg_param: 100.0, normalize: false,
               max_iter: 50, tol: 1e-4, random_seed: nil)
  super()
  @params = {
    hidden_units: hidden_units,
    gamma: gamma,
    reg_param: reg_param,
    normalize: normalize,
    max_iter: max_iter,
    tol: tol,
    random_seed: random_seed || srand
  }
  @rng = Random.new(@params[:random_seed])
end