Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

feat(aws): get regions by partition #5748

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 39 additions & 21 deletions prowler/providers/aws/aws_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
AWSIAMRoleARNPartitionEmptyError,
AWSIAMRoleARNRegionNotEmtpyError,
AWSIAMRoleARNServiceNotIAMnorSTSError,
AWSInvalidPartitionError,
AWSInvalidProviderIdError,
AWSNoCredentialsError,
AWSProfileNotFoundError,
Expand Down Expand Up @@ -1297,6 +1298,44 @@ def create_sts_session(
)
raise error

@staticmethod
def get_regions_by_partition(partition: str = None) -> set:
"""
Get the available AWS regions from the AWS services JSON file with the ability of filtering by partition.

Args:
- partition (str): The AWS partition name. Default is None.

Returns:
set: A set of available AWS regions. All if no `partition` is especified.
"""
try:
data = read_aws_regions_file()

regions = set()
if partition is None:
for service in data["services"].values():
for partition in service["regions"]:
for item in service["regions"][partition]:
regions.add(item)
else:
for service in data["services"].values():
try:
for item in service["regions"][partition]:
regions.add(item)
except KeyError as key_error:
logger.error(
f"{key_error.__class__.__name__}[{key_error.__traceback__.tb_lineno}]: {key_error}"
)
raise AWSInvalidPartitionError(
message=f"Invalid partition name: {partition}",
file=os.path.basename(__file__),
)
return regions
except Exception as error:
logger.error(f"{error.__class__.__name__}: {error}")
return set()


def read_aws_regions_file() -> dict:
"""
Expand All @@ -1313,27 +1352,6 @@ def read_aws_regions_file() -> dict:
return data


def get_aws_available_regions() -> set:
"""
Get the available AWS regions from the AWS services JSON file.

Returns:
set: A set of available AWS regions.
"""
try:
data = read_aws_regions_file()

regions = set()
for service in data["services"].values():
for partition in service["regions"]:
for item in service["regions"][partition]:
regions.add(item)
return regions
except Exception as error:
logger.error(f"{error.__class__.__name__}: {error}")
return set()


# TODO: This can be moved to another class since it doesn't need self
def get_aws_region_for_sts(session_region: str, regions: set[str]) -> str:
# If there is no region passed with -f/--region/--filter-region
Expand Down
11 changes: 11 additions & 0 deletions prowler/providers/aws/exceptions/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ class AWSBaseException(ProwlerException):
"message": "The provided AWS Session Token is expired",
"remediation": "Get a new AWS Session Token and configure it for the provider.",
},
(1917, "AWSInvalidPartitionError"): {
"message": "The provided AWS partition is invalid",
"remediation": "Check the provided AWS partition and ensure it is valid.",
},
}

def __init__(self, code, file=None, original_exception=None, message=None):
Expand Down Expand Up @@ -220,3 +224,10 @@ def __init__(self, file=None, original_exception=None, message=None):
super().__init__(
1016, file=file, original_exception=original_exception, message=message
)


class AWSInvalidPartitionError(AWSBaseException):
def __init__(self, file=None, original_exception=None, message=None):
super().__init__(
1917, file=file, original_exception=original_exception, message=message
)
4 changes: 2 additions & 2 deletions prowler/providers/aws/lib/arguments/arguments.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from argparse import ArgumentTypeError, Namespace
from re import fullmatch, search

from prowler.providers.aws.aws_provider import get_aws_available_regions
from prowler.providers.aws.aws_provider import AwsProvider
from prowler.providers.aws.config import ROLE_SESSION_NAME
from prowler.providers.aws.lib.arn.arn import arn_type

Expand Down Expand Up @@ -64,7 +64,7 @@ def init_parser(self):
"-f",
nargs="+",
help="AWS region names to run Prowler against",
choices=get_aws_available_regions(),
choices=AwsProvider.get_regions_by_partition(),
)
# AWS Organizations
aws_orgs_subparser = aws_parser.add_argument_group("AWS Organizations")
Expand Down
12 changes: 9 additions & 3 deletions tests/config/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
load_and_validate_config_file,
load_and_validate_fixer_config_file,
)
from prowler.providers.aws.aws_provider import get_aws_available_regions
from prowler.providers.aws.aws_provider import AwsProvider

MOCK_PROWLER_VERSION = "3.3.0"
MOCK_OLD_PROWLER_VERSION = "0.0.0"
Expand Down Expand Up @@ -346,8 +346,14 @@ def mock_prowler_get_latest_release(_, **kwargs):


class Test_Config:
def test_get_aws_available_regions(self):
assert len(get_aws_available_regions()) == 34
def test_get_regions_by_partition(self):
assert len(AwsProvider.get_regions_by_partition()) == 34

def test_get_regions_by_partition_with_partition(self):
assert len(AwsProvider.get_regions_by_partition("aws-cn")) == 2

def test_get_regions_by_partition_with_unknown_partition(self):
assert len(AwsProvider.get_regions_by_partition("unknown")) == 0

@mock.patch(
"prowler.config.config.requests.get", new=mock_prowler_get_latest_release
Expand Down
58 changes: 51 additions & 7 deletions tests/providers/aws/aws_provider_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,7 @@
from pytest import raises
from tzlocal import get_localzone

from prowler.providers.aws.aws_provider import (
AwsProvider,
get_aws_available_regions,
get_aws_region_for_sts,
)
from prowler.providers.aws.aws_provider import AwsProvider, get_aws_region_for_sts
from prowler.providers.aws.config import (
AWS_STS_GLOBAL_ENDPOINT_REGION,
BOTO3_USER_AGENT_EXTRA,
Expand Down Expand Up @@ -1720,7 +1716,7 @@ def test_get_regions_from_audit_resources_without_regions(self):
)
assert not recovered_regions

def test_get_aws_available_regions(self):
def test_get_regions_by_partition(self):
with patch(
"prowler.providers.aws.aws_provider.read_aws_regions_file",
return_value={
Expand All @@ -1741,12 +1737,60 @@ def test_get_aws_available_regions(self):
}
},
):
assert get_aws_available_regions() == {
assert AwsProvider.get_regions_by_partition() == {
"af-south-1",
"cn-north-1",
"us-gov-west-1",
}

def test_get_regions_by_partition_with_partition(self):
with patch(
"prowler.providers.aws.aws_provider.read_aws_regions_file",
return_value={
"services": {
"acm": {
"regions": {
"aws": [
"af-south-1",
],
"aws-cn": [
"cn-north-1",
],
"aws-us-gov": [
"us-gov-west-1",
],
}
}
}
},
):
assert AwsProvider.get_regions_by_partition("aws-cn") == {
"cn-north-1",
}

def test_get_regions_by_partition_with_unknown_partition(self):
with patch(
"prowler.providers.aws.aws_provider.read_aws_regions_file",
return_value={
"services": {
"acm": {
"regions": {
"aws": [
"af-south-1",
],
"aws-cn": [
"cn-north-1",
],
"aws-us-gov": [
"us-gov-west-1",
],
}
}
}
},
):
assert AwsProvider.get_regions_by_partition("unknown") == set()

def test_get_aws_region_for_sts_input_regions_none_session_region_none(self):
input_regions = None
session_region = None
Expand Down