Skip to content

Commit

Permalink
add override QuantLinear
Browse files Browse the repository at this point in the history
  • Loading branch information
xly committed Aug 12, 2024
1 parent 5e096b4 commit 4db8fb5
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions moe_infinity/runtime/model_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import math
import torch.distributed as dist
from torch.distributed import rpc
from auto_gptq.nn_modules.qlinear.qlinear_cuda import QuantLinear
from auto_gptq.nn_modules.qlinear.qlinear_cuda_old import QuantLinear as QuantLinearOld

import torch
import functools
Expand Down Expand Up @@ -280,6 +282,13 @@ def archer_cast_classifier(cls, *args, **kwargs):
self.offload_set.add(cls.classifier.weight.data.data_ptr())

return archer_cast_classifier


# GPTQ Override
QuantLinear._old_init = QuantLinear.__init__
QuantLinear.__init__ = param_init_decorator(QuantLinear.__init__)
QuantLinearOld._old_init = QuantLinearOld.__init__
QuantLinearOld.__init__ = param_init_decorator(QuantLinearOld.__init__)

self.cls._old_init = self.cls.__init__
self.cls.__init__ = init_decorator(self.cls._old_init)
Expand Down Expand Up @@ -605,6 +614,11 @@ def archer_from_pretrained(cls, *args, **kwargs):

# clean up initialization hooks
def __exit__(self, exc_type, exc_value, traceback):

# GPTQ Override
QuantLinear.__init__ = QuantLinear._old_init
QuantLinearOld.__init__ = QuantLinearOld._old_init

self.cls.__init__ = self.cls._old_init
self.cls.from_pretrained = self.cls._old_from_pretrained
torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply
Expand Down

0 comments on commit 4db8fb5

Please # to comment.