Class: Rumale::ModelSelection::StratifiedKFold

Inherits:
Object
  • Object
show all
Includes:
Base::Splitter
Defined in:
rumale-model_selection/lib/rumale/model_selection/stratified_k_fold.rb

Overview

StratifiedKFold is a class that generates the set of data indices for K-fold cross-validation. The proportion of the number of samples in each class will be almost equal for each fold.

Examples:

require 'rumale/model_selection/stratified_k_fold'

kf = Rumale::ModelSelection::StratifiedKFold.new(n_splits: 3, shuffle: true, random_seed: 1)
kf.split(samples, labels).each do |train_ids, test_ids|
  train_samples = samples[train_ids, true]
  test_samples = samples[test_ids, true]
  ...
end

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(n_splits: 3, shuffle: false, random_seed: nil) ⇒ StratifiedKFold

Create a new data splitter for stratified K-fold cross validation.

Parameters:

  • n_splits (Integer) (defaults to: 3)

    The number of folds.

  • shuffle (Boolean) (defaults to: false)

    The flag indicating whether to shuffle the dataset.

  • random_seed (Integer) (defaults to: nil)

    The seed value using to initialize the random generator.



40
41
42
43
44
45
46
# File 'rumale-model_selection/lib/rumale/model_selection/stratified_k_fold.rb', line 40

def initialize(n_splits: 3, shuffle: false, random_seed: nil)
  @n_splits = n_splits
  @shuffle = shuffle
  @random_seed = random_seed
  @random_seed ||= srand
  @rng = Random.new(@random_seed)
end

Instance Attribute Details

#n_splitsInteger (readonly)

Return the number of folds.

Returns:

  • (Integer)


25
26
27
# File 'rumale-model_selection/lib/rumale/model_selection/stratified_k_fold.rb', line 25

def n_splits
  @n_splits
end

#rngRandom (readonly)

Return the random generator for shuffling the dataset.

Returns:

  • (Random)


33
34
35
# File 'rumale-model_selection/lib/rumale/model_selection/stratified_k_fold.rb', line 33

def rng
  @rng
end

#shuffleBoolean (readonly)

Return the flag indicating whether to shuffle the dataset.

Returns:

  • (Boolean)


29
30
31
# File 'rumale-model_selection/lib/rumale/model_selection/stratified_k_fold.rb', line 29

def shuffle
  @shuffle
end

Instance Method Details

#split(x, y) ⇒ Array

Generate data indices for stratified K-fold cross validation.

Returns The set of data indices for constructing the training and testing dataset in each fold.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The dataset to be used to generate data indices for stratified K-fold cross validation. This argument exists to unify the interface between the K-fold methods, it is not used in the method.

  • y (Numo::Int32)

    (shape: [n_samples]) The labels to be used to generate data indices for stratified K-fold cross validation.

Returns:

  • (Array)

    The set of data indices for constructing the training and testing dataset in each fold.



57
58
59
60
61
62
63
64
65
66
67
68
# File 'rumale-model_selection/lib/rumale/model_selection/stratified_k_fold.rb', line 57

def split(_x, y)
  # Check the number of samples in each class.
  unless valid_n_splits?(y)
    raise ArgumentError,
          'The value of n_splits must be not less than 2 and not more than the number of samples in each class.'
  end
  # Splits dataset ids of each class to each fold.
  sub_rng = @rng.dup
  fold_sets_each_class = y.to_a.uniq.map { |label| fold_sets(y, label, sub_rng) }
  # Returns array consisting of the training and testing ids for each fold.
  Array.new(@n_splits) { |fold_id| train_test_sets(fold_sets_each_class, fold_id) }
end