Skip to content

Commit

Permalink
Merge branch 'refs/heads/dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
turboderp committed Jul 24, 2024
2 parents 61e9ae8 + 4bbd969 commit 509f9aa
Show file tree
Hide file tree
Showing 10 changed files with 304 additions and 274 deletions.
396 changes: 142 additions & 254 deletions exllamav2/architecture.py

Large diffs are not rendered by default.

40 changes: 33 additions & 7 deletions exllamav2/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
T = TypeVar('T')
no_default = object()

def read(input_dict: dict[str, Any], expected_type: type, keys: str | list[str], default = no_default) -> T:
def read(input_dict: dict[str, Any], expected_type: type | list[type], keys: str | list[str], default = no_default) -> T:

expected_types = expected_type if isinstance(expected_type, list) else [expected_type]

if isinstance(keys, str): keys = [keys]

Expand All @@ -34,10 +36,10 @@ def read(input_dict: dict[str, Any], expected_type: type, keys: str | list[str],
if expected_type == int and isinstance(x, float) and x == int(x):
x = int(x)

if isinstance(x, expected_type):
return cast(T, x)
else:
raise TypeError(f"Value for {key} is not of expected type {expected_type}")
for t in expected_types:
if isinstance(x, t):
return cast(T, x)
raise TypeError(f"Value for {key} is not of expected type {expected_type}")

if default != no_default: return default
raise ValueError(f"Missing any of the following keys: {keys}")
Expand Down Expand Up @@ -104,8 +106,13 @@ class ExLlamaV2Config:
final_logit_softcapping: float | None
attn_logit_softcapping: float | None
sliding_window: int

norm_head: int | None
l3_rope_factor: float | None
l3_rope_low_freq_factor: float | None
l3_rope_high_freq_factor: float | None
l3_rope_original_max_position_embeddings: int | None
checkpoint_fused_mlp: bool
checkpoint_offset_qzeros: bool


def __init__(self,
Expand Down Expand Up @@ -189,10 +196,13 @@ def prepare(self, no_tensors: bool = False):
# Vocab params

self.bos_token_id = read(read_config, int, "bos_token_id", None) # 1
self.eos_token_id = read(read_config, int, "eos_token_id", None) # 2
self.eos_token_id = read(read_config, [int, list], "eos_token_id", None) # 2
self.pad_token_id = read(read_config, int, "pad_token_id", None) # 0
self.vocab_size = read(read_config, int, "vocab_size")

if isinstance(self.eos_token_id, list):
self.eos_token_id = self.eos_token_id[0] # TODO: Figure out a way to maybe use all the EOS tokens somehow

# Standard params

self.initializer_range = read(read_config, float, ["initializer_range"])
Expand Down Expand Up @@ -251,6 +261,10 @@ def prepare(self, no_tensors: bool = False):
self.attn_logit_softcapping = read(read_config, float, "attn_logit_softcapping", None)
self.final_logit_softcapping = read(read_config, float, "final_logit_softcapping", None)

# Normalize weights in head layer

self.norm_head = read(read_config, int, "norm_head", None)

# Positional embeddings

self.rotary_embedding_base = read(read_config, float, ["rope_theta", "attn_config->rope_theta"], 10000.0)
Expand Down Expand Up @@ -281,6 +295,18 @@ def prepare(self, no_tensors: bool = False):
self.alt_rope_method = "su"
# if scaling_type == "yarn":
# self.scale_alpha_value = factor
rope_type = rs.get("rope_type", None)
if rope_type == "llama3":
self.alt_rope_method = "llama3"
self.l3_rope_factor = rs["factor"]
self.l3_rope_low_freq_factor = rs["low_freq_factor"]
self.l3_rope_high_freq_factor = rs["high_freq_factor"]
self.l3_rope_original_max_position_embeddings = rs["original_max_position_embeddings"]

# Checkpoint format (for GPTQ models)

checkpoint_format = read(read_config, str, ["quantization_config->checkpoint_format"], None)
self.checkpoint_offset_qzeros = (checkpoint_format == "gptq_v2")

# Create map of model tensors

Expand Down
8 changes: 8 additions & 0 deletions exllamav2/conversion/convert_exl2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
parser.add_argument("-ml", "--measurement_length", type = int, default = 2048, help = "Max no. tokens per sample when measuring")
parser.add_argument("-so", "--status_output", action = "store_true", help = "Include machine-parseable status updates in console output")
parser.add_argument("-hsol", "--hidden_state_offload_layers", type = int, default = 0, help = "Number of hidden/target states to keep in VRAM. Speed-up but increases VRAM usage")
parser.add_argument("-fst", "--fast_safetensors", action = "store_true", help = "Use fast-safetensors to load layers of the unquantized model. This can help alleviate some out-of-memory issues, especially on Windows.")

args = parser.parse_args()

Expand Down Expand Up @@ -112,6 +113,7 @@ def save_job():
"rope_scale": args.rope_scale,
"rope_alpha": args.rope_alpha,
"output_measurement": output_measurement,
"fast_safetensors": args.fast_safetensors,
"progress": "begin"}

if args.measurement is not None:
Expand Down Expand Up @@ -160,6 +162,8 @@ def save_job():
else:
print(f" -- Measurement will be saved to {job['output_measurement']}")
print(f" !! Conversion script will end after measurement pass")
if job.get("fast_safetensors"):
print(f" -- Enabled fast_safetensors option.")

if job['rope_scale']: print(f" -- RoPE scale: {job['rope_scale']:.2f}")
if job['rope_alpha']: print(f" -- RoPE alpha: {job['rope_alpha']:.2f}")
Expand Down Expand Up @@ -190,6 +194,10 @@ def save_job():

tokenizer = ExLlamaV2Tokenizer(config)

# Set fast_safetensors in config

if job.get("fast_safetensors"): config.fasttensors = True

# Set scaling for input model

if job["rope_scale"] is not None: config.scale_pos_emb = job["rope_scale"]
Expand Down
6 changes: 5 additions & 1 deletion exllamav2/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,8 @@ def make_q_matrix(w: dict,
temp_dq: torch.Tensor,
key: str = None,
prescale: float = 1,
max_dq_rows = 0):
max_dq_rows = 0,
offset_qzeros: bool = False):

