Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Augmentation Pipeline #17

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion chabud/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def on_validation_end(

for i in range(batch_size):
log_image = wandb.Image(
post_img[i].permute(1, 2, 0).detach().numpy() / 6000,
post_img[i][[3, 2, 1], ...].detach().numpy().transpose(1, 2, 0)
/ 3000,
masks={
"prediction": {
"mask_data": mask[i].detach().cpu().numpy(),
Expand Down
112 changes: 95 additions & 17 deletions chabud/datapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
from typing import Iterator

import albumentations as A
import datatree
import lightning as L
import numpy as np
Expand All @@ -13,7 +14,6 @@
import xarray as xr


# %%
def _path_fn(urlpath: str) -> str:
"""
Get the filename from a urlpath and prepend it with 'data' so that it is
Expand Down Expand Up @@ -66,8 +66,10 @@ def _train_val_fold(chip: xr.Dataset) -> int:
Fold 0 is used for validation, Fold 1 and above is for training.
See https://huggingface.co/datasets/chabud-team/chabud-ecml-pkdd2023/discussions/3
"""
if "fold" not in chip.attrs: # no 'fold' attribute, use for training too
return 1 # Training set
if (
"fold" not in chip.attrs
): # no 'fold' attribute, split between train,val with 70/30 split
return np.random.rand() > 0.3
if chip.attrs["fold"] == 0:
return 0 # Validation set
elif chip.attrs["fold"] >= 1:
Expand All @@ -76,44 +78,119 @@ def _train_val_fold(chip: xr.Dataset) -> int:

def _pre_post_mask_tuple(
dataset: xr.Dataset,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]:
) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict]:
"""
From a single xarray.Dataset, split it into a tuple containing the
pre/post/target tensors and a dictionary object containing metadata.

Returns
-------
data_tuple : tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]
data_tuple : tuple[np.ndarray, np.ndarray, np.ndarray, dict]
A tuple with 4 objects, the pre-event image, the post-event image, the
mask image, and a Python dict containing metadata (e.g. filename, UUID,
fold, comments).
"""
# return just the RGB bands for now
pre = dataset.pre_fire.data[[3, 2, 1], ...].astype(dtype="float32")
post = dataset.post_fire.data[[3, 2, 1], ...].astype(dtype="float32")
# # return just the RGB bands for now
# pre = dataset.pre_fire.data[[3, 2, 1], ...].astype(dtype="float32")
# post = dataset.post_fire.data[[3, 2, 1], ...].astype(dtype="float32")
pre = dataset.pre_fire.data.astype(dtype="float32")
post = dataset.post_fire.data.astype(dtype="float32")
mask = dataset.mask.data.astype(dtype="uint8")

# pre_g = dataset.pre_fire.data[2, ...].astype(dtype="float32")
# pre_r = dataset.pre_fire.data[3, ...].astype(dtype="float32")
# pre_nir = dataset.pre_fire.data[7, ...].astype(dtype="float32")
# pre_swir = dataset.pre_fire.data[11, ...].astype(dtype="float32")

# post_g = dataset.post_fire.data[2, ...].astype(dtype="float32")
# post_r = dataset.post_fire.data[3, ...].astype(dtype="float32")
# post_nir = dataset.post_fire.data[7, ...].astype(dtype="float32")
# post_swir = dataset.post_fire.data[11, ...].astype(dtype="float32")

# # NDVI: nir - r / nir + r
# pre_ndvi = np.nan_to_num(
# (pre_nir - pre_r) / (pre_nir + pre_r), nan=0, posinf=0, neginf=0
# )
# # repeat the same for all normalized index
# post_ndvi = np.nan_to_num(
# (post_nir - post_r) / (post_nir + post_r), nan=0, posinf=0, neginf=0
# )

# # NDWI: g - nir / g + nir
# pre_ndwi = np.nan_to_num(
# (pre_g - pre_nir) / (pre_g + pre_nir), nan=0, posinf=0, neginf=0
# )
# post_ndwi = np.nan_to_num(
# (post_g - post_nir) / (post_g + post_nir), nan=0, posinf=0, neginf=0
# )

# # NBR: nir - swir / nir + swir
# pre_nbr = np.nan_to_num(
# (pre_nir - pre_swir) / (pre_nir + pre_swir), nan=0, posinf=0, neginf=0
# )
# post_nbr = np.nan_to_num(
# (post_nir - post_swir) / (post_nir + post_swir), nan=0, posinf=0, neginf=0
# )

# # combine ndvi, ndwi, nbr into a 3-channel array
# pre = np.stack([pre_ndvi, pre_ndwi, pre_nbr], axis=0)
# post = np.stack([post_ndvi, post_ndwi, post_nbr], axis=0)

return (
torch.as_tensor(data=pre),
torch.as_tensor(data=post),
torch.as_tensor(data=mask),
pre,
post,
mask,
{
"filename": os.path.basename(dataset.encoding["source"]),
**dataset.attrs,
},
)


def _apply_augmentation(
sample: tuple[np.ndarray, np.ndarray, np.ndarray, dict]
) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict]:
"""
Apply augmentations to a single sample.
"""
aug = A.Compose(
[
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.ShiftScaleRotate(
p=0.5, shift_limit=0.05, scale_limit=0.05, rotate_limit=10
),
],
additional_targets={"post": "image"},
)
pre, post, mask, metadata = sample

# Apply augmentations - albumenations expects channel last
auged = aug(image=pre.transpose(1, 2, 0), post=post.transpose(1, 2, 0), mask=mask)
pre, post, mask = (
auged["image"].transpose(2, 0, 1),
auged["post"].transpose(2, 0, 1),
auged["mask"],
)
return (pre, post, mask, metadata)


def _stack_tensor_collate_fn(
samples: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[dict]]:
"""
Stack a list of torch.Tensor objects into a single torch.Tensor, and
combine metadata attributes into a list of dicts.
"""
pre_tensor: torch.Tensor = torch.stack(tensors=[sample[0] for sample in samples])
post_tensor: torch.Tensor = torch.stack(tensors=[sample[1] for sample in samples])
mask_tensor: torch.Tensor = torch.stack(tensors=[sample[2] for sample in samples])
pre_tensor: torch.Tensor = torch.stack(
tensors=[torch.as_tensor(sample[0]) for sample in samples]
)
post_tensor: torch.Tensor = torch.stack(
tensors=[torch.as_tensor(sample[1]) for sample in samples]
)
mask_tensor: torch.Tensor = torch.stack(
tensors=[torch.as_tensor(sample[2]) for sample in samples]
)
metadata: list[dict] = [sample[3] for sample in samples]

return pre_tensor, post_tensor, mask_tensor, metadata
Expand All @@ -139,8 +216,8 @@ def __init__(
# From https://huggingface.co/datasets/chabud-team/chabud-ecml-pkdd2023/tree/main
"https://huggingface.co/datasets/chabud-team/chabud-ecml-pkdd2023/resolve/main/train_eval.hdf5",
# From https://huggingface.co/datasets/chabud-team/chabud-extra/tree/main
# "https://huggingface.co/datasets/chabud-team/chabud-extra/resolve/main/california_0.hdf5",
# "https://huggingface.co/datasets/chabud-team/chabud-extra/resolve/main/california_1.hdf5",
"https://huggingface.co/datasets/chabud-team/chabud-extra/resolve/main/california_0.hdf5",
"https://huggingface.co/datasets/chabud-team/chabud-extra/resolve/main/california_1.hdf5",
# "https://huggingface.co/datasets/chabud-team/chabud-extra/resolve/main/california_2.hdf5",
# "https://huggingface.co/datasets/chabud-team/chabud-extra/resolve/main/california_3.hdf5",
# "https://huggingface.co/datasets/chabud-team/chabud-extra/resolve/main/california_4.hdf5",
Expand Down Expand Up @@ -222,8 +299,9 @@ def setup(
# Step 4 - Convert from xarray.Dataset to tuple of torch.Tensor objects
# Also do shuffling (for train set only), batching, and tensor stacking
self.datapipe_train = (
dp_train.shuffle(buffer_size=100)
dp_train.shuffle(buffer_size=2000)
.map(fn=_pre_post_mask_tuple)
.map(fn=_apply_augmentation)
.batch(batch_size=self.batch_size)
.collate(collate_fn=_stack_tensor_collate_fn)
)
Expand Down
93 changes: 93 additions & 0 deletions chabud/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import os
from pathlib import Path

import albumentations as A
from albumentations.pytorch import ToTensorV2
import lightning as L
import numpy as np
from torch.utils.data import Dataset, DataLoader


class ChaBuDDataset(Dataset):
def __init__(self, data_dir: Path, transform=None):
self.data_dir = data_dir
self.uuids = list(data_dir.glob("*.npz"))
self.transform = transform

def __getitem__(self, idx):
uuid = self.uuids[idx]
event = np.load(uuid)
pre, post, mask = (
event["pre"].astype(np.float32),
event["post"].astype(np.float32),
event["mask"].astype(np.uint8),
)

if self.transform:
tfmed = self.transform(
image=pre.transpose(1, 2, 0), post=post.transpose(1, 2, 0), mask=mask
)
pre, post, mask = tfmed["image"], tfmed["post"], tfmed["mask"]

return (pre, post, mask, uuid.stem)

def __len__(self):
return len(self.uuids)


class ChaBuDDataModule(L.LightningDataModule):
def __init__(
self,
data_dir: Path,
batch_size: int = 16,
num_workers: int = 4,
):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.trn_tfm = A.Compose(
[
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.ShiftScaleRotate(
p=0.5, shift_limit=0.05, scale_limit=0.05, rotate_limit=10
),
ToTensorV2(),
],
additional_targets={"post": "image"},
)
self.val_tfm = A.Compose([ToTensorV2()], additional_targets={"post": "image"})
self.tst_tfm = A.Compose([ToTensorV2()], additional_targets={"post": "image"})

def setup(self, stage: str | None = None) -> None:
self.trn_ds = ChaBuDDataset(self.data_dir / "trn", transform=self.trn_tfm)
self.val_ds = ChaBuDDataset(self.data_dir / "val", transform=self.val_tfm)
self.tst_ds = ChaBuDDataset(self.data_dir / "val_orig", transform=self.tst_tfm)

def train_dataloader(self):
return DataLoader(
self.trn_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
pin_memory=True,
)

def val_dataloader(self):
return DataLoader(
self.val_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
pin_memory=True,
)

def test_dataloader(self):
return DataLoader(
self.tst_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
pin_memory=True,
)
Loading