Skip to content

Commit

Permalink
Move autoround from generate.py to eval.py (#868)
Browse files Browse the repository at this point in the history
* move autoround from generate to eval

Signed-off-by: yiliu30 <yi4.liu@intel.com>

* add llama3 back

Signed-off-by: yiliu30 <yi4.liu@intel.com>

* update the scripts

Signed-off-by: yiliu30 <yi4.liu@intel.com>

* update the scripts

Signed-off-by: yiliu30 <yi4.liu@intel.com>

* rename eval_acc.sh -> evals.sh

Signed-off-by: yiliu30 <yi4.liu@intel.com>

* update

Signed-off-by: yiliu30 <yi4.liu@intel.com>

* update

Signed-off-by: yiliu30 <yi4.liu@intel.com>

---------

Signed-off-by: yiliu30 <yi4.liu@intel.com>
  • Loading branch information
yiliu30 committed Sep 14, 2024
1 parent 90c8cbd commit d2bce6a
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 62 deletions.
8 changes: 1 addition & 7 deletions torchao/_models/llama/benchmarks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,6 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt
# TODO: this is an accuracy technique with same perf as int4, should be in evaluations instead of generate.py
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround # auto-round w/o quant_lm_head
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround-cuda-0 # auto-round w/o quant_lm_head

export MODEL_REPO=meta-llama/Meta-Llama-3-8B
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt
Expand All @@ -61,7 +58,4 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt
# TODO: this is an accuracy technique with same perf as int4, should be in evaluations instead of generate.py
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround # auto-round w/o quant_lm_head
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround-cuda-0 # auto-round w/o quant_lm_head
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt
59 changes: 54 additions & 5 deletions torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from tokenizer import get_tokenizer
import time
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer
from torchao._models.llama.model import prepare_inputs_for_model
from torchao._models.llama.model import prepare_inputs_for_model, TransformerBlock
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

def run_evaluation(
Expand Down Expand Up @@ -122,6 +122,51 @@ def run_evaluation(
else:
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(model)
if "autoround" in quantization:
from torchao.prototype.autoround.autoround_llm import quantize_model_with_autoround_
from transformers import AutoTokenizer

_tokenizer = AutoTokenizer.from_pretrained(checkpoint_path.parent)
# parse args from quantization string:
# autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>
_quant_args = quantization.split("-")
_default_quant_args = [False, 200, 128, 8, 2048, 128]
_model_devie = _quant_args[1] if len(_quant_args) > 1 else device
_quant_args = _quant_args[2:]
quant_lm_head, iters, groupsize, batch_size, seqlen, nsamples = [
int(x) for x in _quant_args
] + _default_quant_args[len(_quant_args) :]
model = model.to(_model_devie)
print(
(
f"Quantizing model with autoround(iters={iters}, groupsize={groupsize}, "
f"quant_lm_head={quant_lm_head}, batch_size={batch_size}, seqlen={seqlen}, nsamples={nsamples})"
)
)
with torch.device(_model_devie):
model.setup_caches(
max_batch_size=batch_size, max_seq_length=seqlen, training=True
)

if quant_lm_head:
is_target_module = (
lambda mod, fqn: isinstance(mod, TransformerBlock)
or "output" in fqn
)
else:
is_target_module = lambda mod, fqn: isinstance(mod, TransformerBlock)
quantize_model_with_autoround_(
model=model,
tokenizer=_tokenizer,
is_target_module=is_target_module,
bits=4,
seqlen=seqlen,
bs=batch_size,
iters=iters,
nsamples=nsamples,
)
model.to(device)
model.reset_caches()

if compile:
model = torch.compile(model, mode="max-autotune", fullgraph=True)
Expand All @@ -145,11 +190,15 @@ def run_evaluation(
parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate')
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use')
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
parser.add_argument('-q', '--quantization', type=str,
parser.add_argument(
"-q",
"--quantization",
type=str,
help=(
'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-gptq, autoquant, autoquant-int4, '+
'int4wo-<groupsize>-hqq, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin'
)
"Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-gptq, "
"autoquant, autoquant-int4, int4wo-<groupsize>-hqq, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, "
"sparse-marlin, autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>"
),
)
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time')
Expand Down
9 changes: 9 additions & 0 deletions torchao/_models/llama/evals.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
export CHECKPOINT_PATH=../../../checkpoints # path to checkpoints folder

export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround # auto-round w/o quant_lm_head
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cuda-1 # auto-round w/ quant_lm_head

export MODEL_REPO=meta-llama/Meta-Llama-3-8B
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantizatio autoround-cpu # auto-round w/o quant_lm_head
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cpu-1 # auto-round w/ quant_lm_head
53 changes: 3 additions & 50 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def device_sync(device):
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from torchao._models.llama.model import Transformer, prepare_inputs_for_model, TransformerBlock
from torchao._models.llama.model import Transformer, prepare_inputs_for_model
from torchao._models.llama.tokenizer import get_tokenizer

def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
Expand Down Expand Up @@ -227,53 +227,7 @@ def main(
if "marlin" in quantization:
from torchao.dtypes import MarlinSparseLayoutType
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
if "autoround" in quantization:
from torchao.prototype.autoround.autoround_llm import quantize_model_with_autoround_
from transformers import AutoTokenizer

_tokenizer = AutoTokenizer.from_pretrained(checkpoint_path.parent)
# parse args from quantization string:
# autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>
# A lightweight configuration for generation benchmarking.
_quant_args = quantization.split("-")
_default_quant_args = [True, 1, 128, 1, 512, 32]
_model_devie = _quant_args[1] if len(_quant_args) > 1 else device
_quant_args = _quant_args[2:]
quant_lm_head, iters, groupsize, batch_size, seqlen, nsamples = [
int(x) for x in _quant_args
] + _default_quant_args[len(_quant_args) :]
model = model.to(_model_devie)
print(
(
f"Quantizing model with autoround(iters={iters}, groupsize={groupsize}, "
f"quant_lm_head={quant_lm_head}, batch_size={batch_size}, seqlen={seqlen}, nsamples={nsamples})"
)
)
with torch.device(_model_devie):
model.setup_caches(
max_batch_size=batch_size, max_seq_length=seqlen, training=True
)

if quant_lm_head:
is_target_module = (
lambda mod, fqn: isinstance(mod, TransformerBlock) or "output" in fqn
)
else:
is_target_module = lambda mod, fqn: isinstance(mod, TransformerBlock)
quantize_model_with_autoround_(
model=model,
tokenizer=_tokenizer,
is_target_module=is_target_module,
bits=4,
seqlen=seqlen,
bs=batch_size,
iters=iters,
nsamples=nsamples,
)
model.to(device)
model.reset_caches()
# TODO this needs to be expanded to all of fpx so they can
if "fp6" in quantization:
if "fp6" in quantization:
quantize_(model, fpx_weight_only(3, 2))
if "uintx" in quantization:
# uintx-nbits-groupsize, e.g. "uintx-2-64"
Expand Down Expand Up @@ -461,8 +415,7 @@ def callback(x):
parser.add_argument('-q', '--quantization', type=str,
help=(
'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
+'autoquant-int4, autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>, '
+'uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin'
+'autoquant-int4, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin'
)
)
parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache')
Expand Down

0 comments on commit d2bce6a

Please sign in to comment.