From 13fc6e27c8065518bbdeb441e16b01712049824a Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Wed, 8 Nov 2023 12:47:51 +0100 Subject: [PATCH] Make the parameter in `ScaleTransform` a scalar --- src/transform/scaletransform.jl | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/src/transform/scaletransform.jl b/src/transform/scaletransform.jl index 18923fcc4..d746dd520 100644 --- a/src/transform/scaletransform.jl +++ b/src/transform/scaletransform.jl @@ -13,26 +13,22 @@ true ``` """ struct ScaleTransform{T<:Real} <: Transform - s::Vector{T} + s::T end function ScaleTransform(s::T=1.0) where {T<:Real} - return ScaleTransform{T}([s]) + return ScaleTransform{T}(s) end -@functor ScaleTransform +(t::ScaleTransform)(x) = t.s * x -set!(t::ScaleTransform, ρ::Real) = t.s .= [ρ] +_map(t::ScaleTransform, x::AbstractVector{<:Real}) = t.s .* x +_map(t::ScaleTransform, x::ColVecs) = ColVecs(t.s .* x.X) +_map(t::ScaleTransform, x::RowVecs) = RowVecs(t.s .* x.X) -(t::ScaleTransform)(x) = only(t.s) * x +Base.isequal(t::ScaleTransform, t2::ScaleTransform) = isequal(t.s, only(t2.s)) -_map(t::ScaleTransform, x::AbstractVector{<:Real}) = only(t.s) .* x -_map(t::ScaleTransform, x::ColVecs) = ColVecs(only(t.s) .* x.X) -_map(t::ScaleTransform, x::RowVecs) = RowVecs(only(t.s) .* x.X) - -Base.isequal(t::ScaleTransform, t2::ScaleTransform) = isequal(only(t.s), only(t2.s)) - -Base.show(io::IO, t::ScaleTransform) = print(io, "Scale Transform (s = ", only(t.s), ")") +Base.show(io::IO, t::ScaleTransform) = print(io, "Scale Transform (s = ", t.s, ")") # Helpers