Skip to content

Commit 9d55b94

Browse files
committed
wrap mcintegration
1 parent b91f3da commit 9d55b94

File tree

5 files changed

+72
-4
lines changed

5 files changed

+72
-4
lines changed

Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ version = "4.1.0"
77
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
88
HCubature = "19dc6840-f33b-545b-b366-655c7e3ffd49"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
10+
MCIntegration = "ea1e2de9-7db7-4b42-91ee-0cd1bf6df167"
1011
MonteCarloIntegration = "4886b29c-78c9-11e9-0a6e-41e1f4161f7b"
1112
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
1213
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
@@ -41,6 +42,7 @@ FastGaussQuadrature = "0.5"
4142
ForwardDiff = "0.10"
4243
HCubature = "1.4"
4344
LinearAlgebra = "1.9"
45+
MCIntegration = "0.4"
4446
MonteCarloIntegration = "0.0.1, 0.0.2, 0.0.3, 0.1"
4547
QuadGK = "2.5"
4648
Reexport = "0.2, 1.0"

docs/src/solvers/IntegralSolvers.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ The following algorithms are available:
44

55
- `QuadGKJL`: Uses QuadGK.jl. Requires `nout=1` and `batch=0`, in-place is not allowed.
66
- `HCubatureJL`: Uses HCubature.jl. Requires `batch=0`.
7-
- `VEGAS`: Uses MonteCarloIntegration.jl. Requires `nout=1`. Works only for `>1`-dimensional integrations.
7+
- `VEGAS`: Uses MonteCarloIntegration.jl. Requires `nout=1`. Works only for
8+
`>1`-dimensional integrations.
9+
- `VEGASMC`: Uses MCIntegration.jl. Requires `nout=1`. Works only for `>1`-dimensional integrations.
810
- `CubatureJLh`: h-Cubature from Cubature.jl. Requires `using Cubature`.
911
- `CubatureJLp`: p-Cubature from Cubature.jl. Requires `using Cubature`.
1012
- `CubaVegas`: Vegas from Cuba.jl. Requires `using Cuba`, `nout=1`.
@@ -20,6 +22,7 @@ The following algorithms are available:
2022
QuadGKJL
2123
HCubatureJL
2224
VEGAS
25+
VEGASMC
2326
CubaVegas
2427
CubaSUAVE
2528
CubaDivonne

src/Integrals.jl

+46-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ if !isdefined(Base, :get_extension)
44
using Requires
55
end
66

7-
using Reexport, MonteCarloIntegration, QuadGK, HCubature
7+
using Reexport, MonteCarloIntegration, QuadGK, HCubature, MCIntegration
88
@reexport using SciMLBase
99
using LinearAlgebra
1010

