Class: Rumale::ModelSelection::StratifiedKFold
- Inherits:
-
Object
- Object
- Rumale::ModelSelection::StratifiedKFold
- Includes:
- Base::Splitter
- Defined in:
- rumale-model_selection/lib/rumale/model_selection/stratified_k_fold.rb
Overview
StratifiedKFold is a class that generates the set of data indices for K-fold cross-validation. The proportion of the number of samples in each class will be almost equal for each fold.
Instance Attribute Summary collapse
-
#n_splits ⇒ Integer
readonly
Return the number of folds.
-
#rng ⇒ Random
readonly
Return the random generator for shuffling the dataset.
-
#shuffle ⇒ Boolean
readonly
Return the flag indicating whether to shuffle the dataset.
Instance Method Summary collapse
-
#initialize(n_splits: 3, shuffle: false, random_seed: nil) ⇒ StratifiedKFold
constructor
Create a new data splitter for stratified K-fold cross validation.
-
#split(x, y) ⇒ Array
Generate data indices for stratified K-fold cross validation.
Constructor Details
#initialize(n_splits: 3, shuffle: false, random_seed: nil) ⇒ StratifiedKFold
Create a new data splitter for stratified K-fold cross validation.
40 41 42 43 44 45 46 |
# File 'rumale-model_selection/lib/rumale/model_selection/stratified_k_fold.rb', line 40 def initialize(n_splits: 3, shuffle: false, random_seed: nil) @n_splits = n_splits @shuffle = shuffle @random_seed = random_seed @random_seed ||= srand @rng = Random.new(@random_seed) end |
Instance Attribute Details
#n_splits ⇒ Integer (readonly)
Return the number of folds.
25 26 27 |
# File 'rumale-model_selection/lib/rumale/model_selection/stratified_k_fold.rb', line 25 def n_splits @n_splits end |
#rng ⇒ Random (readonly)
Return the random generator for shuffling the dataset.
33 34 35 |
# File 'rumale-model_selection/lib/rumale/model_selection/stratified_k_fold.rb', line 33 def rng @rng end |
#shuffle ⇒ Boolean (readonly)
Return the flag indicating whether to shuffle the dataset.
29 30 31 |
# File 'rumale-model_selection/lib/rumale/model_selection/stratified_k_fold.rb', line 29 def shuffle @shuffle end |
Instance Method Details
#split(x, y) ⇒ Array
Generate data indices for stratified K-fold cross validation.
57 58 59 60 61 62 63 64 65 66 67 68 |
# File 'rumale-model_selection/lib/rumale/model_selection/stratified_k_fold.rb', line 57 def split(_x, y) # Check the number of samples in each class. unless valid_n_splits?(y) raise ArgumentError, 'The value of n_splits must be not less than 2 and not more than the number of samples in each class.' end # Splits dataset ids of each class to each fold. sub_rng = @rng.dup fold_sets_each_class = y.to_a.uniq.map { |label| fold_sets(y, label, sub_rng) } # Returns array consisting of the training and testing ids for each fold. Array.new(@n_splits) { |fold_id| train_test_sets(fold_sets_each_class, fold_id) } end |