From cae4004208a6aa9c5c4ee3d42c19ce35d52d577f Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 29 Apr 2024 14:45:14 +0200 Subject: [PATCH] Fix (quant): propagate device and dtype in subinjector (#942) --- src/brevitas/quant/base.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/brevitas/quant/base.py b/src/brevitas/quant/base.py index b509e6c16..3a5a2ec53 100644 --- a/src/brevitas/quant/base.py +++ b/src/brevitas/quant/base.py @@ -435,6 +435,8 @@ class MSESymmetricScaleSubInjector(MSESubInjectorBase): mse_init_op = AbsMax stats_impl = MSE stats_reduce_dim = (this << 1).stats_reduce_dim + device = (this << 1).device + type = (this << 1).type class MSEAsymmetricScaleSubInjector(MSESubInjectorBase): @@ -443,6 +445,8 @@ class MSEAsymmetricScaleSubInjector(MSESubInjectorBase): mse_init_op = AbsMinMax stats_impl = MSE stats_reduce_dim = (this << 1).stats_reduce_dim + device = (this << 1).device + type = (this << 1).type class MSEZeroPointSubInjector(MSESubInjectorBase): @@ -453,6 +457,8 @@ class MSEZeroPointSubInjector(MSESubInjectorBase): mse_search_method = 'grid' stats_impl = MSE stats_reduce_dim = (this << 1).stats_reduce_dim + device = (this << 1).device + type = (this << 1).type class MSEAsymmetricScale(ExtendedInjector):