Class: Rumale::Torch::NeuralNetClassifier
- Inherits:
-
Base::Estimator
- Object
- Base::Estimator
- Rumale::Torch::NeuralNetClassifier
- Includes:
- Base::Classifier
- Defined in:
- lib/rumale/torch/neural_net_classifier.rb
Overview
NeuralNetClassifier is a class that provides learning and inference by the neural network defined in torch.rb with an interface similar to classifier of Rumale.
Instance Attribute Summary collapse
-
#classes ⇒ Numo::Int32
readonly
Return the class labels.
-
#device ⇒ Torch::Device
Return the compute device.
-
#loss ⇒ Torch::NN
Return the loss function.
-
#model ⇒ Torch::NN::Module
Return the neural nets defined with torch.rb.
-
#optimizer ⇒ Torch::Optim
Return the optimizer.
Instance Method Summary collapse
-
#decision_function(x) ⇒ Numo::DFloat
Calculate confidence scores for samples.
-
#fit(x, y) ⇒ NeuralNetClassifier
Fit the model with given training data.
-
#initialize(model:, device: nil, optimizer: nil, loss: nil, batch_size: 128, max_epoch: 10, shuffle: true, validation_split: 0, verbose: false, random_seed: nil) ⇒ NeuralNetClassifier
constructor
Create a new classifier with neural nets defined by torch.rb.
-
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
Constructor Details
#initialize(model:, device: nil, optimizer: nil, loss: nil, batch_size: 128, max_epoch: 10, shuffle: true, validation_split: 0, verbose: false, random_seed: nil) ⇒ NeuralNetClassifier
Create a new classifier with neural nets defined by torch.rb.
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
# File 'lib/rumale/torch/neural_net_classifier.rb', line 79 def initialize(model:, device: nil, optimizer: nil, loss: nil, batch_size: 128, max_epoch: 10, shuffle: true, validation_split: 0, verbose: false, random_seed: nil) super() @model = model @device = device || ::Torch.device('cpu') @optimizer = optimizer || ::Torch::Optim::Adam.new(model.parameters) @loss = loss || ::Torch::NN::CrossEntropyLoss.new @params = {} @params[:batch_size] = batch_size @params[:max_epoch] = max_epoch @params[:shuffle] = shuffle @params[:validation_split] = validation_split @params[:verbose] = verbose @params[:random_seed] = random_seed || srand define_parameter_accessors end |
Instance Attribute Details
#classes ⇒ Numo::Int32 (readonly)
Return the class labels.
46 47 48 |
# File 'lib/rumale/torch/neural_net_classifier.rb', line 46 def classes @classes end |
#device ⇒ Torch::Device
Return the compute device.
54 55 56 |
# File 'lib/rumale/torch/neural_net_classifier.rb', line 54 def device @device end |
#loss ⇒ Torch::NN
Return the loss function.
62 63 64 |
# File 'lib/rumale/torch/neural_net_classifier.rb', line 62 def loss @loss end |
#model ⇒ Torch::NN::Module
Return the neural nets defined with torch.rb.
50 51 52 |
# File 'lib/rumale/torch/neural_net_classifier.rb', line 50 def model @model end |
#optimizer ⇒ Torch::Optim
Return the optimizer.
58 59 60 |
# File 'lib/rumale/torch/neural_net_classifier.rb', line 58 def optimizer @optimizer end |
Instance Method Details
#decision_function(x) ⇒ Numo::DFloat
Calculate confidence scores for samples.
135 136 137 |
# File 'lib/rumale/torch/neural_net_classifier.rb', line 135 def decision_function(x) Numo::DFloat.cast(::Torch.no_grad { model.call(::Torch.from_numo(x).to(:float32)) }.numo) end |
#fit(x, y) ⇒ NeuralNetClassifier
Fit the model with given training data.
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
# File 'lib/rumale/torch/neural_net_classifier.rb', line 102 def fit(x, y) encoder = Rumale::Preprocessing::LabelEncoder.new y_encoded = encoder.fit_transform(y) @classes = Numo::NArray[*encoder.classes] train_loader, test_loader = prepare_dataset(x, y_encoded) model.children.each do |layer| layer.reset_parameters if layer.class.method_defined?(:reset_parameters) end 1.upto(max_epoch) do |epoch| train(train_loader) display_epoch(train_loader, test_loader, epoch) if verbose end self end |
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
125 126 127 128 129 |
# File 'lib/rumale/torch/neural_net_classifier.rb', line 125 def predict(x) output = ::Torch.no_grad { model.call(::Torch.from_numo(x).to(:float32)) } _, indices = ::Torch.max(output, 1) @classes[indices.numo].dup end |