|
1 | 1 | import importlib
|
2 | 2 | import sys
|
| 3 | +from contextlib import contextmanager |
3 | 4 | from dataclasses import dataclass
|
4 | 5 | from logging import getLogger
|
5 | 6 | from pathlib import Path
|
@@ -46,6 +47,18 @@ class ModuleData:
|
46 | 47 | module_import_str: str
|
47 | 48 | extra_sys_path: Path
|
48 | 49 |
|
| 50 | + @contextmanager |
| 51 | + def sys_path(self): |
| 52 | + """ Context manager to temporarily alter sys.path""" |
| 53 | + extra_sys_path = str(self.extra_sys_path) if self.extra_sys_path else "" |
| 54 | + if extra_sys_path: |
| 55 | + logger.warning("Adding %s to sys.path...", extra_sys_path) |
| 56 | + sys.path.insert(0, extra_sys_path) |
| 57 | + yield |
| 58 | + if extra_sys_path and sys.path and sys.path[0] == extra_sys_path: |
| 59 | + logger.warning("Removing %s from sys.path...", extra_sys_path) |
| 60 | + sys.path.pop(0) |
| 61 | + |
49 | 62 |
|
50 | 63 | def get_module_data_from_path(path: Path) -> ModuleData:
|
51 | 64 | logger.info(
|
@@ -165,3 +178,54 @@ def get_import_string(
|
165 | 178 | import_string = f"{mod_data.module_import_str}:{use_app_name}"
|
166 | 179 | logger.info(f"Using import string [b green]{import_string}[/b green]")
|
167 | 180 | return import_string
|
| 181 | + |
| 182 | +def get_app( |
| 183 | + *, path: Union[Path, None] = None, app_name: Union[str, None] = None |
| 184 | +) -> FastAPI: |
| 185 | + if not path: |
| 186 | + path = get_default_path() |
| 187 | + logger.debug(f"Using path [blue]{path}[/blue]") |
| 188 | + logger.debug(f"Resolved absolute path {path.resolve()}") |
| 189 | + if not path.exists(): |
| 190 | + raise FastAPICLIException(f"Path does not exist {path}") |
| 191 | + mod_data = get_module_data_from_path(path) |
| 192 | + try: |
| 193 | + with mod_data.sys_path(): |
| 194 | + mod = importlib.import_module(mod_data.module_import_str) |
| 195 | + except (ImportError, ValueError) as e: |
| 196 | + logger.error(f"Import error: {e}") |
| 197 | + logger.warning( |
| 198 | + "Ensure all the package directories have an [blue]__init__.py[" |
| 199 | + "/blue] file" |
| 200 | + ) |
| 201 | + raise |
| 202 | + if not FastAPI: # type: ignore[truthy-function] |
| 203 | + raise FastAPICLIException( |
| 204 | + "Could not import FastAPI, try running 'pip install fastapi'" |
| 205 | + ) from None |
| 206 | + object_names = dir(mod) |
| 207 | + object_names_set = set(object_names) |
| 208 | + if app_name: |
| 209 | + if app_name not in object_names_set: |
| 210 | + raise FastAPICLIException( |
| 211 | + f"Could not find app name {app_name} in " |
| 212 | + f"{mod_data.module_import_str}" |
| 213 | + ) |
| 214 | + app = getattr(mod, app_name) |
| 215 | + if not isinstance(app, FastAPI): |
| 216 | + raise FastAPICLIException( |
| 217 | + f"The app name {app_name} in {mod_data.module_import_str} " |
| 218 | + f"doesn't seem to be a FastAPI app" |
| 219 | + ) |
| 220 | + return app |
| 221 | + for preferred_name in ["app", "api"]: |
| 222 | + if preferred_name in object_names_set: |
| 223 | + obj = getattr(mod, preferred_name) |
| 224 | + if isinstance(obj, FastAPI): |
| 225 | + return obj |
| 226 | + for name in object_names: |
| 227 | + obj = getattr(mod, name) |
| 228 | + if isinstance(obj, FastAPI): |
| 229 | + return obj |
| 230 | + raise FastAPICLIException( |
| 231 | + "Could not find FastAPI app in module, try using --app") |
0 commit comments