diff --git a/libspu/kernel/hal/shape_ops.cc b/libspu/kernel/hal/shape_ops.cc index 6337719f..ffcfed60 100644 --- a/libspu/kernel/hal/shape_ops.cc +++ b/libspu/kernel/hal/shape_ops.cc @@ -43,10 +43,10 @@ Value update_slice(SPUContext* ctx, const Value& in, const Value& update, SPU_TRACE_HAL_DISP(ctx, in, start_indices); if (in.storage_type() != update.storage_type()) { - auto u = - _cast_type(ctx, update, in.storage_type()).setDtype(update.dtype()); - - return update_slice(ctx, in, u, start_indices); + auto ct = _common_type(ctx, update.storage_type(), in.storage_type()); + auto u = _cast_type(ctx, update, ct).setDtype(update.dtype()); + auto i = _cast_type(ctx, in, ct).setDtype(in.dtype()); + return update_slice(ctx, i, u, start_indices); } return _update_slice(ctx, in, update, start_indices).setDtype(in.dtype());