diff --git a/fairseq/utils.py b/fairseq/utils.py index 739ba49f9e..d4ed89e357 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -297,10 +297,11 @@ def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> torch.Tensor: if multi_tensor_l2norm_available: total_norm = multi_tensor_total_norm(grads) else: - warnings.warn( - "amp_C fused kernels unavailable, disabling multi_tensor_l2norm; " - "you may get better performance by installing NVIDIA's apex library" - ) + if torch.cuda.is_available(): + warnings.warn( + "amp_C fused kernels unavailable, disabling multi_tensor_l2norm; " + "you may get better performance by installing NVIDIA's apex library" + ) total_norm = torch.norm( torch.stack([torch.norm(g, p=2, dtype=torch.float32) for g in grads]) )