Skip to content

Commit

Permalink
Merge branch 'hotfix/sqla1.4' of https://github.com/PnX-SI/RefGeo int…
Browse files Browse the repository at this point in the history
…o feat/sqlalchemy2.0
  • Loading branch information
Pierre-Narcisi committed Dec 11, 2023
2 parents f718a05 + d6b59e0 commit 66b8f12
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 35 deletions.
6 changes: 3 additions & 3 deletions src/ref_geo/commands.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import click
from flask.cli import with_appcontext
from sqlalchemy import func, column
from sqlalchemy import func, select

from ref_geo.env import db
from ref_geo.models import BibAreasTypes, LAreas
Expand All @@ -16,10 +16,10 @@ def ref_geo():
def info():
click.echo("RefGeo : nombre de zones par type")
q = (
db.session.query(BibAreasTypes, func.count(LAreas.id_area).label("count"))
select(BibAreasTypes, func.count(LAreas.id_area).label("count"))
.join(LAreas)
.group_by(BibAreasTypes.id_type)
.order_by(BibAreasTypes.id_type)
)
for area_type, count in q.all():
for area_type, count in db.session.scalars(q).unique().all():
click.echo("\t{}: {}".format(area_type.type_name, count))
66 changes: 36 additions & 30 deletions src/ref_geo/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from flask import Blueprint, request, current_app
from flask.json import jsonify
import sqlalchemy as sa
from sqlalchemy import func, distinct, asc, desc
from sqlalchemy import func, select, asc, desc
from sqlalchemy.sql import text
from sqlalchemy.orm import joinedload, undefer
from werkzeug.exceptions import BadRequest
Expand Down Expand Up @@ -60,8 +60,10 @@ def getGeoInfo():
raise BadRequest("Missing 'geometry' in request payload")
geojson = json.dumps(geojson)

