diff --git a/api/src/backend/api/db_utils.py b/api/src/backend/api/db_utils.py index 466f78aea6b..1bcf14209c3 100644 --- a/api/src/backend/api/db_utils.py +++ b/api/src/backend/api/db_utils.py @@ -5,7 +5,6 @@ from django.conf import settings from django.contrib.auth.models import BaseUserManager -from django.core.paginator import Paginator from django.db import connection, models, transaction from psycopg2 import connect as psycopg2_connect from psycopg2.extensions import AsIs, new_type, register_adapter, register_type @@ -120,15 +119,18 @@ def batch_delete(queryset, batch_size=5000): total_deleted = 0 deletion_summary = {} - paginator = Paginator(queryset.order_by("id").only("id"), batch_size) - - for page_num in paginator.page_range: - batch_ids = [obj.id for obj in paginator.page(page_num).object_list] + while True: + # Get a batch of IDs to delete + batch_ids = set( + queryset.values_list("id", flat=True).order_by("id")[:batch_size] + ) + if not batch_ids: + # No more objects to delete + break deleted_count, deleted_info = queryset.filter(id__in=batch_ids).delete() total_deleted += deleted_count - for model_label, count in deleted_info.items(): deletion_summary[model_label] = deletion_summary.get(model_label, 0) + count diff --git a/api/src/backend/api/migrations/0005_update_provider_unique_constraint_with_is_deleted.py b/api/src/backend/api/migrations/0005_update_provider_unique_constraint_with_is_deleted.py new file mode 100644 index 00000000000..5fd5097376f --- /dev/null +++ b/api/src/backend/api/migrations/0005_update_provider_unique_constraint_with_is_deleted.py @@ -0,0 +1,23 @@ +# Generated by Django 5.1.1 on 2024-12-20 13:16 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("api", "0004_rbac_missing_admin_roles"), + ] + + operations = [ + migrations.RemoveConstraint( + model_name="provider", + name="unique_provider_uids", + ), + migrations.AddConstraint( + model_name="provider", + constraint=models.UniqueConstraint( + fields=("tenant_id", "provider", "uid", "is_deleted"), + name="unique_provider_uids", + ), + ), + ] diff --git a/api/src/backend/api/models.py b/api/src/backend/api/models.py index 8b2dc540727..f35e4d43136 100644 --- a/api/src/backend/api/models.py +++ b/api/src/backend/api/models.py @@ -256,7 +256,7 @@ class Meta(RowLevelSecurityProtectedModel.Meta): constraints = [ models.UniqueConstraint( - fields=("tenant_id", "provider", "uid"), + fields=("tenant_id", "provider", "uid", "is_deleted"), name="unique_provider_uids", ), RowLevelSecurityConstraint( diff --git a/api/src/backend/api/tests/integration/test_providers.py b/api/src/backend/api/tests/integration/test_providers.py new file mode 100644 index 00000000000..0f17c2d8391 --- /dev/null +++ b/api/src/backend/api/tests/integration/test_providers.py @@ -0,0 +1,85 @@ +from unittest.mock import Mock, patch + +import pytest +from conftest import get_api_tokens, get_authorization_header +from django.urls import reverse +from rest_framework.test import APIClient + +from api.models import Provider + + +@patch("api.v1.views.Task.objects.get") +@patch("api.v1.views.delete_provider_task.delay") +@pytest.mark.django_db +def test_delete_provider_without_executing_task( + mock_delete_task, mock_task_get, create_test_user, tenants_fixture, tasks_fixture +): + client = APIClient() + + test_user = "test_email@prowler.com" + test_password = "test_password" + + prowler_task = tasks_fixture[0] + task_mock = Mock() + task_mock.id = prowler_task.id + mock_delete_task.return_value = task_mock + mock_task_get.return_value = prowler_task + + user_creation_response = client.post( + reverse("user-list"), + data={ + "data": { + "type": "users", + "attributes": { + "name": "test", + "email": test_user, + "password": test_password, + }, + } + }, + format="vnd.api+json", + ) + assert user_creation_response.status_code == 201 + + access_token, _ = get_api_tokens(client, test_user, test_password) + auth_headers = get_authorization_header(access_token) + + create_provider_response = client.post( + reverse("provider-list"), + data={ + "data": { + "type": "providers", + "attributes": { + "provider": Provider.ProviderChoices.AWS, + "uid": "123456789012", + }, + } + }, + format="vnd.api+json", + headers=auth_headers, + ) + assert create_provider_response.status_code == 201 + provider_id = create_provider_response.json()["data"]["id"] + provider_uid = create_provider_response.json()["data"]["attributes"]["uid"] + + remove_provider = client.delete( + reverse("provider-detail", kwargs={"pk": provider_id}), + headers=auth_headers, + ) + assert remove_provider.status_code == 202 + + recreate_provider_response = client.post( + reverse("provider-list"), + data={ + "data": { + "type": "providers", + "attributes": { + "provider": Provider.ProviderChoices.AWS, + "uid": provider_uid, + }, + } + }, + format="vnd.api+json", + headers=auth_headers, + ) + assert recreate_provider_response.status_code == 201 diff --git a/api/src/backend/api/tests/test_db_utils.py b/api/src/backend/api/tests/test_db_utils.py index 15cbf883996..6c2364600a6 100644 --- a/api/src/backend/api/tests/test_db_utils.py +++ b/api/src/backend/api/tests/test_db_utils.py @@ -2,7 +2,15 @@ from enum import Enum from unittest.mock import patch -from api.db_utils import enum_to_choices, one_week_from_now, generate_random_token +import pytest + +from api.db_utils import ( + batch_delete, + enum_to_choices, + generate_random_token, + one_week_from_now, +) +from api.models import Provider class TestEnumToChoices: @@ -106,3 +114,26 @@ def test_generate_random_token_no_symbols_provided(self): token = generate_random_token(length=5, symbols="") # Default symbols assert len(token) == 5 + + +class TestBatchDelete: + @pytest.fixture + def create_test_providers(self, tenants_fixture): + tenant = tenants_fixture[0] + provider_id = 123456789012 + provider_count = 10 + for i in range(provider_count): + Provider.objects.create( + tenant=tenant, + uid=f"{provider_id + i}", + provider=Provider.ProviderChoices.AWS, + ) + return provider_count + + @pytest.mark.django_db + def test_batch_delete(self, create_test_providers): + _, summary = batch_delete( + Provider.objects.all(), batch_size=create_test_providers // 2 + ) + assert Provider.objects.all().count() == 0 + assert summary == {"api.Provider": create_test_providers}