Class: Rumale::ModelSelection::GroupShuffleSplit
- Inherits:
-
Object
- Object
- Rumale::ModelSelection::GroupShuffleSplit
- Includes:
- Base::Splitter
- Defined in:
- rumale-model_selection/lib/rumale/model_selection/group_shuffle_split.rb
Overview
GroupShuffleSplit is a class that generates the set of data indices for random permutation cross-validation by randomly selecting group labels.
Instance Attribute Summary collapse
-
#n_splits ⇒ Integer
readonly
Return the number of folds.
-
#rng ⇒ Random
readonly
Return the random generator for shuffling the dataset.
Instance Method Summary collapse
-
#initialize(n_splits: 5, test_size: 0.2, train_size: nil, random_seed: nil) ⇒ GroupShuffleSplit
constructor
Create a new data splitter for random permutation cross validation with given group labels.
-
#split(x, y, groups) ⇒ Array
Generate train and test data indices by randomly selecting group labels.
Constructor Details
#initialize(n_splits: 5, test_size: 0.2, train_size: nil, random_seed: nil) ⇒ GroupShuffleSplit
Create a new data splitter for random permutation cross validation with given group labels.
46 47 48 49 50 51 52 53 |
# File 'rumale-model_selection/lib/rumale/model_selection/group_shuffle_split.rb', line 46 def initialize(n_splits: 5, test_size: 0.2, train_size: nil, random_seed: nil) @n_splits = n_splits @test_size = test_size @train_size = train_size @random_seed = random_seed @random_seed ||= srand @rng = Random.new(@random_seed) end |
Instance Attribute Details
#n_splits ⇒ Integer (readonly)
Return the number of folds.
34 35 36 |
# File 'rumale-model_selection/lib/rumale/model_selection/group_shuffle_split.rb', line 34 def n_splits @n_splits end |
#rng ⇒ Random (readonly)
Return the random generator for shuffling the dataset.
38 39 40 |
# File 'rumale-model_selection/lib/rumale/model_selection/group_shuffle_split.rb', line 38 def rng @rng end |
Instance Method Details
#split(x, y, groups) ⇒ Array
Generate train and test data indices by randomly selecting group labels.
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
# File 'rumale-model_selection/lib/rumale/model_selection/group_shuffle_split.rb', line 65 def split(_x, _y, groups) classes = groups.to_a.uniq.sort n_groups = classes.size n_test_groups = (@test_size * n_groups).ceil.to_i n_train_groups = @train_size.nil? ? n_groups - n_test_groups : (@train_size * n_groups).floor.to_i unless n_test_groups.between?(1, n_groups) raise RangeError, 'The number of groups in test split must be not less than 1 and not more than the number of groups.' end unless n_train_groups.between?(1, n_groups) raise RangeError, 'The number of groups in train split must be not less than 1 and not more than the number of groups.' end if (n_test_groups + n_train_groups) > n_groups raise RangeError, 'The total number of groups in test split and train split must be not more than the number of groups.' end sub_rng = @rng.dup Array.new(@n_splits) do test_group_ids = classes.sample(n_test_groups, random: sub_rng) train_group_ids = if @train_size.nil? classes - test_group_ids else (classes - test_group_ids).sample(n_train_groups, random: sub_rng) end test_ids = in1d(groups, test_group_ids).where.to_a train_ids = in1d(groups, train_group_ids).where.to_a [train_ids, test_ids] end end |