diff --git a/Project.toml b/Project.toml index 734ab0a..4ebba3c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "FastTransformsForwardDiff" uuid = "77fa7db0-1c81-401d-9fde-3592fc42b8bc" authors = ["Sheehan Olver "] -version = "0.0.1" +version = "0.0.2" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" diff --git a/src/FastTransformsForwardDiff.jl b/src/FastTransformsForwardDiff.jl index 25b9472..fb72d6a 100644 --- a/src/FastTransformsForwardDiff.jl +++ b/src/FastTransformsForwardDiff.jl @@ -8,6 +8,29 @@ import FFTW: r2r, r2r!, plan_r2r, mul!, Plan @inline tagtype(::Complex{T}) where T = tagtype(T) @inline tagtype(::Type{Complex{T}}) where T = tagtype(T) +dual2array(x::Array{<:Dual{Tag,T}}) where {Tag,T} = reinterpret(reshape, T, x) +dual2array(x::Array{<:Complex{<:Dual{Tag, T}}}) where {Tag,T} = complex.(dual2array(real(x)), dual2array(imag(x))) +array2dual(DT::Type{<:Dual}, x::Array{T}) where T = reinterpret(reshape, DT, real(x)) +array2dual(DT::Type{<:Dual}, x::Array{<:Complex{T}}) where T = complex.(array2dual(DT, real(x)), array2dual(DT, imag(x))) + +value(x::Complex{<:Dual}) = Complex(x.re.value, x.im.value) + +partials(x::Complex{<:Dual}, n::Int) = Complex(partials(x.re, n), partials(x.im, n)) + +npartials(x::Complex{<:Dual{T,V,N}}) where {T,V,N} = N +npartials(::Type{<:Complex{<:Dual{T,V,N}}}) where {T,V,N} = N + + +for P in (:Plan, :ScaledPlan) # need ScaledPlan to avoid ambiguities + @eval begin + Base.:*(p::AbstractFFTs.$P, x::AbstractArray{DT}) where DT<:Dual = array2dual(DT, p * dual2array(x)) + Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:Complex{DT}}) where DT<:Dual = array2dual(DT, p * dual2array(x)) + end +end + +mul!(y::AbstractArray{<:Union{Dual,Complex{<:Dual}}}, p::Plan, x::AbstractArray{<:Union{Dual,Complex{<:Dual}}}) = copyto!(y, p*x) + + include("fft.jl") end # module FastTransformsForwardDiff diff --git a/src/fft.jl b/src/fft.jl index 262f762..9b35e82 100644 --- a/src/fft.jl +++ b/src/fft.jl @@ -1,14 +1,3 @@ -dual2array(x::Array{<:Dual{Tag,T}}) where {Tag,T} = reinterpret(reshape, T, x) -dual2array(x::Array{<:Complex{<:Dual{Tag, T}}}) where {Tag,T} = complex.(dual2array(real(x)), dual2array(imag(x))) -array2dual(DT::Type{<:Dual}, x::Array{T}) where T = reinterpret(reshape, DT, real(x)) -array2dual(DT::Type{<:Dual}, x::Array{<:Complex{T}}) where T = complex.(array2dual(DT, real(x)), array2dual(DT, imag(x))) - -value(x::Complex{<:Dual}) = Complex(x.re.value, x.im.value) - -partials(x::Complex{<:Dual}, n::Int) = Complex(partials(x.re, n), partials(x.im, n)) - -npartials(x::Complex{<:Dual{T,V,N}}) where {T,V,N} = N -npartials(::Type{<:Complex{<:Dual{T,V,N}}}) where {T,V,N} = N # AbstractFFTs.complexfloat(x::AbstractArray{<:Dual}) = float.(x .+ 0im) AbstractFFTs.complexfloat(x::AbstractArray{<:Dual}) = AbstractFFTs.complexfloat.(x) @@ -37,12 +26,3 @@ end r2r(x::AbstractArray{<:Dual}, kinds, region...) = plan_r2r(x, kinds, region...) * x r2r(x::AbstractArray{<:Complex{<:Dual}}, kinds, region...) = plan_r2r(x, kinds, region...) * x - -for P in (:Plan, :ScaledPlan) # need ScaledPlan to avoid ambiguities - @eval begin - Base.:*(p::AbstractFFTs.$P, x::AbstractArray{DT}) where DT<:Dual = array2dual(DT, p * dual2array(x)) - Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:Complex{DT}}) where DT<:Dual = array2dual(DT, p * dual2array(x)) - end -end - -mul!(y::AbstractArray{<:Dual}, p::Plan, x::AbstractArray{<:Dual}) = copyto!(y, p*x) \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 9d1dce2..0c2eb04 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -60,6 +60,10 @@ end @test partials.(fft(A, 2), 1) == fft(partials.(A, 1), 2) @test partials.(fft(A, 2), 2) == fft(partials.(A, 2), 2) end + + c1 = complex.(x1) + @test mul!(similar(c1), FFTW.plan_fft(x1), x1) == fft(x1) + @test mul!(similar(c1), FFTW.plan_fft(c1), c1) == fft(c1) end @testset "r2r" begin