Skip to content

Commit

Permalink
apply context exit logit only when exiting topmost context,
Browse files Browse the repository at this point in the history
also
- make it easier to dump storage configuration
- add backoff to transaction
  • Loading branch information
amakelov committed Jan 14, 2025
1 parent b82e3a9 commit af015e3
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 15 deletions.
65 changes: 57 additions & 8 deletions mandala/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(self,
db_path: str = ":memory:",
overflow_dir: Optional[str] = None,
overflow_threshold_MB: Optional[Union[int, float]] = 50.0,
#! versioning config. this is too much...
deps_path: Optional[Union[str, Path]] = None,
tracer_impl: Optional[type] = None,
strict_tracing: bool = False,
Expand Down Expand Up @@ -76,6 +77,15 @@ def __init__(self,
else:
current_versioner = self.sources['versioner']

self._deps_path = deps_path
self._tracer_impl = tracer_impl
self._strict_tracing = strict_tracing
self._skip_unhashable_globals = skip_unhashable_globals
self._skip_globals_silently = skip_globals_silently
self._skip_missing_deps = skip_missing_deps
self._skip_missing_silently = skip_missing_silently
self._deps_package = deps_package
self._track_globals = track_globals
if deps_path is not None:
deps_path = (
Path(deps_path).absolute().resolve()
Expand Down Expand Up @@ -118,6 +128,22 @@ def __init__(self,
self._mode_stack = []
self._next_mode = 'run'

def dump_config(self) -> dict[str, Any]:
return {
"db_path": self.db.db_path,
"overflow_dir": self.overflow_dir,
"overflow_threshold_MB": self.overflow_threshold_MB,
"deps_path": self._deps_path,
"tracer_impl": self._tracer_impl,
"strict_tracing": self._strict_tracing,
"skip_unhashable_globals": self._skip_unhashable_globals,
"skip_globals_silently": self._skip_globals_silently,
"skip_missing_deps": self._skip_missing_deps,
"skip_missing_silently": self._skip_missing_silently,
"deps_package": self._deps_package,
"track_globals": self._track_globals,
}

@property
def mode(self) -> str:
return self._mode_stack[-1] if self._mode_stack else 'run'
Expand Down Expand Up @@ -1177,22 +1203,45 @@ def __enter__(self) -> "Storage":
self.code_state = code_state
return self

# def __exit__(self, exc_type, exc_value, traceback) -> None:
# Context.current_context = None
# try:
# self.commit()
# except Exception as e:
# raise e
# finally:
# self.cached_versioner = None
# self.code_state = None
# _ = self._mode_stack.pop()
# if self._mode_stack:
# self._next_mode = self._mode_stack[-1]
# else:
# self._next_mode = 'run'
# for hook in self._exit_hooks:
# hook(self)

def __exit__(self, exc_type, exc_value, traceback) -> None:
Context.current_context = None
# Context.current_context = None
try:
self.commit()
if len(self._mode_stack) == 1: # this is the topmost level
self.commit()
except Exception as e:
raise e
finally:
self.cached_versioner = None
self.code_state = None
if len(self._mode_stack) == 1: # this is the topmost level
Context.current_context = None
self.cached_versioner = None
self.code_state = None

for hook in self._exit_hooks:
hook(self)

_ = self._mode_stack.pop()
if self._mode_stack:
self._next_mode = self._mode_stack[-1]
else:
self._next_mode = 'run'
for hook in self._exit_hooks:
hook(self)



class noop:
Expand All @@ -1206,9 +1255,9 @@ class noop:
def __init__(self,):
pass

def __enter__(self) -> "Storage":
def __enter__(self) -> Optional["Storage"]:
if Context.current_context is None:
return self
return None
storage = Context.current_context.storage
res = storage(mode='noop')
return res.__enter__()
Expand Down
38 changes: 31 additions & 7 deletions mandala/storage_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,24 +53,48 @@ def wrapper(self, *args, **kwargs):
if kwargs.get("conn") is not None: # already in a transaction
logging.debug("Folding into existing transaction")
return method(self, *args, **kwargs)
else: # open a connection
logging.debug(
f"Opening new transaction from {self.__class__.__name__}.{method.__name__}"
)

# 10 attempts with exponential backoff, max. time is ~17 minutes
max_attempts = 10
base_delay = 1.0 # in seconds

for attempt in range(max_attempts):
conn = self.conn()
try:
conn.execute("BEGIN IMMEDIATE")
res = method(self, *args, conn=conn, **kwargs)
conn.commit()
return res
except sqlite3.OperationalError as e:
delay = base_delay * (2 ** attempt)
logging.info(f'Transaction failed with error: {e}. Retrying in {delay:.2f} seconds...')
time.sleep(delay)
continue
except Exception as e:
conn.rollback()
raise e
finally:
if not is_in_memory_db(conn):
conn.close()
else:
# in-memory databases are kept open
pass
raise sqlite3.OperationalError("Max retry attempts reached")
# else: # open a connection
# logging.debug(
# f"Opening new transaction from {self.__class__.__name__}.{method.__name__}"
# )
# conn = self.conn()
# try:
# res = method(self, *args, conn=conn, **kwargs)
# conn.commit()
# return res
# except Exception as e:
# conn.rollback()
# raise e
# finally:
# if not is_in_memory_db(conn):
# conn.close()
# else:
# # in-memory databases are kept open
# pass
return wrapper


Expand Down

0 comments on commit af015e3

Please # to comment.