Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[coverage] Value Error for aten._native_batch_norm_legit_no_training.default #3180

Open
Tracked by #3179
chohk88 opened this issue Sep 25, 2024 · 0 comments
Open
Tracked by #3179

Comments

@chohk88
Copy link
Collaborator

chohk88 commented Sep 25, 2024

An error occurred in aten._native_batch_norm_legit_no_training.default due to mismatched shift weights and invalid output shape computation during Torch-TRT compilation.

2024-08-30 23:23:57.723 | ERROR    | Process-5 | /usr/local/lib/python3.10/dist-packages/torch_tensorrt/logging.py:24 - Error Code: 3: [SCALE]-[aten_ops._native_batch_norm_legit_no_training.default]-[/model/0/conv_unit0_adn_N/_native_batch_norm_legit_no_training]:shift weights has count 16 but 64 was expected
2024-08-30 23:23:57.754 | ERROR    | Process-5 | /usr/local/lib/python3.10/dist-packages/torch_tensorrt/logging.py:24 - ITensor::getDimensions: Error Code 4: API Usage Error ([SCALE]-[aten_ops._native_batch_norm_legit_no_training.default]-[/model/0/conv_unit0_adn_N/_native_batch_norm_legit_no_training]: invalid `shift` weights, check IErrorRecorder for details.)
2024-08-30 23:23:57.771 | ERROR    | Process-5 | /usr/local/lib/python3.10/dist-packages/torch_tensorrt/logging.py:24 - ITensor::getDimensions: Error Code 4: API Usage Error (Output shape can not be computed for node [SCALE]-[aten_ops._native_batch_norm_legit_no_training.default]-[/model/0/conv_unit0_adn_N/_native_batch_norm_legit_no_training].)
2024-08-30 23:23:57.789 | ERROR    | Process-5 | /usr/local/lib/python3.10/dist-packages/torch_tensorrt/logging.py:24 - ITensor::getDimensions: Error Code 4: API Usage Error (Output shape can not be computed for node [PARAMETRIC_RELU]-[aten_ops._prelu_kernel.default]-[/model/0/conv_unit0_adn_A/_prelu_kernel].)
2024-08-30 23:23:57.807 | ERROR    | Process-5 | /usr/local/lib/python3.10/dist-packages/torch_tensorrt/logging.py:24 - ITensor::getDimensions: Error Code 4: API Usage Error (Output shape can not be computed for node [CONVOLUTION]-[aten_ops.convolution.default]-[/model/0/conv_unit1_conv/convolution_2].)
2024-08-30 23:23:57.824 | ERROR    | Process-5 | /usr/local/lib/python3.10/dist-packages/torch_tensorrt/logging.py:24 - ITensor::getDimensions: Error Code 4: API Usage Error (Output shape can not be computed for node [CONVOLUTION]-[aten_ops.convolution.default]-[/model/0/conv_unit1_conv/convolution_2].)
2024-08-30 23:23:57.840 | ERROR    | Process-5 | /usr/local/lib/python3.10/dist-packages/torch_tensorrt/logging.py:24 - ITensor::getDimensions: Error Code 4: API Usage Error (Output shape can not be computed for node [CONVOLUTION]-[aten_ops.convolution.default]-[/model/0/conv_unit1_conv/convolution_2].)
2024-08-30 23:23:57.841 | ERROR    | Process-5 | /usr/local/lib/python3.10/dist-packages/model_navigator/commands/execution_context.py:164 - Command exited with error: Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/model_navigator/commands/execution_context.py", line 156, in _execute_function
    fire.Fire(func, unwrapped_args)
  File "/usr/local/lib/python3.10/dist-packages/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/usr/local/lib/python3.10/dist-packages/fire/core.py", line 466, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/usr/local/lib/python3.10/dist-packages/fire/core.py", line 681, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/model_navigator/commands/convert/converters/ep2torchtrt.py", line 129, in convert
    tr_model_compiled = torch_tensorrt.dynamo.compile(
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/dynamo/_compiler.py", line 243, in compile
    trt_gm = compile_module(gm, inputs, settings)
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/dynamo/_compiler.py", line 431, in compile_module
    trt_module = convert_module(
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 107, in convert_module
    interpreter_result = interpret_module_to_result(module, inputs, settings)
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 88, in interpret_module_to_result
    interpreter_result = interpreter.run()
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 336, in run
    self._construct_trt_network_def()
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 317, in _construct_trt_network_def
    super().run()
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/interpreter.py", line 146, in run
    self.env[node] = self.run_node(node)
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 378, in run_node
    trt_node: torch.fx.Node = super().run_node(n)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/interpreter.py", line 203, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 493, in call_function
    return converter(self.ctx, target, args, kwargs, self._cur_node_name)
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 111, in aten_ops_batch_norm_legit_no_training
    return impl.normalization.batch_norm(
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py", line 65, in batch_norm
    if len(input.shape) < 4:
ValueError: __len__() should return >= 0
While executing %_native_batch_norm_legit_no_training_1 : [num_users=1] = call_function[target=torch.ops.aten._native_batch_norm_legit_no_training.default](args = (%convolution_2, %model_0_conv_unit1_adn_n_weight, %model_0_conv_unit1_adn_n_bias, %model_0_conv_unit1_adn_n_running_mean, %model_0_conv_unit1_adn_n_running_var, 0.1, 1e-05), kwargs = {_itensor_to_tensor_meta: {<tensorrt.tensorrt.ITensor object at 0x7f3a216df4b0>: ((s0, 1, 128, 128, 128), torch.float32, False, (2097152, 2097152, 16384, 128, 1), torch.channels_last_3d, False, {}), <tensorrt.tensorrt.ITensor object at 0x7f3a21624670>: ((s0, 16, 64, 64, 64), torch.float32, False, (4194304, 262144, 4096, 64, 1), torch.contiguous_format, False, {}), <tensorrt.tensorrt.ITensor object at 0x7f3a21543cf0>: ((s0, 16, 64, 64, 64), torch.float32, False, (4194304, 262144, 4096, 64, 1), torch.contiguous_format, False, {}), <tensorrt.tensorrt.ITensor object at 0x7f3a21624db0>: ((s0, 16, 64, 64, 64), torch.float32, False, (4194304, 262144, 4096, 64, 1), torch.contiguous_format, False, {}), <tensorrt.tensorrt.ITensor object at 0x7f3a21625db0>: ((s0, 16, 64, 64, 64), torch.float32, False, (4194304, 262144, 4096, 64, 1), torch.contiguous_format, False, {}), <tensorrt.tensorrt.ITensor object at 0x7f3a216e3430>: ((s0, 16, 64, 64, 64), torch.float32, False, (4194304, 262144, 4096, 64, 1), torch.contiguous_format, False, {}), <tensorrt.tensorrt.ITensor object at 0x7f3a215b0270>: ((s0, 16, 64, 64, 64), torch.float32, False, (4194304, 262144, 4096, 64, 1), torch.contiguous_format, False, {})}})
Original traceback:
  File "/usr/local/lib/python3.10/dist-packages/monai/networks/nets/unet.py", line 300, in forward
    x = self.model(x)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1725, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/monai/networks/blocks/convolutions.py", line 317, in forward
    cx: torch.Tensor = self.conv(x)  # apply x to sequence of operations
. Command to reproduce error: /bin/bash torch-trt-fp32/reproduce_conversion.sh
2024-08-30 23:23:57.845 | WARNING  | Process-5 | /usr/lib/python3.10/warnings.py:109 - /usr/lib/python3.10/tempfile.py:1008: ResourceWarning: Implicitly cleaning up <TemporaryDirectory '/tmp/tmprtco2y3s'>
  _warnings.warn(warn_message, ResourceWarning)
2024-08-30 23:23:59.935 | INFO     | MainProcess | /usr/local/lib/python3.10/dist-packages/model_navigator/commands/execution_context.py:218 - Command: /usr/bin/python torch-trt-fp32/reproduce_conversion.py --exported_model_path 'torch-exportedprogram/model.pt2' --converted_model_path 'torch-trt-fp32/model.ep' --input_metadata '{"metadata": [{"name": "input__0", "shape": (-1, 1, 128, 128, 128), "dtype": "float32"}], "pytree_metadata": {"metadata": ("input__0", {}), "tensor_type": "torch"}, "is_legacy": False}' --shapes '{"input__0": {"min": (1, 1, 128, 128, 128), "opt": (2, 1, 128, 128, 128), "max": (2, 1, 128, 128, 128)}}' --batch_dim '0' --max_workspace_size '8589934592' --precision 'fp32' --precision_mode 'hierarchy' --target_device 'cuda' --custom_args '{}' --debug 'False'
/usr/local/lib/python3.10/dist-packages/modelopt/torch/quantization/tensor_quant.py:92: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
  scaled_e4m3_abstract = torch.library.impl_abstract("trt::quantize_fp8")(
[WARNING  | py.warnings        ]: /usr/local/lib/python3.10/dist-packages/tritonclient/grpc/service_pb2_grpc.py:21: RuntimeWarning: The grpc package installed is at version 1.62.1, but the generated code in grpc_service_pb2_grpc.py depends on grpcio>=1.65.5. Please upgrade your grpc module to grpcio>=1.65.5 or downgrade your generated code using grpcio-tools<=1.62.1. This warning will become an error in 1.66.0, scheduled for release on August 6, 2024.
  warnings.warn(
2024-08-30 23:24:04.941 | WARNING  | Process-6 | /usr/lib/python3.10/warnings.py:109 - /usr/local/lib/python3.10/dist-packages/torch/_export/serde/serialize.py:319: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  artifact = torch.load(buffer)
2024-08-30 23:24:09.643 | ERROR    | Process-6 | /usr/local/lib/python3.10/dist-packages/torch_tensorrt/logging.py:24 - Error Code: 3: [SCALE]-[aten_ops._native_batch_norm_legit_no_training.default]-[/model/0/conv_unit0_adn_N/_native_batch_norm_legit_no_training]:shift weights has count 16 but 64 was expected
2024-08-30 23:24:09.672 | ERROR    | Process-6 | /usr/local/lib/python3.10/dist-packages/torch_tensorrt/logging.py:24 - ITensor::getDimensions: Error Code 4: API Usage Error ([SCALE]-[aten_ops._native_batch_norm_legit_no_training.default]-[/model/0/conv_unit0_adn_N/_native_batch_norm_legit_no_training]: invalid `shift` weights, check IErrorRecorder for details.)
2024-08-30 23:24:09.688 | ERROR    | Process-6 | /usr/local/lib/python3.10/dist-packages/torch_tensorrt/logging.py:24 - ITensor::getDimensions: Error Code 4: API Usage Error (Output shape can not be computed for node [SCALE]-[aten_ops._native_batch_norm_legit_no_training.default]-[/model/0/conv_unit0_adn_N/_native_batch_norm_legit_no_training].)
2024-08-30 23:24:09.705 | ERROR    | Process-6 | /usr/local/lib/python3.10/dist-packages/torch_tensorrt/logging.py:24 - ITensor::getDimensions: Error Code 4: API Usage Error (Output shape can not be computed for node [PARAMETRIC_RELU]-[aten_ops._prelu_kernel.default]-[/model/0/conv_unit0_adn_A/_prelu_kernel].)
2024-08-30 23:24:09.722 | ERROR    | Process-6 | /usr/local/lib/python3.10/dist-packages/torch_tensorrt/logging.py:24 - ITensor::getDimensions: Error Code 4: API Usage Error (Output shape can not be computed for node [CONVOLUTION]-[aten_ops.convolution.default]-[/model/0/conv_unit1_conv/convolution_2].)
2024-08-30 23:24:09.738 | ERROR    | Process-6 | /usr/local/lib/python3.10/dist-packages/torch_tensorrt/logging.py:24 - ITensor::getDimensions: Error Code 4: API Usage Error (Output shape can not be computed for node [CONVOLUTION]-[aten_ops.convolution.default]-[/model/0/conv_unit1_conv/convolution_2].)
2024-08-30 23:24:09.753 | ERROR    | Process-6 | /usr/local/lib/python3.10/dist-packages/torch_tensorrt/logging.py:24 - ITensor::getDimensions: Error Code 4: API Usage Error (Output shape can not be computed for node [CONVOLUTION]-[aten_ops.convolution.default]-[/model/0/conv_unit1_conv/convolution_2].)
2024-08-30 23:24:09.754 | ERROR    | Process-6 | /usr/local/lib/python3.10/dist-packages/model_navigator/commands/execution_context.py:164 - Command exited with error: Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/model_navigator/commands/execution_context.py", line 156, in _execute_function
    fire.Fire(func, unwrapped_args)
  File "/usr/local/lib/python3.10/dist-packages/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/usr/local/lib/python3.10/dist-packages/fire/core.py", line 466, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/usr/local/lib/python3.10/dist-packages/fire/core.py", line 681, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/model_navigator/commands/convert/converters/ep2torchtrt.py", line 129, in convert
    tr_model_compiled = torch_tensorrt.dynamo.compile(
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/dynamo/_compiler.py", line 243, in compile
    trt_gm = compile_module(gm, inputs, settings)
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/dynamo/_compiler.py", line 431, in compile_module
    trt_module = convert_module(
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 107, in convert_module
    interpreter_result = interpret_module_to_result(module, inputs, settings)
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 88, in interpret_module_to_result
    interpreter_result = interpreter.run()
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 336, in run
    self._construct_trt_network_def()
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 317, in _construct_trt_network_def
    super().run()
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/interpreter.py", line 146, in run
    self.env[node] = self.run_node(node)
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 378, in run_node
    trt_node: torch.fx.Node = super().run_node(n)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/interpreter.py", line 203, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 493, in call_function
    return converter(self.ctx, target, args, kwargs, self._cur_node_name)
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 111, in aten_ops_batch_norm_legit_no_training
    return impl.normalization.batch_norm(
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py", line 65, in batch_norm
    if len(input.shape) < 4:
ValueError: __len__() should return >= 0
While executing %_native_batch_norm_legit_no_training_1 : [num_users=1] = call_function[target=torch.ops.aten._native_batch_norm_legit_no_training.default](args = (%convolution_2, %model_0_conv_unit1_adn_n_weight, %model_0_conv_unit1_adn_n_bias, %model_0_conv_unit1_adn_n_running_mean, %model_0_conv_unit1_adn_n_running_var, 0.1, 1e-05), kwargs = {_itensor_to_tensor_meta: {<tensorrt.tensorrt.ITensor object at 0x7f8711645cb0>: ((s0, 1, 128, 128, 128), torch.float32, False, (2097152, 2097152, 16384, 128, 1), torch.channels_last_3d, False, {}), <tensorrt.tensorrt.ITensor object at 0x7f8711[767](https://gitlab-master.nvidia.com/dl/jet/ci/-/jobs/109186628#L767)c30>: ((s0, 16, 64, 64, 64), torch.float32, False, (4194304, 262144, 4096, 64, 1), torch.contiguous_format, False, {}), <tensorrt.tensorrt.ITensor object at 0x7f8711506cf0>: ((s0, 16, 64, 64, 64), torch.float32, False, (4194304, 262144, 4096, 64, 1), torch.contiguous_format, False, {}), <tensorrt.tensorrt.ITensor object at 0x7f8711616970>: ((s0, 16, 64, 64, 64), torch.float32, False, (4194304, 262144, 4096, 64, 1), torch.contiguous_format, False, {}), <tensorrt.tensorrt.ITensor object at 0x7f87117d0a70>: ((s0, 16, 64, 64, 64), torch.float32, False, (4194304, 262144, 4096, 64, 1), torch.contiguous_format, False, {}), <tensorrt.tensorrt.ITensor object at 0x7f8711507170>: ((s0, 16, 64, 64, 64), torch.float32, False, (4194304, 262144, 4096, 64, 1), torch.contiguous_format, False, {}), <tensorrt.tensorrt.ITensor object at 0x7f87116da1f0>: ((s0, 16, 64, 64, 64), torch.float32, False, (4194304, 262144, 4096, 64, 1), torch.contiguous_format, False, {})}})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant