Skip to content

Commit

Permalink
Fix: Inherit from traitlets.HasTraits
Browse files Browse the repository at this point in the history
  • Loading branch information
martinRenou committed Nov 23, 2022
1 parent 27d14fe commit 0a306ce
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
11 changes: 7 additions & 4 deletions comm/base_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
import uuid
import logging

from traitlets import HasTraits
from traitlets.utils.importstring import import_item

logger = logging.getLogger('Comm')



class BaseComm:
class BaseComm(HasTraits):
"""Class for communicating between a Frontend and a Kernel
Must be subclassed with a publish_msg method implementation which
Expand Down Expand Up @@ -92,7 +93,6 @@ def open(self, data=None, metadata=None, buffers=None):

def close(self, data=None, metadata=None, buffers=None, deleting=False):
"""Close the frontend-side version of this comm"""
from comm import get_comm_manager
if self._closed:
# only close once
return
Expand All @@ -107,6 +107,7 @@ def close(self, data=None, metadata=None, buffers=None, deleting=False):
)
if not deleting:
# If deleting, the comm can't be registered
from comm import get_comm_manager
get_comm_manager().unregister_comm(self)

def send(self, data=None, metadata=None, buffers=None):
Expand Down Expand Up @@ -160,15 +161,17 @@ def handle_msg(self, msg):
shell.events.trigger("post_execute")


class CommManager:
class CommManager(HasTraits):
"""Default CommManager singleton implementation for Comms in the Kernel"""

# Public APIs

def __init__(self):
def __init__(self, *args, **kwargs):
self.comms = {}
self.targets = {}

super(CommManager, self).__init__(*args, **kwargs)

def register_target(self, target_name, f):
"""Register a callable f for a given target name
Expand Down
14 changes: 13 additions & 1 deletion tests/test_comm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from traitlets import HasTraits, Instance, Any

from comm.base_comm import CommManager, BaseComm


Expand All @@ -7,11 +9,21 @@ def publish_msg(self, msg_type, data=None, metadata=None, buffers=None, **keys):
pass


class CustomCommManager(CommManager):

parent = Any()


def test_comm_manager():
test = CommManager()
assert test.targets == {}


def test_base_comm():
test = MyComm()
assert test.target_name == "comm"
assert test.target_name == "comm"


def test_custom_comm_manager():
test = CustomCommManager(parent=None)
assert type(test) is CustomCommManager

0 comments on commit 0a306ce

Please # to comment.