Class: Rumale::ModelSelection::GroupKFold

Inherits:
Object
  • Object
show all
Includes:
Base::Splitter
Defined in:
rumale-model_selection/lib/rumale/model_selection/group_k_fold.rb

Overview

GroupKFold is a class that generates the set of data indices for K-fold cross-validation. The data points belonging to the same group do not be split into different folds. The number of groups should be greater than or equal to the number of splits.

Examples:

require 'rumale/model_selection/group_k_fold'

cv = Rumale::ModelSelection::GroupKFold.new(n_splits: 3)
x = Numo::DFloat.new(8, 2).rand
groups = Numo::Int32[1, 1, 1, 2, 2, 3, 3, 3]
cv.split(x, nil, groups).each do |train_ids, test_ids|
  puts '---'
  pp train_ids
  pp test_ids
end

# ---
# [0, 1, 2, 3, 4]
# [5, 6, 7]
# ---
# [3, 4, 5, 6, 7]
# [0, 1, 2]
# ---
# [0, 1, 2, 5, 6, 7]
# [3, 4]

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(n_splits: 5) ⇒ GroupKFold

Create a new data splitter for grouped K-fold cross validation.

Parameters:

  • n_splits (Integer) (defaults to: 5)

    The number of folds.



44
45
46
# File 'rumale-model_selection/lib/rumale/model_selection/group_k_fold.rb', line 44

def initialize(n_splits: 5)
  @n_splits = n_splits
end

Instance Attribute Details

#n_splitsInteger (readonly)

Return the number of folds.

Returns:

  • (Integer)


39
40
41
# File 'rumale-model_selection/lib/rumale/model_selection/group_k_fold.rb', line 39

def n_splits
  @n_splits
end

Instance Method Details

#split(x, y, groups) ⇒ Array

Generate data indices for grouped K-fold cross validation.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The dataset to be used to generate data indices for grouped K-fold cross validation.

  • y (Numo::Int32)

    (shape: [n_samples]) This argument exists to unify the interface between the K-fold methods, it is not used in the method.

  • groups (Numo::Int32)

    (shape: [n_samples]) The group labels to be used to generate data indices for grouped K-fold cross validation.

Returns:

  • (Array)

    The set of data indices for constructing the training and testing dataset in each fold.



58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# File 'rumale-model_selection/lib/rumale/model_selection/group_k_fold.rb', line 58

def split(x, _y, groups)
  encoder = ::Rumale::Preprocessing::LabelEncoder.new
  groups = encoder.fit_transform(groups)
  n_groups = encoder.classes.size

  if n_groups < @n_splits
    raise ArgumentError,
          'The number of groups should be greater than or equal to the number of splits.'
  end

  n_samples_per_group = groups.bincount
  group_ids = n_samples_per_group.sort_index.reverse
  n_samples_per_group = n_samples_per_group[group_ids]

  n_samples_per_fold = Numo::Int32.zeros(@n_splits)
  group_to_fold = Numo::Int32.zeros(n_groups)

  n_samples_per_group.each_with_index do |weight, id|
    min_sample_fold_id = n_samples_per_fold.min_index
    n_samples_per_fold[min_sample_fold_id] += weight
    group_to_fold[group_ids[id]] = min_sample_fold_id
  end

  n_samples = x.shape[0]
  sample_ids = Array(0...n_samples)
  fold_ids = group_to_fold[groups]

  Array.new(@n_splits) do |fid|
    test_ids = fold_ids.eq(fid).where.to_a
    train_ids = sample_ids - test_ids
    [train_ids, test_ids]
  end
end