Skip to content

Commit

Permalink
[Observers] group size + channel wise + per token (vllm-project#32)
Browse files Browse the repository at this point in the history
* group size

* add logic in base observer

* group size full lifecycle run

* before vectorize the for loop

* comments, todo add channelwise

* chan wise impl

* comments

* fix channel wise

* comments, validators

* fix typo

* tensor return error fix

* fix sparseml-side of code and add per channel

* pyndatic defaults

* token wise quant

* Update src/compressed_tensors/quantization/quant_args.py

Co-authored-by: Benjamin Fineran <bfineran@users.noreply.github.com>

* comments'

* update dim

* shape consistency

* Update src/compressed_tensors/quantization/lifecycle/forward.py

Co-authored-by: Benjamin Fineran <bfineran@users.noreply.github.com>

* comments

* pass test_quant_args

---------

Co-authored-by: Benjamin Fineran <bfineran@users.noreply.github.com>
  • Loading branch information
horheynm and bfineran authored May 3, 2024
1 parent 774da35 commit 05c1487
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 13 deletions.
94 changes: 86 additions & 8 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
# limitations under the License.

from functools import wraps
from math import ceil

import torch
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.quant_args import (
QuantizationArgs,
QuantizationStrategy,
)
from compressed_tensors.quantization.quant_config import QuantizationStatus
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from torch.nn import Module
Expand All @@ -32,10 +36,9 @@ def quantize(
q_min: torch.Tensor,
q_max: torch.Tensor,
) -> torch.Tensor:

return torch.clamp(
torch.round(
x / scale + zero_point,
),
torch.round(x / scale + zero_point),
q_min,
q_max,
)
Expand All @@ -57,12 +60,88 @@ def fake_quantize(
zero_point: torch.Tensor,
args: QuantizationArgs,
) -> torch.Tensor:
"""
Fake quantize the input tensor x depending on the group_size.
if group_size is greater than 0, then q/dq by groups. The groups
must be divisible by the column size
if group_size is -1, then channel wise q/dq. THe input scale and
zero_points are reshaped to support vectorization (Assumes 1 is
the channel dimension)
:param x: Input tensor
:param scale: scale tensor
:param zero_point: zero point tensor
:param args: quantization args that contain group_size info
:return: fake quantized tensor
"""
bit_range = 2**args.num_bits
max_q = torch.tensor(bit_range / 2 - 1, device=x.device)
min_q = torch.tensor(-bit_range / 2, device=x.device)
Q = torch.zeros_like(x)
Q = quantize(x, scale, zero_point, min_q, max_q)
return dequantize(Q, scale, zero_point)

group_size = args.group_size

# group
if args.strategy == QuantizationStrategy.GROUP:

DQ = torch.zeros_like(x)

# TODO: vectorize the for loop
# TODO: fix genetric assumption about the tensor size for computing group

# TODO: make validation step for inputs

while scale.ndim < 2:
# pad scale and zero point dims for slicing
scale = scale.unsqueeze(1)
zero_point = zero_point.unsqueeze(1)

columns = x.shape[1]
if columns >= group_size:
if columns % group_size != 0:
raise ValueError(
"tesnor column shape must be divisble "
f"by the given group_size {group_size}"
)
for i in range(ceil(columns / group_size)):
# scale.shape should be [nchan, ndim]
# sc.shape should be [nchan, 1] after unsqueeze

sc = scale[:, i].unsqueeze(1)
zp = zero_point[:, i].unsqueeze(1)

idx = i * group_size
Q = quantize(x[:, idx : (idx + group_size)], sc, zp, min_q, max_q)
DQ[:, idx : (idx + group_size)] = dequantize(Q, sc, zp)

# channel-wise
elif args.strategy == QuantizationStrategy.CHANNEL: # group_size == -1
# before: scale shape = [channel_size]
# after: scale shape = [1, channel_size]
scale = scale.unsqueeze(0)
zero_point = zero_point.unsqueeze(0)

Q = quantize(x, scale, zero_point, min_q, max_q)
DQ = dequantize(Q, scale, zero_point)

# per-token
elif args.strategy == QuantizationStrategy.TOKEN:
# before: scale shape = [num_tokens]
# after: scale shape = [num_tokens, 1]
# x.shape = 1, num_tokens, 1]
# scale gets broadcasted as expected withput having [1, num_tokens, 1] shape

scale = scale.unsqueeze(1)
zero_point = zero_point.unsqueeze(1)

Q = quantize(x, scale, zero_point, min_q, max_q)
DQ = dequantize(Q, scale, zero_point)

else:
Q = quantize(x, scale, zero_point, min_q, max_q)
DQ = dequantize(Q, scale, zero_point)

return DQ


def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
Expand Down Expand Up @@ -139,5 +218,4 @@ def maybe_calibrate_or_quantize(
device = next(module.parameters()).device
scale.data = updated_scale.to(device)
zero_point.data = updated_zero_point.to(device)

return fake_quantize(value, scale, zero_point, args)
67 changes: 64 additions & 3 deletions src/compressed_tensors/quantization/observers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@

from typing import Optional, Tuple

from compressed_tensors.quantization.quant_args import QuantizationArgs
import torch
from compressed_tensors.quantization.quant_args import (
QuantizationArgs,
QuantizationStrategy,
)
from compressed_tensors.registry.registry import RegistryMixin
from torch import FloatTensor, IntTensor, Tensor
from torch.nn import Module
Expand Down Expand Up @@ -52,6 +56,12 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
"""
raise NotImplementedError(f"{self.__class__} must implement calculate_qparams")

def post_calculate_qparams(self) -> None:
"""
Run any logic specific to its observers after running calculate_qparams
"""
...

def get_qparams(
self, observed: Optional[Tensor] = None
) -> Tuple[FloatTensor, IntTensor]:
Expand All @@ -64,6 +74,57 @@ def get_qparams(
:return: tuple of scale and zero point based on last observed value
"""
if observed is not None:
# re-calcualte scale and zero point, update the stored value
self._scale, self._zero_point = self.calculate_qparams(observed)
group_size = self.quantization_args.group_size

if self.quantization_args.strategy == QuantizationStrategy.TENSOR:

# re-calculate scale and zero point, update the stored value
self._scale, self._zero_point = self.calculate_qparams(observed)

elif self.quantization_args.strategy == QuantizationStrategy.GROUP:
columns = observed.shape[1]
scales, zero_points = [], []
for i in range(0, columns, self.quantization_args.group_size):
scale, zero_point = self.get_qparams_along_dim(
observed[:, i : (i + group_size)],
0,
)
scales.append(scale)
zero_points.append(zero_point)

self._scale = torch.stack(scales, dim=1)
self._zero_point = torch.stack(zero_points, dim=1)

elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
# assume observed is transposed, because its the output, hence use dim 0
self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)

elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:

# use dim 1, assume the obsersed.shape = [batch, token, hidden]
# should be batch, token

self._scale, self._zero_point = self.get_qparams_along_dim(
observed, dim=1
)

return self._scale, self._zero_point

def get_qparams_along_dim(self, observed, dim: int):
# TODO: add documentation that specifies the shape must
# be padded with 1-dims so the scales are along the right channel
# TODO: generalize the logic for reduce_dims
scales, zero_points = [], []

# TODO: make a more generic way to get the channel
num_dims = observed.shape[dim]

for dim_idx in range(num_dims):
scale, zero_point = self.calculate_qparams(
observed.select(dim=dim, index=dim_idx)
)

scales.append(scale)
zero_points.append(zero_point)
# breakpoint()
return torch.stack(scales), torch.stack(zero_points)
33 changes: 31 additions & 2 deletions src/compressed_tensors/quantization/quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from enum import Enum
from typing import Any, Dict, Optional

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, validator


__all__ = ["QuantizationType", "QuantizationStrategy", "QuantizationArgs"]
Expand All @@ -39,6 +39,7 @@ class QuantizationStrategy(str, Enum):
CHANNEL = "channel"
GROUP = "group"
BLOCK = "block"
TOKEN = "token"


class QuantizationArgs(BaseModel):
Expand All @@ -63,8 +64,8 @@ class QuantizationArgs(BaseModel):
num_bits: int = 8
type: QuantizationType = QuantizationType.INT
symmetric: bool = True
strategy: QuantizationStrategy = QuantizationStrategy.TENSOR
group_size: Optional[int] = None
strategy: Optional[QuantizationStrategy] = None
block_structure: Optional[str] = None
dynamic: bool = False
observer: str = Field(
Expand Down Expand Up @@ -94,3 +95,31 @@ def get_observer(self):
self.observer = "memoryless"

return Observer.load_from_registry(self.observer, quantization_args=self)

@validator("strategy", pre=True, always=True)
def validate_strategy(cls, value, values):
group_size = values.get("group_size")

# use group_size to determinine strategy if not given explicity
if group_size is not None and value is None:
if group_size > 0:
return QuantizationStrategy.GROUP

elif group_size == -1:
return QuantizationStrategy.CHANNEL

else:
raise ValueError(
f"group_size={group_size} with strategy {value} is invald. "
"group_size > 0 for strategy='group' and "
"group_size = -1 for 'channel'"
)

if value == QuantizationStrategy.GROUP:
if group_size is None:
raise ValueError(f"strategy {value} requires group_size to be set.")

if value is None:
return QuantizationStrategy.TENSOR

return value
1 change: 1 addition & 0 deletions src/compressed_tensors/quantization/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def calculate_compression_ratio(model: Module) -> float:
compressed_bits = uncompressed_bits
if is_module_quantized(submodule):
compressed_bits = submodule.quantization_scheme.weights.num_bits

num_weights = parameter.numel()
total_compressed += compressed_bits * num_weights
total_uncompressed += uncompressed_bits * num_weights
Expand Down

0 comments on commit 05c1487

Please # to comment.