Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making datastore batch/transaction more robust to failure. #2303

Merged
merged 1 commit into from
Sep 16, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions google/cloud/datastore/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,13 @@ def put(self, entity):
:type entity: :class:`google.cloud.datastore.entity.Entity`
:param entity: the entity to be saved.

:raises: ValueError if entity has no key assigned, or if the key's
:raises: :class:`~exceptions.ValueError` if the batch is not in
progress, if entity has no key assigned, or if the key's
``project`` does not match ours.
"""
if self._status != self._IN_PROGRESS:
raise ValueError('Batch must be in progress to put()')

if entity.key is None:
raise ValueError("Entity must have a key")

Expand All @@ -206,9 +210,13 @@ def delete(self, key):
:type key: :class:`google.cloud.datastore.key.Key`
:param key: the key to be deleted.

:raises: ValueError if key is not complete, or if the key's
:raises: :class:`~exceptions.ValueError` if the batch is not in
progress, if key is not complete, or if the key's
``project`` does not match ours.
"""
if self._status != self._IN_PROGRESS:
raise ValueError('Batch must be in progress to delete()')

if key.is_partial:
raise ValueError("Key must be complete")

Expand Down Expand Up @@ -255,7 +263,13 @@ def commit(self):
This is called automatically upon exiting a with statement,
however it can be called explicitly if you don't want to use a
context manager.

:raises: :class:`~exceptions.ValueError` if the batch is not
in progress.
"""
if self._status != self._IN_PROGRESS:
raise ValueError('Batch must be in progress to commit()')

try:
self._commit()
finally:
Expand All @@ -267,12 +281,19 @@ def rollback(self):
Marks the batch as aborted (can't be used again).

Overridden by :class:`google.cloud.datastore.transaction.Transaction`.

:raises: :class:`~exceptions.ValueError` if the batch is not
in progress.
"""
if self._status != self._IN_PROGRESS:
raise ValueError('Batch must be in progress to rollback()')

self._status = self._ABORTED

def __enter__(self):
self._client._push_batch(self)
self.begin()
# NOTE: We make sure begin() succeeds before pushing onto the stack.
self._client._push_batch(self)

This comment was marked as spam.

This comment was marked as spam.

This comment was marked as spam.

return self

def __exit__(self, exc_type, exc_val, exc_tb):
Expand Down
2 changes: 2 additions & 0 deletions google/cloud/datastore/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ def put_multi(self, entities):

if not in_batch:
current = self.batch()
current.begin()

for entity in entities:
current.put(entity)
Expand Down Expand Up @@ -384,6 +385,7 @@ def delete_multi(self, keys):

if not in_batch:
current = self.batch()
current.begin()

for key in keys:
current.delete(key)
Expand Down
11 changes: 9 additions & 2 deletions google/cloud/datastore/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ class Transaction(Batch):
:param client: the client used to connect to datastore.
"""

_status = None

This comment was marked as spam.


def __init__(self, client):
super(Transaction, self).__init__(client)
self._id = None
Expand Down Expand Up @@ -125,10 +127,15 @@ def begin(self):
statement, however it can be called explicitly if you don't want
to use a context manager.

:raises: :class:`ValueError` if the transaction has already begun.
:raises: :class:`~exceptions.ValueError` if the transaction has
already begun.
"""
super(Transaction, self).begin()
self._id = self.connection.begin_transaction(self.project)
try:
self._id = self.connection.begin_transaction(self.project)
except:
self._status = self._ABORTED
raise

