Skip to content

Commit 4fc1d23

Browse files
committed
Fixing generate_xdmf and FFTNumpy
1 parent 731a889 commit 4fc1d23

File tree

4 files changed

+43
-33
lines changed

4 files changed

+43
-33
lines changed

mpi4py_fft/io/generate_xdmf.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,10 @@ def get_geometry(kind=0, dim=2):
4040

4141
return """<Geometry Type="VXVY">
4242
<DataItem Format="HDF" NumberType="Float" Precision="{0}" Dimensions="{1}">
43-
{3}:/mesh/{4}
43+
{3}:{6}/mesh/{4}
4444
</DataItem>
4545
<DataItem Format="HDF" NumberType="Float" Precision="{0}" Dimensions="{2}">
46-
{3}:/mesh/{5}
46+
{3}:{6}/mesh/{5}
4747
</DataItem>
4848
</Geometry>"""
4949

@@ -60,13 +60,13 @@ def get_geometry(kind=0, dim=2):
6060

6161
return """<Geometry Type="VXVYVZ">
6262
<DataItem Format="HDF" NumberType="Float" Precision="{0}" Dimensions="{3}">
63-
{4}:/mesh/{5}
63+
{4}:{8}/mesh/{5}
6464
</DataItem>
6565
<DataItem Format="HDF" NumberType="Float" Precision="{0}" Dimensions="{2}">
66-
{4}:/mesh/{6}
66+
{4}:{8}/mesh/{6}
6767
</DataItem>
6868
<DataItem Format="HDF" NumberType="Float" Precision="{0}" Dimensions="{1}">
69-
{4}:/mesh/{7}
69+
{4}:{8}/mesh/{7}
7070
</DataItem>
7171
</Geometry>"""
7272

@@ -233,17 +233,17 @@ def generate_xdmf(h5filename, periodic=True, order='paraview'):
233233

234234
if ndim == 2 and ('slice' not in slices or len(f[group].attrs.get('shape')) > 3):
235235
if order.lower() == 'paraview':
236-
sig = (prec, N[0], N[1], h5filename, cc[0], cc[1])
236+
sig = (prec, N[0], N[1], h5filename, cc[0], cc[1], group)
237237
else:
238-
sig = (prec, N[1], N[0], h5filename, cc[1], cc[0])
238+
sig = (prec, N[1], N[0], h5filename, cc[1], cc[0], group)
239239
else:
240240
if ndim == 2: # 2D slice in 3D domain
241241
pos = f[group+"/mesh/x{}".format(kk)][sl]
242242
z = re.findall(r'<DataItem(.*?)</DataItem>', geo, re.DOTALL)
243243
geo = geo.replace(z[2-kk], ' Format="XML" NumberType="Float" Precision="{0}" Dimensions="{%d}">\n {%d}\n '%(1+kk, 7-kk))
244244
cc = list(cc)
245245
cc.insert(kk, pos)
246-
sig = (prec, N[0], N[1], N[2], h5filename, cc[2], cc[1], cc[0])
246+
sig = (prec, N[0], N[1], N[2], h5filename, cc[2], cc[1], cc[0], group)
247247
geometry[slices] = geo.format(*sig)
248248
topology[slices] = get_topology(N, kind=1)
249249
grid[slices] = ''

mpi4py_fft/libfft.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import numpy as np
23
from . import fftw
34

