Skip to content

Commit

Permalink
Remove common_timestamp function from postgis driver (#1623)
Browse files Browse the repository at this point in the history
* Duplicates full working at last

* Use explicit timezones in metadata.

* Passing all tests with old style test data.

* Cleanup.

* Update whats_new.rst

* Lintage and mypy

* Minor cleanup.

* Commented some non-intutive code.

* Commented some non-intutive code.

* Cleanup.

* oops whitspace.
  • Loading branch information
SpacemanPaul committed Sep 20, 2024
1 parent eeb5e6d commit 2824c16
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 128 deletions.
187 changes: 112 additions & 75 deletions datacube/drivers/postgis/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from sqlalchemy import select, text, and_, or_, func
from sqlalchemy.dialects.postgresql import INTERVAL
from sqlalchemy.exc import IntegrityError
from sqlalchemy.engine import Row

from typing import Iterable, Sequence, Optional, Set, Any
from typing import cast as type_cast
Expand All @@ -50,22 +49,40 @@

# Make a function because it's broken
def _dataset_select_fields() -> tuple:
return tuple(f.alchemy_expression for f in _dataset_fields())


def _dataset_fields() -> tuple:
native_flds = get_native_fields()
return (
Dataset,
# All active URIs, from newest to oldest
func.array(
select(
SelectedDatasetLocation.uri
).where(
and_(
SelectedDatasetLocation.dataset_ref == Dataset.id,
SelectedDatasetLocation.archived == None
)
).order_by(
SelectedDatasetLocation.added.desc(),
SelectedDatasetLocation.id.desc()
).label('uris')
).label('uris')
native_flds["id"],
native_flds["indexed_time"],
native_flds["indexed_by"],
native_flds["product_id"],
native_flds["metadata_type_id"],
native_flds["metadata_doc"],
NativeField(
'archived',
'Archived date',
Dataset.archived
),
NativeField("uris",
"all uris",
func.array(
select(
SelectedDatasetLocation.uri
).where(
and_(
SelectedDatasetLocation.dataset_ref == Dataset.id,
SelectedDatasetLocation.archived == None
)
).order_by(
SelectedDatasetLocation.added.desc(),
SelectedDatasetLocation.id.desc()
).label('uris')
),
alchemy_table=Dataset.__table__ # type: ignore[attr-defined]
)
)


Expand Down Expand Up @@ -230,6 +247,29 @@ def extract_dataset_fields(ds_metadata, fields):
return result


# Min/Max aggregating time fields for temporal_extent methods
time_min = DateDocField('acquisition_time_min',
'Min of time when dataset was acquired',
Dataset.metadata_doc,
False, # is it indexed
offset=[
['properties', 'dtr:start_datetime'],
['properties', 'datetime']
],
selection='least')


time_max = DateDocField('acquisition_time_max',
'Max of time when dataset was acquired',
Dataset.metadata_doc,
False, # is it indexed
offset=[
['properties', 'dtr:end_datetime'],
['properties', 'datetime']
],
selection='greatest')


class PostgisDbAPI:
def __init__(self, parentdb, connection):
self._db = parentdb
Expand Down Expand Up @@ -476,19 +516,6 @@ def all_dataset_ids(self, archived: bool | None = False):
)
return self._connection.execute(query).fetchall()

# Not currently implemented.
# def insert_dataset_source(self, classifier, dataset_id, source_dataset_id):
# r = self._connection.execute(
# insert(DatasetSource).on_conflict_do_nothing(
# index_elements=['classifier', 'dataset_ref']
# ).values(
# classifier=classifier,
# dataset_ref=dataset_id,
# source_dataset_ref=source_dataset_id
# )
# )
# return r.rowcount > 0

def archive_dataset(self, dataset_id):
r = self._connection.execute(
update(Dataset).where(
Expand Down Expand Up @@ -548,10 +575,10 @@ def get_datasets(self, dataset_ids):
).fetchall()

def get_derived_datasets(self, dataset_id):
raise NotImplementedError
raise NotImplementedError()

def get_dataset_sources(self, dataset_id):
raise NotImplementedError
raise NotImplementedError()

def search_datasets_by_metadata(self, metadata, archived):
"""
Expand Down Expand Up @@ -621,14 +648,13 @@ def search_datasets_query(self,
assert source_exprs is None
assert not with_source_ids

if select_fields:
select_columns = tuple(
f.alchemy_expression.label(f.name)
for f in select_fields
)
else:
select_columns = _dataset_select_fields()
if not select_fields:
select_fields = _dataset_fields()

select_columns = tuple(
f.alchemy_expression.label(f.name)
for f in select_fields
)
if geom:
SpatialIndex, spatialquery = self.geospatial_query(geom)
else:
Expand Down Expand Up @@ -663,12 +689,21 @@ def search_datasets(self, expressions,
:type with_source_ids: bool
:type select_fields: tuple[datacube.drivers.postgis._fields.PgField]
:type expressions: tuple[datacube.drivers.postgis._fields.PgExpression]
:return: An iterable of tuples of decoded values
"""
if select_fields is None:
select_fields = _dataset_fields()
select_query = self.search_datasets_query(expressions, source_exprs,
select_fields, with_source_ids,
limit, geom=geom, archived=archived)
_LOG.debug("search_datasets SQL: %s", str(select_query))
return self._connection.execute(select_query)

def decode_row(raw: Iterable[Any]) -> dict[str, Any]:
return {f.name: f.normalise_value(r) for r, f in zip(raw, select_fields)}

for row in self._connection.execute(select_query):
yield decode_row(row)

def bulk_simple_dataset_search(self, products=None, batch_size=0):
"""
Expand All @@ -690,7 +725,7 @@ def bulk_simple_dataset_search(self, products=None, batch_size=0):
query = select(
*_dataset_bulk_select_fields()
).select_from(Dataset).where(
Dataset.archived == None
Dataset.archived.is_(None)
)
if products:
query = query.where(Dataset.product_ref.in_(products))
Expand Down Expand Up @@ -733,10 +768,12 @@ def insert_lineage_bulk(self, values):
)
return res.rowcount, requested - res.rowcount

def get_duplicates(self, match_fields: Sequence[PgField], expressions: Sequence[PgExpression]) -> Iterable[Row]:
def get_duplicates(self,
match_fields: Sequence[PgField],
expressions: Sequence[PgExpression]) -> Iterable[dict[str, Any]]:
# TODO
if "time" in [f.name for f in match_fields]:
return self.get_duplicates_with_time(match_fields, expressions)
yield from self.get_duplicates_with_time(match_fields, expressions)

group_expressions = tuple(f.alchemy_expression for f in match_fields)
join_tables = PostgisDbAPI._join_tables(expressions, match_fields)
Expand All @@ -749,43 +786,47 @@ def get_duplicates(self, match_fields: Sequence[PgField], expressions: Sequence[
query = query.join(*joins)

query = query.where(
and_(Dataset.archived == None, *(PostgisDbAPI._alchemify_expressions(expressions)))
and_(Dataset.archived.is_(None), *(PostgisDbAPI._alchemify_expressions(expressions)))
).group_by(
*group_expressions
).having(
func.count(Dataset.id) > 1
)
return self._connection.execute(query)
for row in self._connection.execute(query):
drow = {"ids": row.ids}
for f in match_fields:
drow[f.name] = getattr(row, f.name)
yield drow

def get_duplicates_with_time(
self, match_fields: Sequence[PgField], expressions: Sequence[PgExpression]
) -> Iterable[Row]:
) -> Iterable[dict[str, Any]]:
fields = []
for f in match_fields:
if f.name == "time":
time_field = type_cast(DateRangeDocField, f).expression_with_leniency
for fld in match_fields:
if fld.name == "time":
time_field = type_cast(DateRangeDocField, fld)
else:
fields.append(f.alchemy_expression)
fields.append(fld.alchemy_expression)

join_tables = PostgisDbAPI._join_tables(expressions, match_fields)

cols = [Dataset.id, time_field.label('time'), *fields]
cols = [Dataset.id, time_field.expression_with_leniency.label('time'), *fields]
query = select(
*cols
).select_from(Dataset)
for joins in join_tables:
query = query.join(*joins)

query = query.where(
and_(Dataset.archived == None, *(PostgisDbAPI._alchemify_expressions(expressions)))
and_(Dataset.archived.is_(None), *(PostgisDbAPI._alchemify_expressions(expressions)))
)

t1 = query.alias("t1")
t2 = query.alias("t2")

time_overlap = select(
t1.c.id,
text("t1.time * t2.time as time_intersect"),
t1.c.time.intersection(t2.c.time).label('time_intersect'),
*fields
).select_from(
t1.join(
Expand All @@ -797,15 +838,24 @@ def get_duplicates_with_time(
query = select(
func.array_agg(func.distinct(time_overlap.c.id)).label("ids"),
*fields, # type: ignore[arg-type]
text("(lower(time_intersect) at time zone 'UTC', upper(time_intersect) at time zone 'UTC') as time")
text("time_intersect as time")
).select_from(
time_overlap # type: ignore[arg-type]
).group_by(
*fields, text("time_intersect")
).having(
func.count(time_overlap.c.id) > 1
)
return self._connection.execute(query)

for row in self._connection.execute(query):
# TODO: Use decode_rows above - would require creating a field class for the ids array.
drow: dict[str, Any] = {
"ids": row.ids,
}
for f in fields:
drow[f.key] = getattr(row, f.key) # type: ignore[union-attr]
drow["time"] = time_field.normalise_value((row.time.lower, row.time.upper))
yield drow

def count_datasets(self, expressions, archived: bool | None = False, geom: Geometry | None = None):
"""
Expand Down Expand Up @@ -1474,33 +1524,20 @@ def remove_lineage_relations(self,
def temporal_extent_by_prod(self, product_id: int) -> tuple[datetime.datetime, datetime.datetime]:
query = self.temporal_extent_full().where(Dataset.product_ref == product_id)
res = self._connection.execute(query)
return res.first()
for tmin, tmax in res:
return (self.time_min.normalise_value(tmin), self.time_max.normalise_value(tmax))
raise RuntimeError("Product has no datasets and therefore no temporal extent")

def temporal_extent_by_ids(self, ids: Iterable[DSID]) -> tuple[datetime.datetime, datetime.datetime]:
query = self.temporal_extent_full().where(Dataset.id.in_(ids))
res = self._connection.execute(query)
return res.first()
for tmin, tmax in res:
return (self.time_min.normalise_value(tmin), self.time_max.normalise_value(tmax))
raise ValueError("no dataset ids provided")

def temporal_extent_full(self) -> Select:
# Hardcode eo3 standard time locations - do not use this approach in a legacy index driver.
time_min = DateDocField('aquisition_time_min',
'Min of time when dataset was acquired',
Dataset.metadata_doc,
False, # is it indexed
offset=[
['properties', 'dtr:start_datetime'],
['properties', 'datetime']
],
selection='least')
time_max = DateDocField('aquisition_time_max',
'Max of time when dataset was acquired',
Dataset.metadata_doc,
False, # is it indexed
offset=[
['properties', 'dtr:end_datetime'],
['properties', 'datetime']
],
selection='greatest')

return select(
func.min(time_min.alchemy_expression), func.max(time_max.alchemy_expression)
func.min(self.time_min.alchemy_expression), func.max(self.time_max.alchemy_expression)
)
1 change: 0 additions & 1 deletion datacube/drivers/postgis/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ def ensure_db(engine, with_permissions=True):
c.execute(text(f"""
grant usage on schema {SCHEMA_NAME} to odc_user;
grant select on all tables in schema {SCHEMA_NAME} to odc_user;
grant execute on function {SCHEMA_NAME}.common_timestamp(text) to odc_user;
grant insert on {SCHEMA_NAME}.dataset,
{SCHEMA_NAME}.location,
Expand Down
Loading

0 comments on commit 2824c16

Please sign in to comment.