Skip to content

Commit dddc201

Browse files
authored
[Ray] Load subtask inputs from meta (#2976)
1 parent c43918d commit dddc201

File tree

8 files changed

+123
-39
lines changed

8 files changed

+123
-39
lines changed

mars/deploy/oscar/session.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,9 @@ async def init(
845845
from .local import new_cluster_in_isolation
846846

847847
return (
848-
await new_cluster_in_isolation(address, timeout=timeout, **kwargs)
848+
await new_cluster_in_isolation(
849+
address, timeout=timeout, backend=backend, **kwargs
850+
)
849851
).session
850852

851853
if kwargs: # pragma: no cover

mars/deploy/oscar/tests/test_local.py

+24-19
Original file line numberDiff line numberDiff line change
@@ -396,8 +396,11 @@ async def test_web_session(create_cluster, config):
396396
await _run_web_session_test(web_address)
397397

398398

399-
def test_sync_execute():
400-
session = new_session(n_cpu=2, web=False, use_uvloop=False)
399+
@pytest.mark.parametrize("config", [{"backend": "mars", "incremental_index": True}])
400+
def test_sync_execute(config):
401+
session = new_session(
402+
backend=config["backend"], n_cpu=2, web=False, use_uvloop=False
403+
)
401404

402405
# web not started
403406
assert session._session.client.web_address is None
@@ -421,23 +424,25 @@ def test_sync_execute():
421424
assert d is c
422425
assert abs(session.fetch(d) - raw.sum()) < 0.001
423426

424-
with tempfile.TemporaryDirectory() as tempdir:
425-
file_path = os.path.join(tempdir, "test.csv")
426-
pdf = pd.DataFrame(
427-
np.random.RandomState(0).rand(100, 10),
428-
columns=[f"col{i}" for i in range(10)],
429-
)
430-
pdf.to_csv(file_path, index=False)
431-
432-
df = md.read_csv(file_path, chunk_bytes=os.stat(file_path).st_size / 5)
433-
result = df.sum(axis=1).execute().fetch()
434-
expected = pd.read_csv(file_path).sum(axis=1)
435-
pd.testing.assert_series_equal(result, expected)
436-
437-
df = md.read_csv(file_path, chunk_bytes=os.stat(file_path).st_size / 5)
438-
result = df.head(10).execute().fetch()
439-
expected = pd.read_csv(file_path).head(10)
440-
pd.testing.assert_frame_equal(result, expected)
427+
# TODO(fyrestone): Remove this when the Ray backend support incremental index.
428+
if config["incremental_index"]:
429+
with tempfile.TemporaryDirectory() as tempdir:
430+
file_path = os.path.join(tempdir, "test.csv")
431+
pdf = pd.DataFrame(
432+
np.random.RandomState(0).rand(100, 10),
433+
columns=[f"col{i}" for i in range(10)],
434+
)
435+
pdf.to_csv(file_path, index=False)
436+
437+
df = md.read_csv(file_path, chunk_bytes=os.stat(file_path).st_size / 5)
438+
result = df.sum(axis=1).execute().fetch()
439+
expected = pd.read_csv(file_path).sum(axis=1)
440+
pd.testing.assert_series_equal(result, expected)
441+
442+
df = md.read_csv(file_path, chunk_bytes=os.stat(file_path).st_size / 5)
443+
result = df.head(10).execute().fetch()
444+
expected = pd.read_csv(file_path).head(10)
445+
pd.testing.assert_frame_equal(result, expected)
441446

442447
for worker_pool in session._session.client._cluster._worker_pools:
443448
_assert_storage_cleaned(

mars/deploy/oscar/tests/test_ray_dag.py

+13
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,16 @@ async def create_cluster(request):
9595
@pytest.mark.asyncio
9696
async def test_execute(ray_start_regular_shared2, create_cluster, config):
9797
await test_local.test_execute(create_cluster, config)
98+
99+
100+
@require_ray
101+
@pytest.mark.asyncio
102+
async def test_iterative_tiling(ray_start_regular_shared2, create_cluster):
103+
await test_local.test_iterative_tiling(create_cluster)
104+
105+
106+
# TODO(fyrestone): Support incremental index in ray backend.
107+
@require_ray
108+
@pytest.mark.parametrize("config", [{"backend": "ray", "incremental_index": False}])
109+
def test_sync_execute(config):
110+
test_local.test_sync_execute(config)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright 1999-2021 Alibaba Group Holding Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
# TODO(fyrestone): Should implement the mars.core.context.Context.
17+
class RayExecutionContext(dict):
18+
@staticmethod
19+
def new_custom_log_dir():
20+
return None

mars/services/task/execution/ray/executor.py

+45-12
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@
3636
from ....cluster.api import ClusterAPI
3737
from ....lifecycle.api import LifecycleAPI
3838
from ....meta.api import MetaAPI
39-
from ....subtask import SubtaskGraph
39+
from ....subtask import Subtask, SubtaskGraph
4040
from ....subtask.utils import iter_input_data_keys, iter_output_data
4141
from ..api import TaskExecutor, ExecutionChunkResult, register_executor_cls
42+
from .context import RayExecutionContext
4243

4344
ray = lazy_import("ray")
4445
logger = logging.getLogger(__name__)
@@ -55,7 +56,7 @@ def execute_subtask(
5556
ensure_coverage()
5657
subtask_chunk_graph = deserialize(*subtask_chunk_graph)
5758
# inputs = [i[1] for i in inputs]
58-
context = dict(zip(input_keys, inputs))
59+
context = RayExecutionContext(zip(input_keys, inputs))
5960
# optimize chunk graph.
6061
subtask_chunk_graph = optimize(subtask_chunk_graph)
6162
# from data_key to results
@@ -117,7 +118,7 @@ async def create(
117118
session_id: str,
118119
address: str,
119120
task,
120-
tile_context,
121+
tile_context: TileContext,
121122
**kwargs
122123
) -> "TaskExecutor":
123124
ray_executor = ray.remote(execute_subtask)
@@ -159,14 +160,10 @@ async def execute_subtask_graph(
159160
result_keys = {chunk.key for chunk in chunk_graph.result_chunks}
160161
for subtask in subtask_graph.topological_iter():
161162
subtask_chunk_graph = subtask.chunk_graph
162-
chunk_key_to_data_keys = get_chunk_key_to_data_keys(subtask_chunk_graph)
163-
key_to_input = {
164-
key: context[key]
165-
for key, _ in iter_input_data_keys(
166-
subtask, subtask_chunk_graph, chunk_key_to_data_keys
167-
)
168-
}
169-
output_keys = self._get_output_keys(subtask_chunk_graph)
163+
key_to_input = await self._load_subtask_inputs(
164+
stage_id, subtask, subtask_chunk_graph, context
165+
)
166+
output_keys = self._get_subtask_output_keys(subtask_chunk_graph)
170167
output_meta_keys = result_keys & output_keys
171168
output_count = len(output_keys) + bool(output_meta_keys)
172169
output_object_refs = self._ray_executor.options(
@@ -250,8 +247,44 @@ async def get_progress(self) -> float:
250247
async def cancel(self):
251248
"""Cancel execution."""
252249

250+
async def _load_subtask_inputs(
251+
self, stage_id: str, subtask: Subtask, chunk_graph: ChunkGraph, context: Dict
252+
):
253+
"""
254+
Load a dict of input key to object ref of subtask from context.
255+
256+
It updates the context if the input object refs are fetched from
257+
the meta service.
258+
"""
259+
key_to_input = {}
260+
key_to_get_meta = {}
261+
chunk_key_to_data_keys = get_chunk_key_to_data_keys(chunk_graph)
262+
for key, _ in iter_input_data_keys(
263+
subtask, chunk_graph, chunk_key_to_data_keys
264+
):
265+
if key in context:
266+
key_to_input[key] = context[key]
267+
else:
268+
key_to_get_meta[key] = self._meta_api.get_chunk_meta.delay(
269+
key, fields=["object_refs"]
270+
)
271+
if key_to_get_meta:
272+
logger.info(
273+
"Fetch %s metas and update context of stage %s.",
274+
len(key_to_get_meta),
275+
stage_id,
276+
)
277+
meta_list = await self._meta_api.get_chunk_meta.batch(
278+
*key_to_get_meta.values()
279+
)
280+
for key, meta in zip(key_to_get_meta.keys(), meta_list):
281+
object_ref = meta["object_refs"][0]
282+
key_to_input[key] = object_ref
283+
context[key] = object_ref
284+
return key_to_input
285+
253286
@staticmethod
254-
def _get_output_keys(chunk_graph):
287+
def _get_subtask_output_keys(chunk_graph: ChunkGraph):
255288
output_keys = {}
256289
for chunk in chunk_graph.results:
257290
if isinstance(chunk.op, VirtualOperand):

mars/services/task/supervisor/processor.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -204,11 +204,12 @@ async def _process_stage_chunk_graph(
204204
# for all execution backends.
205205
try:
206206
key_to_bands = await meta_api.get_chunk_meta.batch(*get_meta_tasks)
207-
except KeyError:
208-
key_to_bands = {}
209-
fetch_op_to_bands = dict(
210-
(key, meta["bands"][0]) for key, meta in zip(fetch_op_keys, key_to_bands)
211-
)
207+
fetch_op_to_bands = dict(
208+
(key, meta["bands"][0])
209+
for key, meta in zip(fetch_op_keys, key_to_bands)
210+
)
211+
except (KeyError, IndexError):
212+
fetch_op_to_bands = {}
212213
with Timer() as timer:
213214
subtask_graph = await asyncio.to_thread(
214215
self._preprocessor.analyze,

mars/services/task/supervisor/tests/test_task_manager.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -549,8 +549,9 @@ async def test_numexpr(actor_pool):
549549
) == [1] * len(result_tileable.chunks)
550550

551551

552+
@pytest.mark.parametrize("config", [{"incremental_index": True}])
552553
@pytest.mark.asyncio
553-
async def test_optimization(actor_pool):
554+
async def test_optimization(actor_pool, config):
554555
(
555556
execution_backend,
556557
pool,
@@ -574,7 +575,7 @@ async def test_optimization(actor_pool):
574575
)
575576
pdf.to_csv(file_path, index=False)
576577

577-
df = md.read_csv(file_path)
578+
df = md.read_csv(file_path, incremental_index=config["incremental_index"])
578579
df2 = df.groupby("c").agg({"a": "sum"})
579580
df3 = df[["b", "a"]]
580581

mars/services/task/supervisor/tests/test_task_manager_on_ray_dag.py

+9
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,12 @@ async def test_iterative_tiling(ray_start_regular_shared2, actor_pool):
5555
@pytest.mark.asyncio
5656
async def test_numexpr(ray_start_regular_shared2, actor_pool):
5757
await test_task_manager.test_numexpr(actor_pool)
58+
59+
60+
# TODO(fyrestone): Support incremental index in ray backend.
61+
@require_ray
62+
@pytest.mark.parametrize("config", [{"incremental_index": False}])
63+
@pytest.mark.parametrize("actor_pool", [{"backend": "ray"}], indirect=True)
64+
@pytest.mark.asyncio
65+
async def test_optimization(ray_start_regular_shared2, actor_pool, config):
66+
await test_task_manager.test_optimization(actor_pool, config)

0 commit comments

Comments
 (0)