Skip to content

Commit dad4df7

Browse files
Make serialization contain comm_id to respect jupyter comm handlers
1 parent 64114f0 commit dad4df7

File tree

4 files changed

+19
-14
lines changed

4 files changed

+19
-14
lines changed

jupyter-lib/shared-compiler/src/main/kotlin/org/jetbrains/kotlinx/jupyter/compiler/util/serializedCompiledScript.kt

+3-2
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,9 @@ data class SerializedVariablesState(
5252

5353
@Serializable
5454
class SerializationReply(
55-
val cellId: Int = 1,
56-
val descriptorsState: Map<String, SerializedVariablesState> = emptyMap()
55+
val cell_id: Int = 1,
56+
val descriptorsState: Map<String, SerializedVariablesState> = emptyMap(),
57+
val comm_id: String = ""
5758
)
5859

5960
@Serializable

src/main/kotlin/org/jetbrains/kotlinx/jupyter/message_types.kt

+5-3
Original file line numberDiff line numberDiff line change
@@ -561,13 +561,15 @@ class SerializationRequest(
561561
val cellId: Int,
562562
val descriptorsState: Map<String, SerializedVariablesState>,
563563
val topLevelDescriptorName: String = "",
564-
val pathToDescriptor: List<String> = emptyList()
564+
val pathToDescriptor: List<String> = emptyList(),
565+
val commId: String = ""
565566
) : MessageContent()
566567

567568
@Serializable
568569
class SerializationReply(
569-
val cellId: Int = 1,
570-
val descriptorsState: Map<String, SerializedVariablesState> = emptyMap()
570+
val cell_id: Int = 1,
571+
val descriptorsState: Map<String, SerializedVariablesState> = emptyMap(),
572+
val comm_id: String = ""
571573
) : MessageContent()
572574

573575
@Serializable(MessageDataSerializer::class)

src/main/kotlin/org/jetbrains/kotlinx/jupyter/protocol.kt

+6-4
Original file line numberDiff line numberDiff line change
@@ -307,21 +307,23 @@ fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJup
307307
sendWrapped(msg, makeReplyMessage(msg, MessageType.COMM_INFO_REPLY, content = CommInfoReply(mapOf())))
308308
}
309309
is CommOpen -> {
310-
if (!content.commId.equals(MessageType.SERIALIZATION_REQUEST.name, ignoreCase = true)) {
310+
if (!content.targetName.equals("kotlin_serialization", ignoreCase = true)) {
311311
send(makeReplyMessage(msg, MessageType.NONE))
312312
return
313313
}
314314
log.debug("Message type in CommOpen: $msg, ${msg.type}")
315315
val data = content.data ?: return sendWrapped(msg, makeReplyMessage(msg, MessageType.SERIALIZATION_REPLY))
316-
316+
if (data.isEmpty()) return sendWrapped(msg, makeReplyMessage(msg, MessageType.SERIALIZATION_REPLY))
317+
log.debug("Message data: $data")
317318
val messageContent = getVariablesDescriptorsFromJson(data)
318319
GlobalScope.launch(Dispatchers.Default) {
319320
repl.serializeVariables(
320321
messageContent.topLevelDescriptorName,
321322
messageContent.descriptorsState,
323+
content.commId,
322324
messageContent.pathToDescriptor
323325
) { result ->
324-
sendWrapped(msg, makeReplyMessage(msg, MessageType.COMM_OPEN, content = result))
326+
sendWrapped(msg, makeReplyMessage(msg, MessageType.COMM_MSG, content = result))
325327
}
326328
}
327329
}
@@ -342,7 +344,7 @@ fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJup
342344
is SerializationRequest -> {
343345
GlobalScope.launch(Dispatchers.Default) {
344346
if (content.topLevelDescriptorName.isNotEmpty()) {
345-
repl.serializeVariables(content.topLevelDescriptorName, content.descriptorsState, content.pathToDescriptor) { result ->
347+
repl.serializeVariables(content.topLevelDescriptorName, content.descriptorsState, commID = content.commId, content.pathToDescriptor) { result ->
346348
sendWrapped(msg, makeReplyMessage(msg, MessageType.SERIALIZATION_REPLY, content = result))
347349
}
348350
} else {

src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl.kt

+5-5
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ interface ReplForJupyter {
134134

135135
suspend fun serializeVariables(cellId: Int, topLevelVarName: String, descriptorsState: Map<String, SerializedVariablesState>, callback: (SerializationReply) -> Unit)
136136

137-
suspend fun serializeVariables(topLevelVarName: String, descriptorsState: Map<String, SerializedVariablesState>, pathToDescriptor: List<String> = emptyList(),
137+
suspend fun serializeVariables(topLevelVarName: String, descriptorsState: Map<String, SerializedVariablesState>, commID: String = "", pathToDescriptor: List<String> = emptyList(),
138138
callback: (SerializationReply) -> Unit)
139139

140140
val homeDir: File?
@@ -552,9 +552,8 @@ class ReplForJupyterImpl(
552552
doWithLock(SerializationArgs(descriptorsState, cellId = cellId, topLevelVarName = topLevelVarName, callback = callback), serializationQueue, SerializationReply(cellId, descriptorsState), ::doSerializeVariables)
553553
}
554554

555-
override suspend fun serializeVariables(topLevelVarName: String, descriptorsState: Map<String, SerializedVariablesState>, pathToDescriptor: List<String>,
556-
callback: (SerializationReply) -> Unit) {
557-
doWithLock(SerializationArgs(descriptorsState, topLevelVarName = topLevelVarName, callback = callback, pathToDescriptor = pathToDescriptor), serializationQueue, SerializationReply(), ::doSerializeVariables)
555+
override suspend fun serializeVariables(topLevelVarName: String, descriptorsState: Map<String, SerializedVariablesState>, commID: String, pathToDescriptor: List<String>, callback: (SerializationReply) -> Unit) {
556+
doWithLock(SerializationArgs(descriptorsState, topLevelVarName = topLevelVarName, callback = callback, comm_id = commID ,pathToDescriptor = pathToDescriptor), serializationQueue, SerializationReply(), ::doSerializeVariables)
558557
}
559558

560559
private fun doSerializeVariables(args: SerializationArgs): SerializationReply {
@@ -569,7 +568,7 @@ class ReplForJupyterImpl(
569568
}
570569
log.debug("Serialization cellID: $cellId")
571570
log.debug("Serialization answer: ${resultMap.entries.first().value.fieldDescriptor}")
572-
return SerializationReply(cellId, resultMap)
571+
return SerializationReply(cellId, resultMap, args.comm_id)
573572
}
574573

575574

@@ -610,6 +609,7 @@ class ReplForJupyterImpl(
610609
var cellId: Int = -1,
611610
val topLevelVarName: String = "",
612611
val pathToDescriptor: List<String> = emptyList(),
612+
val comm_id: String = "",
613613
override val callback: (SerializationReply) -> Unit
614614
) : LockQueueArgs<SerializationReply>
615615

0 commit comments

Comments
 (0)