Skip to content

Commit

Permalink
Add for_update clause
Browse files Browse the repository at this point in the history
  • Loading branch information
dkopitsa committed Aug 8, 2024
1 parent 872d58f commit f3700b6
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 0 deletions.
62 changes: 62 additions & 0 deletions docs/src/piccolo/query_clauses/for_update.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
.. _limit:

for_update
=====

You can use ``for_update`` clauses with the following queries:

* :ref:`Objects`
* :ref:`Select`

Returns a query that will lock rows until the end of the transaction, generating a SELECT ... FOR UPDATE SQL statement.

.. note:: Postgres and CockroachDB only.

-------------------------------------------------------------------------------

default
~~~~~~~
To use select for update without extra parameters. All matched rows will be locked until the end of transaction.

.. code-block:: python
await Band.select(Band.name == 'Pythonistas').for_update()
equals to:

.. code-block:: sql
SELECT ... FOR UPDATE
nowait
~~~~~~~
If another transaction has already acquired a lock on one or more selected rows, the exception will be raised instead of waiting for another transaction


.. code-block:: python
await Band.select(Band.name == 'Pythonistas').for_update(nowait=True)
skip_locked
~~~~~~~
Ignore locked rows

.. code-block:: python
await Band.select(Band.name == 'Pythonistas').for_update(skip_locked=True)
of
~~~~~~~
By default, if there are many tables in query (e.x. when joining), all tables will be locked.
with `of` you can specify tables, which should be locked.


.. code-block:: python
await Band.select().where(Band.manager.name == 'Guido').for_update(of=(Band, ))
1 change: 1 addition & 0 deletions docs/src/piccolo/query_clauses/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ by modifying the return values.
./on_conflict
./output
./returning
./for_update

.. toctree::
:maxdepth: 1
Expand Down
10 changes: 10 additions & 0 deletions piccolo/query/methods/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
AsOfDelegate,
CallbackDelegate,
CallbackType,
ForUpdateDelegate,
LimitDelegate,
OffsetDelegate,
OrderByDelegate,
Expand Down Expand Up @@ -194,6 +195,7 @@ class Objects(
"callback_delegate",
"prefetch_delegate",
"where_delegate",
"for_update_delegate",
)

def __init__(
Expand All @@ -213,6 +215,7 @@ def __init__(
self.prefetch_delegate = PrefetchDelegate()
self.prefetch(*prefetch)
self.where_delegate = WhereDelegate()
self.for_update_delegate = ForUpdateDelegate()

def output(self: Self, load_json: bool = False) -> Self:
self.output_delegate.output(
Expand Down Expand Up @@ -272,6 +275,12 @@ def first(self) -> First[TableInstance]:
self.limit_delegate.limit(1)
return First[TableInstance](query=self)

def for_update(
self: Self, nowait: bool = False, skip_locked: bool = False, of=()
) -> Self:
self.for_update_delegate.for_update(nowait, skip_locked, of)
return self

def get(self, where: Combinable) -> Get[TableInstance]:
self.where_delegate.where(where)
self.limit_delegate.limit(1)
Expand Down Expand Up @@ -322,6 +331,7 @@ def default_querystrings(self) -> t.Sequence[QueryString]:
"offset_delegate",
"output_delegate",
"order_by_delegate",
"for_update_delegate",
):
setattr(select, attr, getattr(self, attr))

Expand Down
17 changes: 17 additions & 0 deletions piccolo/query/methods/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
CallbackType,
ColumnsDelegate,
DistinctDelegate,
ForUpdateDelegate,
GroupByDelegate,
LimitDelegate,
OffsetDelegate,
Expand Down Expand Up @@ -150,6 +151,7 @@ class Select(Query[TableInstance, t.List[t.Dict[str, t.Any]]]):
"output_delegate",
"callback_delegate",
"where_delegate",
"for_update_delegate",
)

def __init__(
Expand All @@ -174,6 +176,7 @@ def __init__(
self.output_delegate = OutputDelegate()
self.callback_delegate = CallbackDelegate()
self.where_delegate = WhereDelegate()
self.for_update_delegate = ForUpdateDelegate()

self.columns(*columns_list)

Expand Down Expand Up @@ -219,6 +222,12 @@ def offset(self: Self, number: int) -> Self:
self.offset_delegate.offset(number)
return self

def for_update(
self: Self, nowait: bool = False, skip_locked: bool = False, of=()
) -> Self:
self.for_update_delegate.for_update(nowait, skip_locked, of)
return self

async def _splice_m2m_rows(
self,
response: t.List[t.Dict[str, t.Any]],
Expand Down Expand Up @@ -618,6 +627,14 @@ def default_querystrings(self) -> t.Sequence[QueryString]:
query += "{}"
args.append(self.offset_delegate._offset.querystring)

if engine_type == "sqlite" and self.for_update_delegate._for_update:
raise NotImplementedError(
"SQLite doesn't support SELECT .. FOR UPDATE"
)

if self.for_update_delegate._for_update:
args.append(self.for_update_delegate._for_update.querystring)

querystring = QueryString(query, *args)

return [querystring]
Expand Down
48 changes: 48 additions & 0 deletions piccolo/query/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,3 +784,51 @@ def on_conflict(
target=target, action=action_, values=values, where=where
)
)


@dataclass
class ForUpdate:
__slots__ = ("nowait", "skip_locked", "of")

nowait: bool
skip_locked: bool
of: tuple[Table]

def __post_init__(self):
if not isinstance(self.nowait, bool):
raise TypeError("nowait must be an integer")
if not isinstance(self.skip_locked, bool):
raise TypeError("skip_locked must be an integer")
if not isinstance(self.of, tuple) or not all(
hasattr(x, "_meta") for x in self.of
):
raise TypeError("of must be an tuple of Table")
if self.nowait and self.skip_locked:
raise TypeError(
"The nowait option cannot be used with skip_locked"
)

@property
def querystring(self) -> QueryString:
sql = " FOR UPDATE"
if self.of:
tables = ", ".join(x._meta.tablename for x in self.of)
sql += " OF " + tables
if self.nowait:
sql += " NOWAIT"
if self.skip_locked:
sql += " SKIP LOCKED"

return QueryString(sql)

def __str__(self) -> str:
return self.querystring.__str__()


@dataclass
class ForUpdateDelegate:

_for_update: t.Optional[ForUpdate] = None

def for_update(self, nowait=False, skip_locked=False, of=()):
self._for_update = ForUpdate(nowait, skip_locked, of)
27 changes: 27 additions & 0 deletions tests/table/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,33 @@ def test_select_raw(self):
response, [{"name": "Pythonistas", "popularity_log": 3.0}]
)

def test_for_update(self):
"""
Make sure the for_update clause works.
"""
self.insert_rows()

query = Band.select()
self.assertNotIn("FOR UPDATE", query.__str__())

query = query.for_update()
self.assertTrue(query.__str__().endswith("FOR UPDATE"))

query = query.for_update(skip_locked=True)
self.assertTrue(query.__str__().endswith("FOR UPDATE SKIP LOCKED"))

query = query.for_update(nowait=True)
self.assertTrue(query.__str__().endswith("FOR UPDATE NOWAIT"))

query = query.for_update(of=(Band,))
self.assertTrue(query.__str__().endswith("FOR UPDATE OF band"))

with self.assertRaises(TypeError):
query = query.for_update(skip_locked=True, nowait=True)

response = query.run_sync()
assert response is not None


class TestSelectSecret(TestCase):
def setUp(self):
Expand Down

0 comments on commit f3700b6

Please sign in to comment.