Class: Rumale::ModelSelection::StratifiedShuffleSplit
- Inherits:
-
Object
- Object
- Rumale::ModelSelection::StratifiedShuffleSplit
- Includes:
- Base::Splitter
- Defined in:
- rumale-model_selection/lib/rumale/model_selection/stratified_shuffle_split.rb
Overview
StratifiedShuffleSplit is a class that generates the set of data indices for random permutation cross-validation. The proportion of the number of samples in each class will be almost equal for each fold.
Instance Attribute Summary collapse
-
#n_splits ⇒ Integer
readonly
Return the number of folds.
-
#rng ⇒ Random
readonly
Return the random generator for shuffling the dataset.
Instance Method Summary collapse
-
#initialize(n_splits: 3, test_size: 0.1, train_size: nil, random_seed: nil) ⇒ StratifiedShuffleSplit
constructor
Create a new data splitter for random permutation cross validation.
-
#split(x, y) ⇒ Array
Generate data indices for stratified random permutation cross validation.
Constructor Details
#initialize(n_splits: 3, test_size: 0.1, train_size: nil, random_seed: nil) ⇒ StratifiedShuffleSplit
Create a new data splitter for random permutation cross validation.
37 38 39 40 41 42 43 44 |
# File 'rumale-model_selection/lib/rumale/model_selection/stratified_shuffle_split.rb', line 37 def initialize(n_splits: 3, test_size: 0.1, train_size: nil, random_seed: nil) @n_splits = n_splits @test_size = test_size @train_size = train_size @random_seed = random_seed @random_seed ||= srand @rng = Random.new(@random_seed) end |
Instance Attribute Details
#n_splits ⇒ Integer (readonly)
Return the number of folds.
25 26 27 |
# File 'rumale-model_selection/lib/rumale/model_selection/stratified_shuffle_split.rb', line 25 def n_splits @n_splits end |
#rng ⇒ Random (readonly)
Return the random generator for shuffling the dataset.
29 30 31 |
# File 'rumale-model_selection/lib/rumale/model_selection/stratified_shuffle_split.rb', line 29 def rng @rng end |
Instance Method Details
#split(x, y) ⇒ Array
Generate data indices for stratified random permutation cross validation.
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
# File 'rumale-model_selection/lib/rumale/model_selection/stratified_shuffle_split.rb', line 55 def split(_x, y) # Initialize and check some variables. train_sz = @train_size.nil? ? 1.0 - @test_size : @train_size sub_rng = @rng.dup # Check the number of samples in each class. unless valid_n_splits?(y) raise ArgumentError, 'The value of n_splits must be not less than 1 and not more than the number of samples in each class.' end # rubocop:disable Layout/LineLength unless enough_data_size_each_class?(y, @test_size, 'test') raise RangeError, 'The number of samples in test split must be not less than 1 and not more than the number of samples in each class.' end unless enough_data_size_each_class?(y, train_sz, 'train') raise RangeError, 'The number of samples in train split must be not less than 1 and not more than the number of samples in each class.' end unless enough_data_size_each_class?(y, train_sz + @test_size, 'train') raise RangeError, 'The total number of samples in test split and train split must be not more than the number of samples in each class.' end # rubocop:enable Layout/LineLength # Returns array consisting of the training and testing ids for each fold. sample_ids_each_class = y.to_a.uniq.map { |label| y.eq(label).where.to_a } Array.new(@n_splits) do train_ids = [] test_ids = [] sample_ids_each_class.each do |sample_ids| n_samples = sample_ids.size n_test_samples = (@test_size * n_samples).ceil.to_i test_ids += sample_ids.sample(n_test_samples, random: sub_rng) train_ids += if @train_size.nil? sample_ids - test_ids else n_train_samples = (train_sz * n_samples).floor.to_i (sample_ids - test_ids).sample(n_train_samples, random: sub_rng) end end [train_ids, test_ids] end end |