@@ -174,7 +174,51 @@ function __solvebp_call(prob::IntegralProblem, alg::VEGAS, sensealg, domain, p;
174174
SciMLBase.build_solution(prob, alg, val, err, chi = chi, retcode = ReturnCode.Success)
175175
end
176176

177-
export QuadGKJL, HCubatureJL, VEGAS, GaussLegendre, QuadratureRule, TrapezoidalRule
177+
178+
function __solvebp_call(prob::IntegralProblem, alg::VEGASMC, sensealg, domain, p;
179+
reltol = nothing, abstol = nothing, maxiters = 1000)
180+
lb, ub = domain
181+
mid = collect((lb + ub) / 2)
182+
vars = Continuous(vec([tuple(a,b) for (a,b) in zip(lb, ub)]))
183+
184+
if prob.f isa BatchIntegralFunction
185+
error("VEGASMC doesn't support batching. See https://github.com/numericalEFT/MCIntegration.jl/issues/29")
186+
else
187+
if isinplace(prob)
188+
f0 = similar(prob.f.integrand_prototype)
189+
f_ = (x, f, c) -> begin
190+
n = 0
191+
for v in x
192+
mid[n+=1] = first(v)
193+
end
194+
prob.f(f0, mid, p)
195+
f .= vec(f0)
196+
end
197+
else
198+
f0 = prob.f(mid, p)
199+
f_ = (x, c) -> begin
200+
n = 0
201+
for v in x
202+
mid[n+=1] = first(v)
203+
end
204+
prob.f(mid, p)
205+
end
206+
end
207+
dof = f0 isa Number ? 1 : ones(Int, length(f0))
208+
res = integrate(f_, inplace=isinplace(prob), var=vars, dof=dof, solver=:vegasmc,
209+
neval=alg.neval, niter=min(maxiters,alg.niter), block=alg.block, adapt=alg.adapt,
210+
gamma=alg.gamma, verbose=alg.verbose, debug=alg.debug)
211+
out, err, chi = if f0 isa Number
212+
map(only, (res.mean, res.stdev, res.chi2))
213+
else
214+
map(a -> reshape(a, size(f0)), (res.mean, res.stdev, res.chi2))
215+
end
216+
SciMLBase.build_solution(prob, VEGASMC(), out, err, chi=chi, retcode = ReturnCode.Success)
217+
end
218+
end
219+
220+
221+
export QuadGKJL, HCubatureJL, VEGAS, VEGASMC, GaussLegendre, QuadratureRule, TrapezoidalRule
178222
export CubaVegas, CubaSUAVE, CubaDivonne, CubaCuhre
179223
export CubatureJLh, CubatureJLp
180224
export ArblibJL

src/algorithms.jl

+18
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,24 @@ struct VEGAS <: SciMLBase.AbstractIntegralAlgorithm
8787
end
8888
VEGAS(; nbins = 100, ncalls = 1000, debug = false) = VEGAS(nbins, ncalls, debug)
8989

90+
91+
"""
92+
VEGASMC()
93+
94+
Markov-chain based Vegas algorithm from MCIntegration.jl
95+
"""
96+
struct VEGASMC <: SciMLBase.AbstractIntegralAlgorithm
97+
neval::Int
98+
niter::Int
99+
block::Int
100+
adapt::Bool
101+
gamma::Float64
102+
verbose::Int
103+
debug::Bool
104+
end
105+
VEGASMC(; neval=10^4, niter=20, block=16, adapt=true, gamma=1.0, verbose=-2, debug=false) =
106+
VEGASMC(neval, niter, block, adapt, gamma, verbose, debug)
107+
90108
"""
91109
GaussLegendre{C, N, W}
92110

test/interface_tests.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ max_nout_test = 2
77
reltol = 1e-3
88
abstol = 1e-3
99

10-
algs = [QuadGKJL, HCubatureJL, CubatureJLh, CubatureJLp, VEGAS, #CubaVegas,
10+
algs = [QuadGKJL, HCubatureJL, CubatureJLh, CubatureJLp, VEGAS, VEGASMC, #CubaVegas,
1111
CubaSUAVE, CubaDivonne, CubaCuhre, ArblibJL]
1212

1313
alg_req = Dict(QuadGKJL => (nout = 1, allows_batch = false, min_dim = 1, max_dim = 1,
@@ -16,6 +16,7 @@ alg_req = Dict(QuadGKJL => (nout = 1, allows_batch = false, min_dim = 1, max_dim
1616
max_dim = Inf, allows_iip = true),
1717
VEGAS => (nout = 1, allows_batch = true, min_dim = 2, max_dim = Inf,
1818
allows_iip = true),
19+
VEGASMC => (nout = Inf, allows_batch = false, min_dim = 1, max_dim = Inf, allows_iip = true),
1920
CubatureJLh => (nout = Inf, allows_batch = true, min_dim = 1,
2021
max_dim = Inf, allows_iip = true),
2122
CubatureJLp => (nout = Inf, allows_batch = true, min_dim = 1,

0 commit comments

Comments
 (0)