Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG]RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! #370

Closed
dlutsniper opened this issue Oct 17, 2023 · 15 comments
Labels
bug Something isn't working

Comments

@dlutsniper
Copy link

Describe the bug
try quick start demo code for finetuned qwen model, but RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Hardware details
GPU: RTX A40 48GB
CPU: 15vCpu Momory: 80GB

Software version
Version of relevant software such as operation system, cuda toolkit, python, auto-gptq, pytorch, transformers, accelerate, etc.
OS: Ubuntu 22.04.1 LTS
cuda: 11.8
python: 3.10.8
auto-gptq 0.4.2
transformers 4.32.0
torch 2.1.0
accelerate 0.23.0

To Reproduce

from transformers import AutoTokenizer, TextGenerationPipeline
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig

pretrained_model_dir = "/root/autodl-tmp/newmodel_14b"
quantized_model_dir = "/root/autodl-tmp/newmodel_14b_int4"

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, trust_remote_code=True)
examples = [
    tokenizer(
        "auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."
    , return_tensors="pt")
]

quantize_config = BaseQuantizeConfig(
    bits=4,  # 将模型量化为 4-bit 数值类型
    group_size=128,  # 一般推荐将此参数的值设置为 128
    desc_act=False,  # 设为 False 可以显著提升推理速度,但是 ppl 可能会轻微地变差
)

# 加载未量化的模型,默认情况下,模型总是会被加载到 CPU 内存中
model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config,
                                            trust_remote_code=True, device_map="cuda:0")

# 量化模型, 样本的数据类型应该为 List[Dict],其中字典的键有且仅有 input_ids 和 attention_mask
model.quantize(examples)

# 使用 safetensors 保存量化好的模型
model.save_quantized(quantized_model_dir, use_safetensors=True)

Expected behavior
as demo show result

Screenshots
图片

Additional context
log:

Warning: please make sure that you are using the latest codes and checkpoints, especially if you used Qwen-7B before 09.25.2023.请使用最新模型和代码,尤其如果你在9月25日前已经开始使用Qwen-7B,千万注意不要使用错误代码和模型。
Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary
Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm
Warning: import flash_attn fail, please install FlashAttention to get higher efficiency https://github.com/Dao-AILab/flash-attention

Loading checkpoint shards: 100%
15/15 [00:03<00:00, 4.34it/s]

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[1], line 25
     21 model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config,
     22                                             trust_remote_code=True, device_map="cuda:0")
     24 # 量化模型, 样本的数据类型应该为 List[Dict],其中字典的键有且仅有 input_ids 和 attention_mask
---> 25 model.quantize(examples)
     27 # 使用 safetensors 保存量化好的模型
     28 model.save_quantized(quantized_model_dir, use_safetensors=True)

File ~/miniconda3/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/miniconda3/lib/python3.10/site-packages/auto_gptq/modeling/_base.py:359, in BaseGPTQForCausalLM.quantize(self, examples, batch_size, use_triton, use_cuda_fp16, autotune_warmup_after_quantized, cache_examples_on_gpu)
    357         else:
    358             additional_layer_inputs[k] = v
--> 359     layer(layer_input, **additional_layer_inputs)
    360 for h in handles:
    361     h.remove()

File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/.cache/huggingface/modules/transformers_modules/newmodel_14b/modeling_qwen.py:653, in QWenBlock.forward(self, hidden_states, rotary_pos_emb_list, registered_causal_mask, layer_past, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions)
    638 def forward(
    639     self,
    640     hidden_states: Optional[Tuple[torch.FloatTensor]],
   (...)
    649     output_attentions: Optional[bool] = False,
    650 ):
    651     layernorm_output = self.ln_1(hidden_states)
