Skip to content
This repository has been archived by the owner on Aug 2, 2023. It is now read-only.

fix: Avoid TooManySessions error when there is an exact match, not only prefix matches #506

Merged
merged 7 commits into from
Dec 21, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/506.fix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix too many session matching problem when exectue app by session_name
36 changes: 30 additions & 6 deletions src/ai/backend/manager/models/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,12 +321,17 @@ async def match_session_ids(
)
if extra_cond is not None:
cond_id = cond_id & extra_cond
cond_name = (
cond_equal_name = (
(kernels.c.session_name == (f'{session_name_or_id}')) &
(kernels.c.access_key == access_key)
)
cond_prefix_name = (
(kernels.c.session_name.like(f'{session_name_or_id}%')) &
(kernels.c.access_key == access_key)
)
if extra_cond is not None:
cond_name = cond_name & extra_cond
cond_equal_name = cond_equal_name & extra_cond
cond_prefix_name = cond_prefix_name & extra_cond
cond_session_id = (
(sa.sql.expression.cast(kernels.c.session_id, sa.String).like(f'{session_name_or_id}%')) &
(kernels.c.access_key == access_key)
Expand Down Expand Up @@ -358,7 +363,24 @@ async def match_session_ids(
)
if for_update:
match_sid_by_id = match_sid_by_id.with_for_update()
match_sid_by_name = (
match_sid_by_equal_name = (
sa.select(info_cols)
.select_from(kernels)
.where(
(kernels.c.session_id.in_(
sa.select(
[kernels.c.session_id],
)
.select_from(kernels)
.where(cond_equal_name)
.group_by(kernels.c.session_id)
.limit(max_matches).offset(0),
)) &
(kernels.c.cluster_role == DEFAULT_ROLE),
)
.order_by(sa.desc(kernels.c.created_at))
)
match_sid_by_prefix_name = (
sa.select(info_cols)
.select_from(kernels)
.where(
Expand All @@ -367,7 +389,7 @@ async def match_session_ids(
[kernels.c.session_id],
)
.select_from(kernels)
.where(cond_name)
.where(cond_prefix_name)
.group_by(kernels.c.session_id)
.limit(max_matches).offset(0),
)) &
Expand All @@ -376,7 +398,8 @@ async def match_session_ids(
.order_by(sa.desc(kernels.c.created_at))
)
if for_update:
match_sid_by_name = match_sid_by_name.with_for_update()
match_sid_by_equal_name = match_sid_by_equal_name.with_for_update()
match_sid_by_prefix_name = match_sid_by_prefix_name.with_for_update()
match_sid_by_session_id = (
sa.select(info_cols)
.select_from(kernels)
Expand All @@ -398,7 +421,8 @@ async def match_session_ids(
match_sid_by_session_id = match_sid_by_session_id.with_for_update()
for match_query in [
match_sid_by_session_id,
match_sid_by_name,
match_sid_by_equal_name,
match_sid_by_prefix_name,
match_sid_by_id,
]:
result = await db_connection.execute(match_query)
Expand Down