Skip to content

Commit 704f56d

Browse files
authored
Support mul! with complex vectors (#2)
* Support mul! with complex vectors * add tests, reorg * v0.0.2
1 parent e75f3f6 commit 704f56d

File tree

4 files changed

+28
-21
lines changed

4 files changed

+28
-21
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "FastTransformsForwardDiff"
22
uuid = "77fa7db0-1c81-401d-9fde-3592fc42b8bc"
33
authors = ["Sheehan Olver <solver@mac.com>"]
4-
version = "0.0.1"
4+
version = "0.0.2"
55

66
[deps]
77
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"

src/FastTransformsForwardDiff.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,29 @@ import FFTW: r2r, r2r!, plan_r2r, mul!, Plan
88
@inline tagtype(::Complex{T}) where T = tagtype(T)
99
@inline tagtype(::Type{Complex{T}}) where T = tagtype(T)
1010

11+
dual2array(x::Array{<:Dual{Tag,T}}) where {Tag,T} = reinterpret(reshape, T, x)
12+
dual2array(x::Array{<:Complex{<:Dual{Tag, T}}}) where {Tag,T} = complex.(dual2array(real(x)), dual2array(imag(x)))
13+
array2dual(DT::Type{<:Dual}, x::Array{T}) where T = reinterpret(reshape, DT, real(x))
14+
array2dual(DT::Type{<:Dual}, x::Array{<:Complex{T}}) where T = complex.(array2dual(DT, real(x)), array2dual(DT, imag(x)))
15+
16+
value(x::Complex{<:Dual}) = Complex(x.re.value, x.im.value)
17+
18+
partials(x::Complex{<:Dual}, n::Int) = Complex(partials(x.re, n), partials(x.im, n))
19+
20+
npartials(x::Complex{<:Dual{T,V,N}}) where {T,V,N} = N
21+
npartials(::Type{<:Complex{<:Dual{T,V,N}}}) where {T,V,N} = N
22+
23+
24+
for P in (:Plan, :ScaledPlan) # need ScaledPlan to avoid ambiguities
25+
@eval begin
26+
Base.:*(p::AbstractFFTs.$P, x::AbstractArray{DT}) where DT<:Dual = array2dual(DT, p * dual2array(x))
27+
Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:Complex{DT}}) where DT<:Dual = array2dual(DT, p * dual2array(x))
28+
end
29+
end
30+
31+
mul!(y::AbstractArray{<:Union{Dual,Complex{<:Dual}}}, p::Plan, x::AbstractArray{<:Union{Dual,Complex{<:Dual}}}) = copyto!(y, p*x)
32+
33+
1134
include("fft.jl")
1235

1336
end # module FastTransformsForwardDiff

src/fft.jl

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,3 @@
1-
dual2array(x::Array{<:Dual{Tag,T}}) where {Tag,T} = reinterpret(reshape, T, x)
2-
dual2array(x::Array{<:Complex{<:Dual{Tag, T}}}) where {Tag,T} = complex.(dual2array(real(x)), dual2array(imag(x)))
3-
array2dual(DT::Type{<:Dual}, x::Array{T}) where T = reinterpret(reshape, DT, real(x))
4-
array2dual(DT::Type{<:Dual}, x::Array{<:Complex{T}}) where T = complex.(array2dual(DT, real(x)), array2dual(DT, imag(x)))
5-
6-
value(x::Complex{<:Dual}) = Complex(x.re.value, x.im.value)
7-
8-
partials(x::Complex{<:Dual}, n::Int) = Complex(partials(x.re, n), partials(x.im, n))
9-
10-
npartials(x::Complex{<:Dual{T,V,N}}) where {T,V,N} = N
11-
npartials(::Type{<:Complex{<:Dual{T,V,N}}}) where {T,V,N} = N
121

132
# AbstractFFTs.complexfloat(x::AbstractArray{<:Dual}) = float.(x .+ 0im)
143
AbstractFFTs.complexfloat(x::AbstractArray{<:Dual}) = AbstractFFTs.complexfloat.(x)
@@ -37,12 +26,3 @@ end
3726
r2r(x::AbstractArray{<:Dual}, kinds, region...) = plan_r2r(x, kinds, region...) * x
3827
r2r(x::AbstractArray{<:Complex{<:Dual}}, kinds, region...) = plan_r2r(x, kinds, region...) * x
3928

40-
41-
for P in (:Plan, :ScaledPlan) # need ScaledPlan to avoid ambiguities
42-
@eval begin
43-
Base.:*(p::AbstractFFTs.$P, x::AbstractArray{DT}) where DT<:Dual = array2dual(DT, p * dual2array(x))
44-
Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:Complex{DT}}) where DT<:Dual = array2dual(DT, p * dual2array(x))
45-
end
46-
end
47-
48-
mul!(y::AbstractArray{<:Dual}, p::Plan, x::AbstractArray{<:Dual}) = copyto!(y, p*x)

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ end
6060
@test partials.(fft(A, 2), 1) == fft(partials.(A, 1), 2)
6161
@test partials.(fft(A, 2), 2) == fft(partials.(A, 2), 2)
6262
end
63+
64+
c1 = complex.(x1)
65+
@test mul!(similar(c1), FFTW.plan_fft(x1), x1) == fft(x1)
66+
@test mul!(similar(c1), FFTW.plan_fft(c1), c1) == fft(c1)
6367
end
6468

6569
@testset "r2r" begin

0 commit comments

Comments
 (0)