Skip to content

Turn resolution into a parameter (currently hardcoded) #48

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions pipeline/download_gene.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ def parse_args():
Path to output directory to save results.
""",
)
parser.add_argument(
"--downsample-ref",
type=int,
default=25,
help="""\
Downsampling factor of the reference space.
""",
)
parser.add_argument(
"--downsample-img",
type=int,
Expand Down Expand Up @@ -79,6 +87,7 @@ def postprocess_dataset(
None,
],
n_images: int,
downsample_ref: int,
) -> Tuple[np.ndarray, np.ndarray, dict[str, Any]]:
"""Post process given dataset.

Expand Down Expand Up @@ -116,7 +125,7 @@ def postprocess_dataset(
# TODO: maybe notify the user somehow?
continue

section_numbers.append(section_coordinate // 25)
section_numbers.append(section_coordinate // downsample_ref)
image_ids.append(img_id)
warped_img = 255 - df.warp(img, border_mode="constant", c=img[0, 0, :].tolist())
dataset_np.append(warped_img)
Expand All @@ -140,6 +149,7 @@ def postprocess_dataset(
def main(
experiment_id: int,
output_dir: Path | str,
downsample_ref: int,
downsample_img: int,
expression: bool = True,
) -> int:
Expand All @@ -151,6 +161,10 @@ def main(
Gene ID to download.
output_dir
Directory when results are going to be saved.
downsample_ref
Downscaling of the reference space grid. If set to 1 no
downsampling takes place. The higher the value the smaller the grid
in the reference space and the faster the matrix multiplication.
downsample_img
Downsampling factor given to Allen API when downloading the images.
This factor is going to reduce the size.
Expand All @@ -174,19 +188,23 @@ def main(

logger.info(f"Start downloading experiment ID {experiment_id}")
dataset = DatasetDownloader(
experiment_id, downsample_img=downsample_img, include_expression=expression
experiment_id,
downsample_img=downsample_img,
include_expression=expression,
downsample_ref=downsample_ref,
)
dataset.fetch_metadata()
dataset_gen = dataset.run()
axis = CommonQueries.get_axis(experiment_id)
dataset_np, expression_np, metadata_dict = postprocess_dataset(
dataset_gen, len(dataset)
dataset_gen, len(dataset), downsample_ref
)
metadata_dict["axis"] = axis
metadata_dict["downsample-ref"] = downsample_ref

logger.info(f"Saving results of experiment ID {experiment_id}")
np.save(output_dir / f"{experiment_id}.npy", dataset_np)
with open(output_dir / f"{experiment_id}.json", "w") as f:
np.save(output_dir / f"{experiment_id}-{downsample_ref}.npy", dataset_np)
with open(output_dir / f"{experiment_id}-{downsample_ref}.json", "w") as f:
json.dump(metadata_dict, f, indent=True, sort_keys=True)

if expression_np is not None:
Expand Down
2 changes: 1 addition & 1 deletion pipeline/gene_to_nissl.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def registration(
except IndexError:
logger.warn(
f"One of the gene slice has a section number ({section_number})"
f"out of nissl volume shape {nissl_volume.shape}. This slice is"
f"out of nissl volume shape {nissl_volume.shape}. This slice is "
"removed from the pipeline."
)
section_numbers_kept.append(False)
Expand Down
5 changes: 3 additions & 2 deletions pipeline/interpolate_gene.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,12 @@ def main(
section_numbers = [int(s) for s in metadata["section_numbers"]]
axis = metadata["axis"]

volume_shape = [num // metadata["downsample-ref"] for num in [13200, 8000, 11400]]
# Wrap the data into a GeneDataset class
gene_dataset = GeneDataset(
section_images,
section_numbers,
volume_shape=(528, 320, 456, 3),
volume_shape=(*volume_shape, 3),
axis=axis,
)

Expand Down Expand Up @@ -197,7 +198,7 @@ def main(

np.save(
output_dir
/ f"{experiment_id}-{interpolator_name}-interpolated-{image_type}.npy",
/ f"{experiment_id}-{metadata['downsample-ref']}-{interpolator_name}-interpolated-{image_type}.npy",
predicted_volume,
)

Expand Down
26 changes: 19 additions & 7 deletions pipeline/nissl_to_ccfv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,28 @@ def registration(
Nissl volume once the registration transformation are applied.
"""
logger.info("Compute the registration...")
nii_data = register(reference_volume, moving_volume)
logger.info(f"Max displacements: {np.abs(nii_data).max(axis=(0, 1, 2, 3))}")
nissl_warped = []
warped_volume = []

logger.info("Apply transformation to Moving Volume...")
warped_volume = transform(moving_volume, nii_data, interpolator="genericLabel")
for i, (reference, moving, nissl) in enumerate(zip(reference_volume, moving_volume, nissl_volume)):
try:
nii_data = register(reference, moving)
logger.info(f"Max displacements: {np.abs(nii_data).max(axis=(0, 1, 2, 3))}")

logger.info("Apply transformation to Nissl Volume...")
nissl_warped = transform(nissl_volume, nii_data)
logger.info("Apply transformation to Moving Volume...")
warped_volume.append(transform(moving, nii_data, interpolator="genericLabel"))

return warped_volume, nissl_warped
logger.info("Apply transformation to Nissl Volume...")
nissl_warped.append(transform(nissl, nii_data))
except RuntimeError:
logger.info(f"Registration for slice {i} went wrong...")
warped_volume.append(moving)
nissl_warped.append(nissl)

if (i + 1) % 5 == 0:
logger.info(f" {i + 1} / {reference_volume.shape[0]} registrations done")

return np.array(warped_volume), np.array(nissl_warped)


def main(
Expand Down