From 2d9513f009544795a13aa0f27647196382996a44 Mon Sep 17 00:00:00 2001 From: Alex Carney Date: Sat, 30 Nov 2024 14:25:26 +0000 Subject: [PATCH] refactor: re-implement pygls' builtin handlers using generators The underlying cause of #433 is that pygls' current implementation of builtin feature handlers cannot guarantee that an async user handler will finish executing before pygls responds with the answer generated from the builtin handler. This commit adds support for another execution model, generators. A generator handler can yield to another sub-handler method like so ``` yield handler_func, args, kwargs ``` The `JsonRPCProtocol` class with then schedule the execution of `handler_func(*args, **kwargs)` as if it were a normal handler function (meaning `handler_func could be async, threaded, sync or a generator itself!) The result of the sub-handler is then sent back into the generator handler allowing the top-level handler to continue and even make use of the result! This gives pygls' built-in handlers much greater control over exactly when a user handler is called, allowing us to fix #433 and opens up a lot other exciting possibilities! This also removes the need for the `LSPMeta` metaclass, so it and the corresponding module have been deleted. --- pygls/protocol/__init__.py | 6 +- pygls/protocol/json_rpc.py | 113 ++++++++++++++++++++++------ pygls/protocol/language_server.py | 118 ++++++++++++++++++++---------- pygls/protocol/lsp_meta.py | 51 ------------- tests/test_protocol.py | 31 -------- 5 files changed, 173 insertions(+), 146 deletions(-) delete mode 100644 pygls/protocol/lsp_meta.py diff --git a/pygls/protocol/__init__.py b/pygls/protocol/__init__.py index 1a30b485..6c61e877 100644 --- a/pygls/protocol/__init__.py +++ b/pygls/protocol/__init__.py @@ -1,7 +1,6 @@ import json -from typing import Any - from collections import namedtuple +from typing import Any from lsprotocol import converters @@ -12,7 +11,6 @@ JsonRPCResponseMessage, ) from pygls.protocol.language_server import LanguageServerProtocol, lsp_method -from pygls.protocol.lsp_meta import LSPMeta, call_user_feature def _dict_to_object(d: Any): @@ -68,8 +66,6 @@ def default_converter(): "JsonRPCRequestMessage", "JsonRPCResponseMessage", "JsonRPCNotification", - "LSPMeta", - "call_user_feature", "_dict_to_object", "_params_field_structure_hook", "_result_field_structure_hook", diff --git a/pygls/protocol/json_rpc.py b/pygls/protocol/json_rpc.py index a9e2cfb9..b4f68d59 100644 --- a/pygls/protocol/json_rpc.py +++ b/pygls/protocol/json_rpc.py @@ -34,7 +34,6 @@ from lsprotocol.types import ( CANCEL_REQUEST, EXIT, - WORKSPACE_EXECUTE_COMMAND, ResponseError, ResponseErrorMessage, ) @@ -51,6 +50,8 @@ from pygls.feature_manager import FeatureManager, is_thread_function if typing.TYPE_CHECKING: + from collections.abc import Generator + from cattrs import Converter from pygls.io_ import AsyncWriter, Writer @@ -130,8 +131,9 @@ def _execute_handler( self, msg_id: str | int, handler: MessageHandler, - params: Any, callback: MessageCallback, + args: tuple[Any, ...] | None = None, + kwargs: dict[str, Any] | None = None, ): """Execute the given message handler. @@ -143,22 +145,41 @@ def _execute_handler( handler The request handler to call - params - The parameters object to pass to the handler - callback An optional callback function to call upon completion of the handler + + args + Positional arguments to pass to the handler + + kwargs + Keyword arguments to pass to the handler """ + args = args or tuple() + kwargs = kwargs or {} + if asyncio.iscoroutinefunction(handler): - future = asyncio.ensure_future(handler(params)) + future = asyncio.ensure_future(handler(*args, **kwargs)) self._request_futures[msg_id] = future future.add_done_callback(callback) elif is_thread_function(handler): - future = self._server.thread_pool.submit(handler, params) + future = self._server.thread_pool.submit(handler, *args, **kwargs) self._request_futures[msg_id] = future future.add_done_callback(callback) + + elif inspect.isgeneratorfunction(handler): + future: Future[Any] = Future() + self._request_futures[msg_id] = future + future.add_done_callback(callback) + + try: + self._run_generator( + future=None, gen=handler(*args, **kwargs), result_future=future + ) + except Exception as exc: + future.set_exception(exc) + else: # While a future is not necessary for a synchronous function, it allows us to use a single # pattern across all handler types @@ -166,11 +187,61 @@ def _execute_handler( future.add_done_callback(callback) try: - result = handler(params) + result = handler(*args, **kwargs) future.set_result(result) except Exception as exc: future.set_exception(exc) + def _run_generator( + self, + future: Future[Any] | None, + *, + gen: Generator[Any, Any, Any], + result_future: Future[Any], + ): + """Run the next portion of the given generator. + + Generator handlers are designed to ``yield`` to other handlers that are executed + separately before their results are sent back into the generator allowing + execution to continue. + + Generator handlers are primarily used in the implementation of pygls' builtin + feature handlers. + + Parameters + ---------- + future + The future that contains the result of the previously executed handler, if any + + gen + The generator to run + + result_future + The future to send the final result to once the generator stops. + """ + + if result_future.cancelled(): + return + + try: + value = future.result() if future is not None else None + handler, args, kwargs = gen.send(value) + + self._execute_handler( + str(uuid.uuid4()), + handler, + args=args, + kwargs=kwargs, + callback=partial( + self._run_generator, gen=gen, result_future=result_future + ), + ) + except StopIteration as result: + result_future.set_result(result.value) + + except Exception as exc: + result_future.set_exception(exc) + def _send_handler_result(self, future: Future[Any], *, msg_id: str | int): """Callback function that sends the result of the given future to the client. @@ -192,6 +263,7 @@ def _send_handler_result(self, future: Future[Any], *, msg_id: str | int): error = JsonRpcInternalError.of(sys.exc_info()) logger.exception('Exception occurred for message "%s": %s', msg_id, error) self._send_response(msg_id, error=error.to_response_error()) + self._server._report_server_error(error, FeatureRequestError) def _check_handler_result(self, future: Future[Any]): """Check the result of the future to see if an error occurred. @@ -237,7 +309,10 @@ def _handle_notification(self, method_name, params): try: handler = self._get_handler(method_name) self._execute_handler( - str(uuid.uuid4()), handler, params, self._check_handler_result + msg_id=str(uuid.uuid4()), + handler=handler, + args=(params,), + callback=self._check_handler_result, ) except JsonRpcMethodNotFound: logger.warning("Ignoring notification for unknown method %r", method_name) @@ -255,16 +330,12 @@ def _handle_request(self, msg_id, method_name, params): try: handler = self._get_handler(method_name) - # workspace/executeCommand is a special case - if method_name == WORKSPACE_EXECUTE_COMMAND: - handler(params, msg_id) - else: - self._execute_handler( - msg_id, - handler, - params, - callback=partial(self._send_handler_result, msg_id=msg_id), - ) + self._execute_handler( + msg_id=msg_id, + handler=handler, + args=(params,), + callback=partial(self._send_handler_result, msg_id=msg_id), + ) except JsonRpcMethodNotFound as error: logger.warning( @@ -369,10 +440,10 @@ def handle_message(self, message): if hasattr(message, "method"): if hasattr(message, "id"): - logger.debug("Request message received.") + logger.debug("Request %r received", message.method) self._handle_request(message.id, message.method, message.params) else: - logger.debug("Notification message received.") + logger.debug("Notification %r received", message.method) self._handle_notification(message.method, message.params) else: if hasattr(message, "error"): diff --git a/pygls/protocol/language_server.py b/pygls/protocol/language_server.py index 5a7bb76a..9a879924 100644 --- a/pygls/protocol/language_server.py +++ b/pygls/protocol/language_server.py @@ -22,29 +22,25 @@ import logging import sys import typing -from functools import lru_cache, partial +from functools import lru_cache from itertools import zip_longest -from typing import ( - Callable, - Optional, - Type, - TypeVar, -) from lsprotocol import types from pygls.capabilities import ServerCapabilitiesBuilder from pygls.protocol.json_rpc import JsonRPCProtocol -from pygls.protocol.lsp_meta import LSPMeta from pygls.uris import from_fs_path from pygls.workspace import Workspace if typing.TYPE_CHECKING: + from collections.abc import Generator + from typing import Any, Callable, Optional, Type, TypeVar + from cattrs import Converter from pygls.lsp.server import LanguageServer -F = TypeVar("F", bound=Callable) + F = TypeVar("F", bound=Callable) logger = logging.getLogger(__name__) @@ -57,7 +53,7 @@ def decorator(f: F) -> F: return decorator -class LanguageServerProtocol(JsonRPCProtocol, metaclass=LSPMeta): +class LanguageServerProtocol(JsonRPCProtocol): """A class that represents language server protocol. It contains implementations for generic LSP features. @@ -105,17 +101,22 @@ def workspace(self) -> Workspace: return self._workspace @lru_cache() - def get_message_type(self, method: str) -> Optional[Type]: + def get_message_type(self, method: str) -> Type[Any] | None: """Return LSP type definitions, as provided by `lsprotocol`""" return types.METHOD_TO_TYPES.get(method, (None,))[0] @lru_cache() - def get_result_type(self, method: str) -> Optional[Type]: + def get_result_type(self, method: str) -> Type[Any] | None: return types.METHOD_TO_TYPES.get(method, (None, None))[1] @lsp_method(types.EXIT) def lsp_exit(self, *args) -> None: """Stops the server process.""" + + # Ensure that the user handler is called first + if (user_handler := self.fm.features.get(types.EXIT)) is not None: + yield user_handler, args, None + returncode = 0 if self._shutdown else 1 if self.writer is None: sys.exit(returncode) @@ -176,13 +177,19 @@ def lsp_initialize(self, params: types.InitializeParams) -> types.InitializeResu ) @lsp_method(types.INITIALIZED) - def lsp_initialized(self, *args) -> None: + def lsp_initialized(self, *args): """Notification received when client and server are connected.""" - pass + + if (user_handler := self.fm.features.get(types.INITIALIZED)) is not None: + yield user_handler, args, None @lsp_method(types.SHUTDOWN) def lsp_shutdown(self, *args) -> None: """Request from client which asks server to shutdown.""" + + if (user_handler := self.fm.features.get(types.SHUTDOWN)) is not None: + yield user_handler, args, None + for future in self._request_futures.values(): future.cancel() @@ -190,59 +197,86 @@ def lsp_shutdown(self, *args) -> None: return None @lsp_method(types.TEXT_DOCUMENT_DID_CHANGE) - def lsp_text_document__did_change( - self, params: types.DidChangeTextDocumentParams - ) -> None: + def lsp_text_document__did_change(self, params: types.DidChangeTextDocumentParams): """Updates document's content. (Incremental(from server capabilities); not configurable for now) """ for change in params.content_changes: self.workspace.update_text_document(params.text_document, change) + if ( + user_handler := self.fm.features.get(types.TEXT_DOCUMENT_DID_CHANGE) + ) is not None: + yield user_handler, (params,), None + @lsp_method(types.TEXT_DOCUMENT_DID_CLOSE) - def lsp_text_document__did_close( - self, params: types.DidCloseTextDocumentParams - ) -> None: + def lsp_text_document__did_close(self, params: types.DidCloseTextDocumentParams): """Removes document from workspace.""" self.workspace.remove_text_document(params.text_document.uri) + if ( + user_handler := self.fm.features.get(types.TEXT_DOCUMENT_DID_CLOSE) + ) is not None: + yield user_handler, (params,), None + @lsp_method(types.TEXT_DOCUMENT_DID_OPEN) - def lsp_text_document__did_open( - self, params: types.DidOpenTextDocumentParams - ) -> None: + def lsp_text_document__did_open(self, params: types.DidOpenTextDocumentParams): """Puts document to the workspace.""" self.workspace.put_text_document(params.text_document) + if ( + user_handler := self.fm.features.get(types.TEXT_DOCUMENT_DID_OPEN) + ) is not None: + yield user_handler, (params,), None + @lsp_method(types.NOTEBOOK_DOCUMENT_DID_OPEN) def lsp_notebook_document__did_open( self, params: types.DidOpenNotebookDocumentParams - ) -> None: + ): """Put a notebook document into the workspace""" self.workspace.put_notebook_document(params) + if ( + user_handler := self.fm.features.get(types.NOTEBOOK_DOCUMENT_DID_OPEN) + ) is not None: + yield user_handler, (params,), None + @lsp_method(types.NOTEBOOK_DOCUMENT_DID_CHANGE) def lsp_notebook_document__did_change( self, params: types.DidChangeNotebookDocumentParams - ) -> None: + ): """Update a notebook's contents""" self.workspace.update_notebook_document(params) + if ( + user_handler := self.fm.features.get(types.NOTEBOOK_DOCUMENT_DID_CHANGE) + ) is not None: + yield user_handler, (params,), None + @lsp_method(types.NOTEBOOK_DOCUMENT_DID_CLOSE) def lsp_notebook_document__did_close( self, params: types.DidCloseNotebookDocumentParams - ) -> None: + ): """Remove a notebook document from the workspace.""" self.workspace.remove_notebook_document(params) + if ( + user_handler := self.fm.features.get(types.NOTEBOOK_DOCUMENT_DID_CLOSE) + ) is not None: + yield user_handler, (params,), None + @lsp_method(types.SET_TRACE) def lsp_set_trace(self, params: types.SetTraceParams) -> None: """Changes server trace value.""" self.trace = params.value + if (user_handler := self.fm.features.get(types.SET_TRACE)) is not None: + yield user_handler, (params,), None + @lsp_method(types.WORKSPACE_DID_CHANGE_WORKSPACE_FOLDERS) def lsp_workspace__did_change_workspace_folders( self, params: types.DidChangeWorkspaceFoldersParams - ) -> None: + ): """Adds/Removes folders from the workspace.""" logger.info("Workspace folders changed: %s", params) @@ -255,23 +289,26 @@ def lsp_workspace__did_change_workspace_folders( if f_remove: self.workspace.remove_folder(f_remove.uri) + if ( + user_handler := self.fm.features.get( + types.WORKSPACE_DID_CHANGE_WORKSPACE_FOLDERS + ) + ) is not None: + yield user_handler, (params,), None + @lsp_method(types.WORKSPACE_EXECUTE_COMMAND) def lsp_workspace__execute_command( - self, params: types.ExecuteCommandParams, msg_id: str - ) -> None: + self, params: types.ExecuteCommandParams + ) -> Generator[Any, Any, Any]: """Executes commands with passed arguments and returns a value.""" cmd_handler = self.fm.commands[params.command] - self._execute_handler( - msg_id, - cmd_handler, - params.arguments, - partial(self._send_handler_result, msg_id=msg_id), - ) + + # Call the user's command implementation + result = yield cmd_handler, (params.arguments,), None + return result @lsp_method(types.WINDOW_WORK_DONE_PROGRESS_CANCEL) - def lsp_work_done_progress_cancel( - self, params: types.WorkDoneProgressCancelParams - ) -> None: + def lsp_work_done_progress_cancel(self, params: types.WorkDoneProgressCancelParams): """Received a progress cancellation from client.""" future = self.progress.tokens.get(params.token) if future is None: @@ -280,3 +317,8 @@ def lsp_work_done_progress_cancel( ) else: future.cancel() + + if ( + user_handler := self.fm.features.get(types.WINDOW_WORK_DONE_PROGRESS_CANCEL) + ) is not None: + yield user_handler, (params,), None diff --git a/pygls/protocol/lsp_meta.py b/pygls/protocol/lsp_meta.py deleted file mode 100644 index 0dc52db0..00000000 --- a/pygls/protocol/lsp_meta.py +++ /dev/null @@ -1,51 +0,0 @@ -import functools -import logging -from pygls.constants import ATTR_FEATURE_TYPE -from pygls.feature_manager import assign_help_attrs - - -logger = logging.getLogger(__name__) - - -def call_user_feature(base_func, method_name): - """Wraps generic LSP features and calls user registered feature - immediately after it. - """ - - @functools.wraps(base_func) - def decorator(self, *args, **kwargs): - ret_val = base_func(self, *args, **kwargs) - - try: - user_func = self.fm.features[method_name] - self._execute_notification(user_func, *args, **kwargs) - except KeyError: - pass - except Exception: - logger.exception( - 'Failed to handle user defined notification "%s": %s', method_name, args - ) - - return ret_val - - return decorator - - -class LSPMeta(type): - """Wraps LSP built-in features (`lsp_` naming convention). - - Built-in features cannot be overridden but user defined features with - the same LSP name will be called after them. - """ - - def __new__(mcs, cls_name, cls_bases, cls): - for attr_name, attr_val in cls.items(): - if callable(attr_val) and hasattr(attr_val, "method_name"): - method_name = attr_val.method_name - wrapped = call_user_feature(attr_val, method_name) - assign_help_attrs(wrapped, method_name, ATTR_FEATURE_TYPE) - cls[attr_name] = wrapped - - logger.debug('Added decorator for lsp method: "%s"', attr_name) - - return super().__new__(mcs, cls_name, cls_bases, cls) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 9a526370..73d6047f 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -16,22 +16,17 @@ ############################################################################ import io import json -from pathlib import Path from typing import Optional -from unittest.mock import Mock import attrs import pytest from lsprotocol.types import ( PROGRESS, TEXT_DOCUMENT_COMPLETION, - ClientCapabilities, CompletionItem, CompletionItemKind, CompletionParams, CompletionResponse, - InitializeParams, - InitializeResult, Position, ProgressParams, ShutdownResponse, @@ -521,29 +516,3 @@ def test_serialize_request_message(method, params, expected): actual = json.loads(buffer.getvalue()) assert actual == expected - - -def test_initialize_should_return_server_capabilities(client_server): - _, server = client_server - params = InitializeParams( - process_id=1234, - root_uri=Path(__file__).parent.as_uri(), - capabilities=ClientCapabilities(), - ) - - server_capabilities = server.protocol.lsp_initialize(params) - - assert isinstance(server_capabilities, InitializeResult) - - -def test_ignore_unknown_notification(client_server): - _, server = client_server - - fn = server.protocol._execute_notification - server.protocol._execute_notification = Mock() - - server.protocol._handle_notification("random/notification", None) - assert not server.protocol._execute_notification.called - - # Remove mock - server.protocol._execute_notification = fn