Skip to content

Commit

Permalink
fix(firestore): fix get and getall method of transaction (apache#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
HemangChothani authored Feb 21, 2020
1 parent 3a37ce9 commit de3aca0
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions google/cloud/firestore_v1/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def get_all(self, references):
.DocumentSnapshot: The next document snapshot that fulfills the
query, or :data:`None` if the document does not exist.
"""
return self._client.get_all(references, transaction=self._id)
return self._client.get_all(references, transaction=self)

def get(self, ref_or_query):
"""
Expand All @@ -225,9 +225,9 @@ def get(self, ref_or_query):
query, or :data:`None` if the document does not exist.
"""
if isinstance(ref_or_query, DocumentReference):
return self._client.get_all([ref_or_query], transaction=self._id)
return self._client.get_all([ref_or_query], transaction=self)
elif isinstance(ref_or_query, Query):
return ref_or_query.stream(transaction=self._id)
return ref_or_query.stream(transaction=self)
else:
raise ValueError(
'Value for argument "ref_or_query" must be a DocumentReference or a Query.'
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/v1/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def test_get_all(self):
transaction = self._make_one(client)
ref1, ref2 = mock.Mock(), mock.Mock()
result = transaction.get_all([ref1, ref2])
client.get_all.assert_called_once_with([ref1, ref2], transaction=transaction.id)
client.get_all.assert_called_once_with([ref1, ref2], transaction=transaction)
self.assertIs(result, client.get_all.return_value)

def test_get_document_ref(self):
Expand All @@ -343,7 +343,7 @@ def test_get_document_ref(self):
transaction = self._make_one(client)
ref = DocumentReference("documents", "doc-id")
result = transaction.get(ref)
client.get_all.assert_called_once_with([ref], transaction=transaction.id)
client.get_all.assert_called_once_with([ref], transaction=transaction)
self.assertIs(result, client.get_all.return_value)

def test_get_w_query(self):
Expand All @@ -354,7 +354,7 @@ def test_get_w_query(self):
query = Query(parent=mock.Mock(spec=[]))
query.stream = mock.MagicMock()
result = transaction.get(query)
query.stream.assert_called_once_with(transaction=transaction.id)
query.stream.assert_called_once_with(transaction=transaction)
self.assertIs(result, query.stream.return_value)

def test_get_failure(self):
Expand Down

0 comments on commit de3aca0

Please sign in to comment.