diff --git a/src/skan/csr.py b/src/skan/csr.py index bc310f1..fb18136 100644 --- a/src/skan/csr.py +++ b/src/skan/csr.py @@ -207,7 +207,7 @@ def csr_to_nbgraph(csr, node_props=None): csr.indices, csr.data, np.array(csr.shape, dtype=np.int32), - node_props, + node_props.astype(np.float64), ) @@ -525,7 +525,7 @@ def __init__( if np.issubdtype(skeleton_image.dtype, np.floating): self.pixel_values = skeleton_image[coords] elif np.issubdtype(skeleton_image.dtype, np.integer): - self.pixel_values = skeleton_image.astype(float)[coords] + self.pixel_values = skeleton_image.astype(np.float64)[coords] else: self.pixel_values = None self.graph = graph diff --git a/src/skan/test/test_csr.py b/src/skan/test/test_csr.py index 2c937e6..efbf9e2 100644 --- a/src/skan/test/test_csr.py +++ b/src/skan/test/test_csr.py @@ -10,7 +10,11 @@ from numpy.testing import assert_equal, assert_almost_equal import pandas as pd import pytest +import scipy +from scipy import ndimage as ndi +from skimage import data from skimage.draw import line +from skimage.morphology import skeletonize from skan import csr from skan._testdata import ( @@ -357,6 +361,16 @@ def test_skeleton_integer_dtype(dtype): assert stats['mean_pixel_value'].max() > 1 +@pytest.mark.parametrize('dtype', [np.float32, np.float64]) +def test_skeleton_all_float_dtypes(dtype): + """Test that skeleton data types can be both float32 and float64.""" + horse = ~data.horse() + skeleton_image = skeletonize(horse) + dt = ndi.distance_transform_edt(horse) + float_skel = (dt * skeleton_image).astype(dtype) + _ = csr.Skeleton(float_skel) + + def test_default_summarize_separator(): with pytest.warns(np.exceptions.VisibleDeprecationWarning, match='separator in column name'): @@ -523,7 +537,7 @@ def test_nx_to_skeleton( @pytest.mark.parametrize( - 'wrong_skeleton', + ('wrong_skeleton'), [ pytest.param(skeleton0, id='Numpy Array.'), pytest.param(csr.Skeleton(skeleton0), id='Skeleton.'), @@ -538,3 +552,36 @@ def test_nx_to_skeleton_attribute_error(wrong_skeleton: Any) -> None: """Test various errors are raised by nx_to_skeleton().""" with pytest.raises(Exception): csr.nx_to_skeleton(wrong_skeleton) + + +@pytest.mark.parametrize( + ('skeleton'), + [ + pytest.param(skeleton0, id='Numpy Array'), + pytest.param(csr.Skeleton(skeleton0), id='Skeleton'), + pytest.param(nx_graph, id='NetworkX Graph without edges.'), + ], + ) +def test_csr_to_nbgraph_attribute_error(skeleton: Any) -> None: + """Raise AttributeError if csr_to_nbgraph() passed incomplete objects.""" + with pytest.raises(AttributeError): + csr.csr_to_nbgraph(skeleton) + + +@pytest.mark.parametrize( + ('graph'), + [ + pytest.param( + scipy.sparse.csr_matrix(skeleton0), + id='Sparse matrix directly from Numpy Array', + ), + pytest.param( + scipy.sparse.csr_matrix(csr.Skeleton(skeleton0)), + id='Sparse matrix from csr.Skeleton', + ), + ], + ) +def test_csr_to_nbgraph_type_error(graph: scipy.sparse.csr_matrix) -> None: + """Test TypeError is raised by csr_to_nbgraph() if wrong type is passed.""" + with pytest.raises(TypeError): + csr.csr_to_nbgraph(graph)