diff --git a/tests/routes/test_collections.py b/tests/routes/test_collections.py index 4ce3cf3..65fdd06 100644 --- a/tests/routes/test_collections.py +++ b/tests/routes/test_collections.py @@ -372,3 +372,20 @@ def test_collections_temporal_extent_datetime_column(app): assert len(intervals) == 4 assert intervals[0][0] == "2004-10-19T10:23:54+00:00" assert intervals[0][1] == "2007-10-24T00:00:00+00:00" + +def test_collections_collectionId_substring_filter(app): + """Test /collections endpoint.""" + response = app.get("/collections", params={"collectionId_substring": "_mgrs"}) + assert response.status_code == 200 + assert response.headers["content-type"] == "application/json" + body = response.json() + + ids = [x["id"] for x in body["collections"]] + + assert "public.sentinel_mgrs" in ids + assert "pg_temp.landsat_centroids" not in ids + assert "pg_temp.hexagons" not in ids + assert "pg_temp.squares" not in ids + assert "public.st_squaregrid" not in ids + assert "public.st_hexagongrid" not in ids + assert "public.st_subdivide" not in ids diff --git a/tipg/dependencies.py b/tipg/dependencies.py index a82feeb..58b543e 100644 --- a/tipg/dependencies.py +++ b/tipg/dependencies.py @@ -233,6 +233,21 @@ def datetime_query( return None +def collectionId_substring_query( + collectionId_substring: Annotated[Optional[str], Query(description="Filter based on collectionId substring regex.")] = None +) -> Optional[str]: + """collectionId substring dependency.""" + compiled_substring = None + if collectionId_substring: + try: + # Attempt to compile the substring pattern provided by the user + compiled_substring = re.compile(collectionId_substring) + except re.error as e: + raise HTTPException( + status_code=422, + detail=f"Invalid substring '{collectionId_substring}' provided for 'collectionId_substring': {e}" + ) + return compiled_substring def properties_query( properties: Annotated[ @@ -450,6 +465,7 @@ def CollectionsParams( description="Starts the response at an offset.", ), ] = None, + collectionId_substring: Annotated[Optional[str], Depends(collectionId_substring_query)] = None ) -> CollectionList: """Return Collections Catalog.""" limit = limit or 0 @@ -487,6 +503,15 @@ def CollectionsParams( and t_intersects(datetime_filter, collection.dt_bounds) ] + # collectionId substring filter + if collectionId_substring is not None: + collections_list = [ + collection + for collection in collections_list + # Use search() to find the substring anywhere in the collection ID + if collectionId_substring.search(collection.id) + ] + matched = len(collections_list) if limit: