Skip to content

Commit

Permalink
Merge pull request feast-dev#35 from dmartinol/feast-rbac-auth
Browse files Browse the repository at this point in the history
Feast RBAC Authorization Manager
  • Loading branch information
redhatHameed authored Jul 9, 2024
2 parents 0ceb0b8 + 8cd2d18 commit 087fd71
Show file tree
Hide file tree
Showing 27 changed files with 957 additions and 142 deletions.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from datetime import timedelta

import pandas as pd
import os

from feast import (
Entity,
Expand All @@ -14,6 +13,8 @@
PushSource,
RequestSource,
)
from feast.feature_logging import LoggingConfig
from feast.infra.offline_stores.file_source import FileLoggingDestination
from feast.on_demand_feature_view import on_demand_feature_view
from feast.types import Float32, Float64, Int64

Expand Down Expand Up @@ -89,6 +90,9 @@ def transformed_conv_rate(inputs: pd.DataFrame) -> pd.DataFrame:
driver_stats_fv[["conv_rate"]], # Sub-selects a feature from a feature view
transformed_conv_rate, # Selects all features from the feature view
],
logging_config=LoggingConfig(
destination=FileLoggingDestination(path=f"{os.path.dirname(os.path.abspath(__file__))}/data")
),
)
driver_activity_v2 = FeatureService(
name="driver_activity_v2", features=[driver_stats_fv, transformed_conv_rate]
Expand Down
24 changes: 21 additions & 3 deletions sdk/python/feast/feature_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@

import pandas as pd
from dateutil import parser
from fastapi import FastAPI, HTTPException, Request, Response, status
from fastapi import Depends, FastAPI, HTTPException, Request, Response, status
from fastapi.logger import logger
from fastapi.params import Depends
from google.protobuf.json_format import MessageToDict
from pydantic import BaseModel

Expand All @@ -20,6 +19,13 @@
from feast.errors import FeatureViewNotFoundException, PushSourceNotFoundException
from feast.permissions.action import WRITE, AuthzedAction
from feast.permissions.security_manager import assert_permissions
from feast.permissions.server.rest import inject_user_details
from feast.permissions.server.utils import (
ServerType,
auth_manager_type_from_env,
init_auth_manager,
init_security_manager,
)


# TODO: deprecate this in favor of push features
Expand Down Expand Up @@ -84,7 +90,11 @@ async def lifespan(app: FastAPI):
async def get_body(request: Request):
return await request.body()

@app.post("/get-online-features")
# TODO RBAC: complete the dependencies for the other endpoints
@app.post(
"/get-online-features",
dependencies=[Depends(inject_user_details)],
)
def get_online_features(body=Depends(get_body)):
try:
body = json.loads(body)
Expand Down Expand Up @@ -297,6 +307,14 @@ def start_server(
keep_alive_timeout: int,
registry_ttl_sec: int,
):
# TODO RBAC remove and use the auth section of the feature store config instead
auth_manager_type = auth_manager_type_from_env()
init_security_manager(auth_manager_type=auth_manager_type, fs=store)
init_auth_manager(
auth_manager_type=auth_manager_type,
server_type=ServerType.REST,
)

if sys.platform != "win32":
FeastServeApplication(
store=store,
Expand Down
26 changes: 25 additions & 1 deletion sdk/python/feast/offline_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,27 @@
from feast.infra.offline_stores.offline_utils import get_offline_store_from_config
from feast.permissions.action import AuthzedAction
from feast.permissions.security_manager import assert_permissions
from feast.permissions.server.arrow import (
arrowflight_middleware,
inject_user_details,
)
from feast.permissions.server.utils import (
ServerType,
auth_manager_type_from_env,
init_auth_manager,
init_security_manager,
)
from feast.saved_dataset import SavedDatasetStorage

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class OfflineServer(fl.FlightServerBase):
def __init__(self, store: FeatureStore, location: str, **kwargs):
super(OfflineServer, self).__init__(location, **kwargs)
super(OfflineServer, self).__init__(
location, middleware=arrowflight_middleware(), **kwargs
)
self._location = location
# A dictionary of configured flights, e.g. API calls received and not yet served
self.flights: Dict[str, Any] = {}
Expand Down Expand Up @@ -159,6 +172,9 @@ def _validate_do_get_parameters(self, command: dict):
# Extracts the API parameters from the flights dictionary, delegates the execution to the FeatureStore instance
# and returns the stream of data
def do_get(self, context: fl.ServerCallContext, ticket: fl.Ticket):
# TODO RBAC: add the same to all the authorized endpoints
inject_user_details(context)

key = ast.literal_eval(ticket.ticket.decode())
if key not in self.flights:
logger.error(f"Unknown key {key}")
Expand Down Expand Up @@ -432,6 +448,14 @@ def start_server(
host: str,
port: int,
):
# TODO RBAC remove and use the auth section of the feature store config instead
auth_manager_type = auth_manager_type_from_env()
init_security_manager(auth_manager_type=auth_manager_type, fs=store)
init_auth_manager(
auth_manager_type=auth_manager_type,
server_type=ServerType.ARROW,
)

location = "grpc+tcp://{}:{}".format(host, port)
server = OfflineServer(store, location)
logger.info(f"Offline store server serving on {location}")
Expand Down
Empty file.
68 changes: 68 additions & 0 deletions sdk/python/feast/permissions/auth/auth_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from abc import ABC
from typing import Optional

from .token_extractor import NoAuthTokenExtractor, TokenExtractor
from .token_parser import NoAuthTokenParser, TokenParser


class AuthManager(ABC):
"""
The authorization manager offers services to manage authorization tokens from client requests
to extract user details before injecting them in the security context.
"""

_token_parser: TokenParser
_token_extractor: TokenExtractor

def __init__(self, token_parser: TokenParser, token_extractor: TokenExtractor):
self._token_parser = token_parser
self._token_extractor = token_extractor

@property
def token_parser(self) -> TokenParser:
return self._token_parser

@property
def token_extractor(self) -> TokenExtractor:
return self._token_extractor


"""
The possibly empty global instance of `AuthManager`.
"""
_auth_manager: Optional[AuthManager] = None


def get_auth_manager() -> AuthManager:
"""
Return the global instance of `AuthManager`.
Raises:
RuntimeError if the clobal instance is not set.
"""
global _auth_manager
if _auth_manager is None:
raise RuntimeError(
"AuthManager is not initialized. Call 'set_auth_manager' first."
)
return _auth_manager


def set_auth_manager(auth_manager: AuthManager):
"""
Initialize the global instance of `AuthManager`.
"""

global _auth_manager
_auth_manager = auth_manager


class AllowAll(AuthManager):
"""
An AuthManager not extracting nor parsing the authorization token.
"""

def __init__(self):
super().__init__(
token_extractor=NoAuthTokenExtractor(), token_parser=NoAuthTokenParser()
)
103 changes: 103 additions & 0 deletions sdk/python/feast/permissions/auth/kubernetes_token_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import logging

import jwt
from kubernetes import client, config
from starlette.authentication import (
AuthenticationError,
)

from feast.permissions.auth.token_parser import TokenParser
from feast.permissions.user import User

logger = logging.getLogger(__name__)


class KubernetesTokenParser(TokenParser):
"""
A `TokenParser` implementation to use Kubernetes RBAC resources to retrieve the user details.
The assumption is that the request header includes an authorization bearer with the token of the
client `ServiceAccount`.
By inspecting the role bindings, this `TokenParser` extracts the associated `Role`s.
The client `ServiceAccount` is instead used as the user name, together with the current namespace.
"""

def __init__(self):
config.load_incluster_config()
self.v1 = client.CoreV1Api()
self.rbac_v1 = client.RbacAuthorizationV1Api()

async def user_details_from_access_token(self, access_token: str) -> User:
"""
Extract the service account from the token and search the roles associated with it.
Returns:
User: Current user, with associated roles. The `username` is the `:` separated concatenation of `namespace` and `service account name`.
Raises:
AuthenticationError if any error happens.
"""
sa_namespace, sa_name = _decode_token(access_token)
current_user = f"{sa_namespace}:{sa_name}"
logging.info(f"Received request from {sa_name} in {sa_namespace}")

roles = self.get_roles(sa_namespace, sa_name)
logging.info(f"SA roles are: {roles}")

return User(username=current_user, roles=roles)

def get_roles(self, namespace: str, service_account_name: str) -> list[str]:
"""
Fetches the Kubernetes `Role`s associated to the given `ServiceAccount` in the given `namespace`.
The research also includes the `ClusterRole`s, so the running deployment must be granted enough permissions to query
for such instances in all the namespaces.
Returns:
list[str]: Name of the `Role`s and `ClusterRole`s associated to the service account. No string manipulation is performed on the role name.
"""
role_bindings = self.rbac_v1.list_namespaced_role_binding(namespace)
cluster_role_bindings = self.rbac_v1.list_cluster_role_binding()

roles: set[str] = set()

for binding in role_bindings.items:
if binding.subjects is not None:
for subject in binding.subjects:
if (
subject.kind == "ServiceAccount"
and subject.name == service_account_name
):
roles.add(binding.role_ref.name)

for binding in cluster_role_bindings.items:
if binding.subjects is not None:
for subject in binding.subjects:
if (
subject.kind == "ServiceAccount"
and subject.name == service_account_name
and subject.namespace == namespace
):
roles.add(binding.role_ref.name)

return list(roles)


def _decode_token(access_token: str) -> tuple[str, str]:
"""
The `sub` portion of the decoded token includes the service account name in the format: `system:serviceaccount:NAMESPACE:SA_NAME`
Returns:
str: the namespace name.
str: the `ServiceAccount` name.
"""
try:
decoded_token = jwt.decode(access_token, options={"verify_signature": False})
if "sub" in decoded_token:
subject: str = decoded_token["sub"]
_, _, sa_namespace, sa_name = subject.split(":")
return (sa_namespace, sa_name)
else:
raise AuthenticationError("Missing sub section in received token.")
except jwt.DecodeError as e:
raise AuthenticationError(f"Error decoding JWT token: {e}")
Loading

0 comments on commit 087fd71

Please # to comment.