diff --git a/docs/testing/automated.rst b/docs/testing/automated.rst index 9da4f532..db62f092 100644 --- a/docs/testing/automated.rst +++ b/docs/testing/automated.rst @@ -41,6 +41,17 @@ Or:: # samples behave normally. assert waffle.sample_is_active('sample_name') +``override_flag`` also allows providing values for its different checks. +For example:: + + @override_flag('flag_name', percent=25) + def test_with_flag(): + ... + + @override_flag('flag_name', staff=True, superusers=True) + def test_with_flag(): + ... + All three will restore the relevant flag, sample, or switch to its previous state: they will restore the old values and will delete objects that did not exist. diff --git a/waffle/tests/test_testutils.py b/waffle/tests/test_testutils.py index 3872a041..98f5560b 100644 --- a/waffle/tests/test_testutils.py +++ b/waffle/tests/test_testutils.py @@ -1,11 +1,15 @@ +import contextlib +import random from decimal import Decimal +from unittest import mock +from django.contrib.auth import get_user_model from django.contrib.auth.models import AnonymousUser from django.db import transaction -from django.test import TransactionTestCase, RequestFactory, TestCase +from django.test import RequestFactory, TestCase, TransactionTestCase import waffle -from waffle.testutils import override_switch, override_flag, override_sample +from waffle.testutils import override_flag, override_sample, override_switch class OverrideSwitchMixin: @@ -116,6 +120,14 @@ def req(): return r +@contextlib.contextmanager +def provide_user(**kwargs): + user = get_user_model()(**kwargs) + user.save() + yield user + user.delete() + + class OverrideFlagTestsMixin: def test_flag_existed_and_was_active(self): waffle.get_waffle_flag_model().objects.create(name='foo', everyone=True) @@ -173,6 +185,240 @@ def test_cache_is_flushed_by_testutils_even_in_transaction(self): assert waffle.flag_is_active(req(), 'foo') + @mock.patch.object(random, 'uniform') + def test_flag_existed_and_was_active_for_percent(self, uniform): + waffle.get_waffle_flag_model().objects.create(name='foo', everyone=None, percent='50') + + uniform.return_value = '75' + + with override_flag('foo', percent=80.0): + assert waffle.flag_is_active(req(), 'foo') + + with override_flag('foo', percent=40.0): + assert not waffle.flag_is_active(req(), 'foo') + + assert waffle.get_waffle_flag_model().objects.get(name='foo').percent == Decimal('50') + + @mock.patch.object(random, 'uniform') + def test_flag_existed_and_was_inactive_for_percent(self, uniform): + waffle.get_waffle_flag_model().objects.create(name='foo', everyone=None, percent=None) + + uniform.return_value = '75' + + with override_flag('foo', percent=80.0): + assert waffle.flag_is_active(req(), 'foo') + + with override_flag('foo', percent=40.0): + assert not waffle.flag_is_active(req(), 'foo') + + assert not waffle.get_waffle_flag_model().objects.get(name='foo').percent + + def test_flag_existed_and_was_active_for_testing(self): + waffle.get_waffle_flag_model().objects.create(name='foo', everyone=None, testing=True) + + with override_flag('foo', testing=True): + request = req() + request.COOKIES['dwft_foo'] = 'True' + assert waffle.flag_is_active(request, 'foo') + assert not waffle.flag_is_active(req(), 'foo') + + with override_flag('foo', testing=False): + request = req() + request.COOKIES['dwft_foo'] = 'True' + assert not waffle.flag_is_active(request, 'foo') + assert not waffle.flag_is_active(req(), 'foo') + + assert waffle.get_waffle_flag_model().objects.get(name='foo').testing + + def test_flag_existed_and_was_inactive_for_testing(self): + waffle.get_waffle_flag_model().objects.create(name='foo', everyone=None, testing=False) + + with override_flag('foo', testing=True): + request = req() + request.COOKIES['dwft_foo'] = 'True' + assert waffle.flag_is_active(request, 'foo') + assert not waffle.flag_is_active(req(), 'foo') + + with override_flag('foo', testing=False): + request = req() + request.COOKIES['dwft_foo'] = 'True' + assert not waffle.flag_is_active(request, 'foo') + assert not waffle.flag_is_active(req(), 'foo') + + assert not waffle.get_waffle_flag_model().objects.get(name='foo').testing + + def test_flag_existed_and_was_active_for_superusers(self): + waffle.get_waffle_flag_model().objects.create(name='foo', everyone=None, superusers=True) + + with override_flag('foo', superusers=True): + with provide_user(username='foo', is_superuser=True) as user: + request = req() + request.user = user + assert waffle.flag_is_active(request, 'foo') + with provide_user(username='foo', is_superuser=False) as user: + request = req() + request.user = user + assert not waffle.flag_is_active(request, 'foo') + + with override_flag('foo', superusers=False): + with provide_user(username='foo', is_superuser=True) as user: + request = req() + request.user = user + assert not waffle.flag_is_active(request, 'foo') + with provide_user(username='foo', is_superuser=False) as user: + request = req() + request.user = user + assert not waffle.flag_is_active(request, 'foo') + + assert waffle.get_waffle_flag_model().objects.get(name='foo').superusers + + def test_flag_existed_and_was_inactive_for_superusers(self): + waffle.get_waffle_flag_model().objects.create(name='foo', everyone=None, superusers=False) + + with override_flag('foo', superusers=True): + with provide_user(username='foo', is_superuser=True) as user: + request = req() + request.user = user + assert waffle.flag_is_active(request, 'foo') + with provide_user(username='foo', is_superuser=False) as user: + request = req() + request.user = user + assert not waffle.flag_is_active(request, 'foo') + + with override_flag('foo', superusers=False): + with provide_user(username='foo', is_superuser=True) as user: + request = req() + request.user = user + assert not waffle.flag_is_active(request, 'foo') + with provide_user(username='foo', is_superuser=False) as user: + request = req() + request.user = user + assert not waffle.flag_is_active(request, 'foo') + + assert not waffle.get_waffle_flag_model().objects.get(name='foo').superusers + + def test_flag_existed_and_was_active_for_staff(self): + waffle.get_waffle_flag_model().objects.create(name='foo', everyone=None, staff=True) + + with override_flag('foo', staff=True): + with provide_user(username='foo', is_staff=True) as user: + request = req() + request.user = user + assert waffle.flag_is_active(request, 'foo') + with provide_user(username='foo', is_staff=False) as user: + request = req() + request.user = user + assert not waffle.flag_is_active(request, 'foo') + + with override_flag('foo', staff=False): + with provide_user(username='foo', is_staff=True) as user: + request = req() + request.user = user + assert not waffle.flag_is_active(request, 'foo') + with provide_user(username='foo', is_staff=False) as user: + request = req() + request.user = user + assert not waffle.flag_is_active(request, 'foo') + + assert waffle.get_waffle_flag_model().objects.get(name='foo').staff + + def test_flag_existed_and_was_inactive_for_staff(self): + waffle.get_waffle_flag_model().objects.create(name='foo', everyone=None, staff=False) + + with override_flag('foo', staff=True): + with provide_user(username='foo', is_staff=True) as user: + request = req() + request.user = user + assert waffle.flag_is_active(request, 'foo') + with provide_user(username='foo', is_staff=False) as user: + request = req() + request.user = user + assert not waffle.flag_is_active(request, 'foo') + + with override_flag('foo', staff=False): + with provide_user(username='foo', is_staff=True) as user: + request = req() + request.user = user + assert not waffle.flag_is_active(request, 'foo') + with provide_user(username='foo', is_staff=False) as user: + request = req() + request.user = user + assert not waffle.flag_is_active(request, 'foo') + + assert not waffle.get_waffle_flag_model().objects.get(name='foo').staff + + def test_flag_existed_and_was_active_for_authenticated(self): + waffle.get_waffle_flag_model().objects.create(name='foo', everyone=None, authenticated=True) + + with override_flag('foo', authenticated=True): + with provide_user(username='foo') as user: + request = req() + request.user = user + assert waffle.flag_is_active(request, 'foo') + assert not waffle.flag_is_active(req(), 'foo') + + with override_flag('foo', authenticated=False): + with provide_user(username='foo') as user: + request = req() + request.user = user + assert not waffle.flag_is_active(request, 'foo') + assert not waffle.flag_is_active(req(), 'foo') + + assert waffle.get_waffle_flag_model().objects.get(name='foo').authenticated + + def test_flag_existed_and_was_inactive_for_authenticated(self): + waffle.get_waffle_flag_model().objects.create(name='foo', everyone=None, authenticated=False) + + with override_flag('foo', authenticated=True): + with provide_user(username='foo') as user: + request = req() + request.user = user + assert waffle.flag_is_active(request, 'foo') + assert not waffle.flag_is_active(req(), 'foo') + + with override_flag('foo', authenticated=False): + with provide_user(username='foo') as user: + request = req() + request.user = user + assert not waffle.flag_is_active(request, 'foo') + assert not waffle.flag_is_active(req(), 'foo') + + assert not waffle.get_waffle_flag_model().objects.get(name='foo').authenticated + + def test_flag_existed_and_was_active_for_languages(self): + waffle.get_waffle_flag_model().objects.create(name='foo', everyone=None, languages="en,es") + + with override_flag('foo', languages="en,es"): + request = req() + request.LANGUAGE_CODE = "en" + assert waffle.flag_is_active(request, 'foo') + assert not waffle.flag_is_active(req(), 'foo') + + with override_flag('foo', languages=""): + request = req() + request.LANGUAGE_CODE = "en" + assert not waffle.flag_is_active(request, 'foo') + assert not waffle.flag_is_active(req(), 'foo') + + assert waffle.get_waffle_flag_model().objects.get(name='foo').languages == "en,es" + + def test_flag_existed_and_was_inactive_for_languages(self): + waffle.get_waffle_flag_model().objects.create(name='foo', everyone=None, languages="") + + with override_flag('foo', languages="en,es"): + request = req() + request.LANGUAGE_CODE = "en" + assert waffle.flag_is_active(request, 'foo') + assert not waffle.flag_is_active(req(), 'foo') + + with override_flag('foo', languages=""): + request = req() + request.LANGUAGE_CODE = "en" + assert not waffle.flag_is_active(request, 'foo') + assert not waffle.flag_is_active(req(), 'foo') + + assert waffle.get_waffle_flag_model().objects.get(name='foo').languages == "" + class OverrideFlagsTestCase(OverrideFlagTestsMixin, TestCase): """ diff --git a/waffle/testutils.py b/waffle/testutils.py index c54b81cc..f258c329 100644 --- a/waffle/testutils.py +++ b/waffle/testutils.py @@ -1,3 +1,4 @@ +import sys from typing import Generic, Optional, TypeVar, Union from django.test.utils import TestContextDecorator @@ -10,11 +11,27 @@ from waffle.models import Switch, Sample +if sys.version_info >= (3, 8): + from typing import TypedDict +else: + TypedDict = dict + + __all__ = ['override_flag', 'override_sample', 'override_switch'] _T = TypeVar("_T") +class _FlagValues(TypedDict): + active: Optional[bool] + percent: Optional[float] + testing: Optional[bool] + superusers: Optional[bool] + staff: Optional[bool] + authenticated: Optional[bool] + languages: Optional[str] + + class _overrider(TestContextDecorator, Generic[_T]): def __init__(self, name: str, active: _T): super().__init__() @@ -44,6 +61,69 @@ def disable(self) -> None: self.update(self.old_value) +class _flag_overrider(TestContextDecorator): + def __init__( + self, + name: str, + active: Optional[bool] = None, + percent: Optional[float] = None, + testing: Optional[bool] = None, + superusers: Optional[bool] = None, + staff: Optional[bool] = None, + authenticated: Optional[bool] = None, + languages: Optional[str] = None, + ): + super().__init__() + self.name = name + self.active = active + self.percent = percent + self.testing = testing + self.superusers = superusers + self.staff = staff + self.authenticated = authenticated + self.languages = languages + + def get(self) -> None: + self.obj, self.created = self.cls.objects.get_or_create(name=self.name) + + def update( + self, + active: Optional[bool] = None, + percent: Optional[float] = None, + testing: Optional[bool] = None, + superusers: Optional[bool] = None, + staff: Optional[bool] = None, + authenticated: Optional[bool] = None, + languages: Optional[str] = None, + ) -> None: + raise NotImplementedError + + def get_values(self) -> _FlagValues: + raise NotImplementedError + + def enable(self) -> None: + self.get() + self.old_values = self.get_values() + current_values = _FlagValues( + active=self.active, + percent=self.percent, + testing=self.testing, + superusers=self.superusers, + staff=self.staff, + authenticated=self.authenticated, + languages=self.languages, + ) + if self.old_values != current_values: + self.update(**current_values) + + def disable(self) -> None: + if self.created: + self.obj.delete() + self.obj.flush() + else: + self.update(**self.old_values) + + class override_switch(_overrider[bool]): """ override_switch is a contextmanager for easier testing of switches. @@ -54,7 +134,7 @@ class override_switch(_overrider[bool]): with override_switch('happy_mode', active=True): ... - If `Switch` already existed, it's value would be changed inside the context + If `Switch` already existed, its value would be changed inside the context block, then restored to the original value. If `Switch` did not exist before entering the context, it is created, then removed at the end of the block. @@ -78,17 +158,71 @@ def get_value(self) -> bool: return self.obj.active -class override_flag(_overrider[Optional[bool]]): +class override_flag(_flag_overrider): + """ + override_flag is a contextmanager for easier testing of flags. + + It accepts two parameters, name of the switch and it's state. Example + usage:: + + with override_flag('happy_mode', active=True): + ... + + with override_flag('happy_mode', staff=True): + ... + + If `Flag` already existed, its values would be changed inside the context + block, then restored to its original values. If `Flag` did not exist + before entering the context, it is created, then removed at the end of the + block. + + It can also act as a decorator:: + + @override_flag('happy_mode', active=True) + def test_happy_mode_enabled(): + ... + + """ cls = get_waffle_flag_model() - def update(self, active: Optional[bool]) -> None: + def update( + self, + active: Optional[bool] = None, + percent: Optional[float] = None, + testing: Optional[bool] = None, + superusers: Optional[bool] = None, + staff: Optional[bool] = None, + authenticated: Optional[bool] = None, + languages: Optional[str] = None, + ) -> None: obj = self.cls.objects.get(pk=self.obj.pk) obj.everyone = active + obj.percent = percent + + if testing is not None: + obj.testing = testing + if superusers is not None: + obj.superusers = superusers + if staff is not None: + obj.staff = staff + if authenticated is not None: + obj.authenticated = authenticated + if languages is not None: + obj.languages = languages + obj.save() obj.flush() - def get_value(self) -> Optional[bool]: - return self.obj.everyone + def get_values(self) -> _FlagValues: + return _FlagValues( + active=self.obj.everyone, + percent=self.obj.percent, + testing=self.obj.testing, + superusers=self.obj.superusers, + staff=self.obj.staff, + authenticated=self.obj.authenticated, + languages=self.obj.languages, + ) class override_sample(_overrider[Union[bool, float]]):