Skip to content

Commit

Permalink
Thread pool is no longer used for files in memory (#933)
Browse files Browse the repository at this point in the history
* Tinche/aiofiles#47
Thread pool is no longer used for files in memory

* fix tests

* fix import sorted

* little change
  • Loading branch information
abersheeran authored May 6, 2020
1 parent 152a05a commit 9725751
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 4 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ isort
mypy
pytest
pytest-cov
pytest-asyncio

# Documentation
mkdocs
Expand Down
24 changes: 20 additions & 4 deletions starlette/datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,17 +433,33 @@ def __init__(
file = tempfile.SpooledTemporaryFile(max_size=self.spool_max_size)
self.file = file

@property
def _in_memory(self) -> bool:
rolled_to_disk = getattr(self.file, "_rolled", True)
return not rolled_to_disk

async def write(self, data: typing.Union[bytes, str]) -> None:
await run_in_threadpool(self.file.write, data)
if self._in_memory:
self.file.write(data) # type: ignore
else:
await run_in_threadpool(self.file.write, data)

async def read(self, size: int = None) -> typing.Union[bytes, str]:
async def read(self, size: int = -1) -> typing.Union[bytes, str]:
if self._in_memory:
return self.file.read(size)
return await run_in_threadpool(self.file.read, size)

async def seek(self, offset: int) -> None:
await run_in_threadpool(self.file.seek, offset)
if self._in_memory:
self.file.seek(offset)
else:
await run_in_threadpool(self.file.seek, offset)

async def close(self) -> None:
await run_in_threadpool(self.file.close)
if self._in_memory:
self.file.close()
else:
await run_in_threadpool(self.file.close)


class FormData(ImmutableMultiDict):
Expand Down
17 changes: 17 additions & 0 deletions tests/test_datastructures.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import io

import pytest

from starlette.datastructures import (
URL,
CommaSeparatedStrings,
Expand All @@ -8,6 +10,7 @@
MultiDict,
MutableHeaders,
QueryParams,
UploadFile,
)


Expand Down Expand Up @@ -210,6 +213,20 @@ def test_queryparams():
assert QueryParams(q) == q


class TestUploadFile(UploadFile):
spool_max_size = 1024


@pytest.mark.asyncio
async def test_upload_file():
big_file = TestUploadFile("big-file")
await big_file.write(b"big-data" * 512)
await big_file.write(b"big-data")
await big_file.seek(0)
assert await big_file.read(1024) == b"big-data" * 128
await big_file.close()


def test_formdata():
upload = io.BytesIO(b"test")
form = FormData([("a", "123"), ("a", "456"), ("b", upload)])
Expand Down

0 comments on commit 9725751

Please # to comment.