-
Notifications
You must be signed in to change notification settings - Fork 237
/
Copy pathSpecialFunctionsExt.jl
58 lines (35 loc) · 3.26 KB
/
SpecialFunctionsExt.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# math functionality corresponding to SpecialFunctions.jl
module SpecialFunctionsExt
using CUDA
using CUDA: @device_override
isdefined(Base, :get_extension) ? (using SpecialFunctions) : (using ..SpecialFunctions)
## error
@device_override SpecialFunctions.erf(x::Float64) = ccall("extern __nv_erf", llvmcall, Cdouble, (Cdouble,), x)
@device_override SpecialFunctions.erf(x::Float32) = ccall("extern __nv_erff", llvmcall, Cfloat, (Cfloat,), x)
@device_override SpecialFunctions.erfinv(x::Float64) = ccall("extern __nv_erfinv", llvmcall, Cdouble, (Cdouble,), x)
@device_override SpecialFunctions.erfinv(x::Float32) = ccall("extern __nv_erfinvf", llvmcall, Cfloat, (Cfloat,), x)
@device_override SpecialFunctions.erfc(x::Float64) = ccall("extern __nv_erfc", llvmcall, Cdouble, (Cdouble,), x)
@device_override SpecialFunctions.erfc(x::Float32) = ccall("extern __nv_erfcf", llvmcall, Cfloat, (Cfloat,), x)
@device_override SpecialFunctions.erfcinv(x::Float64) = ccall("extern __nv_erfcinv", llvmcall, Cdouble, (Cdouble,), x)
@device_override SpecialFunctions.erfcinv(x::Float32) = ccall("extern __nv_erfcinvf", llvmcall, Cfloat, (Cfloat,), x)
@device_override SpecialFunctions.erfcx(x::Float64) = ccall("extern __nv_erfcx", llvmcall, Cdouble, (Cdouble,), x)
@device_override SpecialFunctions.erfcx(x::Float32) = ccall("extern __nv_erfcxf", llvmcall, Cfloat, (Cfloat,), x)
## gamma function
@device_override SpecialFunctions.loggamma(x::Float64) = ccall("extern __nv_lgamma", llvmcall, Cdouble, (Cdouble,), x)
@device_override SpecialFunctions.loggamma(x::Float32) = ccall("extern __nv_lgammaf", llvmcall, Cfloat, (Cfloat,), x)
@device_override SpecialFunctions.gamma(x::Float64) = ccall("extern __nv_tgamma", llvmcall, Cdouble, (Cdouble,), x)
@device_override SpecialFunctions.gamma(x::Float32) = ccall("extern __nv_tgammaf", llvmcall, Cfloat, (Cfloat,), x)
## Bessel
@device_override SpecialFunctions.besselj0(x::Float64) = ccall("extern __nv_j0", llvmcall, Cdouble, (Cdouble,), x)
@device_override SpecialFunctions.besselj0(x::Float32) = ccall("extern __nv_j0f", llvmcall, Cfloat, (Cfloat,), x)
@device_override SpecialFunctions.besselj1(x::Float64) = ccall("extern __nv_j1", llvmcall, Cdouble, (Cdouble,), x)
@device_override SpecialFunctions.besselj1(x::Float32) = ccall("extern __nv_j1f", llvmcall, Cfloat, (Cfloat,), x)
@device_override SpecialFunctions.besselj(n::Int32, x::Float64) = ccall("extern __nv_jn", llvmcall, Cdouble, (Int32, Cdouble), n, x)
@device_override SpecialFunctions.besselj(n::Int32, x::Float32) = ccall("extern __nv_jnf", llvmcall, Cfloat, (Int32, Cfloat), n, x)
@device_override SpecialFunctions.bessely0(x::Float64) = ccall("extern __nv_y0", llvmcall, Cdouble, (Cdouble,), x)
@device_override SpecialFunctions.bessely0(x::Float32) = ccall("extern __nv_y0f", llvmcall, Cfloat, (Cfloat,), x)
@device_override SpecialFunctions.bessely1(x::Float64) = ccall("extern __nv_y1", llvmcall, Cdouble, (Cdouble,), x)
@device_override SpecialFunctions.bessely1(x::Float32) = ccall("extern __nv_y1f", llvmcall, Cfloat, (Cfloat,), x)
@device_override SpecialFunctions.bessely(n::Int32, x::Float64) = ccall("extern __nv_yn", llvmcall, Cdouble, (Int32, Cdouble), n, x)
@device_override SpecialFunctions.bessely(n::Int32, x::Float32) = ccall("extern __nv_ynf", llvmcall, Cfloat, (Int32, Cfloat), n, x)
end