diff --git a/docs/changes/newsfragments/350.enh b/docs/changes/newsfragments/350.enh new file mode 100644 index 0000000000..1cd260fd10 --- /dev/null +++ b/docs/changes/newsfragments/350.enh @@ -0,0 +1 @@ +Add support for ``UKB_15K_GM`` mask by `Synchon Mandal`_ diff --git a/junifer/data/masks.py b/junifer/data/masks.py index f7b1f503de..227131490b 100644 --- a/junifer/data/masks.py +++ b/junifer/data/masks.py @@ -163,6 +163,10 @@ def compute_brain_mask( "func": compute_epi_mask, "space": "inherit", }, + "UKB_15K_GM": { + "family": "UKB", + "space": "MNI152NLin6Asym", + }, } @@ -567,6 +571,8 @@ def load_mask( elif t_family == "Callable": mask_img = mask_definition["func"] mask_fname = None + elif t_family == "UKB": + mask_fname = _load_ukb_mask(name) else: raise_error(f"I don't know about the {t_family} mask family.") @@ -632,3 +638,33 @@ def _load_vickery_patil_mask( mask_fname = _masks_path / "vickery-patil" / mask_fname return mask_fname + + +def _load_ukb_mask(name: str) -> Path: + """Load UKB mask. + + Parameters + ---------- + name : {"UKB_15K_GM"} + The name of the mask. + + Returns + ------- + pathlib.Path + File path to the mask image. + + Raises + ------ + ValueError + If ``name`` is invalid. + + """ + if name == "UKB_15K_GM": + mask_fname = "UKB_15K_GM_template.nii.gz" + else: + raise_error(f"Cannot find a UKB mask called {name}") + + # Set path for masks + mask_fname = _masks_path / "ukb" / mask_fname + + return mask_fname diff --git a/junifer/data/masks/ukb/UKB_15K_GM_template.nii.gz b/junifer/data/masks/ukb/UKB_15K_GM_template.nii.gz new file mode 100644 index 0000000000..289394f955 Binary files /dev/null and b/junifer/data/masks/ukb/UKB_15K_GM_template.nii.gz differ diff --git a/junifer/data/tests/test_masks.py b/junifer/data/tests/test_masks.py index 5851e7ae70..a4526213c3 100644 --- a/junifer/data/tests/test_masks.py +++ b/junifer/data/tests/test_masks.py @@ -22,6 +22,7 @@ from junifer.data.masks import ( _available_masks, + _load_ukb_mask, _load_vickery_patil_mask, compute_brain_mask, get_mask, @@ -212,6 +213,7 @@ def test_register_mask( [ "GM_prob0.2", "GM_prob0.2_cortex", + "UKB_15K_GM", ], ) def test_list_masks_correct(mask_name: str) -> None: @@ -291,6 +293,21 @@ def test_vickery_patil_error() -> None: _load_vickery_patil_mask(name="wrong", resolution=2.0) +def test_ukb() -> None: + """Test UKB mask.""" + mask, mask_fname, space = load_mask("UKB_15K_GM", resolution=2.0) + assert_array_almost_equal(mask.header["pixdim"][1:4], 2.0) # type: ignore + assert space == "MNI152NLin6Asym" + assert mask_fname is not None + assert mask_fname.name == "UKB_15K_GM_template.nii.gz" + + +def test_ukb_error() -> None: + """Test error for UKB mask.""" + with pytest.raises(ValueError, match=r"find a UKB mask "): + _load_ukb_mask(name="wrong") + + def test_get_mask() -> None: """Test the get_mask function.""" with OasisVBMTestingDataGrabber() as dg: