Skip to content

Commit

Permalink
chore: Better random bool values for example_tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
keehyuna committed May 20, 2024
1 parent 00d1d85 commit fc6499e
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions py/torch_tensorrt/_Input.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,9 +364,13 @@ def example_tensor(
)

if isinstance(self.shape, dict):
return torch.rand(self.shape[optimization_profile_field]).to(
dtype=self.dtype.to(torch.dtype, use_default=True)
)
if self.dtype == dtype.b:
return torch.rand(self.shape[optimization_profile_field]) < 0.5
else:
return torch.rand(self.shape[optimization_profile_field]).to(
dtype=self.dtype.to(torch.dtype, use_default=True)
)

else:
raise RuntimeError(
f"Input shape is dynamic but shapes are not provided as dictionary (found: {self.shape})"
Expand Down

0 comments on commit fc6499e

Please sign in to comment.