Skip to content

Revert D74293458 #2974

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion torchrec/metrics/metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,6 @@ def _generate_rec_metrics(
kwargs = metric_def.arguments

kwargs["enable_pt2_compile"] = metrics_config.enable_pt2_compile
kwargs["should_clone_update_inputs"] = metrics_config.should_clone_update_inputs

rec_tasks: List[RecTaskInfo] = []
if metric_def.rec_tasks and metric_def.rec_task_indices:
Expand Down
3 changes: 0 additions & 3 deletions torchrec/metrics/metrics_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,6 @@ class MetricsConfig:
update if the inputs are invalid. Invalid inputs include the case where all
examples have 0 weights for a batch.
enable_pt2_compile (bool): whether to enable PT2 compilation for metrics.
should_clone_update_inputs (bool): whether to clone the inputs of update(). This
prevents CUDAGraph error on overwritting tensor outputs by subsequent runs.
"""

rec_tasks: List[RecTaskInfo] = field(default_factory=list)
Expand All @@ -186,7 +184,6 @@ class MetricsConfig:
compute_on_all_ranks: bool = False
should_validate_update: bool = False
enable_pt2_compile: bool = False
should_clone_update_inputs: bool = False


DefaultTaskInfo = RecTaskInfo(
Expand Down
42 changes: 0 additions & 42 deletions torchrec/metrics/rec_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,14 +384,6 @@ def __init__(
if "enable_pt2_compile" in kwargs:
del kwargs["enable_pt2_compile"]

# pyre-fixme[8]: Attribute has type `bool`; used as `Union[bool,
# Dict[str, Any]]`.
self._should_clone_update_inputs: bool = kwargs.get(
"should_clone_update_inputs", False
)
if "should_clone_update_inputs" in kwargs:
del kwargs["should_clone_update_inputs"]

if self._window_size < self._batch_size:
raise ValueError(
f"Local window size must be larger than batch size. Got local window size {self._window_size} and batch size {self._batch_size}."
Expand Down Expand Up @@ -549,35 +541,6 @@ def _create_default_weights(self, predictions: torch.Tensor) -> torch.Tensor:
def _check_nonempty_weights(self, weights: torch.Tensor) -> torch.Tensor:
return torch.gt(torch.count_nonzero(weights, dim=-1), 0)

def clone_update_inputs(
self,
predictions: RecModelOutput,
labels: RecModelOutput,
weights: Optional[RecModelOutput],
**kwargs: Dict[str, Any],
) -> tuple[
RecModelOutput, RecModelOutput, Optional[RecModelOutput], Dict[str, Any]
]:
def clone_rec_model_output(
rec_model_output: RecModelOutput,
) -> RecModelOutput:
if isinstance(rec_model_output, torch.Tensor):
return rec_model_output.clone()
else:
return {k: v.clone() for k, v in rec_model_output.items()}

predictions = clone_rec_model_output(predictions)
labels = clone_rec_model_output(labels)
if weights is not None:
weights = clone_rec_model_output(weights)

if "required_inputs" in kwargs:
kwargs["required_inputs"] = {
k: v.clone() for k, v in kwargs["required_inputs"].items()
}

return predictions, labels, weights, kwargs

def _update(
self,
*,
Expand All @@ -587,11 +550,6 @@ def _update(
**kwargs: Dict[str, Any],
) -> None:
with torch.no_grad():
if self._should_clone_update_inputs:
predictions, labels, weights, kwargs = self.clone_update_inputs(
predictions, labels, weights, **kwargs
)

if self._compute_mode in [
RecComputeMode.FUSED_TASKS_COMPUTATION,
RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
Expand Down
Loading