Module: Rumale::ModelSelection

Defined in:
rumale-model_selection/lib/rumale/model_selection/k_fold.rb,
rumale-model_selection/lib/rumale/model_selection/version.rb,
rumale-model_selection/lib/rumale/model_selection/function.rb,
rumale-model_selection/lib/rumale/model_selection/group_k_fold.rb,
rumale-model_selection/lib/rumale/model_selection/shuffle_split.rb,
rumale-model_selection/lib/rumale/model_selection/grid_search_cv.rb,
rumale-model_selection/lib/rumale/model_selection/cross_validation.rb,
rumale-model_selection/lib/rumale/model_selection/stratified_k_fold.rb,
rumale-model_selection/lib/rumale/model_selection/time_series_split.rb,
rumale-model_selection/lib/rumale/model_selection/group_shuffle_split.rb,
rumale-model_selection/lib/rumale/model_selection/stratified_shuffle_split.rb

Overview

This module consists of the classes for model validation techniques.

Defined Under Namespace

Classes: CrossValidation, GridSearchCV, GroupKFold, GroupShuffleSplit, KFold, ShuffleSplit, StratifiedKFold, StratifiedShuffleSplit, TimeSeriesSplit

Class Method Summary collapse

Class Method Details

.train_test_split(x, y = nil, test_size: 0.1, train_size: nil, stratify: false, random_seed: nil) ⇒ Array<Numo::NArray>

Split randomly data set into test and train data.

Examples:

require 'rumale/model_selection/function'

x_train, x_test, y_train, y_test = Rumale::ModelSelection.train_test_split(x, y, test_size: 0.2, stratify: true, random_seed: 1)

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The dataset to be used to generate data indices.

  • y (Numo::Int32) (defaults to: nil)

    (shape: [n_samples]) The labels to be used to generate data indices for stratified random permutation. If stratify = false, this parameter is ignored.

  • test_size (Float) (defaults to: 0.1)

    The ratio of number of samples for test data.

  • train_size (Float) (defaults to: nil)

    The ratio of number of samples for train data. If nil is given, it sets to 1 - test_size.

  • stratify (Boolean) (defaults to: false)

    The flag indicating whether to perform stratify split.

  • random_seed (Integer) (defaults to: nil)

    The seed value using to initialize the random generator.

Returns:

  • (Array<Numo::NArray>)

    The set of training and testing data.



29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# File 'rumale-model_selection/lib/rumale/model_selection/function.rb', line 29

def train_test_split(x, y = nil, test_size: 0.1, train_size: nil, stratify: false, random_seed: nil)
  splitter = if stratify
               ::Rumale::ModelSelection::StratifiedShuffleSplit.new(
                 n_splits: 1, test_size: test_size, train_size: train_size, random_seed: random_seed
               )
             else
               ::Rumale::ModelSelection::ShuffleSplit.new(
                 n_splits: 1, test_size: test_size, train_size: train_size, random_seed: random_seed
               )
             end
  train_ids, test_ids = splitter.split(x, y).first
  x_train = x[train_ids, true].dup
  y_train = y[train_ids].dup
  x_test = x[test_ids, true].dup
  y_test = y[test_ids].dup
  [x_train, x_test, y_train, y_test]
end