Class: Rumale::ModelSelection::ShuffleSplit

Inherits:
Object
  • Object
show all
Includes:
Base::Splitter
Defined in:
rumale-model_selection/lib/rumale/model_selection/shuffle_split.rb

Overview

ShuffleSplit is a class that generates the set of data indices for random permutation cross-validation.

Examples:

require 'rumale/model_selection/shuffle_split'

ss = Rumale::ModelSelection::ShuffleSplit.new(n_splits: 3, test_size: 0.2, random_seed: 1)
ss.split(samples, labels).each do |train_ids, test_ids|
  train_samples = samples[train_ids, true]
  test_samples = samples[test_ids, true]
  ...
end

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(n_splits: 3, test_size: 0.1, train_size: nil, random_seed: nil) ⇒ ShuffleSplit

Create a new data splitter for random permutation cross validation.

Parameters:

  • n_splits (Integer) (defaults to: 3)

    The number of folds.

  • test_size (Float) (defaults to: 0.1)

    The ratio of number of samples for test data.

  • train_size (Float) (defaults to: nil)

    The ratio of number of samples for train data.

  • random_seed (Integer) (defaults to: nil)

    The seed value using to initialize the random generator.



36
37
38
39
40
41
42
43
# File 'rumale-model_selection/lib/rumale/model_selection/shuffle_split.rb', line 36

def initialize(n_splits: 3, test_size: 0.1, train_size: nil, random_seed: nil)
  @n_splits = n_splits
  @test_size = test_size
  @train_size = train_size
  @random_seed = random_seed
  @random_seed ||= srand
  @rng = Random.new(@random_seed)
end

Instance Attribute Details

#n_splitsInteger (readonly)

Return the number of folds.

Returns:

  • (Integer)


24
25
26
# File 'rumale-model_selection/lib/rumale/model_selection/shuffle_split.rb', line 24

def n_splits
  @n_splits
end

#rngRandom (readonly)

Return the random generator for shuffling the dataset.

Returns:

  • (Random)


28
29
30
# File 'rumale-model_selection/lib/rumale/model_selection/shuffle_split.rb', line 28

def rng
  @rng
end

Instance Method Details

#split(x, _y = nil) ⇒ Array

Generate data indices for random permutation cross validation.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The dataset to be used to generate data indices for random permutation cross validation.

Returns:

  • (Array)

    The set of data indices for constructing the training and testing dataset in each fold.



50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# File 'rumale-model_selection/lib/rumale/model_selection/shuffle_split.rb', line 50

def split(x, _y = nil)
  # Initialize and check some variables.
  n_samples = x.shape[0]
  n_test_samples = (@test_size * n_samples).ceil.to_i
  n_train_samples = @train_size.nil? ? n_samples - n_test_samples : (@train_size * n_samples).floor.to_i
  unless @n_splits.between?(1, n_samples)
    raise ArgumentError,
          'The value of n_splits must be not less than 1 and not more than the number of samples.'
  end
  unless n_test_samples.between?(1, n_samples)
    raise RangeError,
          'The number of samples in test split must be not less than 1 and not more than the number of samples.'
  end
  unless n_train_samples.between?(1, n_samples)
    raise RangeError,
          'The number of samples in train split must be not less than 1 and not more than the number of samples.'
  end
  if (n_test_samples + n_train_samples) > n_samples
    raise RangeError,
          'The total number of samples in test split and train split must be not more than the number of samples.'
  end
  sub_rng = @rng.dup
  # Returns array consisting of the training and testing ids for each fold.
  dataset_ids = Array(0...n_samples)
  Array.new(@n_splits) do
    test_ids = dataset_ids.sample(n_test_samples, random: sub_rng)
    train_ids = if @train_size.nil?
                  dataset_ids - test_ids
                else
                  (dataset_ids - test_ids).sample(n_train_samples, random: sub_rng)
                end
    [train_ids, test_ids]
  end
end