Skip to content

Commit

Permalink
don't clone in transform_into
Browse files Browse the repository at this point in the history
  • Loading branch information
LouisChourakiSonos authored and kali committed Feb 19, 2025
1 parent fb6a126 commit f7eaef3
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 8 deletions.
4 changes: 2 additions & 2 deletions cli/src/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ impl Parameters {
{
stage!("metal", typed_model -> typed_model, |m:TypedModel| {
tract_metal::transform::MetalTransform::default()
.transform_into(&m)
.transform_into(m)
});
}
#[cfg(not(any(target_os = "macos", target_os = "ios")))]
Expand All @@ -768,7 +768,7 @@ impl Parameters {
for transform in transform {
stage!(transform, typed_model -> typed_model, |m:TypedModel| {
let transform = tract_core::transform::get_transform(transform).with_context(|| format!("Could not find transform named {}", transform))?;
transform.transform_into(&m)
transform.transform_into(m)
});
stage!(&format!("{}-declutter", transform), typed_model -> typed_model, |m:TypedModel| m.into_decluttered());
}
Expand Down
6 changes: 3 additions & 3 deletions core/src/floats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ mod test {
// Execution in F16 with returns NaN
let runnable_model = &crate::transform::get_transform("f32-to-f16")
.unwrap()
.transform_into(&model)?
.transform_into(model.clone())?
.into_runnable()?;
assert!(runnable_model.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0]
.to_scalar::<f16>()?
Expand All @@ -258,7 +258,7 @@ mod test {
// Execution in F16 with filter that returns the good output.
let runnable_model = &crate::transform::get_transform("f32-to-f16!=layer.1")
.unwrap()
.transform_into(&model)?
.transform_into(model.clone())?
.into_runnable()?;
assert_eq!(
runnable_model.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0],
Expand All @@ -268,7 +268,7 @@ mod test {
// Execution in F16 with returns NaN despite the filter.
let runnable_model = &crate::transform::get_transform("f32-to-f16!=layer.0")
.unwrap()
.transform_into(&model)?
.transform_into(model)?
.into_runnable()?;
assert!(runnable_model.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0]
.to_scalar::<f16>()?
Expand Down
3 changes: 1 addition & 2 deletions core/src/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ pub fn build_float_translator<T1: Datum + Float, T2: Datum + Float>(
pub trait ModelTransform: Debug {
fn name(&self) -> Cow<str>;
fn transform(&self, model: &mut TypedModel) -> TractResult<()>;
fn transform_into(&self, model: &TypedModel) -> TractResult<TypedModel> {
let mut model = model.clone();
fn transform_into(&self, mut model: TypedModel) -> TractResult<TypedModel> {
self.transform(&mut model)?;
Ok(model)
}
Expand Down
2 changes: 1 addition & 1 deletion metal/src/ops/change_axes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ pub mod tests {

let expected = model.clone().into_runnable()?.run(tvec![input.clone().into()])?;

let metal_model = MetalTransform::default().transform_into(&model)?;
let metal_model = MetalTransform::default().transform_into(model)?;
let output = metal_model.clone().into_runnable()?.run(tvec![input.clone().into()])?;

let _ = &output[0].close_enough(&expected[0], Approximation::Close)?;
Expand Down

0 comments on commit f7eaef3

Please # to comment.