Skip to content

Commit

Permalink
More default
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 8, 2024
1 parent 386584a commit 29e1b32
Showing 1 changed file with 10 additions and 17 deletions.
27 changes: 10 additions & 17 deletions src/brevitas/core/scaling/standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,26 +61,20 @@ class ConstScaling(brevitas.jit.ScriptModule):
def __init__(
self,
scaling_init: Union[float, Tensor],
restrict_scaling_impl: Optional[Module] = None,
restrict_scaling_impl: Module = FloatRestrictValue(),
scaling_min_val: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None) -> None:
super(ConstScaling, self).__init__()
self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl)
if isinstance(scaling_init, Tensor):
scaling_init = scaling_init.to(device=device, dtype=dtype)
if restrict_scaling_impl is not None:
scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init)
self.restrict_init_module = restrict_scaling_impl.restrict_init_module()
else:
self.restrict_init_module = Identity()
scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init)
self.restrict_init_module = restrict_scaling_impl.restrict_init_module()
self.value = StatelessBuffer(scaling_init.detach())
else:
if restrict_scaling_impl is not None:
scaling_init = restrict_scaling_impl.restrict_init_float(scaling_init)
self.restrict_init_module = restrict_scaling_impl.restrict_init_module()
else:
self.restrict_init_module = Identity()
scaling_init = restrict_scaling_impl.restrict_init_float(scaling_init)
self.restrict_init_module = restrict_scaling_impl.restrict_init_module()
self.value = StatelessBuffer(torch.tensor(scaling_init, dtype=dtype, device=device))

@brevitas.jit.script_method
Expand Down Expand Up @@ -138,7 +132,7 @@ def __init__(
self,
scaling_init: Union[float, Tensor],
scaling_shape: Optional[Tuple[int, ...]] = None,
restrict_scaling_impl: Optional[Module] = None,
restrict_scaling_impl: Module = FloatRestrictValue(),
scaling_min_val: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None) -> None:
Expand All @@ -153,11 +147,10 @@ def __init__(
scaling_init = scaling_init.detach()
else:
scaling_init = torch.tensor(scaling_init, dtype=dtype, device=device)
if restrict_scaling_impl is not None:
scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init)
self.restrict_init_module = restrict_scaling_impl.restrict_init_module()
else:
self.restrict_init_module = Identity()

scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init)
self.restrict_init_module = restrict_scaling_impl.restrict_init_module()

if scaling_init.shape == SCALAR_SHAPE and scaling_shape is not None:
scaling_init = torch.full(scaling_shape, scaling_init, dtype=dtype, device=device)
self.value = Parameter(scaling_init)
Expand Down

0 comments on commit 29e1b32

Please # to comment.