diff --git a/jaxtyping/_ipython_extension.py b/jaxtyping/_ipython_extension.py index 7960b2f..ff5219b 100644 --- a/jaxtyping/_ipython_extension.py +++ b/jaxtyping/_ipython_extension.py @@ -20,7 +20,9 @@ from ._import_hook import JaxtypingTransformer, Typechecker -try: +def choose_typechecker_magics(): + # The import is local to avoid degrading import times when the magic is + # not needed. from IPython.core.magic import line_magic, Magics, magics_class @magics_class @@ -40,17 +42,15 @@ def typechecker(self, typechecker): JaxtypingTransformer(typechecker=Typechecker(typechecker)) ) -except Exception: - # Very broad exception-handling, as e.g. IPython will sometimes be - # present but fail to import for mysterious reasons. - pass + return ChooseTypecheckerMagics def load_ipython_extension(ipython): try: - ipython.register_magics(ChooseTypecheckerMagics) - except NameError: - raise NameError( - "ChooseTypecheckerMagics is not defined.\n\n" - + "You may be trying to use IPython extension without IPython installed." - ) + ChooseTypecheckerMagics = choose_typechecker_magics() + except Exception as e: + # Very broad exception-handling, as e.g. IPython will sometimes be + # present but fail to import for mysterious reasons. + raise RuntimeError("Failed to define jaxtyping.typechecker magic") from e + + ipython.register_magics(ChooseTypecheckerMagics)