forked from pytorch/ao
-
Notifications
You must be signed in to change notification settings - Fork 0
/
quant_api.py
407 lines (336 loc) · 14 KB
/
quant_api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Quantization APIs
Generally these APIs can be applied directly to any model
with Linear modules to obtain quantized linear ops. The intended
usage involves applying torch.compile to the model afterwards
both because primitives were designed based on the fusions that
come along with it and because that is how we access the intended quantized
and mixed GEMM kernels
TODO: There are 2 different approaches to quantizing a model. The first and more historically
popular approach is to use module swaps which explicitly change the linear modules and the second
approach is to instead use subclasses to change the interpretation of the linear module
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Callable
from .dynamic_quant import DynamicallyPerAxisQuantizedLinear
from .utils import (
TORCH_VERSION_AFTER_2_3,
TORCH_VERSION_AFTER_2_4,
unwrap_tensor_subclass,
)
from .subclass import (
Int4WeightOnlyQuantizedLinearWeight,
Int8DynamicallyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
QuantizedLinearWeightBase,
to_laq,
)
from .quant_primitives import (
MappingType,
ZeroPointDomain,
)
from .weight_only import WeightOnlyInt8QuantLinear
from .unified import Quantizer, TwoStepQuantizer
from .GPTQ import (
Int4WeightOnlyGPTQQuantizer,
Int4WeightOnlyQuantizer,
)
from .autoquant import autoquant, AutoQuantizableLinearWeight
__all__ = [
"apply_weight_only_int8_quant",
"apply_dynamic_quant",
"change_linear_weights_to_int8_dqtensors",
"change_linear_weights_to_int8_woqtensors",
"change_linear_weights_to_int4_woqtensors",
"swap_conv2d_1x1_to_linear",
"Quantizer",
"TwoStepQuantizer",
"Int4WeightOnlyGPTQQuantizer",
"Int4WeightOnlyQuantizer",
"quantize",
"autoquant",
"_get_subclass_inserter",
"get_apply_8da4w_quant",
"get_apply_int4wo_quant",
"get_apply_int8wo_quant",
"get_apply_int8dyn_quant",
]
if TORCH_VERSION_AFTER_2_3:
from .GPTQ import (
Int8DynActInt4WeightQuantizer,
Int8DynActInt4WeightGPTQQuantizer,
)
__all__ += [
"Int8DynActInt4WeightQuantizer",
"Int8DynActInt4WeightGPTQQuantizer",
]
def _replace_with_custom_fn_if_matches_filter(
model,
replacement_fn,
filter_fn,
cur_fqn="",
) -> None:
"""
Recursively replaces each child module in `model` with the result of `replacement_fn(child)`
if `filter_fn(child)` returns `True`.
Args:
model (torch.nn.Module): The model containing modules to be replaced.
replacement_fn (Callable[[torch.nn.Module], torch.nn.Module]): The function to replace matching modules.
filter_fn (Callable[[torch.nn.Module], bool]): The filter function to determine which modules to replace.
cur_fqn (str, optional): The current fully qualified name of the module being processed. Defaults to "".
Returns:
None
"""
if filter_fn(model, cur_fqn[:-1]):
model = replacement_fn(model)
return model
else:
for name, child in model.named_children():
new_child = _replace_with_custom_fn_if_matches_filter(
child, replacement_fn, filter_fn, f"{cur_fqn}{name}."
)
if new_child is not child:
setattr(model, name, new_child)
return model
def _is_linear(mod, *args):
return (
isinstance(mod, torch.nn.Linear)
and hasattr(mod, "weight")
and not isinstance(mod.weight, QuantizedLinearWeightBase)
and not isinstance(mod.weight, AutoQuantizableLinearWeight)
)
def _in_features_greater_than_16(mod, *args):
return hasattr(mod, "in_features") and mod.in_features > 16
def apply_weight_only_int8_quant(model, filter_fn=None):
"""
Applies weight-only symmetric per-channel int8 quantization to all linear layers
in the given model using module swaps.
"""
_replace_with_custom_fn_if_matches_filter(
model,
WeightOnlyInt8QuantLinear.from_float,
_is_linear if filter_fn is None else filter_fn,
)
def apply_dynamic_quant(model, filter_fn=None):
"""
Applies dynamic symmetric per-token activation and per-channel weight
quantization to all linear layers by converting all linear weight
tensors to the `Int8DynamicallyQuantizedLinearWeight` Tensor subclass.
"""
change_linear_weights_to_int8_dqtensors(model, filter_fn)
import torch.nn.utils.parametrize as parametrize
def _get_subclass_inserter(cls, enable_parametrization=False, **kwargs):
"""
Returns a function which inserts the given subclass into all linear modules
in the model. The inserted module will have its weight set to the result of
`cls(mod.weight, **kwargs)`. If parametrization is enabled then this will be done using
torch.nn.utils.parametrize instead of directly setting the attribute on the module.
Args:
cls (torch.Tensor): The class to insert as a child module.
kwargs (Any): Any additional arguments for the constructor.
"""
constructor = kwargs.pop("constructor", "subclass_constructor")
from_float = kwargs.pop("method", "from_float")
def insert_subclass(lin):
if enable_parametrization:
lin.weight = torch.nn.Parameter(cls.from_float(lin.weight, **kwargs), requires_grad=False)
_, args = lin.weight.__tensor_flatten__()
parametrize.register_parametrization(lin, "weight", getattr(cls, constructor)(*args))
else:
lin.weight = torch.nn.Parameter(
# cls.from_float(...)
getattr(cls, from_float)(lin.weight, **kwargs), requires_grad=False
)
return lin
return insert_subclass
def change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs):
"""
Converts all linear weight tensors to the `Int8DynamicallyQuantizedLinearWeight`
Tensor subclass, effectively applying the same form of quantization
as apply_dynamic_quant while not modifying the linear modules.
"""
if filter_fn is None:
filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16(
*args
)
if TORCH_VERSION_AFTER_2_4:
quantize(model, get_apply_int8dyn_quant(), filter_fn)
unwrap_tensor_subclass(model, filter_fn)
else:
_replace_with_custom_fn_if_matches_filter(
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn
)
def change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs):
"""
Converts all linear weight tensors to the
`Int8WeightOnlyQuantizedLinearWeight` tensor subclass,
effectively applying the same form of quantization
as apply_dynamic_quant while not modifying the linear modules.
"""
_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(Int8WeightOnlyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs),
_is_linear if filter_fn is None else filter_fn,
)
def change_linear_weights_to_int4_woqtensors(model, **kwargs):
"""
Converts all linear weight tensors to the
`Int4WeightOnlyQuantizedLinearWeight` tensor subclass,
effectively applying the same form of quantization
as apply_dynamic_quant while not modifying the linear modules.
"""
filter_fn = kwargs.pop("filter_fn", _is_linear)
_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs),
filter_fn,
)
def swap_conv2d_1x1_to_linear(model, filter_fn=None):
"""
Changes all conv2d 1x1 modules to equivalent linear modules so that they can then be quantized.
"""
class PermuteSandwich(torch.nn.Module):
def __init__(self, mod):
super().__init__()
self.mod = mod
def forward(self, *args):
return self.mod(args[0].permute(0, 2, 3, 1)).permute(-0, 3, 1, 2)
def replace_conv2d_1x1(conv):
assert conv.kernel_size == (1, 1)
lin = torch.nn.Linear(
conv.in_channels, conv.out_channels, bias=(conv.bias is None)
)
lin.weight = torch.nn.Parameter(conv.weight.squeeze(-1, -2))
lin.bias = conv.bias
return PermuteSandwich(lin)
if filter_fn is None:
filter_fn = lambda mod, *args: isinstance(
mod, torch.nn.Conv2d
) and mod.kernel_size == (1, 1)
_replace_with_custom_fn_if_matches_filter(
model, replace_conv2d_1x1, filter_fn=filter_fn
)
def _get_linear_subclass_inserter(constructor):
def insert_subclass(lin):
lin.weight = torch.nn.Parameter(constructor(lin.weight), requires_grad=False)
return lin
return insert_subclass
def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn=None) -> torch.nn.Module:
"""Convert the weight of linear modules in the model with `apply_tensor_subclass`
Args:
model: input model
apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (quantized) tensor subclass instance
filter_fn: used to filter out the modules that we don't want to apply tenosr subclass
Example::
# weight settings
groupsize = 32
mapping_type = MappingType.ASYMMETRIC
block_size = (1, groupsize)
target_dtype = torch.int32
quant_min = 0
quant_max = 15
eps = 1e-6
preserve_zero = False
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT
apply_weight_quant = lambda x: to_aq(x, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain)
# apply to modules under block0 submodule
def filter_fn(module, fqn):
return fqn == "block0"
m = MyModel(...)
m = quantize(m, apply_weight_quant, filter_fn)
"""
_replace_with_custom_fn_if_matches_filter(
model,
_get_linear_subclass_inserter(apply_tensor_subclass),
_is_linear if filter_fn is None else filter_fn,
)
return model
def get_apply_8da4w_quant(groupsize=32):
def apply_8da4w_quant(weight):
# avoid circular dep
from torchao.dtypes.aqt import to_aq
# weight settings
mapping_type = MappingType.SYMMETRIC
block_size = (1, groupsize)
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
quant_min = -8
quant_max = 7
# TODO: make a general helper function?
# input settings
def get_per_token_block_size(x):
block_size = []
for i in range(len(x.shape)-1):
block_size.append(1)
block_size.append(x.shape[-1])
return block_size
# input settings
input_mapping_type = MappingType.ASYMMETRIC
input_target_dtype = torch.int8
input_quant_func = lambda x: to_aq(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)
weight = to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps)
weight = to_laq(weight, input_quant_func)
return weight
return apply_8da4w_quant
def get_apply_int4wo_quant(groupsize=32):
def apply_int4wo_quant(weight):
# avoid circular dep
from torchao.dtypes.aqt import to_aq
groupsize = 32
mapping_type = MappingType.ASYMMETRIC
block_size = (1, groupsize)
target_dtype = torch.int32
quant_min = 0
quant_max = 15
eps = 1e-6
preserve_zero = False
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT
return to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, extended_layout="tensor_core_tiled")
return apply_int4wo_quant
def get_apply_int8wo_quant():
def apply_int8wo_quant(weight):
# avoid circular dep
from torchao.dtypes.aqt import to_aq
mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64
block_size = (1, weight.shape[1])
return to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
return apply_int8wo_quant
def get_apply_int8dyn_quant():
def apply_int8dyn_quant(weight):
# avoid circular dep
from torchao.dtypes.aqt import to_aq
# weight settings
mapping_type = MappingType.SYMMETRIC
def get_weight_block_size(x):
return (1, x.shape[1])
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64
# input settings
def get_per_token_block_size(x):
block_size = list(x.shape)
for i in range(len(block_size)-1):
block_size[i] = 1
return block_size
input_mapping_type = MappingType.SYMMETRIC
input_target_dtype = torch.int8
input_eps = 1e-5
input_quant_min = -127
input_quant_max = 127
input_quant_func = lambda x: to_aq(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)
block_size = get_weight_block_size(weight)
weight = to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
weight = to_laq(weight, input_quant_func)
return weight
return apply_int8dyn_quant