Skip to content

Commit

Permalink
fix: get_delimited_template handles streams as bytes
Browse files Browse the repository at this point in the history
  • Loading branch information
k-burt-uch committed Feb 11, 2025
1 parent dce0aec commit 805e656
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
6 changes: 3 additions & 3 deletions sheepdog/utils/transforms/graph_to_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def entity_to_template_str(label, file_format, **kwargs):
writer = csv.writer(output, delimiter=DELIMITERS[file_format])
writer.writerow(template)
writer.writerow(tsv_example_row(label, template))
return output.getvalue()
return output.getvalue().encode("utf-8")
else:
raise UnsupportedError(file_format)

Expand Down Expand Up @@ -216,15 +216,15 @@ def get_json_template(entity_types):

def get_delimited_template(entity_types, file_format, filename=TEMPLATE_NAME):
"""Return :param: `file_format` (TSV or CSV) template for entity types."""
tar_obj = io.StringIO()
tar_obj = io.BytesIO()
tar = tarfile.open(filename, mode="w|gz", fileobj=tar_obj)

for entity_type in entity_types:
content = entity_to_template_str(entity_type, file_format=file_format)
partname = "{}.{}".format(entity_type, file_format)
tarinfo = tarfile.TarInfo(name=partname)
tarinfo.size = len(content)
tar.addfile(tarinfo, io.StringIO(content))
tar.addfile(tarinfo, io.BytesIO(content))

tar.close()
return tar_obj.getvalue()
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/test_templates.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
import pytest

from sheepdog import dictionary
from sheepdog.utils.transforms.graph_to_doc import (
entity_to_template,
is_property_hidden,
get_delimited_template,
)
from sheepdog.utils import _get_links


@pytest.mark.parametrize("file_format", ["json", "csv", "tsv"])
def test_get_delimited_template(file_format):
try:
get_delimited_template(
dictionary.schema, file_format=file_format, filename="test"
)
except Exception as e:
pytest.fail(
"get_delimited_template unexpectedly threw an exception {0}".format(e)
)


def test_urls_in_templates_json():
"""Test that urls is in JSON template iff entity is data_file"""
for label in dictionary.schema:
Expand Down

0 comments on commit 805e656

Please # to comment.