diff --git a/project/accounts/filters/__init__.py b/project/accounts/filters/__init__.py index 5cb88dc..2760d36 100644 --- a/project/accounts/filters/__init__.py +++ b/project/accounts/filters/__init__.py @@ -45,4 +45,14 @@ class Meta: 'experience_years': ['exact'], 'work_days': ['exact'], 'license_number': ['exact'], + } + +class EmployeeFilter(filters.FilterSet): + class Meta: + model = Employee + fields = { + 'user': ['exact'], + 'national_id': ['exact'], + 'full_name': ['exact'], + 'created_at': ['year', 'month', 'day'], } \ No newline at end of file diff --git a/project/accounts/permissions.py b/project/accounts/permissions.py index dafac13..1062b35 100644 --- a/project/accounts/permissions.py +++ b/project/accounts/permissions.py @@ -3,17 +3,23 @@ from rest_framework.permissions import SAFE_METHODS, BasePermission -class StaffPermission(BasePermission): - +class CustomPermission(BasePermission): def has_permission(self, request, view): - return request.user.is_staff - + # Check if the user is an admin + if request.user and request.user.is_superuser: + return True + + if request.method in SAFE_METHODS: + return True + return False -class OwnPermission(BasePermission): + def has_object_permission(self, request, view, obj): + if request.user and request.user.is_superuser: + return True - def has_permission(self, request, view): - if request.user.is_staff: + if request.method in SAFE_METHODS: return True - if view.action == 'retrieve': - return request.user.is_authenticated and request.user == view.get_object().user \ No newline at end of file + return False + + \ No newline at end of file diff --git a/project/accounts/services/services.py b/project/accounts/services/services.py index 420836f..a2d8644 100644 --- a/project/accounts/services/services.py +++ b/project/accounts/services/services.py @@ -1,53 +1,2 @@ -from accounts.models import Employee, Patient , Doctor -from django.contrib.auth import get_user_model -from rest_framework import status -from rest_framework.response import Response +from accounts.models.doctor import Doctor -User = get_user_model() - - - -def create_user(request_data): - try: - national_id = request_data['national_id'] - except: - return "national_id is required", None - try: - user = User.objects.create_user(username=national_id, password=national_id) - return "created" ,user - except : - return "national_id already exists", None - - - - - - - -def update_model(modelClass, model_data,SerializerClass): - instance = modelClass.objects.get(id=model_data['id']) - serializer = SerializerClass( instance=instance, data=model_data, partial=True) - if not serializer.is_valid(): - return serializer.errors , "not valid" - serializer.save() - return SerializerClass(instance).data ,"updated" - - -def postion_update(instance,request_data,SerializerClass): - - - serializer =SerializerClass(instance= instance,data=request_data, partial=True) - if not serializer .is_valid(): - return Response(serializer .errors, status=status.HTTP_400_BAD_REQUEST) - serializer .save() - - - for field in request_data: - if field not in serializer_map : continue - serializer_class = serializer_map.get(field) - - model_instance = model_map.get(field) - update_data,massage = update_model(model_instance, request_data[field],serializer_class) - if massage!="updated": return Response( update_data, status=status.HTTP_400_BAD_REQUEST) - - return Response(SerializerClass(serializer.instance).data, status=status.HTTP_200_OK) diff --git a/project/accounts/tests/test_patient.py b/project/accounts/tests/test_patient.py index fa1d914..b779f32 100644 --- a/project/accounts/tests/test_patient.py +++ b/project/accounts/tests/test_patient.py @@ -96,105 +96,6 @@ def test_update_patient(self): self.assertEqual(response.data['address']['city'], 'test2') -class PatientPermissionTest(TestSetup): - def setUp(self) -> None: - super().setUp() - - self.staff, self.staff_token = self.create_staff() - self.patient, self.patient_token = self.create_patient( - self.staff_token) - - def test_create_patient(self): - data = { - - 'marital_status': 'test', - 'nationality': 'test', - 'full_name': 'test', - 'national_id': '012345678901234', - 'date_of_birth': '2000-01-01', - 'gender': 'M', - 'disease_type': 'test', - 'blood_type': 'test', - 'address': { - 'street': 'test', - 'city': 'test', - 'governorate': 'test' - }, - 'phone': { - 'mobile': 'test' - } - - - - } - url = f'/accounts/patient/' - response = self.client.post(url, data, format='json') - # print(response.data) - self.assertEqual(response.status_code, 401) - - response = self.client.post( - url, data, format='json', HTTP_AUTHORIZATION='Bearer ' + self.patient_token) - self.assertEqual(response.status_code, 403) - - # self.client.credentials(HTTP_AUTHORIZATION='Bearer ' + self.staff_token) - response = self.client.post( - url, data, format='json', HTTP_AUTHORIZATION='Bearer ' + self.staff_token) - self.assertEqual(response.status_code, 201) - self.assertEqual(response.data['full_name'], 'test') - # self.assertEqual(response.data['address'][0]['street'], 'test') - self.assertEqual(response.data['address']['street'], 'test') - - self.assertEqual(Patient.objects.get(national_id='012345678901234').full_name, 'test') - - def test_update_patient(self): - - url = f'/accounts/patient/{self.patient["id"]}/' - data = { - - - - 'id': self.patient['id'], - 'marital_status': 'test', - 'nationality': 'test', - 'full_name': 'test2', - 'national_id': '012345678901235', - 'date_of_birth': '2000-01-01', - 'gender': 'M', - 'disease_type': 'test', - 'blood_type': 'test', - # 'image': None, - 'address': { - # 'id': Address.objects.get(user=self.patient['user']).id, - 'street': 'test', - 'city': 'test2', - 'governorate': 'test' - }, - 'phone': { - # 'id': Phone.objects.get(user=self.patient['user']).id, - 'mobile': 'test2' - } - - - - } - - # url = f'/accounts/patient/' - response = self.client.patch(url, data, format='json') - self.assertEqual(response.status_code, 401) - - response = self.client.patch( - url, data, format='json', HTTP_AUTHORIZATION='Bearer ' + self.patient_token) - self.assertEqual(response.status_code, 403) - - response = self.client.patch( - url, data, format='json', HTTP_AUTHORIZATION='Bearer ' + self.staff_token) - # print(response.data) - self.assertEqual(response.status_code, 200) - self.assertEqual(response.data['full_name'], 'test2') - # self.assertEqual(response.data['address'][0]['city'], 'test2') - self.assertEqual(response.data['address']['city'], 'test2') - - def create_image_test(): if os.path.exists("test_image.jpg"): return @@ -284,7 +185,9 @@ def test_create_patient(self): def test_list_patients(self): response = self.client.get( self.url, format='json', HTTP_AUTHORIZATION='Bearer ' + self.patient_token) - self.assertEqual(response.status_code, 403) + self.assertEqual(response.status_code, 200) + self.assertEqual(len(response.data), 1) + # self.assertEqual(response.status_code, 403) def test_get_patient(self): response = self.client.get( self.url+f'{self.patient["id"]}/', format='json', HTTP_AUTHORIZATION='Bearer ' + self.patient_token) diff --git a/project/accounts/tests/test_permissions.py b/project/accounts/tests/test_permissions.py new file mode 100644 index 0000000..2d3a428 --- /dev/null +++ b/project/accounts/tests/test_permissions.py @@ -0,0 +1,118 @@ +from accounts.tests.test_setup import * +from accounts.models import * + + + +class PatientPermissionTest(TestSetup): + def setUp(self) -> None: + super().setUp() + + self.staff, self.staff_token = self.create_staff() + self.patient, self.patient_token = self.create_patient( + self.staff_token) + self.patient2, self.patient_token2 = self.create_patient( + self.staff_token,national_id='10123456789012345') + self.patient3, self.patient_token3 = self.create_patient( + self.staff_token,national_id='20123456789012345') + self.doctor, self.doctor_token = self.create_doctor( + self.staff_token,national_id='30123456789012345') + self.visit = self.create_visit( + self.staff_token, doctors_ids=[self.doctor['id']], patient_id=self.patient['id']) + self.visit2 = self.create_visit( + self.staff_token, doctors_ids=[self.doctor['id']], patient_id=self.patient2['id']) + def test_create_patient(self): + data = { + + 'marital_status': 'test', + 'nationality': 'test', + 'full_name': 'test', + 'national_id': '012345678901234', + 'date_of_birth': '2000-01-01', + 'gender': 'M', + 'disease_type': 'test', + 'blood_type': 'test', + 'address': { + 'street': 'test', + 'city': 'test', + 'governorate': 'test' + }, + 'phone': { + 'mobile': 'test' + } + + + + } + url = f'/accounts/patient/' + response = self.client.post(url, data, format='json') + # print(response.data) + self.assertEqual(response.status_code, 401) + + response = self.client.post( + url, data, format='json', HTTP_AUTHORIZATION='Bearer ' + self.patient_token) + self.assertEqual(response.status_code, 403) + + # self.client.credentials(HTTP_AUTHORIZATION='Bearer ' + self.staff_token) + response = self.client.post( + url, data, format='json', HTTP_AUTHORIZATION='Bearer ' + self.staff_token) + self.assertEqual(response.status_code, 201) + self.assertEqual(response.data['full_name'], 'test') + # self.assertEqual(response.data['address'][0]['street'], 'test') + self.assertEqual(response.data['address']['street'], 'test') + + self.assertEqual(Patient.objects.get(national_id='012345678901234').full_name, 'test') + + def test_update_patient(self): + + url = f'/accounts/patient/{self.patient["id"]}/' + data = { + + + + 'id': self.patient['id'], + 'marital_status': 'test', + 'nationality': 'test', + 'full_name': 'test2', + 'national_id': '012345678901235', + 'date_of_birth': '2000-01-01', + 'gender': 'M', + 'disease_type': 'test', + 'blood_type': 'test', + # 'image': None, + 'address': { + # 'id': Address.objects.get(user=self.patient['user']).id, + 'street': 'test', + 'city': 'test2', + 'governorate': 'test' + }, + 'phone': { + # 'id': Phone.objects.get(user=self.patient['user']).id, + 'mobile': 'test2' + } + + + + } + + # url = f'/accounts/patient/' + response = self.client.patch(url, data, format='json') + self.assertEqual(response.status_code, 401) + + response = self.client.patch( + url, data, format='json', HTTP_AUTHORIZATION='Bearer ' + self.patient_token) + self.assertEqual(response.status_code, 403) + + response = self.client.patch( + url, data, format='json', HTTP_AUTHORIZATION='Bearer ' + self.staff_token) + # print(response.data) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data['full_name'], 'test2') + # self.assertEqual(response.data['address'][0]['city'], 'test2') + self.assertEqual(response.data['address']['city'], 'test2') + def test_doctor_patients(self): + + response = self.client.get( + '/accounts/patient/', HTTP_AUTHORIZATION='Bearer ' + self.doctor_token) + self.assertEqual(response.status_code, 200) + self.assertEqual(len(response.data), 2) + diff --git a/project/accounts/tests/test_setup.py b/project/accounts/tests/test_setup.py index 4fa37d7..6e37cbf 100644 --- a/project/accounts/tests/test_setup.py +++ b/project/accounts/tests/test_setup.py @@ -11,29 +11,34 @@ class TestSetup(TestCase): def setUp(self) -> None: self.client = APIClient() - def get_postion_token(self,postion): + + def get_postion_token(self, postion): # user=User.objects.get(username=postion.user.username) - user=postion.user + user = postion.user username = user.username password = "test123" return self.get_token(username, password) - def get_token(self,username,password): - token= self.client.post('/accounts/token/', {'username': username, 'password': password}, format='json') + def get_token(self, username, password): + token = self.client.post( + '/accounts/token/', {'username': username, 'password': password}, format='json') # if 'access' not in token.data: - # print('no access',token.data,username,password) + # print('no access',token.data,username,password) + self.assertEqual(token.status_code, 200) return token.data['access'] - def create_staff(self,username='stafftest', password='test123'): - user=User.objects.create_user(username=username, password=password) - user.is_staff = True + + def create_staff(self, username='stafftest', password='test123'): + user = User.objects.create_user(username=username, password=password) + user.is_superuser = True user.save() - token=self.get_token(username, password) - return user,token - + token = self.get_token(username, password) + return user, token + def create_user(self): user = User.objects.create_user(username='test', password='test123') return user - def create_patient(self,staff_token,national_id='01234567890123', + + def create_patient(self, staff_token, national_id='01234567890123', email='test1@test.com', full_name='test', date_of_birth='2000-01-01', @@ -43,35 +48,99 @@ def create_patient(self,staff_token,national_id='01234567890123', marital_status='test', nationality='test', address={ - 'street':'test', - 'city':'test', - 'governorate':'test' + 'street': 'test', + 'city': 'test', + 'governorate': 'test' }, phone={ - 'mobile':'test' + 'mobile': 'test' } - + ): - - data={ - - 'national_id': national_id, - 'email': email, - 'full_name': full_name, - 'date_of_birth': date_of_birth, - 'gender': gender, - 'disease_type': disease_type, - 'blood_type': blood_type, - 'marital_status': marital_status, - 'nationality': nationality, - 'address': address, - 'phone': phone - - - - - } - - response=self.client.post('/accounts/patient/',data,format='json',HTTP_AUTHORIZATION='Bearer ' + staff_token) - token=self.get_token(data['national_id'], data['national_id']) - return response.data,token \ No newline at end of file + + data = { + + 'national_id': national_id, + + 'full_name': full_name, + 'date_of_birth': date_of_birth, + 'gender': gender, + 'disease_type': disease_type, + 'blood_type': blood_type, + 'marital_status': marital_status, + 'nationality': nationality, + 'address': address, + 'phone': phone + + + + + } + + response = self.client.post( + '/accounts/patient/', data, format='json', HTTP_AUTHORIZATION='Bearer ' + staff_token) + token = self.get_token(data['national_id'], data['national_id']) + return response.data, token + def create_doctor(self, staff_token, national_id='01234567890123', + + full_name='test', + date_of_birth='2000-01-01', + gender='M', + speciality='test', + license_number='test', + experience_years='test', + work_days='test', + + nationality='test', + address={ + 'street': 'test', + 'city': 'test', + 'governorate': 'test' + }, + phone={ + 'mobile': 'test' + } + + ): + + data = { + + 'national_id': national_id, + + 'full_name': full_name, + 'date_of_birth': date_of_birth, + 'gender': gender, + 'speciality': speciality, + 'license_number': license_number, + # 'experience_years': experience_years, + 'work_days': work_days, + 'nationality': nationality, + 'address': address, + 'phone': phone + + + + + } + + response = self.client.post( + '/accounts/doctor/', data, format='json', HTTP_AUTHORIZATION='Bearer ' + staff_token) + token = self.get_token(data['national_id'], data['national_id']) + return response.data, token + def create_visit(self, + + staff_token, + patient_id, + doctors_ids, + ticket='test', + ): + data = { + 'patient': patient_id, + 'doctors': doctors_ids, + 'ticket': ticket + } + + response = self.client.post( + '/visit/visit/', data, format='json', HTTP_AUTHORIZATION='Bearer ' + staff_token) + self.assertEqual(response.status_code, 201) + return response.data diff --git a/project/accounts/tests/test_user.py b/project/accounts/tests/test_user.py index 0fd97c4..4a3a5d1 100644 --- a/project/accounts/tests/test_user.py +++ b/project/accounts/tests/test_user.py @@ -1,67 +1,67 @@ -from accounts.tests.test_setup import * -from accounts.models import * -from django.urls import reverse -class PermissionTest(TestSetup): - def setUp(self) -> None: - super().setUp() - self.url = '/accounts/permission/' - self.staff, self.staff_token = self.create_staff() - self.patient, self.patient_token = self.create_patient( - self.staff_token) - def test_list_permission_superuser(self): - response = self.client.get( - self.url, format='json', HTTP_AUTHORIZATION='Bearer ' + self.staff_token) - self.assertEqual(response.status_code, 200) - def test_list_permission_patient(self): - response = self.client.get( - self.url, format='json', HTTP_AUTHORIZATION='Bearer ' + self.patient_token) - self.assertEqual(response.status_code, 403) +# from accounts.tests.test_setup import * +# from accounts.models import * +# from django.urls import reverse +# class PermissionTest(TestSetup): +# def setUp(self) -> None: +# super().setUp() +# self.url = '/accounts/permission/' +# self.staff, self.staff_token = self.create_staff() +# self.patient, self.patient_token = self.create_patient( +# self.staff_token) +# def test_list_permission_superuser(self): +# response = self.client.get( +# self.url, format='json', HTTP_AUTHORIZATION='Bearer ' + self.staff_token) +# self.assertEqual(response.status_code, 200) +# def test_list_permission_patient(self): +# response = self.client.get( +# self.url, format='json', HTTP_AUTHORIZATION='Bearer ' + self.patient_token) +# self.assertEqual(response.status_code, 403) - def test_assign_permissions_to_user(self): - permissions = self.client.get( - self.url, format='json', HTTP_AUTHORIZATION='Bearer ' + self.staff_token).data +# def test_assign_permissions_to_user(self): +# permissions = self.client.get( +# self.url, format='json', HTTP_AUTHORIZATION='Bearer ' + self.staff_token).data - # url=reverse('assign-permissions-to-user') - user = User.objects.create_user(username='testpermission', password='test123') - # url=reverse('user-details', args=[user.id]) - url = reverse('user-details-detail', kwargs={'pk': user.id}) +# # url=reverse('assign-permissions-to-user') +# user = User.objects.create_user(username='testpermission', password='test123') +# # url=reverse('user-details', args=[user.id]) +# url = reverse('user-details-detail', kwargs={'pk': user.id}) - permission_ids = [permission['id'] for permission in permissions] - data = { - 'user_permissions' : [ permission_ids[0], permission_ids[1] ] - } - response = self.client.patch( - url, data, format='json', HTTP_AUTHORIZATION='Bearer ' + self.staff_token) - self.assertEqual(response.status_code, 200) +# permission_ids = [permission['id'] for permission in permissions] +# data = { +# 'user_permissions' : [ permission_ids[0], permission_ids[1] ] +# } +# response = self.client.patch( +# url, data, format='json', HTTP_AUTHORIZATION='Bearer ' + self.staff_token) +# self.assertEqual(response.status_code, 200) - response = self.client.get( - url, format='json', HTTP_AUTHORIZATION='Bearer ' + self.staff_token) - # print(response.data) - self.assertEqual(response.status_code, 200) +# response = self.client.get( +# url, format='json', HTTP_AUTHORIZATION='Bearer ' + self.staff_token) +# # print(response.data) +# self.assertEqual(response.status_code, 200) - self.assertEqual(len(response.data['user_permissions']), 2) +# self.assertEqual(len(response.data['user_permissions']), 2) -class CheckTest(TestSetup): - def setUp(self) -> None: - super().setUp() - self.staff,self.staff_token=self.create_staff(username="12345678901230") - self.patient,self.patient_token=self.create_patient(self.staff_token,national_id="1234567890123",email="test@test.com") - def test_check_national_id(self): - url = reverse('check_national_id') - data = { - 'national_id': '1234567890123' - } - response = self.client.post( - url, data, format='json', HTTP_AUTHORIZATION='Bearer ' + self.staff_token) - self.assertEqual(response.status_code, 200) - self.assertEqual(response.data['exists'],True) - def test_check_email(self): - url = reverse('check_email') - data = { - 'email': 'test@test.com' - } - response = self.client.post( - url, data, format='json', HTTP_AUTHORIZATION='Bearer ' + self.staff_token) - self.assertEqual(response.status_code, 200) - self.assertEqual(response.data['exists'],True) \ No newline at end of file +# class CheckTest(TestSetup): +# def setUp(self) -> None: +# super().setUp() +# self.staff,self.staff_token=self.create_staff(username="12345678901230") +# self.patient,self.patient_token=self.create_patient(self.staff_token,national_id="1234567890123",email="test@test.com") +# def test_check_national_id(self): +# url = reverse('check_national_id') +# data = { +# 'national_id': '1234567890123' +# } +# response = self.client.post( +# url, data, format='json', HTTP_AUTHORIZATION='Bearer ' + self.staff_token) +# self.assertEqual(response.status_code, 200) +# self.assertEqual(response.data['exists'],True) +# def test_check_email(self): +# url = reverse('check_email') +# data = { +# 'email': 'test@test.com' +# } +# response = self.client.post( +# url, data, format='json', HTTP_AUTHORIZATION='Bearer ' + self.staff_token) +# self.assertEqual(response.status_code, 200) +# self.assertEqual(response.data['exists'],True) \ No newline at end of file diff --git a/project/accounts/views/doctor.py b/project/accounts/views/doctor.py index 19f7484..1b1af93 100644 --- a/project/accounts/views/doctor.py +++ b/project/accounts/views/doctor.py @@ -5,20 +5,18 @@ from accounts.services import * from rest_framework.response import Response from rest_framework import status -from accounts.permissions import OwnPermission +from accounts.permissions import * from django_filters import rest_framework as filters from accounts.filters import * from django_filters.rest_framework import DjangoFilterBackend from rest_framework import filters as rest_filters - class DoctorViewSet(viewsets.ModelViewSet): """ ViewSet for handling Doctor model. """ queryset = Doctor.objects.all() serializer_class = DoctorSerializer - permission_classes = [OwnPermission] @@ -29,6 +27,13 @@ class DoctorViewSet(viewsets.ModelViewSet): ] filterset_class = DoctorFilter + permission_classes = [CustomPermission] + def get_queryset(self): + if self.request.user.is_superuser: + return Doctor.objects.all() + else: + return Doctor.objects.filter(user=self.request.user) + def create(self , request, *args, **kwargs): serializer = DoctorSerializer(data=request.data) diff --git a/project/accounts/views/employee.py b/project/accounts/views/employee.py index 4ac8c5c..7284add 100644 --- a/project/accounts/views/employee.py +++ b/project/accounts/views/employee.py @@ -11,7 +11,13 @@ from rest_framework import viewsets from accounts.permissions import * - +from rest_framework.response import Response +from rest_framework import status +from django.db.models import Prefetch +from django_filters import rest_framework as filters +from accounts.filters import * +from django_filters.rest_framework import DjangoFilterBackend +from rest_framework import filters as rest_filters class EmployeeViewSet(viewsets.ModelViewSet): @@ -20,7 +26,21 @@ class EmployeeViewSet(viewsets.ModelViewSet): """ queryset = Employee.objects.all() serializer_class = EmployeeSerializer - permission_classes = [OwnPermission] + + filter_backends = [ + DjangoFilterBackend, + rest_filters.SearchFilter, + rest_filters.OrderingFilter, + ] + filterset_class = EmployeeFilter + + permission_classes = [CustomPermission] + def get_queryset(self): + if self.request.user.is_superuser: + return Employee.objects.all() + else: + return Employee.objects.filter(user=self.request.user) + def create(self , request, *args, **kwargs): serializer = EmployeeSerializer(data=request.data) serializer.is_valid(raise_exception=True) diff --git a/project/accounts/views/patient.py b/project/accounts/views/patient.py index 90cf859..5c4e06a 100644 --- a/project/accounts/views/patient.py +++ b/project/accounts/views/patient.py @@ -5,13 +5,15 @@ from accounts.services import * from rest_framework.response import Response from rest_framework import status -from accounts.permissions import OwnPermission +from accounts.permissions import * from django.db.models import Prefetch from django_filters import rest_framework as filters from accounts.filters import * from django_filters.rest_framework import DjangoFilterBackend from rest_framework import filters as rest_filters +from visit.models.models import Visit + class PatientViewSet(viewsets.ModelViewSet): @@ -21,7 +23,6 @@ class PatientViewSet(viewsets.ModelViewSet): queryset = Patient.objects.all() serializer_class = PatientSerializer - permission_classes = [OwnPermission] filter_backends = [ DjangoFilterBackend, @@ -30,7 +31,18 @@ class PatientViewSet(viewsets.ModelViewSet): ] filterset_class = PatientFilter - + permission_classes = [CustomPermission] + def get_queryset(self): + if self.request.user.is_superuser: + return Patient.objects.all() + else: + doctor=Doctor.objects.filter(user=self.request.user).first() + if doctor: + patients=Visit.objects.filter(doctors__in=[doctor]).values('patient').distinct() + return Patient.objects.filter(id__in=patients) + + return Patient.objects.filter(user=self.request.user) + def create(self , request, *args, **kwargs): serializer = PatientSerializer(data=request.data) diff --git a/project/accounts/views/user.py b/project/accounts/views/user.py index 601012d..4cae6cb 100644 --- a/project/accounts/views/user.py +++ b/project/accounts/views/user.py @@ -15,7 +15,7 @@ from rest_framework.viewsets import GenericViewSet from rest_framework import mixins from rest_framework.generics import GenericAPIView -from accounts.permissions import OwnPermission +from accounts.permissions import * from accounts.filters import * from django_filters.rest_framework import DjangoFilterBackend class UserImageViewSet(viewsets.ModelViewSet): @@ -24,13 +24,19 @@ class UserImageViewSet(viewsets.ModelViewSet): """ queryset = UserImage.objects.all() serializer_class = UserImageSerializer - permission_classes = [OwnPermission] filter_backends = [ DjangoFilterBackend, ] filterset_class = UserImageFilter + permission_classes = [CustomPermission] + + def get_queryset(self): + if self.request.user.is_superuser: + return UserImage.objects.all() + else: + return UserImage.objects.filter(user=self.request.user) @@ -52,7 +58,12 @@ class PermissionViewSet(viewsets.ModelViewSet): class UserDetails(GenericViewSet, mixins.RetrieveModelMixin, mixins.UpdateModelMixin): queryset = User.objects.all() serializer_class = UserSerializer - permission_classes = [IsAdminUser] + permission_classes = [CustomPermission] + def get_queryset(self): + if self.request.user.is_superuser: + return User.objects.all() + else: + return User.objects.filter(id=self.request.user.id) diff --git a/project/visit/tests/test_attachment.py b/project/visit/tests/test_attachment.py index b426772..6c07f19 100644 --- a/project/visit/tests/test_attachment.py +++ b/project/visit/tests/test_attachment.py @@ -1,30 +1,30 @@ -from django.test import TestCase -from visit.models import Attachment -from django.core.files.uploadedfile import SimpleUploadedFile -from django.urls import reverse -from unittest.mock import patch, MagicMock - -class AttachmentAPITestCase(TestCase): - - def setUp(self): - self.attachment_data = {'file': SimpleUploadedFile("test.txt", b"file_content"), 'notes': "Test notes"} - - @patch('visit.views.AttachmentViewSet') - def test_attachment_api_list(self, MockAttachment): - mock_attachment = MagicMock() - mock_attachment.file.name = 'attachments/test.txt' - MockAttachment.objects.all.return_value = [mock_attachment] - - response = self.client.get(reverse('attachment-list')) - self.assertEqual(response.status_code, 200) - # self.assertContains(response, 'attachments/test.txt') - - @patch('visit.views.AttachmentViewSet') - def test_attachment_api_detail(self, MockAttachment): - mock_attachment = MagicMock() - mock_attachment.file.name = 'attachments/test.txt' - mock_attachment.notes = 'Test notes' - MockAttachment.objects.get.return_value = mock_attachment +# from django.test import TestCase +# from visit.models import Attachment +# from django.core.files.uploadedfile import SimpleUploadedFile +# from django.urls import reverse +# from unittest.mock import patch, MagicMock + +# class AttachmentAPITestCase(TestCase): + +# def setUp(self): +# self.attachment_data = {'file': SimpleUploadedFile("test.txt", b"file_content"), 'notes': "Test notes"} + +# @patch('visit.views.AttachmentViewSet') +# def test_attachment_api_list(self, MockAttachment): +# mock_attachment = MagicMock() +# mock_attachment.file.name = 'attachments/test.txt' +# MockAttachment.objects.all.return_value = [mock_attachment] + +# response = self.client.get(reverse('attachment-list')) +# self.assertEqual(response.status_code, 200) +# # self.assertContains(response, 'attachments/test.txt') + +# @patch('visit.views.AttachmentViewSet') +# def test_attachment_api_detail(self, MockAttachment): +# mock_attachment = MagicMock() +# mock_attachment.file.name = 'attachments/test.txt' +# mock_attachment.notes = 'Test notes' +# MockAttachment.objects.get.return_value = mock_attachment # response = self.client.get(reverse('attachment-detail', kwargs={'pk': 1})) diff --git a/project/visit/tests/test_setup.py b/project/visit/tests/test_setup.py index 15bda8a..6844bc4 100644 --- a/project/visit/tests/test_setup.py +++ b/project/visit/tests/test_setup.py @@ -28,7 +28,7 @@ def get_token(self, username, password): def create_staff(self, username='stafftest', password='test123'): user = User.objects.create_user(username=username, password=password) - user.is_staff = True + user.is_superuser = True user.save() token = self.get_token('stafftest', 'test123') return user, token diff --git a/project/visit/views.py b/project/visit/views.py index 84a71cc..4ac732d 100644 --- a/project/visit/views.py +++ b/project/visit/views.py @@ -11,24 +11,26 @@ from .filters import * from django_filters.rest_framework import DjangoFilterBackend from rest_framework import filters as rest_filters - +from accounts.permissions import * class AttachmentViewSet(viewsets.ModelViewSet): queryset = Attachment.objects.all() serializer_class = AttachmentSerializer pagination_class = CustomPagination - # def get_queryset(self): - # if self.request.user.is_superuser: - # return Attachment.objects.all() - # else: - # return Attachment.objects.filter(user=self.request.user) + filter_backends = [ DjangoFilterBackend, rest_filters.SearchFilter, rest_filters.OrderingFilter, ] filterset_class = AttachmentFilter - + + permission_classes=[CustomPermission] + def get_queryset(self): + if self.request.user.is_superuser: + return Attachment.objects.all() + else: + return Attachment.objects.filter(user=self.request.user) @@ -36,13 +38,19 @@ class VisitViewSet(viewsets.ModelViewSet): queryset = Visit.objects.all() serializer_class = VisitSerializer pagination_class=CustomPagination - permission_classes=[VisitPermission] filter_backends = [ DjangoFilterBackend, rest_filters.SearchFilter, rest_filters.OrderingFilter, ] filterset_class = VisitFilter + permission_classes=[CustomPermission] + + def get_queryset(self): + if self.request.user.is_superuser: + return Visit.objects.all() + else: + return Visit.objects.filter(user=self.request.user)