@@ -371,28 +372,37 @@ class FFTNumPy(FFTBase): #pragma: no cover
371372
372373
"""
373374

374-
def __init__(self, shape, axes=None, dtype=float, padding=False, **kw):
375+
def __init__(self, shape, axes=None, dtype=float, padding=False,
376+
transforms=None, **kw):
375377
FFTBase.__init__(self, shape, axes, dtype, padding)
376378
typecode = self.dtype.char
377379

378380
self.sizes = list(np.take(self.shape, self.axes))
379381
arrayA = np.zeros(self.shape, self.dtype)
380-
if self.real_transform:
381-
axis = self.axes[-1]
382-
self.shape[axis] = self.shape[axis]//2 + 1
383-
arrayB = np.zeros(self.shape, typecode.upper())
384-
fwd = np.fft.rfftn
385-
bck = np.fft.irfftn
382+
transforms = {} if transforms is None else transforms
383+
if tuple(self.axes) in transforms:
384+
fwd, bck = transforms[tuple(self.axes)]
385+
arrayB = fwd(arrayA, axes=self.axes).astype(typecode)
386+
self.fwd = functools.partial(fwd, shape=self.sizes)
387+
self.bck = functools.partial(bck, shape=self.sizes)
388+
386389
else:
387-
arrayB = np.zeros(self.shape, typecode)
388-
fwd = np.fft.fftn
389-
bck = np.fft.ifftn
390+
if self.real_transform:
391+
fwd = np.fft.rfftn
392+
bck = np.fft.irfftn
393+
arrayB = fwd(arrayA, s=self.sizes, axes=self.axes).astype(typecode.upper())
394+
self.shape = arrayB.shape
395+
else:
396+
fwd = np.fft.fftn
397+
bck = np.fft.ifftn
398+
arrayB = fwd(arrayA, s=self.sizes, axes=self.axes).astype(typecode)
399+
self.fwd = functools.partial(fwd, s=self.sizes)
400+
self.bck = functools.partial(bck, s=self.sizes)
390401

391-
fwd.input_array = arrayA
392-
fwd.output_array = arrayB
393-
bck.input_array = arrayB
394-
bck.output_array = arrayA
395-
self.fwd, self.bck = fwd, bck
402+
self.fwd_input_array = arrayA
403+
self.fwd_output_array = arrayB
404+
self.bck_input_array = arrayB
405+
self.bck_output_array = arrayA
396406

397407
self.padding_factor = 1
398408
if padding is not False:
@@ -407,14 +417,14 @@ def __init__(self, shape, axes=None, dtype=float, padding=False, **kw):
407417
self.backward = _Xfftn_wrap(self._backward, arrayB, arrayA)
408418

409419
def _forward(self, **kw):
410-
self.fwd.output_array[:] = self.fwd(self.fwd.input_array, s=self.sizes,
420+
self.fwd_output_array[:] = self.fwd(self.fwd_input_array,
411421
axes=self.axes, **kw)
412-
self._truncation_forward(self.fwd.output_array, self.forward.output_array)
422+
self._truncation_forward(self.fwd_output_array, self.forward.output_array)
413423
return self.forward.output_array
414424

415425
def _backward(self, **kw):
416-
self._padding_backward(self.backward.input_array, self.bck.input_array)
417-
self.backward.output_array[:] = self.bck(self.bck.input_array, s=self.sizes,
426+
self._padding_backward(self.backward.input_array, self.bck_input_array)
427+
self.backward.output_array[:] = self.bck(self.bck_input_array,
418428
axes=self.axes, **kw)
419429
return self.backward.output_array
420430

mpi4py_fft/mpifft.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,8 @@ def __init__(self, comm, shape=None, axes=None, dtype=float, slab=False,
288288

289289
axes = self.axes[-1]
290290
pencil = Pencil(self.subcomm, shape, axes[-1])
291-
xfftn = FFT(pencil.subshape, axes, dtype, padding, use_pyfftw,
292-
transforms, **kw)
291+
xfftn = FFT(pencil.subshape, axes, dtype, padding, use_pyfftw=use_pyfftw,
292+
transforms=transforms, **kw)
293293
self.xfftn.append(xfftn)
294294
self.pencil[0] = pencilA = pencil
295295
if not shape[axes[-1]] == xfftn.forward.output_array.shape[axes[-1]]:
@@ -300,8 +300,8 @@ def __init__(self, comm, shape=None, axes=None, dtype=float, slab=False,
300300
for axes in reversed(self.axes[:-1]):
301301
pencilB = pencilA.pencil(axes[-1])
302302
transAB = pencilA.transfer(pencilB, dtype)
303-
xfftn = FFT(pencilB.subshape, axes, dtype, padding, use_pyfftw,
304-
transforms, **kw)
303+
xfftn = FFT(pencilB.subshape, axes, dtype, padding, use_pyfftw=use_pyfftw,
304+
transforms=transforms, **kw)
305305
self.xfftn.append(xfftn)
306306
self.transfer.append(transAB)
307307
pencilA = pencilB

tests/test_libfft.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
except ImportError:
1111
has_pyfftw = False
1212

13-
abstol = dict(f=5e-5, d=1e-14, g=1e-15)
13+
abstol = dict(f=5e-5, d=1e-14, g=1e-14)
1414

1515
def allclose(a, b):
1616
atol = abstol[a.dtype.char.lower()]
@@ -91,7 +91,7 @@ def test_libfft():
9191

9292
B.fill(0)
9393
B = fft.forward(A, B)
94-
assert allclose(B, X)
94+
assert allclose(B, X), np.linalg.norm(B-X)
9595

9696

9797
if __name__ == '__main__':

0 commit comments

Comments
 (0)