Skip to content

Commit

Permalink
Fix device scope issues (#1841)
Browse files Browse the repository at this point in the history
We want to always place tf ops on a GPU device, this broke.
  • Loading branch information
mattdangerw committed Sep 18, 2024
1 parent 0f35d5e commit f75965f
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 9 deletions.
1 change: 1 addition & 0 deletions keras_nlp/src/tokenizers/byte_pair_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ def tokenize(self, inputs):
if self.add_prefix_space:
inputs = tf.strings.join([" ", inputs])

inputs = tf.convert_to_tensor(inputs)
unbatched = inputs.shape.rank == 0
if unbatched:
inputs = tf.expand_dims(inputs, 0)
Expand Down
1 change: 1 addition & 0 deletions keras_nlp/src/tokenizers/sentence_piece_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def _check_vocabulary(self):
@preprocessing_function
def tokenize(self, inputs):
self._check_vocabulary()
inputs = tf.convert_to_tensor(inputs)
unbatched = inputs.shape.rank == 0
if unbatched:
inputs = tf.expand_dims(inputs, 0)
Expand Down
1 change: 1 addition & 0 deletions keras_nlp/src/tokenizers/word_piece_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ def _check_vocabulary(self):
@preprocessing_function
def tokenize(self, inputs):
self._check_vocabulary()
inputs = tf.convert_to_tensor(inputs)
unbatched = inputs.shape.rank == 0
pattern = None
if self.split and self.special_tokens_in_strings:
Expand Down
17 changes: 9 additions & 8 deletions keras_nlp/src/utils/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,28 +53,29 @@ def preprocessing_function(fn):

params = inspect.signature(fn).parameters
accepts_labels = all(k in params for k in ("x", "y", "sample_weight"))
with tf.device("cpu"):
if not accepts_labels:
if not accepts_labels:

@functools.wraps(fn)
def wrapper(self, x, **kwargs):
@functools.wraps(fn)
def wrapper(self, x, **kwargs):
with tf.device("cpu"):
x = convert_preprocessing_inputs(x)
with no_convert_scope():
x = fn(self, x, **kwargs)
return convert_preprocessing_outputs(x)

else:
else:

@functools.wraps(fn)
def wrapper(self, x, y=None, sample_weight=None, **kwargs):
@functools.wraps(fn)
def wrapper(self, x, y=None, sample_weight=None, **kwargs):
with tf.device("cpu"):
x, y, sample_weight = convert_preprocessing_inputs(
(x, y, sample_weight)
)
with no_convert_scope():
x = fn(self, x, y=y, sample_weight=sample_weight, **kwargs)
return convert_preprocessing_outputs(x)

return wrapper
return wrapper


def convert_preprocessing_inputs(x):
Expand Down
15 changes: 14 additions & 1 deletion keras_nlp/src/utils/tensor_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@
from keras_nlp.src.utils.tensor_utils import convert_preprocessing_outputs
from keras_nlp.src.utils.tensor_utils import convert_to_ragged_batch
from keras_nlp.src.utils.tensor_utils import is_tensor_type
from keras_nlp.src.utils.tensor_utils import preprocessing_function
from keras_nlp.src.utils.tensor_utils import tensor_to_list


class ConvertHelpers(TestCase):
def test_basics(self):
inputs = ops.array([1, 2, 3])
inputs = [1, 2, 3]
# Convert to tf.
outputs = convert_preprocessing_inputs(inputs)
self.assertAllEqual(outputs, ops.array(inputs))
Expand Down Expand Up @@ -92,6 +93,18 @@ def to_list(x):
inputs = tree.flatten(tree.map_structure(to_list, inputs))
self.assertAllEqual(outputs, inputs)

def test_placement(self):
# Make sure we always place preprocessing on the CPU on all backends.
@preprocessing_function
def test(self, inputs):
for x in inputs:
if isinstance(x, tf.Tensor):
self.assertTrue("CPU" in x.device)
self.assertFalse("GPU" in x.device)
return inputs

test(self, ([1, 2, 3], ["foo", "bar"], "foo"))


class TensorToListTest(TestCase):
def test_ragged_input(self):
Expand Down

0 comments on commit f75965f

Please sign in to comment.