Class: Rumale::ModelSelection::TimeSeriesSplit

Inherits:
Object
  • Object
show all
Includes:
Base::Splitter
Defined in:
rumale-model_selection/lib/rumale/model_selection/time_series_split.rb

Overview

TimeSeriesSplit is a class that generates the set of data indices for time series cross-validation. It is assumed that the dataset given are already ordered by time information.

Examples:

require 'rumale/model_selection/time_series_split'

cv = Rumale::ModelSelection::TimeSeriesSplit.new(n_splits: 5)
x = Numo::DFloat.new(6, 2).rand
cv.split(x, nil).each do |train_ids, test_ids|
  puts '---'
  pp train_ids
  pp test_ids
end

# ---
# [0]
# [1]
# ---
# [0, 1]
# [2]
# ---
# [0, 1, 2]
# [3]
# ---
# [0, 1, 2, 3]
# [4]
# ---
# [0, 1, 2, 3, 4]
# [5]

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(n_splits: 5, max_train_size: nil) ⇒ TimeSeriesSplit

Create a new data splitter for time series cross-validation.

Parameters:

  • n_splits (Integer) (defaults to: 5)

    The number of splits.

  • max_train_size (Integer/Nil) (defaults to: nil)

    The maximum number of training samples in a split.



52
53
54
55
# File 'rumale-model_selection/lib/rumale/model_selection/time_series_split.rb', line 52

def initialize(n_splits: 5, max_train_size: nil)
  @n_splits = n_splits
  @max_train_size = max_train_size
end

Instance Attribute Details

#max_train_sizeInteger/Nil (readonly)

Return the maximum number of training samples in a split.

Returns:

  • (Integer/Nil)


46
47
48
# File 'rumale-model_selection/lib/rumale/model_selection/time_series_split.rb', line 46

def max_train_size
  @max_train_size
end

#n_splitsInteger (readonly)

Return the number of splits.

Returns:

  • (Integer)


42
43
44
# File 'rumale-model_selection/lib/rumale/model_selection/time_series_split.rb', line 42

def n_splits
  @n_splits
end

Instance Method Details

#split(x, y) ⇒ Array

Generate data indices for time series cross-validation.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The dataset to be used to generate data indices for time series cross-validation. It is expected that the data will be ordered by time information.

  • y (Numo::Int32)

    (shape: [n_samples]) This argument exists to unify the interface between the K-fold methods, it is not used in the method.

Returns:

  • (Array)

    The set of data indices for constructing the training and testing dataset in each fold.



66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# File 'rumale-model_selection/lib/rumale/model_selection/time_series_split.rb', line 66

def split(x, _y)
  n_samples = x.shape[0]
  unless (@n_splits + 1).between?(2, n_samples)
    raise ArgumentError,
          'The number of folds (n_splits + 1) must be not less than 2 and not more than the number of samples.'
  end

  test_size = n_samples / (@n_splits + 1)
  offset = test_size + n_samples % (@n_splits + 1)

  Array.new(@n_splits) do |n|
    start = offset * (n + 1)
    train_ids = if !@max_train_size.nil? && @max_train_size < test_size
                  Array((start - @max_train_size)...start)
                else
                  Array(0...start)
                end
    test_ids = Array(start...(start + test_size))
    [train_ids, test_ids]
  end
end