Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Fix potential data leak for shuffle tasks #2975

Merged
merged 3 commits into from
Apr 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mars/deploy/oscar/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,6 +1236,7 @@ async def fetch_infos(self, *tileables, fields, **kwargs) -> list:
return result

async def decref(self, *tileable_keys):
logger.debug("Decref tileables on client: %s", tileable_keys)
return await self._lifecycle_api.decref_tileables(list(tileable_keys))

async def _get_ref_counts(self) -> Dict[str, int]:
Expand Down
50 changes: 50 additions & 0 deletions mars/deploy/oscar/tests/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,14 @@ async def test_execute(create_cluster, config):

del a, b

if not isinstance(session._isolated_session, _IsolatedWebSession):
worker_pools = session.client._cluster._worker_pools
await session.destroy()
for worker_pool in worker_pools:
_assert_storage_cleaned(
session.session_id, worker_pool.external_address, StorageLevel.MEMORY
)


@pytest.mark.asyncio
async def test_iterative_tiling(create_cluster):
Expand All @@ -254,6 +262,14 @@ async def test_iterative_tiling(create_cluster):
assert df2.index_value.min_val >= 1
assert df2.index_value.max_val <= 30

if not isinstance(session._isolated_session, _IsolatedWebSession):
worker_pools = session.client._cluster._worker_pools
await session.destroy()
for worker_pool in worker_pools:
_assert_storage_cleaned(
session.session_id, worker_pool.external_address, StorageLevel.MEMORY
)


@pytest.mark.asyncio
async def test_execute_describe(create_cluster):
Expand All @@ -271,6 +287,14 @@ async def test_execute_describe(create_cluster):
res = await session.fetch(r)
pd.testing.assert_frame_equal(res, raw.describe())

if not isinstance(session._isolated_session, _IsolatedWebSession):
worker_pools = session.client._cluster._worker_pools
await session.destroy()
for worker_pool in worker_pools:
_assert_storage_cleaned(
session.session_id, worker_pool.external_address, StorageLevel.MEMORY
)


@pytest.mark.asyncio
async def test_sync_execute_in_async(create_cluster):
Expand Down Expand Up @@ -395,6 +419,12 @@ async def test_web_session(create_cluster, config):
await session.destroy()
await _run_web_session_test(web_address)

worker_pools = client._cluster._worker_pools
for worker_pool in worker_pools:
_assert_storage_cleaned(
session.session_id, worker_pool.external_address, StorageLevel.MEMORY
)


def test_sync_execute():
session = new_session(n_cpu=2, web=False, use_uvloop=False)
Expand Down Expand Up @@ -541,6 +571,26 @@ def test_decref(setup_session):
ref_counts = session._get_ref_counts()
assert len(ref_counts) == 0

with tempfile.TemporaryDirectory() as tempdir:
file_path = os.path.join(tempdir, "test.csv")
pdf = pd.DataFrame(
np.random.RandomState(0).rand(100, 10),
columns=[f"col{i}" for i in range(10)],
)
pdf.to_csv(file_path, index=False)

df = md.read_csv(file_path, chunk_bytes=os.stat(file_path).st_size / 5)
df2 = df.head(10)

result = df2.execute().fetch()
expected = pdf.head(10)
pd.testing.assert_frame_equal(result, expected)

del df, df2

ref_counts = session._get_ref_counts()
assert len(ref_counts) == 0

worker_addr = session._session.client._cluster._worker_pools[0].external_address
_assert_storage_cleaned(session.session_id, worker_addr, StorageLevel.MEMORY)

Expand Down
5 changes: 4 additions & 1 deletion mars/services/lifecycle/supervisor/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ def _check_ref_counts(cls, keys: List[str], ref_counts: List[int]):
)

def incref_chunks(self, chunk_keys: List[str], counts: List[int] = None):
logger.debug("Increase reference count for chunks %s", chunk_keys)
logger.debug(
"Increase reference count for chunks %s",
{ck: self._chunk_ref_counts[ck] for ck in chunk_keys},
)
self._check_ref_counts(chunk_keys, counts)
counts = counts if counts is not None else itertools.repeat(1)
for chunk_key, count in zip(chunk_keys, counts):
Expand Down
6 changes: 4 additions & 2 deletions mars/services/storage/api/oscar.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,11 @@ async def fetch(
error: str
raise or ignore
"""
await self._storage_handler_ref.fetch_batch(
fetch_key = await self._storage_handler_ref.fetch_batch(
self._session_id, [data_key], level, band_name, remote_address, error
)
if fetch_key:
return fetch_key

@fetch.batch
async def batch_fetch(self, args_list, kwargs_list):
Expand All @@ -201,7 +203,7 @@ async def batch_fetch(self, args_list, kwargs_list):
assert extracted_args == (level, band_name, dest_address, error)
extracted_args = (level, band_name, dest_address, error)
data_keys.append(data_key)
await self._storage_handler_ref.fetch_batch(
return await self._storage_handler_ref.fetch_batch(
self._session_id, data_keys, *extracted_args
)

Expand Down
6 changes: 5 additions & 1 deletion mars/services/storage/transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,11 @@ async def send_batch_data(
)
await self._data_manager_ref.unpin.batch(*unpin_tasks)
logger.debug(
"Finish sending data (%s, %s) to %s", session_id, data_keys, address
"Finish sending data (%s, %s) to %s, total size is %s",
session_id,
data_keys,
address,
sum(data_sizes),
)


Expand Down
6 changes: 3 additions & 3 deletions mars/services/task/execution/mars/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,9 +384,9 @@ def _get_decref_stage_chunk_key_to_counts(
for inp_subtask in subtask_graph.predecessors(subtask):
for c in inp_subtask.chunk_graph.results:
decref_chunk_key_to_counts[c.key] += 1
# decref result of chunk graphs
for c in stage_processor.chunk_graph.results:
decref_chunk_key_to_counts[c.key] += 1
# decref result of chunk graphs
for c in stage_processor.chunk_graph.results:
decref_chunk_key_to_counts[c.key] += 1
return decref_chunk_key_to_counts

@mo.extensible
Expand Down