Class: Rumale::ModelSelection::TimeSeriesSplit
- Inherits:
-
Object
- Object
- Rumale::ModelSelection::TimeSeriesSplit
- Includes:
- Base::Splitter
- Defined in:
- rumale-model_selection/lib/rumale/model_selection/time_series_split.rb
Overview
TimeSeriesSplit is a class that generates the set of data indices for time series cross-validation. It is assumed that the dataset given are already ordered by time information.
Instance Attribute Summary collapse
-
#max_train_size ⇒ Integer/Nil
readonly
Return the maximum number of training samples in a split.
-
#n_splits ⇒ Integer
readonly
Return the number of splits.
Instance Method Summary collapse
-
#initialize(n_splits: 5, max_train_size: nil) ⇒ TimeSeriesSplit
constructor
Create a new data splitter for time series cross-validation.
-
#split(x, y) ⇒ Array
Generate data indices for time series cross-validation.
Constructor Details
#initialize(n_splits: 5, max_train_size: nil) ⇒ TimeSeriesSplit
Create a new data splitter for time series cross-validation.
52 53 54 55 |
# File 'rumale-model_selection/lib/rumale/model_selection/time_series_split.rb', line 52 def initialize(n_splits: 5, max_train_size: nil) @n_splits = n_splits @max_train_size = max_train_size end |
Instance Attribute Details
#max_train_size ⇒ Integer/Nil (readonly)
Return the maximum number of training samples in a split.
46 47 48 |
# File 'rumale-model_selection/lib/rumale/model_selection/time_series_split.rb', line 46 def max_train_size @max_train_size end |
#n_splits ⇒ Integer (readonly)
Return the number of splits.
42 43 44 |
# File 'rumale-model_selection/lib/rumale/model_selection/time_series_split.rb', line 42 def n_splits @n_splits end |
Instance Method Details
#split(x, y) ⇒ Array
Generate data indices for time series cross-validation.
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
# File 'rumale-model_selection/lib/rumale/model_selection/time_series_split.rb', line 66 def split(x, _y) n_samples = x.shape[0] unless (@n_splits + 1).between?(2, n_samples) raise ArgumentError, 'The number of folds (n_splits + 1) must be not less than 2 and not more than the number of samples.' end test_size = n_samples / (@n_splits + 1) offset = test_size + n_samples % (@n_splits + 1) Array.new(@n_splits) do |n| start = offset * (n + 1) train_ids = if !@max_train_size.nil? && @max_train_size < test_size Array((start - @max_train_size)...start) else Array(0...start) end test_ids = Array(start...(start + test_size)) [train_ids, test_ids] end end |