-
Notifications
You must be signed in to change notification settings - Fork 233
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Random Replacement #274
Random Replacement #274
Conversation
/gcbrun |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Round of comments on this!
"provided." | ||
) | ||
|
||
countReplaceOptions = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
under_score note camelCase for local variables
f"Received: rate={rate}" | ||
) | ||
|
||
if [self.skip_list, self.skip_fn, self.skip_py_fn].count(None) < 2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: format this the same as the lower check
num_to_select, self.max_replacements | ||
) | ||
num_to_select = tf.math.minimum( | ||
num_to_select, tf.cast(positions.row_lengths(), tf.int32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should not be needed, how could the binomial exceed the positions.row_lengths()
given as input?
# Convert to ragged tensor. | ||
inputs = tf.RaggedTensor.from_tensor(inputs) | ||
|
||
skip_masks = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe in the interest of readability pull this whole block into a private _generate_skip_mask()
method.
seed=self._generator.make_seeds()[:, 0], | ||
) | ||
] | ||
synonym = inputs[index] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Careful to name this inferring the use case too much. Choose name like original_token, replacement_token, that don't assume a word or character level usage.
] | ||
synonym = inputs[index] | ||
|
||
if self.replacement_fn is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we pull out a private method _replace_token()
to make this more readable?
|
||
if self.replacement_fn is not None: | ||
synonym = self.replacement_fn(synonym) | ||
inputs = tf.tensor_scatter_nd_update( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm skeptical we really need a map_fn
with a loop inside it, with a scatter inside it. That sounds like it will be inefficient.
If we follow the flow we have in random deletion, we will have a complete mask
containing only the indices we want to run a deletion on right? Then we could do something like run a tf.map_fn
over the pair of (inputs.flat_values, mask.flat_values)
and early return if mask == false, otherwise we lookup a replacement.
There might be other ways to do this, but overall we should:
- avoid nested maps/loops
- avoid scatters inside a map/loop
PR for Random Replacement Layer