diff --git a/tutorials/sphinx_tuto/export.py b/tutorials/sphinx_tuto/export.py index 1d1e5f30b..df8e3fda5 100644 --- a/tutorials/sphinx_tuto/export.py +++ b/tutorials/sphinx_tuto/export.py @@ -132,6 +132,25 @@ # and the FX graph: print("fx graph:", model_export.graph_module.print_readable()) +################################################## +# Working with nested keys +# ~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Nested keys are a core feature of the tensordict library, and being able to export modules that read and write +# nested entries is therefore an important feature to support. +# Because keyword arguments must be regualar strings, it is not possible for :class:`~tensordict.nn.dispatch` to work +# directly with them. Instead, ``dispatch`` will unpack nested keys joined with a regular underscore (`"_"`), as the +# following example shows. + +model_nested = Seq( + Mod(lambda x: x + 1, in_keys=[("some", "key")], out_keys=["hidden"]), + Mod(lambda x: x - 1, in_keys=["hidden"], out_keys=[("some", "output")]), +).select_out_keys(("some", "output")) + +model_nested_export = export(model_nested, args=(), kwargs={"some_key": x}) +print("exported module with nested input:", model_nested_export.module()) + + ################################################## # Note that the callable returned by `module()` is a pure python callable that can be in turn compiled using # :func:`~torch.compile`.