Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot build quantized int8 models for Phi3 128k models [TensorRT-LLM 0.12.0] #2214

Open
2 of 4 tasks
louis845 opened this issue Sep 10, 2024 · 1 comment
Open
2 of 4 tasks
Labels
bug Something isn't working

Comments

@louis845
Copy link

System Info

  • CPU x86_64 (intel i9)
  • 128G memory (RAM)
  • GPU: 1 x RTXA6000
  • Libraries:
    • TensorRT-LLM 0.12.0 (stable)
    • TensorRT 10.3.0
    • transformers 4.42.4
    • CUDA version 12.2 (driver 535.183.01)
    • Python 3.10.12

Who can help?

@Tracin
@ncomly-nvidia
@kaiyux

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Steps to reproduce the behavior:

  1. Use examples/quantization/quantize.py to convert Phi3 128k models (mini and medium) to int8 (int8_sq quantization, int8 kv_cache)
  2. Use trtllm-build to build the engine

Expected behavior

Builds successfully, as per the support matrix of Phi3 in TRTLLM v0.12.0: https://github.com/NVIDIA/TensorRT-LLM/tree/28fb9aacaa5a05494635194a9cbb264da9a744bd/examples/phi

actual behavior

Errors occur when building the engine with trtllm-build. Command:

trtllm-build --checkpoint_dir ./tempquant --output_dir ./temptrtllm --gemm_plugin auto --max_batch_size 1 --max_input_len 1024 --max_seq_len 2048

Errors:

[TensorRT-LLM] TensorRT-LLM version: 0.12.0
[09/10/2024-16:41:48] [TRT-LLM] [I] Set bert_attention_plugin to auto.
[09/10/2024-16:41:48] [TRT-LLM] [I] Set gpt_attention_plugin to auto.
[09/10/2024-16:41:48] [TRT-LLM] [I] Set gemm_plugin to auto.
[09/10/2024-16:41:48] [TRT-LLM] [I] Set gemm_swiglu_plugin to None.
[09/10/2024-16:41:48] [TRT-LLM] [I] Set fp8_rowwise_gemm_plugin to None.
[09/10/2024-16:41:48] [TRT-LLM] [I] Set nccl_plugin to auto.
[09/10/2024-16:41:48] [TRT-LLM] [I] Set lookup_plugin to None.
[09/10/2024-16:41:48] [TRT-LLM] [I] Set lora_plugin to None.
[09/10/2024-16:41:48] [TRT-LLM] [I] Set moe_plugin to auto.
[09/10/2024-16:41:48] [TRT-LLM] [I] Set mamba_conv1d_plugin to auto.
[09/10/2024-16:41:48] [TRT-LLM] [I] Set context_fmha to True.
[09/10/2024-16:41:48] [TRT-LLM] [I] Set bert_context_fmha_fp32_acc to False.
[09/10/2024-16:41:48] [TRT-LLM] [I] Set paged_kv_cache to True.
[09/10/2024-16:41:48] [TRT-LLM] [I] Set remove_input_padding to True.
[09/10/2024-16:41:48] [TRT-LLM] [I] Set reduce_fusion to False.
[09/10/2024-16:41:48] [TRT-LLM] [I] Set enable_xqa to True.
[09/10/2024-16:41:48] [TRT-LLM] [I] Set tokens_per_block to 64.
[09/10/2024-16:41:48] [TRT-LLM] [I] Set use_paged_context_fmha to False.
[09/10/2024-16:41:48] [TRT-LLM] [I] Set use_fp8_context_fmha to False.
[09/10/2024-16:41:48] [TRT-LLM] [I] Set multiple_profiles to False.
[09/10/2024-16:41:48] [TRT-LLM] [I] Set paged_state to True.
[09/10/2024-16:41:48] [TRT-LLM] [I] Set streamingllm to False.
[09/10/2024-16:41:48] [TRT-LLM] [W] Implicitly setting Phi3Config.producer = {'name': 'modelopt', 'version': '0.15.1'}
[09/10/2024-16:41:48] [TRT-LLM] [W] Implicitly setting Phi3Config.residual_mlp = False
[09/10/2024-16:41:48] [TRT-LLM] [W] Implicitly setting Phi3Config.bias = False
[09/10/2024-16:41:48] [TRT-LLM] [W] Implicitly setting Phi3Config.rotary_pct = 1.0
[09/10/2024-16:41:48] [TRT-LLM] [W] Implicitly setting Phi3Config.rank = 0
[09/10/2024-16:41:48] [TRT-LLM] [W] Implicitly setting Phi3Config.decoder = phi3
[09/10/2024-16:41:48] [TRT-LLM] [W] Implicitly setting Phi3Config.rmsnorm = True
[09/10/2024-16:41:48] [TRT-LLM] [W] Implicitly setting Phi3Config.lm_head_bias = False
[09/10/2024-16:41:48] [TRT-LLM] [W] Implicitly setting Phi3Config.original_max_position_embeddings = 4096
[09/10/2024-16:41:48] [TRT-LLM] [W] Implicitly setting Phi3Config.longrope_scaling_short_factors = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.01, 1.02, 1.02, 1.04, 1.04, 1.07, 1.07, 1.1, 1.3000000000000003, 1.3000000000000003, 1.5000000000000004, 1.5700000000000005, 1.9000000000000008, 2.3100000000000014, 2.759999999999992, 3.3899999999999784, 3.9399999999999666, 4.009999999999965, 4.289999999999959, 4.349999999999958, 5.349999999999937, 6.659999999999909, 7.029999999999901, 7.51999999999989, 8.00999999999988, 8.249999999999876, 8.279999999999875, 9.629999999999846, 9.89999999999984, 10.589999999999826, 11.049999999999816, 11.7899999999998, 12.189999999999792, 12.889999999999777, 13.129999999999772, 13.16999999999977, 13.20999999999977, 13.479999999999764, 13.539999999999763, 13.779999999999758, 13.929999999999755, 14.429999999999744, 14.759999999999737, 15.149999999999729, 15.419999999999723, 15.53999999999972, 15.659999999999718, 15.749999999999716, 15.759999999999716, 15.799999999999715, 16.05999999999971, 16.079999999999714, 16.11999999999972, 16.11999999999972, 16.18999999999973, 16.31999999999975, 16.539999999999786, 16.799999999999827]
[09/10/2024-16:41:48] [TRT-LLM] [W] Implicitly setting Phi3Config.longrope_scaling_long_factors = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.25, 1.25, 1.5, 2.0, 2.75, 5.75, 5.75, 6.5, 9.25, 11.0, 13.25, 19.25, 19.75, 19.75, 21.25, 21.5, 26.5, 30.0, 33.75, 35.25, 38.5, 42.0, 42.25, 46.0, 47.0, 50.0, 50.5, 51.0, 52.0, 52.75, 53.75, 54.75, 57.0, 57.25, 58.5, 59.25, 59.5, 62.0, 62.5, 62.75, 63.25, 63.25, 63.25, 63.75, 64.0, 64.0, 64.25, 64.5, 64.5, 65.0, 65.0]
[09/10/2024-16:41:48] [TRT-LLM] [I] Compute capability: (8, 6)
[09/10/2024-16:41:48] [TRT-LLM] [I] SM count: 84
[09/10/2024-16:41:48] [TRT-LLM] [I] SM clock: 2100 MHz
[09/10/2024-16:41:48] [TRT-LLM] [I] int4 TFLOPS: 722
[09/10/2024-16:41:48] [TRT-LLM] [I] int8 TFLOPS: 361
[09/10/2024-16:41:48] [TRT-LLM] [I] fp8 TFLOPS: 0
[09/10/2024-16:41:48] [TRT-LLM] [I] float16 TFLOPS: 180
[09/10/2024-16:41:48] [TRT-LLM] [I] bfloat16 TFLOPS: 180
[09/10/2024-16:41:48] [TRT-LLM] [I] float32 TFLOPS: 90
[09/10/2024-16:41:48] [TRT-LLM] [I] Total Memory: 47 GiB
[09/10/2024-16:41:48] [TRT-LLM] [I] Memory clock: 8001 MHz
[09/10/2024-16:41:48] [TRT-LLM] [I] Memory bus width: 384
[09/10/2024-16:41:48] [TRT-LLM] [I] Memory bandwidth: 768 GB/s
[09/10/2024-16:41:48] [TRT-LLM] [I] NVLink is active: False
[09/10/2024-16:41:48] [TRT-LLM] [I] PCIe speed: 2500 Mbps
[09/10/2024-16:41:48] [TRT-LLM] [I] PCIe link width: 8
[09/10/2024-16:41:48] [TRT-LLM] [I] PCIe bandwidth: 2 GB/s
Traceback (most recent call last):
  File "/home/louis_ml/.local/lib/python3.10/site-packages/tensorrt_llm/functional.py", line 647, in from_string
    return RotaryScalingType[s]
  File "/usr/lib/python3.10/enum.py", line 440, in __getitem__
    return cls._member_map_[name]
KeyError: 'su'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/louis_ml/.local/bin/trtllm-build", line 8, in <module>
    sys.exit(main())
  File "/home/louis_ml/.local/lib/python3.10/site-packages/tensorrt_llm/commands/build.py", line 500, in main
    parallel_build(model_config, ckpt_dir, build_config, args.output_dir,
  File "/home/louis_ml/.local/lib/python3.10/site-packages/tensorrt_llm/commands/build.py", line 377, in parallel_build
    passed = build_and_save(rank, rank % workers, ckpt_dir,
  File "/home/louis_ml/.local/lib/python3.10/site-packages/tensorrt_llm/commands/build.py", line 344, in build_and_save
    engine = build_model(build_config,
  File "/home/louis_ml/.local/lib/python3.10/site-packages/tensorrt_llm/commands/build.py", line 313, in build_model
    model = model_cls.from_checkpoint(ckpt_dir, config=rank_config)
  File "/home/louis_ml/.local/lib/python3.10/site-packages/tensorrt_llm/models/modeling_utils.py", line 480, in from_checkpoint
    model = cls(config)
  File "/home/louis_ml/.local/lib/python3.10/site-packages/tensorrt_llm/models/modeling_utils.py", line 418, in __call__
    obj = type.__call__(cls, *args, **kwargs)
  File "/home/louis_ml/.local/lib/python3.10/site-packages/tensorrt_llm/models/phi3/model.py", line 232, in __init__
    super().__init__(config, transformer, lm_head)
  File "/home/louis_ml/.local/lib/python3.10/site-packages/tensorrt_llm/models/modeling_utils.py", line 717, in __init__
    Attention.create_attention_const_params(self, config)
  File "/home/louis_ml/.local/lib/python3.10/site-packages/tensorrt_llm/layers/attention.py", line 515, in create_attention_const_params
    rotary_embedding_scale_type = RotaryScalingType.from_string(
  File "/home/louis_ml/.local/lib/python3.10/site-packages/tensorrt_llm/functional.py", line 649, in from_string
    raise ValueError(f'Unsupported rotary scaling type: {s}')
ValueError: Unsupported rotary scaling type: su

additional notes

The environment that TRT is being run in is a VM with GPU passthrough (Ubuntu host, Ubuntu guest). The models are downloaded from HF's official Phi3 models: eg https://huggingface.co/collections/microsoft/phi-3-6626e15e9585a200d2d761e3, and are of the latest version.

Building and running the models (Phi3 mini 4k, mini 128k, medium 4k, medium 128k) without quantization (using dtype bf16 or fp16) works as intended (on the same VM environment), while only the models (Phi3 mini 4k, medium 4k) work with int8 quantization. I believe that the problem is due to the implementation of ROPE within TRTLLM, as evident from the error message and the fact that only 128k Phi3 models have problems with int8 quantization.

I hope that this can be resolved, as models with longer contexts are very useful in a practical sense. If possible, please also considering supporting int4 quantization officially for Phi3 models in the future. Thanks!

@louis845 louis845 added the bug Something isn't working label Sep 10, 2024
@eoastafurov
Copy link

Same for me

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants