Skip to content

Commit

Permalink
fix(db-utils): fix batch_delete function (#6285)
Browse files Browse the repository at this point in the history
Co-authored-by: Víctor Fernández Poyatos <victor@prowler.com>
  • Loading branch information
prowler-bot and vicferpoy authored Dec 20, 2024
1 parent 7022b7b commit c656cf8
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 8 deletions.
14 changes: 8 additions & 6 deletions api/src/backend/api/db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from django.conf import settings
from django.contrib.auth.models import BaseUserManager
from django.core.paginator import Paginator
from django.db import connection, models, transaction
from psycopg2 import connect as psycopg2_connect
from psycopg2.extensions import AsIs, new_type, register_adapter, register_type
Expand Down Expand Up @@ -120,15 +119,18 @@ def batch_delete(queryset, batch_size=5000):
total_deleted = 0
deletion_summary = {}

paginator = Paginator(queryset.order_by("id").only("id"), batch_size)

for page_num in paginator.page_range:
batch_ids = [obj.id for obj in paginator.page(page_num).object_list]
while True:
# Get a batch of IDs to delete
batch_ids = set(
queryset.values_list("id", flat=True).order_by("id")[:batch_size]
)
if not batch_ids:
# No more objects to delete
break

deleted_count, deleted_info = queryset.filter(id__in=batch_ids).delete()

total_deleted += deleted_count

for model_label, count in deleted_info.items():
deletion_summary[model_label] = deletion_summary.get(model_label, 0) + count

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Generated by Django 5.1.1 on 2024-12-20 13:16

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("api", "0004_rbac_missing_admin_roles"),
]

operations = [
migrations.RemoveConstraint(
model_name="provider",
name="unique_provider_uids",
),
migrations.AddConstraint(
model_name="provider",
constraint=models.UniqueConstraint(
fields=("tenant_id", "provider", "uid", "is_deleted"),
name="unique_provider_uids",
),
),
]
2 changes: 1 addition & 1 deletion api/src/backend/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ class Meta(RowLevelSecurityProtectedModel.Meta):

constraints = [
models.UniqueConstraint(
fields=("tenant_id", "provider", "uid"),
fields=("tenant_id", "provider", "uid", "is_deleted"),
name="unique_provider_uids",
),
RowLevelSecurityConstraint(
Expand Down
85 changes: 85 additions & 0 deletions api/src/backend/api/tests/integration/test_providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from unittest.mock import Mock, patch

import pytest
from conftest import get_api_tokens, get_authorization_header
from django.urls import reverse
from rest_framework.test import APIClient

from api.models import Provider


@patch("api.v1.views.Task.objects.get")
@patch("api.v1.views.delete_provider_task.delay")
@pytest.mark.django_db
def test_delete_provider_without_executing_task(
mock_delete_task, mock_task_get, create_test_user, tenants_fixture, tasks_fixture
):
client = APIClient()

test_user = "test_email@prowler.com"
test_password = "test_password"

prowler_task = tasks_fixture[0]
task_mock = Mock()
task_mock.id = prowler_task.id
mock_delete_task.return_value = task_mock
mock_task_get.return_value = prowler_task

user_creation_response = client.post(
reverse("user-list"),
data={
"data": {
"type": "users",
"attributes": {
"name": "test",
"email": test_user,
"password": test_password,
},
}
},
format="vnd.api+json",
)
assert user_creation_response.status_code == 201

access_token, _ = get_api_tokens(client, test_user, test_password)
auth_headers = get_authorization_header(access_token)

create_provider_response = client.post(
reverse("provider-list"),
data={
"data": {
"type": "providers",
"attributes": {
"provider": Provider.ProviderChoices.AWS,
"uid": "123456789012",
},
}
},
format="vnd.api+json",
headers=auth_headers,
)
assert create_provider_response.status_code == 201
provider_id = create_provider_response.json()["data"]["id"]
provider_uid = create_provider_response.json()["data"]["attributes"]["uid"]

remove_provider = client.delete(
reverse("provider-detail", kwargs={"pk": provider_id}),
headers=auth_headers,
)
assert remove_provider.status_code == 202

recreate_provider_response = client.post(
reverse("provider-list"),
data={
"data": {
"type": "providers",
"attributes": {
"provider": Provider.ProviderChoices.AWS,
"uid": provider_uid,
},
}
},
format="vnd.api+json",
headers=auth_headers,
)
assert recreate_provider_response.status_code == 201
33 changes: 32 additions & 1 deletion api/src/backend/api/tests/test_db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@
from enum import Enum
from unittest.mock import patch

from api.db_utils import enum_to_choices, one_week_from_now, generate_random_token
import pytest

from api.db_utils import (
batch_delete,
enum_to_choices,
generate_random_token,
one_week_from_now,
)
from api.models import Provider


class TestEnumToChoices:
Expand Down Expand Up @@ -106,3 +114,26 @@ def test_generate_random_token_no_symbols_provided(self):
token = generate_random_token(length=5, symbols="")
# Default symbols
assert len(token) == 5


class TestBatchDelete:
@pytest.fixture
def create_test_providers(self, tenants_fixture):
tenant = tenants_fixture[0]
provider_id = 123456789012
provider_count = 10
for i in range(provider_count):
Provider.objects.create(
tenant=tenant,
uid=f"{provider_id + i}",
provider=Provider.ProviderChoices.AWS,
)
return provider_count

@pytest.mark.django_db
def test_batch_delete(self, create_test_providers):
_, summary = batch_delete(
Provider.objects.all(), batch_size=create_test_providers // 2
)
assert Provider.objects.all().count() == 0
assert summary == {"api.Provider": create_test_providers}

0 comments on commit c656cf8

Please # to comment.