We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
你好, 我目前在尝试解决ppoffload不支持gradient_accumulation_fusion的问题。 同时开ppoffload和ga fusion,会报错在megatron/core/tensor_parallel/layers.py 中,报错原因是 weight.grad=None,即在offload和onload的过程中,Tensor会丢grad。
if ctx.gradient_accumulation_fusion: if wgrad_compute: if hasattr(weight, 'main_grad'): tmp_grad = weight.main_grad else: tmp_grad = weight.grad # weight.grad 为 None if tmp_grad.dtype == torch.float32: fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32( total_input, grad_output, tmp_grad ) elif tmp_grad.dtype in (torch.float16, torch.bfloat16): fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16( total_input, grad_output, tmp_grad ) else: raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
我进行了多种尝试保留gard,都失败了。甚至注释掉offload后释放显存的操作和onload的操作,依然报同样的错误。 我有两个问题,希望大佬帮忙解答。 一、ppoffload 支持 ga fusion这个特性,后续有支持么?为啥我offload中不释放Tensor依然报错呢? 二、offload.py 中 ForwardEmptyBackwardIdentityFunction 和 ForwardLeftBackwardRightFunction 这两个类是什么作用呢?没看懂这两个类的作用,是不是跟丢失grad有关呢?
The text was updated successfully, but these errors were encountered:
No branches or pull requests
你好,
我目前在尝试解决ppoffload不支持gradient_accumulation_fusion的问题。
同时开ppoffload和ga fusion,会报错在megatron/core/tensor_parallel/layers.py 中,报错原因是 weight.grad=None,即在offload和onload的过程中,Tensor会丢grad。
我进行了多种尝试保留gard,都失败了。甚至注释掉offload后释放显存的操作和onload的操作,依然报同样的错误。
我有两个问题,希望大佬帮忙解答。
一、ppoffload 支持 ga fusion这个特性,后续有支持么?为啥我offload中不释放Tensor依然报错呢?
二、offload.py 中 ForwardEmptyBackwardIdentityFunction 和 ForwardLeftBackwardRightFunction 这两个类是什么作用呢?没看懂这两个类的作用,是不是跟丢失grad有关呢?
The text was updated successfully, but these errors were encountered: