Skip to content

Commit

Permalink
fix mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
Ariana Barzinpour committed Sep 19, 2024
1 parent d6b9357 commit 210b7f0
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions datacube/drivers/postgis/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,14 +1355,15 @@ def get_all_relations(self, dsids: Iterable[uuid.UUID]) -> Iterable[LineageRelat
source_id=rel.source_dataset_ref,
derived_id=rel.derived_dataset_ref)

def write_relations(self, relations: Iterable[LineageRelation], allow_updates: bool):
def write_relations(self, relations: Iterable[LineageRelation], allow_updates: bool) -> int:
"""
Write a set of LineageRelation objects to the database.
:param relations: An Iterable of LineageRelation objects
:param allow_updates: if False, only allow adding new relations, not updating old ones.
:return: Count of database rows affected
"""
affected = 0
if allow_updates:
by_classifier: dict[str, Any] = {}
for rel in relations:
Expand All @@ -1375,34 +1376,31 @@ def write_relations(self, relations: Iterable[LineageRelation], allow_updates: b
by_classifier[rel.classifier].append(db_repr)
else:
by_classifier[rel.classifier] = [db_repr]
updates = 0
for classifier, values in by_classifier.items():
qry = insert(DatasetLineage).on_conflict_do_update(
index_elements=["derived_dataset_ref", "source_dataset_ref"],
set_={"classifier": classifier},
where=(DatasetLineage.classifier != classifier))
res = self._connection.execute(qry, values)
updates += res.rowcount
return updates
affected += res.rowcount
else:
if len(relations):
for rel in relations:
values = [
{
"derived_dataset_ref": rel.derived_id,
"source_dataset_ref": rel.source_id,
"classifier": rel.classifier
}
for rel in relations
]
qry = insert(DatasetLineage)
try:
res = self._connection.execute(
qry, values
)
return res.rowcount
affected += res.rowcount
except IntegrityError:
return 0
return
return affected

def load_lineage_relations(self,
roots: Iterable[uuid.UUID],
Expand Down

0 comments on commit 210b7f0

Please sign in to comment.