How does DeepAR validation work (in detail)? #3207
Unanswered
Serendipity31
asked this question in
Q&A
Replies: 1 comment
-
@Serendipity31, if you finally find the answers to any of your questions, please post it here. Very useful questions. |
Beta Was this translation helpful? Give feedback.
0 replies
# for free
to join this conversation on GitHub.
Already have an account?
# to comment
-
The Issue - I want to understand the validation process in detail
I am struggling to feel confident that I have found the answers to several questions about how validation works in DeepAR (version 0.14.0). This post contains my questions, my attempts to work out the answers, and a hypothetical scenario to help make this all a bit more concrete. If anyone is able to take a look at any of these questions and check/correct my understanding, I would greatly appreciate it.
Hypothetical Scenario - Suppose for the sake of these questions that:
prediction_length = 1
context_length = 2
lags_seq = 3
batch_size = 20
DeepAREstimator
has:Question 1: When average validation loss is calculated, what values are included in the average loss calculation?
context_length
--> 3 loss values per series x 100 series)Answer based on my current understanding: Option 2
Reasoning
Since the purpose of observing validation loss is to see how well the model is generalising to unseen data, it would seem like option 1 is the way it should be. And in this thread, Iostella's answer suggests the answer to my question is option 1.
However, the [definition] of
validation_step()
does not passfuture_only = True
toloss()
. This means the default value (future_only = False
) remains.With nothing overriding
future_only = False
, when the the loss is calculated, the target values provided to the loss function are a concatenation ofcontext_target
andfuture_target_reshaped
). The concatenation happens on line 569. The estimation of loss values happens in line 579. This really makes it seem like the answer to my question is actually option 2.Is this correct?
Question 2: (Per epoch) For what number of series is validation loss calculated?
num_batches_per_epoch
(even if this means the validation loss for some series is calculated more than once)Answer based on my current understanding: Option 1
Reasoning
I think option 1 is both what should be happening and what is happening. However, I am struggling to understand
create_validation_data_loader()
[code], and would be greatful for someone to verify my understanding. Here is what I undestand aboutcreate_validation_data_loader()
... Withincreate_validation_data_loader()
:_create_instance_splitter()
[code]self.validation_sampler
(for which the default is an instance ofValidationSplitSampler
)as_stacked_batches()
[code].as_stacked_batches()
returns an instance ofIterableSlice
IterableSlice
[code] takes two inputs: an iterable version of the dataset andnum_batches_per_epoch
(which will be either anint
or the default ofNone
). Because nothing explicitly establishes a new value fornum_batches_per_epoch
, the default remains. Therefore, in the hypothetical example, this would result in:- 5 validation batches (each with 20 sliced series)
- Each series would show up in a single one of these batches
limit_val_batches
, but defaults to 1.0. Therefore, in my hypotehtical example, unless I were to explicltly over-ride this argument, whenever validation loss is calculated, it will be calculated using all 5 validation batches.Is this an accurate description of events?
Question 3: After how many iterations is the validation loss calculated?
num_batches_per_epoch
)Answer based on my current understanding: Option 1
Reasoning
The PyTorch Lightning trainer has a 'flag' called
check_val_every_n_epoch
. This flag takes a default value of 1.0. Therefore, unless I were to explicitly over-ride this default (and this example I have not), I would expect that validation loss gets calculated at the end of each epoch.Is this correct?
Question 4: When checking validation loss, how many times is the LSTM network unrolled?
context_length
(i.e. 2 in my example)past_length
(a concatenation ofcontext_length
andmax(lags_seq)
- i.e. 5 in my example)Answer based on my current understanding: Option 2
Reasoning
Considering option 1
If I look at the definition of
ValidationSplitSampler
[code], it returns an instance ofPredictionSplitSampler
withmin_past = 0
andmin_future = prediction_length
. When given a validation time series, the__call__
function inPredictionSplitSampler
[code] returns the last time point for splitting (e.g. 48 in my example).This then gets used within the
InstanceSplitter
that is returned by_create_instance_splitter
[code]. Mote specifically, it gets used inflatmap_transform()
[code]. This function uses the indices from theValidationSplitSampler
in a call to_split_instance()
[code]. In turn, two objects (past_piece
andfuture_piece
) are returned from a call to_split_array()
[code]. Within_split_array()
,past_piece
goes backpast_length
in time from the point where the data are sliced [code]. Andpast_length
covers the part of the series made up fromcontext_length
and themax(lag_seq)
.Therefore it seems like the validation batches do not include any data older than
past_length
, and so it seems like option 1 cannot be true. That leaves options 2 and 3.Considering Option 2
The number of times a PyTorch LSTM network unrolls is determined by the sequence length of the input tensor. The input to the LSTM network in DeepAR is called
rnn_input
, and it's the first output from theprepare_rnn_input()
function [code]. A closer look atprepare_rnn_input()
shows that the first output is formed from the concatenation oflags
andfeatures
[code]. This means that the sequence length of the input tensor that gets passed to the LSTM network must come fromlags
.lags
is the output of a call tolagged_sequence_values()
, which takes as inputsself.lags_seq
,prior_input
, andinput
(i.e. the scaled context length) [see here]. The output fromlagged_sequence_values()
is a tensor [code], the shape of which depends oninput
(and notprior_input
orpast_length
).Therefore, although I am not 100% confident that I understand the full shape of
rnn_input
, I think this means that the sequence length passed to the LSTM network during validation is the same ascontext_length
(which would make option 2 the answer to this question).Question 5: (Given the answer to Q4), why is the network unrolled
context_length
times during validation, rather than making use of the whole available history of each series?Answer based on my current understanding: Option 1
Reasoning
I have not managed to find any discussions of anything related to options 2 or 3, so it's more a process of elimination guess than it is a confident assertion based on a deep understanding of this issue.
Is this right? Have I missed something?
Beta Was this translation helpful? Give feedback.
All reactions