Skip to content

Commit

Permalink
Revert "[Auto Parallel] change lookup_table_v2 auto cast for align mo…
Browse files Browse the repository at this point in the history
…de (#68231)"

This reverts commit 5d2a1fd.
  • Loading branch information
From00 committed Dec 25, 2024
1 parent c6c2f23 commit e8c3d12
Showing 1 changed file with 0 additions and 12 deletions.
12 changes: 0 additions & 12 deletions python/paddle/amp/auto_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,18 +695,6 @@ def master_grad_hook():

# set amp op list
original_white_list, original_black_list = tracer._get_amp_op_list()

# TODO(zhangyuqin1998): In auto parallel align mode, ensure lookup_table_v2 runs in FP32.
# By default, lookup_table_v2 is in the white_list, and runs in BF16/BF16.
# Users can add lookup_table_v2 to the amp_custom_black_list but cannot remove it from the default white_list.
# If lookup_table_v2 appears in both the white_list and black_list, AMP will select it in BF16/BF16.
# Therefore, in auto parallel align mode, add lookup_table_v2 to the black_list and ensure it is not in the white_list.
from paddle.distributed import in_auto_parallel_align_mode

if in_auto_parallel_align_mode():
_black_list.add("lookup_table_v2")
if "lookup_table_v2" in _white_list:
_white_list.remove("lookup_table_v2")
tracer._set_amp_op_list(_white_list, _black_list)

# TODO(zhiqiu) set amp related flags automatically in this guard
Expand Down

0 comments on commit e8c3d12

Please # to comment.