Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Pass a device pointer into the quantize kernel for the scales #5159

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,

#endif

void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor& input,
float scale);
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale);

void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor lookup_table);
Expand Down
15 changes: 9 additions & 6 deletions csrc/quantization/compressed_tensors/int8_quant_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ namespace vllm {
template <typename scalar_t, typename scale_type>
__global__ void static_scaled_int8_quant_kernel(
const scalar_t* __restrict__ input, int8_t* __restrict__ out,
scale_type scale, const int hidden_size) {
const scale_type* scale_ptr, const int hidden_size) {
const int tid = threadIdx.x;
const int token_idx = blockIdx.x;
scale_type scale = *scale_ptr;

for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] =
Expand All @@ -39,11 +40,13 @@ __global__ void static_scaled_int8_quant_kernel(
}
} // namespace vllm

void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
float scale) {
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
torch::Tensor const& input, // [..., hidden_size]
torch::Tensor const& scale) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(scale.numel() == 1);

int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
Expand All @@ -53,7 +56,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
vllm::static_scaled_int8_quant_kernel<scalar_t, float>
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
out.data_ptr<int8_t>(), scale,
hidden_size);
out.data_ptr<int8_t>(),
scale.data_ptr<float>(), hidden_size);
});
}
4 changes: 3 additions & 1 deletion tests/kernels/test_int8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def test_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype,
torch.iinfo(torch.int8).min,
torch.iinfo(torch.int8).max).to(torch.int8)
out2 = torch.empty_like(x, dtype=torch.int8)
ops.static_scaled_int8_quant(out2, x, scale)
scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda")

ops.static_scaled_int8_quant(out2, x, scale_argument)
assert torch.allclose(out1, out2,
atol=1) # big atol to account for rounding errors
2 changes: 1 addition & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def scaled_fp8_quant(

# int8
def static_scaled_int8_quant(input: torch.Tensor,
scale: float) -> torch.Tensor:
scale: torch.Tensor) -> torch.Tensor:
"""
Quantize the input tensor to int8 and return the quantized tensor.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
act_scale = layer.input_scale

# Input quantize
x_q = custom_ops.static_scaled_int8_quant(x, act_scale[0].item())
x_q = custom_ops.static_scaled_int8_quant(x, act_scale)

return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale,
weight_scale, x.dtype)
Loading