Skip to content

Commit

Permalink
add ability to return logits in eval
Browse files Browse the repository at this point in the history
  • Loading branch information
sdtblck committed May 11, 2021
1 parent c7c2063 commit 3389e4f
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def __init__(self, *super_args, **super_kwargs):

# We schedule the all-reduces, so disable it in super().backward()
self.enable_backward_allreduce = False
self.eval_return_logits = False
self.outputs = None

# used to disable the pipeline all-reduce when used with 1-bit Adam/1-bit LAMB
self.pipeline_enable_backward_allreduce = True
Expand Down Expand Up @@ -336,7 +338,7 @@ def train_batch(self, data_iter=None):
# TODO: should return precisely what loss returned and allow others to be queried?
return self.agg_train_loss

def eval_batch(self, data_iter):
def eval_batch(self, data_iter, return_logits=False):
"""Evaluate the pipeline on a batch of data from ``data_iter``. The
engine will evaluate ``self.train_batch_size()`` total samples
collectively across all workers.
Expand All @@ -363,7 +365,7 @@ def eval_batch(self, data_iter):
Returns:
The arithmetic mean of the losses computed this batch.
"""

self.eval_return_logits = return_logits
self.module.eval()
self.total_loss = None

Expand Down Expand Up @@ -393,7 +395,11 @@ def eval_batch(self, data_iter):

# Reset any buffers that may have been populated during the forward passes.
# ds_checkpointing.reset()

self.eval_return_logits = False
if return_logits:
outputs = self.outputs
self.outputs = None
return self.agg_eval_loss, outputs
return self.agg_eval_loss

def inference_batch(self, data_iter):
Expand Down Expand Up @@ -666,6 +672,8 @@ def _exec_forward_pass(self, buffer_id):
else:
# Some models just return loss from forward()
self.loss = outputs
if self.eval_return_logits:
self.outputs = outputs

if isinstance(self.loss, torch.Tensor):
if self.total_loss is None:
Expand Down

0 comments on commit 3389e4f

Please # to comment.