def rollback(self):
"""Rolls back the current transaction.
Expand Down
73 changes: 72 additions & 1 deletion unit_tests/datastore/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,20 @@ def test_put_entity_wo_key(self):
client = _Client(_PROJECT, connection)
batch = self._makeOne(client)

batch.begin()
self.assertRaises(ValueError, batch.put, _Entity())

def test_put_entity_wrong_status(self):
_PROJECT = 'PROJECT'
connection = _Connection()
client = _Client(_PROJECT, connection)
batch = self._makeOne(client)
entity = _Entity()
entity.key = _Key('OTHER')

self.assertEqual(batch._status, batch._INITIAL)
self.assertRaises(ValueError, batch.put, entity)

def test_put_entity_w_key_wrong_project(self):
_PROJECT = 'PROJECT'
connection = _Connection()
Expand All @@ -78,6 +90,7 @@ def test_put_entity_w_key_wrong_project(self):
entity = _Entity()
entity.key = _Key('OTHER')

batch.begin()

This comment was marked as spam.

This comment was marked as spam.

self.assertRaises(ValueError, batch.put, entity)

def test_put_entity_w_partial_key(self):
Expand All @@ -90,6 +103,7 @@ def test_put_entity_w_partial_key(self):
key = entity.key = _Key(_PROJECT)
key._id = None

batch.begin()
batch.put(entity)

mutated_entity = _mutated_pb(self, batch.mutations, 'insert')
Expand All @@ -113,6 +127,7 @@ def test_put_entity_w_completed_key(self):
entity.exclude_from_indexes = ('baz', 'spam')
key = entity.key = _Key(_PROJECT)

batch.begin()
batch.put(entity)

mutated_entity = _mutated_pb(self, batch.mutations, 'upsert')
Expand All @@ -129,6 +144,17 @@ def test_put_entity_w_completed_key(self):
self.assertTrue(spam_values[2].exclude_from_indexes)
self.assertFalse('frotz' in prop_dict)

def test_delete_wrong_status(self):
_PROJECT = 'PROJECT'
connection = _Connection()
client = _Client(_PROJECT, connection)
batch = self._makeOne(client)
key = _Key(_PROJECT)
key._id = None

self.assertEqual(batch._status, batch._INITIAL)
self.assertRaises(ValueError, batch.delete, key)

def test_delete_w_partial_key(self):
_PROJECT = 'PROJECT'
connection = _Connection()
Expand All @@ -137,6 +163,7 @@ def test_delete_w_partial_key(self):
key = _Key(_PROJECT)
key._id = None

batch.begin()
self.assertRaises(ValueError, batch.delete, key)

def test_delete_w_key_wrong_project(self):
Expand All @@ -146,6 +173,7 @@ def test_delete_w_key_wrong_project(self):
batch = self._makeOne(client)
key = _Key('OTHER')

batch.begin()
self.assertRaises(ValueError, batch.delete, key)

def test_delete_w_completed_key(self):
Expand All @@ -155,6 +183,7 @@ def test_delete_w_completed_key(self):
batch = self._makeOne(client)
key = _Key(_PROJECT)

batch.begin()
batch.delete(key)

mutated_key = _mutated_pb(self, batch.mutations, 'delete')
Expand All @@ -180,23 +209,43 @@ def test_rollback(self):
_PROJECT = 'PROJECT'
client = _Client(_PROJECT, None)
batch = self._makeOne(client)
self.assertEqual(batch._status, batch._INITIAL)
batch.begin()
self.assertEqual(batch._status, batch._IN_PROGRESS)
batch.rollback()
self.assertEqual(batch._status, batch._ABORTED)

def test_rollback_wrong_status(self):
_PROJECT = 'PROJECT'
client = _Client(_PROJECT, None)
batch = self._makeOne(client)

self.assertEqual(batch._status, batch._INITIAL)
self.assertRaises(ValueError, batch.rollback)

def test_commit(self):
_PROJECT = 'PROJECT'
connection = _Connection()
client = _Client(_PROJECT, connection)
batch = self._makeOne(client)

self.assertEqual(batch._status, batch._INITIAL)
batch.begin()
self.assertEqual(batch._status, batch._IN_PROGRESS)
batch.commit()
self.assertEqual(batch._status, batch._FINISHED)

self.assertEqual(connection._committed,
[(_PROJECT, batch._commit_request, None)])

def test_commit_wrong_status(self):
_PROJECT = 'PROJECT'
connection = _Connection()
client = _Client(_PROJECT, connection)
batch = self._makeOne(client)

self.assertEqual(batch._status, batch._INITIAL)
self.assertRaises(ValueError, batch.commit)

def test_commit_w_partial_key_entities(self):
_PROJECT = 'PROJECT'
_NEW_ID = 1234
Expand All @@ -209,6 +258,8 @@ def test_commit_w_partial_key_entities(self):
batch._partial_key_entities.append(entity)

self.assertEqual(batch._status, batch._INITIAL)
batch.begin()
self.assertEqual(batch._status, batch._IN_PROGRESS)
batch.commit()
self.assertEqual(batch._status, batch._FINISHED)

Expand Down Expand Up @@ -295,6 +346,26 @@ def test_as_context_mgr_w_error(self):
self.assertEqual(mutated_entity.key, key._key)
self.assertEqual(connection._committed, [])

def test_as_context_mgr_enter_fails(self):
klass = self._getTargetClass()

class FailedBegin(klass):

def begin(self):
raise RuntimeError

client = _Client(None, None)
self.assertEqual(client._batches, [])

batch = FailedBegin(client)
with self.assertRaises(RuntimeError):
# The context manager will never be entered because
# of the failure.
with batch: # pragma: NO COVER
pass
# Make sure no batch was added.
self.assertEqual(client._batches, [])


class _PathElementPB(object):

Expand Down
3 changes: 3 additions & 0 deletions unit_tests/datastore/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,7 @@ def __init__(self, client):
from google.cloud.datastore.batch import Batch
self._client = client
self._batch = Batch(client)
self._batch.begin()

def __enter__(self):
self._client._push_batch(self._batch)
Expand All @@ -972,10 +973,12 @@ def __exit__(self, *args):
class _NoCommitTransaction(object):

def __init__(self, client, transaction_id='TRANSACTION'):
from google.cloud.datastore.batch import Batch
from google.cloud.datastore.transaction import Transaction
self._client = client
xact = self._transaction = Transaction(client)
xact._id = transaction_id
Batch.begin(xact)

def __enter__(self):
self._client._push_batch(self._transaction)
Expand Down
25 changes: 22 additions & 3 deletions unit_tests/datastore/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,19 @@ def test_begin_tombstoned(self):

self.assertRaises(ValueError, xact.begin)

def test_begin_w_begin_transaction_failure(self):
_PROJECT = 'PROJECT'
connection = _Connection(234)
client = _Client(_PROJECT, connection)
xact = self._makeOne(client)

connection._side_effect = RuntimeError
with self.assertRaises(RuntimeError):
xact.begin()

self.assertIsNone(xact.id)
self.assertEqual(connection._begun, _PROJECT)

def test_rollback(self):
_PROJECT = 'PROJECT'
connection = _Connection(234)
Expand Down Expand Up @@ -118,10 +131,10 @@ def test_commit_w_partial_keys(self):
connection._completed_keys = [_make_key(_KIND, _ID, _PROJECT)]
client = _Client(_PROJECT, connection)
xact = self._makeOne(client)
xact.begin()
entity = _Entity()
xact.put(entity)
xact._commit_request = commit_request = object()
xact.begin()
xact.commit()
self.assertEqual(connection._committed,
(_PROJECT, commit_request, 234))
Expand Down Expand Up @@ -176,7 +189,10 @@ def _make_key(kind, id_, project):

class _Connection(object):
_marker = object()
_begun = _rolled_back = _committed = None
_begun = None
_rolled_back = None
_committed = None
_side_effect = None

def __init__(self, xact_id=123):
self._xact_id = xact_id
Expand All @@ -185,7 +201,10 @@ def __init__(self, xact_id=123):

def begin_transaction(self, project):
self._begun = project
return self._xact_id
if self._side_effect is None:
return self._xact_id
else:
raise self._side_effect

def rollback(self, project, transaction_id):
self._rolled_back = project, transaction_id
Expand Down