Skip to content

Commit

Permalink
Feature/pdct 1780 support multiple geos in update (#278)
Browse files Browse the repository at this point in the history
* feat: add function to retrieve list of geo ids from iso codes

- function that retrieves the list of geo ids from the geo repo, based
  on the iso codes thats provided in the payload. It validates these iso
  codes at the same time returning an error if the geography does not
  exist in the db

* feat: add geographies onto family write dto

this will currently have a default value of none, so that it works with front end requests until the payload data is also updated

* feat: add list of geo ids parameter to update function

* feat: validate geography iso codes to get ids in family service

* feat: update model for family geographies

* test: integration tests for update endpoint with invalid geo id in list

* test: update dto helpers with geographies property

* chore: update project version

* chore: update project version

* tests: add final integration, testing multi geos are added

* feat: default geographies to none on the model, to ensure that this does not break any requests made by the front before multi geos is implemented

* fix: use geography repo mock in repo unit tests

* chore: update project version

* refactor: make update_geographies check a bit more explicit

if the array of geo ids coming through is empty, i.e it was not passed int he payload as expected (the frontend hasnt been updted for multi geos we do nowt and carry on

* test: fixing missing argument in rollbackrepo

* feat: add linear ticket to todo comments - updating code once frontend multi geos has been implemented

* refactor: split out family geo update functionality into separate methods

* feat: helper functions for pergorming familygeo updates

* refactor: update sql query to be more explicit

* refactor : move functions into family repo and out of helpers

* tests: add tests for repository errors

* feat: wrap geography updates in a repository error

---------

Co-authored-by: Osneil Drakes <osneildrakes@Osneils-MBP.communityfibre.co.uk>
Co-authored-by: Osneil Drakes <osneildrakes@Osneils-MacBook-Pro.local>
  • Loading branch information
3 people authored Jan 9, 2025
1 parent 3d2f936 commit 47c9067
Show file tree
Hide file tree
Showing 16 changed files with 419 additions and 13 deletions.
4 changes: 3 additions & 1 deletion app/model/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ class FamilyWriteDTO(BaseModel):
title: str
summary: str
geography: str
geographies: list[str]
geographies: Optional[list[str]] = (
None # Todo APP-97: remove default once implemented on the frontend
)
category: str
metadata: dict[str, list[str]]
collections: list[str]
Expand Down
116 changes: 113 additions & 3 deletions app/repository/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def _update_intention(
import_id: str,
family: FamilyWriteDTO,
geo_id: int,
geo_ids: list[int],
original_family: Family,
):
original_collections = [
Expand All @@ -238,6 +239,19 @@ def _update_intention(
.geography_id
!= geo_id
)

update_geographies = False
# TODO: Todo APP-97: remove this conditional once multi-geography support is
# implemented on the frontend
if geo_ids != []:
current_family_geographies_ids = [
family_geography.geography_id
for family_geography in db.query(FamilyGeography).filter(
FamilyGeography.family_import_id == import_id
)
]
update_geographies = set(current_family_geographies_ids) != set(geo_ids)

update_basics = (
update_title
or update_geo
Expand All @@ -250,7 +264,13 @@ def _update_intention(
.one()
)
update_metadata = existing_metadata.value != family.metadata
return update_title, update_basics, update_metadata, update_collections
return (
update_title,
update_basics,
update_metadata,
update_collections,
update_geographies,
)


def all(db: Session, org_id: Optional[int]) -> list[FamilyReadDTO]:
Expand Down Expand Up @@ -354,14 +374,17 @@ def search(
return [_family_to_dto_search_endpoint(db, f) for f in found]


def update(db: Session, import_id: str, family: FamilyWriteDTO, geo_id: int) -> bool:
def update(
db: Session, import_id: str, family: FamilyWriteDTO, geo_id: int, geo_ids: list[int]
) -> bool:
"""
Updates a single entry with the new values passed.
:param db Session: the database connection
:param str import_id: The family import id to change.
:param FamilyDTO family: The new values
:param int geo_id: a validated geography id
:param list[int] geo_ids: a list of validated geography ids
:return bool: True if new values were set otherwise false.
"""
new_values = family.model_dump()
Expand All @@ -380,7 +403,8 @@ def update(db: Session, import_id: str, family: FamilyWriteDTO, geo_id: int) ->
update_basics,
update_metadata,
update_collections,
) = _update_intention(db, import_id, family, geo_id, original_family)
update_geographies,
) = _update_intention(db, import_id, family, geo_id, geo_ids, original_family)

# Return if nothing to do
if not (update_title or update_basics or update_metadata or update_collections):
Expand Down Expand Up @@ -475,6 +499,10 @@ def update(db: Session, import_id: str, family: FamilyWriteDTO, geo_id: int) ->
)
db.add(new_collection)

# Update geographies if geographies have changed.
if update_geographies:
perform_family_geographies_update(db, import_id, geo_ids)

return True


Expand Down Expand Up @@ -655,3 +683,85 @@ def count(db: Session, org_id: Optional[int]) -> Optional[int]:
return

return n_families


def remove_old_geographies(
db: Session, import_id: str, geo_ids: list[int], original_geographies: set[int]
):
"""
Removes geographies that are no longer in geo_ids.
This function compares the original set of geographies with the new geo_ids
and removes the geographies that are no longer present.
:param Session db: the database session
:param str import_id: the family import ID for the geographies
:param list[int] geo_ids: the list of geography IDs to be kept
:param set[int] original_geographies: the set of original geography IDs to be compared
:raises RepositoryError: if a geography removal fails
"""
cols_to_remove = set(original_geographies) - set(geo_ids)
for col in cols_to_remove:
try:
db.execute(
db_delete(FamilyGeography).where(FamilyGeography.geography_id == col)
)
except Exception as e:
msg = f"Could not remove family {import_id} from geography {col}: {str(e)}"
_LOGGER.error(msg)
raise RepositoryError(msg)


def add_new_geographies(
db: Session, import_id: str, geo_ids: list[int], original_geographies: set[int]
):
"""
Adds new geographies that are not already in the original geographies.
This function identifies the geographies that need to be added (i.e., those
that are in geo_ids but not in the original set) and adds them to the database.
:param Session db: the database session
:param str import_id: the family import ID for the geographies
:param list[str] geo_ids: the list of geography IDs to be added
:param set[str] original_geographies: the set of original geography IDs to be checked against
:raises RepositoryError: if fails to add a geography
"""
cols_to_add = set(geo_ids) - set(original_geographies)

for col in cols_to_add:
try:
new_geography = FamilyGeography(
family_import_id=import_id,
geography_id=col,
)
db.add(new_geography)
db.flush()
except Exception as e:
msg = f"Failed to add geography {col} to family {import_id}: {str(e)}"
_LOGGER.error(msg)
raise RepositoryError(msg)


def perform_family_geographies_update(db: Session, import_id: str, geo_ids: list[int]):
"""
Updates geographies by removing old ones and adding new ones.
This function performs a complete update by removing geographies that are no
longer in geo_ids and adding new geographies that were not previously present.
:param Session db: the database session
:param str import_id: the family import ID for the geographies
:param list[str] geo_ids: the list of geography IDs to be updated
"""
original_geographies = set(
[
fg.geography_id
for fg in db.query(FamilyGeography).filter(
FamilyGeography.family_import_id == import_id
)
]
)

remove_old_geographies(db, import_id, geo_ids, original_geographies)
add_new_geographies(db, import_id, geo_ids, original_geographies)
23 changes: 23 additions & 0 deletions app/repository/geography.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,27 @@


def get_id_from_value(db: Session, geo_string: str) -> Optional[int]:
"""
Fetch the ID of a geography based on its iso value.
:param Session db: Database session.
:param str geo_string: The geography value to look up.
:return Optional[int]: The ID of the geography if found, otherwise None.
"""
return db.query(Geography.id).filter_by(value=geo_string).scalar()


def get_ids_from_values(db: Session, geo_strings: list[str]) -> list[int]:
"""
Fetch IDs for multiple geographies based on their iso values.
:param Session db: Database session.
:param list[str] geo_strings: A list of geography iso values to look up.
:return list[int]: A list of IDs corresponding to the provided geography values.
"""
return [
geography.id
for geography in db.query(Geography)
.filter(Geography.value.in_(geo_strings))
.all()
]
84 changes: 83 additions & 1 deletion app/repository/helpers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
"""Helper functions for repos"""

import logging
from typing import Optional, Union, cast
from uuid import uuid4

from db_client.models.dfce.family import Slug
from db_client.models.dfce.family import FamilyGeography, Slug
from db_client.models.organisation.counters import CountedEntity, EntityCounter
from db_client.models.organisation.users import Organisation
from slugify import slugify
from sqlalchemy import delete as db_delete
from sqlalchemy.orm import Session

from app.errors import RepositoryError

_LOGGER = logging.getLogger(__name__)


def generate_unique_slug(
existing_slugs: set[str], title: str, attempts: int = 100, suffix_length: int = 6
Expand Down Expand Up @@ -91,3 +97,79 @@ def generate_import_id(
db.query(EntityCounter).filter(EntityCounter.prefix == org_name).one()
)
return counter.create_import_id(entity_type)


def remove_old_geographies(
db: Session, import_id: str, geo_ids: list[int], original_geographies: set[int]
):
"""
Removes geographies that are no longer in geo_ids.
This function compares the original set of geographies with the new geo_ids
and removes the geographies that are no longer present.
:param Session db: the database session
:param str import_id: the family import ID for the geographies
:param list[int] geo_ids: the list of geography IDs to be kept
:param set[int] original_geographies: the set of original geography IDs to be compared
:raises RepositoryError: if a geography removal fails
"""
cols_to_remove = set(original_geographies) - set(geo_ids)
for col in cols_to_remove:
result = db.execute(
db_delete(FamilyGeography).where(FamilyGeography.geography_id == col)
)

if result.rowcount == 0: # type: ignore
msg = f"Could not remove family {import_id} from geography {col}"
_LOGGER.error(msg)
raise RepositoryError(msg)


def add_new_geographies(
db: Session, import_id: str, geo_ids: list[int], original_geographies: set[int]
):
"""
Adds new geographies that are not already in the original geographies.
This function identifies the geographies that need to be added (i.e., those
that are in geo_ids but not in the original set) and adds them to the database.
:param Session db: the database session
:param str import_id: the family import ID for the geographies
:param list[str] geo_ids: the list of geography IDs to be added
:param set[str] original_geographies: the set of original geography IDs to be checked against
"""
cols_to_add = set(geo_ids) - set(original_geographies)

for col in cols_to_add:
db.flush()
new_geography = FamilyGeography(
family_import_id=import_id,
geography_id=col,
)
db.add(new_geography)


def perform_family_geographies_update(db: Session, import_id: str, geo_ids: list[int]):
"""
Updates geographies by removing old ones and adding new ones.
This function performs a complete update by removing geographies that are no
longer in geo_ids and adding new geographies that were not previously present.
:param Session db: the database session
:param str import_id: the family import ID for the geographies
:param list[str] geo_ids: the list of geography IDs to be updated
"""
original_geographies = set(
[
fg.geography_id
for fg in db.query(FamilyGeography).filter(
FamilyGeography.family_import_id == import_id
)
]
)

remove_old_geographies(db, import_id, geo_ids, original_geographies)
add_new_geographies(db, import_id, geo_ids, original_geographies)
6 changes: 5 additions & 1 deletion app/repository/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ def search(

@staticmethod
def update(
db: Session, import_id: str, family: FamilyWriteDTO, geo_id: int
db: Session,
import_id: str,
family: FamilyWriteDTO,
geo_id: int,
geography_ids: list[int],
) -> bool:
"""Updates a family"""
...
Expand Down
9 changes: 8 additions & 1 deletion app/service/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ def update(
# Validate geography
geo_id = geography.get_id(db, family_dto.geography)

# Validate geographies if they are passed as part of the json object, otherwise
# pass an empty list
# Todo APP-97: update this once the frontend can send multiple geographies
geography_ids = (
geography.get_ids(db, family_dto.geographies) if family_dto.geographies else []
)

# Validate family belongs to same org as current user.
entity_org_id: int = corpus.get_corpus_org_id(family.corpus_import_id, db)
app_user.raise_if_unauthorised_to_make_changes(
Expand All @@ -156,7 +163,7 @@ def update(
raise ValidationError(msg)

try:
if family_repo.update(db, import_id, family_dto, geo_id):
if family_repo.update(db, import_id, family_dto, geo_id, geography_ids):
db.commit()
else:
db.rollback()
Expand Down
28 changes: 28 additions & 0 deletions app/service/geography.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,35 @@


def get_id(db: Session, geo_string: str) -> int:
"""
Fetch the ID of a geography and validate its existence.
:param Session db: Database session.
:param str geo_string: The geography iso value to look up.
:raises ValidationError: If the geography value is invalid.
:return int: The ID of the geography.
"""
id = geography_repo.get_id_from_value(db, geo_string)
if id is None:
raise ValidationError(f"The geography value {geo_string} is invalid!")
return id


def get_ids(db: Session, geo_strings: list[str]) -> list[int]:
"""
Fetch IDs for multiple geographies and validate their existence.
:param Session db: Database session.
:param list[str] geo_strings: A list of geography iso values to look up.
:raises ValidationError: If any of the geography values are invalid.
:return list[int]: A list of IDs corresponding to the provided geography values.
"""

geo_ids = geography_repo.get_ids_from_values(db, geo_strings)

if len(geo_ids) != len(geo_strings):
raise ValidationError(
f"One or more of the following geography values are invalid: {', '.join(geo_strings)}"
)

return geo_ids
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "admin_backend"
version = "2.17.29"
version = "2.17.30"
description = ""
authors = ["CPR-dev-team <tech@climatepolicyradar.org>"]
packages = [{ include = "app" }, { include = "tests" }]
Expand Down
2 changes: 1 addition & 1 deletion tests/helpers/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def create_family_write_dto(
title: str = "title",
summary: str = "summary",
geography: str = "CHN",
geographies: list[str] = ["CHN", "BRB", "BHS"],
geographies: Optional[list[str]] = [],
category: str = FamilyCategory.LEGISLATIVE.value,
metadata: Optional[dict] = None,
collections: Optional[list[str]] = None,
Expand Down
Loading

0 comments on commit 47c9067

Please # to comment.