--> 653     attn_outputs = self.attn(
    654         layernorm_output,
    655         rotary_pos_emb_list,
    656         registered_causal_mask=registered_causal_mask,
    657         layer_past=layer_past,
    658         attention_mask=attention_mask,
    659         head_mask=head_mask,
    660         use_cache=use_cache,
    661         output_attentions=output_attentions,
    662     )
    663     attn_output = attn_outputs[0]
    665     outputs = attn_outputs[1:]

File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/.cache/huggingface/modules/transformers_modules/newmodel_14b/modeling_qwen.py:482, in QWenAttention.forward(self, hidden_states, rotary_pos_emb_list, registered_causal_mask, layer_past, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions, use_cache)
    480     q_pos_emb, k_pos_emb = rotary_pos_emb
    481     # Slice the pos emb for current inference
--> 482     query = apply_rotary_pos_emb(query, q_pos_emb)
    483     key = apply_rotary_pos_emb(key, k_pos_emb)
    484 else:

File ~/.cache/huggingface/modules/transformers_modules/newmodel_14b/modeling_qwen.py:1410, in apply_rotary_pos_emb(t, freqs)
   1408 t_ = t_.float()
   1409 t_pass_ = t_pass_.float()
-> 1410 t_ = (t_ * cos) + (_rotate_half(t_) * sin)
   1411 return torch.cat((t_, t_pass_), dim=-1).type_as(t)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
@dlutsniper dlutsniper added the bug Something isn't working label Oct 17, 2023
@wangitu
Copy link

wangitu commented Oct 17, 2023

在 model.quantize 函数中有如下代码:

if isinstance(v, torch.Tensor):
one_kwargs[k] = move_to_device(v, self.data_device)
else:
one_kwargs[k] = v

但是有些模型比如 qwen,qwenblock 在 forward 时接受的其中一个参数是 rotary_pos_emb_list。这个参数是一个列表: List[torch.Tensor]。但是以上的代码并不能把 List[torch.Tensor] 搬到 cuda 上。而计算是在 cuda 上进行的,所以会产生 cuda and cpu 的错误。

一个最简单的 workaround 是定义一个 nested_move_to_device:

def nested_move_to_device(v, device):
if isinstance(v, torch.Tensor):
return move_to_device(v, device)
elif isinstance(v, (list, tuple)):
return type(v)([nested_move_to_device(e, device) for e in v])
else:
return v

然后替换掉最开始的代码,问题便可迎刃而解。

可以参考这个 repo: https://github.com/wangitu/unpadded-AutoGPTQ

@dlutsniper
Copy link
Author

在 model.quantize 函数中有如下代码:

if isinstance(v, torch.Tensor): one_kwargs[k] = move_to_device(v, self.data_device) else: one_kwargs[k] = v

但是有些模型比如 qwen,qwenblock 在 forward 时接受的其中一个参数是 rotary_pos_emb_list。这个参数是一个列表: List[torch.Tensor]。但是以上的代码并不能把 List[torch.Tensor] 搬到 cuda 上。而计算是在 cuda 上进行的,所以会产生 cuda and cpu 的错误。

一个最简单的 workaround 是定义一个 nested_move_to_device:

def nested_move_to_device(v, device): if isinstance(v, torch.Tensor): return move_to_device(v, device) elif isinstance(v, (list, tuple)): return type(v)([nested_move_to_device(e, device) for e in v]) else: return v

然后替换掉最开始的代码,问题便可迎刃而解。

可以参考这个 repo: https://github.com/wangitu/unpadded-AutoGPTQ

真棒。量化成功了
图片
新问题是:推理报错

---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
Cell In[9], line 6
      4 quantized_model_dir = "/root/autodl-tmp/newmodel_7b-Int4"
      5 # load quantized model to the first GPU
----> 6 model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device="cuda:0",trust_remote_code=True)
      8 # download quantized model from Hugging Face Hub and load to the first GPU
      9 # model = AutoGPTQForCausalLM.from_quantized(repo_id, device="cuda:0", use_safetensors=True, use_triton=False)
     10 
     11 # inference with model.generate
     12 print(tokenizer.decode(model.generate(**tokenizer("auto_gptq is", return_tensors="pt").to(model.device))[0]))

