diff --git a/judge/views/api/api_v2.py b/judge/views/api/api_v2.py index 17db906150..6db452b17c 100644 --- a/judge/views/api/api_v2.py +++ b/judge/views/api/api_v2.py @@ -1,7 +1,7 @@ from operator import attrgetter from django.conf import settings -from django.core.exceptions import PermissionDenied, ValidationError +from django.core.exceptions import ObjectDoesNotExist, PermissionDenied, ValidationError from django.db.models import Count, F, OuterRef, Prefetch, Q, Subquery from django.http import Http404, JsonResponse from django.utils import timezone @@ -18,6 +18,30 @@ from judge.views.submission import group_test_cases +class BaseSimpleFilter: + def __init__(self, lookup): + self.lookup = lookup + + def get_object(self, key): + raise NotImplementedError() + + def to_filter(self, key): + try: + return {self.lookup: self.get_object(key)} + except ObjectDoesNotExist: + return {self.lookup: None} + + +class ProfileSimpleFilter(BaseSimpleFilter): + def get_object(self, key): + return Profile.objects.get(user__username=key) + + +class ProblemSimpleFilter(BaseSimpleFilter): + def get_object(self, key): + return Problem.objects.get(code=key) + + class APILoginRequiredException(Exception): pass @@ -108,10 +132,13 @@ def filter_queryset(self, queryset): for key, filter_name in self.basic_filters: if key in self.request.GET: - # May raise ValueError or ValidationError, but is caught in APIMixin - queryset = queryset.filter(**{ - filter_name: self.request.GET.get(key), - }) + if isinstance(filter_name, BaseSimpleFilter): + queryset = queryset.filter(**filter_name.to_filter(self.request.GET.get(key))) + else: + # May raise ValueError or ValidationError, but is caught in APIMixin + queryset = queryset.filter(**{ + filter_name: self.request.GET.get(key), + }) self.used_basic_filters.add(key) for key, filter_name in self.list_filters: @@ -516,8 +543,8 @@ def get_object_data(self, profile): class APISubmissionList(APIListView): model = Submission basic_filters = ( - ('user', 'user__user__username'), - ('problem', 'problem__code'), + ('user', ProfileSimpleFilter('user')), + ('problem', ProblemSimpleFilter('problem')), ) list_filters = ( ('language', 'language__key'),