Class: Rumale::ModelSelection::GroupKFold
- Inherits:
-
Object
- Object
- Rumale::ModelSelection::GroupKFold
- Includes:
- Base::Splitter
- Defined in:
- rumale-model_selection/lib/rumale/model_selection/group_k_fold.rb
Overview
GroupKFold is a class that generates the set of data indices for K-fold cross-validation. The data points belonging to the same group do not be split into different folds. The number of groups should be greater than or equal to the number of splits.
Instance Attribute Summary collapse
-
#n_splits ⇒ Integer
readonly
Return the number of folds.
Instance Method Summary collapse
-
#initialize(n_splits: 5) ⇒ GroupKFold
constructor
Create a new data splitter for grouped K-fold cross validation.
-
#split(x, y, groups) ⇒ Array
Generate data indices for grouped K-fold cross validation.
Constructor Details
#initialize(n_splits: 5) ⇒ GroupKFold
Create a new data splitter for grouped K-fold cross validation.
44 45 46 |
# File 'rumale-model_selection/lib/rumale/model_selection/group_k_fold.rb', line 44 def initialize(n_splits: 5) @n_splits = n_splits end |
Instance Attribute Details
#n_splits ⇒ Integer (readonly)
Return the number of folds.
39 40 41 |
# File 'rumale-model_selection/lib/rumale/model_selection/group_k_fold.rb', line 39 def n_splits @n_splits end |
Instance Method Details
#split(x, y, groups) ⇒ Array
Generate data indices for grouped K-fold cross validation.
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
# File 'rumale-model_selection/lib/rumale/model_selection/group_k_fold.rb', line 58 def split(x, _y, groups) encoder = ::Rumale::Preprocessing::LabelEncoder.new groups = encoder.fit_transform(groups) n_groups = encoder.classes.size if n_groups < @n_splits raise ArgumentError, 'The number of groups should be greater than or equal to the number of splits.' end n_samples_per_group = groups.bincount group_ids = n_samples_per_group.sort_index.reverse n_samples_per_group = n_samples_per_group[group_ids] n_samples_per_fold = Numo::Int32.zeros(@n_splits) group_to_fold = Numo::Int32.zeros(n_groups) n_samples_per_group.each_with_index do |weight, id| min_sample_fold_id = n_samples_per_fold.min_index n_samples_per_fold[min_sample_fold_id] += weight group_to_fold[group_ids[id]] = min_sample_fold_id end n_samples = x.shape[0] sample_ids = Array(0...n_samples) fold_ids = group_to_fold[groups] Array.new(@n_splits) do |fid| test_ids = fold_ids.eq(fid).where.to_a train_ids = sample_ids - test_ids [train_ids, test_ids] end end |