diff --git a/golem/core/common.py b/golem/core/common.py index df9e29732f..da20a71728 100644 --- a/golem/core/common.py +++ b/golem/core/common.py @@ -236,7 +236,14 @@ def wrapper(*args, **kwargs): return decorator -def config_logging(suffix='', datadir=None, loglevel=None, config_desc=None): +# pylint: disable=too-many-branches,too-many-locals +def config_logging( + suffix='', + datadir=None, + loglevel=None, + config_desc=None, + formatter_prefix='', # prefix added to every logged line +): """Config logger""" try: from loggingconfig_local import LOGGING @@ -247,6 +254,8 @@ def config_logging(suffix='', datadir=None, loglevel=None, config_desc=None): datadir = simpleenv.get_local_datadir("default") logdir_path = os.path.join(datadir, 'logs') + for formatter in LOGGING.get('formatters', {}).values(): + formatter['format'] = f"{formatter_prefix}{formatter['format']}" for handler_name, handler in LOGGING.get('handlers', {}).items(): if 'filename' in handler: handler['filename'] %= { diff --git a/golem/network/p2p/p2pservice.py b/golem/network/p2p/p2pservice.py index e1df1b3d3a..3f17882876 100644 --- a/golem/network/p2p/p2pservice.py +++ b/golem/network/p2p/p2pservice.py @@ -788,20 +788,6 @@ def want_to_start_task_session( self.task_server\ .task_connections_helper.cannot_start_task_session(conn_id) - def peer_want_task_session(self, node_info, super_node_info, conn_id): - """Process request to start task session from this node to a node - from node_info. - :param Node node_info: node that requests task session with this node - :param Node|None super_node_info: information about supernode - that has passed this information - :param conn_id: connection id - """ - self.task_server.start_task_session( - node_info, - super_node_info, - conn_id - ) - ############################# # RANKING FUNCTIONS # ############################# diff --git a/golem/network/p2p/peersession.py b/golem/network/p2p/peersession.py index 1af963a399..e35f67cc0e 100644 --- a/golem/network/p2p/peersession.py +++ b/golem/network/p2p/peersession.py @@ -453,12 +453,10 @@ def _react_to_challenge_solution(self, msg): message.base.Disconnect.REASON.Unverified ) - def _react_to_want_to_start_task_session(self, msg): - self.p2p_service.peer_want_task_session( - msg.node_info, - msg.super_node_info, - msg.conn_id - ) + @classmethod + def _react_to_want_to_start_task_session(cls, msg): + # TODO: https://github.com/golemfactory/golem/issues/4005 + logger.debug("Ignored WTSTS. msg=%s", msg) def _react_to_set_task_session(self, msg): self.p2p_service.want_to_start_task_session( diff --git a/golem/network/transport/msg_queue.py b/golem/network/transport/msg_queue.py index 46c6b1c4da..6cc13a5532 100644 --- a/golem/network/transport/msg_queue.py +++ b/golem/network/transport/msg_queue.py @@ -5,22 +5,26 @@ import golem_messages from golem_messages import exceptions as msg_exceptions +from golem_messages import message from golem import decorators from golem import model from golem.core import variables -if typing.TYPE_CHECKING: - # pylint: disable=ungrouped-imports,unused-import - from golem_messages import message - - logger = logging.getLogger(__name__) READ_LOCK = threading.Lock() +# CLasses that aren't allowed in queue +FORBIDDEN_CLASSES = ( + message.base.Disconnect, + message.base.Hello, + message.base.RandVal, +) -def put(node_id: str, msg: 'message.base.Base') -> None: +def put(node_id: str, msg: message.base.Message) -> None: + assert not isinstance(msg, FORBIDDEN_CLASSES),\ + "Disconnect message shouldn't be in a queue" db_model = model.QueuedMessage.from_message(node_id, msg) db_model.save() diff --git a/golem/task/acl.py b/golem/task/acl.py index 995db35cb6..56ce1a7b24 100644 --- a/golem/task/acl.py +++ b/golem/task/acl.py @@ -1,4 +1,5 @@ import abc +import logging import operator import time from enum import Enum @@ -6,6 +7,10 @@ from typing import Dict, Set, Union, Iterable, Optional, Tuple from sortedcontainers import SortedList +from golem.core import common + +logger = logging.getLogger(__name__) + DENY_LIST_NAME = "deny.txt" ALL_EXCEPT_ALLOWED = "ALL_EXCEPT_ALLOWED" @@ -73,6 +78,12 @@ def is_allowed(self, node_id: str) -> Tuple[bool, Optional[DenyReason]]: def disallow(self, node_id: str, timeout_seconds: int = -1, persist: bool = False) -> None: + logger.info( + 'Banned node. node_id=%s, timeout=%ds, persist=%s', + common.short_node_id(node_id), + timeout_seconds, + persist, + ) if timeout_seconds < 0: self._deny_deadlines[node_id] = self._always else: @@ -114,6 +125,12 @@ def is_allowed(self, node_id: str) -> Tuple[bool, Optional[DenyReason]]: def disallow(self, node_id: str, timeout_seconds: int = 0, persist: bool = False) -> None: + logger.info( + 'Banned node. node_id=%s, timeout=%ds, persist=%s', + common.short_node_id(node_id), + timeout_seconds, + persist, + ) self._allow_set.discard(node_id) if persist and self._list_path: diff --git a/golem/task/server/helpers.py b/golem/task/server/helpers.py index c1b6a762d2..1e8b70aea3 100644 --- a/golem/task/server/helpers.py +++ b/golem/task/server/helpers.py @@ -1,7 +1,18 @@ import logging +import typing from golem_messages import message from golem_messages import helpers as msg_helpers +from golem_messages import utils as msg_utils + +from golem import model +from golem.core import common +from golem.network import history +from golem.network.transport import msg_queue + +if typing.TYPE_CHECKING: + # pylint: disable=unused-import + from golem.network.p2p.local_node import LocalNode logger = logging.getLogger(__name__) @@ -80,3 +91,130 @@ def on_error(exc, *_args, **_kwargs): client_options=client_options, output_dir=output_dir ) + +def send_report_computed_task(task_server, waiting_task_result) -> None: + """ Send task results after finished computations + """ + task_to_compute = history.get( + message_class_name='TaskToCompute', + node_id=waiting_task_result.owner.key, + task_id=waiting_task_result.task_id, + subtask_id=waiting_task_result.subtask_id + ) + + if not task_to_compute: + logger.warning( + "Cannot send ReportComputedTask. TTC missing." + " node=%s, task_id=%r, subtask_id=%r", + common.node_info_str( + waiting_task_result.owner.node_name, + waiting_task_result.owner.key, + ), + waiting_task_result.task_id, + waiting_task_result.subtask_id, + ) + return + + my_node: LocalNode = task_server.node + client_options = task_server.get_share_options( + waiting_task_result.task_id, + waiting_task_result.owner.prv_addr, + ) + + report_computed_task = message.tasks.ReportComputedTask( + task_to_compute=task_to_compute, + node_name=my_node.node_name, + address=my_node.prv_addr, + port=task_server.cur_port, + key_id=my_node.key, + node_info=my_node.to_dict(), + extra_data=[], + size=waiting_task_result.result_size, + package_hash='sha1:' + waiting_task_result.package_sha1, + multihash=waiting_task_result.result_hash, + secret=waiting_task_result.result_secret, + options=client_options.__dict__, + ) + + msg_queue.put( + waiting_task_result.owner.key, + report_computed_task, + ) + report_computed_task = msg_utils.copy_and_sign( + msg=report_computed_task, + private_key=task_server.keys_auth._private_key, # noqa pylint: disable=protected-access + ) + history.add( + msg=report_computed_task, + node_id=waiting_task_result.owner.key, + local_role=model.Actor.Provider, + remote_role=model.Actor.Requestor, + ) + + # if the Concent is not available in the context of this subtask + # we can only assume that `ReportComputedTask` above reaches + # the Requestor safely + + if not task_to_compute.concent_enabled: + logger.debug( + "Concent not enabled for this task, " + "skipping `ForceReportComputedTask`. " + "task_id=%r, " + "subtask_id=%r, ", + task_to_compute.task_id, + task_to_compute.subtask_id, + ) + return + + # we're preparing the `ForceReportComputedTask` here and + # scheduling the dispatch of that message for later + # (with an implicit delay in the concent service's `submit` method). + # + # though, should we receive the acknowledgement for + # the `ReportComputedTask` sent above before the delay elapses, + # the `ForceReportComputedTask` message to the Concent will be + # cancelled and thus, never sent to the Concent. + + delayed_forcing_msg = message.concents.ForceReportComputedTask( + report_computed_task=report_computed_task, + result_hash='sha1:' + waiting_task_result.package_sha1 + ) + logger.debug('[CONCENT] ForceReport: %s', delayed_forcing_msg) + + task_server.client.concent_service.submit_task_message( + waiting_task_result.subtask_id, + delayed_forcing_msg, + ) + + +def send_task_failure(waiting_task_failure) -> None: + """Inform task owner that an error occurred during task computation + """ + + task_to_compute = history.get( + message_class_name='TaskToCompute', + node_id=waiting_task_failure.owner.key, + task_id=waiting_task_failure.task_id, + subtask_id=waiting_task_failure.subtask_id + ) + + if not task_to_compute: + logger.warning( + "Cannot send TaskFailure. TTC missing." + " node=%s, task_id=%r, subtask_id=%r", + common.node_info_str( + waiting_task_failure.owner.node_name, + waiting_task_failure.owner.key, + ), + waiting_task_failure.task_id, + waiting_task_failure.subtask_id, + ) + return + + msg_queue.put( + waiting_task_failure.owner.key, + message.tasks.TaskFailure( + task_to_compute=task_to_compute, + err=waiting_task_failure.err_msg + ), + ) diff --git a/golem/task/server/queue.py b/golem/task/server/queue.py deleted file mode 100644 index 0e3caf299d..0000000000 --- a/golem/task/server/queue.py +++ /dev/null @@ -1,98 +0,0 @@ -import collections -import logging -import typing - -from golem.network import nodeskeeper -from golem.network.transport import msg_queue - -if typing.TYPE_CHECKING: - from golem_messages import message - - from golem.task import taskkeeper - from golem.task.tasksession import TaskSession - - -logger = logging.getLogger(__name__) - -class TaskMessagesQueueMixin: - """Message Queue functionality for TaskServer""" - - task_keeper: 'taskkeeper.TaskHeaderKeeper' - - def __init__(self): - for attr_name in ( - 'conn_established_for_type', - 'conn_failure_for_type', - 'conn_final_failure_for_type', - ): - if not hasattr(self, attr_name): - setattr(self, attr_name, {}) - - self.conn_established_for_type.update({ - 'msg_queue': self.msg_queue_connection_established, - }) - self.conn_failure_for_type.update({ - 'msg_queue': self.msg_queue_connection_failure, - }) - self.conn_final_failure_for_type.update({ - 'msg_queue': self.msg_queue_connection_final_failure, - }) - - def send_message(self, node_id: str, msg: 'message.base.Message'): - logger.debug('send_message(%r, %r)', node_id, msg) - msg_queue.put(node_id, msg) - - # Temporary code to immediately initiate session - node = self.task_keeper.find_newest_node(node_id) - if node is None: - node = nodeskeeper.get(node_id) - logger.debug("Found in memory %r", node) - if node is None: - logger.debug( - "Don't have any info about node. Will try later. node_id=%r", - node_id, - ) - return - self._add_pending_request( # type: ignore - 'msg_queue', - node, - prv_port=node.prv_port, - pub_port=node.pub_port, - args={ - 'node_id': node_id, - } - ) - - def msg_queue_connection_established( - self, - session: 'TaskSession', - conn_id, - node_id, - ): - self.new_session_prepare( # type: ignore - session=session, - key_id=node_id, - conn_id=conn_id, - ) - session.send_hello() - for msg in msg_queue.get(node_id): - session.send(msg) - - def msg_queue_connection_failure(self, conn_id, *args, **kwargs): - def cbk(session): - self.msg_queue_connection_established(session, conn_id, *args, **kwargs) - try: - self.response_list[conn_id].append(cbk) - except KeyError: - self.response_list[conn_id] = collections.deque([cbk]) - try: - pc = self.pending_connections[conn_id] - except KeyError: - pass - else: - pc.status = PenConnStatus.WaitingAlt - pc.time = time.time() - - def msg_queue_connection_final_failure(self, conn_id, *_args, **_kwargs): - self.remove_pending_conn(conn_id) - self.remove_responses(conn_id) diff --git a/golem/task/server/queue_.py b/golem/task/server/queue_.py new file mode 100644 index 0000000000..bec2a4ba99 --- /dev/null +++ b/golem/task/server/queue_.py @@ -0,0 +1,162 @@ +import logging +import time +import typing + +from golem_messages import message + +from golem.core import common +from golem.network import nodeskeeper +from golem.network.transport import msg_queue +from golem.network.transport import tcpserver + +if typing.TYPE_CHECKING: + # pylint: disable=unused-import + from golem.task import taskkeeper + from golem.task.tasksession import TaskSession + + +logger = logging.getLogger(__name__) + +class TaskMessagesQueueMixin: + """Message Queue functionality for TaskServer""" + + task_keeper: 'taskkeeper.TaskHeaderKeeper' + forwarded_session_requests: typing.Dict[str, dict] + + def __init__(self): + # Possible values of .sessions: + # None - PendingConnection + # TaskSession - session established + # Keys are always node_id a.k.a. key_id + self.sessions: 'typing.Dict[str, typing.Optional[TaskSession]]' = {} + + for attr_name in ( + 'conn_established_for_type', + 'conn_failure_for_type', + 'conn_final_failure_for_type', + ): + if not hasattr(self, attr_name): + setattr(self, attr_name, {}) + + self.conn_established_for_type.update({ + 'msg_queue': self.msg_queue_connection_established, + }) + self.conn_failure_for_type.update({ + 'msg_queue': self.msg_queue_connection_failure, + }) + self.conn_final_failure_for_type.update({ + 'msg_queue': self.msg_queue_connection_final_failure, + }) + + def initiate_session(self, node_id: str) -> None: + if node_id in self.sessions: + session = self.sessions[node_id] + if session is not None: + session.read_msg_queue() + return + + node = self.task_keeper.find_newest_node(node_id) + if node is None: + node = nodeskeeper.get(node_id) + logger.debug("Found in memory %r", node) + if node is None: + logger.debug( + "Don't have any info about node. Will try later. node_id=%r", + node_id, + ) + return + result = self._add_pending_request( # type: ignore + 'msg_queue', + node, + prv_port=node.prv_port, + pub_port=node.pub_port, + args={ + 'node_id': node_id, + } + ) + if result: + self.sessions[node_id] = None + + def remove_session_by_node_id(self, node_id): + try: + session = self.sessions[node_id] + except KeyError: + return + del self.sessions[node_id] + if session is None: + return + self.remove_session(session) + + def remove_session(self, session): + session.disconnect( + message.base.Disconnect.REASON.NoMoreMessages, + ) + self.remove_pending_conn(session.conn_id) + + def connect_to_nodes(self): + for node_id in msg_queue.waiting(): + self.initiate_session(node_id) + + def sweep_sessions(self): + for node_id in self.sessions: + session = self.sessions[node_id] + if session is None: + continue + if session.is_active: + continue + self.remove_session_by_node_id(node_id) + + def msg_queue_connection_established( + self, + session: 'TaskSession', + conn_id, + node_id, + ): + try: + if self.sessions[node_id] is not None: + # There is a session already established + # with this node_id. All messages will be processed + # in that other session. + session.dropped() + return + except KeyError: + pass + session.key_id = node_id + session.conn_id = conn_id + self.sessions[node_id] = session + self._mark_connected( # type: ignore + conn_id, + session.address, + session.port, + ) + self.forwarded_session_requests.pop(node_id, None) + session.send_hello() + + def msg_queue_connection_failure(self, conn_id, *_args, **_kwargs): + try: + pc = self.pending_connections[conn_id] + except KeyError: + pass + else: + pc.status = tcpserver.PenConnStatus.WaitingAlt + pc.time = time.time() + + def msg_queue_connection_final_failure( + self, + conn_id, + node_id, + *_args, + **_kwargs, + ): + logger.debug( + "Final connection failure for TaskSession." + " conn_id=%s, node_id=%s", + conn_id, + common.short_node_id(node_id), + ) + self.remove_pending_conn(conn_id) + try: + if self.sessions[node_id] is None: + del self.sessions[node_id] + except KeyError: + pass diff --git a/golem/task/server/resources.py b/golem/task/server/resources.py index 624bfa1393..dacf026cba 100644 --- a/golem/task/server/resources.py +++ b/golem/task/server/resources.py @@ -17,11 +17,13 @@ from golem.core.hostaddress import ip_address_private from golem.network.hyperdrive.client import HyperdriveClientOptions, \ to_hyperg_peer +from golem.network.transport import msg_queue from golem.resource.hyperdrive import resource as hpd_resource from golem.resource.resourcehandshake import ResourceHandshake if TYPE_CHECKING: + # pylint: disable=unused-import from golem.task import taskmanager @@ -264,12 +266,12 @@ def _nonce_shared(self, key_id, result, options): ) os.remove(handshake.file) - self.send_message( + msg_queue.put( node_id=key_id, msg=message.resources.ResourceHandshakeStart( resource=handshake.hash, options=options.__dict__, ), - ) + ) def _share_handshake_nonce(self, key_id): handshake = self.resource_handshakes.get(key_id) diff --git a/golem/task/server/verification.py b/golem/task/server/verification.py new file mode 100644 index 0000000000..c0eecb70f3 --- /dev/null +++ b/golem/task/server/verification.py @@ -0,0 +1,143 @@ +import logging +import typing + +from golem_messages import message +from golem_messages import utils as msg_utils +from golem_messages.datastructures import p2p as dt_p2p + +from golem import model +from golem.core import common +from golem.network import history +from golem.network.transport import msg_queue +from golem.task.result.resultmanager import ExtractedPackage + +if typing.TYPE_CHECKING: + # pylint: disable=unused-import + from golem.core import keysauth + from golem.task import taskmanager + +logger = logging.getLogger(__name__) + +class VerificationMixin: + keys_auth: 'keysauth.KeysAuth' + task_manager: 'taskmanager.TaskManager' + + def verify_results( + self, + report_computed_task: message.tasks.ReportComputedTask, + extracted_package: ExtractedPackage, + ) -> None: + + node = dt_p2p.Node(**report_computed_task.node_info) + subtask_id = report_computed_task.subtask_id + logger.info( + 'Verifying results. node=%s, subtask_id=%s', + common.node_info_str( + node.node_name, + node.key, + ), + subtask_id, + ) + result_files = extracted_package.get_full_path_files() + + def verification_finished(): + logger.debug("Verification finished handler.") + if not self.task_manager.verify_subtask(subtask_id): + logger.debug("Verification failure. subtask_id=%r", subtask_id) + self.send_result_rejected( + report_computed_task=report_computed_task, + reason=message.tasks.SubtaskResultsRejected.REASON + .VerificationNegative + ) + return + + task_to_compute = report_computed_task.task_to_compute + + config_desc = self.config_desc + if config_desc.disallow_node_timeout_seconds is not None: + # Experimental feature. Try to spread subtasks fairly amongst + # providers. + self.disallow_node( + node_id=task_to_compute.provider_id, + timeout_seconds=config_desc.disallow_node_timeout_seconds, + persist=False, + ) + if config_desc.disallow_ip_timeout_seconds is not None: + # Experimental feature. Try to spread subtasks fairly amongst + # providers. + self.disallow_ip( + ip=self.address, + timeout_seconds=config_desc.disallow_ip_timeout_seconds, + ) + + payment_processed_ts = self.accept_result( + subtask_id, + report_computed_task.provider_id, + task_to_compute.provider_ethereum_address, + task_to_compute.price, + ) + + response_msg = message.tasks.SubtaskResultsAccepted( + report_computed_task=report_computed_task, + payment_ts=payment_processed_ts, + ) + msg_queue.put(node.key, response_msg) + history.add( + msg_utils.copy_and_sign( + msg=response_msg, + private_key=self.keys_auth._private_key, # noqa pylint: disable=protected-access + ), + node_id=task_to_compute.provider_id, + local_role=model.Actor.Requestor, + remote_role=model.Actor.Provider, + ) + + self.task_manager.computed_task_received( + subtask_id, + result_files, + verification_finished + ) + + def send_result_rejected( + self, + report_computed_task: message.tasks.ReportComputedTask, + reason: message.tasks.SubtaskResultsRejected.REASON, + ) -> None: + """ + Inform that result doesn't pass the verification or that + the verification was not possible + + :param str subtask_id: subtask that has wrong result + :param SubtaskResultsRejected.Reason reason: the rejection reason + """ + + + logger.debug( + 'send_result_rejected. reason=%r, rct=%r', + reason, + report_computed_task, + ) + + node = dt_p2p.Node(**report_computed_task.node_info) + + self.reject_result( # type: ignore + report_computed_task.subtask_id, + node.key, + ) + + response_msg = message.tasks.SubtaskResultsRejected( + report_computed_task=report_computed_task, + reason=reason, + ) + msg_queue.put(node.key, response_msg) + + response_msg = msg_utils.copy_and_sign( + msg=response_msg, + private_key=self.keys_auth._private_key, # noqa pylint: disable=protected-access + ) + history.add( + response_msg, + node_id=report_computed_task.task_to_compute.provider_id, + local_role=model.Actor.Requestor, + remote_role=model.Actor.Provider, + ) diff --git a/golem/task/taskcomputer.py b/golem/task/taskcomputer.py index 8236324194..e9b688148d 100644 --- a/golem/task/taskcomputer.py +++ b/golem/task/taskcomputer.py @@ -132,7 +132,6 @@ def resource_failure(self, res_id, reason): 'Error downloading resources: {}'.format(reason), ) self.__task_finished(subtask) - self.session_closed() def task_computed(self, task_thread: TaskThread) -> None: if task_thread.end_time is None: @@ -317,12 +316,6 @@ def lock_config(self, on=True): for l in self.listeners: l.lock_config(on) - def session_timeout(self): - self.session_closed() - - def session_closed(self): - pass - def __request_task(self): if self.has_assigned_task(): return @@ -344,7 +337,7 @@ def __compute_task(self, subtask_id, docker_images, logger.warning("Subtask '%s' of task '%s' cannot be computed: " "task header has been unexpectedly removed", subtask_id, task_id) - return self.session_closed() + return deadline = min(task_header.deadline, subtask_deadline) task_timeout = deadline_to_timeout(deadline) diff --git a/golem/task/taskmanager.py b/golem/task/taskmanager.py index e6fd415d8c..30ab7f9435 100644 --- a/golem/task/taskmanager.py +++ b/golem/task/taskmanager.py @@ -18,6 +18,7 @@ from apps.core.task.coretask import CoreTask from apps.core.task.coretaskstate import TaskDefinition +from golem import model from golem.clientconfigdescriptor import ClientConfigDescriptor from golem.core.common import get_timestamp_utc, HandleForwardedError, \ HandleKeyError, node_info_str, short_node_id, to_unicode, update_dict @@ -396,8 +397,14 @@ def task_finished(self, task_id: str) -> bool: def task_needs_computation(self, task_id: str) -> bool: if self.task_finished(task_id): + task_status = self.tasks_states[task_id].status logger.info( - f'task is not active: {task_id}, status: {task_status}') + 'task is not active: %(task_id)s, status: %(task_status)s', + { + 'task_id': task_id, + 'task_status': task_status, + } + ) return False task = self.tasks[task_id] if not task.needs_computation(): @@ -730,6 +737,7 @@ def get_node_id_for_subtask(self, subtask_id): @handle_subtask_key_error def computed_task_received(self, subtask_id, result, verification_finished): + logger.debug("Computed task received. subtask_id=%s", subtask_id) task_id = self.subtask2task_mapping[subtask_id] subtask_state = self.tasks_states[task_id].subtask_states[subtask_id] @@ -747,6 +755,7 @@ def computed_task_received(self, subtask_id, result, @TaskManager.handle_generic_key_error def verification_finished_(): + logger.debug("Verification finished. subtask_id=%s", subtask_id) ss = self.__set_subtask_state_finished(subtask_id) if not self.tasks[task_id].verify_subtask(subtask_id): logger.debug("Subtask %r not accepted\n", subtask_id) @@ -999,6 +1008,17 @@ def query_task_state(self, task_id): return ts + def subtask_to_task( + self, + subtask_id: str, + local_role: model.Actor, + ) -> Optional[str]: + if local_role == model.Actor.Provider: + return self.comp_task_keeper.subtask_to_task.get(subtask_id) + elif local_role == model.Actor.Requestor: + return self.subtask2task_mapping.get(subtask_id) + return None + def get_subtasks(self, task_id) -> Optional[List[str]]: """ Get all subtasks related to given task id diff --git a/golem/task/taskserver.py b/golem/task/taskserver.py index 1020c42cfc..9c458b520b 100644 --- a/golem/task/taskserver.py +++ b/golem/task/taskserver.py @@ -5,7 +5,6 @@ import os import time import weakref -from collections import deque from enum import Enum from pathlib import Path from typing import ( @@ -18,7 +17,6 @@ from golem_messages import exceptions as msg_exceptions from golem_messages import message -from golem_messages.datastructures import p2p as dt_p2p from golem_messages.datastructures import tasks as dt_tasks from pydispatch import dispatcher from twisted.internet.defer import inlineCallbacks @@ -30,11 +28,13 @@ from golem.core.common import node_info_str, short_node_id from golem.environments.environment import SupportStatus, UnsupportReason from golem.marketplace import OfferPool +from golem.network.transport import msg_queue from golem.network.transport.network import ProtocolFactory, SessionFactory from golem.network.transport.tcpnetwork import ( TCPNetwork, SocketAddress, SafeProtocol) from golem.network.transport.tcpserver import ( - PendingConnectionsServer, PenConnStatus) + PendingConnectionsServer, +) from golem.ranking.helper.trust import Trust from golem.ranking.manager.database_manager import ( get_requestor_efficiency, @@ -53,10 +53,11 @@ from golem.task.taskstate import TaskOp from golem.utils import decode_hex -from .result.resultmanager import ExtractedPackage from .server import concent -from .server import queue as srv_queue +from .server import helpers +from .server import queue_ as srv_queue from .server import resources +from .server import verification as srv_verification from .taskcomputer import TaskComputer from .taskkeeper import TaskHeaderKeeper from .taskmanager import TaskManager @@ -83,6 +84,7 @@ class TaskServer( PendingConnectionsServer, resources.TaskResourcesMixin, srv_queue.TaskMessagesQueueMixin, + srv_verification.VerificationMixin, ): def __init__(self, node, @@ -126,11 +128,8 @@ def __init__(self, finished_cb=task_finished_cb) self.task_connections_helper = TaskConnectionsHelper() self.task_connections_helper.task_server = self - # Remove .task_sessions when Message Queue is implemented - # https://github.com/golemfactory/golem/issues/2223 - self.task_sessions = {} + self.sessions: Dict[str, TaskSession] = {} self.task_sessions_incoming: weakref.WeakSet = weakref.WeakSet() - self.task_sessions_outgoing: weakref.WeakSet = weakref.WeakSet() OfferPool.change_interval(self.config_desc.offer_pooling_interval) @@ -138,7 +137,6 @@ def __init__(self, self.min_trust = 0.0 self.last_messages = [] - self.last_message_time_threshold = config_desc.task_session_timeout self.results_to_send = {} self.failures_to_send = {} @@ -148,7 +146,6 @@ def __init__(self, self.forwarded_session_request_timeout = \ config_desc.waiting_for_task_session_timeout self.forwarded_session_requests = {} - self.response_list = {} self.acl = get_acl(Path(client.datadir), max_times=config_desc.disallow_id_max_times) self.acl_ip = DenyAcl([], max_times=config_desc.disallow_ip_max_times) @@ -180,36 +177,31 @@ def __init__(self, signal='golem.taskmanager' ) - @property - def all_sessions(self): - return frozenset( - self.task_sessions_outgoing | self.task_sessions_incoming, - ) - def sync_network(self, timeout=None): if timeout is None: - timeout = self.last_message_time_threshold + timeout = self.config_desc.task_session_timeout jobs = ( functools.partial( super().sync_network, timeout=timeout, ), self._sync_pending, - self.__send_waiting_results, + self._send_waiting_results, self.task_computer.run, self.task_connections_helper.sync, self._sync_forwarded_session_requests, self.__remove_old_tasks, - self.__remove_old_sessions, functools.partial( concent.process_messages_received_from_concent, concent_service=self.client.concent_service, ), + self.sweep_sessions, + self.connect_to_nodes, ) for job in jobs: try: - logger.debug("TServer sync running: job=%r", job) + #logger.debug("TServer sync running: job=%r", job) job() except Exception: # pylint: disable=broad-except logger.exception("TaskServer.sync_network job %r failed", job) @@ -326,7 +318,7 @@ def _request_task(self, theader: dt_tasks.TaskHeader) -> Optional[str]: provider_ethereum_public_key=self.get_key_id(), task_header=theader, ) - self.send_message( + msg_queue.put( node_id=theader.task_owner.key, msg=wtct, ) @@ -409,17 +401,27 @@ def send_task_failed( owner=header.task_owner) def new_connection(self, session): - if self.active: - self.task_sessions_incoming.add(session) - else: + if not self.active: session.disconnect(message.base.Disconnect.REASON.NoMoreMessages) + return + logger.debug( + 'Incoming TaskSession. address=%s:%d', + session.address, + session.port, + ) + self.task_sessions_incoming.add(session) def disconnect(self): - for task_session in self.all_sessions: + for node_id in list(self.sessions): try: + task_session = self.sessions[node_id] + if task_session is None: + # Pending connection + continue task_session.dropped() - except Exception as exc: - logger.error("Error closing incoming session: %s", exc) + del self.sessions[node_id] + except Exception as exc: # pylint: disable=broad-except + logger.error("Error closing session: %s", exc) def get_own_tasks_headers(self): return self.task_manager.get_tasks_headers() @@ -469,29 +471,12 @@ def remove_task_header(self, task_id) -> bool: self.requested_tasks.discard(task_id) return self.task_keeper.remove_task_header(task_id) - def add_task_session(self, subtask_id, session: TaskSession): - self.task_sessions[subtask_id] = session - - def remove_task_session(self, task_session: TaskSession): - self.remove_pending_conn(task_session.conn_id) - self.remove_responses(task_session.conn_id) - - for tsk in list(self.task_sessions.keys()): - if self.task_sessions[tsk] == task_session: - del self.task_sessions[tsk] - def set_last_message(self, type_, t, msg, address, port): if len(self.last_messages) >= 5: self.last_messages = self.last_messages[-4:] self.last_messages.append([type_, t, address, port, msg]) - def get_last_messages(self): - return self.last_messages - - def get_waiting_task_result(self, subtask_id): - return self.results_to_send.get(subtask_id, None) - def get_node_name(self): return self.config_desc.node_name @@ -518,7 +503,6 @@ def retry_sending_task_result(self, subtask_id): def change_config(self, config_desc, run_benchmarks=False): PendingConnectionsServer.change_config(self, config_desc) self.config_desc = config_desc - self.last_message_time_threshold = config_desc.task_session_timeout self.task_keeper.change_config(config_desc) return self.task_computer.change_config( config_desc, run_benchmarks=run_benchmarks) @@ -658,33 +642,6 @@ def reject_result(self, subtask_id, key_id): def get_computing_trust(self, node_id): return self.client.get_computing_trust(node_id) - def start_task_session(self, node_info, super_node_info, conn_id): - args = { - 'key_id': node_info.key, - 'node_info': node_info, - 'super_node_info': super_node_info, - 'ans_conn_id': conn_id - } - node = node_info - self._add_pending_request( - TASK_CONN_TYPES['start_session'], - node, - prv_port=node.prv_port, - pub_port=node.pub_port, - args=args - ) - - def respond_to(self, key_id, session, conn_id): - self.remove_pending_conn(conn_id) - responses = self.response_list.get(conn_id, None) - - if responses: - while responses: - res = responses.popleft() - res(session) - else: - session.dropped() - def get_socket_addresses(self, node_info, prv_port=None, pub_port=None): """ Change node info into tcp addresses. Adds a suggested address. :param Node node_info: node information @@ -718,20 +675,10 @@ def get_socket_addresses(self, node_info, prv_port=None, pub_port=None): def quit(self): self.task_computer.quit() - def remove_responses(self, conn_id): - self.response_list.pop(conn_id, None) - - def final_conn_failure(self, conn_id): - self.remove_responses(conn_id) - super(TaskServer, self).final_conn_failure(conn_id) - def add_forwarded_session_request(self, key_id, conn_id): self.forwarded_session_requests[key_id] = dict( conn_id=conn_id, time=time.time()) - def remove_forwarded_session_request(self, key_id): - return self.forwarded_session_requests.pop(key_id, None) - def get_min_performance_for_task(self, task: Task) -> float: env = self.get_environment_by_id(task.header.environment) return env.get_min_accepted_performance() @@ -865,8 +812,7 @@ def should_accept_requestor(self, node_id): logger.debug("Requesting trust level: %r", trust) if trust >= self.config_desc.requesting_trust: return SupportStatus.ok() - else: - return SupportStatus.err({UnsupportReason.REQUESTOR_TRUST: trust}) + return SupportStatus.err({UnsupportReason.REQUESTOR_TRUST: trust}) def disallow_node(self, node_id: str, timeout_seconds: int, persist: bool) \ -> None: @@ -879,13 +825,13 @@ def disallow_ip(self, ip: str, timeout_seconds: int) -> None: def _sync_forwarded_session_requests(self): now = time.time() for key_id, data in list(self.forwarded_session_requests.items()): - if data: - if now - data['time'] >= self.forwarded_session_request_timeout: - logger.debug('connection timeout: %s', data) - self.final_conn_failure(data['conn_id']) - self.remove_forwarded_session_request(key_id) - else: - self.forwarded_session_requests.pop(key_id) + if not data: + del self.forwarded_session_requests[key_id] + continue + if now - data['time'] >= self.forwarded_session_request_timeout: + logger.debug('connection timeout: %s', data) + del self.forwarded_session_requests[key_id] + self.final_conn_failure(data['conn_id']) def _get_factory(self): return self.factory(self) @@ -904,210 +850,6 @@ def _listening_failure(self, **kwargs): # sys.exit(0) ############################# - # CONNECTION REACTIONS # - ############################# - def __connection_for_task_request_established( - self, session: TaskSession, conn_id, node_name, key_id, task_id, - estimated_performance, price, max_resource_size, max_memory_size): - self.new_session_prepare( - session=session, - key_id=key_id, - conn_id=conn_id, - ) - session.send_hello() - session.request_task(node_name, task_id, estimated_performance, price, - max_resource_size, max_memory_size) - - def __connection_for_task_request_failure( - self, conn_id, node_name, key_id, task_id, estimated_performance, - price, max_resource_size, max_memory_size, *args): - def response(session): - return self.__connection_for_task_request_established( - session, conn_id, node_name, key_id, task_id, - estimated_performance, price, max_resource_size, - max_memory_size) - - if key_id in self.response_list: - self.response_list[conn_id].append(response) - else: - self.response_list[conn_id] = deque([response]) - - self.client.want_to_start_task_session(key_id, self.node, conn_id) - - pc = self.pending_connections.get(conn_id) - if pc: - pc.status = PenConnStatus.WaitingAlt - pc.time = time.time() - - def __connection_for_task_result_established(self, session, conn_id, - waiting_task_result): - self.new_session_prepare( - session=session, - key_id=waiting_task_result.owner.key, - conn_id=conn_id, - ) - - session.send_hello() - session.send_report_computed_task(waiting_task_result, - self.node.prv_addr, self.cur_port, - self.node) - - def __connection_for_task_result_failure(self, conn_id, - waiting_task_result): - def response(session): - self.__connection_for_task_result_established( - session, conn_id, waiting_task_result) - - if waiting_task_result.owner.key in self.response_list: - self.response_list[conn_id].append(response) - else: - self.response_list[conn_id] = deque([response]) - - self.client.want_to_start_task_session( - waiting_task_result.owner.key, self.node, conn_id) - - pc = self.pending_connections.get(conn_id) - if pc: - pc.status = PenConnStatus.WaitingAlt - pc.time = time.time() - - def __connection_for_task_failure_established(self, session, conn_id, - key_id, subtask_id, err_msg): - self.new_session_prepare( - session=session, - key_id=key_id, - conn_id=conn_id, - ) - session.send_hello() - session.send_task_failure(subtask_id, err_msg) - - def __connection_for_task_failure_failure(self, conn_id, key_id, - subtask_id, err_msg): - def response(session): - return self.__connection_for_task_failure_established( - session, conn_id, key_id, subtask_id, err_msg) - - if key_id in self.response_list: - self.response_list[conn_id].append(response) - else: - self.response_list[conn_id] = deque([response]) - - self.client.want_to_start_task_session(key_id, self.node, conn_id) - - pc = self.pending_connections.get(conn_id) - if pc: - pc.status = PenConnStatus.WaitingAlt - pc.time = time.time() - - def __connection_for_start_session_established( - self, session, conn_id, key_id, node_info, super_node_info, - ans_conn_id): - self.new_session_prepare( - session=session, - key_id=key_id, - conn_id=conn_id, - ) - session.send_hello() - session.send_start_session_response(ans_conn_id) - - def __connection_for_start_session_failure( - self, conn_id, key_id, node_info, super_node_info, ans_conn_id): - logger.info( - "Failed to start requested task session for node {}".format( - key_id)) - self.final_conn_failure(conn_id) - # self.__initiate_nat_traversal( - # key_id, node_info, super_node_info, ans_conn_id) - - def __connection_for_task_request_final_failure( - self, - conn_id, - node_name, - key_id, - task_id, - *_args, - **_kwargs, - ): - logger.info( - "Cannot connect to task owner. task_id: %s, node: %s", - task_id, - node_info_str( - node_name, - key_id, - ) - ) - self.task_keeper.remove_task_header(task_id) - self.task_manager.comp_task_keeper.request_failure(task_id) - self.remove_pending_conn(conn_id) - self.remove_responses(conn_id) - - def __connection_for_task_result_final_failure(self, conn_id, - waiting_task_result): - logger.info("Cannot connect to task {} owner".format( - waiting_task_result.subtask_id)) - - waiting_task_result.lastSendingTrial = time.time() - waiting_task_result.delayTime = \ - self.config_desc.max_results_sending_delay - waiting_task_result.alreadySending = False - self.remove_pending_conn(conn_id) - self.remove_responses(conn_id) - - def __connection_for_task_failure_final_failure(self, conn_id, key_id, - subtask_id, err_msg): - logger.info("Cannot connect to task {} owner".format(subtask_id)) - self.task_computer.session_timeout() - self.remove_pending_conn(conn_id) - self.remove_responses(conn_id) - - def __connection_for_start_session_final_failure( - self, conn_id, key_id, node_info, super_node_info, ans_conn_id): - logger.warning("Impossible to start session with {}".format(node_info)) - self.task_computer.session_timeout() - self.remove_pending_conn(conn_id) - self.remove_responses(conn_id) - self.remove_pending_conn(ans_conn_id) - self.remove_responses(ans_conn_id) - - def new_session_prepare(self, - session: TaskSession, - key_id: str, - conn_id: str): - self.remove_forwarded_session_request(key_id) - session.key_id = key_id - session.conn_id = conn_id - self._mark_connected(conn_id, session.address, session.port) - self.task_sessions_outgoing.add(session) - - @classmethod - def noop(cls, *args, **kwargs): - args_, kwargs_ = args, kwargs # avoid params name collision in logger - logger.debug('Noop(%r, %r)', args_, kwargs_) - - def __connection_for_task_verification_result_established( - self, - session: TaskSession, - conn_id, - extracted_package: ExtractedPackage, - key_id, - subtask_id: str): - - full_path_files = extracted_package.get_full_path_files() - self.new_session_prepare( - session=session, - key_id=key_id, - conn_id=conn_id, - ) - - session.send_hello() - session.result_received(subtask_id, full_path_files) - - def __connection_for_task_verification_result_failure( # noqa pylint:disable=no-self-use - self, _conn_id, _extracted_package, key_id, subtask_id: str): - logger.warning("Failed to establish a session to deliver " - "the verification result for %s to the provider %s", - subtask_id, key_id) - # SYNC METHODS ############################# def __remove_old_tasks(self): @@ -1117,135 +859,32 @@ def __remove_old_tasks(self): for node_id in nodes_with_timeouts: Trust.COMPUTED.decrease(node_id) - def __remove_old_sessions(self): - cur_time = time.time() - - for session in self.all_sessions: - dt = cur_time - session.last_message_time - if dt < self.last_message_time_threshold: - continue - if session.task_computer is not None: - session.task_computer.session_timeout() - session.dropped() - - def __send_waiting_results(self): + def _send_waiting_results(self): for subtask_id in list(self.results_to_send.keys()): - wtr = self.results_to_send[subtask_id] + wtr: WaitingTaskResult = self.results_to_send[subtask_id] now = time.time() if not wtr.already_sending: if now - wtr.last_sending_trial > wtr.delay_time: wtr.already_sending = True wtr.last_sending_trial = now - session = self.task_sessions.get(subtask_id, None) - if session: - self.__connection_for_task_result_established( - session, session.conn_id, wtr) - else: - args = {'waiting_task_result': wtr} - node = wtr.owner - self._add_pending_request( - TASK_CONN_TYPES['task_result'], - node, - prv_port=node.prv_port, - pub_port=node.pub_port, - args=args - ) - - for subtask_id in list(self.failures_to_send.keys()): - wtf = self.failures_to_send[subtask_id] - - session = self.task_sessions.get(subtask_id, None) - if session: - self.__connection_for_task_failure_established( - session, session.conn_id, wtf.owner.key, subtask_id, - wtf.err_msg) - else: - args = { - 'key_id': wtf.owner.key, - 'subtask_id': wtf.subtask_id, - 'err_msg': wtf.err_msg - } - node = wtf.owner - self._add_pending_request( - TASK_CONN_TYPES['task_failure'], - node, - prv_port=node.prv_port, - pub_port=node.pub_port, - args=args - ) + helpers.send_report_computed_task( + task_server=self, + waiting_task_result=wtr, + ) + for wtf in list(self.failures_to_send.values()): + helpers.send_task_failure( + waiting_task_failure=wtf, + ) self.failures_to_send.clear() - def verify_results( - self, - report_computed_task: message.tasks.ReportComputedTask, - extracted_package: ExtractedPackage) -> None: - - kwargs = { - 'extracted_package': extracted_package, - 'key_id': report_computed_task.key_id, - 'subtask_id': report_computed_task.subtask_id, - } - - node = dt_p2p.Node(**report_computed_task.node_info) - - self._add_pending_request( - TASK_CONN_TYPES['task_verification_result'], - node, - prv_port=node.prv_port, - pub_port=node.pub_port, - args=kwargs, - ) - # CONFIGURATION METHODS ############################# @staticmethod def __get_task_manager_root(datadir): return os.path.join(datadir, "ComputerRes") - def _set_conn_established(self): - self.conn_established_for_type.update({ - TASK_CONN_TYPES['task_request']: - self.__connection_for_task_request_established, - TASK_CONN_TYPES['task_result']: - self.__connection_for_task_result_established, - TASK_CONN_TYPES['task_failure']: - self.__connection_for_task_failure_established, - TASK_CONN_TYPES['start_session']: - self.__connection_for_start_session_established, - TASK_CONN_TYPES['task_verification_result']: - self.__connection_for_task_verification_result_established, - }) - - def _set_conn_failure(self): - self.conn_failure_for_type.update({ - TASK_CONN_TYPES['task_request']: - self.__connection_for_task_request_failure, - TASK_CONN_TYPES['task_result']: - self.__connection_for_task_result_failure, - TASK_CONN_TYPES['task_failure']: - self.__connection_for_task_failure_failure, - TASK_CONN_TYPES['start_session']: - self.__connection_for_start_session_failure, - TASK_CONN_TYPES['task_verification_result']: - self.__connection_for_task_verification_result_failure, - }) - - def _set_conn_final_failure(self): - self.conn_final_failure_for_type.update({ - TASK_CONN_TYPES['task_request']: - self.__connection_for_task_request_final_failure, - TASK_CONN_TYPES['task_result']: - self.__connection_for_task_result_final_failure, - TASK_CONN_TYPES['task_failure']: - self.__connection_for_task_failure_final_failure, - TASK_CONN_TYPES['start_session']: - self.__connection_for_start_session_final_failure, - TASK_CONN_TYPES['task_verification_result']: - self.__connection_for_task_verification_result_failure, - }) - # TODO: https://github.com/golemfactory/golem/issues/2633 # and remove linter switch offs @@ -1280,14 +919,3 @@ def __init__(self, task_id, subtask_id, err_msg, owner): self.subtask_id = subtask_id self.owner = owner self.err_msg = err_msg - - -# TODO: Get rid of archaic int labels and use plain strings instead. issue #2404 -TASK_CONN_TYPES = { - 'task_request': 1, - # unused: 'pay_for_task': 4, - 'task_result': 5, - 'task_failure': 6, - 'start_session': 7, - 'task_verification_result': 8, -} diff --git a/golem/task/tasksession.py b/golem/task/tasksession.py index 05a8ea064e..d7849fe787 100644 --- a/golem/task/tasksession.py +++ b/golem/task/tasksession.py @@ -1,11 +1,11 @@ # pylint: disable=too-many-lines -import copy import datetime import enum +import functools import logging import time -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional from ethereum.utils import denoms from golem_messages import exceptions as msg_exceptions @@ -16,6 +16,7 @@ import golem from golem.core import common +from golem.core import golem_async from golem.core.keysauth import KeysAuth from golem.core import variables from golem.docker.environment import DockerEnvironment @@ -25,6 +26,7 @@ from golem.network import history from golem.network import nodeskeeper from golem.network.concent import helpers as concent_helpers +from golem.network.transport import msg_queue from golem.network.transport import tcpnetwork from golem.network.transport.session import BasicSafeSession from golem.ranking.manager.database_manager import ( @@ -49,12 +51,6 @@ def drop_after_attr_error(*args, **_): args[0].dropped() -def call_task_computer_and_drop_after_attr_error(*args, **_): - logger.warning("Attribute error occured(2)", exc_info=True) - args[0].task_computer.session_closed() - args[0].dropped() - - def get_task_message( message_class_name, node_id, @@ -81,20 +77,6 @@ def get_task_message( return msg -def copy_and_sign(msg: message.base.Message, private_key) \ - -> message.base.Message: - """Returns signed shallow copy of message - - Copy is made only if original is unsigned. - """ - if msg.sig is None: - # If message is delayed in msgs_to_send then will - # overcome this by making a signed copy - msg = copy.copy(msg) - msg.sign_message(private_key) - return msg - - def check_docker_images( ctd: message.ComputeTaskDef, env: DockerEnvironment, @@ -120,9 +102,6 @@ class TaskSession(BasicSafeSession, ResourceHandshakeSessionMixin): """ Session for Golem task network """ handle_attr_error = common.HandleAttributeError(drop_after_attr_error) - handle_attr_error_with_task_computer = common.HandleAttributeError( - call_task_computer_and_drop_after_attr_error - ) def __init__(self, conn): """ @@ -133,12 +112,10 @@ def __init__(self, conn): """ BasicSafeSession.__init__(self, conn) ResourceHandshakeSessionMixin.__init__(self) + # set in server.queue.msg_queue_connection_established() self.conn_id = None # connection id - # set in TaskServer.new_session_prepare() self.key_id: Optional[str] = None - # messages waiting to be send (because connection hasn't been - # verified yet) - self.msgs_to_send = [] + self.__set_msg_interpretations() @property @@ -157,6 +134,15 @@ def task_computer(self) -> 'TaskComputer': def concent_service(self): return self.task_server.client.concent_service + @property + def is_active(self) -> bool: + if not self.conn.opened: + return False + inactivity: float = time.time() - self.last_message_time + if inactivity > self.task_server.config_desc.task_session_timeout: + return False + return True + ######################## # BasicSession methods # ######################## @@ -180,7 +166,7 @@ def interpret(self, msg): def dropped(self): """ Close connection """ BasicSafeSession.dropped(self) - self.task_server.remove_task_session(self) + self.task_server.remove_session_by_node_id(self.key_id) ####################### # SafeSession methods # @@ -194,262 +180,44 @@ def my_private_key(self) -> bytes: def my_public_key(self) -> bytes: return self.task_server.keys_auth.public_key - ################################### - # IMessageHistoryProvider methods # - ################################### - - def _subtask_to_task(self, sid, local_role): - if local_role == Actor.Provider: - return self.task_manager.comp_task_keeper.subtask_to_task.get(sid) - elif local_role == Actor.Requestor: - return self.task_manager.subtask2task_mapping.get(sid) - return None + def verify_owners(self, msg, my_role) -> bool: + if self.concent_service.available: + concent_key = self.concent_service.variant['pubkey'] + else: + concent_key = None + if my_role is Actor.Provider: + requestor_key = msg_utils.decode_hex(self.key_id) + provider_key = self.task_server.keys_auth.ecc.raw_pubkey + else: + requestor_key = self.task_server.keys_auth.ecc.raw_pubkey + provider_key = msg_utils.decode_hex(self.key_id) + try: + msg.verify_owners( + requestor_public_key=requestor_key, + provider_public_key=provider_key, + concent_public_key=concent_key, + ) + except msg_exceptions.MessageError: + node_id = common.short_node_id(self.key_id) + logger.info( + 'Dropping invalid %(msg_class)s.' + ' sender_node_id: %(node_id)s, task_id: %(task_id)s,' + ' subtask_id: %(subtask_id)s', + { + 'msg_class': msg.__class__.__name__, + 'node_id': node_id, + 'task_id': msg.task_id, + 'subtask_id': msg.subtask_id, + }, + ) + logger.debug('Invalid message received', exc_info=True) + return False + return True ####################### # FileSession methods # ####################### - def result_received(self, subtask_id: str, result_files: List[str]): - """ Inform server about received result - """ - def send_verification_failure(): - self._reject_subtask_result( - subtask_id, - reason=message.tasks.SubtaskResultsRejected.REASON - .VerificationNegative - ) - - def verification_finished(): - logger.debug("Verification finished handler.") - if not self.task_manager.verify_subtask(subtask_id): - logger.debug("Verification failure. subtask_id=%r", subtask_id) - send_verification_failure() - self.dropped() - return - - task_id = self._subtask_to_task(subtask_id, Actor.Requestor) - - report_computed_task = get_task_message( - message_class_name='ReportComputedTask', - node_id=self.key_id, - task_id=task_id, - subtask_id=subtask_id - ) - task_to_compute = report_computed_task.task_to_compute - - # FIXME Remove in 0.20 - if not task_to_compute.sig: - task_to_compute.sign_message(self.my_private_key) - - config_desc = self.task_server.config_desc - if config_desc.disallow_node_timeout_seconds is not None: - # Experimental feature. Try to spread subtasks fairly amongst - # providers. - self.task_server.disallow_node( - node_id=task_to_compute.provider_id, - timeout_seconds=config_desc.disallow_node_timeout_seconds, - persist=False, - ) - if config_desc.disallow_ip_timeout_seconds is not None: - # Experimental feature. Try to spread subtasks fairly amongst - # providers. - self.task_server.disallow_ip( - ip=self.address, - timeout_seconds=config_desc.disallow_ip_timeout_seconds, - ) - - payment_processed_ts = self.task_server.accept_result( - subtask_id, - self.key_id, - task_to_compute.provider_ethereum_address, - task_to_compute.price, - ) - - response_msg = message.tasks.SubtaskResultsAccepted( - report_computed_task=report_computed_task, - payment_ts=payment_processed_ts, - ) - self.send(response_msg) - history.add( - copy_and_sign( - msg=response_msg, - private_key=self.my_private_key, - ), - node_id=task_to_compute.provider_id, - local_role=Actor.Requestor, - remote_role=Actor.Provider, - ) - self.dropped() - - self.task_manager.computed_task_received( - subtask_id, - result_files, - verification_finished - ) - - def _reject_subtask_result(self, subtask_id, reason): - logger.debug('_reject_subtask_result(%r, %r)', subtask_id, reason) - - self.task_server.reject_result(subtask_id, self.key_id) - self.send_result_rejected(subtask_id, reason) - - # TODO address, port and eth_account should be in node_info - # (or shouldn't be here at all). Issue #2403 - def send_report_computed_task( - self, - task_result, - address, - port, - node_info): - """ Send task results after finished computations - :param WaitingTaskResult task_result: finished computations result - with additional information - :param str address: task result owner address - :param int port: task result owner port - :param Node node_info: information about this node - :return: - """ - extra_data = [] - - node_name = self.task_server.get_node_name() - - task_to_compute = get_task_message( - message_class_name='TaskToCompute', - node_id=self.key_id, - task_id=task_result.task_id, - subtask_id=task_result.subtask_id - ) - - if not task_to_compute: - return - - client_options = self.task_server.get_share_options(task_result.task_id, - self.address) - - report_computed_task = message.tasks.ReportComputedTask( - task_to_compute=task_to_compute, - node_name=node_name, - address=address, - port=port, - key_id=self.task_server.get_key_id(), - node_info=node_info.to_dict(), - extra_data=extra_data, - size=task_result.result_size, - package_hash='sha1:' + task_result.package_sha1, - multihash=task_result.result_hash, - secret=task_result.result_secret, - options=client_options.__dict__, - ) - - self.send(report_computed_task) - report_computed_task = copy_and_sign( - msg=report_computed_task, - private_key=self.my_private_key, - ) - history.add( - msg=report_computed_task, - node_id=self.key_id, - local_role=Actor.Provider, - remote_role=Actor.Requestor, - ) - - # if the Concent is not available in the context of this subtask - # we can only assume that `ReportComputedTask` above reaches - # the Requestor safely - - if not task_to_compute.concent_enabled: - logger.debug( - "Concent not enabled for this task, " - "skipping `ForceReportComputedTask`. " - "task_id=%r, " - "subtask_id=%r, ", - task_to_compute.task_id, - task_to_compute.subtask_id, - ) - return - - # we're preparing the `ForceReportComputedTask` here and - # scheduling the dispatch of that message for later - # (with an implicit delay in the concent service's `submit` method). - # - # though, should we receive the acknowledgement for - # the `ReportComputedTask` sent above before the delay elapses, - # the `ForceReportComputedTask` message to the Concent will be - # cancelled and thus, never sent to the Concent. - - delayed_forcing_msg = message.concents.ForceReportComputedTask( - report_computed_task=report_computed_task, - result_hash='sha1:' + task_result.package_sha1 - ) - logger.debug('[CONCENT] ForceReport: %s', delayed_forcing_msg) - - self.concent_service.submit_task_message( - task_result.subtask_id, - delayed_forcing_msg, - ) - - def send_task_failure(self, subtask_id, err_msg): - """ Inform task owner that an error occurred during task computation - :param str subtask_id: - :param err_msg: error message that occurred during computation - """ - - task_id = self._subtask_to_task(subtask_id, Actor.Provider) - - task_to_compute = get_task_message( - message_class_name='TaskToCompute', - node_id=self.key_id, - task_id=task_id, - subtask_id=subtask_id - ) - - if not task_to_compute: - logger.warning("Could not retrieve TaskToCompute" - " for subtask_id: %s, task_id: %s", - subtask_id, task_id) - return - - self.send( - message.tasks.TaskFailure( - task_to_compute=task_to_compute, - err=err_msg - ) - ) - - def send_result_rejected(self, subtask_id, reason): - """ - Inform that result doesn't pass the verification or that - the verification was not possible - - :param str subtask_id: subtask that has wrong result - :param SubtaskResultsRejected.Reason reason: the rejection reason - """ - - task_id = self._subtask_to_task(subtask_id, Actor.Requestor) - - report_computed_task = get_task_message( - message_class_name='ReportComputedTask', - node_id=self.key_id, - task_id=task_id, - subtask_id=subtask_id - ) - - response_msg = message.tasks.SubtaskResultsRejected( - report_computed_task=report_computed_task, - reason=reason, - ) - self.send(response_msg) - response_msg = copy_and_sign( - msg=response_msg, - private_key=self.my_private_key, - ) - history.add( - response_msg, - node_id=report_computed_task.task_to_compute.provider_id, - local_role=Actor.Requestor, - remote_role=Actor.Provider, - ) - def send_hello(self): """ Send first hello message, that should begin the communication """ self.send( @@ -463,43 +231,44 @@ def send_hello(self): send_unverified=True ) - def send_start_session_response(self, conn_id): - """Inform that this session was started as an answer for a request - to start task session - :param uuid conn_id: connection id for reference - """ - self.send(message.tasks.StartSessionResponse(conn_id=conn_id)) + def read_msg_queue(self): + if not self.key_id: + return + if not self.verified: + return + for msg in msg_queue.get(self.key_id): + self.send(msg) ######################### # Reactions to messages # ######################### + def _cannot_assign_task(self, task_id, reason): + logger.debug("Cannot assign task: %r", reason) + self.send( + message.tasks.CannotAssignTask( + task_id=task_id, + reason=reason, + ), + ) + self.dropped() # pylint: disable=too-many-return-statements def _react_to_want_to_compute_task(self, msg): - def _cannot_assign(reason): - logger.debug("Cannot assign task: %r", reason) - self.send( - message.tasks.CannotAssignTask( - task_id=msg.task_id, - reason=reason, - ), - ) - self.dropped() reasons = message.tasks.CannotAssignTask.REASON if msg.concent_enabled and not self.concent_service.enabled: - _cannot_assign(reasons.ConcentDisabled) + self._cannot_assign_task(msg.task_id, reasons.ConcentDisabled) return if not self.task_manager.is_my_task(msg.task_id): - _cannot_assign(reasons.NotMyTask) + self._cannot_assign_task(msg.task_id, reasons.NotMyTask) return try: msg.task_header.verify(self.my_public_key) except msg_exceptions.InvalidSignature: - _cannot_assign(reasons.NotMyTask) + self._cannot_assign_task(msg.task_id, reasons.NotMyTask) return node_name_id = common.node_info_str(msg.node_name, self.key_id) @@ -538,7 +307,7 @@ def _cannot_assign(reason): ) if not task_server_ok: - _cannot_assign(reasons.NoMoreSubtasks) + self._cannot_assign_task(msg.task_id, reasons.NoMoreSubtasks) return if not self.task_manager.check_next_subtask( @@ -548,7 +317,7 @@ def _cannot_assign(reason): msg.task_id, node_name_id, ) - _cannot_assign(reasons.NoMoreSubtasks) + self._cannot_assign_task(msg.task_id, reasons.NoMoreSubtasks) return if self.task_manager.task_finished(msg.task_id): @@ -557,7 +326,7 @@ def _cannot_assign(reason): msg.task_id, node_name_id, ) - _cannot_assign(reasons.TaskFinished) + self._cannot_assign_task(msg.task_id, reasons.TaskFinished) return if self.task_manager.should_wait_for_node(msg.task_id, self.key_id): @@ -593,91 +362,104 @@ def _cannot_assign(reason): msg.task_id, node_name_id) return - def _offer_chosen(is_chosen: bool) -> None: - if not self.conn.opened: - logger.info( - "Provider disconnected. task_id=%r, node=%r", - msg.task_id, - node_name_id, - ) - return - - if not is_chosen: - logger.info( - "Provider not chosen by marketplace. task_id=%r, node=%r", - msg.task_id, - node_name_id, - ) - _cannot_assign(reasons.NoMoreSubtasks) - return - - logger.info("Offer confirmed, assigning subtask") - ctd = self.task_manager.get_next_subtask( - self.key_id, msg.node_name, msg.task_id, msg.perf_index, - msg.price, msg.max_resource_size, msg.max_memory_size, - self.address) - - ctd["resources"] = self.task_server.get_resources(msg.task_id) - logger.debug( - "CTD generated. task_id=%s, node=%s ctd=%s", - msg.task_id, - node_name_id, - ctd, - ) + task = self.task_manager.tasks[msg.task_id] + offer = Offer( + scaled_price=scale_price(task.header.max_price, msg.price), + reputation=get_provider_efficiency(self.key_id), + quality=get_provider_efficacy(self.key_id).vector, + ) - if ctd is None: - _cannot_assign(reasons.NoMoreSubtasks) - return + logger.debug( + "Offer accepted & added to pool. offer=%s", + offer, + ) + d = OfferPool.add(msg.task_id, offer) + d.addCallback( + functools.partial( + self._offer_chosen, + msg=msg, + node_id=self.key_id, + ), + ) + # Adding errback won't be needed in asyncio + d.addErrback(golem_async.default_errback) + def _offer_chosen( + self, + is_chosen: bool, + msg: message.tasks.WantToComputeTask, + node_id: str, + ) -> None: + node_name_id = common.node_info_str(msg.node_name, node_id) + reasons = message.tasks.CannotAssignTask.REASON + if not is_chosen: logger.info( - "Subtask assigned. task_id=%r, node=%s, subtask_id=%r", + "Provider not chosen by marketplace. task_id=%r, node=%r", msg.task_id, node_name_id, - ctd["subtask_id"], - ) - task = self.task_manager.tasks[ctd['task_id']] - task_state = self.task_manager.tasks_states[ctd['task_id']] - price = taskkeeper.compute_subtask_value( - msg.price, - task.header.subtask_timeout, - ) - ttc = message.tasks.TaskToCompute( - compute_task_def=ctd, - want_to_compute_task=msg, - requestor_id=task.header.task_owner.key, - requestor_public_key=task.header.task_owner.key, - requestor_ethereum_public_key=task.header.task_owner.key, - provider_id=self.key_id, - package_hash='sha1:' + task_state.package_hash, - concent_enabled=msg.concent_enabled, - price=price, - size=task_state.package_size, - resources_options=self.task_server.get_share_options( - ctd['task_id'], self.address).__dict__ - ) - ttc.generate_ethsig(self.my_private_key) - self.send(ttc) - history.add( - msg=copy_and_sign( - msg=ttc, - private_key=self.my_private_key, - ), - node_id=self.key_id, - local_role=Actor.Requestor, - remote_role=Actor.Provider, ) + self._cannot_assign_task(msg.task_id, reasons.NoMoreSubtasks) + return - task = self.task_manager.tasks[msg.task_id] - offer = Offer( - scaled_price=scale_price(task.header.max_price, msg.price), - reputation=get_provider_efficiency(self.key_id), - quality=get_provider_efficacy(self.key_id).vector, + logger.info("Offer confirmed, assigning subtask") + ctd = self.task_manager.get_next_subtask( + self.key_id, msg.node_name, msg.task_id, msg.perf_index, + msg.price, msg.max_resource_size, msg.max_memory_size, + self.address) + + logger.debug( + "CTD generated. task_id=%s, node=%s ctd=%s", + msg.task_id, + node_name_id, + ctd, ) - OfferPool.add(msg.task_id, offer).addCallback(_offer_chosen) + if ctd is None: + self._cannot_assign_task(msg.task_id, reasons.NoMoreSubtasks) + return + + ctd["resources"] = self.task_server.get_resources(msg.task_id) + + logger.info( + "Subtask assigned. task_id=%r, node=%s, subtask_id=%r", + msg.task_id, + node_name_id, + ctd["subtask_id"], + ) + task = self.task_manager.tasks[ctd['task_id']] + task_state = self.task_manager.tasks_states[ctd['task_id']] + price = taskkeeper.compute_subtask_value( + msg.price, + task.header.subtask_timeout, + ) + ttc = message.tasks.TaskToCompute( + compute_task_def=ctd, + want_to_compute_task=msg, + requestor_id=task.header.task_owner.key, + requestor_public_key=task.header.task_owner.key, + requestor_ethereum_public_key=task.header.task_owner.key, + provider_id=self.key_id, + package_hash='sha1:' + task_state.package_hash, + concent_enabled=msg.concent_enabled, + price=price, + size=task_state.package_size, + resources_options=self.task_server.get_share_options( + ctd['task_id'], self.address).__dict__ + ) + ttc.generate_ethsig(self.my_private_key) + self.send(ttc) + history.add( + msg=msg_utils.copy_and_sign( + msg=ttc, + private_key=self.my_private_key, + ), + node_id=self.key_id, + local_role=Actor.Requestor, + remote_role=Actor.Provider, + ) # pylint: disable=too-many-return-statements - @handle_attr_error_with_task_computer + @handle_attr_error @history.provider_history def _react_to_task_to_compute(self, msg): ctd: Optional[message.tasks.ComputeTaskDef] = msg.compute_task_def @@ -685,7 +467,6 @@ def _react_to_task_to_compute(self, msg): if ctd is None or want_to_compute_task is None: logger.debug( 'TaskToCompute without ctd or want_to_compute_task: %r', msg) - self.task_computer.session_closed() self.dropped() return @@ -696,7 +477,6 @@ def _react_to_task_to_compute(self, msg): logger.debug( 'WantToComputeTask attached to TaskToCompute is not signed ' 'with key: %r.', want_to_compute_task.provider_public_key) - self.task_computer.session_closed() self.dropped() return @@ -714,7 +494,6 @@ def _cannot_compute(reason): reason=reason, ), ) - self.task_computer.session_closed() self.dropped() reasons = message.tasks.CannotComputeTask.REASON @@ -784,9 +563,6 @@ def _cannot_compute(reason): return self.task_manager.comp_task_keeper.receive_subtask(msg) - self.task_server.add_task_session( - ctd['subtask_id'], self - ) if not self.task_server.task_given(self.key_id, ctd, msg.price): _cannot_compute(None) return @@ -804,37 +580,12 @@ def _react_to_waiting_for_results( self, msg: message.tasks.WaitingForResults, ): - if self.concent_service.available: - concent_key = self.concent_service.variant['pubkey'] - else: - concent_key = None - try: - msg.verify_owners( - requestor_public_key=msg_utils.decode_hex(self.key_id), - provider_public_key=self.task_server.keys_auth.ecc.raw_pubkey, - concent_public_key=concent_key, - ) - except msg_exceptions.MessageError: - node_id = common.short_node_id(self.key_id) - logger.info( - 'Dropping invalid WaitingForResults.' - ' sender_node_id: %(node_id)s, task_id: %(task_id)s,' - ' subtask_id: %(subtask_id)s', - { - 'node_id': node_id, - 'task_id': msg.task_id, - 'subtask_id': msg.subtask_id, - }, - ) - logger.debug('Invalid WaitingForResults received', exc_info=True) + if not self.verify_owners(msg, my_role=Actor.Provider): return self.task_server.subtask_waiting( task_id=msg.task_id, subtask_id=msg.subtask_id, ) - self.task_computer.session_closed() - if not self.msgs_to_send: - self.disconnect(message.base.Disconnect.REASON.NoMoreMessages) def _react_to_cannot_compute_task(self, msg): if self.check_provider_for_subtask(msg.subtask_id): @@ -866,15 +617,19 @@ def _react_to_cannot_assign_task(self, msg): msg.task_id, msg.reason, ) + self.task_server.requested_tasks.discard(msg.task_id) reasons = message.tasks.CannotAssignTask.REASON if msg.reason is reasons.TaskFinished: + # Requestor doesn't want us to ask again self.task_server.remove_task_header(msg.task_id) self.task_manager.comp_task_keeper.request_failure(msg.task_id) - self.task_computer.session_closed() self.dropped() @history.requestor_history def _react_to_report_computed_task(self, msg): + if not self.verify_owners(msg, my_role=Actor.Requestor): + return + subtask_id = msg.subtask_id if not self.check_provider_for_subtask(subtask_id): self.dropped() @@ -885,29 +640,21 @@ def _react_to_report_computed_task(self, msg): self.dropped() return - self.task_server.add_task_session( - msg.subtask_id, self - ) - returned_msg = concent_helpers.process_report_computed_task( msg=msg, ecc=self.task_server.keys_auth.ecc, ) self.send(returned_msg) if not isinstance(returned_msg, message.tasks.AckReportComputedTask): - self.dropped() return - def after_success(): - self.disconnect(message.base.Disconnect.REASON.NoMoreMessages) - def after_error(): if msg.task_to_compute.concent_enabled: return # in case of resources failure, if we're not using the Concent # we're immediately sending a rejection message to the Provider - self._reject_subtask_result( - subtask_id, + self.task_server.send_result_rejected( + report_computed_task=msg, reason=message.tasks.SubtaskResultsRejected.REASON .ResourcesFailure, ) @@ -915,12 +662,10 @@ def after_error(): subtask_id, 'Error downloading task result' ) - self.dropped() task_server_helpers.computed_task_reported( task_server=self.task_server, report_computed_task=msg, - after_success=after_success, after_error=after_error, ) @@ -1036,11 +781,26 @@ def _react_to_task_failure(self, msg): def _react_to_hello(self, msg): if not self.conn.opened: + logger.info("Hello received after connection closed. msg=%s", msg) return send_hello = False if self.key_id is None: self.key_id = msg.client_key_id + try: + existing_session = self.task_server.sessions[self.key_id] + except KeyError: + self.task_server.sessions[self.key_id] = self + else: + if (existing_session is not None)\ + and existing_session is not self: + node_name = getattr(msg.node_info, 'node_name', '') + logger.debug( + 'Duplicated session. Dropping. node=%s', + common.node_info_str(node_name, self.key_id), + ) + self.dropped() + return send_hello = True if (msg.proto_id != variables.PROTOCOL_CONST.ID)\ @@ -1083,17 +843,17 @@ def _react_to_rand_val(self, msg): if self.key_id is None: return - if self.rand_val == msg.rand_val: - self.verified = True - self.task_server.verified_conn(self.conn_id, ) - for msg_ in self.msgs_to_send: - self.send(msg_) - self.msgs_to_send = [] - else: + if self.rand_val != msg.rand_val: self.disconnect(message.base.Disconnect.REASON.Unverified) + self.verified = True + self.task_server.verified_conn(self.conn_id, ) + self.read_msg_queue() + def _react_to_start_session_response(self, msg): - self.task_server.respond_to(self.key_id, self, msg.conn_id) + raise NotImplementedError( + "Implement reversed task session request #4005", + ) @history.provider_history def _react_to_ack_report_computed_task(self, msg): @@ -1151,9 +911,22 @@ def _react_to_reject_report_computed_task(self, msg): "an unknown task (subtask_id='%s')", self.key_id, msg.subtask_id) + def disconnect(self, reason: message.base.Disconnect.REASON): + if not self.conn.opened: + return + if not (self.verified and self.key_id): + self.dropped() + return + super().disconnect(reason) + def send(self, msg, send_unverified=False): + if self.key_id and not self.conn.opened: + msg_queue.put(self.key_id, msg) + return if not self.verified and not send_unverified: - self.msgs_to_send.append(msg) + if not self.key_id: + raise RuntimeError('Connection unverified') + msg_queue.put(self.key_id, msg) return BasicSafeSession.send(self, msg, send_unverified=send_unverified) self.task_server.set_last_message( @@ -1167,7 +940,7 @@ def send(self, msg, send_unverified=False): def check_provider_for_subtask(self, subtask_id) -> bool: node_id = self.task_manager.get_node_id_for_subtask(subtask_id) if node_id != self.key_id: - logger.warning('Received message about subtask %r from diferrent ' + logger.warning('Received message about subtask %r from different ' 'node %r than expected %r', subtask_id, self.key_id, node_id) return False diff --git a/golem/testutils.py b/golem/testutils.py index f5adb9e45f..d22105debf 100644 --- a/golem/testutils.py +++ b/golem/testutils.py @@ -66,7 +66,7 @@ def tearDown(self): except OSError as e: logger.debug("%r", e, exc_info=True) tree = '' - for path, dirs, files in os.walk(self.path): + for path, _dirs, files in os.walk(self.path): tree += path + '\n' for f in files: tree += f + '\n' @@ -108,7 +108,7 @@ def additional_dir_content(self, file_num_list, dir_=None, results=None, results = [] for el in file_num_list: if isinstance(el, int): - for i in range(el): + for _ in range(el): t = tempfile.NamedTemporaryFile(dir=dir_, delete=False) results.append(t.name) else: diff --git a/requirements.txt b/requirements.txt index 1899d66120..2a134b8de5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,7 +33,7 @@ eth-keys==0.2.0b3 eth-tester==0.1.0-beta.24 eth-utils==1.0.3 ethereum==1.6.1 -Golem-Messages==3.3.0 +Golem-Messages==3.3.1 Golem-Smart-Contracts-Interface==1.7.0 greenlet==0.4.15 h2==3.0.1 diff --git a/requirements_to-freeze.txt b/requirements_to-freeze.txt index a746d49a83..65023525d7 100644 --- a/requirements_to-freeze.txt +++ b/requirements_to-freeze.txt @@ -15,7 +15,7 @@ docker==3.5.0 enforce==0.3.4 eth-utils==1.0.3 ethereum==1.6.1 -Golem-Messages==3.3.0 +Golem-Messages==3.3.1 Golem-Smart-Contracts-Interface==1.7.0 html2text==2018.1.9 humanize==0.5.1 diff --git a/tests/factories/taskserver.py b/tests/factories/taskserver.py index 05978584ae..84174892b1 100644 --- a/tests/factories/taskserver.py +++ b/tests/factories/taskserver.py @@ -51,3 +51,15 @@ def xsubtask_id( # pylint: disable=no-self-argument ): value = extracted or idgenerator.generate_id_from_hex(wtr.owner.key) # noqa pylint: disable=no-member wtr.subtask_id = value + + +class WaitingTaskFailureFactory(factory.Factory): + class Meta: + model = taskserver.WaitingTaskFailure + + task_id = factory.Faker('uuid4') + subtask_id = factory.Faker('uuid4') + err_msg = factory.Faker('text') + owner = factory.SubFactory( + 'golem_messages.factories.datastructures.p2p.Node', + ) diff --git a/tests/golem/network/test_nodeskeeper.py b/tests/golem/network/test_nodeskeeper.py new file mode 100644 index 0000000000..ec10dcc66a --- /dev/null +++ b/tests/golem/network/test_nodeskeeper.py @@ -0,0 +1,15 @@ +from golem_messages.factories.datastructures import p2p as dt_p2p_factory +from golem import testutils +from golem.network import nodeskeeper + +class TestNodesKeeper(testutils.DatabaseFixture): + def setUp(self): + super().setUp() + self.node = dt_p2p_factory.Node() + + def test_get(self): + nodeskeeper.store(self.node) + self.assertEqual( + self.node, + nodeskeeper.get(self.node.key), + ) diff --git a/tests/golem/task/dummy/runner.py b/tests/golem/task/dummy/runner.py index 7cced1fc80..f062459971 100644 --- a/tests/golem/task/dummy/runner.py +++ b/tests/golem/task/dummy/runner.py @@ -116,6 +116,7 @@ def _make_mock_ets(): ets.eth_base_for_batch_payment.return_value = 0.001 * denoms.ether ets.get_payment_address.return_value = '0x' + 40 * '6' ets.get_nodes_with_overdue_payments.return_value = [] + ets.add_payment_info.return_value = int(time.time()) return ets @@ -148,7 +149,11 @@ def shutdown(): start_time = time.time() report("Starting in {}".format(datadir)) from golem.core.common import config_logging - config_logging(datadir=datadir, loglevel="DEBUG") + config_logging( + datadir=datadir, + loglevel="DEBUG", + formatter_prefix="RQ ", + ) client = create_client(datadir, '[Requestor] DUMMY') client.are_terms_accepted = lambda: True @@ -204,7 +209,11 @@ def shutdown(): start_time = time.time() report("Starting in {}".format(datadir)) from golem.core.common import config_logging - config_logging(datadir=datadir, loglevel="DEBUG") + config_logging( + datadir=datadir, + loglevel="DEBUG", + formatter_prefix=f"P{provider_id} ", + ) client = create_client(datadir, f'[Provider{ provider_id }] DUMMY') client.are_terms_accepted = lambda: True diff --git a/tests/golem/task/dummy/task.py b/tests/golem/task/dummy/task.py index 062072ad40..afd0ca38e2 100644 --- a/tests/golem/task/dummy/task.py +++ b/tests/golem/task/dummy/task.py @@ -12,7 +12,7 @@ import golem from golem.appconfig import MIN_PRICE -from golem.core.common import timeout_to_deadline, get_timestamp_utc +from golem.core import common from golem.task.taskbase import Task, AcceptClientVerdict @@ -42,6 +42,10 @@ def __init__(self, shared_data_size, subtask_data_size, result_size, self.result_size = result_size self.difficulty = difficulty + def __str__(self): + import pprint + return pprint.pformat(self.__dict__) + # pylint: disable=too-many-locals class DummyTask(Task): @@ -79,13 +83,13 @@ def __init__(self, client_id, params, num_subtasks, public_key): task_id=task_id, task_owner=task_owner, environment=environment, - deadline=timeout_to_deadline(14400), + deadline=common.timeout_to_deadline(14400), subtask_timeout=1200, subtasks_count=num_subtasks, estimated_memory=0, max_price=MIN_PRICE, min_version=golem.__version__, - timestamp=int(get_timestamp_utc()), + timestamp=int(common.get_timestamp_utc()), ) # load the script to be run remotely from the file in the current dir @@ -113,8 +117,12 @@ def __init__(self, client_id, params, num_subtasks, public_key): self.subtask_results = {} self.assigned_nodes = {} self.assigned_subtasks = {} - self.total_tasks = 1 self._lock = Lock() + print( + "Task created." + f" num_subtasks={num_subtasks}" + f" params={params}" + ) def __setstate__(self, state): super(DummyTask, self).__setstate__(state) @@ -174,6 +182,11 @@ def query_extra_data(self, perf_index: float, # assign a task self.assigned_nodes[node_id] = subtask_id self.assigned_subtasks[subtask_id] = node_id + print( + "Subtask assigned" + f" subtask_id={subtask_id}" + f" node_id={common.short_node_id(node_id)}" + ) # create subtask-specific data, 4 bits go for one char (hex digit) data = random.getrandbits(self.task_params.subtask_data_size * 4) @@ -183,7 +196,7 @@ def query_extra_data(self, perf_index: float, subtask_def = ComputeTaskDef() subtask_def['task_id'] = self.task_id subtask_def['subtask_id'] = subtask_id - subtask_def['deadline'] = timeout_to_deadline(5 * 60) + subtask_def['deadline'] = common.timeout_to_deadline(5 * 60) subtask_def['extra_data'] = { 'data_file': self.shared_data_file, 'subtask_data': self.subtask_data[subtask_id], @@ -199,7 +212,12 @@ def verify_task(self): # Check if self.subtask_results contains a non None result # for each subtack. if not len(self.subtask_results) == self.subtasks_count: + print( + "Results vs Count: " + f"{len(self.subtask_results)} != {self.subtasks_count}", + ) return False + print(f"subtask results: {self.subtask_results}") return all(self.subtask_results.values()) def verify_subtask(self, subtask_id): @@ -221,6 +239,11 @@ def verify_subtask(self, subtask_id): def computation_finished(self, subtask_id, task_result, verification_finished=None): + print( + "Computation finished" + f" subtask_id: {subtask_id}" + f" task_result: {task_result}" + ) with self._lock: if subtask_id in self.assigned_subtasks: node_id = self.assigned_subtasks.pop(subtask_id, None) @@ -231,6 +254,8 @@ def computation_finished(self, subtask_id, task_result, if not self.verify_subtask(subtask_id): self.subtask_results[subtask_id] = None + if verification_finished is not None: + verification_finished() def get_resources(self): return self.task_resources @@ -255,7 +280,10 @@ def abort(self): print('DummyTask.abort called') def update_task_state(self, task_state): - print('DummyTask.update_task_state called') + print( + 'DummyTask.update_task_state called' + f" task_state={task_state}" + ) def get_active_tasks(self): return self.assigned_subtasks @@ -287,7 +315,8 @@ def get_finishing_subtasks(self, node_id): return [] def accept_client(self, node_id): - print('DummyTask.accept_client called node_id=%r ' - '- WIP: move more responsibilities from query_extra_data', - node_id) + print( + "DummyTask.accept_client called" + f" node_id={common.short_node_id(node_id)}" + ) return diff --git a/tests/golem/task/dummy/test_runner_script.py b/tests/golem/task/dummy/test_runner_script.py index 7129f19d94..5a5bdb7331 100644 --- a/tests/golem/task/dummy/test_runner_script.py +++ b/tests/golem/task/dummy/test_runner_script.py @@ -102,6 +102,7 @@ def test_run_computing_node(self, mock_config_logging, mock_reactor, _): mock_config_logging.assert_called_once_with( datadir=mock.ANY, loglevel='DEBUG', + formatter_prefix="Ppid ", ) client.quit() diff --git a/tests/golem/task/server/test_helpers.py b/tests/golem/task/server/test_helpers.py new file mode 100644 index 0000000000..14dc5aa0f6 --- /dev/null +++ b/tests/golem/task/server/test_helpers.py @@ -0,0 +1,161 @@ +import pathlib +import time +import unittest +from unittest import mock + +import faker +from golem_messages import message +from golem_messages import factories as msg_factories + +from golem import model +from golem import testutils +from golem.core import keysauth +from golem.network.hyperdrive.client import HyperdriveClientOptions +from golem.task.server import helpers +from tests import factories + +fake = faker.Faker() + + +@mock.patch( + 'golem.network.history.MessageHistoryService.get_sync_as_message', +) +@mock.patch("golem.network.transport.msg_queue.put") +class TestSendReportComputedTask(testutils.TempDirFixture): + def setUp(self): + super().setUp() + self.wtr = factories.taskserver.WaitingTaskResultFactory() + + self.task_server = mock.MagicMock() + self.task_server.cur_port = 31337 + self.task_server.node = msg_factories.datastructures.p2p.Node( + node_name=fake.name(), + ) + self.task_server.get_key_id.return_value = 'key id' + self.task_server.keys_auth = keysauth.KeysAuth( + self.path, + 'filename', + '', + ) + self.task_server.get_share_options.return_value =\ + HyperdriveClientOptions( + "CLI id", + 0.3, + ) + + self.ttc = msg_factories.tasks.TaskToComputeFactory( + task_id=self.wtr.task_id, + subtask_id=self.wtr.subtask_id, + compute_task_def__deadline=int(time.time()) + 3600, + ) + + def assert_submit_task_message(self, subtask_id, wtr): + submit_mock = self.task_server.client.concent_service\ + .submit_task_message + submit_mock.assert_called_once_with( + subtask_id, + mock.ANY, + ) + + msg = submit_mock.call_args[0][1] + self.assertEqual(msg.result_hash, 'sha1:' + wtr.package_sha1) + + @mock.patch( + 'golem.network.history.add', + ) + def test_basic(self, add_mock, put_mock, get_mock, *_): + get_mock.return_value = self.ttc + helpers.send_report_computed_task( + self.task_server, + self.wtr, + ) + + put_mock.assert_called_once() + node_id, rct = put_mock.call_args[0] + + self.assertEqual(node_id, self.wtr.owner.key) + self.assertIsInstance(rct, message.tasks.ReportComputedTask) + self.assertEqual(rct.subtask_id, self.wtr.subtask_id) + self.assertEqual(rct.node_name, self.task_server.node.node_name) + self.assertEqual(rct.address, self.task_server.node.prv_addr) + self.assertEqual(rct.port, self.task_server.cur_port) + self.assertEqual(rct.extra_data, []) + self.assertEqual(rct.node_info, self.task_server.node.to_dict()) + self.assertEqual(rct.package_hash, 'sha1:' + self.wtr.package_sha1) + self.assertEqual(rct.multihash, self.wtr.result_hash) + self.assertEqual(rct.secret, self.wtr.result_secret) + + add_mock.assert_called_once_with( + msg=mock.ANY, + node_id=self.wtr.owner.key, + local_role=model.Actor.Provider, + remote_role=model.Actor.Requestor, + ) + + def test_concent_no_message(self, _put_mock, get_mock, *_): + get_mock.return_value = self.ttc + helpers.send_report_computed_task( + self.task_server, + self.wtr, + ) + self.task_server.concent_service.submit.assert_not_called() + + def test_concent_success(self, _put_mock, get_mock, *_): + self.ttc.concent_enabled = True + get_mock.return_value = self.ttc + helpers.send_report_computed_task( + self.task_server, + self.wtr, + ) + self.assert_submit_task_message(self.wtr.subtask_id, self.wtr) + + def test_concent_success_many_files(self, _put_mock, get_mock, *_): + result = [] + for i in range(100, 300, 99): + p = pathlib.Path(self.tempdir) / str(i) + with p.open('wb') as f: + f.write(b'\0' * i * 2 ** 20) + result.append(str(p)) + self.wtr.result = result + self.ttc.concent_enabled = True + get_mock.return_value = self.ttc + helpers.send_report_computed_task( + self.task_server, + self.wtr, + ) + + self.assert_submit_task_message(self.wtr.subtask_id, self.wtr) + + def test_concent_disabled(self, _put_mock, get_mock, *_): + self.ttc.concent_enabled = False + get_mock.return_value = self.ttc + helpers.send_report_computed_task( + self.task_server, + self.wtr, + ) + self.task_server.client.concent_service.submit.assert_not_called() + + +@mock.patch( + 'golem.network.history.MessageHistoryService.get_sync_as_message', +) +@mock.patch("golem.network.transport.msg_queue.put") +class TestSendTaskFailure(unittest.TestCase): + def setUp(self): + super().setUp() + self.wtf = factories.taskserver.WaitingTaskFailureFactory() + self.ttc = msg_factories.tasks.TaskToComputeFactory( + task_id=self.wtf.task_id, + subtask_id=self.wtf.subtask_id, + compute_task_def__deadline=int(time.time()) + 3600, + ) + + def test_no_task_to_compute(self, put_mock, get_mock, *_): + get_mock.return_value = None + helpers.send_task_failure(self.wtf) + put_mock.assert_not_called() + + def test_basic(self, put_mock, get_mock, *_): + get_mock.return_value = self.ttc + helpers.send_task_failure(self.wtf) + put_mock.assert_called_once() diff --git a/tests/golem/task/server/test_queue.py b/tests/golem/task/server/test_queue.py index a09063a1ca..4c57a09ce5 100644 --- a/tests/golem/task/server/test_queue.py +++ b/tests/golem/task/server/test_queue.py @@ -2,14 +2,15 @@ from unittest import mock import uuid +from freezegun import freeze_time from golem_messages.factories import tasks as tasks_factories from golem import testutils -from golem.network import nodeskeeper +from golem.network.transport import tcpserver from golem.task import taskkeeper -from golem.task.server import queue as srv_queue +from golem.task.server import queue_ as srv_queue -class TestTaskResourcesMixin( +class TestTaskQueueMixin( testutils.DatabaseFixture, testutils.TestWithClient, ): @@ -17,6 +18,7 @@ def setUp(self): super().setUp() self.server = srv_queue.TaskMessagesQueueMixin() self.server._add_pending_request = mock.MagicMock() + self.server._mark_connected = mock.MagicMock() self.server.task_manager = self.client.task_manager self.server.client = self.client self.server.task_keeper = taskkeeper.TaskHeaderKeeper( @@ -24,11 +26,9 @@ def setUp(self): node=self.client.node, min_price=0 ) - self.server.new_session_prepare = mock.MagicMock() self.server.remove_pending_conn = mock.MagicMock() - self.server.remove_responses = mock.MagicMock() - self.server.response_list = {} self.server.pending_connections = {} + self.server.forwarded_session_requests = {} self.message = tasks_factories.ReportComputedTaskFactory() self.node_id = self.message.task_to_compute.want_to_compute_task\ @@ -36,53 +36,25 @@ def setUp(self): self.session = mock.MagicMock() self.conn_id = str(uuid.uuid4()) - @mock.patch("golem.network.transport.msg_queue.put") - def test_send_message(self, mock_put, *_): - nodeskeeper.store( - self.message.task_to_compute.want_to_compute_task.task_header\ - .task_owner, - ) - self.server.send_message( - node_id=self.node_id, - msg=self.message, - ) - mock_put.assert_called_once_with( - self.node_id, - self.message, - ) - - @mock.patch("golem.network.transport.msg_queue.get") - def test_conn_established(self, mock_get, *_): - mock_get.return_value = [self.message, ] + def test_conn_established(self, *_): self.server.msg_queue_connection_established( self.session, self.conn_id, self.node_id, ) - self.server.new_session_prepare.assert_called_once_with( - session=self.session, - key_id=self.node_id, - conn_id=self.conn_id, - ) + self.assertEqual(self.node_id, self.session.key_id) + self.assertEqual(self.conn_id, self.session.conn_id) self.session.send_hello.assert_called_once_with() - mock_get.assert_called_once_with(self.node_id) - self.session.send.assert_called_once_with(self.message) - @mock.patch( - "golem.task.server.queue.TaskMessagesQueueMixin" - ".msg_queue_connection_established", - ) - def test_conn_failure(self, mock_established, *_): + @freeze_time('2019-04-15 11:15:00') + def test_conn_failure(self, *_): + pc = self.server.pending_connections[self.conn_id] = mock.MagicMock() self.server.msg_queue_connection_failure( self.conn_id, node_id=self.node_id, ) - self.server.response_list[self.conn_id][0](self.session) - mock_established.assert_called_once_with( - self.session, - self.conn_id, - node_id=self.node_id, - ) + self.assertEqual(pc.status, tcpserver.PenConnStatus.WaitingAlt) + self.assertEqual(pc.time, 1555326900.0) def test_conn_final_failure(self, *_): self.server.msg_queue_connection_final_failure( diff --git a/tests/golem/task/test_concent_logic.py b/tests/golem/task/test_concent_logic.py index 2d93837751..ba34f64684 100644 --- a/tests/golem/task/test_concent_logic.py +++ b/tests/golem/task/test_concent_logic.py @@ -179,6 +179,10 @@ def test_want_to_compute_task_signed_by_different_key_than_it_contains( task_session_dropped.assert_called_once() +@mock.patch( + 'golem.task.tasksession.TaskSession.verify_owners', + return_value=True, +) class ReactToReportComputedTaskTestCase(testutils.TempDirFixture): def setUp(self): super().setUp() @@ -189,7 +193,7 @@ def setUp(self): private_key_name='priv_key', password='password', ) - self.task_session.key_id = "KEY_ID" + self.task_session.key_id = "deadbeef" self.msg = factories.tasks.ReportComputedTaskFactory() self.msg._fake_sign() self.now = datetime.datetime.utcnow() @@ -211,9 +215,9 @@ def setUp(self): self.task_session.task_manager.tasks_states[task_id] = task_state = \ taskstate.TaskState() ctk = self.task_session.task_manager.comp_task_keeper - ctk.get_node_for_task_id.return_value = "KEY_ID" + ctk.get_node_for_task_id.return_value = self.task_session.key_id self.task_session.task_manager.get_node_id_for_subtask.return_value = \ - "KEY_ID" + self.task_session.key_id task_state.subtask_states[self.msg.subtask_id] = subtask_state = \ taskstate.SubtaskState() subtask_state.deadline = self.msg.task_to_compute.compute_task_def[ @@ -231,25 +235,23 @@ def assert_reject_reason(self, send_mock, reason, **kwargs): self.assertEqual(getattr(msg, attr_name), kwargs[attr_name]) @mock.patch('golem.task.tasksession.TaskSession.dropped') - def test_subtask_id_unknown(self, dropped_mock): + def test_subtask_id_unknown(self, dropped_mock, *_): "Drop if subtask is unknown" self.task_session.task_manager.get_node_id_for_subtask.return_value = \ None self.task_session._react_to_report_computed_task(self.msg) dropped_mock.assert_called_once_with() - self.task_session.task_manager.get_node_id_for_subtask.return_value = \ - "KEY_ID" - @mock.patch('golem.task.tasksession.TaskSession.dropped') - def test_spoofed_task_to_compute(self, dropped_mock): + @mock.patch('golem.task.tasksession.TaskSession.send') + def test_spoofed_task_to_compute(self, send_mock, verify_mock, *_): "Drop if task_to_compute is spoofed" - self.msg.task_to_compute.sig = b'31337' + verify_mock.return_value = False self.task_session._react_to_report_computed_task(self.msg) - dropped_mock.assert_called_once_with() + send_mock.assert_not_called() @mock.patch('golem.network.history.MessageHistoryService.get_sync') @mock.patch('golem.task.tasksession.TaskSession.send') - def test_task_deadline_not_found(self, send_mock, get_mock): + def test_task_deadline_not_found(self, send_mock, get_mock, *_): "Reject if subtask timeout unreachable" get_mock.return_value = [] self.task_session.task_server.task_keeper.task_headers = {} @@ -261,7 +263,7 @@ def test_task_deadline_not_found(self, send_mock, get_mock): @mock.patch('golem.network.history.MessageHistoryService.get_sync') @mock.patch('golem.task.tasksession.TaskSession.send') - def test_subtask_deadline(self, send_mock, get_mock): + def test_subtask_deadline(self, send_mock, get_mock, *_): "Reject after subtask timeout" get_mock.return_value = [] after_deadline = self.now \ @@ -278,7 +280,7 @@ def test_subtask_deadline(self, send_mock, get_mock): 'golem.network.history.MessageHistoryService.get_sync_as_message' ) @mock.patch('golem.task.tasksession.TaskSession.send') - def test_cannot_compute_task_received(self, send_mock, get_mock): + def test_cannot_compute_task_received(self, send_mock, get_mock, *_): "Reject if CannotComputeTask received" get_mock.return_value = unwanted_msg = \ factories.tasks.CannotComputeTaskFactory( @@ -293,7 +295,7 @@ def test_cannot_compute_task_received(self, send_mock, get_mock): ) @mock.patch('golem.task.tasksession.TaskSession.send') - def test_task_failure_received(self, send_mock): + def test_task_failure_received(self, send_mock, *_): "Reject if TaskFailure received" unwanted_msg = factories.tasks.TaskFailureFactory( subtask_id=self.msg.subtask_id, diff --git a/tests/golem/task/test_taskcomputer.py b/tests/golem/task/test_taskcomputer.py index b6b296a113..c2e80364f0 100644 --- a/tests/golem/task/test_taskcomputer.py +++ b/tests/golem/task/test_taskcomputer.py @@ -12,7 +12,7 @@ from golem.core.common import timeout_to_deadline from golem.core.deferred import sync_wait from golem.docker.manager import DockerManager -from golem.task.taskcomputer import TaskComputer, PyTaskThread, logger +from golem.task.taskcomputer import TaskComputer, PyTaskThread from golem.testutils import DatabaseFixture from golem.tools.ci import ci_skip from golem.tools.assertlogs import LogTestCase @@ -70,7 +70,6 @@ def test_run(self): tc2.last_checking = 10 ** 10 tc2.run() - tc2.session_timeout() def test_resource_failure(self): task_server = self.task_server @@ -299,15 +298,12 @@ def test_compute_task(self, start): ) compute_task(*args, **kwargs) - assert task_computer.session_closed.called assert not start.called header = mock.Mock(deadline=time.time() + 3600) task_computer.task_server.task_keeper.task_headers[task_id] = header - task_computer.session_closed.reset_mock() compute_task(*args, **kwargs) - assert not task_computer.session_closed.called assert start.called @staticmethod diff --git a/tests/golem/task/test_taskmanager.py b/tests/golem/task/test_taskmanager.py index 9852875d2b..3f0cddb89c 100644 --- a/tests/golem/task/test_taskmanager.py +++ b/tests/golem/task/test_taskmanager.py @@ -19,11 +19,12 @@ from twisted.internet.defer import fail from apps.appsmanager import AppsManager -from golem.clientconfigdescriptor import ClientConfigDescriptor from apps.core.task.coretask import CoreTask from apps.core.task.coretaskstate import TaskDefinition from apps.blender.task.blenderrendertask import BlenderRenderTask +from golem import model from golem import testutils +from golem.clientconfigdescriptor import ClientConfigDescriptor from golem.core.common import timeout_to_deadline from golem.core.keysauth import KeysAuth from golem.network.p2p.local_node import LocalNode @@ -1332,6 +1333,30 @@ def test_check_timeouts_removes_output_directory(self, mock_get_dir, *_): TaskStatus.timeout, ) + def test_subtask_to_task(self, *_): + task_keeper = Mock(subtask_to_task=dict()) + mapping = dict() + + self.tm.comp_task_keeper = task_keeper + self.tm.subtask2task_mapping = mapping + task_keeper.subtask_to_task['sid_1'] = 'task_1' + mapping['sid_2'] = 'task_2' + + self.assertEqual( + self.tm.subtask_to_task('sid_1', model.Actor.Provider), + 'task_1', + ) + self.assertEqual( + self.tm.subtask_to_task('sid_2', model.Actor.Requestor), + 'task_2', + ) + self.assertIsNone( + self.tm.subtask_to_task('sid_2', model.Actor.Provider), + ) + self.assertIsNone( + self.tm.subtask_to_task('sid_1', model.Actor.Requestor), + ) + class TestCopySubtaskResults(DatabaseFixture): @@ -1464,3 +1489,52 @@ def test_waiting(self, *_): def test_finished(self, *_): self.tm.tasks_states[self.task_id].status = TaskStatus.finished self.assertTrue(self.tm.task_finished(self.task_id)) + + +@patch('golem.core.statskeeper.StatsKeeper._get_or_create') +class TestNeedsComputation(unittest.TestCase): + def setUp(self): + with patch('golem.core.statskeeper.StatsKeeper._get_or_create'): + self.tm = TaskManager( + node=dt_p2p_factory.Node(), + keys_auth=MagicMock(spec=KeysAuth), + root_path='/tmp', + config_desc=ClientConfigDescriptor(), + task_persistence=False + ) + dummy_path = '/fiu/bzdziu' + self.task_id = str(uuid.uuid4()) + self.tm.tasks_states[self.task_id] = TaskState() + definition = TaskDefinition() + definition.options = Mock() + definition.output_format = Mock() + definition.task_id = self.task_id + definition.task_type = "blender" + definition.subtask_timeout = 3671 + definition.timeout = 3671 * 10 + definition.max_price = 1 * 10 ** 18 + definition.resolution = [1920, 1080] + definition.resources = [str(uuid.uuid4()) for _ in range(5)] + #definition.output_file = os.path.join(self.tempdir, 'somefile') + definition.main_scene_file = dummy_path + definition.options.frames = [1] + self.task = BlenderRenderTask( + task_definition=definition, + owner=dt_p2p_factory.Node( + node_name='node', + ), + total_tasks=1, + root_path=dummy_path, + ) + self.tm.tasks[self.task_id] = self.task + + def test_finished(self, *_): + self.tm.tasks_states[self.task_id].status = TaskStatus.finished + self.assertFalse(self.tm.task_needs_computation(self.task_id)) + + def test_task_doesnt_need_computation(self, *_): + self.task.last_task = self.task.total_tasks + self.assertFalse(self.tm.task_needs_computation(self.task_id)) + + def test_needs_computation(self, *_): + self.assertTrue(self.tm.task_needs_computation(self.task_id)) diff --git a/tests/golem/task/test_taskserver.py b/tests/golem/task/test_taskserver.py index 78eb3f1c8f..5f9538995f 100644 --- a/tests/golem/task/test_taskserver.py +++ b/tests/golem/task/test_taskserver.py @@ -4,7 +4,6 @@ import random import tempfile import uuid -from collections import deque from math import ceil from unittest.mock import Mock, MagicMock, patch, ANY @@ -35,14 +34,16 @@ from golem.task.acl import DenyReason as AclDenyReason from golem.task.server import concent as server_concent from golem.task.taskbase import AcceptClientVerdict -from golem.task.taskserver import TASK_CONN_TYPES -from golem.task.taskserver import TaskServer, WaitingTaskResult, logger -from golem.task.tasksession import TaskSession +from golem.task.taskserver import ( + logger, + TaskServer, + WaitingTaskFailure, + WaitingTaskResult, +) from golem.task.taskstate import TaskState, TaskOp from golem.tools.assertlogs import LogTestCase from golem.tools.testwithreactor import TestDatabaseWithReactor -from tests.factories.resultpackage import ExtractedPackageFactory from tests.factories.hyperdrive import hyperdrive_client_kwargs @@ -314,22 +315,6 @@ def test_send_results(self, trust, *_): os.remove(result_file) - def test_connection_for_task_request_established(self, *_): - ccd = ClientConfigDescriptor() - ccd.min_price = 11 - ts = self.ts - session = Mock() - session.address = "10.10.10.10" - session.port = 1020 - ts.conn_established_for_type[TASK_CONN_TYPES['task_request']]( - session, "abc", "nodename", "key", "xyz", 1010, 30, 3, 1) - self.assertIn(session, self.ts.task_sessions_outgoing) - self.assertEqual(session.key_id, "key") - self.assertEqual(session.conn_id, "abc") - session.send_hello.assert_called_with() - session.request_task.assert_called_with("nodename", "xyz", 1010, 30, 3, - 1) - def test_change_config(self, *_): ts = self.ts @@ -340,15 +325,14 @@ def test_change_config(self, *_): # ccd2.use_waiting_ttl = False ts.change_config(ccd2) self.assertEqual(ts.config_desc, ccd2) - self.assertEqual(ts.last_message_time_threshold, 124) self.assertEqual(ts.task_keeper.min_price, 0.0057) self.assertEqual(ts.task_computer.task_request_frequency, 31) # self.assertEqual(ts.task_computer.use_waiting_ttl, False) @patch("golem.task.taskserver.TaskServer._sync_pending") - def test_sync(self, *_): + def test_sync(self, mock_sync_pending, *_): self.ts.sync_network() - self.ts._sync_pending.assert_called_once_with() + mock_sync_pending.assert_called_once_with() @patch("golem.task.taskserver.TaskServer._sync_pending", side_effect=RuntimeError("Intentional failure")) @@ -361,35 +345,6 @@ def test_sync_job_fails(self, *_): .assert_called_once() # pylint: enable=no-member - def test_forwarded_session_requests(self, *_): - ts = self.ts - ts.network = Mock() - - key_id = str(uuid.uuid4()) - conn_id = str(uuid.uuid4()) - subtask_id = str(uuid.uuid4()) - - ts.add_forwarded_session_request(key_id, conn_id) - self.assertEqual(len(ts.forwarded_session_requests), 1) - - ts.forwarded_session_requests[key_id]['time'] = 0 - ts._sync_forwarded_session_requests() - self.assertEqual(len(ts.forwarded_session_requests), 0) - - ts.add_forwarded_session_request(key_id, conn_id) - ts.forwarded_session_requests[key_id] = None - ts._sync_forwarded_session_requests() - self.assertEqual(len(ts.forwarded_session_requests), 0) - - session = MagicMock() - session.address = '127.0.0.1' - session.port = 65535 - - ts.conn_established_for_type[TASK_CONN_TYPES['task_failure']]( - session, conn_id, key_id, subtask_id, "None" - ) - self.assertIn(session, ts.task_sessions_outgoing) - def test_retry_sending_task_result(self, *_): ts = self.ts ts.network = Mock() @@ -403,203 +358,57 @@ def test_retry_sending_task_result(self, *_): ts.retry_sending_task_result(subtask_id) self.assertFalse(wtr.already_sending) - def test_send_waiting_results(self, *_): + @patch("golem.task.server.helpers.send_task_failure") + @patch("golem.task.server.helpers.send_report_computed_task") + def test_send_waiting_results(self, mock_send_rct, mock_send_tf, *_): ts = self.ts - ts.network = Mock() - ts._mark_connected = Mock() - ts.task_computer = Mock() - ts.task_manager = Mock() - ts.task_manager.check_timeouts.return_value = [] - ts.task_keeper = Mock() - ts.task_connections_helper = Mock() - ts._add_pending_request = Mock() - subtask_id = 'xxyyzz' - wtr = Mock() + wtr = WaitingTaskResult( + task_id='task_id', + subtask_id=subtask_id, + result=['result'], + last_sending_trial=0, + delay_time=0, + owner=dt_p2p_factory.Node(), + ) + ts.results_to_send[subtask_id] = wtr wtr.already_sending = True - wtr.last_sending_trial = 0 - wtr.delay_time = 0 - wtr.subtask_id = subtask_id - wtr.address = '127.0.0.1' - wtr.port = 10000 - ts.sync_network() - ts._add_pending_request.assert_not_called() + ts._send_waiting_results() + mock_send_rct.assert_not_called() wtr.last_sending_trial = 0 ts.retry_sending_task_result(subtask_id) - ts.sync_network() - self.assertEquals(ts._add_pending_request.call_count, 1) - - ts._add_pending_request.reset_mock() - ts.task_sessions[subtask_id] = Mock() - ts.task_sessions[subtask_id].last_message_time = float('infinity') - - ts.sync_network() - ts._add_pending_request.assert_not_called() - - ts._add_pending_request.reset_mock() - ts.results_to_send = dict() - - wtf = wtr - - ts.failures_to_send[subtask_id] = wtf - ts.sync_network() - ts._add_pending_request.assert_not_called() - self.assertEqual(ts.failures_to_send, {}) - - ts._add_pending_request.reset_mock() - ts.task_sessions.pop(subtask_id) - - ts.failures_to_send[subtask_id] = wtf - ts.sync_network() - self.assertEquals(ts._add_pending_request.call_count, 1) - self.assertEqual(ts.failures_to_send, {}) - - def test_add_task_session(self, *_): - ts = self.ts - ts.network = Mock() - - session = Mock() - subtask_id = 'xxyyzz' - ts.add_task_session(subtask_id, session) - self.assertIsNotNone(ts.task_sessions[subtask_id]) - - def test_remove_task_session(self, *_): - ts = self.ts - ts.network = Mock() - - conn_id = str(uuid.uuid4()) - session = Mock() - session.conn_id = conn_id - - ts.remove_task_session(session) - ts.task_sessions['task'] = session - ts.remove_task_session(session) - - def test_respond_to(self, *_): - ts = self.ts - ts.network = Mock() - session = Mock() - - ts.respond_to('key_id', session, 'conn_id') - self.assertTrue(session.dropped.called) - - session.dropped.called = False - ts.response_list['conn_id'] = deque([lambda *_: lambda x: x]) - ts.respond_to('key_id', session, 'conn_id') - self.assertFalse(session.dropped.called) - - def test_conn_for_task_failure_established(self, *_): - ts = self.ts - ts.network = Mock() - session = Mock() - session.address = '127.0.0.1' - session.port = 40102 - - method = ts._TaskServer__connection_for_task_failure_established - method(session, 'conn_id', 'key_id', 'subtask_id', 'err_msg') - - self.assertEqual(session.key_id, 'key_id') - self.assertIn(session, ts.task_sessions_outgoing) - self.assertTrue(session.send_hello.called) - session.send_task_failure.assert_called_once_with('subtask_id', - 'err_msg') - - def test_conn_for_start_session_failure(self, *_): - ts = self.ts - ts.network = Mock() - ts.final_conn_failure = Mock() - - method = ts._TaskServer__connection_for_start_session_failure - method('conn_id', 'key_id', Mock(), Mock(), 'ans_conn_id') + ts._send_waiting_results() + mock_send_rct.assert_called_once_with( + task_server=self.ts, + waiting_task_result=wtr, + ) - ts.final_conn_failure.assert_called_with('conn_id') + mock_send_rct.reset_mock() - def test_conn_final_failures(self, *_): - ts = self.ts - ts.network = Mock() - ts.final_conn_failure = Mock() - ts.task_computer = Mock() + ts._send_waiting_results() + mock_send_rct.assert_not_called() - ts.remove_pending_conn = Mock() - ts.remove_responses = Mock() + ts.results_to_send = {} - method = ts._TaskServer__connection_for_task_result_final_failure - wtr = Mock() - method('conn_id', wtr) - - self.assertTrue(ts.remove_pending_conn.called) - self.assertTrue(ts.remove_responses.called) - self.assertFalse(wtr.alreadySending) - self.assertTrue(wtr.lastSendingTrial) - - ts.remove_pending_conn.called = False - ts.remove_responses.called = False - - method = ts._TaskServer__connection_for_task_failure_final_failure - method('conn_id', 'key_id', 'subtask_id', 'err_msg') - - self.assertTrue(ts.remove_pending_conn.called) - self.assertTrue(ts.remove_responses.called) - self.assertTrue(ts.task_computer.session_timeout.called) - ts.remove_pending_conn.called = False - ts.remove_responses.called = False - ts.task_computer.session_timeout.called = False - - method = ts._TaskServer__connection_for_start_session_final_failure - method('conn_id', 'key_id', Mock(), Mock(), 'ans_conn_id') - - self.assertTrue(ts.remove_pending_conn.called) - self.assertTrue(ts.remove_responses.called) - self.assertTrue(ts.task_computer.session_timeout.called) - - ts.remove_pending_conn.reset_mock() - method = ts._TaskServer__connection_for_task_request_final_failure - method('conn_id', 'node_name', 'key_id', 'task_id', 1000, 1000, 1000, - 1024, 3) - ts.remove_pending_conn.assert_called_once_with('conn_id') - - def test_task_result_connection_failure(self, *_): - """Tests what happens after connection failure when sending - task_result""" - node = Mock( - key='deadbeef', - prv_port=None, - prv_addr='10.0.0.10', + wtf = WaitingTaskFailure( + task_id="failed_task_id", + subtask_id=subtask_id, + owner=dt_p2p_factory.Node(), + err_msg="Controlled failure", ) - ts = self.ts - ts.network = MagicMock() - ts.final_conn_failure = Mock() - ts.task_computer = Mock() - ts._is_address_accessible = Mock(return_value=True) - # Always fail on listening - from golem.network.transport import tcpnetwork - ts.network.listen = MagicMock( - side_effect=lambda listen_info, waiting_task_result: - tcpnetwork.TCPNetwork.__call_failure_callback( # noqa pylint: disable=too-many-function-args - listen_info.failure_callback, - {'waiting_task_result': waiting_task_result} - ) - ) - # Try sending mocked task_result - wtr = MagicMock( - owner=node, - ) - ts._add_pending_request( - TASK_CONN_TYPES['task_result'], - node, - prv_port=node.prv_port, - pub_port=node.pub_port, - args={'waiting_task_result': wtr} + ts.failures_to_send[subtask_id] = wtf + ts._send_waiting_results() + mock_send_tf.assert_called_once_with( + waiting_task_failure=wtf, ) - ts._sync_pending() - assert not ts.network.connect.called + self.assertEqual(ts.failures_to_send, {}) def test_should_accept_provider_no_such_task(self, *_args): # given @@ -960,29 +769,9 @@ def test_should_accept_requestor(self, *_): assert UnsupportReason.DENY_LIST in ss.desc self.assertEqual(ss.desc[UnsupportReason.DENY_LIST], "ABC") - @patch('golem.task.taskserver.TaskServer._mark_connected') - def test_new_session_prepare(self, mark_mock, *_): - session = tasksession.TaskSession(conn=MagicMock()) - session.address = '127.0.0.1' - session.port = 10 - - key_id = str(uuid.uuid4()) - conn_id = str(uuid.uuid4()) - - self.ts.new_session_prepare( - session=session, - key_id=key_id, - conn_id=conn_id - ) - self.assertEqual(session.key_id, key_id) - self.assertEqual(session.conn_id, conn_id) - mark_mock.assert_called_once_with(conn_id, session.address, - session.port) - self.assertIn(session, self.ts.task_sessions_outgoing) - def test_new_connection(self, *_): ts = self.ts - tss = TaskSession(Mock()) + tss = tasksession.TaskSession(Mock()) ts.new_connection(tss) assert len(ts.task_sessions_incoming) == 1 assert ts.task_sessions_incoming.pop() == tss @@ -1168,7 +957,8 @@ def test_results(self, trust, *_): def test_disconnect(self, *_): session_mock = Mock() - self.ts.task_sessions_outgoing.add(session_mock) + self.ts.sessions['active_node_id'] = session_mock + self.ts.sessions['pending_node_id'] = None self.ts.disconnect() session_mock.dropped.assert_called_once_with() @@ -1190,10 +980,13 @@ def setUp(self): for parent in self.__class__.__bases__: parent.setUp(self) - self.node = Mock(prv_addr='10.0.0.2', prv_port=40102, - pub_addr='1.2.3.4', pub_port=40102, - hyperg_prv_port=3282, hyperg_pub_port=3282, - prv_addresses=['10.0.0.2'],) + self.node = dt_p2p_factory.Node( + prv_addr='10.0.0.2', + prv_port=40102, + pub_addr='1.2.3.4', + pub_port=40102, + prv_addresses=['10.0.0.2'], + ) self.resource_manager = Mock( add_resources=Mock(side_effect=lambda *a, **b: ([], "a1b2c3")) @@ -1331,67 +1124,3 @@ def test_finished_task_listener(self, *_): op=TaskOp.TIMEOUT) assert remove_task.call_count == 2 assert remove_task_funds_lock.call_count == 2 - - -class TaskVerificationResultTest(TaskServerTestBase): - - def setUp(self): - super().setUp() - self.conn_id = 'connid' - self.key_id = 'keyid' - self.conn_type = TASK_CONN_TYPES['task_verification_result'] - - @staticmethod - def _mock_session(): - session = Mock() - session.address = "10.10.10.10" - session.port = 1020 - return session - - def test_connection_established(self): - session = self._mock_session() - extracted_package = ExtractedPackageFactory() - subtask_id = 'test_subtask_id' - - self.ts.conn_established_for_type[self.conn_type]( - session, self.conn_id, extracted_package, self.key_id, subtask_id - ) - self.assertEqual(session.key_id, self.key_id) - self.assertEqual(session.conn_id, self.conn_id) - result_received_call = session.result_received.call_args[0] - self.assertEqual(result_received_call[0], subtask_id) - - @patch('golem.task.taskserver.logger.warning') - def test_conection_failed(self, log_mock): - extracted_package = ExtractedPackageFactory() - subtask_id = 'test_subtask_id' - self.ts.conn_failure_for_type[self.conn_type]( - self.conn_id, extracted_package, self.key_id, subtask_id - ) - self.assertIn( - "Failed to establish a session", log_mock.call_args[0][0]) - self.assertIn(subtask_id, log_mock.call_args[0][1]) - self.assertIn(self.key_id, log_mock.call_args[0][2]) - - @patch('golem.task.taskserver.TaskServer._is_address_accessible', - Mock(return_value=True)) - @patch('golem.task.taskserver.TaskServer.get_socket_addresses', - Mock(return_value=[Mock()])) - def test_verify_results(self, *_): - rct = msg_factories.tasks.ReportComputedTaskFactory( - node_info=self.ts.node.to_dict()) - extracted_package = ExtractedPackageFactory() - self.ts.verify_results(rct, extracted_package) - pc = list(self.ts.pending_connections.values())[0] - - self.assertEqual( - pc.established.func.__name__, - '__connection_for_task_verification_result_established') - self.assertEqual( - pc.failure.func.__name__, - '__connection_for_task_verification_result_failure', - ) - self.assertEqual( - pc.final_failure.func.__name__, - '__connection_for_task_verification_result_failure', - ) diff --git a/tests/golem/task/test_tasksession.py b/tests/golem/task/test_tasksession.py index ff014d1645..e7015caa93 100644 --- a/tests/golem/task/test_tasksession.py +++ b/tests/golem/task/test_tasksession.py @@ -30,7 +30,6 @@ from golem.docker.environment import DockerEnvironment from golem.docker.image import DockerImage from golem.network.hyperdrive import client as hyperdrive_client -from golem.model import Actor from golem.network import history from golem.network.hyperdrive.client import HyperdriveClientOptions from golem.task import taskstate @@ -38,8 +37,6 @@ from golem.task.tasksession import TaskSession, logger, get_task_message from golem.tools.assertlogs import LogTestCase -from tests import factories - fake = faker.Faker() @@ -248,6 +245,7 @@ def test_request_task(self, *_): options = HyperdriveClientOptions("CLI1", 0.3) ts2.task_server.get_share_options.return_value = options ts2.interpret(mt) + ts2.conn.send_message.assert_called_once() ms = ts2.conn.send_message.call_args[0][0] self.assertIsInstance(ms, message.tasks.TaskToCompute) expected = [ @@ -296,7 +294,7 @@ def setUp(self): super().setUp() random.seed() self.task_session = TaskSession(Mock()) - self.task_session.key_id = 'unittest_key_id' + self.task_session.key_id = 'deadbeef' self.task_session.task_server.get_share_options.return_value = \ hyperdrive_client.HyperdriveClientOptions('1', 1.0) keys_auth = KeysAuth( @@ -306,6 +304,7 @@ def setUp(self): password='', ) self.task_session.task_server.keys_auth = keys_auth + self.task_session.task_server.sessions = {} self.task_session.task_manager.task_finished.return_value = False self.pubkey = keys_auth.public_key self.privkey = keys_auth._private_key @@ -333,118 +332,6 @@ def test_hello(self, send_mock, *_): msg = send_mock.call_args[0][0] self.assertCountEqual(msg.slots(), expected) - @patch( - 'golem.network.history.MessageHistoryService.get_sync_as_message', - ) - @patch( - 'golem.network.history.add', - ) - def test_send_report_computed_task(self, add_mock, get_mock, *_): - ts = self.task_session - ts.verified = True - ts.task_server.get_node_name.return_value = "ABC" - wtr = factories.taskserver.WaitingTaskResultFactory() - - ttc = msg_factories.tasks.TaskToComputeFactory( - task_id=wtr.task_id, - subtask_id=wtr.subtask_id, - compute_task_def__deadline=calendar.timegm(time.gmtime()) + 3600, - ) - get_mock.return_value = ttc - ts.task_server.get_key_id.return_value = 'key id' - ts.send_report_computed_task( - wtr, wtr.owner.pub_addr, wtr.owner.pub_port, wtr.owner) - - rct: message.tasks.ReportComputedTask = \ - ts.conn.send_message.call_args[0][0] - self.assertIsInstance(rct, message.tasks.ReportComputedTask) - self.assertEqual(rct.subtask_id, wtr.subtask_id) - self.assertEqual(rct.node_name, "ABC") - self.assertEqual(rct.address, wtr.owner.pub_addr) - self.assertEqual(rct.port, wtr.owner.pub_port) - self.assertEqual(rct.extra_data, []) - self.assertEqual(rct.node_info, wtr.owner.to_dict()) - self.assertEqual(rct.package_hash, 'sha1:' + wtr.package_sha1) - self.assertEqual(rct.multihash, wtr.result_hash) - self.assertEqual(rct.secret, wtr.result_secret) - - add_mock.assert_called_once_with( - msg=ANY, - node_id=ts.key_id, - local_role=Actor.Provider, - remote_role=Actor.Requestor, - ) - - ts2 = TaskSession(Mock()) - ts2.verified = True - ts2.key_id = "DEF" - ts2.can_be_not_encrypted.append(rct.__class__) - ts2.task_manager.subtask2task_mapping = {wtr.subtask_id: wtr.task_id} - task_state = taskstate.TaskState() - task_state.subtask_states[wtr.subtask_id] = taskstate.SubtaskState() - task_state.subtask_states[wtr.subtask_id].deadline = \ - calendar.timegm(time.gmtime()) + 3600 - ts2.task_manager.tasks_states = { - wtr.task_id: task_state, - } - ts2.task_manager.get_node_id_for_subtask.return_value = "DEF" - get_mock.side_effect = history.MessageNotFound - - with patch( - 'golem.network.concent.helpers.process_report_computed_task', - return_value=msg_factories.tasks.AckReportComputedTaskFactory() - ): - ts2.interpret(rct) - - @patch('golem.task.tasksession.get_task_message') - def test_result_received(self, get_msg_mock, *_): - conn = Mock() - conn.send_message.side_effect = lambda msg: msg._fake_sign() - ts = TaskSession(conn) - ts.task_manager.verify_subtask.return_value = True - keys_auth = KeysAuth( - datadir=self.path, - difficulty=4, - private_key_name='prv', - password='', - ) - ts.task_server.keys_auth = keys_auth - subtask_id = "xxyyzz" - get_msg_mock.return_value = msg_factories \ - .tasks.ReportComputedTaskFactory( - subtask_id=subtask_id, - ) - - def finished(): - if not ts.task_manager.verify_subtask(subtask_id): - ts._reject_subtask_result(subtask_id, '') - ts.dropped() - return - - payment = ts.task_server.accept_result( - subtask_id, - 'key_id', - 'eth_address', - ) - rct = msg_factories.tasks.ReportComputedTaskFactory( - task_to_compute__compute_task_def__subtask_id=subtask_id, - ) - ts.send(msg_factories.tasks.SubtaskResultsAcceptedFactory( - report_computed_task=rct, - payment_ts=payment.processed_ts)) - ts.dropped() - - ts.task_manager.computed_task_received = Mock( - side_effect=finished(), - ) - ts.result_received(subtask_id, pickle.dumps({'stdout': 'xyz'})) - - self.assertTrue(ts.msgs_to_send) - sra = ts.msgs_to_send[0] - self.assertIsInstance(sra, message.tasks.SubtaskResultsAccepted) - - conn.close.assert_called() - def _get_srr(self, key2=None, concent=False): key1 = 'known' key2 = key2 or key1 @@ -534,11 +421,9 @@ def __reset_mocks(): msg.want_to_compute_task.sign_message(keys.raw_privkey) # noqa pylint: disable=no-member msg._fake_sign() ts._react_to_task_to_compute(msg) - ts.task_server.add_task_session.assert_not_called() ts.task_server.task_given.assert_not_called() ts.task_manager.comp_task_keeper.receive_subtask.assert_not_called() ts.send.assert_not_called() - ts.task_computer.session_closed.assert_called_with() assert conn.close.called # No source code in the local environment -> failure @@ -577,8 +462,6 @@ def _prepare_and_react(compute_task_def, resource_size=102400): env.get_source_code.return_value = "print 'Hello world'" msg = _prepare_and_react(ctd) ts.task_manager.comp_task_keeper.receive_subtask.assert_called_with(msg) - ts.task_computer.session_closed.assert_not_called() - ts.task_server.add_task_session.assert_called_with(msg.subtask_id, ts) ts.task_server.task_given.assert_called_with( header.task_owner.key, ctd, @@ -588,7 +471,6 @@ def _prepare_and_react(compute_task_def, resource_size=102400): def __assert_failure(ts, conn, reason): ts.task_manager.comp_task_keeper.receive_subtask.assert_not_called() - ts.task_computer.session_closed.assert_called_with() assert conn.close.called ts.send.assert_called_once_with(ANY) msg = ts.send.call_args[0][0] @@ -633,8 +515,6 @@ def __assert_failure(ts, conn, reason): __reset_mocks() ctd['extra_data']['src_code'] = "print 'Hello world!'" msg = _prepare_and_react(ctd) - ts.task_computer.session_closed.assert_not_called() - ts.task_server.add_task_session.assert_called_with(msg.subtask_id, ts) ts.task_server.task_given.assert_called_with( header.task_owner.key, ctd, @@ -718,20 +598,6 @@ def test_react_to_ack_reject_report_computed_task(self, *_): self.assert_concent_cancel( cancel.call_args[0], subtask_id, 'ForceReportComputedTask') - def test_subtask_to_task(self, *_): - task_keeper = Mock(subtask_to_task=dict()) - mapping = dict() - - self.task_session.task_manager.comp_task_keeper = task_keeper - self.task_session.task_manager.subtask2task_mapping = mapping - task_keeper.subtask_to_task['sid_1'] = 'task_1' - mapping['sid_2'] = 'task_2' - - assert self.task_session._subtask_to_task('sid_1', Actor.Provider) - assert self.task_session._subtask_to_task('sid_2', Actor.Requestor) - assert not self.task_session._subtask_to_task('sid_2', Actor.Provider) - assert not self.task_session._subtask_to_task('sid_1', Actor.Requestor) - @patch('golem.task.taskkeeper.ProviderStatsManager', Mock()) def test_react_to_cannot_assign_task(self, *_): self._test_react_to_cannot_assign_task() @@ -815,7 +681,6 @@ def test_react_to_want_to_compute_invalid_task_header_signature(self, *_): ts._react_to_want_to_compute_task(wtct) sent_msg = ts.conn.send_message.call_args[0][0] - ts.task_server.remove_task_session.assert_called() self.assertIsInstance(sent_msg, message.tasks.CannotAssignTask) self.assertEqual(sent_msg.reason, message.tasks.CannotAssignTask.REASON.NotMyTask) @@ -834,7 +699,6 @@ def test_react_to_want_to_compute_not_my_task_id(self, *_): ts._react_to_want_to_compute_task(wtct) sent_msg = ts.conn.send_message.call_args[0][0] - ts.task_server.remove_task_session.assert_called() self.assertIsInstance(sent_msg, message.tasks.CannotAssignTask) self.assertEqual(sent_msg.reason, message.tasks.CannotAssignTask.REASON.NotMyTask) @@ -966,60 +830,6 @@ def _mock_task_to_compute(task_id, subtask_id, node_id, **kwargs): service = history.MessageHistoryService.instance service.add_sync(nmsg_dict) - def assert_submit_task_message(self, subtask_id, wtr): - self.ts.concent_service.submit_task_message.assert_called_once_with( - subtask_id, ANY) - - msg = self.ts.concent_service.submit_task_message.call_args[0][1] - self.assertEqual(msg.result_hash, 'sha1:' + wtr.package_sha1) - - def test_send_report_computed_task_concent_no_message(self): - wtr = factories.taskserver.WaitingTaskResultFactory(owner=self.n) - self.ts.send_report_computed_task( - wtr, wtr.owner.pub_addr, wtr.owner.pub_port, self.n) - self.ts.concent_service.submit.assert_not_called() - - def test_send_report_computed_task_concent_success(self): - wtr = factories.taskserver.WaitingTaskResultFactory( - xtask_id=self.task_id, xsubtask_id=self.subtask_id, owner=self.n) - self._mock_task_to_compute(self.task_id, self.subtask_id, - self.ts.key_id, concent_enabled=True) - self.ts.send_report_computed_task( - wtr, wtr.owner.pub_addr, wtr.owner.pub_port, self.n) - - self.assert_submit_task_message(self.subtask_id, wtr) - - def test_send_report_computed_task_concent_success_many_files(self): - result = [] - for i in range(100, 300, 99): - p = pathlib.Path(self.tempdir) / str(i) - with p.open('wb') as f: - f.write(b'\0' * i * 2 ** 20) - result.append(str(p)) - - wtr = factories.taskserver.WaitingTaskResultFactory( - xtask_id=self.task_id, xsubtask_id=self.subtask_id, owner=self.n, - result=result - ) - self._mock_task_to_compute(self.task_id, self.subtask_id, - self.ts.key_id, concent_enabled=True) - - self.ts.send_report_computed_task( - wtr, wtr.owner.pub_addr, wtr.owner.pub_port, self.n) - - self.assert_submit_task_message(self.subtask_id, wtr) - - def test_send_report_computed_task_concent_disabled(self): - wtr = factories.taskserver.WaitingTaskResultFactory( - task_id=self.task_id, subtask_id=self.subtask_id, owner=self.n) - - self._mock_task_to_compute( - self.task_id, self.subtask_id, self.node_id, concent_enabled=False) - - self.ts.send_report_computed_task( - wtr, wtr.owner.pub_addr, wtr.owner.pub_port, self.n) - self.ts.concent_service.submit.assert_not_called() - class GetTaskMessageTest(TestCase): def test_get_task_message(self): @@ -1042,6 +852,7 @@ def test_get_task_message_fail(self): class SubtaskResultsAcceptedTest(TestCase): def setUp(self): self.task_session = TaskSession(Mock()) + self.task_session.verified = True self.task_server = Mock() self.task_session.conn.server = self.task_server self.requestor_keys = cryptography.ECCx(None) @@ -1104,6 +915,7 @@ def test_react_with_wrong_key(self): # given key_id = "CDEF" sra = msg_factories.tasks.SubtaskResultsAcceptedFactory() + sra._fake_sign() ctk = self.task_session.task_manager.comp_task_keeper ctk.get_node_for_task_id.return_value = "ABC" self.task_session.key_id = key_id @@ -1114,50 +926,9 @@ def test_react_with_wrong_key(self): # then self.task_server.subtask_accepted.assert_not_called() - def test_result_received(self): - self.task_server.keys_auth._private_key = \ - self.requestor_keys.raw_privkey - self.task_server.keys_auth.public_key = \ - self.requestor_keys.raw_pubkey - self.task_server.accept_result.return_value = 11111 - - def computed_task_received(*args): - args[2]() - - self.task_session.task_manager.computed_task_received = \ - computed_task_received - - rct = msg_factories.tasks.ReportComputedTaskFactory() - ttc = rct.task_to_compute - ttc.sign_message(private_key=self.requestor_keys.raw_privkey) - - self.task_session.send = Mock() - - history_dict = { - 'TaskToCompute': ttc, - 'ReportComputedTask': rct, - } - with patch('golem.task.tasksession.get_task_message', - side_effect=lambda **kwargs: - history_dict[kwargs['message_class_name']]): - self.task_session.result_received( - ttc.compute_task_def.get('subtask_id'), # noqa pylint:disable=no-member - pickle.dumps({'stdout': 'xyz'}), - ) - - assert self.task_session.send.called - sra = self.task_session.send.call_args[0][0] # noqa pylint:disable=unsubscriptable-object - self.assertIsInstance(sra.task_to_compute, message.tasks.TaskToCompute) - self.assertIsInstance(sra.report_computed_task, - message.tasks.ReportComputedTask) - self.assertTrue(sra.task_to_compute.sig) - self.assertTrue( - sra.task_to_compute.verify_signature( - self.requestor_keys.raw_pubkey - ) - ) - +@patch("golem.task.tasksession.TaskSession.verify_owners", return_value=True) +@patch("golem.network.transport.msg_queue.put") class ReportComputedTaskTest( ConcentMessageMixin, LogTestCase, @@ -1190,7 +961,6 @@ def setUp(self): self.subtask_id = idgenerator.generate_id_from_hex(self.node_id) ts = TaskSession(Mock()) - ts.result_received = Mock() ts.key_id = "ABC" ts.task_manager.get_node_id_for_subtask.return_value = ts.key_id ts.task_manager.subtask2task_mapping = { @@ -1217,13 +987,15 @@ def setUp(self): self.addCleanup(gsam.stop) def _prepare_report_computed_task(self, **kwargs): - return msg_factories.tasks.ReportComputedTaskFactory( + msg = msg_factories.tasks.ReportComputedTaskFactory( task_to_compute__task_id=self.task_id, task_to_compute__subtask_id=self.subtask_id, **kwargs, ) + msg._fake_sign() + return msg - def test_result_received(self): + def test_result_received(self, *_): msg = self._prepare_report_computed_task() self.ts.task_manager.task_result_manager.pull_package = \ self._create_pull_package(True) @@ -1237,7 +1009,7 @@ def test_result_received(self): self.assert_concent_cancel( cancel.call_args[0], self.subtask_id, 'ForceGetTaskResult') - def test_reject_result_pull_failed_no_concent(self): + def test_reject_result_pull_failed_no_concent(self, *_): msg = self._prepare_report_computed_task( task_to_compute__concent_enabled=False) @@ -1248,12 +1020,14 @@ def test_reject_result_pull_failed_no_concent(self): with patch('golem.task.tasksession.get_task_message', return_value=msg): with patch('golem.network.concent.helpers.' 'process_report_computed_task', - return_value=message.tasks.AckReportComputedTask()): + return_value=message.tasks.AckReportComputedTask( + report_computed_task=msg, + )): self.ts._react_to_report_computed_task(msg) - assert self.ts.task_server.reject_result.called - assert self.ts.task_manager.task_computation_failure.called + self.ts.task_server.send_result_rejected.assert_called_once() + self.ts.task_manager.task_computation_failure.assert_called_once() - def test_reject_result_pull_failed_with_concent(self): + def test_reject_result_pull_failed_with_concent(self, *_): msg = self._prepare_report_computed_task( task_to_compute__concent_enabled=True) @@ -1299,6 +1073,7 @@ def setUp(self): ) self.task_session = TaskSession(conn) self.task_session.task_server.config_desc.key_difficulty = 1 + self.task_session.task_server.sessions = {} @patch('golem.task.tasksession.TaskSession.send_hello') def test_positive(self, mock_hello, *_): @@ -1375,3 +1150,54 @@ def test_react_to_hello_key_difficult(self, mock_hello, *_): self.task_session._react_to_hello(self.msg) # then mock_hello.assert_called_once_with() + + +class TestDisconnect(TestCase): + def setUp(self): + addr = twisted.internet.address.IPv4Address( + type='TCP', + host=fake.ipv4(), + port=fake.random_int(min=1, max=2**16-1), + ) + conn = MagicMock( + transport=MagicMock( + getPeer=MagicMock(return_value=addr), + ), + ) + self.task_session = TaskSession(conn) + + def test_unverified_without_key_id(self, *_): + self.assertIsNone(self.task_session.key_id) + self.assertFalse(self.task_session.verified) + self.task_session.disconnect( + message.base.Disconnect.REASON.NoMoreMessages, + ) + + +@patch('golem.task.tasksession.TaskSession._cannot_assign_task') +class TestOfferChosen(TestCase): + def setUp(self): + addr = twisted.internet.address.IPv4Address( + type='TCP', + host=fake.ipv4(), + port=fake.random_int(min=1, max=2**16-1), + ) + conn = MagicMock( + transport=MagicMock( + getPeer=MagicMock(return_value=addr), + ), + ) + self.task_session = TaskSession(conn) + self.msg = msg_factories.tasks.WantToComputeTaskFactory() + + def test_ctd_is_none(self, mock_cat, *_): + self.task_session.task_manager.get_next_subtask.return_value = None + self.task_session._offer_chosen( + msg=self.msg, + node_id='deadbeef', + is_chosen=True, + ) + mock_cat.assert_called_once_with( + self.msg.task_id, + message.tasks.CannotAssignTask.REASON.NoMoreSubtasks, + )