Skip to content
This repository has been archived by the owner on Aug 2, 2023. It is now read-only.

feat: Upgrade SQLAlchemy v1.4 for native asyncio support #406

Merged
merged 21 commits into from
Mar 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/406.fix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Upgrade SQLAlchemy to v1.4 for native asyncio support and better transaction/concurrency handling
2 changes: 1 addition & 1 deletion config/ci.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ drivers = ["console"]
[logging.pkg-ns]
"" = "WARNING"
"aiotools" = "INFO"
"aiopg" = "WARNING"
"aiohttp" = "INFO"
"ai.backend" = "INFO"
"alembic" = "INFO"
"sqlalchemy" = "WARNING"

[logging.console]
colored = true
Expand Down
2 changes: 1 addition & 1 deletion config/halfstack.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ drivers = ["console"]
[logging.pkg-ns]
"" = "WARNING"
"aiotools" = "INFO"
"aiopg" = "WARNING"
"aiohttp" = "INFO"
"ai.backend" = "INFO"
"alembic" = "INFO"
"sqlalchemy" = "WARNING"

[logging.console]
colored = true
Expand Down
2 changes: 1 addition & 1 deletion config/sample.toml
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,10 @@ ssl-verify = true
[logging.pkg-ns]
"" = "WARNING"
"aiotools" = "INFO"
"aiopg" = "WARNING"
"aiohttp" = "INFO"
"ai.backend" = "INFO"
"alembic" = "INFO"
"sqlalchemy" = "WARNING"


[debug]
Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ install_requires =
aiohttp_cors~=0.7
aiohttp_sse~=2.0
aiojobs~=0.3.0
aiopg~=1.1.0
aioredis~=1.3.1
aioredlock~=0.7.0
aiotools~=1.2.1
alembic~=1.4.3
async_timeout~=3.0
asyncache>=0.1.1
asyncpg>=0.22.0
attrs>=20.3
boltons~=20.2.1
callosum~=0.9.7
Expand All @@ -62,7 +62,7 @@ install_requires =
python-snappy~=0.6.0
PyYAML~=5.4.1
pyzmq~=22.0.3
SQLAlchemy~=1.3.20
SQLAlchemy~=1.4.2
uvloop~=0.15.1
setproctitle~=1.2.2
tabulate~=0.8.6
Expand Down
62 changes: 32 additions & 30 deletions src/ai/backend/gateway/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,36 +64,38 @@ async def handle_gql(request: web.Request, params: Any) -> web.Response:
app_ctx: PrivateContext = request.app['admin.context']
manager_status = await root_ctx.shared_config.get_manager_status()
known_slot_types = await root_ctx.shared_config.get_resource_slots()
gql_ctx = GraphQueryContext(
dataloader_manager=DataLoaderManager(),
local_config=root_ctx.local_config,
shared_config=root_ctx.shared_config,
etcd=root_ctx.shared_config.etcd,
user=request['user'],
access_key=request['keypair']['access_key'],
dbpool=root_ctx.dbpool,
redis_stat=root_ctx.redis_stat,
redis_image=root_ctx.redis_image,
manager_status=manager_status,
known_slot_types=known_slot_types,
background_task_manager=root_ctx.background_task_manager,
storage_manager=root_ctx.storage_manager,
registry=root_ctx.registry,
)
result = app_ctx.gql_schema.execute(
params['query'],
app_ctx.gql_executor,
variable_values=params['variables'],
operation_name=params['operation_name'],
context_value=gql_ctx,
middleware=[
GQLLoggingMiddleware(),
GQLMutationUnfrozenRequiredMiddleware(),
GQLMutationPrivilegeCheckMiddleware(),
],
return_promise=True)
if inspect.isawaitable(result):
result = await result
async with root_ctx.db.begin() as db_conn:
gql_ctx = GraphQueryContext(
dataloader_manager=DataLoaderManager(),
local_config=root_ctx.local_config,
shared_config=root_ctx.shared_config,
etcd=root_ctx.shared_config.etcd,
user=request['user'],
access_key=request['keypair']['access_key'],
db=root_ctx.db,
db_conn=db_conn,
redis_stat=root_ctx.redis_stat,
redis_image=root_ctx.redis_image,
manager_status=manager_status,
known_slot_types=known_slot_types,
background_task_manager=root_ctx.background_task_manager,
storage_manager=root_ctx.storage_manager,
registry=root_ctx.registry,
)
result = app_ctx.gql_schema.execute(
params['query'],
app_ctx.gql_executor,
variable_values=params['variables'],
operation_name=params['operation_name'],
context_value=gql_ctx,
middleware=[
GQLLoggingMiddleware(),
GQLMutationUnfrozenRequiredMiddleware(),
GQLMutationPrivilegeCheckMiddleware(),
],
return_promise=True)
if inspect.isawaitable(result):
result = await result
if result.errors:
errors = []
for e in result.errors:
Expand Down
80 changes: 40 additions & 40 deletions src/ai/backend/gateway/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ async def auth_middleware(request: web.Request, handler) -> web.StreamResponse:
params = _extract_auth_params(request)
if params:
sign_method, access_key, signature = params
async with root_ctx.dbpool.acquire() as conn, conn.begin():
async with root_ctx.db.begin() as conn:
j = (
keypairs
.join(users, keypairs.c.user == users.c.uuid)
Expand All @@ -414,7 +414,7 @@ async def auth_middleware(request: web.Request, handler) -> web.StreamResponse:
)
)
result = await conn.execute(query)
row = await result.first()
row = result.first()
if row is None:
raise AuthorizationFailed('Access key not found')
my_signature = \
Expand All @@ -429,27 +429,27 @@ async def auth_middleware(request: web.Request, handler) -> web.StreamResponse:
.where(keypairs.c.access_key == access_key)
)
await conn.execute(query)
request['is_authorized'] = True
request['keypair'] = {
col.name: row[f'keypairs_{col.name}']
for col in keypairs.c
if col.name != 'secret_key'
}
request['keypair']['resource_policy'] = {
col.name: row[f'keypair_resource_policies_{col.name}']
for col in keypair_resource_policies.c
}
request['user'] = {
col.name: row[f'users_{col.name}']
for col in users.c
if col.name not in ('password', 'description', 'created_at')
}
request['user']['id'] = row['keypairs_user_id'] # legacy
# if request['role'] in ['admin', 'superadmin']:
if row['keypairs_is_admin']:
request['is_admin'] = True
if request['user']['role'] == 'superadmin':
request['is_superadmin'] = True
request['is_authorized'] = True
request['keypair'] = {
col.name: row[f'keypairs_{col.name}']
for col in keypairs.c
if col.name != 'secret_key'
}
request['keypair']['resource_policy'] = {
col.name: row[f'keypair_resource_policies_{col.name}']
for col in keypair_resource_policies.c
}
request['user'] = {
col.name: row[f'users_{col.name}']
for col in users.c
if col.name not in ('password', 'description', 'created_at')
}
request['user']['id'] = row['keypairs_user_id'] # legacy
# if request['role'] in ['admin', 'superadmin']:
if row['keypairs_is_admin']:
request['is_admin'] = True
if request['user']['role'] == 'superadmin':
request['is_superadmin'] = True

