diff --git a/solara/server/kernel.py b/solara/server/kernel.py index c46b6108b5..dcc9f1fc1d 100644 --- a/solara/server/kernel.py +++ b/solara/server/kernel.py @@ -6,6 +6,7 @@ import struct from typing import Set +import ipykernel import ipykernel.kernelbase import jupyter_client.session as session from ipykernel.comm import CommManager @@ -17,6 +18,33 @@ logger = logging.getLogger("solara.server.kernel") + +ipykernel_version = tuple(map(int, ipykernel.__version__.split("."))) +if ipykernel_version >= (6, 18, 0): + import comm.base_comm + + class Comm(comm.base_comm.BaseComm): + # log = logging.getLogger("Comm") + + def __init__(self, **kwargs) -> None: + self.kernel = ipykernel.kernelbase.Kernel.instance() + super().__init__(**kwargs) + + def publish_msg(self, msg_type, data=None, metadata=None, buffers=None, **keys): + data = {} if data is None else data + metadata = {} if metadata is None else metadata + content = dict(data=data, comm_id=self.comm_id, **keys) + self.kernel.session.send( + self.kernel.iopub_socket, + msg_type, + content, + metadata=metadata, + parent=self.kernel.get_parent("shell"), + ident=self.topic, + buffers=buffers, + ) + + comm.create_comm = Comm # from notebook.base.zmqhandlers import serialize_binary_message # this saves us a depdendency on notebook/jupyter_server when e.g. # running on pyodide @@ -160,7 +188,11 @@ def __init__(self): # solara/server/kernel.py:111: error: "SessionWebsocket" has no attribute "stream" # not sure why we cannot reproduce that locally self.session.stream = self.iopub_socket # type: ignore - self.comm_manager = CommManager(parent=self, kernel=self) + if ipykernel_version >= (6, 18, 0): + # from this version on, ipykernel uses the comm package https://github.com/ipython/ipykernel/pull/973 + self.comm_manager = CommManager() + else: + self.comm_manager = CommManager(parent=self, kernel=self) self.shell = None self.log = logging.getLogger("fake")