Skip to content

Commit

Permalink
scratch
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Jul 12, 2024
1 parent f203537 commit 9967a4a
Showing 1 changed file with 37 additions and 17 deletions.
54 changes: 37 additions & 17 deletions src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,24 @@ def compress(
W[:, dead] = 0

g_idx = None
# if hasattr(self.layer, "quantization_scheme"):
# quant_scheme = self.layer.quantization_scheme
# actorder = quant_scheme.weights.actorder

# if actorder:
# group_size = quant_scheme.weights.group_size
# perm = torch.argsort(torch.diag(self.H), descending=True)
# # W = W[:, perm]
# # self.H = self.H[perm][:, perm]
# invperm = torch.argsort(perm)

# # g_idx = torch.Tensor(
# # [perm[i] // group_size for i in range(self.columns)]
# # ).to(device=invperm.device)
# g_idx = torch.Tensor(
# [i // group_size for i in range(self.columns)]
# ).to(device=invperm.device)
# self.layer.weight_g_idx.data = g_idx
if hasattr(self.layer, "quantization_scheme"):
quant_scheme = self.layer.quantization_scheme
actorder = quant_scheme.weights.actorder
Expand All @@ -114,8 +132,11 @@ def compress(
g_idx = torch.Tensor(
[perm[i] // group_size for i in range(self.columns)]
).to(device=invperm.device)
# g_idx = torch.Tensor(
# [i // group_size for i in range(self.columns)]
# ).to(device=invperm.device)
self.layer.weight_g_idx.data = g_idx

Losses = torch.zeros(self.rows, device=self.dev)

damp = percdamp * torch.mean(torch.diag(self.H))
Expand Down Expand Up @@ -200,24 +221,22 @@ def compress(
altered_qargs = copy(quant_scheme.weights)
altered_qargs.strategy = QuantizationStrategy.CHANNEL

# apply g_idx
if g_idx is not None:
# scale and zp already transformed by group_size
# extract first index of group_size
indices_to_extract = torch.arange(
0, g_idx.shape[0], group_size
q = fake_quantize(
q,
scale[:, int(g_idx[column_idx])],
zero_point[:, int(g_idx[column_idx])],
altered_qargs,
)

else:

q = fake_quantize(
q,
scale[:, input_dim_group],
zero_point[:, input_dim_group],
altered_qargs,
)
grouped_indicies = g_idx[indices_to_extract].int()

scale = scale[:, grouped_indicies]
zero_point = zero_point[:, grouped_indicies]

q = fake_quantize(
q,
scale[:, input_dim_group],
zero_point[:, input_dim_group],
altered_qargs,
)

Q1[:, i] = q
Losses1[:, i] = (w - q) ** 2 / d**2
Expand All @@ -244,6 +263,7 @@ def compress(

if actorder:
W = W[:, invperm]
self.H = self.H[perm][:, perm]

if isinstance(self.layer, transformers.Conv1D):
W = W.t()
Expand Down

0 comments on commit 9967a4a

Please # to comment.