diff --git a/golem/task/rpc.py b/golem/task/rpc.py index c75cd568b3..bba922173e 100644 --- a/golem/task/rpc.py +++ b/golem/task/rpc.py @@ -14,6 +14,7 @@ from apps.core.task import coretask from apps.rendering.task import framerenderingtask +from apps.rendering.task.renderingtask import RenderingTask from golem.client import Client from golem.core import golem_async from golem.core import common @@ -668,3 +669,39 @@ def get_estimated_cost(self, _task_type: str, options: dict) -> dict: } logger.info('Estimated task cost. result=%r', result) return result + + @rpc_utils.expose('comp.task.rendering.task_fragments') + def get_fragments(self, task_id: str) -> \ + typing.Tuple[ + typing.Optional[typing.Dict[int, typing.List[typing.Dict]]], + typing.Optional[str] + ]: + """ + Returns the task fragments for a given rendering task. A single task + fragment is a collection of subtasks referring to the same, common part + of the whole task. Fragments are identified using incremental integer + indices. + :param task_id: Task ID of the rendering task for which fragments should + be obtained. + :return: A dictionary where keys are the fragment indices and values are + lists of subtasks asssociated with a given fragment. Returns None + (along with an error message) if the task is not known or it is not a + rendering task. + """ + task = self.task_manager.tasks.get(task_id) + if task is None: + return None, f"Task not found: '{task_id}'" + if not isinstance(task, RenderingTask): + return None, f"Incorrect task type: '{task.__class__.__name__}'" + + fragments: typing.Dict[int, typing.List[typing.Dict]] = {} + + for subtask_index in range(1, task.total_tasks + 1): + fragments[subtask_index] = [] + + for extra_data in task.subtasks_given.values(): + subtask = self.task_manager.get_subtask_dict( + extra_data['subtask_id']) + fragments[extra_data['start_task']].append(subtask) + + return fragments, None diff --git a/tests/golem/task/test_rpc.py b/tests/golem/task/test_rpc.py index 0875644b90..b9bd8eafc9 100644 --- a/tests/golem/task/test_rpc.py +++ b/tests/golem/task/test_rpc.py @@ -3,6 +3,7 @@ from tempfile import TemporaryDirectory import unittest from unittest import mock +import uuid import faker from ethereum.utils import denoms @@ -11,6 +12,8 @@ from twisted.internet import defer from apps.dummy.task import dummytaskstate +from apps.dummy.task.dummytask import DummyTask +from apps.rendering.task.renderingtask import RenderingTask from golem import clientconfigdescriptor from golem.core import common from golem.core import deferred as golem_deferred @@ -730,3 +733,58 @@ def test_basic(self, *_): subtasks, ) self.transaction_system.eth_for_deposit.assert_called_once_with() + + +@mock.patch('golem.task.taskmanager.TaskManager.get_subtask_dict', + return_value=Mock()) +class TestGetFragments(ProviderBase): + + def test_get_fragments(self, *_): + task_id = str(uuid.uuid4()) + subtasks_count = 3 + mock_task = Mock(spec=RenderingTask) + mock_task.total_tasks = subtasks_count + mock_task.subtasks_given = { + 'subtask-uuid-1': { + 'subtask_id': 'subtask-uuid-1', + 'start_task': 1, + }, + 'subtask-uuid-2': { + 'subtask_id': 'subtask-uuid-2', + 'start_task': 2, + }, + 'subtask-uuid-3': { + 'subtask_id': 'subtask-uuid-3', + 'start_task': 2, + }, + 'subtask-uuid-4': { + 'subtask_id': 'subtask-uuid-4', + 'start_task': 2, + }, + } + self.client.task_server.task_manager.tasks[task_id] = mock_task + + task_fragments, error = self.provider.get_fragments(task_id) + + self.assertTrue(len(task_fragments) == subtasks_count) + self.assertTrue(len(task_fragments[1]) == 1) + self.assertTrue(len(task_fragments[2]) == 3) + self.assertTrue(len(task_fragments[3]) == 0) + + def test_task_not_found(self, *_): + task_id = str(uuid.uuid4()) + + task_fragments, error = self.provider.get_fragments(task_id) + + self.assertIsNone(task_fragments) + self.assertTrue('Task not found' in error) + + def test_wrong_task_type(self, *_): + task_id = str(uuid.uuid4()) + mock_task = Mock(spec=DummyTask) + self.client.task_server.task_manager.tasks[task_id] = mock_task + + task_fragments, error = self.provider.get_fragments(task_id) + + self.assertIsNone(task_fragments) + self.assertTrue('Incorrect task type' in error)