Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Random Replacement #274

Closed
wants to merge 21 commits into from
Closed

Conversation

aflah02
Copy link
Collaborator

@aflah02 aflah02 commented Jul 20, 2022

PR for Random Replacement Layer

@chenmoneygithub
Copy link
Contributor

/gcbrun

@mattdangerw mattdangerw self-requested a review September 7, 2022 17:02
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Round of comments on this!

"provided."
)

countReplaceOptions = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

under_score note camelCase for local variables

f"Received: rate={rate}"
)

if [self.skip_list, self.skip_fn, self.skip_py_fn].count(None) < 2:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: format this the same as the lower check

num_to_select, self.max_replacements
)
num_to_select = tf.math.minimum(
num_to_select, tf.cast(positions.row_lengths(), tf.int32)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should not be needed, how could the binomial exceed the positions.row_lengths() given as input?

# Convert to ragged tensor.
inputs = tf.RaggedTensor.from_tensor(inputs)

skip_masks = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe in the interest of readability pull this whole block into a private _generate_skip_mask() method.

seed=self._generator.make_seeds()[:, 0],
)
]
synonym = inputs[index]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Careful to name this inferring the use case too much. Choose name like original_token, replacement_token, that don't assume a word or character level usage.

]
synonym = inputs[index]

if self.replacement_fn is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we pull out a private method _replace_token() to make this more readable?


if self.replacement_fn is not None:
synonym = self.replacement_fn(synonym)
inputs = tf.tensor_scatter_nd_update(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm skeptical we really need a map_fn with a loop inside it, with a scatter inside it. That sounds like it will be inefficient.

If we follow the flow we have in random deletion, we will have a complete mask containing only the indices we want to run a deletion on right? Then we could do something like run a tf.map_fn over the pair of (inputs.flat_values, mask.flat_values) and early return if mask == false, otherwise we lookup a replacement.

There might be other ways to do this, but overall we should:

  • avoid nested maps/loops
  • avoid scatters inside a map/loop

@mattdangerw mattdangerw self-assigned this Mar 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants