Class: Rumale::ModelSelection::KFold
- Inherits:
-
Object
- Object
- Rumale::ModelSelection::KFold
- Includes:
- Base::Splitter
- Defined in:
- rumale-model_selection/lib/rumale/model_selection/k_fold.rb
Overview
KFold is a class that generates the set of data indices for K-fold cross-validation.
Instance Attribute Summary collapse
-
#n_splits ⇒ Integer
readonly
Return the number of folds.
-
#rng ⇒ Random
readonly
Return the random generator for shuffling the dataset.
-
#shuffle ⇒ Boolean
readonly
Return the flag indicating whether to shuffle the dataset.
Instance Method Summary collapse
-
#initialize(n_splits: 3, shuffle: false, random_seed: nil) ⇒ KFold
constructor
Create a new data splitter for K-fold cross validation.
-
#split(x, _y = nil) ⇒ Array
Generate data indices for K-fold cross validation.
Constructor Details
#initialize(n_splits: 3, shuffle: false, random_seed: nil) ⇒ KFold
Create a new data splitter for K-fold cross validation.
40 41 42 43 44 45 46 |
# File 'rumale-model_selection/lib/rumale/model_selection/k_fold.rb', line 40 def initialize(n_splits: 3, shuffle: false, random_seed: nil) @n_splits = n_splits @shuffle = shuffle @random_seed = random_seed @random_seed ||= srand @rng = Random.new(@random_seed) end |
Instance Attribute Details
#n_splits ⇒ Integer (readonly)
Return the number of folds.
25 26 27 |
# File 'rumale-model_selection/lib/rumale/model_selection/k_fold.rb', line 25 def n_splits @n_splits end |
#rng ⇒ Random (readonly)
Return the random generator for shuffling the dataset.
33 34 35 |
# File 'rumale-model_selection/lib/rumale/model_selection/k_fold.rb', line 33 def rng @rng end |
#shuffle ⇒ Boolean (readonly)
Return the flag indicating whether to shuffle the dataset.
29 30 31 |
# File 'rumale-model_selection/lib/rumale/model_selection/k_fold.rb', line 29 def shuffle @shuffle end |
Instance Method Details
#split(x, _y = nil) ⇒ Array
Generate data indices for K-fold cross validation.
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
# File 'rumale-model_selection/lib/rumale/model_selection/k_fold.rb', line 53 def split(x, _y = nil) # Initialize and check some variables. n_samples, = x.shape unless @n_splits.between?(2, n_samples) raise ArgumentError, 'The value of n_splits must be not less than 2 and not more than the number of samples.' end sub_rng = @rng.dup # Splits dataset ids to each fold. dataset_ids = Array(0...n_samples) dataset_ids.shuffle!(random: sub_rng) if @shuffle fold_sets = Array.new(@n_splits) do |n| n_fold_samples = n_samples / @n_splits n_fold_samples += 1 if n < n_samples % @n_splits dataset_ids.shift(n_fold_samples) end # Returns array consisting of the training and testing ids for each fold. Array.new(@n_splits) do |n| train_ids = fold_sets.select.with_index { |_, id| id != n }.flatten test_ids = fold_sets[n] [train_ids, test_ids] end end |