diff --git a/src/brevitas/quant/base.py b/src/brevitas/quant/base.py index b509e6c16..3a5a2ec53 100644 --- a/src/brevitas/quant/base.py +++ b/src/brevitas/quant/base.py @@ -435,6 +435,8 @@ class MSESymmetricScaleSubInjector(MSESubInjectorBase): mse_init_op = AbsMax stats_impl = MSE stats_reduce_dim = (this << 1).stats_reduce_dim + device = (this << 1).device + type = (this << 1).type class MSEAsymmetricScaleSubInjector(MSESubInjectorBase): @@ -443,6 +445,8 @@ class MSEAsymmetricScaleSubInjector(MSESubInjectorBase): mse_init_op = AbsMinMax stats_impl = MSE stats_reduce_dim = (this << 1).stats_reduce_dim + device = (this << 1).device + type = (this << 1).type class MSEZeroPointSubInjector(MSESubInjectorBase): @@ -453,6 +457,8 @@ class MSEZeroPointSubInjector(MSESubInjectorBase): mse_search_method = 'grid' stats_impl = MSE stats_reduce_dim = (this << 1).stats_reduce_dim + device = (this << 1).device + type = (this << 1).type class MSEAsymmetricScale(ExtendedInjector):