areas = LAreas.query.filter_by(enable=True).filter(
geojson_intersect_filter.params(geojson=geojson)
areas = (
select(LAreas)
.filter_by(enable=True)
.filter(geojson_intersect_filter.params(geojson=geojson))
)
if "area_type" in request.json:
areas = areas.join(BibAreasTypes).filter_by(type_code=request.json["area_type"])
Expand All @@ -78,7 +80,7 @@ def getGeoInfo():
{
"areas": [
area.as_dict(fields=["id_area", "id_type", "area_code", "area_name"])
for area in areas.all()
for area in db.session.scalars(areas).unique().all()
],
"altitude": altitude,
}
Expand Down Expand Up @@ -119,8 +121,10 @@ def getAreasIntersection():
raise BadRequest("Missing 'geometry' in request payload")
geojson = json.dumps(geojson)

areas = LAreas.query.filter_by(enable=True).filter(
geojson_intersect_filter.params(geojson=geojson)
areas = (
select(LAreas)
.filter_by(enable=True)
.filter(geojson_intersect_filter.params(geojson=geojson))
)
if "area_type" in request.json:
areas = areas.join(BibAreasTypes).filter_by(type_code=request.json["area_type"])
Expand All @@ -133,7 +137,9 @@ def getAreasIntersection():
areas = areas.order_by(LAreas.id_type)

response = {}
for id_type, _areas in groupby(areas.all(), key=lambda area: area.id_type):
for id_type, _areas in groupby(
db.session.scalars(areas).unique().all(), key=lambda area: area.id_type
):
_areas = list(_areas)
response[id_type] = _areas[0].area_type.as_dict(fields=["type_code", "type_name"])
response[id_type].update(
Expand Down Expand Up @@ -163,13 +169,13 @@ def get_municipalities():
"""
parameters = request.args

q = db.session.query(LiMunicipalities).order_by(LiMunicipalities.nom_com.asc())
q = select(LiMunicipalities).order_by(LiMunicipalities.nom_com.asc())

if "nom_com" in parameters:
q = q.filter(LiMunicipalities.nom_com.ilike("{}%".format(parameters.get("nom_com"))))
q = q.where(LiMunicipalities.nom_com.ilike("{}%".format(parameters.get("nom_com"))))
limit = int(parameters.get("limit")) if parameters.get("limit") else 100

data = q.limit(limit)
data = db.session.scalars(q.limit(limit)).all()
return jsonify([d.as_dict() for d in data])


Expand All @@ -190,8 +196,8 @@ def get_areas():
# change all args in a list of value
params = {key: request.args.getlist(key) for key, value in request.args.items()}

q = (
db.session.query(LAreas)
query = (
select(LAreas)
.options(joinedload("area_type").load_only("type_code"))
.order_by(LAreas.area_name.asc())
)
Expand All @@ -206,32 +212,34 @@ def get_areas():
}
return response, 400
if enable_param == "true":
q = q.filter(LAreas.enable == True)
query = query.where(LAreas.enable == True)
elif enable_param == "false":
q = q.filter(LAreas.enable == False)
query = query.where(LAreas.enable == False)
else:
q = q.filter(LAreas.enable == True)
query = query.where(LAreas.enable == True)

if "id_type" in params:
q = q.filter(LAreas.id_type.in_(params["id_type"]))
query = query.where(LAreas.id_type.in_(params["id_type"]))

if "type_code" in params:
q = q.filter(LAreas.area_type.has(BibAreasTypes.type_code.in_(params["type_code"])))
query = query.where(LAreas.area_type.has(BibAreasTypes.type_code.in_(params["type_code"])))

if "area_name" in params:
q = q.filter(LAreas.area_name.ilike("%{}%".format(params.get("area_name")[0])))
query = query.where(LAreas.area_name.ilike("%{}%".format(params.get("area_name")[0])))

limit = int(params.get("limit")[0]) if params.get("limit") else 100

data = q.limit(limit)
data = db.session.scalars(query.limit(limit)).unique().all()

# allow to format response
format = request.args.get("format", default="", type=str)

fields = {"area_type.type_code"}
if format == "geojson":
fields |= {"+geojson_4326"}
data = data.options(undefer("geojson_4326"))
query = query.options(undefer("geojson_4326"))

data = db.session.scalars(query).unique().all()
response = [d.as_dict(fields=fields) for d in data]
if format == "geojson":
# format features as geojson according to standard
Expand All @@ -256,7 +264,7 @@ def get_area_size():
raise BadRequest("Missing 'geometry' in request payload")
geojson = json.dumps(geojson)

query = db.session.query(area_size_func.params(geojson=geojson))
query = select(area_size_func.params(geojson=geojson))

return jsonify(db.session.execute(query).scalar())

Expand All @@ -275,25 +283,23 @@ def get_area_types():
type_code = request.args.get("code")
type_name = request.args.get("name")
sort = request.args.get("sort")
query = db.session.query(BibAreasTypes)
query = select(BibAreasTypes)
# GET ONLY INFO FOR A SPECIFIC CODE
if type_code:
code_exists = (
db.session.query(BibAreasTypes)
.filter(BibAreasTypes.type_code == type_code)
.one_or_none()
)
code_exists = db.session.scalars(
select(BibAreasTypes).where(BibAreasTypes.type_code == type_code)
).scalar_one_or_none()
if not code_exists:
raise BadRequest("This area type code does not exist")
query = query.filter(BibAreasTypes.type_code == type_code)
query = query.where(BibAreasTypes.type_code == type_code)
# FILTER BY NAME
if type_name:
query = query.filter(BibAreasTypes.type_name.ilike("%{}%".format(type_name)))
query = query.where(BibAreasTypes.type_name.ilike("%{}%".format(type_name)))
# SORT
if sort == "asc":
query = query.order_by(asc("type_name"))
if sort == "desc":
query = query.order_by(desc("type_name"))
# FIELDS
fields = ["type_name", "type_code", "id_type"]
return jsonify([d.as_dict(fields=fields) for d in query.all()])
return jsonify([d.as_dict(fields=fields) for d in db.session.scalars(query).all()])
5 changes: 3 additions & 2 deletions src/ref_geo/tests/test_ref_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from ref_geo.env import db
from ref_geo.models import BibAreasTypes, LAreas
from sqlalchemy import select


polygon = {
Expand Down Expand Up @@ -42,7 +43,7 @@ def has_french_dem():

@pytest.fixture(scope="function")
def area_commune():
return BibAreasTypes.query.filter_by(type_code="COM").one()
return db.session.execute(select(BibAreasTypes).filter_by(type_code="COM")).scalar_one()


@pytest.mark.usefixtures("client_class", "temporary_transaction")
Expand Down Expand Up @@ -302,7 +303,7 @@ def test_get_areas_as_geojson(self, area_commune):
"""
type_code = area_commune.type_code
id_type = area_commune.id_type
first_comm = LAreas.query.filter(LAreas.id_type == id_type).first()
first_comm = db.session.scalars(db.select(LAreas).where(LAreas.id_type == id_type)).first()
# will test many responses are return
response = self.client.get(
url_for("ref_geo.get_areas"),
Expand Down

0 comments on commit 66b8f12

Please sign in to comment.