Skip to content

Commit

Permalink
Merge pull request #76 from voipro/fix-async-during-watch
Browse files Browse the repository at this point in the history
Multiple queued commands during watch
  • Loading branch information
fiorix committed Nov 21, 2014
2 parents d7f13e5 + 076ff53 commit 537836b
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 8 deletions.
17 changes: 17 additions & 0 deletions tests/test_watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,20 @@ def testRedisWithBulkCommands_hgetall(self):
self.assertEqual({"foo": "bar", "bar": "foo"}, h0)
self.assertEqual(h0, h1)
self.assertEqual(h0, h2)

@defer.inlineCallbacks
def testRedisWithAsyncCommandsDuringWatch(self):
yield self.db.hset(self._KEYS[0], "foo", "bar")
yield self.db.hset(self._KEYS[0], "bar", "foo")

h0 = yield self.db.hgetall(self._KEYS[0])
t = yield self.db.watch(self._KEYS[0])
(h1, h2) = yield defer.gatherResults([
t.hgetall(self._KEYS[0]),
t.hgetall(self._KEYS[0]),
], consumeErrors=True)
yield t.unwatch()

self.assertEqual({"foo": "bar", "bar": "foo"}, h0)
self.assertEqual(h0, h1)
self.assertEqual(h0, h2)
16 changes: 8 additions & 8 deletions txredisapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def __init__(self, charset="utf-8", errors="strict"):

self.transactions = 0
self.inTransaction = False
self.inMulti = False
self.unwatch_cc = lambda: ()
self.commit_cc = lambda: ()

Expand Down Expand Up @@ -418,11 +419,7 @@ def handleTransactionData(self, reply):
self.transactions -= len(reply)
if self.transactions == 0:
self.commit_cc()
if self.inTransaction: # watch but no multi: process the reply as usual
f = self.post_proc[1:]
if len(f) == 1 and callable(f[0]):
reply = f[0](reply)
else: # multi: this must be an exec reply
if not self.inTransaction: # multi: this must be an exec reply
tmp = []
for f, v in zip(self.post_proc[1:], reply):
if callable(f):
Expand Down Expand Up @@ -494,7 +491,7 @@ def execute_command(self, *args, **kwargs):
if self.pipelining:
self.pipelined_replies.append(r)

if self.inTransaction:
if self.inMulti:
self.post_proc.append(kwargs.get("post_proc"))
else:
if "post_proc" in kwargs:
Expand Down Expand Up @@ -1376,11 +1373,13 @@ def sort(self, key, start=None, end=None, by=None, get=None,
def _clear_txstate(self):
if self.inTransaction:
self.inTransaction = False
self.inMulti = False
self.factory.connectionQueue.put(self)

def watch(self, keys):
if not self.inTransaction:
self.inTransaction = True
self.inMulti = False
self.unwatch_cc = self._clear_txstate
self.commit_cc = lambda: ()
if isinstance(keys, (str, unicode)):
Expand All @@ -1399,6 +1398,7 @@ def unwatch(self):
# must be executed.
def multi(self, keys=None):
self.inTransaction = True
self.inMulti = True
self.unwatch_cc = lambda: ()
self.commit_cc = self._clear_txstate
if keys is not None:
Expand All @@ -1423,12 +1423,12 @@ def _commit_check(self, response):
return response

def commit(self):
if self.inTransaction is False:
if self.inMulti is False:
raise RedisError("Not in transaction")
return self.execute_command("EXEC").addCallback(self._commit_check)

def discard(self):
if self.inTransaction is False:
if self.inMulti is False:
raise RedisError("Not in transaction")
self.post_proc = []
self.transactions = 0
Expand Down

0 comments on commit 537836b

Please sign in to comment.