Skip to content

Commit

Permalink
fix(tenant): fix delete tenants behavior (#6013)
Browse files Browse the repository at this point in the history
  • Loading branch information
vicferpoy authored Dec 4, 2024
1 parent 58723ae commit ad7134d
Show file tree
Hide file tree
Showing 8 changed files with 1,162 additions and 945 deletions.
1,956 changes: 1,025 additions & 931 deletions api/poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ drf-nested-routers = "^0.94.1"
drf-spectacular = "0.27.2"
drf-spectacular-jsonapi = "0.5.1"
gunicorn = "23.0.0"
prowler = {git = "https://github.com/prowler-cloud/prowler.git", branch = "master"}
prowler = {git = "https://github.com/prowler-cloud/prowler.git", tag = "5.0.0"}
psycopg2-binary = "2.9.9"
pytest-celery = {extras = ["redis"], version = "^1.0.1"}
# Needed for prowler compatibility
Expand Down
28 changes: 25 additions & 3 deletions api/src/backend/api/base_views.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import uuid

from django.db import transaction, connection
from django.db import connection, transaction
from rest_framework import permissions
from rest_framework.exceptions import NotAuthenticated
from rest_framework.filters import SearchFilter
Expand Down Expand Up @@ -69,10 +69,32 @@ def dispatch(self, request, *args, **kwargs):
return super().dispatch(request, *args, **kwargs)

def initial(self, request, *args, **kwargs):
user_id = str(request.user.id)
if (
request.resolver_match.url_name != "tenant-detail"
and request.method != "DELETE"
):
user_id = str(request.user.id)

with connection.cursor() as cursor:
cursor.execute(f"SELECT set_config('api.user_id', '{user_id}', TRUE);")
return super().initial(request, *args, **kwargs)

# TODO: DRY this when we have time
if request.auth is None:
raise NotAuthenticated

tenant_id = request.auth.get("tenant_id")
if tenant_id is None:
raise NotAuthenticated("Tenant ID is not present in token")

try:
uuid.UUID(tenant_id)
except ValueError:
raise ValidationError("Tenant ID must be a valid UUID")

with connection.cursor() as cursor:
cursor.execute(f"SELECT set_config('api.user_id', '{user_id}', TRUE);")
cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
self.request.tenant_id = tenant_id
return super().initial(request, *args, **kwargs)


Expand Down
13 changes: 12 additions & 1 deletion api/src/backend/api/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,13 +418,24 @@ def test_tenants_partial_update_invalid_content(
)
assert response.status_code == status.HTTP_400_BAD_REQUEST

def test_tenants_delete(self, authenticated_client, tenants_fixture):
@patch("api.db_router.MainRouter.admin_db", new="default")
@patch("api.v1.views.delete_tenant_task.apply_async")
def test_tenants_delete(
self, delete_tenant_mock, authenticated_client, tenants_fixture
):
def _delete_tenant(kwargs):
Tenant.objects.filter(pk=kwargs.get("tenant_id")).delete()

delete_tenant_mock.side_effect = _delete_tenant
tenant1, *_ = tenants_fixture
response = authenticated_client.delete(
reverse("tenant-detail", kwargs={"pk": tenant1.id})
)
assert response.status_code == status.HTTP_204_NO_CONTENT
assert Tenant.objects.count() == len(tenants_fixture) - 1
assert Membership.objects.filter(tenant_id=tenant1.id).count() == 0
# User is not deleted because it has another membership
assert User.objects.count() == 1

def test_tenants_delete_invalid(self, authenticated_client):
response = authenticated_client.delete(
Expand Down
22 changes: 21 additions & 1 deletion api/src/backend/api/v1/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from tasks.tasks import (
check_provider_connection_task,
delete_provider_task,
delete_tenant_task,
perform_scan_summary_task,
perform_scan_task,
)
Expand Down Expand Up @@ -171,7 +172,7 @@ class SchemaView(SpectacularAPIView):

def get(self, request, *args, **kwargs):
spectacular_settings.TITLE = "Prowler API"
spectacular_settings.VERSION = "1.0.0"
spectacular_settings.VERSION = "1.0.1"
spectacular_settings.DESCRIPTION = (
"Prowler API specification.\n\nThis file is auto-generated."
)
Expand Down Expand Up @@ -401,6 +402,25 @@ def create(self, request, *args, **kwargs):
)
return Response(data=serializer.data, status=status.HTTP_201_CREATED)

def destroy(self, request, *args, **kwargs):
# This will perform validation and raise a 404 if the tenant does not exist
tenant_id = kwargs.get("pk")
get_object_or_404(Tenant, id=tenant_id)

with transaction.atomic():
# Delete memberships
Membership.objects.using(MainRouter.admin_db).filter(
tenant_id=tenant_id
).delete()

# Delete users without memberships
User.objects.using(MainRouter.admin_db).filter(
membership__isnull=True
).delete()
# Delete tenant in batches
delete_tenant_task.apply_async(kwargs={"tenant_id": tenant_id})
return Response(status=status.HTTP_204_NO_CONTENT)


@extend_schema_view(
list=extend_schema(
Expand Down
28 changes: 26 additions & 2 deletions api/src/backend/tasks/jobs/deletion.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from celery.utils.log import get_task_logger
from django.db import transaction

from api.db_utils import batch_delete
from api.models import Finding, Provider, Resource, Scan, ScanSummary
from api.db_router import MainRouter
from api.db_utils import batch_delete, tenant_transaction
from api.models import Finding, Provider, Resource, Scan, ScanSummary, Tenant

logger = get_task_logger(__name__)

Expand Down Expand Up @@ -49,3 +50,26 @@ def delete_provider(pk: str):
deletion_summary.update(provider_summary)

return deletion_summary


def delete_tenant(pk: str):
"""
Gracefully deletes an instance of a tenant along with its related data.
Args:
pk (str): The primary key of the Tenant instance to delete.
Returns:
dict: A dictionary with the count of deleted objects per model,
including related models.
"""
deletion_summary = {}

for provider in Provider.objects.using(MainRouter.admin_db).filter(tenant_id=pk):
with tenant_transaction(pk):
summary = delete_provider(provider.id)
deletion_summary.update(summary)

Tenant.objects.using(MainRouter.admin_db).filter(id=pk).delete()

return deletion_summary
7 changes: 6 additions & 1 deletion api/src/backend/tasks/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from config.celery import RLSTask
from django_celery_beat.models import PeriodicTask
from tasks.jobs.connection import check_provider_connection
from tasks.jobs.deletion import delete_provider
from tasks.jobs.deletion import delete_provider, delete_tenant
from tasks.jobs.scan import aggregate_findings, perform_prowler_scan

from api.db_utils import tenant_transaction
Expand Down Expand Up @@ -134,3 +134,8 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
@shared_task(name="scan-summary")
def perform_scan_summary_task(tenant_id: str, scan_id: str):
return aggregate_findings(tenant_id=tenant_id, scan_id=scan_id)


@shared_task(name="tenant-deletion")
def delete_tenant_task(tenant_id: str):
return delete_tenant(pk=tenant_id)
51 changes: 46 additions & 5 deletions api/src/backend/tasks/tests/test_deletion.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,63 @@
from unittest.mock import patch

import pytest
from django.core.exceptions import ObjectDoesNotExist
from tasks.jobs.deletion import delete_provider
from tasks.jobs.deletion import delete_provider, delete_tenant

from api.models import Provider
from api.models import Provider, Tenant


@pytest.mark.django_db
class TestDeleteInstance:
def test_delete_instance_success(self, providers_fixture):
class TestDeleteProvider:
def test_delete_provider_success(self, providers_fixture):
instance = providers_fixture[0]
result = delete_provider(instance.id)

assert result
with pytest.raises(ObjectDoesNotExist):
Provider.objects.get(pk=instance.id)

def test_delete_instance_does_not_exist(self):
def test_delete_provider_does_not_exist(self):
non_existent_pk = "babf6796-cfcc-4fd3-9dcf-88d012247645"

with pytest.raises(ObjectDoesNotExist):
delete_provider(non_existent_pk)


@patch("api.db_router.MainRouter.admin_db", new="default")
@pytest.mark.django_db
class TestDeleteTenant:
def test_delete_tenant_success(self, tenants_fixture, providers_fixture):
"""
Test successful deletion of a tenant and its related data.
"""
tenant = tenants_fixture[0]
providers = Provider.objects.filter(tenant_id=tenant.id)

# Ensure the tenant and related providers exist before deletion
assert Tenant.objects.filter(id=tenant.id).exists()
assert providers.exists()

# Call the function and validate the result
deletion_summary = delete_tenant(tenant.id)

assert deletion_summary is not None
assert not Tenant.objects.filter(id=tenant.id).exists()
assert not Provider.objects.filter(tenant_id=tenant.id).exists()

def test_delete_tenant_with_no_providers(self, tenants_fixture):
"""
Test deletion of a tenant with no related providers.
"""
tenant = tenants_fixture[1] # Assume this tenant has no providers
providers = Provider.objects.filter(tenant_id=tenant.id)

# Ensure the tenant exists but has no related providers
assert Tenant.objects.filter(id=tenant.id).exists()
assert not providers.exists()

# Call the function and validate the result
deletion_summary = delete_tenant(tenant.id)

assert deletion_summary == {} # No providers, so empty summary
assert not Tenant.objects.filter(id=tenant.id).exists()

0 comments on commit ad7134d

Please # to comment.