Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

actorder #16

Closed
wants to merge 17 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ def compress(
self.H = torch.linalg.cholesky(self.H, upper=True)
Hinv = self.H

actorder = False
invperm = False

# See section 3.4 of https://arxiv.org/abs/2203.07259
for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
Expand Down Expand Up @@ -140,14 +143,38 @@ def compress(
q = torch.dequantize(q)
elif hasattr(self.layer, "quantization_scheme"):
quant_scheme = self.layer.quantization_scheme
actorder = quant_scheme.weights.actorder

if quant_scheme.weights is not None:
scale = self.layer.weight_scale
zero_point = self.layer.weight_zero_point
from compressed_tensors.quantization import QuantizationStrategy
from compressed_tensors.quantization.lifecycle.forward import (
fake_quantize,
)

scale = self.layer.weight_scale
zero_point = self.layer.weight_zero_point

if actorder:
perm = torch.argsort(torch.diag(self.H), descending=True)
W = W[:, perm]
self.H = self.H[perm][:, perm]
invperm = torch.argsort(perm)

group_size = quant_scheme.weights.group_size
if group_size is None or group_size == -1:
group_size = self.layer.weight.shape[1]

if actorder:
indices = torch.arange(self.columns, device=invperm.device)
g_idx = (perm[indices] // group_size).to(dtype=torch.int32)
g_idx = g_idx[invperm]
self.layer.weight_g_idx.data = g_idx
else:
indices = torch.arange(
self.columns, device=W.device, dtype=torch.int32
)
g_idx = indices // group_size

strategy = quant_scheme.weights.strategy

if strategy == QuantizationStrategy.TENSOR:
Expand Down Expand Up @@ -176,6 +203,17 @@ def compress(
# ends up being a channelwise application
altered_qargs = copy(quant_scheme.weights)
altered_qargs.strategy = QuantizationStrategy.CHANNEL

# apply g_idx
if g_idx is not None:
# scale and zp already transformed by group_size
# extract first index of group_idze
indices_to_extract = torch.arange(
0, g_idx.shape[0], group_size
)
scale = scale[:, g_idx[indices_to_extract]]
zero_point = zero_point[:, g_idx[indices_to_extract]]

q = fake_quantize(
q,
scale[:, input_dim_group],
Expand Down Expand Up @@ -206,6 +244,9 @@ def compress(
logger.info("time %.2f" % (time.time() - tick))
logger.info("error %.2f" % torch.sum(Losses).item())

if actorder:
W = W[:, invperm]

if isinstance(self.layer, transformers.Conv1D):
W = W.t()
W = W.reshape(final_shape).to(final_dtype)
Expand Down
Loading