# No matter if authenticated or not, pass-through to the handler.
# (if it's required, auth_required decorator will handle the situation.)
Expand Down Expand Up @@ -531,9 +531,9 @@ async def get_role(request: web.Request, params: Any) -> web.Response:
(association_groups_users.c.user_id == request['user']['uuid'])
)
)
async with root_ctx.dbpool.acquire() as conn:
async with root_ctx.db.begin() as conn:
result = await conn.execute(query)
row = await result.first()
row = result.first()
if row is None:
raise GenericNotFound('No such user group or '
'you are not the member of the group.')
Expand Down Expand Up @@ -563,12 +563,12 @@ async def authorize(request: web.Request, params: Any) -> web.Response:

# [Hooking point for AUTHORIZE with the FIRST_COMPLETED requirement]
# The hook handlers should accept the whole ``params`` dict, and optional
# ``dbpool`` parameter (if the hook needs to query to database).
# ``db`` parameter (if the hook needs to query to database).
# They should return a corresponding Backend.AI user object after performing
# their own authentication steps, like LDAP authentication, etc.
hook_result = await root_ctx.hook_plugin_ctx.dispatch(
'AUTHORIZE',
(params, root_ctx.dbpool),
(params, root_ctx.db),
return_when=FIRST_COMPLETED,
)
if hook_result.status != PASSED:
Expand All @@ -579,7 +579,7 @@ async def authorize(request: web.Request, params: Any) -> web.Response:
else:
# No AUTHORIZE hook is defined (proceed with normal login)
user = await check_credential(
root_ctx.dbpool,
root_ctx.db,
params['domain'], params['username'], params['password']
)
if user is None:
Expand All @@ -588,7 +588,7 @@ async def authorize(request: web.Request, params: Any) -> web.Response:
raise AuthorizationFailed('This account needs email verification.')
if user.get('status') in INACTIVE_USER_STATUSES:
raise AuthorizationFailed('User credential mismatch.')
async with root_ctx.dbpool.acquire() as conn:
async with root_ctx.db.begin() as conn:
query = (sa.select([keypairs.c.access_key, keypairs.c.secret_key])
.select_from(keypairs)
.where(
Expand All @@ -597,7 +597,7 @@ async def authorize(request: web.Request, params: Any) -> web.Response:
)
.order_by(sa.desc(keypairs.c.is_admin)))
result = await conn.execute(query)
keypair = await result.first()
keypair = result.first()
if keypair is None:
raise AuthorizationFailed('No API keypairs found.')
return web.json_response({
Expand Down Expand Up @@ -640,13 +640,13 @@ async def #(request: web.Request, params: Any) -> web.Response:
# Merge the hook results as a single map.
user_data_overriden = ChainMap(*cast(Mapping, hook_result.result))

async with root_ctx.dbpool.acquire() as conn:
async with root_ctx.db.begin() as conn:
# Check if email already exists.
query = (sa.select([users])
.select_from(users)
.where((users.c.email == params['email'])))
result = await conn.execute(query)
row = await result.first()
row = result.first()
if row is not None:
raise GenericBadRequest('Email already exists')

Expand All @@ -673,7 +673,7 @@ async def #(request: web.Request, params: Any) -> web.Response:
if result.rowcount > 0:
checkq = users.select().where(users.c.email == params['email'])
result = await conn.execute(checkq)
user = await result.first()
user = result.first()
# Create user's first access_key and secret_key.
ak, sk = _gen_keypair()
resource_policy = (
Expand Down Expand Up @@ -701,7 +701,7 @@ async def #(request: web.Request, params: Any) -> web.Response:
.where(groups.c.domain_name == params['domain'])
.where(groups.c.name == group_name))
result = await conn.execute(query)
grp = await result.first()
grp = result.first()
if grp is not None:
values = [{'user_id': user.uuid, 'group_id': grp.id}]
query = association_groups_users.insert().values(values)
Expand Down Expand Up @@ -741,11 +741,11 @@ async def signout(request: web.Request, params: Any) -> web.Response:
if request['user']['email'] != params['email']:
raise GenericForbidden('Not the account owner')
result = await check_credential(
root_ctx.dbpool,
root_ctx.db,
domain_name, params['email'], params['password'])
if result is None:
raise GenericBadRequest('Invalid email and/or password')
async with root_ctx.dbpool.acquire() as conn, conn.begin():
async with root_ctx.db.begin() as conn:
# Inactivate the user.
query = (
users.update()
Expand Down Expand Up @@ -779,7 +779,7 @@ async def update_password(request: web.Request, params: Any) -> web.Response:
log_args = (domain_name, email)
log.info(log_fmt, *log_args)

user = await check_credential(root_ctx.dbpool, domain_name, email, params['old_password'])
user = await check_credential(root_ctx.db, domain_name, email, params['old_password'])
if user is None:
log.info(log_fmt + ': old password mismtach', *log_args)
raise AuthorizationFailed('Old password mismatch')
Expand All @@ -801,7 +801,7 @@ async def update_password(request: web.Request, params: Any) -> web.Response:
hook_result.reason = hook_result.reason or 'invalid password format'
raise RejectedByHook.from_hook_result(hook_result)

async with root_ctx.dbpool.acquire() as conn:
async with root_ctx.db.begin() as conn:
# Update user password.
data = {
'password': params['new_password'],
Expand All @@ -821,7 +821,7 @@ async def get_ssh_keypair(request: web.Request) -> web.Response:
log_fmt = 'AUTH.GET_SSH_KEYPAIR(d:{}, ak:{})'
log_args = (domain_name, access_key)
log.info(log_fmt, *log_args)
async with root_ctx.dbpool.acquire() as conn:
async with root_ctx.db.begin() as conn:
# Get SSH public key. Return partial string from the public key just for checking.
query = (
sa.select([keypairs.c.ssh_public_key])
Expand All @@ -840,7 +840,7 @@ async def refresh_ssh_keypair(request: web.Request) -> web.Response:
log_args = (domain_name, access_key)
log.info(log_fmt, *log_args)
root_ctx: RootContext = request.app['_root.context']
async with root_ctx.dbpool.acquire() as conn:
async with root_ctx.db.begin() as conn:
pubkey, privkey = generate_ssh_keypair()
data = {
'ssh_public_key': pubkey,
Expand Down
Loading