diff --git a/gridit/grid.py b/gridit/grid.py index f5d00be..073bca5 100644 --- a/gridit/grid.py +++ b/gridit/grid.py @@ -76,6 +76,17 @@ def __hash__(self): """Return unique hash based on content.""" return hash(tuple(self)) + def __getstate__(self): + """Serialize object attributes for pickle dumps.""" + state = dict(self) + if state["projection"] is None or state["projection"] == "": + del state["projection"] + return state + + def __setstate__(self, state): + """Set object attributes from pickle loads.""" + self.__init__(**state) + def __eq__(self, other): """Return True if objects are equal.""" if self.__class__.__name__ != other.__class__.__name__: diff --git a/tests/test_grid.py b/tests/test_grid.py index e1a1abc..63d8e7e 100644 --- a/tests/test_grid.py +++ b/tests/test_grid.py @@ -60,3 +60,34 @@ def test_grid_transform(grid_basic): from affine import Affine assert grid_basic.transform == Affine(10.0, 0.0, 1000.0, 0.0, -10.0, 2000.0) + + +def test_pickle(): + import pickle + + # Without projection + grid = Grid(25.0, (36, 33), (1748725.0, 5449775.0)) + expected_bytes = ( + b"\x80\x04\x95c\x00\x00\x00\x00\x00\x00\x00\x8c\x0b" + b"gridit.grid\x94\x8c\x04Grid\x94\x93\x94)\x81\x94}\x94(\x8c\n" + b"resolution\x94G@9\x00\x00\x00\x00\x00\x00\x8c\x05" + b"shape\x94K$K!\x86\x94\x8c\x08" + b"top_left\x94GA:\xae\xf5\x00\x00\x00\x00GAT\xca\x0b\xc0\x00\x00\x00\x86" + b"\x94ub." + ) + assert pickle.loads(expected_bytes) == grid, "failed loading previous serialization" + assert pickle.loads(pickle.dumps(grid)) == grid, "failed round-trip" + + # With projection + grid = Grid(25.0, (36, 33), (1748725.0, 5449775.0), "EPSG:2193") + expected_bytes = ( + b"\x80\x04\x95|\x00\x00\x00\x00\x00\x00\x00\x8c\x0b" + b"gridit.grid\x94\x8c\x04Grid\x94\x93\x94)\x81\x94}\x94(\x8c\n" + b"resolution\x94G@9\x00\x00\x00\x00\x00\x00\x8c\x05" + b"shape\x94K$K!\x86\x94\x8c\x08" + b"top_left\x94GA:\xae\xf5\x00\x00\x00\x00GAT\xca\x0b\xc0\x00\x00\x00\x86" + b"\x94\x8c\nprojection\x94\x8c\tEPSG:2193" + b"\x94ub." + ) + assert pickle.loads(expected_bytes) == grid, "failed loading previous serialization" + assert pickle.loads(pickle.dumps(grid)) == grid, "failed round-trip"