Skip to content

Commit

Permalink
fix: fix db_class not vaild
Browse files Browse the repository at this point in the history
Signed-off-by: Zxilly <zhouxinyu1001@gmail.com>
  • Loading branch information
Zxilly committed Sep 7, 2021
1 parent d834662 commit 4effb2d
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 16 deletions.
28 changes: 17 additions & 11 deletions casbin_sqlalchemy_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,20 @@ class Filter:
class Adapter(persist.Adapter, persist.adapters.UpdateAdapter):
"""the interface for Casbin adapters."""

def __init__(self, engine, db_class=None, filtered=False):
def __init__(self, engine, db_class= None, filtered=False):
if isinstance(engine, str):
self._engine = create_engine(engine)
else:
self._engine = engine

if db_class is None:
db_class = CasbinRule
else:
for attr in ('ptype', 'v0', 'v1', 'v2', 'v3', 'v4', 'v5'):
if not hasattr(db_class, attr):
raise Exception(f'{attr} not found in custom DatabaseClass.')
Base.metadata = db_class.metadata

self._db_class = db_class
self.session_local = sessionmaker(bind=self._engine)

Expand Down Expand Up @@ -185,7 +191,7 @@ def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
return True if r > 0 else False

def update_policy(
self, sec: str, ptype: str, old_rule: [str], new_rule: [str]
self, sec: str, ptype: str, old_rule: [str], new_rule: [str]
) -> None:
"""
Update the old_rule with the new_rule in the database (storage).
Expand Down Expand Up @@ -218,15 +224,15 @@ def update_policy(
exec(f"old_rule_line.v{index} = None")

def update_policies(
self,
sec: str,
ptype: str,
old_rules: [
[str],
],
new_rules: [
[str],
],
self,
sec: str,
ptype: str,
old_rules: [
[str],
],
new_rules: [
[str],
],
) -> None:
"""
Update the old_rules with the new_rules in the database (storage).
Expand Down
36 changes: 31 additions & 5 deletions tests/test_adapter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import os
from unittest import TestCase

import casbin
from sqlalchemy import create_engine, Column, Integer, String
from sqlalchemy.orm import sessionmaker

from casbin_sqlalchemy_adapter import Adapter
from casbin_sqlalchemy_adapter import Base
from casbin_sqlalchemy_adapter import CasbinRule
from casbin_sqlalchemy_adapter.adapter import Filter
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from unittest import TestCase
import casbin
import os


def get_fixture(path):
Expand Down Expand Up @@ -35,6 +37,30 @@ def get_enforcer():


class TestConfig(TestCase):
def test_custom_db_class(self):
class CustomRule(Base):
__tablename__ = "casbin_rule2"

id = Column(Integer, primary_key=True)
ptype = Column(String(255))
v0 = Column(String(255))
v1 = Column(String(255))
v2 = Column(String(255))
v3 = Column(String(255))
v4 = Column(String(255))
v5 = Column(String(255))
not_exist = Column(String(255))

engine = create_engine("sqlite://")
adapter = Adapter(engine, CustomRule)

session = sessionmaker(bind=engine)
Base.metadata.create_all(engine)
s = session()
s.add(CustomRule(not_exist="NotNone"))
s.commit()
self.assertEqual(s.query(CustomRule).all()[0].not_exist, "NotNone")

def test_enforcer_basic(self):
e = get_enforcer()

Expand Down

0 comments on commit 4effb2d

Please # to comment.