File ~/miniconda3/lib/python3.10/site-packages/auto_gptq/modeling/auto.py:108, in AutoGPTQForCausalLM.from_quantized(cls, model_name_or_path, device_map, max_memory, device, low_cpu_mem_usage, use_triton, inject_fused_attention, inject_fused_mlp, use_cuda_fp16, quantize_config, model_basename, use_safetensors, trust_remote_code, warmup_triton, trainable, disable_exllama, **kwargs)
    102 # TODO: do we need this filtering of kwargs? @PanQiWei is there a reason we can't just pass all kwargs?
    103 keywords = {
    104     key: kwargs[key]
    105     for key in list(signature(quant_func).parameters.keys()) + huggingface_kwargs
    106     if key in kwargs
    107 }
--> 108 return quant_func(
    109     model_name_or_path=model_name_or_path,
    110     device_map=device_map,
    111     max_memory=max_memory,
    112     device=device,
    113     low_cpu_mem_usage=low_cpu_mem_usage,
    114     use_triton=use_triton,
    115     inject_fused_attention=inject_fused_attention,
    116     inject_fused_mlp=inject_fused_mlp,
    117     use_cuda_fp16=use_cuda_fp16,
    118     quantize_config=quantize_config,
    119     model_basename=model_basename,
    120     use_safetensors=use_safetensors,
    121     trust_remote_code=trust_remote_code,
    122     warmup_triton=warmup_triton,
    123     trainable=trainable,
    124     disable_exllama=disable_exllama,
    125     **keywords
    126 )

File ~/miniconda3/lib/python3.10/site-packages/auto_gptq/modeling/_base.py:791, in BaseGPTQForCausalLM.from_quantized(cls, model_name_or_path, device_map, max_memory, device, low_cpu_mem_usage, use_triton, torch_dtype, inject_fused_attention, inject_fused_mlp, use_cuda_fp16, quantize_config, model_basename, use_safetensors, trust_remote_code, warmup_triton, trainable, disable_exllama, **kwargs)
    788             break
    790 if resolved_archive_file is None: # Could not find a model file to use
--> 791     raise FileNotFoundError(f"Could not find model in {model_name_or_path}")
    793 model_save_name = resolved_archive_file
    795 if not disable_exllama and trainable:

FileNotFoundError: Could not find model in /root/autodl-tmp/newmodel_7b-Int4

图片

@wangitu
Copy link

wangitu commented Oct 18, 2023

关于你的“推理报错的新问题”,原因是你在 model.save_quantize 后,保存你量化模型的目录里只有权重文件以及相关配置文件。

但是在 AutoGPTQForCausalLM.from_quantized 且 trust_remote_code 时,代码逻辑是找到 "保存量化模型的目录中的定义模型结构的文件" (qwen 是 modeling_qwen.py),然后把权重加载到上述文件定义的模型中。

可是,你保存量化模型的目录中,没有这样的 modeling_qwen.py,代码就会报错。

所以,把原始的 qwen 中除了权重以外的所有文件 (包括 tokenizer 的代码等),全都复制到你的量化模型的目录中,问题应该就可以解决了。

@dlutsniper
Copy link
Author

关于你的“推理报错的新问题”,原因是你在 model.save_quantize 后,保存你量化模型的目录里只有权重文件以及相关配置文件。

但是在 AutoGPTQForCausalLM.from_quantized 且 trust_remote_code 时,代码逻辑是找到 "保存量化模型的目录中的定义模型结构的文件" (qwen 是 modeling_qwen.py),然后把权重加载到上述文件定义的模型中。

可是,你保存量化模型的目录中,没有这样的 modeling_qwen.py,代码就会报错。

