diff --git a/casbin_sqlalchemy_adapter/adapter.py b/casbin_sqlalchemy_adapter/adapter.py index d85681c..1666be9 100644 --- a/casbin_sqlalchemy_adapter/adapter.py +++ b/casbin_sqlalchemy_adapter/adapter.py @@ -46,7 +46,7 @@ class Filter: class Adapter(persist.Adapter, persist.adapters.UpdateAdapter): """the interface for Casbin adapters.""" - def __init__(self, engine, db_class=None, filtered=False): + def __init__(self, engine, db_class= None, filtered=False): if isinstance(engine, str): self._engine = create_engine(engine) else: @@ -54,6 +54,12 @@ def __init__(self, engine, db_class=None, filtered=False): if db_class is None: db_class = CasbinRule + else: + for attr in ('ptype', 'v0', 'v1', 'v2', 'v3', 'v4', 'v5'): + if not hasattr(db_class, attr): + raise Exception(f'{attr} not found in custom DatabaseClass.') + Base.metadata = db_class.metadata + self._db_class = db_class self.session_local = sessionmaker(bind=self._engine) @@ -185,7 +191,7 @@ def remove_filtered_policy(self, sec, ptype, field_index, *field_values): return True if r > 0 else False def update_policy( - self, sec: str, ptype: str, old_rule: [str], new_rule: [str] + self, sec: str, ptype: str, old_rule: [str], new_rule: [str] ) -> None: """ Update the old_rule with the new_rule in the database (storage). @@ -218,15 +224,15 @@ def update_policy( exec(f"old_rule_line.v{index} = None") def update_policies( - self, - sec: str, - ptype: str, - old_rules: [ - [str], - ], - new_rules: [ - [str], - ], + self, + sec: str, + ptype: str, + old_rules: [ + [str], + ], + new_rules: [ + [str], + ], ) -> None: """ Update the old_rules with the new_rules in the database (storage). diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 7c34280..f3819f8 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -1,12 +1,14 @@ +import os +from unittest import TestCase + +import casbin +from sqlalchemy import create_engine, Column, Integer, String +from sqlalchemy.orm import sessionmaker + from casbin_sqlalchemy_adapter import Adapter from casbin_sqlalchemy_adapter import Base from casbin_sqlalchemy_adapter import CasbinRule from casbin_sqlalchemy_adapter.adapter import Filter -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker -from unittest import TestCase -import casbin -import os def get_fixture(path): @@ -35,6 +37,30 @@ def get_enforcer(): class TestConfig(TestCase): + def test_custom_db_class(self): + class CustomRule(Base): + __tablename__ = "casbin_rule2" + + id = Column(Integer, primary_key=True) + ptype = Column(String(255)) + v0 = Column(String(255)) + v1 = Column(String(255)) + v2 = Column(String(255)) + v3 = Column(String(255)) + v4 = Column(String(255)) + v5 = Column(String(255)) + not_exist = Column(String(255)) + + engine = create_engine("sqlite://") + adapter = Adapter(engine, CustomRule) + + session = sessionmaker(bind=engine) + Base.metadata.create_all(engine) + s = session() + s.add(CustomRule(not_exist="NotNone")) + s.commit() + self.assertEqual(s.query(CustomRule).all()[0].not_exist, "NotNone") + def test_enforcer_basic(self): e = get_enforcer()