diff --git a/src/sora.cpp b/src/sora.cpp index 517e2a7..d4966e8 100644 --- a/src/sora.cpp +++ b/src/sora.cpp @@ -208,6 +208,14 @@ std::shared_ptr Sora::CreateConnection( if (video_frame_transformer) { conn->SetVideoSenderFrameTransformer(video_frame_transformer); } + + weak_connections_.erase( + std::remove_if( + weak_connections_.begin(), weak_connections_.end(), + [](std::weak_ptr w) { return w.expired(); }), + weak_connections_.end()); + weak_connections_.push_back(conn); + return conn; } diff --git a/src/sora.h b/src/sora.h index 3ec8bcb..bef0bb6 100644 --- a/src/sora.h +++ b/src/sora.h @@ -159,6 +159,8 @@ class Sora : public DisposePublisher { */ SoraVideoSource* CreateVideoSource(); + std::vector> weak_connections_; + private: /** * Python で渡された値を boost::json::value に変換します。 diff --git a/src/sora_sdk_ext.cpp b/src/sora_sdk_ext.cpp index b08b029..c3d0b40 100644 --- a/src/sora_sdk_ext.cpp +++ b/src/sora_sdk_ext.cpp @@ -288,6 +288,38 @@ PyType_Slot connection_slots[] = { {Py_tp_clear, (void*)connection_tp_clear}, {0, nullptr}}; +int sora_tp_traverse(PyObject* self, visitproc visit, void* arg) { + if (!nb::inst_ready(self)) { + return 0; + } + + Sora* sora = nb::inst_ptr(self); + for (auto wc : sora->weak_connections_) { + auto conn = wc.lock(); + if (conn) { + nb::object conn_obj = nb::find(conn); + Py_VISIT(conn_obj.ptr()); + } + } + + return 0; +} + +int sora_tp_clear(PyObject* self) { + if (!nb::inst_ready(self)) { + return 0; + } + + Sora* sora = nb::inst_ptr(self); + sora->weak_connections_.clear(); + + return 0; +} + +PyType_Slot sora_slots[] = {{Py_tp_traverse, (void*)sora_tp_traverse}, + {Py_tp_clear, (void*)sora_tp_clear}, + {0, nullptr}}; + /** * Python で利用するすべてのクラスと定数は以下のように定義しなければならない */ @@ -529,7 +561,7 @@ NB_MODULE(sora_sdk_ext, m) { .def("__del__", &SoraVideoFrameTransformer::Del) .def_rw("on_transform", &SoraVideoFrameTransformer::on_transform_); - nb::class_(m, "Sora") + nb::class_(m, "Sora", nb::type_slots(sora_slots)) .def(nb::init, std::optional>(), "use_hardware_encoder"_a = nb::none(), "openh264"_a = nb::none()) .def("create_connection", &Sora::CreateConnection, "signaling_urls"_a, diff --git a/test_with_llvm.py b/test_with_llvm.py index 77f28c6..59e371c 100644 --- a/test_with_llvm.py +++ b/test_with_llvm.py @@ -8,7 +8,9 @@ def test(debugger, command, result, internal_dict): debugger.HandleCommand("settings set target.process.follow-fork-mode child") target = debugger.CreateTargetWithFileAndArch("uv", lldb.LLDB_ARCH_DEFAULT) - process = target.LaunchSimple(["run", "pytest", "tests", "-s"], None, None) + process = target.LaunchSimple( + ["run", "pytest", "tests/test_sora_disconnect.py", "-s"], None, None + ) if not process: print("Error: could not launch process")