Class: Rumale::ModelSelection::ShuffleSplit
- Inherits:
-
Object
- Object
- Rumale::ModelSelection::ShuffleSplit
- Includes:
- Base::Splitter
- Defined in:
- rumale-model_selection/lib/rumale/model_selection/shuffle_split.rb
Overview
ShuffleSplit is a class that generates the set of data indices for random permutation cross-validation.
Instance Attribute Summary collapse
-
#n_splits ⇒ Integer
readonly
Return the number of folds.
-
#rng ⇒ Random
readonly
Return the random generator for shuffling the dataset.
Instance Method Summary collapse
-
#initialize(n_splits: 3, test_size: 0.1, train_size: nil, random_seed: nil) ⇒ ShuffleSplit
constructor
Create a new data splitter for random permutation cross validation.
-
#split(x, _y = nil) ⇒ Array
Generate data indices for random permutation cross validation.
Constructor Details
#initialize(n_splits: 3, test_size: 0.1, train_size: nil, random_seed: nil) ⇒ ShuffleSplit
Create a new data splitter for random permutation cross validation.
36 37 38 39 40 41 42 43 |
# File 'rumale-model_selection/lib/rumale/model_selection/shuffle_split.rb', line 36 def initialize(n_splits: 3, test_size: 0.1, train_size: nil, random_seed: nil) @n_splits = n_splits @test_size = test_size @train_size = train_size @random_seed = random_seed @random_seed ||= srand @rng = Random.new(@random_seed) end |
Instance Attribute Details
#n_splits ⇒ Integer (readonly)
Return the number of folds.
24 25 26 |
# File 'rumale-model_selection/lib/rumale/model_selection/shuffle_split.rb', line 24 def n_splits @n_splits end |
#rng ⇒ Random (readonly)
Return the random generator for shuffling the dataset.
28 29 30 |
# File 'rumale-model_selection/lib/rumale/model_selection/shuffle_split.rb', line 28 def rng @rng end |
Instance Method Details
#split(x, _y = nil) ⇒ Array
Generate data indices for random permutation cross validation.
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
# File 'rumale-model_selection/lib/rumale/model_selection/shuffle_split.rb', line 50 def split(x, _y = nil) # Initialize and check some variables. n_samples = x.shape[0] n_test_samples = (@test_size * n_samples).ceil.to_i n_train_samples = @train_size.nil? ? n_samples - n_test_samples : (@train_size * n_samples).floor.to_i unless @n_splits.between?(1, n_samples) raise ArgumentError, 'The value of n_splits must be not less than 1 and not more than the number of samples.' end unless n_test_samples.between?(1, n_samples) raise RangeError, 'The number of samples in test split must be not less than 1 and not more than the number of samples.' end unless n_train_samples.between?(1, n_samples) raise RangeError, 'The number of samples in train split must be not less than 1 and not more than the number of samples.' end if (n_test_samples + n_train_samples) > n_samples raise RangeError, 'The total number of samples in test split and train split must be not more than the number of samples.' end sub_rng = @rng.dup # Returns array consisting of the training and testing ids for each fold. dataset_ids = Array(0...n_samples) Array.new(@n_splits) do test_ids = dataset_ids.sample(n_test_samples, random: sub_rng) train_ids = if @train_size.nil? dataset_ids - test_ids else (dataset_ids - test_ids).sample(n_train_samples, random: sub_rng) end [train_ids, test_ids] end end |