所以,把原始的 qwen 中除了权重以外的所有文件 (包括 tokenizer 的代码等),全都复制到你的量化模型的目录中,问题应该就可以解决了。

多谢您的指点,我从原始的模型复制了其他文件,推理有了新的问题:RuntimeError: value cannot be converted to type at::BFloat16 without overflow

图片

@wangitu
Copy link

wangitu commented Oct 18, 2023

关于你的“推理报错的新问题”,原因是你在 model.save_quantize 后,保存你量化模型的目录里只有权重文件以及相关配置文件。
但是在 AutoGPTQForCausalLM.from_quantized 且 trust_remote_code 时,代码逻辑是找到 "保存量化模型的目录中的定义模型结构的文件" (qwen 是 modeling_qwen.py),然后把权重加载到上述文件定义的模型中。
可是,你保存量化模型的目录中,没有这样的 modeling_qwen.py,代码就会报错。
所以,把原始的 qwen 中除了权重以外的所有文件 (包括 tokenizer 的代码等),全都复制到你的量化模型的目录中,问题应该就可以解决了。

多谢您的指点,我从原始的模型复制了其他文件,推理有了新的问题:RuntimeError: value cannot be converted to type at::BFloat16 without overflow

图片

关于这个 overflow 的问题,我没有遇到过。但我估计一个可能的原因是(在你报错的最后一行):
masked_fill(~causal mask, torch,finfo(query.dtype).min)
amp.autocast 应该会将 torch.finfo(query.dtype).min 进行自动转换类型,但 torch.finfo(query.dtype).min 却超出了被转换成为的类型的表示范围。

我以前在 full-parameter sft 百川的时候也有溢出的问题,我当时的解决方法是把 torch.finfo(query.dtype).min 替换为 -50 (e^-50 已经足够小了)。你可以这样试试。另外附上一个相似的解决方案:https://discuss.pytorch.org/t/runtimeerror-value-cannot-be-converted-to-type-at-half-without-overflow-1e-30/109768

不能确定是否可以解决哈

@xunfeng1980
Copy link

Qwen-Chat-14B 同样的问题

@xunfeng1980
Copy link

好奇 Qwen 官方提供的 int4 模型是咋量化的,不过并没有提供量化脚本

@dlutsniper
Copy link
Author

对比原始模型,原始Int4模型,微调模型及其AutoGPTQ量化模型的config.json,
微调模型及其AutoGPTQ量化模型:
bf16: true,fp16: false
原始Int4模型:
bf16: false,fp16: true
所以我手动修改与原始Int4模型的配置一样,推理报错(RuntimeError: value cannot be converted to type at::BFloat16 without overflow)解决了
图片

新问题是微调训练的表现消失了😂
图片
图片

@dlutsniper
Copy link
Author

原来基于Qwen-7B-Chat微调训练,量化之后进行推理,应该将demo的pipeline切换为chat,推理表现正常了

图片

@xunfeng1980
Copy link

原来基于Qwen-7B-Chat微调训练,量化之后进行推理,应该将demo的pipeline切换为chat,推理表现正常了

图片

你好,问下有遇到温度低于 0.6 就报错不

@dlutsniper
Copy link
Author

原来基于Qwen-7B-Chat微调训练,量化之后进行推理,应该将demo的pipeline切换为chat,推理表现正常了
图片

你好,问下有遇到温度低于 0.6 就报错不

建议您提交一个新的issue,并详细描述依赖,命令行,报错详细信息,让社区的人研究研究

@xunfeng1980
Copy link

#377

@lonngxiang
Copy link

好奇 Qwen 官方提供的 int4 模型是咋量化的,不过并没有提供量化脚本

@lonngxiang
Copy link

@dlutsniper 能麻烦提供参考下最终解决的相关代码吗

@xunfeng1980
Copy link

好奇 Qwen 官方提供的 int4 模型是咋量化的,不过并没有提供量化脚本

我这可以量化了,就按照上面的方式:#370 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants