From d2bce6a56eae5701cb72eb0cf6359626e7bd0190 Mon Sep 17 00:00:00 2001 From: Yi Liu Date: Sat, 14 Sep 2024 09:13:02 +0800 Subject: [PATCH] Move `autoround` from `generate.py` to `eval.py` (#868) * move autoround from generate to eval Signed-off-by: yiliu30 * add llama3 back Signed-off-by: yiliu30 * update the scripts Signed-off-by: yiliu30 * update the scripts Signed-off-by: yiliu30 * rename eval_acc.sh -> evals.sh Signed-off-by: yiliu30 * update Signed-off-by: yiliu30 * update Signed-off-by: yiliu30 --------- Signed-off-by: yiliu30 --- torchao/_models/llama/benchmarks.sh | 8 +--- torchao/_models/llama/eval.py | 59 ++++++++++++++++++++++++++--- torchao/_models/llama/evals.sh | 9 +++++ torchao/_models/llama/generate.py | 53 ++------------------------ 4 files changed, 67 insertions(+), 62 deletions(-) create mode 100644 torchao/_models/llama/evals.sh diff --git a/torchao/_models/llama/benchmarks.sh b/torchao/_models/llama/benchmarks.sh index 48c75931f..6582832f6 100644 --- a/torchao/_models/llama/benchmarks.sh +++ b/torchao/_models/llama/benchmarks.sh @@ -49,9 +49,6 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt -# TODO: this is an accuracy technique with same perf as int4, should be in evaluations instead of generate.py -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround # auto-round w/o quant_lm_head -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround-cuda-0 # auto-round w/o quant_lm_head export MODEL_REPO=meta-llama/Meta-Llama-3-8B python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt @@ -61,7 +58,4 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16 python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt -# TODO: this is an accuracy technique with same perf as int4, should be in evaluations instead of generate.py -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround # auto-round w/o quant_lm_head -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround-cuda-0 # auto-round w/o quant_lm_head +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt \ No newline at end of file diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index f3300e331..673c4f595 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -26,7 +26,7 @@ from tokenizer import get_tokenizer import time from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer -from torchao._models.llama.model import prepare_inputs_for_model +from torchao._models.llama.model import prepare_inputs_for_model, TransformerBlock from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 def run_evaluation( @@ -122,6 +122,51 @@ def run_evaluation( else: if not TORCH_VERSION_AT_LEAST_2_5: unwrap_tensor_subclass(model) + if "autoround" in quantization: + from torchao.prototype.autoround.autoround_llm import quantize_model_with_autoround_ + from transformers import AutoTokenizer + + _tokenizer = AutoTokenizer.from_pretrained(checkpoint_path.parent) + # parse args from quantization string: + # autoround------- + _quant_args = quantization.split("-") + _default_quant_args = [False, 200, 128, 8, 2048, 128] + _model_devie = _quant_args[1] if len(_quant_args) > 1 else device + _quant_args = _quant_args[2:] + quant_lm_head, iters, groupsize, batch_size, seqlen, nsamples = [ + int(x) for x in _quant_args + ] + _default_quant_args[len(_quant_args) :] + model = model.to(_model_devie) + print( + ( + f"Quantizing model with autoround(iters={iters}, groupsize={groupsize}, " + f"quant_lm_head={quant_lm_head}, batch_size={batch_size}, seqlen={seqlen}, nsamples={nsamples})" + ) + ) + with torch.device(_model_devie): + model.setup_caches( + max_batch_size=batch_size, max_seq_length=seqlen, training=True + ) + + if quant_lm_head: + is_target_module = ( + lambda mod, fqn: isinstance(mod, TransformerBlock) + or "output" in fqn + ) + else: + is_target_module = lambda mod, fqn: isinstance(mod, TransformerBlock) + quantize_model_with_autoround_( + model=model, + tokenizer=_tokenizer, + is_target_module=is_target_module, + bits=4, + seqlen=seqlen, + bs=batch_size, + iters=iters, + nsamples=nsamples, + ) + model.to(device) + model.reset_caches() if compile: model = torch.compile(model, mode="max-autotune", fullgraph=True) @@ -145,11 +190,15 @@ def run_evaluation( parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate') parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation') - parser.add_argument('-q', '--quantization', type=str, + parser.add_argument( + "-q", + "--quantization", + type=str, help=( - 'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-, int4wo--gptq, autoquant, autoquant-int4, '+ - 'int4wo--hqq, uintx--, uintx---hqq, sparse-marlin' - ) + "Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-, int4wo--gptq, " + "autoquant, autoquant-int4, int4wo--hqq, uintx--, uintx---hqq, " + "sparse-marlin, autoround-------" + ), ) parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time') diff --git a/torchao/_models/llama/evals.sh b/torchao/_models/llama/evals.sh new file mode 100644 index 000000000..253d1dfee --- /dev/null +++ b/torchao/_models/llama/evals.sh @@ -0,0 +1,9 @@ +export CHECKPOINT_PATH=../../../checkpoints # path to checkpoints folder + +export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf +python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround # auto-round w/o quant_lm_head +python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cuda-1 # auto-round w/ quant_lm_head + +export MODEL_REPO=meta-llama/Meta-Llama-3-8B +python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantizatio autoround-cpu # auto-round w/o quant_lm_head +python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cpu-1 # auto-round w/ quant_lm_head diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 1ce3ef6b6..5fb905dbf 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -30,7 +30,7 @@ def device_sync(device): wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) -from torchao._models.llama.model import Transformer, prepare_inputs_for_model, TransformerBlock +from torchao._models.llama.model import Transformer, prepare_inputs_for_model from torchao._models.llama.tokenizer import get_tokenizer def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization @@ -227,53 +227,7 @@ def main( if "marlin" in quantization: from torchao.dtypes import MarlinSparseLayoutType quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType())) - if "autoround" in quantization: - from torchao.prototype.autoround.autoround_llm import quantize_model_with_autoround_ - from transformers import AutoTokenizer - - _tokenizer = AutoTokenizer.from_pretrained(checkpoint_path.parent) - # parse args from quantization string: - # autoround------- - # A lightweight configuration for generation benchmarking. - _quant_args = quantization.split("-") - _default_quant_args = [True, 1, 128, 1, 512, 32] - _model_devie = _quant_args[1] if len(_quant_args) > 1 else device - _quant_args = _quant_args[2:] - quant_lm_head, iters, groupsize, batch_size, seqlen, nsamples = [ - int(x) for x in _quant_args - ] + _default_quant_args[len(_quant_args) :] - model = model.to(_model_devie) - print( - ( - f"Quantizing model with autoround(iters={iters}, groupsize={groupsize}, " - f"quant_lm_head={quant_lm_head}, batch_size={batch_size}, seqlen={seqlen}, nsamples={nsamples})" - ) - ) - with torch.device(_model_devie): - model.setup_caches( - max_batch_size=batch_size, max_seq_length=seqlen, training=True - ) - - if quant_lm_head: - is_target_module = ( - lambda mod, fqn: isinstance(mod, TransformerBlock) or "output" in fqn - ) - else: - is_target_module = lambda mod, fqn: isinstance(mod, TransformerBlock) - quantize_model_with_autoround_( - model=model, - tokenizer=_tokenizer, - is_target_module=is_target_module, - bits=4, - seqlen=seqlen, - bs=batch_size, - iters=iters, - nsamples=nsamples, - ) - model.to(device) - model.reset_caches() - # TODO this needs to be expanded to all of fpx so they can - if "fp6" in quantization: + if "fp6" in quantization: quantize_(model, fpx_weight_only(3, 2)) if "uintx" in quantization: # uintx-nbits-groupsize, e.g. "uintx-2-64" @@ -461,8 +415,7 @@ def callback(x): parser.add_argument('-q', '--quantization', type=str, help=( 'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-, int4wo--hqq, autoquant, ' - +'autoquant-int4, autoround-------, ' - +'uintx--, uintx---hqq, sparse-marlin' + +'autoquant-int4, uintx--, uintx---hqq, sparse-marlin' ) ) parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache')