From 0ebaf61a46530170a507126d50c7dc0576a7ecbd Mon Sep 17 00:00:00 2001 From: Bas van Beek Date: Mon, 10 Jan 2022 17:40:26 +0100 Subject: [PATCH 1/2] ENH: Add a workflow for computing Debye scattering --- FOX/classes/multi_mol.py | 98 +++++++++- FOX/data/scattering.yaml | 392 +++++++++++++++++++++++++++++++++++++++ FOX/functions/debye.py | 139 ++++++++++++++ tests/test_multi_mol.py | 42 +++++ 4 files changed, 670 insertions(+), 1 deletion(-) create mode 100644 FOX/data/scattering.yaml create mode 100644 FOX/functions/debye.py diff --git a/FOX/classes/multi_mol.py b/FOX/classes/multi_mol.py index 6e001554..e594f723 100644 --- a/FOX/classes/multi_mol.py +++ b/FOX/classes/multi_mol.py @@ -34,7 +34,7 @@ from scipy.fftpack import fft from scipy.spatial.distance import cdist -from scm.plams import Molecule, Atom, Bond, PeriodicTable +from scm.plams import Molecule, Atom, Bond, PeriodicTable, Units from nanoutils import group_by_values, Literal from ..utils import slice_iter, lattice_to_volume @@ -45,6 +45,7 @@ from ..functions.adf import get_adf_df, _adf_inner_cdktree, _adf_inner from ..functions.molecule_utils import fix_bond_orders, separate_mod from ..functions.periodic import parse_periodic +from ..functions.debye import get_debye_scattering if TYPE_CHECKING: import numpy.typing as npt @@ -1365,6 +1366,101 @@ def init_rdf( df /= n_mol return df + def init_debye_scattering( + self, + half_angle: float | npt.NDArray[np.float64], + wavelength: float | npt.NDArray[np.float64], + mol_subset: MolSubset = None, + atom_subset: AtomSubset = None, + *, + periodic: None | Sequence[Literal["x", "y", "z"]] | Sequence[Literal[0, 1, 2]] = None, + atom_pairs: None | Iterable[tuple[str, str]] = None, + ) -> pd.DataFrame: + """Initialize the calculation of Debye scattering factors. + + Scatering factors are calculated for all possible atom-pairs in **atom_subset** and + returned as a dataframe. + + Parameters + ---------- + half_angle : :class:`float` or :class:`np.ndarray` + One or more half angles. Units should be in radian. + wavelength : :class:`float` + One or wavelengths. Units should be in nanometer. + mol_subset : :class:`slice`, optional + Perform the calculation on a subset of molecules in this instance, as + determined by their moleculair index. + Include all :math:`m` molecules in this instance if :data:`None`. + atom_subset : :class:`Sequence[str] `, optional + Perform the calculation on a subset of atoms in this instance, as + determined by their atomic index or atomic symbol. + Include all :math:`n` atoms per molecule in this instance if :data:`None`. + periodic : :class:`str`, optional + If specified, correct for the systems periodicity if + :attr:`self.lattice is not None `. + Accepts ``"x"``, ``"y"`` and/or ``"z"``. + atom_pairs : :class:`Iterable[tuple[str, str]] ` + An explicit list of atom-pairs for the to-be calculated distances. + Note that **atom_pairs** and **atom_subset** are mutually exclusive. + + Returns + ------- + :class:`pd.DataFrame ` + A dataframe of with the Debye scattering, averaged over all conformations. + Keys are of the form: at_symbol1 + ' ' + at_symbol2 (*e.g.* ``"Cd Cd"``). + + """ + if atom_subset is not None and atom_pairs is not None: + raise TypeError("`atom_subset` and `atom_pairs` are mutually exclusive") + elif atom_pairs is not None: + pair_dict = _parse_atom_pairs(self, atom_pairs) + elif atom_subset is not None: + pair_dict = self.get_pair_dict(atom_subset, r=2) + else: + # If **atom_subset** is None: extract atomic symbols from they keys of **self.atoms** + pair_dict = self.get_pair_dict(sorted(self.atoms, key=str), r=2) + + # Construct an empty dataframe with appropiate dimensions, indices and keys + q = np.atleast_1d(np.abs(4 * np.pi * np.sin(half_angle / wavelength))) + q *= Units.conversion_ratio("nm", "angstrom") + df = pd.DataFrame( + 0.0, + columns=pd.Index(pair_dict.keys(), name='Atom pairs'), + index=pd.Index(q, name="Q / Angstrom**-1"), + ) + + # Define the subset + m_subset = self._get_mol_subset(mol_subset) + m_self = self[m_subset] * Units.conversion_ratio("angstrom", "au") + + # Parse the lattice and periodicty settings + if periodic is not None: + periodic_ar = parse_periodic(periodic) + if self.lattice is None: + raise TypeError("cannot perform periodic calculations if the " + "molecules `lattice` is None") + lattice_ar = self.lattice if self.lattice.ndim == 2 else self.lattice[m_subset] + else: + lattice_ar = _GetNone() + periodic_ar = np.arange(3, dtype=np.int64) + + # Fill the dataframe with Debye scatterings, averaged over all conformations + n_mol = len(m_self) + symbol_ar = self.symbol + for key, (i, j) in pair_dict.items(): + shape = n_mol, len(i), len(j) + iterator = slice_iter(shape, m_self.dtype.itemsize) + for slc in iterator: + dist_mat = m_self.get_dist_mat( + mol_subset=slc, atom_subset=(i, j), + lattice=lattice_ar[slc], periodicity=periodic_ar, + ) + df[key] += get_debye_scattering( + dist_mat, symbol_ar[i], symbol_ar[j], q, validate_param=False + ).sum(axis=0) + df /= n_mol + return df + def get_dist_mat( self, mol_subset: MolSubset = None, diff --git a/FOX/data/scattering.yaml b/FOX/data/scattering.yaml new file mode 100644 index 00000000..89b71f8b --- /dev/null +++ b/FOX/data/scattering.yaml @@ -0,0 +1,392 @@ +H: + a: [0.493002, 0.322912, 0.140191, 0.04081] + b: [10.5109, 26.1257, 3.14236, 57.7997] + c: [0.003038] +He: + a: [0.8734, 0.6309, 0.3112, 0.178] + b: [9.1037, 3.3568, 22.9276, 0.9821] + c: [0.0064] +Li: + a: [1.1282, 0.7508, 0.6175, 0.4653] + b: [3.9546, 1.0524, 85.3905, 168.261] + c: [0.0377] +Be: + a: [1.5919, 1.1278, 0.5391, 0.7029] + b: [43.6427, 1.8623, 103.483, 0.542] + c: [0.0385] +B: + a: [2.0545, 1.3326, 1.0979, 0.7068] + b: [23.2185, 1.021, 60.3498, 0.1403] + c: [-0.1932] +C: + a: [2.31, 1.02, 1.5886, 0.865] + b: [20.8439, 10.2075, 0.5687, 51.6512] + c: [0.2156] +N: + a: [12.2126, 3.1322, 2.0125, 1.1663] + b: [0.0057, 9.8933, 28.9975, 0.5826] + c: [-11.529] +O: + a: [3.0485, 2.2868, 1.5463, 0.867] + b: [13.2771, 5.7011, 0.3239, 32.9089] + c: [0.2508] +F: + a: [3.5392, 2.6412, 1.517, 1.0243] + b: [10.2825, 4.2944, 0.2615, 26.1476] + c: [0.2776] +Ne: + a: [3.9553, 3.1125, 1.4546, 1.1251] + b: [8.4042, 3.4262, 0.2306, 21.7184] + c: [0.3515] +Na: + a: [4.7626, 3.1736, 1.2674, 1.1128] + b: [3.285, 8.8422, 0.3136, 129.424] + c: [0.676] +Mg: + a: [5.4204, 2.1735, 1.2269, 2.3073] + b: [2.8275, 79.2611, 0.3808, 7.1937] + c: [0.8584] +Al: + a: [6.4202, 1.9002, 1.5936, 1.9646] + b: [3.0387, 0.7426, 31.5472, 85.0886] + c: [1.1151] +Si: + a: [6.2915, 3.0353, 1.9891, 1.541] + b: [2.4386, 32.3337, 0.6785, 81.6937] + c: [1.1407] +P: + a: [6.4345, 4.1791, 1.78, 1.4908] + b: [1.9067, 27.157, 0.526, 68.1645] + c: [1.1149] +S: + a: [6.9053, 5.2034, 1.4379, 1.5863] + b: [1.4679, 22.2151, 0.2536, 56.172] + c: [0.8669] +Cl: + a: [11.4604, 7.1964, 6.2556, 1.6455] + b: [0.0104, 1.1662, 18.5194, 47.7784] + c: [-9.5574] +Ar: + a: [7.4845, 6.7723, 0.6539, 1.6442] + b: [0.9072, 14.8407, 43.8983, 33.3929] + c: [1.4445] +K: + a: [8.2186, 7.4398, 1.0519, 0.8659] + b: [12.7949, 0.7748, 213.187, 41.6841] + c: [1.4228] +Ca: + a: [8.6266, 7.3873, 1.5899, 1.0211] + b: [10.4421, 0.6599, 85.7484, 178.437] + c: [1.3751] +Sc: + a: [9.189, 7.3679, 1.6409, 1.468] + b: [9.0213, 0.5729, 136.108, 51.3531] + c: [1.3329] +Ti: + a: [9.7595, 7.3558, 1.6991, 1.9021] + b: [7.8508, 0.5, 35.6338, 116.105] + c: [1.2807] +V: + a: [10.2971, 7.3511, 2.0703, 2.0571] + b: [6.8657, 0.4385, 26.8938, 102.478] + c: [1.2199] +Cr: + a: [10.6406, 7.3537, 3.324, 1.4922] + b: [6.1038, 0.392, 20.2626, 98.7399] + c: [1.1832] +Mn: + a: [11.2819, 7.3573, 3.0193, 2.2441] + b: [5.3409, 0.3432, 17.8674, 83.7543] + c: [1.0896] +Fe: + a: [11.7695, 7.3573, 3.5222, 2.3045] + b: [4.7611, 0.3072, 15.3535, 76.8805] + c: [1.0369] +Co: + a: [12.2841, 7.3409, 4.0034, 2.3488] + b: [4.2791, 0.2784, 13.5359, 71.1692] + c: [1.0118] +Ni: + a: [12.8376, 7.292, 4.4438, 2.38] + b: [3.8785, 0.2565, 12.1763, 66.3421] + c: [1.0341] +Cu: + a: [13.338, 7.1676, 5.6158, 1.6735] + b: [3.5828, 0.247, 11.3966, 64.8126] + c: [1.191] +Zn: + a: [14.0743, 7.0318, 5.1652, 2.41] + b: [3.2655, 0.2333, 10.3163, 58.7097] + c: [1.3041] +Ga: + a: [15.2354, 6.7006, 4.3591, 2.9623] + b: [3.0669, 0.2412, 10.7805, 61.4135] + c: [1.7189] +Ge: + a: [16.0816, 6.3747, 3.7068, 3.683] + b: [2.8509, 0.2516, 11.4468, 54.7625] + c: [2.1313] +As: + a: [16.6723, 6.0701, 3.4313, 4.2779] + b: [2.6345, 0.2647, 12.9479, 47.7972] + c: [2.531] +Se: + a: [17.0006, 5.8196, 3.9731, 4.3543] + b: [2.4098, 0.2726, 15.2372, 43.8163] + c: [2.8409] +Br: + a: [17.1789, 5.2358, 5.6377, 3.9851] + b: [2.1723, 16.5796, 0.2609, 41.4328] + c: [2.9557] +Kr: + a: [17.3555, 6.7286, 5.5493, 3.5375] + b: [1.9384, 16.5623, 0.2261, 39.3972] + c: [2.825] +Rb: + a: [17.1784, 9.6435, 5.1399, 1.5292] + b: [1.7888, 17.3151, 0.2748, 164.934] + c: [3.4873] +Sr: + a: [17.5663, 9.8184, 5.422, 2.6694] + b: [1.5564, 14.0988, 0.1664, 132.376] + c: [2.5064] +Y: + a: [17.776, 10.2946, 5.72629, 3.26588] + b: [1.4029, 12.8006, 0.125599, 104.354] + c: [1.91213] +Zr: + a: [17.8765, 10.948, 5.41732, 3.65721] + b: [1.27618, 11.916, 0.117622, 87.6627] + c: [2.06929] +Nb: + a: [17.6142, 12.0144, 4.04183, 3.53346] + b: [1.18865, 11.766, 0.204785, 69.7957] + c: [3.75591] +Mo: + a: [3.7025, 17.2356, 12.8876, 3.7429] + b: [0.2772, 1.0958, 11.004, 61.6584] + c: [4.3875] +Tc: + a: [19.1301, 11.0948, 4.64901, 2.71263] + b: [0.864132, 8.14487, 21.5707, 86.8472] + c: [5.40428] +Ru: + a: [19.2674, 12.9182, 4.86337, 1.56756] + b: [0.80852, 8.43467, 24.7997, 94.2928] + c: [5.37874] +Rh: + a: [19.2957, 14.3501, 4.73425, 1.28918] + b: [0.751536, 8.21758, 25.8749, 98.6062] + c: [5.328] +Pd: + a: [19.3319, 15.5017, 5.29537, 0.605844] + b: [0.698655, 7.98929, 25.2052, 76.8986] + c: [5.26593] +Ag: + a: [19.2808, 16.6885, 4.8045, 1.0463] + b: [0.6446, 7.4726, 24.6605, 99.8156] + c: [5.179] +Cd: + a: [19.2214, 17.6444, 4.461, 1.6029] + b: [0.5946, 6.9089, 24.7008, 87.4825] + c: [5.0694] +In: + a: [19.1624, 18.5596, 4.2948, 2.0396] + b: [0.5476, 6.3776, 25.8499, 92.8029] + c: [4.9391] +Sn: + a: [19.1889, 19.1005, 4.4585, 2.4663] + b: [5.8303, 0.5031, 26.8909, 83.9571] + c: [4.7821] +Sb: + a: [19.6418, 19.0455, 5.0371, 2.6827] + b: [5.3034, 0.4607, 27.9074, 75.2825] + c: [4.5909] +Te: + a: [19.9644, 19.0138, 6.14487, 2.5239] + b: [4.81742, 0.420885, 28.5284, 70.8403] + c: [4.352] +I: + a: [20.1472, 18.9949, 7.5138, 2.2735] + b: [4.347, 0.3814, 27.766, 66.8776] + c: [4.0712] +Xe: + a: [20.2933, 19.0298, 8.9767, 1.99] + b: [3.9282, 0.344, 26.4659, 64.2658] + c: [3.7118] +Cs: + a: [20.3892, 19.1062, 10.662, 1.4953] + b: [3.569, 0.3107, 24.3879, 213.904] + c: [3.3352] +Ba: + a: [20.3361, 19.297, 10.888, 2.6959] + b: [3.216, 0.2756, 20.2073, 167.202] + c: [2.7731] +La: + a: [20.578, 19.599, 11.3727, 3.28719] + b: [2.94817, 0.244475, 18.7726, 133.124] + c: [2.14678] +Ce: + a: [21.1671, 19.7695, 11.8513, 3.33049] + b: [2.81219, 0.226836, 17.6083, 127.113] + c: [1.86264] +Pr: + a: [22.044, 19.6697, 12.3856, 2.82428] + b: [2.77393, 0.222087, 16.7669, 143.644] + c: [2.0583] +Nd: + a: [22.6845, 19.6847, 12.774, 2.85137] + b: [2.66248, 0.210628, 15.885, 137.903] + c: [1.98486] +Pm: + a: [23.3405, 19.6095, 13.1235, 2.87516] + b: [2.5627, 0.202088, 15.1009, 132.721] + c: [2.02876] +Sm: + a: [24.0042, 19.4258, 13.4396, 2.89604] + b: [2.47274, 0.196451, 14.3996, 128.007] + c: [2.20963] +Eu: + a: [24.6274, 19.0886, 13.7603, 2.9227] + b: [2.3879, 0.1942, 13.7546, 123.174] + c: [2.5745] +Gd: + a: [25.0709, 19.0798, 13.8518, 3.54545] + b: [2.25341, 0.181951, 12.9331, 101.398] + c: [2.4196] +Tb: + a: [25.8976, 18.2185, 14.3167, 2.95354] + b: [2.24256, 0.196143, 12.6648, 115.362] + c: [3.58324] +Dy: + a: [26.507, 17.6383, 14.5596, 2.96577] + b: [2.1802, 0.202172, 12.1899, 111.874] + c: [4.29728] +Ho: + a: [26.9049, 17.294, 14.5583, 3.63837] + b: [2.07051, 0.19794, 11.4407, 92.6566] + c: [4.56796] +Er: + a: [27.6563, 16.4285, 14.9779, 2.98233] + b: [2.07356, 0.223545, 11.3604, 105.703] + c: [5.92046] +Tm: + a: [28.1819, 15.8851, 15.1542, 2.98706] + b: [2.02859, 0.238849, 10.9975, 102.961] + c: [6.75621] +Yb: + a: [28.6641, 15.4345, 15.3087, 2.98963] + b: [1.9889, 0.257119, 10.6647, 100.417] + c: [7.56672] +Lu: + a: [28.9476, 15.2208, 15.1, 3.71601] + b: [1.90182, 9.98519, 0.261033, 84.3298] + c: [7.97628] +Hf: + a: [29.144, 15.1726, 14.7586, 4.30013] + b: [1.83262, 9.5999, 0.275116, 72.029] + c: [8.58154] +Ta: + a: [29.2024, 15.2293, 14.5135, 4.76492] + b: [1.77333, 9.37046, 0.295977, 63.3644] + c: [9.24354] +W: + a: [29.0818, 15.43, 14.4327, 5.11982] + b: [1.72029, 9.2259, 0.321703, 57.056] + c: [9.8875] +Re: + a: [28.7621, 15.7189, 14.5564, 5.44174] + b: [1.67191, 9.09227, 0.3505, 52.0861] + c: [10.472] +Os: + a: [28.1894, 16.155, 14.9305, 5.67589] + b: [1.62903, 8.97948, 0.382661, 48.1647] + c: [11.0005] +Ir: + a: [27.3049, 16.7296, 15.6115, 5.83377] + b: [1.59279, 8.86553, 0.417916, 45.0011] + c: [11.4722] +Pt: + a: [27.0059, 17.7639, 15.7131, 5.7837] + b: [1.51293, 8.81174, 0.424593, 38.6103] + c: [11.6883] +Au: + a: [16.8819, 18.5913, 25.5582, 5.86] + b: [0.4611, 8.6216, 1.4826, 36.3956] + c: [12.0658] +Hg: + a: [20.6809, 19.0417, 21.6575, 5.9676] + b: [0.545, 8.4484, 1.5729, 38.3246] + c: [12.6089] +Tl: + a: [27.5446, 19.1584, 15.538, 5.52593] + b: [0.65515, 8.70751, 1.96347, 45.8149] + c: [13.1746] +Pb: + a: [31.0617, 13.0637, 18.442, 5.9696] + b: [0.6902, 2.3576, 8.618, 47.2579] + c: [13.4118] +Bi: + a: [33.3689, 12.951, 16.5877, 6.4692] + b: [0.704, 2.9238, 8.7937, 48.0093] + c: [13.5782] +Po: + a: [34.6726, 15.4733, 13.1138, 7.02588] + b: [0.700999, 3.55078, 9.55642, 47.0045] + c: [13.677] +At: + a: [35.3163, 19.0211, 9.49887, 7.42518] + b: [0.68587, 3.97458, 11.3824, 45.4715] + c: [13.7108] +Rn: + a: [35.5631, 21.2816, 8.0037, 7.4433] + b: [0.6631, 4.0691, 14.0422, 44.2473] + c: [13.6905] +Fr: + a: [35.9299, 23.0547, 12.1439, 2.11253] + b: [0.646453, 4.17619, 23.1052, 150.645] + c: [13.7247] +Ra: + a: [35.763, 22.9064, 12.4739, 3.21097] + b: [0.616341, 3.87135, 19.9887, 142.325] + c: [13.6211] +Ac: + a: [35.6597, 23.1032, 12.5977, 4.08655] + b: [0.589092, 3.65155, 18.599, 117.02] + c: [13.5266] +Th: + a: [35.5645, 23.4219, 12.7473, 4.80703] + b: [0.563359, 3.46204, 17.8309, 99.1722] + c: [13.4314] +Pa: + a: [35.8847, 23.2948, 14.1891, 4.17287] + b: [0.547751, 3.41519, 16.9235, 105.251] + c: [13.4287] +U: + a: [36.0228, 23.4128, 14.9491, 4.188] + b: [0.5293, 3.3253, 16.0927, 100.613] + c: [13.3966] +Np: + a: [36.1874, 23.5964, 15.6402, 4.1855] + b: [0.511929, 3.25396, 15.3622, 97.4908] + c: [13.3573] +Pu: + a: [36.5254, 23.8083, 16.7707, 3.47947] + b: [0.499384, 3.26371, 14.9455, 105.98] + c: [13.3812] +Am: + a: [36.6706, 24.0992, 17.3415, 3.49331] + b: [0.483629, 3.20647, 14.3136, 102.273] + c: [13.3592] +Cm: + a: [36.6488, 24.4096, 17.399, 4.21665] + b: [0.465154, 3.08997, 13.4346, 88.4834] + c: [13.2887] +Bk: + a: [36.7881, 24.7736, 17.8919, 4.23284] + b: [0.451018, 3.04619, 12.8946, 86.003] + c: [13.2754] +Cf: + a: [36.9185, 25.1995, 18.3317, 4.24391] + b: [0.437533, 3.00775, 12.4044, 83.7881] + c: [13.2674] diff --git a/FOX/functions/debye.py b/FOX/functions/debye.py new file mode 100644 index 00000000..7f582393 --- /dev/null +++ b/FOX/functions/debye.py @@ -0,0 +1,139 @@ +"""A module for computing Debye scattering factors. + +Index +----- +.. currentmodule:: FOX.functions.debye +.. autosummary:: + get_debye_scattering + SCATERING_FACTORS + +API +--- +.. autofunction:: get_debye_scattering +.. autodata:: SCATERING_FACTORS + +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +from pathlib import Path + +import yaml +import numpy as np +import pandas as pd + +if TYPE_CHECKING: + from numpy.typing import NDArray + from numpy import str_ as U, float64 as f8 + from typing_extensions import TypedDict + + class _CoefDict(TypedDict): + a: list[float] + b: list[float] + c: list[float] + +__all__ = ["SCATERING_FACTORS", "get_debye_scattering"] + + +def _load_scattering_df() -> pd.DataFrame: + """Load the dataframe with scattering coefficients. + + Coefficients are taken from: + International Tables for Crystallography (2006). Vol. C, ch. 6.1, pp. 578-580, table 6.1.1.4 + + .. code-block:: python + + >>> df = _load_scattering_df() + >>> print(df) + coef a b c + i 0 1 2 3 0 1 2 3 0 + symbol + H 0.493002 0.322912 0.140191 0.04081 10.510900 26.12570 3.14236 57.7997 0.003038 + He 0.873400 0.630900 0.311200 0.17800 9.103700 3.35680 22.92760 0.9821 0.006400 + Li 1.128200 0.750800 0.617500 0.46530 3.954600 1.05240 85.39050 168.2610 0.037700 + Be 1.591900 1.127800 0.539100 0.70290 43.642700 1.86230 103.48300 0.5420 0.038500 + B 2.054500 1.332600 1.097900 0.70680 23.218500 1.02100 60.34980 0.1403 -0.193200 + ... ... ... ... ... ... ... ... ... ... + Pu 36.525400 23.808300 16.770700 3.47947 0.499384 3.26371 14.94550 105.9800 13.381200 + Am 36.670600 24.099200 17.341500 3.49331 0.483629 3.20647 14.31360 102.2730 13.359200 + Cm 36.648800 24.409600 17.399000 4.21665 0.465154 3.08997 13.43460 88.4834 13.288700 + Bk 36.788100 24.773600 17.891900 4.23284 0.451018 3.04619 12.89460 86.0030 13.275400 + Cf 36.918500 25.199500 18.331700 4.24391 0.437533 3.00775 12.40440 83.7881 13.267400 + + """ # noqa: E501 + root = Path(__file__).parents[1] + with open(root / "data" / "scattering.yaml", "r") as f: + dct: dict[str, _CoefDict] = yaml.load(f, Loader=yaml.SafeLoader) + + columns = pd.MultiIndex.from_tuples([ + ("a", 0), + ("a", 1), + ("a", 2), + ("a", 3), + ("b", 0), + ("b", 1), + ("b", 2), + ("b", 3), + ("c", 0), + ], names=("Coefficients", "i")) + + index = [] + data = np.empty((len(dct), 9), order="F", dtype=np.float64) + for i, (k, v) in enumerate(dct.items()): + index.append(k) + data[i] = v["a"] + v["b"] + v["c"] + return pd.DataFrame(data, index=pd.Index(index, name="symbol"), columns=columns) + + +#: A dataframe with generalized X-ray scattering coefficients. +#: +#: See Also +#: -------- +#: International Tables for Crystallography (2006). Vol. C, ch. 6.1, pp. 578-580, table 6.1.1.4 +SCATERING_FACTORS = _load_scattering_df() + + +def _get_scattering(symbol: NDArray[U], q: NDArray[f8]) -> NDArray[f8]: + """Computer the scattering factors for the given atomic symbols.""" + stol2 = (q / (4 * np.pi))**2 + a = SCATERING_FACTORS.loc[symbol, "a"].values + b = SCATERING_FACTORS.loc[symbol, "b"].values + ret = SCATERING_FACTORS.loc[symbol, ("c", 0)].values.copy() + ret += (a * np.exp(-b * stol2)).sum(axis=1) + return ret + + +@np.errstate(invalid="ignore") +def get_debye_scattering( + dist_mat: NDArray[f8], + symbols1: NDArray[U], + symbols2: NDArray[U], + scattering_vector: NDArray[f8], + validate_param: bool = True, +) -> NDArray[f8]: + """Placeholder.""" + if validate_param: + dist_mat = np.array(dist_mat, dtype=np.float64, ndmin=3, copy=False) + scattering_vector = np.array(scattering_vector, dtype=np.float64, ndmin=1, copy=False) + symbols1 = np.array(symbols1, dtype=np.str_, ndmin=1, copy=False) + symbols2 = np.array(symbols2, dtype=np.str_, ndmin=1, copy=False) + + try: + assert dist_mat.ndim == 3, "Invalid `dist_mat` dimensionality" + assert symbols1.ndim == 1, "Invalid `symbols1` dimensionality" + assert symbols2.ndim == 1, "Invalid `symbols2` dimensionality" + assert scattering_vector.ndim == 1, "Invalid `scattering_vector` dimensionality" + except AssertionError as ex: + raise ValueError(str(ex)) from None + + q_r_ij = scattering_vector * dist_mat + f_ij = ( + _get_scattering(symbols1, scattering_vector)[..., None] * + _get_scattering(symbols2, scattering_vector)[None, ...] + ) + + ret = np.sin(q_r_ij) + ret /= q_r_ij + ret *= f_ij[None, ...] + return np.nansum(ret, axis=(1, 2)) diff --git a/tests/test_multi_mol.py b/tests/test_multi_mol.py index 63a615c2..0990f570 100644 --- a/tests/test_multi_mol.py +++ b/tests/test_multi_mol.py @@ -243,6 +243,48 @@ def test_raises( mol.init_rdf(**kwargs) +class TestDebye: + """Test :meth:`.MultiMolecule.init_debye_scattering`.""" + + @pytest.mark.parametrize("kwargs", [ + ({'atom_subset': ('Cd', 'Se', 'O')}), + ({'mol_subset': np.s_[::10]}), + ({"atom_pairs": [("Cd", "Se")]}), + ]) + def test_passes(self, kwargs: Mapping[str, Any]) -> None: + debye = MOL[:100].init_debye_scattering(1, 1, **kwargs) + assertion.assert_(np.isfinite, debye, post_process=np.all) + + @pytest.mark.parametrize("mol", [MOL_LATTICE_3D, MOL_LATTICE_2D]) + @pytest.mark.parametrize( + "periodic", + chain( + [None], + combinations("xyz", 1), + combinations("xyz", 2), + combinations("xyz", 3), + ), + ) + def test_lattice( + self, mol: MultiMolecule, periodic: None | Sequence[Literal["x", "y", "z"]] + ) -> None: + assert mol.lattice is not None + debye = mol.init_debye_scattering(1, 1, periodic=periodic) + assertion.assert_(np.isfinite, debye, post_process=np.all) + + @pytest.mark.parametrize("mol,kwargs,exc", [ + (MOL_LATTICE_3D, {"periodic": "bob"}, ValueError), + (MOL, {"periodic": "xyz"}, TypeError), + (MOL, {"atom_subset": "Cd", "atom_pairs": [("Cd", "Cd")]}, TypeError), + (MOL, {"atom_pairs": [("Cd", "Bob")]}, ValueError), + ]) + def test_raises( + self, mol: MultiMolecule, kwargs: Mapping[str, Any], exc: Type[Exception] + ) -> None: + with pytest.raises(exc): + mol.init_debye_scattering(1, 1, **kwargs) + + def test_rmsf(): """Test :meth:`.MultiMolecule.init_rmsf`.""" mol = MOL.copy() From ccf21fd3c2dc664c7b81c5e413192f7b71e8bc3c Mon Sep 17 00:00:00 2001 From: Bas van Beek Date: Thu, 27 Jan 2022 16:06:23 +0100 Subject: [PATCH 2/2] BLD: Bump the minimum numpy version to 1.17 1.17 is the first version that made `np.errstate` usable as a decorator --- .github/workflows/pythonpackage.yml | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pythonpackage.yml b/.github/workflows/pythonpackage.yml index 887049f1..585680c1 100644 --- a/.github/workflows/pythonpackage.yml +++ b/.github/workflows/pythonpackage.yml @@ -80,7 +80,7 @@ jobs: pip install --pre -e .[test_no_optional] --upgrade --force-reinstall pip install git+https://github.com/SCM-NV/PLAMS@master --upgrade elif [[ $SPECIAL == '; minimum version' ]]; then - pip install Nano-Utils==1.2.1 schema==0.7.1 AssertionLib==2.2 noodles==0.3.3 pyyaml==5.1 numpy==1.15 h5py==2.10 pandas==0.24 scipy==1.2.0 + pip install Nano-Utils==1.2.1 schema==0.7.1 AssertionLib==2.2 noodles==0.3.3 pyyaml==5.1 numpy==1.17 h5py==2.10 pandas==0.24 scipy==1.2.0 pip install -e .[test_no_optional] else pip install -e .[test_no_optional] diff --git a/setup.py b/setup.py index d9501503..d0daf215 100644 --- a/setup.py +++ b/setup.py @@ -103,7 +103,7 @@ install_requires=[ 'Nano-Utils>=2.3', 'pyyaml>=5.1', - 'numpy>=1.15', + 'numpy>=1.17', 'scipy>=1.2', 'pandas>=0.24', 'schema>=0.7.1,!=0.7.5',