Skip to content

Commit

Permalink
refactor(DataFetcher)!: rename DataFetcher to TileFetcher (#443)
Browse files Browse the repository at this point in the history
  • Loading branch information
mrsmrynk authored Jan 25, 2025
1 parent c3aa24d commit 517672d
Show file tree
Hide file tree
Showing 20 changed files with 618 additions and 599 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
ThreadPoolExecutor,
as_completed,
)
from typing import TYPE_CHECKING
from typing import (
TYPE_CHECKING,
Literal,
)

if TYPE_CHECKING:
from pathlib import Path
Expand All @@ -22,78 +25,84 @@
WMSVersion,
)
from aviary.core.exceptions import AviaryUserError
from aviary.core.tile import Tile

if TYPE_CHECKING:
from aviary.core.type_aliases import (
BufferSize,
Coordinate,
Channels,
ChannelsSet,
Coordinates,
EPSGCode,
GroundSamplingDistance,
TileSize,
)
from aviary.data.data_fetcher import DataFetcher
from aviary.inference.tile_fetcher import TileFetcher


def composite_fetcher(
x_min: Coordinate,
y_min: Coordinate,
data_fetchers: list[DataFetcher],
coordinates: Coordinates,
tile_fetchers: list[TileFetcher],
axis: Literal['channel', 'time_step'] = 'channel',
num_workers: int = 1,
) -> npt.NDArray:
"""Fetches data from the sources.
) -> Tile:
"""Fetches a tile from the sources.
Parameters:
x_min: minimum x coordinate
y_min: minimum y coordinate
data_fetchers: data fetchers
coordinates: coordinates (x_min, y_min) of the tile
tile_fetchers: tile fetchers
axis: axis to concatenate the tiles (`channel`, `time_step`)
num_workers: number of workers
Returns:
data
tile
"""
with ThreadPoolExecutor(max_workers=num_workers) as executor:
tasks = [
executor.submit(
data_fetcher,
x_min=x_min,
y_min=y_min,
tile_fetcher,
coordinates=coordinates,
)
for data_fetcher in data_fetchers
for tile_fetcher in tile_fetchers
]
data = [
tiles = [
futures.result() for futures in as_completed(tasks)
]

return np.concatenate(data, axis=-1)
return Tile.from_tiles(
tiles=tiles,
axis=axis,
)


def vrt_fetcher(
x_min: Coordinate,
y_min: Coordinate,
coordinates: Coordinates,
path: Path,
channels: Channels,
tile_size: TileSize,
ground_sampling_distance: GroundSamplingDistance,
interpolation_mode: InterpolationMode = InterpolationMode.BILINEAR,
buffer_size: BufferSize = 0,
drop_channels: list[int] | None = None,
ignore_channels: ChannelsSet | None = None,
fill_value: int = 0,
) -> npt.NDArray:
"""Fetches data from the virtual raster.
) -> Tile:
"""Fetches a tile from the virtual raster.
Parameters:
x_min: minimum x coordinate
y_min: minimum y coordinate
coordinates: coordinates (x_min, y_min) of the tile
path: path to the virtual raster (.vrt file)
channels: channels
tile_size: tile size in meters
ground_sampling_distance: ground sampling distance in meters
interpolation_mode: interpolation mode (`BILINEAR` or `NEAREST`)
buffer_size: buffer size in meters (specifies the area around the tile that is additionally fetched)
drop_channels: channel indices to drop (supports negative indexing)
ignore_channels: channels to ignore
fill_value: fill value of nodata pixels
Returns:
data
tile
"""
x_min, y_min = coordinates
bounding_box = BoundingBox(
x_min=x_min,
y_min=y_min,
Expand Down Expand Up @@ -129,47 +138,60 @@ def vrt_fetcher(
data = _permute_data(
data=data,
)
return _drop_channels(
tile = Tile.from_composite(
data=data,
drop_channels=drop_channels,
channels=channels,
coordinates=coordinates,
tile_size=tile_size,
buffer_size=buffer_size,
)

if ignore_channels is not None:
for channel in ignore_channels:
tile = tile.remove_channel(
channel=channel,
inplace=True,
)

return tile


def wms_fetcher(
x_min: Coordinate,
y_min: Coordinate,
coordinates: Coordinates,
url: str,
version: WMSVersion,
layer: str,
epsg_code: EPSGCode,
response_format: str,
channels: Channels,
tile_size: TileSize,
ground_sampling_distance: GroundSamplingDistance,
style: str | None = None,
buffer_size: BufferSize = 0,
drop_channels: list[int] | None = None,
ignore_channels: ChannelsSet | None = None,
fill_value: str = '0x000000',
) -> npt.NDArray:
"""Fetches data from the web map service.
) -> Tile:
"""Fetches a tile from the web map service.
Parameters:
x_min: minimum x coordinate
y_min: minimum y coordinate
coordinates: coordinates (x_min, y_min) of the tile
url: url of the web map service
version: version of the web map service (`V1_1_1` or `V1_3_0`)
layer: name of the layer
epsg_code: EPSG code
response_format: format of the response (MIME type, e.g., 'image/png')
channels: channels
tile_size: tile size in meters
ground_sampling_distance: ground sampling distance in meters
style: name of the style
buffer_size: buffer size in meters (specifies the area around the tile that is additionally fetched)
drop_channels: channel indices to drop (supports negative indexing)
ignore_channels: channels to ignore
fill_value: fill value of nodata pixels
Returns:
data
tile
"""
x_min, y_min = coordinates
bounding_box = BoundingBox(
x_min=x_min,
y_min=y_min,
Expand Down Expand Up @@ -204,11 +226,23 @@ def wms_fetcher(
data = _permute_data(
data=data,
)
return _drop_channels(
tile = Tile.from_composite(
data=data,
drop_channels=drop_channels,
channels=channels,
coordinates=coordinates,
tile_size=tile_size,
buffer_size=buffer_size,
)

if ignore_channels is not None:
for channel in ignore_channels:
tile = tile.remove_channel(
channel=channel,
inplace=True,
)

return tile


def _compute_tile_size_pixels(
tile_size: TileSize,
Expand All @@ -230,27 +264,6 @@ def _compute_tile_size_pixels(
return int((tile_size + 2 * buffer_size) / ground_sampling_distance)


def _drop_channels(
data: npt.NDArray,
drop_channels: list[int] | None,
) -> npt.NDArray:
"""Drops the specified channels from the data.
Parameters:
data: data
drop_channels: channel indices to drop (supports negative indexing)
Returns:
data
"""
if drop_channels is None:
return data

channels = np.arange(data.shape[-1])
retain_channels = np.delete(channels, drop_channels)
return data[..., retain_channels]


def _get_wms_params(
version: WMSVersion,
layer: str,
Expand Down
18 changes: 0 additions & 18 deletions aviary/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,3 @@
from .data_fetcher import (
CompositeFetcher,
CompositeFetcherConfig,
DataFetcher,
DataFetcherConfig,
VRTFetcher,
VRTFetcherConfig,
WMSFetcher,
WMSFetcherConfig,
)
from .data_loader import DataLoader
from .data_preprocessor import (
CompositePreprocessor,
Expand All @@ -22,12 +12,8 @@
from .dataset import Dataset

__all__ = [
'CompositeFetcher',
'CompositeFetcherConfig',
'CompositePreprocessor',
'CompositePreprocessorConfig',
'DataFetcher',
'DataFetcherConfig',
'DataLoader',
'DataPreprocessor',
'DataPreprocessorConfig',
Expand All @@ -36,8 +22,4 @@
'NormalizePreprocessorConfig',
'StandardizePreprocessor',
'StandardizePreprocessorConfig',
'VRTFetcher',
'VRTFetcherConfig',
'WMSFetcher',
'WMSFetcherConfig',
]
18 changes: 18 additions & 0 deletions aviary/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,31 @@
SegmentationModel,
SegmentationModelConfig,
)
from .tile_fetcher import (
CompositeFetcher,
CompositeFetcherConfig,
TileFetcher,
TileFetcherConfig,
VRTFetcher,
VRTFetcherConfig,
WMSFetcher,
WMSFetcherConfig,
)

__all__ = [
'CompositeFetcher',
'CompositeFetcherConfig',
'Exporter',
'Model',
'ONNXSegmentationModel',
'SegmentationExporter',
'SegmentationExporterConfig',
'SegmentationModel',
'SegmentationModelConfig',
'TileFetcher',
'TileFetcherConfig',
'VRTFetcher',
'VRTFetcherConfig',
'WMSFetcher',
'WMSFetcherConfig',
]
Loading

0 comments on commit 517672d

Please # to comment.