# EXL2

Expand Down Expand Up @@ -354,6 +355,9 @@ def make_q_matrix(w: dict,
if prescale != 1: w["scales"] *= prescale
if w["scales"].dtype == torch.float: w["scales"] = w["scales"].half()

if offset_qzeros:
w["qzeros"] -= 0b00010001000100010001000100010001

# GPTQ with g_idx (act_order)

if "g_idx" in w and not (w["g_idx"] == 0).all().item():
Expand Down
58 changes: 51 additions & 7 deletions exllamav2/generator/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,11 @@ def iterate(self) -> list[dict]:
"stop_string"
"max_new_tokens"
"end_filter"
optional, if "eos_reason" == "stop_token":
"eos_triggering_token_id": int
"eos_triggering_token_str": str
optional, if "eos_reason" == "stop_string":
"eos_triggering_string": str
"full_completion": str - full text completion
"new_tokens": int - number of tokens generated
"time_enqueued": float - time from job was enqueued until it started, in seconds
Expand Down Expand Up @@ -1849,7 +1854,10 @@ def emit(
eos_reason: str = None,
emit_held = False,
suppressed_text = None,
suppressed_tokens = None
suppressed_tokens = None,
stop_token: int = None,
stop_string: str = None,
rem_held_text: str = None
):
r = {
"job": self,
Expand All @@ -1860,6 +1868,15 @@ def emit(

if eos_reason is not None:
r.update({ "eos_reason": eos_reason })
if eos_reason == "stop_token":
id_to_piece = self.generator.tokenizer.get_id_to_piece_list(True)
r.update({
"eos_triggering_token_id": stop_token,
"eos_triggering_token_str": id_to_piece[stop_token]
})
pass
if eos_reason == "stop_string":
r.update({ "eos_triggering_string": stop_string })

if emit_held:
if self.held_text != "":
Expand Down Expand Up @@ -1903,18 +1920,29 @@ def emit(
"accepted_draft_tokens": self.accepted_draft_tokens,
"rejected_draft_tokens": self.rejected_draft_tokens
})
if eos_reason == "stop_string":
self.held_text = rem_held_text
rh = {}
if self.held_text:
rh.update({ "text": self.held_text })
if self.held_tokens:
rh.update({ "token_ids": self.held_tokens.torch().clone() })
if self.held_probs:
rh.update({ "token_probs": self.held_probs.torch().clone() })
if self.held_k_tokens:
rh.update({ "top_k_tokens": self.held_k_tokens.torch().clone() })
rh.update({ "top_k_probs": self.held_k_probs.torch().clone() })
if self.held_logits:
rh.update({ "logits": self.held_logits.torch().clone() })
if rh:
r.update({ "held": rh })

if self.identifier is not None:
r.update({ "identifier": self.identifier })

results.append(r)
return emit_eos, next_token

# End on stop tokens

if next_token.item() in self.stop_tokens:
return emit(results, emit_eos = True, eos_reason = "stop_token")

# Decode and buffer output

id_to_piece = self.generator.tokenizer.get_id_to_piece_list(self.decode_special_tokens)
Expand All @@ -1934,6 +1962,11 @@ def emit(
if self.return_logits:
self.held_logits.append(logits[:1, :, :])

# End on stop tokens

if next_token.item() in self.stop_tokens:
return emit(results, emit_eos = True, eos_reason = "stop_token", stop_token = next_token.item())

# Stop if we reach max_new_tokens

if self.new_tokens >= self.max_new_tokens - self.generator.num_draft_tokens:
Expand Down Expand Up @@ -2032,8 +2065,19 @@ def rewind_checkpoint():
self.stop_strings_utf32_buffer
)
if match >= 0:
held = self.held_text[match:]
self.held_text = self.held_text[:match]
return emit(results, emit_eos = True, emit_held = True, eos_reason = "stop_string")
for s in self.stop_strings:
if held.startswith(s):
return emit(
results,
emit_eos = True,
emit_held = True,
eos_reason = "stop_string",
stop_string = s,
rem_held_text = held
)
assert False, "Detected stop string but couldn't identify it (logic error)"
if match == -2:
return emit(results)

Expand Down
6 changes: 6 additions & 0 deletions exllamav2/generator/dynamic_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class ExLlamaV2DynamicJobAsync:
job: ExLlamaV2DynamicJob
queue: asyncio.Queue
generator: ExLlamaV2DynamicGeneratorAsync
cancelled: bool = False

def __init__(self, generator: ExLlamaV2DynamicGeneratorAsync, *args: object, **kwargs: object):
self.generator = generator
Expand All @@ -87,6 +88,10 @@ async def put_result(self, result):

async def __aiter__(self):
while True:
# Get out if the job is cancelled
if self.cancelled:
break

result = await self.queue.get()
if isinstance(result, Exception):
raise result
Expand All @@ -96,3 +101,4 @@ async def __aiter__(self):

async def cancel(self):
await self.generator.cancel(self)
self.cancelled = True
19 changes: 16 additions & 3 deletions exllamav2/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def __init__(self,
f_beg: int = None,
f_end: int = None,
is_sub_module: bool = True,
altpack_qkv: bool = False):
altpack_qkv: bool = False,
normalize_unq: bool = False):
super().__init__(model, key)

self.is_sub_module = is_sub_module
Expand Down Expand Up @@ -89,20 +90,23 @@ def __init__(self,
self.altpack_qkv = altpack_qkv

self.assumed_footprint = in_features * (out_features + self.padding) * 2 + 128
self.normalize_unq = normalize_unq


@torch.inference_mode
def load(self,
w: dict | nn.Parameter | tuple | None = None,
device_tensors: bool = True):

cfg = self.model.config

if self.f_key: w = self.load_weight_fused(self.f_key, self.f_beg, self.f_end, self.in_features, self.out_features, self.altpack_qkv)
if w is None: w = self.load_weight()

# Load quantized linear layer from dictionary

if isinstance(w, dict):
assert not self.model.config.load_in_q4, "Can't load quantized layer in Q4 mode"
assert not cfg.load_in_q4, "Can't load quantized layer in Q4 mode"
if self.has_bias:
assert "bias" in w, self.key + " has no bias but bias expected"
else:
Expand All @@ -117,14 +121,17 @@ def load(self,
self.q_handle = ext.make_q_matrix(w,
self.temp_dq,
prescale = self.prescale,
max_dq_rows = self.model.config.max_dq_size // self.out_features)
max_dq_rows = cfg.max_dq_size // self.out_features,
offset_qzeros = cfg.checkpoint_offset_qzeros)
self.prev_prescale = self.prescale
self.prescale = 1

# Load FP16 linear layer without bias, optionally quantize to Q4

elif isinstance(w, nn.Parameter):
assert not self.has_bias, self.key + " has no bias tensor but bias is expected"
if self.normalize_unq:
w = self.normalize(w)
if self.padding > 0: w = nn.Parameter(F.pad(w.data, (0, 0, 0, self.padding)).contiguous())
if not self.model.config.load_in_q4 or not ".layers." in self.key:
self.linear = nn.Linear(self.in_features, self.out_features, self.has_bias, device = "meta", dtype = torch.float16)
Expand All @@ -138,6 +145,8 @@ def load(self,

elif isinstance(w, tuple):
assert self.has_bias, self.key + " has bias tensor but bias is not expected"
if self.normalize_unq:
w = self.normalize(w[0]), w[1]
ww = w[0]
wb = w[1]
if self.padding > 0:
Expand All @@ -154,6 +163,10 @@ def load(self,
self.fp16_bias = wb


def normalize(self, w: torch.Tensor):
return nn.functional.normalize(w)


def matrix_shape(self):

return self.in_features, self.out_features
Expand Down
2 changes: 2 additions & 0 deletions exllamav2/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def __init__(self,
f = load_file(self.lora_path, map_location = "cpu")

for key in f.keys():
if any(key.endswith(x) for x in [".original_module.weight", ".modules_to_save.weight"]):
continue
tensor = f[key]

# Find target
Expand Down
Loading

0 comments on commit 509f9aa

Please sign in to comment.