Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Improve type annotations for execute_values. #12311

Merged
merged 6 commits into from
Mar 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12311.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type annotations for `execute_values`.
17 changes: 7 additions & 10 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,11 +290,15 @@ def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
if isinstance(self.database_engine, PostgresEngine):
from psycopg2.extras import execute_batch

self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
self._do_execute(
lambda the_sql: execute_batch(self.txn, the_sql, args), sql
)
else:
self.executemany(sql, args)

def execute_values(self, sql: str, *args: Any, fetch: bool = True) -> List[Tuple]:
def execute_values(
self, sql: str, values: Iterable[Iterable[Any]], fetch: bool = True
) -> List[Tuple]:
"""Corresponds to psycopg2.extras.execute_values. Only available when
using postgres.

Expand All @@ -305,15 +309,8 @@ def execute_values(self, sql: str, *args: Any, fetch: bool = True) -> List[Tuple
from psycopg2.extras import execute_values

return self._do_execute(
# Type ignore: mypy is unhappy because if `x` is a 5-tuple, then there will
# be two values for `fetch`: one given positionally, and another given
# as a keyword argument. We might be able to fix this by
# - propagating the signature of psycopg2.extras.execute_values to this
# function, or
# - changing `*args: Any` to `values: T` for some appropriate T.
lambda *x: execute_values(self.txn, *x, fetch=fetch), # type: ignore[misc]
lambda the_sql: execute_values(self.txn, the_sql, values, fetch=fetch),
sql,
*args,
)

def execute(self, sql: str, *args: Any) -> None:
Expand Down