diff --git a/backend/core/config.py b/backend/core/config.py index c1ac2dd0..cc6f2b35 100644 --- a/backend/core/config.py +++ b/backend/core/config.py @@ -1,9 +1,6 @@ -import os -from jsmin import jsmin from fastapi import FastAPI, Request, Response from middleware import decode_token from fastapi.responses import FileResponse -from db.connection import SessionLocal from routes.user import user_route from routes.tag import tag_route @@ -18,32 +15,21 @@ from routes.visualization import visualization_route from routes.reference_data import reference_data_routes +import os +from jsmin import jsmin +from db.connection import SessionLocal + from models.business_unit import BusinessUnit from models.commodity_category import CommodityCategory from models.currency import Currency from models.country import Country -app = FastAPI( - root_path="/api", - title="IDH-IDC", - description="Auth Client ID: 99w2F1wVLZq8GqJwZph1kE42GuAZFvlF", - version="1.0.0", - contact={ - "name": "Akvo", - "url": "https://akvo.org", - "email": "dev@akvo.org", - }, - license_info={ - "name": "AGPL3", - "url": "https://www.gnu.org/licenses/agpl-3.0.en.html", - }, -) - JS_FILE = "./config.min.js" def generate_config_file() -> None: + print("[START] Generating config") session = SessionLocal() env_js = "var __ENV__={" env_js += 'client_id:"{}"'.format(os.environ["CLIENT_ID"]) @@ -51,9 +37,11 @@ def generate_config_file() -> None: env_js += "};" min_js = jsmin("".join([env_js, ""])) business_units = session.query(BusinessUnit).all() or [] + session.flush() if business_units: business_units = [bu.serialize for bu in business_units] commodity_categories = session.query(CommodityCategory).all() or [] + session.flush() if commodity_categories: commodity_categories = [ cc.serialize_with_commodities for cc in commodity_categories @@ -61,6 +49,7 @@ def generate_config_file() -> None: currencies = ( session.query(Currency.abbreviation, Currency.country).distinct() or [] ) + session.flush() if currencies: currencies = [ {"value": c[0], "label": c[0], "country": c[1]} for c in currencies @@ -69,6 +58,7 @@ def generate_config_file() -> None: session.query(Country).filter(Country.parent.is_(None)).all() or [] # noqa ) + session.flush() if countries: countries = [c.to_dropdown for c in countries] min_js += "var master={};".format( @@ -83,6 +73,25 @@ def generate_config_file() -> None: ) with open(JS_FILE, "w") as jsfile: jsfile.write(min_js) + print("[DONE] Config generated") + session.close() + + +app = FastAPI( + root_path="/api", + title="IDH-IDC", + description="Auth Client ID: 99w2F1wVLZq8GqJwZph1kE42GuAZFvlF", + version="1.0.0", + contact={ + "name": "Akvo", + "url": "https://akvo.org", + "email": "dev@akvo.org", + }, + license_info={ + "name": "AGPL3", + "url": "https://www.gnu.org/licenses/agpl-3.0.en.html", + }, +) # Routes register @@ -118,7 +127,6 @@ def health_check(): description="static javascript config", ) async def main(res: Response): - generate_config_file() res.headers["Content-Type"] = "application/x-javascript; charset=utf-8" return JS_FILE diff --git a/backend/db/crud_case.py b/backend/db/crud_case.py index 2f1db208..c83bdffb 100644 --- a/backend/db/crud_case.py +++ b/backend/db/crud_case.py @@ -171,6 +171,7 @@ def update_case(session: Session, id: int, payload: CaseBase) -> CaseDict: for ct in prev_tags: session.delete(ct) session.commit() + session.flush() # store new tags for tag_id in payload.tags: tag = CaseTag(tag=tag_id, case=case.id) @@ -290,9 +291,11 @@ def delete_case(session: Session, case_id: int): for sa in segment_answer: session.delete(sa) session.commit() + session.flush() for s in segment: session.delete(s) session.commit() + session.flush() # visualization visualization = ( @@ -303,6 +306,7 @@ def delete_case(session: Session, case_id: int): for vis in visualization: session.delete(vis) session.commit() + session.flush() # case_commodity case_commodity = ( @@ -313,12 +317,14 @@ def delete_case(session: Session, case_id: int): for cc in case_commodity: session.delete(cc) session.commit() + session.flush() # case tag case_tag = session.query(CaseTag).filter(CaseTag.case == case_id).all() for ct in case_tag: session.delete(ct) session.commit() + session.flush() # user case user_case_access = ( @@ -329,6 +335,7 @@ def delete_case(session: Session, case_id: int): for uca in user_case_access: session.delete(uca) session.commit() + session.flush() session.delete(case) session.commit() diff --git a/backend/db/crud_reset_password.py b/backend/db/crud_reset_password.py index e722cde7..4fe5fd0a 100644 --- a/backend/db/crud_reset_password.py +++ b/backend/db/crud_reset_password.py @@ -36,3 +36,4 @@ def get_reset_password(session: Session, url: str) -> ResetPasswordBase: def delete_reset_password(session: Session, url: str) -> None: session.query(ResetPassword).filter(ResetPassword.url == url).delete() session.commit() + session.flush() diff --git a/backend/db/crud_segment.py b/backend/db/crud_segment.py index 339fae87..58739750 100644 --- a/backend/db/crud_segment.py +++ b/backend/db/crud_segment.py @@ -71,6 +71,7 @@ def update_segment( for sa in prev_segment_answers: session.delete(sa) session.commit() + session.flush() # handle segment answers for val in payload.answers: segment_answer = SegmentAnswer( diff --git a/backend/db/crud_user.py b/backend/db/crud_user.py index 610cabe6..49b6b496 100644 --- a/backend/db/crud_user.py +++ b/backend/db/crud_user.py @@ -117,6 +117,7 @@ def update_user( for ut in prev_user_tags: session.delete(ut) session.commit() + session.flush() # add new user tags for tag in payload.tags: user_tag = UserTag(user=user.id, tag=tag) @@ -131,6 +132,7 @@ def update_user( for uc in prev_user_cases: session.delete(uc) session.commit() + session.flush() # add new user case access for proj in payload.cases: case_access = UserCaseAccess( @@ -153,6 +155,7 @@ def update_user( for bu in prev_user_bus: session.delete(bu) session.commit() + session.flush() # add new user business units for bu in payload.business_units: business_unit = UserBusinessUnit( diff --git a/backend/dev.sh b/backend/dev.sh index 7af02b0d..8d012005 100755 --- a/backend/dev.sh +++ b/backend/dev.sh @@ -5,5 +5,4 @@ pip -q install --cache-dir=.pip -r requirements.txt pip check alembic upgrade head - uvicorn main:app --reload --port 5000 diff --git a/backend/generate_config.sh b/backend/generate_config.sh new file mode 100755 index 00000000..fd6d3c2f --- /dev/null +++ b/backend/generate_config.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash +set -eu + +# Run the Python command to import and execute the function +python -c "from core.config import generate_config_file; generate_config_file(); exit()" diff --git a/backend/main.py b/backend/main.py index c917af33..d77ab885 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,4 +1,5 @@ import uvicorn + from db.connection import engine, Base from core.config import app, generate_config_file diff --git a/backend/models/business_unit.py b/backend/models/business_unit.py index 9739ca56..a3251c11 100644 --- a/backend/models/business_unit.py +++ b/backend/models/business_unit.py @@ -1,9 +1,10 @@ -from db.connection import Base -from sqlalchemy import Column, Integer, String -from sqlalchemy.orm import relationship from typing import Optional from typing_extensions import TypedDict +from sqlalchemy import Column, Integer, String +from sqlalchemy.orm import relationship + from pydantic import BaseModel +from db.connection import Base class BusinessUnitDict(TypedDict): @@ -18,7 +19,7 @@ class BusinessUnit(Base): name = Column(String, nullable=False, unique=True) business_unit_users = relationship( - 'UserBusinessUnit', + "UserBusinessUnit", cascade="all, delete", passive_deletes=True, back_populates="business_unit_detail" diff --git a/backend/models/reset_password.py b/backend/models/reset_password.py index bda803d9..16daa6e3 100644 --- a/backend/models/reset_password.py +++ b/backend/models/reset_password.py @@ -18,7 +18,7 @@ class ResetPasswordBase(BaseModel): expired: bool class Config: - orm_mode = True + from_attributes = True class ResetPassword(Base): diff --git a/backend/models/user_business_unit.py b/backend/models/user_business_unit.py index a9b7994c..7935f6f9 100644 --- a/backend/models/user_business_unit.py +++ b/backend/models/user_business_unit.py @@ -1,12 +1,12 @@ import enum -from db.connection import Base -from sqlalchemy import Column, ForeignKey, Integer, Enum -from sqlalchemy.orm import relationship from typing import Optional from typing_extensions import TypedDict +from sqlalchemy import Column, ForeignKey, Integer, Enum +from sqlalchemy.orm import relationship + from pydantic import BaseModel -from models.business_unit import BusinessUnit +from db.connection import Base class UserBusinessUnitRole(enum.Enum): @@ -37,10 +37,10 @@ class UserBusinessUnit(Base): id = Column(Integer, primary_key=True, nullable=False) user = Column(Integer, ForeignKey("user.id"), nullable=False) - business_unit = Column(Integer, ForeignKey("business_unit.id"), nullable=False) - role = Column( - Enum(UserBusinessUnitRole, name="user_business_unit_role"), nullable=False - ) + business_unit = Column( + Integer, ForeignKey("business_unit.id"), nullable=False) + role = Column(Enum( + UserBusinessUnitRole, name="user_business_unit_role"), nullable=False) user_business_unit_user_detail = relationship( "User", @@ -49,7 +49,7 @@ class UserBusinessUnit(Base): back_populates="user_business_units", ) business_unit_detail = relationship( - BusinessUnit, + "BusinessUnit", cascade="all, delete", passive_deletes=True, back_populates="business_unit_users", @@ -60,7 +60,7 @@ def __init__( business_unit: int, role: UserBusinessUnitRole, id: Optional[int] = None, - user: Optional[int] = None, + user: Optional[int] = None ): self.id = id self.user = user diff --git a/backend/seeder/commodity.py b/backend/seeder/commodity.py index 5e5ab44c..248bbb3d 100644 --- a/backend/seeder/commodity.py +++ b/backend/seeder/commodity.py @@ -47,6 +47,7 @@ def seeder_commodity(session: Session): if category_objects: session.bulk_save_objects(category_objects, update_changed_only=True) session.commit() + session.flush() print("[DATABASE UPDATED]: Commodity Category") commodities = commodities[["id", "group_id", "name"]] @@ -73,6 +74,7 @@ def seeder_commodity(session: Session): if commodity_objects: session.bulk_save_objects(commodity_objects, update_changed_only=True) session.commit() + session.flush() print("[DATABASE UPDATED]: Commodity") session.close() diff --git a/backend/setup.cfg b/backend/setup.cfg index 1e39ded3..def45731 100644 --- a/backend/setup.cfg +++ b/backend/setup.cfg @@ -2,6 +2,9 @@ max-line-length = 88 ignore = E203, E266, E501, W503 select = E,W,F +per-file-ignores = + # imported but unused + __init__.py: F401 [tool.black] line-length = 88