Skip to content

Commit

Permalink
Add auth string handling to server side client (#16833)
Browse files Browse the repository at this point in the history
  • Loading branch information
cicdw authored Jan 23, 2025
1 parent 2b6f0d1 commit 131136b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/prefect/server/api/clients.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import base64
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from urllib.parse import quote
from uuid import UUID
Expand All @@ -17,6 +18,7 @@
from prefect.server.schemas.core import WorkPool
from prefect.server.schemas.filters import VariableFilter, VariableFilterName
from prefect.server.schemas.responses import DeploymentResponse, OrchestrationResult
from prefect.settings import PREFECT_SERVER_API_AUTH_STRING
from prefect.types import StrictVariableValue

if TYPE_CHECKING:
Expand All @@ -37,6 +39,13 @@ def __init__(self, additional_headers: dict[str, str] | None = None):
# will point it to the the currently running server instance
api_app = create_app()

# we pull the auth string from _server_ settings because this client is run on the server
auth_string = PREFECT_SERVER_API_AUTH_STRING.value()

if auth_string:
token = base64.b64encode(auth_string.encode("utf-8")).decode("utf-8")
additional_headers.setdefault("Authorization", f"Basic {token}")

self._http_client = PrefectHttpxAsyncClient(
transport=httpx.ASGITransport(app=api_app, raise_app_exceptions=False),
headers={**additional_headers},
Expand Down
15 changes: 15 additions & 0 deletions tests/server/api/test_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
SetStateStatus,
)
from prefect.server.schemas.states import Paused, Suspended
from prefect.settings import (
PREFECT_API_AUTH_STRING,
PREFECT_SERVER_API_AUTH_STRING,
temporary_settings,
)

if TYPE_CHECKING:
from prefect.server.database.orm_models import ORMDeployment, ORMVariable
Expand Down Expand Up @@ -86,6 +91,16 @@ async def suspended_flow_run(session: AsyncSession) -> FlowRun:
return FlowRun.model_validate(flow_run, from_attributes=True)


async def test_get_client_includes_auth_string_from_context():
with temporary_settings(updates={PREFECT_API_AUTH_STRING: "admin:test"}):
async with OrchestrationClient() as client:
assert "Authorization" not in client._http_client.headers

with temporary_settings(updates={PREFECT_SERVER_API_AUTH_STRING: "admin:test"}):
async with OrchestrationClient() as client:
assert client._http_client.headers["Authorization"].startswith("Basic")


async def test_read_deployment(
deployment: "ORMDeployment", orchestration_client: OrchestrationClient
):
Expand Down

0 comments on commit 131136b

Please # to comment.