Class: Rumale::ModelSelection::GroupShuffleSplit

Inherits:
Object
  • Object
show all
Includes:
Base::Splitter
Defined in:
rumale-model_selection/lib/rumale/model_selection/group_shuffle_split.rb

Overview

GroupShuffleSplit is a class that generates the set of data indices for random permutation cross-validation by randomly selecting group labels.

Examples:

require 'rumale/model_selection/group_shuffle_split'

cv = Rumale::ModelSelection::GroupShuffleSplit.new(n_splits: 2, test_size: 0.2, random_seed: 1)
x = Numo::DFloat.new(8, 2).rand
groups = Numo::Int32[1, 1, 1, 2, 2, 3, 3, 3]
cv.split(x, nil, groups).each do |train_ids, test_ids|
  puts '---'
  pp train_ids
  pp test_ids
end

# ---
# [0, 1, 2, 5, 6, 7]
# [3, 4]
# ---
# [3, 4, 5, 6, 7]
# [0, 1, 2]

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(n_splits: 5, test_size: 0.2, train_size: nil, random_seed: nil) ⇒ GroupShuffleSplit

Create a new data splitter for random permutation cross validation with given group labels.

Parameters:

  • n_splits (Integer) (defaults to: 5)

    The number of folds.

  • test_size (Float) (defaults to: 0.2)

    The ratio of number of groups for test data.

  • train_size (Float/Nil) (defaults to: nil)

    The ratio of number of groups for train data.

  • random_seed (Integer) (defaults to: nil)

    The seed value using to initialize the random generator.



46
47
48
49
50
51
52
53
# File 'rumale-model_selection/lib/rumale/model_selection/group_shuffle_split.rb', line 46

def initialize(n_splits: 5, test_size: 0.2, train_size: nil, random_seed: nil)
  @n_splits = n_splits
  @test_size = test_size
  @train_size = train_size
  @random_seed = random_seed
  @random_seed ||= srand
  @rng = Random.new(@random_seed)
end

Instance Attribute Details

#n_splitsInteger (readonly)

Return the number of folds.

Returns:

  • (Integer)


34
35
36
# File 'rumale-model_selection/lib/rumale/model_selection/group_shuffle_split.rb', line 34

def n_splits
  @n_splits
end

#rngRandom (readonly)

Return the random generator for shuffling the dataset.

Returns:

  • (Random)


38
39
40
# File 'rumale-model_selection/lib/rumale/model_selection/group_shuffle_split.rb', line 38

def rng
  @rng
end

Instance Method Details

#split(x, y, groups) ⇒ Array

Generate train and test data indices by randomly selecting group labels.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The dataset to be used to generate data indices for random permutation cross validation.

  • y (Numo::Int32)

    (shape: [n_samples]) This argument exists to unify the interface between the K-fold methods, it is not used in the method.

  • groups (Numo::Int32)

    (shape: [n_samples]) The group labels to be used to generate data indices for random permutation cross validation.

Returns:

  • (Array)

    The set of data indices for constructing the training and testing dataset in each fold.



65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# File 'rumale-model_selection/lib/rumale/model_selection/group_shuffle_split.rb', line 65

def split(_x, _y, groups)
  classes = groups.to_a.uniq.sort
  n_groups = classes.size
  n_test_groups = (@test_size * n_groups).ceil.to_i
  n_train_groups = @train_size.nil? ? n_groups - n_test_groups : (@train_size * n_groups).floor.to_i

  unless n_test_groups.between?(1, n_groups)
    raise RangeError,
          'The number of groups in test split must be not less than 1 and not more than the number of groups.'
  end
  unless n_train_groups.between?(1, n_groups)
    raise RangeError,
          'The number of groups in train split must be not less than 1 and not more than the number of groups.'
  end
  if (n_test_groups + n_train_groups) > n_groups
    raise RangeError,
          'The total number of groups in test split and train split must be not more than the number of groups.'
  end

  sub_rng = @rng.dup

  Array.new(@n_splits) do
    test_group_ids = classes.sample(n_test_groups, random: sub_rng)
    train_group_ids = if @train_size.nil?
                        classes - test_group_ids
                      else
                        (classes - test_group_ids).sample(n_train_groups, random: sub_rng)
                      end
    test_ids = in1d(groups, test_group_ids).where.to_a
    train_ids = in1d(groups, train_group_ids).where.to_a
    [train_ids, test_ids]
  end
end