Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

fix(db-utils): fix batch_delete function #6285

Merged
merged 1 commit into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions api/src/backend/api/db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from django.conf import settings
from django.contrib.auth.models import BaseUserManager
from django.core.paginator import Paginator
from django.db import connection, models, transaction
from psycopg2 import connect as psycopg2_connect
from psycopg2.extensions import AsIs, new_type, register_adapter, register_type
Expand Down Expand Up @@ -120,15 +119,18 @@ def batch_delete(queryset, batch_size=5000):
total_deleted = 0
deletion_summary = {}

paginator = Paginator(queryset.order_by("id").only("id"), batch_size)

for page_num in paginator.page_range:
batch_ids = [obj.id for obj in paginator.page(page_num).object_list]
while True:
# Get a batch of IDs to delete
batch_ids = set(
queryset.values_list("id", flat=True).order_by("id")[:batch_size]
)
if not batch_ids:
# No more objects to delete
break

deleted_count, deleted_info = queryset.filter(id__in=batch_ids).delete()

total_deleted += deleted_count

for model_label, count in deleted_info.items():
deletion_summary[model_label] = deletion_summary.get(model_label, 0) + count

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Generated by Django 5.1.1 on 2024-12-20 13:16

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("api", "0004_rbac_missing_admin_roles"),
]

operations = [
migrations.RemoveConstraint(
model_name="provider",
name="unique_provider_uids",
),
migrations.AddConstraint(
model_name="provider",
constraint=models.UniqueConstraint(
fields=("tenant_id", "provider", "uid", "is_deleted"),
name="unique_provider_uids",
),
),
]
2 changes: 1 addition & 1 deletion api/src/backend/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ class Meta(RowLevelSecurityProtectedModel.Meta):

constraints = [
models.UniqueConstraint(
fields=("tenant_id", "provider", "uid"),
fields=("tenant_id", "provider", "uid", "is_deleted"),
name="unique_provider_uids",
),
RowLevelSecurityConstraint(
Expand Down
85 changes: 85 additions & 0 deletions api/src/backend/api/tests/integration/test_providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from unittest.mock import Mock, patch

import pytest
from conftest import get_api_tokens, get_authorization_header
from django.urls import reverse
from rest_framework.test import APIClient

from api.models import Provider


@patch("api.v1.views.Task.objects.get")
@patch("api.v1.views.delete_provider_task.delay")
@pytest.mark.django_db
def test_delete_provider_without_executing_task(
mock_delete_task, mock_task_get, create_test_user, tenants_fixture, tasks_fixture
):
client = APIClient()

test_user = "test_email@prowler.com"
test_password = "test_password"

prowler_task = tasks_fixture[0]
task_mock = Mock()
task_mock.id = prowler_task.id
mock_delete_task.return_value = task_mock
mock_task_get.return_value = prowler_task

user_creation_response = client.post(
reverse("user-list"),
data={
"data": {
"type": "users",
"attributes": {
"name": "test",
"email": test_user,
"password": test_password,
},
}
},
format="vnd.api+json",
)
assert user_creation_response.status_code == 201

access_token, _ = get_api_tokens(client, test_user, test_password)
auth_headers = get_authorization_header(access_token)

create_provider_response = client.post(
reverse("provider-list"),
data={
"data": {
"type": "providers",
"attributes": {
"provider": Provider.ProviderChoices.AWS,
"uid": "123456789012",
},
}
},
format="vnd.api+json",
headers=auth_headers,
)
assert create_provider_response.status_code == 201
provider_id = create_provider_response.json()["data"]["id"]
provider_uid = create_provider_response.json()["data"]["attributes"]["uid"]

remove_provider = client.delete(
reverse("provider-detail", kwargs={"pk": provider_id}),
headers=auth_headers,
)
assert remove_provider.status_code == 202

recreate_provider_response = client.post(
reverse("provider-list"),
data={
"data": {
"type": "providers",
"attributes": {
"provider": Provider.ProviderChoices.AWS,
"uid": provider_uid,
},
}
},
format="vnd.api+json",
headers=auth_headers,
)
assert recreate_provider_response.status_code == 201
33 changes: 32 additions & 1 deletion api/src/backend/api/tests/test_db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@
from enum import Enum
from unittest.mock import patch

from api.db_utils import enum_to_choices, one_week_from_now, generate_random_token
import pytest

from api.db_utils import (
batch_delete,
enum_to_choices,
generate_random_token,
one_week_from_now,
)
from api.models import Provider


class TestEnumToChoices:
Expand Down Expand Up @@ -106,3 +114,26 @@ def test_generate_random_token_no_symbols_provided(self):
token = generate_random_token(length=5, symbols="")
# Default symbols
assert len(token) == 5


class TestBatchDelete:
@pytest.fixture
def create_test_providers(self, tenants_fixture):
tenant = tenants_fixture[0]
provider_id = 123456789012
provider_count = 10
for i in range(provider_count):
Provider.objects.create(
tenant=tenant,
uid=f"{provider_id + i}",
provider=Provider.ProviderChoices.AWS,
)
return provider_count

@pytest.mark.django_db
def test_batch_delete(self, create_test_providers):
_, summary = batch_delete(
Provider.objects.all(), batch_size=create_test_providers // 2
)
assert Provider.objects.all().count() == 0
assert summary == {"api.Provider": create_test_providers}
Loading