diff --git a/changes/537.feature b/changes/537.feature new file mode 100644 index 000000000..0e4a9ebf3 --- /dev/null +++ b/changes/537.feature @@ -0,0 +1 @@ +Allow mounting subpath of vfolders if specified as relative paths appended to vfolder names and improve storage error propagation. Also introduce `StructuredJSONObjectColumn` and `StructuredJSONObjectListColumn` to define database columns based on `common.types.JSONSerializableMixin`. diff --git a/setup.cfg b/setup.cfg index d4797f874..a178556f7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,7 +42,7 @@ install_requires = aiohttp_sse~=2.0 aiomonitor~=0.4.5 aioredis[hiredis]~=2.0 - aiotools~=1.5.3 + aiotools~=1.5.4 alembic~=1.6.5 async_timeout~=4.0 asyncache>=0.1.1 diff --git a/src/ai/backend/manager/api/image.py b/src/ai/backend/manager/api/image.py index 4d75107eb..17a7a6d90 100644 --- a/src/ai/backend/manager/api/image.py +++ b/src/ai/backend/manager/api/image.py @@ -27,6 +27,7 @@ domains, groups, query_allowed_sgroups, association_groups_users as agus, ) +from ..types import UserScope from .auth import admin_required from .exceptions import InvalidAPIParameters from .manager import ALL_ALLOWED, READ_ALLOWED, server_status_required @@ -427,10 +428,12 @@ async def import_image(request: web.Request, params: Any) -> web.Response: None, SessionTypes.BATCH, resource_policy, - domain_name=request['user']['domain_name'], - group_id=group_id, - user_uuid=request['user']['uuid'], - user_role=request['user']['role'], + user_scope=UserScope( + domain_name=request['user']['domain_name'], + group_id=group_id, + user_uuid=request['user']['uuid'], + user_role=request['user']['role'], + ), internal_data={ 'domain_socket_proxies': ['/var/run/docker.sock'], 'docker_credentials': docker_creds, diff --git a/src/ai/backend/manager/api/session.py b/src/ai/backend/manager/api/session.py index 33f7568fe..049b0a98e 100644 --- a/src/ai/backend/manager/api/session.py +++ b/src/ai/backend/manager/api/session.py @@ -89,6 +89,7 @@ from ..config import DEFAULT_CHUNK_SIZE from ..defs import DEFAULT_ROLE, REDIS_STREAM_DB +from ..types import UserScope from ..models import ( domains, association_groups_users as agus, groups, @@ -511,13 +512,14 @@ async def _create(request: web.Request, params: Any) -> web.Response: params['config']['scaling_group'], params['session_type'], resource_policy, - domain_name=params['domain'], # type: ignore # params always have it - group_id=group_id, - user_uuid=owner_uuid, - user_role=request['user']['role'], + user_scope=UserScope( + domain_name=params['domain'], # type: ignore # params always have it + group_id=group_id, + user_uuid=owner_uuid, + user_role=request['user']['role'], + ), cluster_mode=params['cluster_mode'], cluster_size=params['cluster_size'], - startup_command=params['startup_command'], session_tag=params['tag'], starts_at=starts_at, agent_list=params['config']['agent_list'], @@ -1006,10 +1008,12 @@ async def create_cluster(request: web.Request, params: Any) -> web.Response: params['scaling_group'], params['sess_type'], resource_policy, - domain_name=params['domain'], # type: ignore - group_id=group_id, - user_uuid=owner_uuid, - user_role=request['user']['role'], + user_scope=UserScope( + domain_name=params['domain'], # type: ignore + group_id=group_id, + user_uuid=owner_uuid, + user_role=request['user']['role'], + ), session_tag=params['tag'], ), )) diff --git a/src/ai/backend/manager/models/__init__.py b/src/ai/backend/manager/models/__init__.py index 23e381cec..d464244d2 100644 --- a/src/ai/backend/manager/models/__init__.py +++ b/src/ai/backend/manager/models/__init__.py @@ -8,6 +8,7 @@ from . import keypair as _keypair from . import user as _user from . import vfolder as _vfolder +from . import dotfile as _dotfile from . import resource_policy as _rpolicy from . import resource_preset as _rpreset from . import scaling_group as _sgroup @@ -25,6 +26,7 @@ *_keypair.__all__, *_user.__all__, *_vfolder.__all__, + *_dotfile.__all__, *_rpolicy.__all__, *_rpreset.__all__, *_sgroup.__all__, @@ -41,6 +43,7 @@ from .keypair import * # noqa from .user import * # noqa from .vfolder import * # noqa +from .dotfile import * # noqa from .resource_policy import * # noqa from .resource_preset import * # noqa from .scaling_group import * # noqa diff --git a/src/ai/backend/manager/models/alembic/versions/7dd1d81c3204_add_vfolder_mounts_to_kernels.py b/src/ai/backend/manager/models/alembic/versions/7dd1d81c3204_add_vfolder_mounts_to_kernels.py new file mode 100644 index 000000000..8cc5525ec --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/7dd1d81c3204_add_vfolder_mounts_to_kernels.py @@ -0,0 +1,29 @@ +"""add-vfolder-mounts-to-kernels + +Revision ID: 7dd1d81c3204 +Revises: 60a1effa77d2 +Create Date: 2022-03-09 16:41:48.304128 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = '7dd1d81c3204' +down_revision = '60a1effa77d2' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('kernels', sa.Column('vfolder_mounts', sa.JSON(), nullable=True)) + op.drop_index('ix_keypairs_resource_policy', table_name='keypairs') + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_index('ix_keypairs_resource_policy', 'keypairs', ['resource_policy'], unique=False) + op.drop_column('kernels', 'vfolder_mounts') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/base.py b/src/ai/backend/manager/models/base.py index 38409df50..0ae07be32 100644 --- a/src/ai/backend/manager/models/base.py +++ b/src/ai/backend/manager/models/base.py @@ -55,6 +55,7 @@ KernelId, ResourceSlot, SessionId, + JSONSerializableMixin, ) from ai.backend.manager.models.utils import execute_with_retry @@ -192,9 +193,9 @@ def copy(self): return ResourceSlotColumn() -class StructuredJSONBColumn(TypeDecorator): +class StructuredJSONColumn(TypeDecorator): """ - A column type check scheduler_opts's validation using trafaret. + A column type to convert JSON values back and forth using a Trafaret. """ impl = JSONB @@ -213,7 +214,54 @@ def process_result_value(self, raw_value, dialect): return self._schema.check(raw_value) def copy(self): - return StructuredJSONBColumn(self._schema) + return StructuredJSONColumn(self._schema) + + +class StructuredJSONObjectColumn(TypeDecorator): + """ + A column type to convert JSON values back and forth using JSONSerializableMixin. + """ + + impl = JSONB + cache_ok = True + + def __init__(self, attr_cls: Type[JSONSerializableMixin]) -> None: + super().__init__() + self._attr_cls = attr_cls + + def process_bind_param(self, value, dialect): + return self._attr_cls.to_json(value) + + def process_result_value(self, raw_value, dialect): + return self._attr_cls.from_json(raw_value) + + def copy(self): + return StructuredJSONObjectColumn(self._attr_cls) + + +class StructuredJSONObjectListColumn(TypeDecorator): + """ + A column type to convert JSON values back and forth using JSONSerializableMixin, + but store and load a list of the objects. + """ + + impl = JSONB + cache_ok = True + + def __init__(self, attr_cls: Type[JSONSerializableMixin]) -> None: + super().__init__() + self._attr_cls = attr_cls + + def process_bind_param(self, value, dialect): + return [self._attr_cls.to_json(item) for item in value] + + def process_result_value(self, raw_value, dialect): + if raw_value is None: + return [] + return [self._attr_cls.from_json(item) for item in raw_value] + + def copy(self): + return StructuredJSONObjectListColumn(self._attr_cls) class CurrencyTypes(enum.Enum): diff --git a/src/ai/backend/manager/models/dotfile.py b/src/ai/backend/manager/models/dotfile.py new file mode 100644 index 000000000..911de7823 --- /dev/null +++ b/src/ai/backend/manager/models/dotfile.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from pathlib import PurePosixPath +from typing import Any, Mapping, Sequence, TYPE_CHECKING + +import sqlalchemy as sa +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import ( + AsyncConnection as SAConnection, + ) + +from ai.backend.common import msgpack +from ai.backend.common.types import VFolderMount + +from ..api.exceptions import BackendError +from ..types import UserScope +from .keypair import keypairs +from .domain import query_domain_dotfiles +from .group import query_group_dotfiles + +__all__ = ( + 'prepare_dotfiles', +) + + +async def prepare_dotfiles( + conn: SAConnection, + user_scope: UserScope, + access_key: str, + vfolder_mounts: Sequence[VFolderMount], +) -> Mapping[str, Any]: + # Feed SSH keypair and dotfiles if exists. + internal_data = {} + query = ( + sa.select([ + keypairs.c.ssh_public_key, + keypairs.c.ssh_private_key, + keypairs.c.dotfiles, + ]) + .select_from(keypairs) + .where(keypairs.c.access_key == access_key) + ) + result = await conn.execute(query) + row = result.first() + dotfiles = msgpack.unpackb(row['dotfiles']) + internal_data.update({'dotfiles': dotfiles}) + if row['ssh_public_key'] and row['ssh_private_key']: + internal_data['ssh_keypair'] = { + 'public_key': row['ssh_public_key'], + 'private_key': row['ssh_private_key'], + } + # use dotfiles in the priority of keypair > group > domain + dotfile_paths = set(map(lambda x: x['path'], dotfiles)) + # add keypair dotfiles + internal_data.update({'dotfiles': list(dotfiles)}) + # add group dotfiles + dotfiles, _ = await query_group_dotfiles(conn, user_scope.group_id) + for dotfile in dotfiles: + if dotfile['path'] not in dotfile_paths: + internal_data['dotfiles'].append(dotfile) + dotfile_paths.add(dotfile['path']) + # add domain dotfiles + dotfiles, _ = await query_domain_dotfiles(conn, user_scope.domain_name) + for dotfile in dotfiles: + if dotfile['path'] not in dotfile_paths: + internal_data['dotfiles'].append(dotfile) + dotfile_paths.add(dotfile['path']) + # reverse the dotfiles list so that higher priority can overwrite + # in case the actual path is the same + internal_data['dotfiles'].reverse() + + # check if there is no name conflict of dotfile and vfolder + vfolder_kernel_paths = {m.kernel_path for m in vfolder_mounts} + for dotfile in internal_data.get('dotfiles', []): + dotfile_path = PurePosixPath(dotfile['path']) + if not dotfile_path.is_absolute(): + dotfile_path = PurePosixPath('/home/work', dotfile['path']) + if dotfile_path in vfolder_kernel_paths: + raise BackendError( + f"There is a kernel-side path from vfolders that conflicts with " + f"a dotfile '{dotfile['path']}'.", + ) + + return internal_data diff --git a/src/ai/backend/manager/models/kernel.py b/src/ai/backend/manager/models/kernel.py index 5b2d43b69..90b7bf2d0 100644 --- a/src/ai/backend/manager/models/kernel.py +++ b/src/ai/backend/manager/models/kernel.py @@ -41,6 +41,7 @@ SlotName, RedisConnectionInfo, ResourceSlot, + VFolderMount, ) from ..defs import DEFAULT_ROLE @@ -53,6 +54,7 @@ PaginatedList, ResourceSlotColumn, SessionIDColumnType, + StructuredJSONObjectListColumn, batch_result, batch_multiresult, metadata, @@ -65,7 +67,7 @@ if TYPE_CHECKING: from .gql import GraphQueryContext -__all__: Sequence[str] = ( +__all__ = ( 'kernels', 'session_dependencies', 'KernelStatus', @@ -183,8 +185,9 @@ def default_hostname(context) -> str: sa.Column('occupied_slots', ResourceSlotColumn(), nullable=False), sa.Column('occupied_shares', pgsql.JSONB(), nullable=False, default={}), # legacy sa.Column('environ', sa.ARRAY(sa.String), nullable=True), - sa.Column('mounts', sa.ARRAY(sa.String), nullable=True), # list of list - sa.Column('mount_map', pgsql.JSONB(), nullable=True, default={}), + sa.Column('mounts', sa.ARRAY(sa.String), nullable=True), # list of list; legacy since 22.03 + sa.Column('mount_map', pgsql.JSONB(), nullable=True, default={}), # legacy since 22.03 + sa.Column('vfolder_mounts', StructuredJSONObjectListColumn(VFolderMount), nullable=True), sa.Column('attached_devices', pgsql.JSONB(), nullable=True, default={}), sa.Column('resource_opts', pgsql.JSONB(), nullable=True, default={}), sa.Column('bootstrap_script', sa.String(length=16 * 1024), nullable=True), @@ -1292,7 +1295,7 @@ def parse_row(cls, ctx: GraphQueryContext, row: Row) -> Mapping[str, Any]: 'result': row['result'].name, 'service_ports': row['service_ports'], 'occupied_slots': row['occupied_slots'].to_json(), - 'mounts': row['mounts'], + 'vfolder_mounts': row['vfolder_mounts'], 'resource_opts': row['resource_opts'], 'num_queries': row['num_queries'], # optionally hidden diff --git a/src/ai/backend/manager/models/scaling_group.py b/src/ai/backend/manager/models/scaling_group.py index d339f5938..be5f596d6 100644 --- a/src/ai/backend/manager/models/scaling_group.py +++ b/src/ai/backend/manager/models/scaling_group.py @@ -28,7 +28,7 @@ set_if_set, batch_result, batch_multiresult, - StructuredJSONBColumn, + StructuredJSONColumn, ) from .group import resolve_group_name_or_id from .user import UserRole @@ -68,7 +68,7 @@ sa.Column('driver', sa.String(length=64), nullable=False), sa.Column('driver_opts', pgsql.JSONB(), nullable=False, default={}), sa.Column('scheduler', sa.String(length=64), nullable=False), - sa.Column('scheduler_opts', StructuredJSONBColumn( + sa.Column('scheduler_opts', StructuredJSONColumn( t.Dict({ t.Key('allowed_session_types', default=['interactive', 'batch']): t.List(tx.Enum(SessionTypes), min_length=1), diff --git a/src/ai/backend/manager/models/storage.py b/src/ai/backend/manager/models/storage.py index abf47e1c9..f1625d800 100644 --- a/src/ai/backend/manager/models/storage.py +++ b/src/ai/backend/manager/models/storage.py @@ -5,6 +5,7 @@ from contextvars import ContextVar import itertools import logging +from pathlib import PurePosixPath from typing import ( Any, AsyncIterator, @@ -120,12 +121,18 @@ async def _fetch( _ctx_volumes_cache.set(results) return results - async def get_mount_path(self, vfolder_host: str, vfolder_id: UUID) -> str: + async def get_mount_path( + self, + vfolder_host: str, + vfolder_id: UUID, + subpath: PurePosixPath = PurePosixPath("."), + ) -> str: async with self.request( vfolder_host, 'GET', 'folder/mount', json={ 'volume': self.split_host(vfolder_host)[1], 'vfid': str(vfolder_id), + 'subpath': str(subpath), }, ) as (_, resp): reply = await resp.json() @@ -157,7 +164,10 @@ async def request( if client_resp.status // 100 != 2: try: error_data = await client_resp.json() - raise VFolderOperationFailed(extra_data=error_data) + raise VFolderOperationFailed( + extra_msg=error_data.pop("msg"), + extra_data=error_data, + ) except aiohttp.ClientResponseError: # when the response body is not JSON, just raise with status info. raise VFolderOperationFailed( diff --git a/src/ai/backend/manager/models/vfolder.py b/src/ai/backend/manager/models/vfolder.py index 33c0eb159..747c6910d 100644 --- a/src/ai/backend/manager/models/vfolder.py +++ b/src/ai/backend/manager/models/vfolder.py @@ -1,6 +1,9 @@ from __future__ import annotations import enum +import os.path +import uuid +from pathlib import PurePosixPath from typing import ( Any, List, @@ -10,7 +13,6 @@ Set, TYPE_CHECKING, ) -import uuid from dateutil.parser import parse as dtparse import graphene @@ -20,7 +22,11 @@ from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection import trafaret as t +from ai.backend.common.types import VFolderMount + +from ..api.exceptions import InvalidAPIParameters, VFolderNotFound, VFolderOperationFailed from ..defs import RESERVED_VFOLDER_PATTERNS, RESERVED_VFOLDERS +from ..types import UserScope from .base import ( metadata, EnumValueType, GUID, IDColumn, Item, PaginatedList, BigInt, @@ -31,6 +37,7 @@ from .user import UserRole if TYPE_CHECKING: from .gql import GraphQueryContext + from .storage import StorageSessionManager __all__: Sequence[str] = ( 'vfolders', @@ -46,6 +53,7 @@ 'get_allowed_vfolder_hosts_by_group', 'get_allowed_vfolder_hosts_by_user', 'verify_vfolder_name', + 'prepare_vfolder_mounts', ) @@ -448,6 +456,142 @@ async def get_allowed_vfolder_hosts_by_user( return allowed_hosts +async def prepare_vfolder_mounts( + conn: SAConnection, + storage_manager: StorageSessionManager, + allowed_vfolder_types: Sequence[str], + user_scope: UserScope, + requested_mounts: Sequence[str], + requested_mount_map: Mapping[str, str], +) -> Sequence[VFolderMount]: + """ + Determine the actual mount information from the requested vfolder lists, + vfolder configurations, and the given user scope. + """ + + # Fast-path for empty requested mounts + if not requested_mounts: + return [] + + requested_vfolder_names: dict[str, str] = {} + requested_vfolder_subpaths: dict[str, str] = {} + requested_vfolder_dstpaths: dict[str, str] = {} + matched_vfolder_mounts: list[VFolderMount] = [] + + # Split the vfolder name and subpaths + for key in requested_mounts: + name, _, subpath = key.partition("/") + if not PurePosixPath(os.path.normpath(key)).is_relative_to(name): + raise InvalidAPIParameters( + f"The subpath '{subpath}' should designate " + f"a subdirectory of the vfolder '{name}'.", + ) + requested_vfolder_names[key] = name + requested_vfolder_subpaths[key] = os.path.normpath(subpath) + for key, value in requested_mount_map.items(): + requested_vfolder_dstpaths[key] = value + + # Check if there are overlapping mount sources + for p1 in requested_mounts: + for p2 in requested_mounts: + if p1 == p2: + continue + if PurePosixPath(p1).is_relative_to(PurePosixPath(p2)): + raise InvalidAPIParameters( + f"VFolder source path '{p1}' overlaps with '{p2}'", + ) + + # Query the accessible vfolders that satisfy either: + # - the name matches with the requested vfolder name, or + # - the name starts with a dot (dot-prefixed vfolder) for automatic mounting. + if requested_vfolder_names: + extra_vf_conds = ( + vfolders.c.name.in_(requested_vfolder_names.values()) | + vfolders.c.name.startswith('.') + ) + else: + extra_vf_conds = vfolders.c.name.startswith('.') + accessible_vfolders = await query_accessible_vfolders( + conn, user_scope.user_uuid, + user_role=user_scope.user_role, + domain_name=user_scope.domain_name, + allowed_vfolder_types=allowed_vfolder_types, + extra_vf_conds=extra_vf_conds, + ) + + # for vfolder in accessible_vfolders: + for key, vfolder_name in requested_vfolder_names.items(): + for vfolder in accessible_vfolders: + if vfolder['name'] == requested_vfolder_names[key]: + break + else: + raise VFolderNotFound(f"VFolder {vfolder_name} is not found or accessible.") + if vfolder['group'] is not None and vfolder['group'] != str(user_scope.group_id): + # User's accessible group vfolders should not be mounted + # if not belong to the execution kernel. + continue + try: + mount_base_path = PurePosixPath( + await storage_manager.get_mount_path( + vfolder['host'], + vfolder['id'], + PurePosixPath(requested_vfolder_subpaths[key]), + ), + ) + except VFolderOperationFailed as e: + raise InvalidAPIParameters(e.extra_msg, e.extra_data) from None + if vfolder['name'] == '.local' and vfolder['group'] is not None: + # Auto-create per-user subdirectory inside the group-owned ".local" vfolder. + async with storage_manager.request( + vfolder['host'], 'POST', 'folder/file/mkdir', + params={ + 'volume': storage_manager.split_host(vfolder['host'])[1], + 'vfid': vfolder['id'], + 'relpath': str(user_scope.user_uuid.hex), + 'exist_ok': True, + }, + ): + pass + # Mount the per-user subdirectory as the ".local" vfolder. + matched_vfolder_mounts.append(VFolderMount( + name=vfolder['name'], + vfid=vfolder['id'], + vfsubpath=PurePosixPath(user_scope.user_uuid.hex), + host_path=mount_base_path / user_scope.user_uuid.hex, + kernel_path=PurePosixPath("/home/work/.local"), + mount_perm=vfolder['permission'], + )) + else: + # Normal vfolders + kernel_path_raw = requested_vfolder_dstpaths.get(key) + if kernel_path_raw is None: + kernel_path = PurePosixPath(f"/home/work/{vfolder['name']}") + else: + kernel_path = PurePosixPath(kernel_path_raw) + if not kernel_path.is_absolute(): + kernel_path = PurePosixPath("/home/work", kernel_path_raw) + matched_vfolder_mounts.append(VFolderMount( + name=vfolder['name'], + vfid=vfolder['id'], + vfsubpath=PurePosixPath(requested_vfolder_subpaths[key]), + host_path=mount_base_path / requested_vfolder_subpaths[key], + kernel_path=kernel_path, + mount_perm=vfolder['permission'], + )) + + # Check if there are overlapping mount targets + for vf1 in matched_vfolder_mounts: + for vf2 in matched_vfolder_mounts: + if vf1.name == vf2.name: + continue + if vf1.kernel_path.is_relative_to(vf2.kernel_path): + raise InvalidAPIParameters( + f"VFolder mount path {vf1.kernel_path} overlaps with {vf2.kernel_path}", + ) + + return matched_vfolder_mounts + + class VirtualFolder(graphene.ObjectType): class Meta: interfaces = (Item, ) diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index b8eca0ee6..29d02d974 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -32,7 +32,6 @@ import weakref import aiodocker -import aiohttp import aioredis import aiotools from async_timeout import timeout as _timeout @@ -97,7 +96,6 @@ KernelCreationFailed, KernelDestructionFailed, KernelExecutionFailed, KernelRestartFailed, ScalingGroupNotFound, - VFolderNotFound, AgentError, GenericForbidden, QuotaExceeded, @@ -105,13 +103,14 @@ from .config import SharedConfig from .exceptions import MultiAgentError from .defs import DEFAULT_ROLE, INTRINSIC_SLOTS -from .types import SessionGetter +from .types import SessionGetter, UserScope from .models import ( - agents, kernels, keypairs, vfolders, - query_group_dotfiles, query_domain_dotfiles, + agents, kernels, keypairs, keypair_resource_policies, AgentStatus, KernelStatus, - query_accessible_vfolders, query_allowed_sgroups, + prepare_dotfiles, + prepare_vfolder_mounts, + query_allowed_sgroups, recalc_agent_resource_occupancy, recalc_concurrency_used, AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES, @@ -762,21 +761,15 @@ async def enqueue_session( session_type: SessionTypes, resource_policy: dict, *, - domain_name: str, - group_id: uuid.UUID, - user_uuid: uuid.UUID, - user_role: str, + user_scope: UserScope, cluster_mode: ClusterMode = ClusterMode.SINGLE_NODE, cluster_size: int = 1, - startup_command: str = None, session_tag: str = None, internal_data: dict = None, starts_at: datetime = None, agent_list: Sequence[str] = None, ) -> SessionId: - mounts = kernel_enqueue_configs[0]['creation_config'].get('mounts') or [] - mount_map = kernel_enqueue_configs[0]['creation_config'].get('mount_map') or {} session_id = SessionId(uuid.uuid4()) # Check keypair resource limit @@ -786,11 +779,13 @@ async def enqueue_session( f"{resource_policy['max_containers_per_session']} containers.", ) - # Check scaling group availability if scaling_group parameter is given. - # If scaling_group is not provided, it will be selected as the first one among - # the list of allowed scaling groups. async with self.db.begin_readonly() as conn: - sgroups = await query_allowed_sgroups(conn, domain_name, group_id, access_key) + # Check scaling group availability if scaling_group parameter is given. + # If scaling_group is not provided, it will be selected as the first one among + # the list of allowed scaling groups. + sgroups = await query_allowed_sgroups( + conn, user_scope.domain_name, user_scope.group_id, access_key, + ) if not sgroups: raise ScalingGroupNotFound("You have no scaling groups allowed to use.") if scaling_group is None: @@ -806,67 +801,28 @@ async def enqueue_session( break else: raise ScalingGroupNotFound(f"The scaling group {scaling_group} does not exist.") - assert scaling_group is not None + assert scaling_group is not None + + # Translate mounts/mount_map into vfolder mounts + requested_mounts = kernel_enqueue_configs[0]['creation_config'].get('mounts') or [] + requested_mount_map = kernel_enqueue_configs[0]['creation_config'].get('mount_map') or {} + allowed_vfolder_types = await self.shared_config.get_vfolder_types() + vfolder_mounts = await prepare_vfolder_mounts( + conn, + self.storage_manager, + allowed_vfolder_types, + user_scope, + requested_mounts, + requested_mount_map, + ) - # sanity check for vfolders - allowed_vfolder_types = ['user', 'group'] - # allowed_vfolder_types = await root_ctx.shared_config.etcd.get('path-to-vfolder-type') - determined_mounts = [] - matched_mounts = set() - async with self.db.begin_readonly() as conn: - if mounts: - extra_vf_conds = ( - vfolders.c.name.in_(mounts) | - vfolders.c.name.startswith('.') - ) - else: - extra_vf_conds = vfolders.c.name.startswith('.') - matched_vfolders = await query_accessible_vfolders( - conn, user_uuid, - user_role=user_role, domain_name=domain_name, - allowed_vfolder_types=allowed_vfolder_types, - extra_vf_conds=extra_vf_conds) - - for item in matched_vfolders: - if item['group'] is not None and item['group'] != str(group_id): - # User's accessible group vfolders should not be mounted - # if not belong to the execution kernel. - continue - mount_path = await self.storage_manager.get_mount_path(item['host'], item['id']) - if item['name'] == '.local' and item['group'] is not None: - try: - async with self.storage_manager.request( - item['host'], 'POST', 'folder/file/mkdir', - params={ - 'volume': self.storage_manager.split_host(item['host'])[1], - 'vfid': item['id'], - 'relpath': str(user_uuid.hex), - }, - ): - pass - except aiohttp.ClientResponseError: - # the server may respond with error if the directory already exists - pass - matched_mounts.add(item['name']) - determined_mounts.append(( - item['name'], - item['host'], - f"{mount_path}/{user_uuid.hex}", - item['permission'].value, - '', - )) - else: - matched_mounts.add(item['name']) - determined_mounts.append(( - item['name'], - item['host'], - mount_path, - item['permission'].value, - item['unmanaged_path'] if item['unmanaged_path'] else '', - )) - if mounts and set(mounts) > matched_mounts: - raise VFolderNotFound(extra_data=[*(set(mounts) - matched_mounts)]) - mounts = determined_mounts + # Prepare internal data for common dotfiles. + dotfile_data = await prepare_dotfiles( + conn, + user_scope, + access_key, + vfolder_mounts, + ) ids = [] is_multicontainer = cluster_size > 1 @@ -898,6 +854,10 @@ async def enqueue_session( else: raise InvalidAPIParameters("Missing kernel configurations") + # Prepare internal data. + internal_data = {} if internal_data is None else internal_data + internal_data.update(dotfile_data) + hook_result = await self.hook_plugin_ctx.dispatch( 'PRE_ENQUEUE_SESSION', (session_id, session_name, access_key), @@ -916,7 +876,7 @@ async def enqueue_session( image_ref = kernel['image_ref'] resource_opts = creation_config.get('resource_opts') or {} - creation_config['mounts'] = mounts + creation_config['mounts'] = [vfmount.to_json() for vfmount in vfolder_mounts] # TODO: merge into a single call image_info = await self.shared_config.inspect_image(image_ref) image_min_slots, image_max_slots = \ @@ -1020,59 +980,6 @@ async def enqueue_session( environ = kernel_enqueue_configs[0]['creation_config'].get('environ') or {} # Create kernel object in PENDING state. - async with self.db.begin_readonly() as conn: - # Feed SSH keypair and dotfiles if exists. - query = (sa.select([keypairs.c.ssh_public_key, - keypairs.c.ssh_private_key, - keypairs.c.dotfiles]) - .select_from(keypairs) - .where(keypairs.c.access_key == access_key)) - result = await conn.execute(query) - row = result.first() - dotfiles = msgpack.unpackb(row['dotfiles']) - internal_data = {} if internal_data is None else internal_data - internal_data.update({'dotfiles': dotfiles}) - if row['ssh_public_key'] and row['ssh_private_key']: - internal_data['ssh_keypair'] = { - 'public_key': row['ssh_public_key'], - 'private_key': row['ssh_private_key'], - } - # use dotfiles in the priority of keypair > group > domain - dotfile_paths = set(map(lambda x: x['path'], dotfiles)) - # add keypair dotfiles - internal_data.update({'dotfiles': list(dotfiles)}) - # add group dotfiles - dotfiles, _ = await query_group_dotfiles(conn, group_id) - for dotfile in dotfiles: - if dotfile['path'] not in dotfile_paths: - internal_data['dotfiles'].append(dotfile) - dotfile_paths.add(dotfile['path']) - # add domain dotfiles - dotfiles, _ = await query_domain_dotfiles(conn, domain_name) - for dotfile in dotfiles: - if dotfile['path'] not in dotfile_paths: - internal_data['dotfiles'].append(dotfile) - dotfile_paths.add(dotfile['path']) - # reverse the dotfiles list so that higher priority can overwrite - # in case the actual path is the same - internal_data['dotfiles'].reverse() - - # check if there is no name conflict of dotfile and vfolder - for dotfile in internal_data.get('dotfiles', []): - if dotfile['path'].startswith('/'): - if dotfile['path'].startswith('/home/'): - path_arr = dotfile['path'].split('/') - # check if there is a dotfile whose path equals /home/work/vfolder_name - if len(path_arr) >= 3 and path_arr[2] == 'work' and \ - path_arr[3] in matched_mounts: - raise BackendError( - f'There is a vfolder whose name conflicts with ' - f'dotfile {path_arr[3]} with path "{dotfile["path"]}"') - else: - if dotfile['path'] in matched_mounts: - raise BackendError( - f'There is a vfolder whose name conflicts with ' - f'dotfile {dotfile["path"]}') mapped_agent = None if not agent_list: pass @@ -1096,9 +1003,9 @@ async def _enqueue() -> None: 'cluster_idx': kernel['cluster_idx'], 'cluster_hostname': f"{kernel['cluster_role']}{kernel['cluster_idx']}", 'scaling_group': scaling_group, - 'domain_name': domain_name, - 'group_id': group_id, - 'user_uuid': user_uuid, + 'domain_name': user_scope.domain_name, + 'group_id': user_scope.group_id, + 'user_uuid': user_scope.user_uuid, 'access_key': access_key, 'image': image_ref.canonical, 'registry': image_ref.registry, @@ -1110,8 +1017,10 @@ async def _enqueue() -> None: 'occupied_shares': {}, 'resource_opts': resource_opts, 'environ': [f'{k}={v}' for k, v in environ.items()], - 'mounts': [list(mount) for mount in mounts], # postgres save tuple as str - 'mount_map': mount_map, + 'mounts': [ # TODO: keep for legacy? + mount.name for mount in vfolder_mounts + ], + 'vfolder_mounts': vfolder_mounts, 'bootstrap_script': kernel.get('bootstrap_script'), 'repl_in_port': 0, 'repl_out_port': 0, @@ -1429,8 +1338,7 @@ async def _create_kernels_in_one_agent( 'cluster_idx': binding.kernel.cluster_idx, 'cluster_hostname': binding.kernel.cluster_hostname, 'idle_timeout': resource_policy['idle_timeout'], - 'mounts': scheduled_session.mounts, - 'mount_map': scheduled_session.mount_map, + 'mounts': [item.to_json() for item in scheduled_session.vfolder_mounts], 'environ': { # inherit per-session environment variables **scheduled_session.environ, diff --git a/src/ai/backend/manager/scheduler/types.py b/src/ai/backend/manager/scheduler/types.py index 63baa1d70..82a0eb5b1 100644 --- a/src/ai/backend/manager/scheduler/types.py +++ b/src/ai/backend/manager/scheduler/types.py @@ -37,6 +37,7 @@ ResourceSlot, SlotName, SlotTypes, + VFolderMount, ) from ..defs import DEFAULT_ROLE @@ -194,8 +195,7 @@ class PendingSession: requested_slots: ResourceSlot target_sgroup_names: MutableSequence[str] environ: MutableMapping[str, str] - mounts: Sequence[str] - mount_map: Mapping[str, str] + vfolder_mounts: Sequence[VFolderMount] bootstrap_script: Optional[str] startup_command: Optional[str] internal_data: Optional[MutableMapping[str, Any]] @@ -230,8 +230,7 @@ def db_cols(cls) -> Set[ColumnElement]: kernels.c.internal_data, kernels.c.resource_opts, kernels.c.environ, - kernels.c.mounts, - kernels.c.mount_map, + kernels.c.vfolder_mounts, kernels.c.bootstrap_script, kernels.c.startup_command, kernels.c.preopen_ports, @@ -276,8 +275,7 @@ def from_row(cls, row: Row) -> PendingSession: k: v for k, v in map(lambda s: s.split('=', maxsplit=1), row['environ']) }, - mounts=row['mounts'], - mount_map=row['mount_map'], + vfolder_mounts=row['vfolder_mounts'], bootstrap_script=row['bootstrap_script'], startup_command=row['startup_command'], preopen_ports=row['preopen_ports'], diff --git a/src/ai/backend/manager/types.py b/src/ai/backend/manager/types.py index fbfbfdce9..d70abd4b8 100644 --- a/src/ai/backend/manager/types.py +++ b/src/ai/backend/manager/types.py @@ -1,4 +1,6 @@ +import attr import enum +import uuid from typing import ( Protocol, ) @@ -23,3 +25,11 @@ def __call__(self, *, db_connection: SAConnection) -> Row: class Sentinel(enum.Enum): token = 0 + + +@attr.define(slots=True) +class UserScope: + domain_name: str + group_id: uuid.UUID + user_uuid: uuid.UUID + user_role: str diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 4c57e9a69..c933bedc6 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -252,8 +252,7 @@ class SessionKernelIdPair: group_id=example_group_id, resource_policy={}, resource_opts={}, - mounts=[], - mount_map={}, + vfolder_mounts=[], environ={}, bootstrap_script=None, startup_command=None,