Skip to content

Commit

Permalink
feat: support for soft deletion of casbin rules (#72)
Browse files Browse the repository at this point in the history
* feat: add support for soft deletion of casbin rules

* refactor: improve reusability of test case

* fix: type hints

* test: soft delete

* chore: updated .gitignore

* refactor: pass the sqlalchemy attribute itself instead of the attribute name string

* test: softdelete flag in database

* refactor: save_policy
          - load rules from db before making changes
          - improved comments

* test: save_policy softdelete strategy

* fix: formatted code with black

* fix: do not create test.db by default

* fix: units tests for CI/CD pipeline

* docs: added Soft Delete example

* fix: make sure softdelete filter is applied

* docs: make usage of  explicit

* docs: moved softdelete logic into base class

* docs: improvement

* feat: validate the type of db_class_softdelete_attribute

* fix: default value of is_deleted flag
  • Loading branch information
trbtm authored Jul 8, 2024
1 parent e9ff609 commit 8911c16
Show file tree
Hide file tree
Showing 6 changed files with 413 additions and 59 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,7 @@ venv.bak/

# mypy
.mypy_cache/
.idea
.idea

# vscode settings
.vscode
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,24 @@ else:
pass
```

## Soft Delete example

Soft Delete for casbin rules is supported, only when using a custom casbin rule model.
The Soft Delete mechanism is enabled by passing the attribute of the flag indicating whether
a rule is deleted to `db_class_softdelete_attribute`.
That attribute needs to be of type `sqlalchemy.Boolean`.

```python
adapter = Adapter(
engine,
db_class=MyCustomCasbinRuleModel,
db_class_softdelete_attribute=MyCustomCasbinRuleModel.is_deleted,
)
```

Please be aware that this adapter only sets a flag like `is_deleted` to `True`.
The provided model needs to handle the update of fields like `deleted_by`, `deleted_at`, etc.
An example for this is given in [examples/softdelete.py](examples/softdelete.py).

### Getting Help

Expand Down
135 changes: 110 additions & 25 deletions casbin_sqlalchemy_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import sqlalchemy
from casbin import persist
from sqlalchemy import Column, Integer, String
from sqlalchemy import create_engine, or_
from sqlalchemy import Column, Integer, String, Boolean
from sqlalchemy import create_engine, or_, not_
from sqlalchemy.orm import sessionmaker

# declarative base class
Expand Down Expand Up @@ -56,15 +56,33 @@ class Filter:
class Adapter(persist.Adapter, persist.adapters.UpdateAdapter):
"""the interface for Casbin adapters."""

def __init__(self, engine, db_class=None, filtered=False, create_all_models=True):
def __init__(
self,
engine,
db_class=None,
db_class_softdelete_attribute=None,
filtered=False,
create_all_models=True,
):
if isinstance(engine, str):
self._engine = create_engine(engine)
else:
self._engine = engine

self.softdelete_attribute = None

if db_class is None:
db_class = CasbinRule
else:
if db_class_softdelete_attribute is not None and not isinstance(
db_class_softdelete_attribute.type, Boolean
):
msg = f"The type of db_class_softdelete_attribute needs to be {str(Boolean)!r}. "
msg += f"An attribute of type {str(type(db_class_softdelete_attribute.type))!r} was given."
raise ValueError(msg)
# Softdelete is only supported when using custom class
self.softdelete_attribute = db_class_softdelete_attribute

for attr in (
"id",
"ptype",
Expand Down Expand Up @@ -102,7 +120,9 @@ def _session_scope(self):
def load_policy(self, model):
"""loads all policy rules from the storage."""
with self._session_scope() as session:
lines = session.query(self._db_class).all()
query = session.query(self._db_class)
query = self._softdelete_query(query)
lines = query.all()
for line in lines:
persist.load_policy_line(str(line), model)

Expand All @@ -113,6 +133,7 @@ def load_filtered_policy(self, model, filter) -> None:
"""loads all policy rules from the storage."""
with self._session_scope() as session:
query = session.query(self._db_class)
query = self._softdelete_query(query)
filters = self.filter_query(query, filter)
filters = filters.all()

Expand Down Expand Up @@ -140,15 +161,60 @@ def _save_policy_line(self, ptype, rule, session=None):

def save_policy(self, model):
"""saves all policy rules to the storage."""

# Use the default strategy when soft delete is not enabled
if self.softdelete_attribute is None:
with self._session_scope() as session:
query = session.query(self._db_class)
query.delete()
for sec in ["p", "g"]:
if sec not in model.model.keys():
continue
for ptype, ast in model.model[sec].items():
for rule in ast.policy:
self._save_policy_line(ptype, rule, session=session)
return True

# Custom stategy for softdelete since it does not make sense to recreate all of the
# entries when using soft delete
with self._session_scope() as session:
query = session.query(self._db_class)
query.delete()
query = self._softdelete_query(query)

# Delete entries that are not part of the model anymore
lines_before_changes = query.all()

# Create new entries in the database
for sec in ["p", "g"]:
if sec not in model.model.keys():
continue
for ptype, ast in model.model[sec].items():
for rule in ast.policy:
self._save_policy_line(ptype, rule, session=session)
# Filter for rule in the database
filter_query = query.filter(self._db_class.ptype == ptype)
for index, value in enumerate(rule):
v_value = getattr(self._db_class, "v{}".format(index))
filter_query = filter_query.filter(v_value == value)
# If the rule is not present, create an entry in the database
if filter_query.count() == 0:
self._save_policy_line(ptype, rule, session=session)

for line in lines_before_changes:
ptype = line.ptype
sec = ptype[0] # derived from persist.load_policy_line function
fields_with_None = [
line.v0,
line.v1,
line.v2,
line.v3,
line.v4,
line.v5,
]
rule = [element for element in fields_with_None if element is not None]
# If the the rule is not part of the model, set the deletion flag to True
if not model.has_policy(sec, ptype, rule):
setattr(line, self.softdelete_attribute.name, True)

return True

def add_policy(self, sec, ptype, rule):
Expand All @@ -164,10 +230,15 @@ def remove_policy(self, sec, ptype, rule):
"""removes a policy rule from the storage."""
with self._session_scope() as session:
query = session.query(self._db_class)
query = self._softdelete_query(query)
query = query.filter(self._db_class.ptype == ptype)
for i, v in enumerate(rule):
query = query.filter(getattr(self._db_class, "v{}".format(i)) == v)
r = query.delete()

if self.softdelete_attribute is None:
r = query.delete()
else:
r = query.update({self.softdelete_attribute: True})

return True if r > 0 else False

Expand All @@ -177,20 +248,27 @@ def remove_policies(self, sec, ptype, rules):
return
with self._session_scope() as session:
query = session.query(self._db_class)
query = self._softdelete_query(query)
query = query.filter(self._db_class.ptype == ptype)
rules = zip(*rules)
for i, rule in enumerate(rules):
query = query.filter(
or_(getattr(self._db_class, "v{}".format(i)) == v for v in rule)
)
query.delete()

if self.softdelete_attribute is None:
query.delete()
else:
query.update({self.softdelete_attribute: True})

def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
"""removes policy rules that match the filter from the storage.
This is part of the Auto-Save feature.
"""
with self._session_scope() as session:
query = session.query(self._db_class).filter(self._db_class.ptype == ptype)
query = session.query(self._db_class)
query = self._softdelete_query(query)
query = query.filter(self._db_class.ptype == ptype)

if not (0 <= field_index <= 5):
return False
Expand All @@ -200,12 +278,16 @@ def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
if v != "":
v_value = getattr(self._db_class, "v{}".format(field_index + i))
query = query.filter(v_value == v)
r = query.delete()

if self.softdelete_attribute is None:
r = query.delete()
else:
r = query.update({self.softdelete_attribute: True})

return True if r > 0 else False

def update_policy(
self, sec: str, ptype: str, old_rule: [str], new_rule: [str]
self, sec: str, ptype: str, old_rule: list[str], new_rule: list[str]
) -> None:
"""
Update the old_rule with the new_rule in the database (storage).
Expand All @@ -219,7 +301,9 @@ def update_policy(
"""

with self._session_scope() as session:
query = session.query(self._db_class).filter(self._db_class.ptype == ptype)
query = session.query(self._db_class)
query = self._softdelete_query(query)
query = query.filter(self._db_class.ptype == ptype)

# locate the old rule
for index, value in enumerate(old_rule):
Expand All @@ -241,12 +325,8 @@ def update_policies(
self,
sec: str,
ptype: str,
old_rules: [
[str],
],
new_rules: [
[str],
],
old_rules: list[list[str]],
new_rules: list[list[str]],
) -> None:
"""
Update the old_rules with the new_rules in the database (storage).
Expand All @@ -262,8 +342,8 @@ def update_policies(
self.update_policy(sec, ptype, old_rules[i], new_rules[i])

def update_filtered_policies(
self, sec, ptype, new_rules: [[str]], field_index, *field_values
) -> [[str]]:
self, sec, ptype, new_rules: list[list[str]], field_index, *field_values
) -> list[list[str]]:
"""update_filtered_policies updates all the policies on the basis of the filter."""

filter = Filter()
Expand All @@ -278,16 +358,15 @@ def update_filtered_policies(

self._update_filtered_policies(new_rules, filter)

def _update_filtered_policies(self, new_rules, filter) -> [[str]]:
def _update_filtered_policies(self, new_rules, filter) -> list[list[str]]:
"""_update_filtered_policies updates all the policies on the basis of the filter."""

with self._session_scope() as session:

# Load old policies

query = session.query(self._db_class).filter(
self._db_class.ptype == filter.ptype
)
query = session.query(self._db_class)
query = self._softdelete_query(query)
query = query.filter(self._db_class.ptype == filter.ptype)
filtered_query = self.filter_query(query, filter)
old_rules = filtered_query.all()

Expand All @@ -302,3 +381,9 @@ def _update_filtered_policies(self, new_rules, filter) -> [[str]]:
# return deleted rules

return old_rules

def _softdelete_query(self, query):
query_softdelete = query
if self.softdelete_attribute is not None:
query_softdelete = query_softdelete.where(not_(self.softdelete_attribute))
return query_softdelete
91 changes: 91 additions & 0 deletions examples/softdelete.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from datetime import datetime, UTC

import casbin
from casbin_sqlalchemy_adapter import Base, Adapter
from sqlalchemy import false, Column, DateTime, String, Integer, Boolean
from sqlalchemy.engine.default import DefaultExecutionContext

from some_user_library import get_current_user_id


def _deleted_at_default(context: DefaultExecutionContext) -> datetime | None:
current_parameters = context.get_current_parameters()
if current_parameters.get("is_deleted"):
return datetime.now(UTC)
else:
return None


def _deleted_by_default(context: DefaultExecutionContext) -> int | None:
current_parameters = context.get_current_parameters()
if current_parameters.get("is_deleted"):
return get_current_user_id()
else:
return None


class BaseModel(Base):
__abstract__ = True

created_at = Column(DateTime, default=lambda: datetime.now(UTC), nullable=False)
updated_at = Column(
DateTime,
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
nullable=False,
)
deleted_at = Column(
DateTime,
default=_deleted_at_default,
onupdate=_deleted_at_default,
nullable=True,
)

created_by = Column(Integer, default=get_current_user_id, nullable=False)
updated_by = Column(
Integer,
default=get_current_user_id,
onupdate=get_current_user_id,
nullable=False,
)
deleted_by = Column(
Integer,
default=_deleted_by_default,
onupdate=_deleted_by_default,
nullable=True,
)
is_deleted = Column(
Boolean,
default=False,
server_default=false(),
index=True,
nullable=False,
)


class CasbinSoftDeleteRule(BaseModel):
__tablename__ = "casbin_rule"

id = Column(Integer, primary_key=True)
ptype = Column(String(255))
v0 = Column(String(255))
v1 = Column(String(255))
v2 = Column(String(255))
v3 = Column(String(255))
v4 = Column(String(255))
v5 = Column(String(255))


engine = your_engine_factory()
# Initialize the Adapter, pass your custom CasbinRule model
# and pass the Boolean field indicating whether a rule is deleted or not
# your model needs to handle the update of fields
# 'updated_by', 'updated_at', 'deleted_by', etc.
adapter = Adapter(
engine,
CasbinSoftDeleteRule,
CasbinSoftDeleteRule.is_deleted,
)
# Create the Enforcer, etc.
e = casbin.Enforcer("path/to/model.conf", adapter)
...
Loading

0 comments on commit 8911c16

Please # to comment.