Skip to content
This repository has been archived by the owner on Aug 2, 2023. It is now read-only.

Commit

Permalink
fix: missing pieces in #400 (#405)
Browse files Browse the repository at this point in the history
* Fix a regression in session list graph queries
  • Loading branch information
achimnol authored Mar 19, 2021
1 parent e4b2597 commit b929ba5
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 30 deletions.
1 change: 1 addition & 0 deletions changes/405.fix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Follow-up fixes for #400
3 changes: 1 addition & 2 deletions src/ai/backend/manager/models/gql.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
ComputeContainerList,
LegacyComputeSession,
LegacyComputeSessionList,
KernelStatus,
)
from .keypair import (
KeyPair,
Expand Down Expand Up @@ -1043,7 +1042,7 @@ async def resolve_compute_session_list(
domain_name: str = None,
group_id: uuid.UUID = None,
access_key: AccessKey = None,
status: KernelStatus = None,
status: str = None,
order_key: str = None,
order_asc: bool = True,
) -> ComputeSessionList:
Expand Down
56 changes: 28 additions & 28 deletions src/ai/backend/manager/models/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ class Meta:
dependencies = graphene.List(lambda: ComputeSession)

@classmethod
def parse_row(cls, context: Any, row: RowProxy) -> Mapping[str, Any]:
def parse_row(cls, ctx: GraphQueryContext, row: RowProxy) -> Mapping[str, Any]:
assert row is not None
return {
# identity
Expand Down Expand Up @@ -746,10 +746,10 @@ def parse_row(cls, context: Any, row: RowProxy) -> Mapping[str, Any]:
}

@classmethod
def from_row(cls, context: Mapping[str, Any], row: RowProxy) -> Optional[ComputeSession]:
def from_row(cls, ctx: GraphQueryContext, row: RowProxy) -> ComputeSession | None:
if row is None:
return None
props = cls.parse_row(context, row)
props = cls.parse_row(ctx, row)
return cls(**props)

async def resolve_occupied_slots(self, info: graphene.ResolveInfo) -> Mapping[str, Any]:
Expand Down Expand Up @@ -787,18 +787,18 @@ async def resolve_dependencies(
@classmethod
async def load_count(
cls,
context,
ctx: GraphQueryContext,
*,
domain_name=None,
group_id=None,
access_key=None,
status=None,
domain_name: str = None,
group_id: uuid.UUID = None,
access_key: str = None,
status: str = None,
) -> int:
if isinstance(status, str):
status_list = [KernelStatus[s] for s in status.split(',')]
elif isinstance(status, KernelStatus):
status_list = [status]
async with context['dbpool'].acquire() as conn:
async with ctx.dbpool.acquire() as conn:
query = (
sa.select([sa.func.count(kernels.c.id)])
.select_from(kernels)
Expand All @@ -819,22 +819,22 @@ async def load_count(
@classmethod
async def load_slice(
cls,
context,
ctx: GraphQueryContext,
limit: int,
offset: int,
*,
domain_name=None,
group_id=None,
access_key=None,
status=None,
order_key=None,
order_asc=None,
) -> Sequence[Optional[ComputeSession]]:
domain_name: str = None,
group_id: uuid.UUID = None,
access_key: str = None,
status: str = None,
order_key: str = None,
order_asc: bool = True,
) -> Sequence[ComputeSession | None]:
if isinstance(status, str):
status_list = [KernelStatus[s] for s in status.split(',')]
elif isinstance(status, KernelStatus):
status_list = [status]
async with context['dbpool'].acquire() as conn:
async with ctx.dbpool.acquire() as conn:
if order_key is None:
_ordering = DEFAULT_SESSION_ORDERING
else:
Expand Down Expand Up @@ -865,15 +865,15 @@ async def load_slice(
query = query.where(kernels.c.access_key == access_key)
if status is not None:
query = query.where(kernels.c.status.in_(status_list))
return [cls.from_row(context, r) async for r in conn.execute(query)]
return [cls.from_row(ctx, r) async for r in conn.execute(query)]

@classmethod
async def batch_load_by_dependency(
cls,
context,
ctx: GraphQueryContext,
session_ids: Sequence[SessionId],
) -> Sequence[Sequence[ComputeSession]]:
async with context['dbpool'].acquire() as conn:
async with ctx.dbpool.acquire() as conn:
j = sa.join(
kernels, session_dependencies,
kernels.c.session_id == session_dependencies.c.depends_on,
Expand All @@ -887,20 +887,20 @@ async def batch_load_by_dependency(
)
)
return await batch_multiresult(
context, conn, query, cls,
ctx, conn, query, cls,
session_ids, lambda row: row['id'],
)

@classmethod
async def batch_load_detail(
cls,
context,
ctx: GraphQueryContext,
session_ids: Sequence[SessionId],
*,
domain_name=None,
access_key=None,
) -> Sequence[Optional[ComputeSession]]:
async with context['dbpool'].acquire() as conn:
domain_name: str = None,
access_key: str = None,
) -> Sequence[ComputeSession | None]:
async with ctx.dbpool.acquire() as conn:
j = (
kernels
.join(groups, groups.c.id == kernels.c.group_id)
Expand All @@ -922,7 +922,7 @@ async def batch_load_detail(
if access_key is not None:
query = query.where(kernels.c.access_key == access_key)
return await batch_result(
context, conn, query, cls,
ctx, conn, query, cls,
session_ids, lambda row: row['id'],
)

Expand Down

0 comments on commit b929ba5

Please # to comment.