diff --git a/docs/examples/te_gemma/check_cuda_graphs.py b/docs/examples/te_gemma/check_cuda_graphs.py new file mode 100644 index 0000000000..aee35f6911 --- /dev/null +++ b/docs/examples/te_gemma/check_cuda_graphs.py @@ -0,0 +1,67 @@ +import torch +from transformer_engine.pytorch import Linear, LayerNorm + + +# 1. Define model with static buffers +class TE_Model(torch.nn.Module): + def __init__(self, max_seq_len=4096): + super().__init__() + self.max_seq_len = max_seq_len + self.ln = LayerNorm(1024) + self.attn_proj = Linear(1024, 1024) + + # Pre-allocate static buffers + self.register_buffer("kv_cache", torch.zeros(max_seq_len, 1024, device="cuda")) + self.register_buffer( + "attn_mask", torch.tril(torch.ones(max_seq_len, max_seq_len, device="cuda")) + ) + + def forward(self, hidden_states, seq_start: int): + # Dynamic slicing of static buffers + seq_len = hidden_states.size(1) + current_mask = self.attn_mask[seq_start : seq_start + seq_len, :seq_len] + + x = self.ln(hidden_states) + x = self.attn_proj(x) + # Update KV cache (in-place) + self.kv_cache[seq_start : seq_start + seq_len].copy_(x) + return x + + +# 2. Create graphable callables +model = TE_Model().cuda() +static_input = torch.randn(8, 256, 1024, device="cuda") # (batch, seq, hidden) +seq_start = torch.tensor(0, device="cuda") + +# Wrap with CUDA Graphs +graph_model = torch.cuda.make_graphed_callables( + [model], # Module list + sample_args=[(static_input, seq_start)], # Must match actual input structure + # memory_pool=torch.cuda.graphs.graph_pool_handle(), + allow_unused_input=False, +) + + +# 3. Warmup and execution +def run_inference(x, seq_start): + # Inputs must match sample_args' device/type/shape + x = x.to("cuda", non_blocking=True).requires_grad_(False) + seq_start = seq_start.to("cuda", non_blocking=True) + + with torch.cuda.amp.autocast(): + return graph_model(x, seq_start) + + +# Warm-up (essential for TE's kernel auto-tuner) +for _ in range(3): + _ = run_inference(static_input, seq_start) +torch.cuda.synchronize() + + +# 4. Usage with dynamic sequence lengths +def process_batch(inputs, start_pos): + # inputs: (batch, seq) on CPU + inputs_gpu = inputs.to("cuda", non_blocking=True) + + # Output shares memory with pre-allocated buffers + return run_inference(inputs_gpu, start_pos) diff --git a/docs/examples/te_gemma/check_gemm.py b/docs/examples/te_gemma/check_gemm.py new file mode 100755 index 0000000000..1ed6edd23a --- /dev/null +++ b/docs/examples/te_gemma/check_gemm.py @@ -0,0 +1,137 @@ +import functools +from typing import Optional, Tuple, Union, List +import torch +import transformer_engine as te +import transformer_engine_torch as tex +from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.utils import assert_dim_for_fp8_exec +from transformer_engine.pytorch.module.base import get_workspace +import transformer_engine.pytorch.cpp_extensions as cpp_tex + + +@functools.lru_cache(maxsize=None) +def _empty_tensor() -> torch.Tensor: + """Get tensor with no entries and no data""" + return torch.Tensor() + + +def gemm( + A: torch.Tensor, + B: torch.Tensor, + dtype: torch.dtype, + workspace: torch.Tensor, + gelu: bool = False, + gelu_input: Optional[torch.Tensor] = None, + grad: bool = False, + accumulate: bool = False, + layout: str = "TN", + out: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + use_bias: bool = False, + ub_algo: tex.CommOverlapAlgo = None, + ub: Union[tex.CommOverlap, tex.CommOverlapP2P] = None, + extra_output_tensor: torch.Tensor = None, +) -> Tuple[Union[torch.Tensor, None], ...]: + """Non FP8 GEMM.""" + + assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported." + transa = layout[0] == "T" + transb = layout[1] == "T" + empty_tensor = _empty_tensor() + fp8_index = -1 # dummy index + + if out is None: + out = torch.empty( + B.shape[1] if transb else B.shape[0], + A.shape[0] if transa else A.shape[1], + dtype=dtype, + device="cuda", + ) + else: + if not out.is_contiguous(): + raise ValueError("Output tensor is not contiguous.") + + if gelu and not grad: + gelu_input = torch.empty_like(out, dtype=dtype) + elif not gelu: + gelu_input = empty_tensor + + if grad and use_bias: + grad_bias = torch.empty(B.shape[1], dtype=out.dtype, device="cuda") + else: + grad_bias = empty_tensor + + bias = bias if use_bias else empty_tensor + + assert ( + A.dtype == dtype and B.dtype == dtype + ), f"Expected dtype={dtype}, but found A.dtype={A.dtype} and B.dtype={B.dtype}" + input_dtype = TE_DType[dtype] + output_dtype = TE_DType[out.dtype] + if use_bias: + bias_dtype = TE_DType[grad_bias.dtype] if grad else TE_DType[bias.dtype] + else: + bias_dtype = output_dtype + + args = ( + A, + empty_tensor, + fp8_index, + input_dtype, + transa, + B, + empty_tensor, + fp8_index, + input_dtype, + transb, + out, + empty_tensor, # out_scale + output_dtype, + empty_tensor, # out_amax + grad_bias if grad else bias, + bias_dtype, + gelu_input, + grad, + workspace, + workspace.shape[0], + accumulate, + False, # use_split_accumulator + ) + fn = torch.ops.tex_ts.te_gemm_ts + if ub_algo is not None: + assert ub is not None, "ub object is None!" + _ = fn(*args) + + import pdb + + pdb.set_trace() + return out, grad_bias, gelu_input + + +if __name__ == "__main__": + fc2_weight = torch.load("fc2_weight.pth").cuda() + + base_repo = "/perfhome/mnt/wkstn/work/repos/te_gemma_gen_support/TransformerEngine/docs/examples/te_gemma/" + base_repo = "" + gelu_out = torch.load(base_repo + "gelu_out.pth").cuda() + + activation_dtype = torch.bfloat16 + fc2_bias = _empty_tensor() + use_fc2_bias = False + + dim_size = list(gelu_out.size()) + dim_size[1] = fc2_weight.size(0) + fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) + + _ = cpp_tex.gemm( + fc2_weight, + gelu_out, + activation_dtype, + get_workspace(), + bias=fc2_bias, + use_bias=use_fc2_bias, + out=fc2_out, + ub_algo=None, + ub=None, + extra_output_tensor=None, + ) diff --git a/docs/examples/te_gemma/check_rope.ipynb b/docs/examples/te_gemma/check_rope.ipynb new file mode 100755 index 0000000000..26d5c9058f --- /dev/null +++ b/docs/examples/te_gemma/check_rope.ipynb @@ -0,0 +1,716 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "72f61b51-b6fc-4463-9783-d42a25ca3a2f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "before tex import\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.nn.functional as F\n", + "import math\n", + "print(\"before tex import\")\n", + "import transformer_engine as te\n", + "import transformer_engine_torch as tex" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "1f81be75-bf64-43b2-852a-7c482a1c3418", + "metadata": {}, + "outputs": [], + "source": [ + "from transformer_engine.pytorch.attention import apply_rotary_pos_emb" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "8853f973-d834-41a9-929d-8687b947134f", + "metadata": {}, + "outputs": [], + "source": [ + "def compare_rope_outputs(t, freqs_s11d, freqs_sb1d):\n", + " output1 = tex.fused_rope_forward(t, freqs_s11d, torch.Tensor(), False)\n", + " output2 = tex.fused_rope_forward(t, freqs_sb1d, torch.Tensor(), False)\n", + " print(output1, output2, sep=\"\\n\")\n", + " assert torch.allclose(output1, output2)\n", + " return output1, output2" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "6b7bada1-6748-46f1-93a4-c2ac1a617063", + "metadata": {}, + "outputs": [], + "source": [ + "torch.manual_seed(0)\n", + "b = 2\n", + "s = 3\n", + "h = 2\n", + "d = 4" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "54a8f6d6-28f8-4a9a-8ba0-0fdefff138e7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([3, 1, 1, 4]) torch.Size([3, 2, 1, 4])\n" + ] + } + ], + "source": [ + "freqs_s11d = torch.ones(s, 1, 1, d).cuda() * math.pi/4\n", + "freqs_sb1d = freqs_s11d.broadcast_to(s, b, 1, d).clone()\n", + "t = torch.ones(s, b, h, d).cuda()\n", + "\n", + "print(freqs_s11d.shape, freqs_sb1d.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "5070307a-3104-401b-b84c-00f3bbf02ccc", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[[0.7854, 0.7854, 0.7854, 0.7854]]],\n", + "\n", + "\n", + " [[[0.7854, 0.7854, 0.7854, 0.7854]]],\n", + "\n", + "\n", + " [[[0.7854, 0.7854, 0.7854, 0.7854]]]], device='cuda:0')" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "freqs_s11d" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "81e52785-e6ad-4180-9567-564af692375c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(4, 4, 4, 1)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "freqs_s11d.stride()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "0da9bc09-7e1e-4056-85eb-64b6122c7440", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4, 0\n", + "4, 4, 4, 1, \n", + "nvt_fused_rope_fwd: 4, 0fused_rope_fwd: 4, 0fused_rope_fwd_launcher: 4, 0thread_id: 0, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n" + ] + } + ], + "source": [ + "output = tex.fused_rope_forward(t, freqs_s11d, torch.Tensor(), False)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "1b78017d-09b3-4b5f-93a8-75f6ba6f131c", + "metadata": {}, + "outputs": [], + "source": [ + "output_unfused=apply_rotary_pos_emb(\n", + " t,\n", + " freqs_s11d,\n", + " tensor_format=\"sbhd\",\n", + " fused=False,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "6f5d9350-deb1-48ef-a0a2-e18e01ed336f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n", + " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]],\n", + "\n", + " [[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n", + " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]]],\n", + "\n", + "\n", + " [[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n", + " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]],\n", + "\n", + " [[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n", + " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]]],\n", + "\n", + "\n", + " [[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n", + " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]],\n", + "\n", + " [[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n", + " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]]]],\n", + " device='cuda:0')" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "output_unfused" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "b01e29b8-dfdf-41ac-81a5-d8edf6a8c168", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4, 0\n", + "4, 4, 4, 1, \n", + "nvt_fused_rope_fwd: 4, 0fused_rope_fwd: 4, 0fused_rope_fwd_launcher: 4, 08, 4\n", + "8, 4, 4, 1, \n", + "nvt_fused_rope_fwd: 8, 4fused_rope_fwd: 8, 4fused_rope_fwd_launcher: 8, 4thread_id: 0, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 2, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 2, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 2, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 2, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 2, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 2, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 2, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 2, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 0, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 0, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 0, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 0, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 0, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 0, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 0, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 0, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 2, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 2, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 2, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 2, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 2, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 2, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 2, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 2, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 0, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 0, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 0, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 0, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 0, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 0, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 0, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 0, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 0, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 0, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 0, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 0, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 0, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 0, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 0, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 0, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 0, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 0, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 0, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 0, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 0, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 0, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 0, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 0, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 2, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 2, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 2, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 2, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 2, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 2, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 2, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 2, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 1, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 1, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 1, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 1, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 1, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 1, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 1, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 1, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 2, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 2, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 2, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 2, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 2, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 2, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 2, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 2, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 1, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 1, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 1, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 1, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 1, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 1, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 1, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 1, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 1, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 1, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 1, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 1, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 1, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 1, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 1, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 1, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 1, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 1, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 1, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 1, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 0, s_id: 1, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 1, s_id: 1, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 2, s_id: 1, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "thread_id: 3, s_id: 1, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n", + "tensor([[[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n", + " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]],\n", + "\n", + " [[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n", + " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]]],\n", + "\n", + "\n", + " [[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n", + " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]],\n", + "\n", + " [[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n", + " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]]],\n", + "\n", + "\n", + " [[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n", + " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]],\n", + "\n", + " [[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n", + " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]]]],\n", + " device='cuda:0')\n", + "tensor([[[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n", + " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]],\n", + "\n", + " [[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n", + " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]]],\n", + "\n", + "\n", + " [[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n", + " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]],\n", + "\n", + " [[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n", + " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]]],\n", + "\n", + "\n", + " [[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n", + " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]],\n", + "\n", + " [[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n", + " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]]]],\n", + " device='cuda:0')\n" + ] + } + ], + "source": [ + "output1, output2 = compare_rope_outputs(t, freqs_s11d, freqs_sb1d)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "b168b178-1f63-4ccc-b084-2ac2c1ec016b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([6, 1, 1, 4]) torch.Size([6, 2, 1, 4])\n" + ] + } + ], + "source": [ + "freqs_s11d = torch.randn(s, 1, 1, d).cuda()\n", + "freqs_sb1d = freqs_s11d.broadcast_to(s, b, 1, d).clone()\n", + "\n", + "print(freqs_s11d.shape, freqs_sb1d.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "33ec2e07-6e54-49f7-92f7-2f217a766456", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n", + " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]],\n", + "\n", + " [[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n", + " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]]],\n", + "\n", + "\n", + " [[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n", + " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]],\n", + "\n", + " [[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n", + " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]]],\n", + "\n", + "\n", + " [[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n", + " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]],\n", + "\n", + " [[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n", + " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]]],\n", + "\n", + "\n", + " [[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n", + " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]],\n", + "\n", + " [[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.0000e+00],\n", + " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.0000e+00]]]],\n", + " device='cuda:0')\n", + "tensor([[[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n", + " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]],\n", + "\n", + " [[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n", + " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]]],\n", + "\n", + "\n", + " [[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n", + " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]],\n", + "\n", + " [[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n", + " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]]],\n", + "\n", + "\n", + " [[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n", + " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]],\n", + "\n", + " [[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n", + " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]]],\n", + "\n", + "\n", + " [[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n", + " [ 7.0711e-01, 7.0711e-01, 7.0711e-01, 7.0711e-01]],\n", + "\n", + " [[ 7.0711e-01, 7.0711e-01, 7.0711e-01, 7.0711e-01],\n", + " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]]]],\n", + " device='cuda:0')\n" + ] + }, + { + "ename": "AssertionError", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[9], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m output1, output2 \u001b[38;5;241m=\u001b[39m \u001b[43mcompare_rope_outputs\u001b[49m\u001b[43m(\u001b[49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfreqs_s11d\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfreqs_sb1d\u001b[49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[8], line 5\u001b[0m, in \u001b[0;36mcompare_rope_outputs\u001b[0;34m(t, freqs_s11d, freqs_sb1d)\u001b[0m\n\u001b[1;32m 3\u001b[0m output2 \u001b[38;5;241m=\u001b[39m tex\u001b[38;5;241m.\u001b[39mfused_rope_forward(t, freqs_sb1d, torch\u001b[38;5;241m.\u001b[39mTensor(), \u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28mprint\u001b[39m(output1, output2, sep\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mallclose(output1, output2)\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output1, output2\n", + "\u001b[0;31mAssertionError\u001b[0m: " + ] + } + ], + "source": [ + "output1, output2 = compare_rope_outputs(t, freqs_s11d, freqs_sb1d)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3b58b818-7b31-4ecd-80bd-b5ba049b3c2e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "before tex import\n" + ] + } + ], + "source": [ + "freqs_s11d = torch.randn(s, 1, 1, d).cuda()\n", + "print(freqs_s11d)\n", + "freqs_sb1d = freqs_s11d.broadcast_to(s, b, 1, d).clone()\n", + "print(freqs_sb1d)\n", + "assert torch.all(torch.eq(freqs_sb1d[:, 0, ...], freqs_sb1d[:, 1, ...]))\n", + "\n", + "comp" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c04940b8-3056-466b-90f6-07a02ac47ace", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/examples/te_gemma/media/calibration.svg b/docs/examples/te_gemma/media/calibration.svg new file mode 100755 index 0000000000..b1e1b5ae4b --- /dev/null +++ b/docs/examples/te_gemma/media/calibration.svg @@ -0,0 +1 @@ +FP8 with initial scaling factorsHighprecisionweightInitialFP8 scalingfactorsFP8WeightFP8InputHighprecisioninputFP8GEMMWeight calibrationHighprecisionweightFP8 scalingfactorsHighprecisioninputHighprecisionGEMMFP8 with calibrated scaling factorsHighprecisionweightCalibratedFP8 scalingfactorsFP8WeightFP8InputHighprecisioninputFP8GEMM \ No newline at end of file diff --git a/docs/examples/te_gemma/media/calibration_1_half.svg b/docs/examples/te_gemma/media/calibration_1_half.svg new file mode 100755 index 0000000000..af2641387f --- /dev/null +++ b/docs/examples/te_gemma/media/calibration_1_half.svg @@ -0,0 +1 @@ +HighprecisionweightInitialFP8 scalingfactorsFP8WeightFP8InputHighprecisioninputFP8GEMMHighprecisionweightFP8 scalingfactorsHighprecisioninputHighprecisionGEMMFP8 with initial scaling factorsWeight calibration \ No newline at end of file diff --git a/docs/examples/te_gemma/media/calibration_2_half.svg b/docs/examples/te_gemma/media/calibration_2_half.svg new file mode 100755 index 0000000000..2d56f7d434 --- /dev/null +++ b/docs/examples/te_gemma/media/calibration_2_half.svg @@ -0,0 +1 @@ +Weight calibrationHighprecisionweightFP8 scalingfactorsHighprecisioninputHighprecisionGEMMFP8 with calibrated scaling factorsHighprecisionweightCalibratedFP8 scalingfactorsFP8WeightFP8InputHighprecisioninputFP8GEMM \ No newline at end of file diff --git a/docs/examples/te_gemma/media/fp8_model_init.svg b/docs/examples/te_gemma/media/fp8_model_init.svg new file mode 100755 index 0000000000..c7fce2120d --- /dev/null +++ b/docs/examples/te_gemma/media/fp8_model_init.svg @@ -0,0 +1 @@ +FP32/BF16FP8FP8 with fp8_model_init()FP8weightFP8GEMMHighprecisionweightHighprecisioninputHighprecisionGEMMHighprecisionweightFP8WeightFP8inputFP8GEMMFP8input \ No newline at end of file diff --git a/docs/examples/te_gemma/media/fp8_model_init_1_half.svg b/docs/examples/te_gemma/media/fp8_model_init_1_half.svg new file mode 100755 index 0000000000..3b217a3eb2 --- /dev/null +++ b/docs/examples/te_gemma/media/fp8_model_init_1_half.svg @@ -0,0 +1 @@ +FP32/BF16HighprecisionweightHighprecisioninputHighprecisionGEMMHighprecisionweightFP8WeightFP8inputFP8GEMMFP8 \ No newline at end of file diff --git a/docs/examples/te_gemma/media/fp8_model_init_2_half.svg b/docs/examples/te_gemma/media/fp8_model_init_2_half.svg new file mode 100755 index 0000000000..46587664fe --- /dev/null +++ b/docs/examples/te_gemma/media/fp8_model_init_2_half.svg @@ -0,0 +1 @@ +FP8FP8 with fp8_model_init()FP8weightFP8GEMMHighprecisionweightFP8WeightFP8inputFP8GEMMFP8input \ No newline at end of file diff --git a/docs/examples/te_gemma/media/generation_animation.gif b/docs/examples/te_gemma/media/generation_animation.gif new file mode 100755 index 0000000000..25150cb9b6 Binary files /dev/null and b/docs/examples/te_gemma/media/generation_animation.gif differ diff --git a/docs/examples/te_gemma/media/graphs.svg b/docs/examples/te_gemma/media/graphs.svg new file mode 100755 index 0000000000..f734637e6d --- /dev/null +++ b/docs/examples/te_gemma/media/graphs.svg @@ -0,0 +1 @@ +Without CUDA GraphsWith CUDA GraphsLaunch 1Kernel 1Launch 2Kernel 2Launch 3Kernel 3Launch Graph 1Kernel 1Kernel 2Kernel 3 \ No newline at end of file diff --git a/docs/examples/te_gemma/media/graphs_1.png b/docs/examples/te_gemma/media/graphs_1.png new file mode 100755 index 0000000000..f42b50fe0d Binary files /dev/null and b/docs/examples/te_gemma/media/graphs_1.png differ diff --git a/docs/examples/te_gemma/media/graphs_2.png b/docs/examples/te_gemma/media/graphs_2.png new file mode 100755 index 0000000000..35c34ede55 Binary files /dev/null and b/docs/examples/te_gemma/media/graphs_2.png differ diff --git a/docs/examples/te_gemma/media/plot.svg b/docs/examples/te_gemma/media/plot.svg new file mode 100755 index 0000000000..481f156df6 --- /dev/null +++ b/docs/examples/te_gemma/media/plot.svg @@ -0,0 +1 @@ +87.68 s54.11 s28.22 s16.75 s12.13 s0 s10 s20 s30 s40 s50 s60 s70 s80 s90 s100 sHF (baseline)TE (subsitution ofGemmaDecoderLayer withte.TransformerLayer)TE + THD attentionTE + THD attention + CUDA GraphsTE + THD attention + FP8 \ No newline at end of file diff --git a/docs/examples/te_gemma/media/thd_bshd.svg b/docs/examples/te_gemma/media/thd_bshd.svg new file mode 100755 index 0000000000..47eed69565 --- /dev/null +++ b/docs/examples/te_gemma/media/thd_bshd.svg @@ -0,0 +1 @@ +BSHD LayoutQKVQKVCumulative sequence lengths:3, 3 + 1, 3 + 1 + 3, 3 + 1 + 3 + 1Sequence offsets:0, 4, 8, 12[batch_size,seq_len,head_nr,dim][total_nr_tokens,head_nr,dim]Seq. 1Seq. 2Seq. 4Seq. 3sbtTHD LayoutPad. 1Pad. 2Pad. 4Pad. 3Attention masktokenpadding \ No newline at end of file diff --git a/docs/examples/te_gemma/requirements.txt b/docs/examples/te_gemma/requirements.txt new file mode 100755 index 0000000000..c90fb6dad0 --- /dev/null +++ b/docs/examples/te_gemma/requirements.txt @@ -0,0 +1,4 @@ +transformers==4.41.1 +accelerate==0.30.1 +datasets==2.19.1 +sentencepiece==0.2.0 \ No newline at end of file diff --git a/docs/examples/te_gemma/run_gemma_2b.py b/docs/examples/te_gemma/run_gemma_2b.py new file mode 100644 index 0000000000..db2fb087c9 --- /dev/null +++ b/docs/examples/te_gemma/run_gemma_2b.py @@ -0,0 +1,15 @@ +from transformers import AutoTokenizer, AutoModelForCausalLM +from huggingface_hub import login + +access_token = "" +login(access_token) + +model_name = "google/gemma-3-4b-it" +tokenizer = AutoTokenizer.from_pretrained(model_name) +model = AutoModelForCausalLM.from_pretrained(model_name) +print(model.config) +input_text = "Write me a poem about Machine Learning." +input_ids = tokenizer(input_text, return_tensors="pt") + +outputs = model.generate(**input_ids) +print(tokenizer.decode(outputs[0])) diff --git a/docs/examples/te_gemma/run_generation.py b/docs/examples/te_gemma/run_generation.py new file mode 100755 index 0000000000..910fa325d0 --- /dev/null +++ b/docs/examples/te_gemma/run_generation.py @@ -0,0 +1,55 @@ +from utils import * +import transformer_engine.pytorch as te + +hyperparams.model_name = ( # "/tmp/gemma-7b-hf/" # <== Add model weight location here e.g. "/path/to/downloaded/gemma/weights" + "/perfhome/repos/ckpt/models/gemma-7b-hf/" +) +hyperparams.qkv_format = "thd" + +run_generation = True +run_calibration = False + +if run_calibration: + hyperparams.fuse_qkv_params = True # This is needed by the last improvement. + + model = init_te_gemma_model(hyperparams) + + # Calibration + with te.fp8_autocast(enabled=False, calibrating=True), torch.autocast( + device_type="cuda", dtype=torch.bfloat16 + ): + model.train() + run_forward_pass(model, hyperparams, num_iters=512) + + # Compute scale_fwd with enabled fp8 autocast + with te.fp8_autocast(enabled=True), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + run_forward_pass(model, hyperparams, 1) + + # Some parameters are in pointing to the same tensors, double save is avoided here. + dict_to_save = { + k: v + for k, v in model.state_dict().items() + if ("_context_phase" not in k and "_generation_phase" not in k) + } + torch.save(dict_to_save, "calibrated_weights.pth") # <== Add path to save calibrated weights. + + +if run_generation: + + # hyperparams.generation_cuda_graphs = False # 4.15s + hyperparams.generation_cuda_graphs = True # 4.38s + + if hyperparams.generation_cuda_graphs: + # It is necessary to preallocate a static buffer. + # CUDA graphs require static input tensors for every kernel. + # This approach may result in a slight increase in memory consumption; + # however, the substantial speedup achieved makes it worthwhile. + hyperparams.cuda_graphs_static_batch_size = 64 + hyperparams.cuda_graphs_static_max_seq_len = 128 + hyperparams.cuda_graphs_static_max_context_len = 128 + + hyperparams.is_paged = False + model = init_te_gemma_model(hyperparams) + + print_sample_of_generated_texts(model) + benchmark_generation(model) diff --git a/docs/examples/te_gemma/run_generation_llama.py b/docs/examples/te_gemma/run_generation_llama.py new file mode 100755 index 0000000000..1c3e6626ca --- /dev/null +++ b/docs/examples/te_gemma/run_generation_llama.py @@ -0,0 +1,12 @@ +from utils import * + +hyperparams.model_name = ( # "/tmp/gemma-7b-hf/" # <== Add model weight location here e.g. "/path/to/downloaded/gemma/weights" + "/perfhome/repos/ckpt/models/llama2-7b-hf/" +) +hyperparams.qkv_format = "thd" + +# model = init_te_llama_model(hyperparams) +model = init_baseline_model(hyperparams) + +print_sample_of_generated_texts(model) +# benchmark_generation(model) diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py new file mode 100755 index 0000000000..706ea16bc4 --- /dev/null +++ b/docs/examples/te_gemma/te_gemma.py @@ -0,0 +1,594 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from contextlib import contextmanager + +from typing import Optional +from functools import partial +from collections import OrderedDict + +import torch +import transformer_engine as te +from transformer_engine.pytorch.attention import InferenceParams, RotaryPositionEmbedding +from transformer_engine.common.recipe import Format, DelayedScaling +from torch.cuda.amp import autocast + +import transformers +from transformers.models.gemma.modeling_gemma import GemmaForCausalLM, GemmaConfig, GemmaModel + +import torch.nn.functional as F + + +class TEGemmaDecoderLayer(te.pytorch.TransformerLayer): + """ + Wrapper class over TE's `TransformerLayer`. This makes the wrapper very + similar to HF's `GemmaDecoderLayer` and easier to replace it in the code. + + Args: + config: GemmaConfig + args: positional args (for compatibility with `GemmaDecoderLayer`) + kwargs: keyword args (for compatibility with `GemmaDecoderLayer`) + """ + + def __init__(self, config: GemmaConfig, layer_idx: int, *args, **kwargs): + + self.gemma_config = config + + super().__init__( + hidden_size=config.hidden_size, + ffn_hidden_size=config.intermediate_size, + num_attention_heads=config.num_attention_heads, + bias=False, + layernorm_epsilon=config.rms_norm_eps, + hidden_dropout=0, + attention_dropout=0, + fuse_qkv_params=config.fuse_qkv_params, + normalization="RMSNorm", + activation="geglu", + # attn_input_format=config.qkv_format, + attn_input_format="bshd", + num_gqa_groups=config.num_key_value_heads, + kv_channels=self.gemma_config.head_dim, + layer_number=( + layer_idx + 1 + ), # Layer numbers in TE starts from 1, not 0 like in the HF. + zero_centered_gamma=True, + ) + + def forward(self, *args, **kwargs): # We need to additionally pass positional encoding. + + # this args cannot be passed to TransformerLayer + keys_to_remove = [ + "position_ids", + "past_key_value", + "output_attentions", + "use_cache", + "cache_position", + ] + for key in keys_to_remove: + kwargs.pop(key, None) + + rope_emb = kwargs.pop("rope_emb", None) + # We need to return tuple to be compatible with HF. + return (super().forward(*args, rotary_pos_emb=rope_emb, **kwargs),) + + +class StaticGemmaModel(torch.nn.Module): + """ + StaticGemma is based of HF GemmaModel class. + It is adjusted to work properly with CUDA Graphs. + """ + + def __init__( + self, + model: GemmaModel, + dtype: torch.dtype, + mask: torch.Tensor, + lm_head: torch.nn.Module, + ): + super().__init__() + self.model = model + self.normalizer = torch.tensor(self.model.config.hidden_size**0.5, dtype=dtype) + self.mask = mask + self.lm_head = lm_head + + def set_inference_params(self, inference_params): + self.inference_params = inference_params + + # @sudhakars: is `arbitrary` fine being the default here? + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor = None, + attn_mask_type: str = "arbitrary", + rope_emb: torch.Tensor = None, + ): + # print(f"StaticGemmaModel forward start") + with torch.no_grad(): + # static operation - for CUDA graphs + hidden_states.data[:] = hidden_states.data[:] * self.normalizer + + for i, decoder_layer in enumerate(self.model.layers): + # print(f"layer {i}") + hidden_states.data[:] = decoder_layer( + hidden_states, + attention_mask=attention_mask, + self_attn_mask_type=self.mask if attn_mask_type is None else attn_mask_type, + inference_params=self.inference_params, + rope_emb=rope_emb, + )[ + 0 + ] # static copy - for CUDA graphs + + hidden_states.copy_(self.model.norm(hidden_states)) # static copy - for CUDA graphs + logits = self.lm_head(hidden_states) + logits = logits.float() + return logits, hidden_states + + +class GemmaGenerator(torch.nn.Module): + """ + GemmaGenerator gets one layer of embeddins, + makes forward pass and returns next tokens. + """ + + def __init__( + self, model: GemmaModel, lm_head: torch.nn.Module, dtype: torch.dtype, qkv_format: str + ): + super().__init__() + self.model = model + self.gemma_layers = StaticGemmaModel(model, dtype, "arbitrary", lm_head) + self.qkv_format = qkv_format + + def set_inference_params(self, inference_params): + self.inference_params = inference_params + self.gemma_layers.set_inference_params(inference_params) + + # @sudhakars: is `arbitrary` a good default value here? + def forward( + self, + hidden_states: torch.Tensor, + mask: torch.Tensor = None, + attn_mask_type: str = "arbitrary", + rope_emb: torch.Tensor = None, + ): + logits, _ = self.gemma_layers( + hidden_states, attention_mask=mask, attn_mask_type=attn_mask_type, rope_emb=rope_emb + ) + + assert logits.shape[0] == hidden_states.shape[0] # b + assert logits.shape[1] == hidden_states.shape[1] # seq_len + # logits.shape[2] = number of tokens + logits = logits[:, -1, :] + next_tokens = torch.argmax(logits, dim=1) + + # static copy for CUDA graphs + hidden_states.copy_(self.model.embed_tokens(next_tokens).unsqueeze(1)) + + return next_tokens + + +@contextmanager +def replace_decoder(te_decoder_cls): + """ + Replace `GemmaDecoderLayer` with custom `TEGemmaDecoderLayer`. + """ + original_gemma_decoder_cls = transformers.models.gemma.modeling_gemma.GemmaDecoderLayer + transformers.models.gemma.modeling_gemma.GemmaDecoderLayer = te_decoder_cls + try: + yield + finally: + transformers.models.gemma.modeling_gemma.GemmaDecoderLayer = original_gemma_decoder_cls + + +class TEGemmaForCausalLM(GemmaForCausalLM): + """ + Causal LM created with `GemmaModel`. The underlying `GemmaDecoderLayer` + class is monkey-patched with `TEGemmaDecoderLayer` class before + initializing the causal LM with `GemmaForCausalLM`. + + Args: + config: GemmaConfig + """ + + def __init__(self, config: GemmaConfig): + with replace_decoder(te_decoder_cls=TEGemmaDecoderLayer): + super().__init__(config) + self.config = config + self.to(torch.bfloat16).cuda() + self.hidden_size = config.hidden_size + self._model_generation_phase = GemmaGenerator( + lm_head=self.lm_head, + model=self.model, + dtype=torch.bfloat16, + qkv_format=config.qkv_format, + ) + self._model_context_phase = StaticGemmaModel( + self.model, torch.bfloat16, "arbitrary", self.lm_head + ) + + if self.config.fp8: + self.fp8_recipe = DelayedScaling( + fp8_format=Format.HYBRID, amax_history_len=16, amax_compute_algo="max" + ) + + self.te_rope_emb = RotaryPositionEmbedding(self.config.head_dim)( + max_seq_len=self.config.max_position_embeddings + ).cuda() + + @staticmethod + def _padding_to_end(inputs, lengths, max_seq_len=None): + """ + Gets the tensor with sequence padded from the beginning and + return tensor padded from its end. + + Parameters + ---------- + inputs : Tensor, tensor with shape [b, s] containing token numbers. + It's padded from the beggining. + lengths: Tensor, tensor with shape [s] with lengths of the sequences. + + """ + max_seq_len = torch.max(lengths) if max_seq_len is None else max_seq_len + batch_size, max_seq_len = inputs.shape + new_input_ids = inputs.clone() + for i in range(batch_size): + new_input_ids[i, : lengths[i]] = inputs[i, (max_seq_len - lengths[i]) : max_seq_len] + new_input_ids[i, lengths[i] :] = inputs[i, 0 : (max_seq_len - lengths[i])] + + # Disable the input preparation that involves extra padding + # inputs.copy_(new_input_ids) + + # Trim the inputs to no extra padding i.e. fix the max seq len to + # the longest sequence in the batch + actual_max_seq_len = max_seq_len + inputs.data = new_input_ids[:, :actual_max_seq_len] + # print(f"actual_max_seq_len: {actual_max_seq_len}") + + # For Paged Attention, make the valid sequences, multiple of 64 + # inputs.data = new_input_ids[:, :4].repeat(1, 16) + # import pdb; pdb.set_trace() + # print(f"inputs.data.shape: {inputs.data.shape}") + # exit() + + def _next_64_multiply(self, x): + return ((x + 63) // 64) * 64 + + # This function is overriden in TeGEmmaForCausalLMCudaGraphs. + def _create_hidden_states_buffer(self, input_ids: torch.Tensor): + tensor = torch.empty( + (input_ids.shape[0], input_ids.shape[1], self.hidden_size), + device="cuda", + dtype=torch.float32, + ) + # import pdb; pdb.set_trace() + return tensor + + # This function is overriden in TeGEmmaForCausalLMCudaGraphs. + def _create_inference_params(self, *args, **kwargs): + infer_params = InferenceParams(*args, **kwargs) + return infer_params + + # This function is overriden in TeGEmmaForCausalLMCudaGraphs. + def _get_max_input_seq_len(self, input_ids): + return ( + input_ids.shape[1] + if not hasattr(self.config, "cuda_graphs_static_max_context_len") + else self.config.cuda_graphs_static_max_context_len + ) + + # The buffer for generation is some part (beginning) of hidden states buffer. + # This function returns pointer to it and also copies there data if provided. + def _get_generation_buffer(self, hidden_states_buffer, data_to_copy=None): + # hidden_states_buffer has shape [b, s, hd] + # generation_buffer will have shape [b, 1, hd] + # Notice that "generation_buffer = hidden_states_buffer[:, 0, :].unsqueeze(1)" + # will return uncontiguous buffer, which we want to avoid. + output = hidden_states_buffer.view(-1)[ + : hidden_states_buffer.shape[0] * hidden_states_buffer.shape[2] + ] + if data_to_copy is not None: + output.copy_(data_to_copy.reshape(-1)) + generation_buffer = output.view( + (hidden_states_buffer.shape[0], 1, hidden_states_buffer.shape[2]) + ) + return generation_buffer + + def _generate_context_phase(self, input_ids: torch.Tensor, inference_params: InferenceParams): + # import pdb; pdb.set_trace() + hidden_states = self._create_hidden_states_buffer(input_ids) + hidden_states.copy_(self.model.embed_tokens(input_ids)) + + # We need to update offsets before every forward pass to make cache work properly. + lengths = input_ids.ne(0).sum(dim=1) + + # import pdb; pdb.set_trace() + if self.config.qkv_format == "thd": + # inference_params.setup_before_new_input( + # lengths_tensor=lengths, max_input_length=input_ids.shape[1] + # ) + lengths = input_ids.ne(0).sum(dim=1) + inference_params.pre_step(OrderedDict(zip(list(range(len(lengths))), lengths.tolist()))) + else: + inference_params.setup_before_new_input(length=input_ids.shape[1]) + + logits, hs_buffer = self._model_context_phase( + hidden_states, + attention_mask=((input_ids == 0) if self.config.qkv_format != "thd" else None), + attn_mask_type="padding_causal" if self.config.qkv_format == "thd" else "arbitrary", + rope_emb=self.te_rope_emb, + ) + + if self.config.qkv_format == "thd": + logits = logits[torch.arange(logits.size(0)), lengths - 1, :] + else: + logits = logits[:, -1, :] + + next_tokens = torch.argmax(logits, dim=1) + + # self.hidden_states have shape [b, s, hd]. + # We return hidden state for the last token - output has shape [b, 1, hd] + hidden_states = self._get_generation_buffer( + hidden_states, self.model.embed_tokens(next_tokens) + ) + return hidden_states, next_tokens + + def _make_mask_one_token_longer(self, mask): + return torch.cat( + [mask, torch.zeros(mask.size(0), 1, 1, 1, dtype=torch.bool, device=mask.device)], dim=-1 + ) + + @torch.no_grad() + def generate( + self, + input_ids: Optional[torch.Tensor] = None, + pad_token_id: int = 0, + max_new_tokens: int = 0, + *args, + **kwargs + ): + self.eval() + + # We need both autocasts: FP8 for operations that can run in lower precision + # and BF16 for those that cannot. + with autocast(dtype=torch.bfloat16, cache_enabled=False), te.pytorch.fp8_autocast( + enabled=self.config.fp8, fp8_recipe=self.fp8_recipe if self.config.fp8 else None + ): + + lengths = torch.sum(input_ids.ne(pad_token_id), dim=-1).squeeze() # [s] + + # print(f"max_input_sequence_len: {max_input_sequence_len}") + # exit() + + if self.config.qkv_format == "thd": + # For thd layout padding is at the end, otherwise at the beginning. + TEGemmaForCausalLM._padding_to_end( + input_ids, + lengths, + max_seq_len=( + self.config.cuda_graphs_static_max_context_len + if self.config.generation_cuda_graphs + else None + ), + ) + + batch_size, max_input_sequence_len = input_ids.shape[0], self._get_max_input_seq_len( + input_ids + ) + + # InferenceParams is a cache, where keys and values of previous tokens are stored. + # Moreover it stores length of both already generated and input sequences. + inference_params = self._create_inference_params( + max_batch_size=batch_size, + # num_layers=self.config.num_hidden_layers, + max_sequence_length=128, + num_heads_kv=self.config.num_key_value_heads, + # num_heads_q=self.config.num_attention_heads, + head_dim_v=self.config.head_dim, + head_dim_k=self.config.head_dim, + dtype=torch.bfloat16, + is_paged=self.config.is_paged, + page_size=64, + total_num_pages=64 * 128 // 64, # 64 * 64 (max_sequence_length) / 64 (page_size) + ) + + self._model_context_phase.set_inference_params(inference_params) + self._model_generation_phase.set_inference_params(inference_params) + + # print(f"context phase start") + # import pdb; pdb.set_trace() + hidden_states, next_tokens = self._generate_context_phase(input_ids, inference_params) + + # print(f"context phase done") + # Generation phase. + if self.config.qkv_format == "thd": + lengths_tensor = torch.ones((next_tokens.shape[0],), dtype=int) + inference_params.pre_step( + OrderedDict(zip(list(range(len(lengths_tensor))), lengths_tensor.tolist())) + ) + else: + inference_params.setup_before_new_input(length=1) + + output_tokens = [next_tokens] + + mask = None + if self.config.qkv_format != "thd": + mask = (input_ids == 0).unsqueeze(1).unsqueeze(1) + + for _ in range(max_new_tokens): + if self.config.qkv_format != "thd": + # It will not work with cuda graphs, but it is not used for thd qkv_format. + # Attention mask in bshd needs attn_mask increased by 1 to + # include the next token to be generated + mask = self._make_mask_one_token_longer(mask) + + next_tokens = self._model_generation_phase( + hidden_states, + mask=mask, + attn_mask_type="padding" if self.config.qkv_format == "thd" else "arbitrary", + rope_emb=self.te_rope_emb, + ) + + # self.inference_params contains for example kv_cache. + # This needs to be called before every pass, + # to update the information of sequence lengths. + # Here we increase sequence offsets by one, + # because we generated one token for every sequence. + if self.config.qkv_format == "thd": + lengths_tensor = torch.ones((next_tokens.shape[0],), dtype=int) + inference_params.pre_step( + OrderedDict(zip(list(range(len(lengths_tensor))), lengths_tensor.tolist())) + ) + else: + inference_params.setup_before_new_input(length=1) + # next_tokens is static output tensor, so we need to clone it + # - it gets changed every iteration. + output_tokens.append(next_tokens.clone()) + + result = torch.cat((input_ids, torch.stack(output_tokens).permute([1, 0])), dim=1) + return result + + def forward(self, *args, **kwargs): + self._model_context_phase.set_inference_params(None) + hidden_states = self.model.embed_tokens(kwargs["input_ids"]) + logits = self._model_context_phase( + hidden_states, + attention_mask=( + (kwargs["input_ids"] == 0) if self.config.qkv_format != "thd" else None + ), + attn_mask_type="arbitrary", + ) + return logits + + +class TEGemmaForCausalLMCudaGraphs(TEGemmaForCausalLM): + """ + TEGemmaForCausalLMCudaGraphs is the version of the class TEGemmaForCausalLM + using CUDA Graphs to speed it up. We need to make one trade-off. + Namely, batch_size, max_seq_len and max_context_seq_len need to be static. + It is necessary to run generation with the same value of + these variables that we recorded graph on. + """ + + def __init__(self, config: GemmaConfig): + super().__init__(config) + assert ( + config.qkv_format == "thd" + ), "Generation with CUDA Graphs are implemented only for thd format." + + # Preparation of the static buffers. + self.config = config + self.hidden_states_buffer = torch.empty( + ( + self.config.cuda_graphs_static_batch_size, + self.config.cuda_graphs_static_max_context_len, + self.config.hidden_size, + ) + ).cuda() + + # This is in fact part of the buffer for hidden_states. + self.generation_buffer = self._get_generation_buffer(self.hidden_states_buffer) + self.inference_params = InferenceParams( + max_batch_size=self.config.cuda_graphs_static_batch_size, + # num_layers=self.config.num_hidden_layers, + max_sequence_length=self.config.cuda_graphs_static_max_seq_len, + num_heads_kv=self.config.num_key_value_heads, + # num_heads_q=self.config.num_attention_heads, + head_dim_v=self.config.head_dim, + head_dim_k=self.config.head_dim, + dtype=torch.bfloat16, + is_paged=self.config.is_paged, + page_size=64, + total_num_pages=64 + * self.config.cuda_graphs_static_max_seq_len + // 64, # 64 * 64 (max_sequence_length) / 64 (page_size) + ) + + self._model_generation_phase.set_inference_params(self.inference_params) + self._model_context_phase.set_inference_params(self.inference_params) + + def record(self): + # We want to record model in training=False, because it will be used in generation. + self.eval() + + # Here "the trick" happens. We override methods from TEGemmaForCausalLM + # with their recorded version. After invocation of each of them, + # captured graph will be replayed with minimal usage of CPU, + # what will lead to huge speedup. + input_shape = ( + self.config.cuda_graphs_static_batch_size, + self.config.cuda_graphs_static_max_context_len, + ) + + # [1] Should be same as lengths_tensor from TEGemmaForCausalLM + lengths = torch.tensor(input_shape[0] * [input_shape[1]], device="cuda", dtype=torch.int32) + max_input_length = input_shape[1] + + self.inference_params.pre_step( + OrderedDict(zip(list(range(len(lengths))), lengths.tolist())) + ) + + # print(f"context phase recording start") + + self._model_context_phase = self.record_graph( + self._model_context_phase, + self.hidden_states_buffer, + attn_mask_type="padding_causal", + rope_emb=self.te_rope_emb, + ) # CUDA Graphs recording + + # print(f"context phase recording done") + input_shape = (self.config.cuda_graphs_static_batch_size, 1) + + lengths = torch.tensor(input_shape[0] * [1], device="cuda", dtype=torch.int32) + + self.inference_params.pre_step( + OrderedDict(zip(list(range(len(lengths))), lengths.tolist())) + ) + + self._model_generation_phase = self.record_graph( + self._model_generation_phase, + self.generation_buffer, + attn_mask_type="padding", + rope_emb=self.te_rope_emb, + ) # CUDA Graphs recording + + """ + Functions _create_hidden_states_buffer and _create_inference_params + from base class are overriden to make hidden_states and inference_params static + - not changing their position in memory between every invocation. + """ + + def _create_hidden_states_buffer(self, *args, **kwargs): + return self.hidden_states_buffer + + def _create_inference_params(self, *args, **kwargs): + self.inference_params.reset() + return self.inference_params + + def _get_max_input_seq_len(self, _): + return self.config.cuda_graphs_static_max_context_len + + @torch.no_grad() + def record_graph(self, function, input_tensor, **sample_kwargs): + # function is invoked on argument (self.hidden_states,) and all kernels are recorded. + # record_graph() returns captured function, which can be run later with lower of th CPU. + fp8_format = Format.HYBRID + fp8_recipe = DelayedScaling( + fp8_format=fp8_format, amax_history_len=1024, amax_compute_algo="max" + ) + + # We need both autocasts: FP8 for operations that can run in lower precision + # and BF16 for those that cannot. + with autocast(dtype=torch.bfloat16, cache_enabled=False): + graphed_function = te.pytorch.make_graphed_callables( + function, + (input_tensor,), + fp8_enabled=self.config.fp8, + fp8_recipe=fp8_recipe, + allow_unused_input=True, + num_warmup_iters=5, + sample_kwargs=sample_kwargs, + ) + return graphed_function diff --git a/docs/examples/te_gemma/te_gemma_loading_weights.py b/docs/examples/te_gemma/te_gemma_loading_weights.py new file mode 100755 index 0000000000..41f62ad7f3 --- /dev/null +++ b/docs/examples/te_gemma/te_gemma_loading_weights.py @@ -0,0 +1,160 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import re +import gc +import torch + +from typing import List + +from transformer_engine.pytorch.fp8 import fp8_model_init + +from transformers.modeling_utils import load_state_dict, _load_state_dict_into_model +from transformers.utils.hub import get_checkpoint_shard_files + +""" + This file contains logic of mapping the HuggingFace GemmaModel parameters + with TransformerEngine TransformerLayer. When we have initialized Transformer models + both with HF and with TE, we can copy parameters from the first to the second. +""" + + +def _load_weights_for_fp8_model(vanilla_model, hyperparams): + # The weights are loaded from the file with state_dict + # of model with weights which contains also fp8 parameters. + # The weights are in BF16 precision, but they contain fp8 metadata + # computed by the calibration procedure. + vanilla_model.load_state_dict( + torch.load(hyperparams.fp8_model_weights_filename), + strict=False, + # strict = false, because some parameters have + # multiple pointers to the same weight + # vanilla_model._model_context_phase.model + # and vanilla_model._model_generation_phase.model + ) + + +def _load_weights_for_standard_model(vanilla_model, config): + # The weights are loaded from the file with original weights. + archive_file = os.path.join(config.model_name, "model.safetensors.index.json") + resolved_archive_file, _ = get_checkpoint_shard_files(config.model_name, archive_file) + total_dict = {} + for shard_file in resolved_archive_file: + state_dict = load_state_dict(shard_file) + total_dict.update(state_dict) + + replace_params( + total_dict, + vanilla_model.state_dict(), + config, + qkv_fused_and_interleaved=config.fuse_qkv_params, + ) + # Copy parameters like embedding: + _load_state_dict_into_model(vanilla_model, total_dict, start_prefix="") + + # Force mem release. Taken from huggingface code. + del total_dict + gc.collect() + + +def load_te_model(cls, config): + """ + Custom method adapted from `from_pretrained` method in HuggingFace + Transformers repo: + https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579 + """ + config.use_cache = False # To make TransformerLayer compatible with GemmaModel + with fp8_model_init(config.fp8_model_init): + # there we need only to create model + vanilla_model = cls(config).to(torch.bfloat16).cuda() + + # return vanilla_model + # and now we copy the weights into it + if config.fp8_model_weights_filename is not None: + _load_weights_for_fp8_model(vanilla_model, config) + else: + _load_weights_for_standard_model(vanilla_model, config) + + return vanilla_model + + +def _get_all_layer_prefixes_to_update(hf_state_dict): + """ + There are many parameters in hf_state_dict, whose name start with "model.layers.[number]." + This function extracts all strings like "model.layers.[number]." + that are starting strings of keys in hf_state_dict. + """ + all_layer_prefixes = set() + for param_key in hf_state_dict.keys(): + layer_prefix_pat = "model.layers.\d+." + m = re.match(layer_prefix_pat, param_key) + if m is not None: + all_layer_prefixes.add(m.group()) + return all_layer_prefixes + + +def replace_params(hf_state_dict, te_state_dict, config, qkv_fused_and_interleaved=False): + """ + Replaces params from TE TransformerLayer state_dict with corresponding parameters + from HuggingFace GemmaModel state_dict. + """ + all_layer_prefixes: List[str] = _get_all_layer_prefixes_to_update(hf_state_dict) + + for layer_prefix in all_layer_prefixes: + + def copy_from_ht_to_te(te_name, hf_name, start=None, end=None): + te_state_dict[layer_prefix + te_name].data[start:end].copy_( + hf_state_dict[layer_prefix + hf_name] + ) + + copy_from_ht_to_te( + "self_attention.layernorm_qkv.layer_norm_weight", "input_layernorm.weight" + ) + copy_from_ht_to_te("self_attention.proj.weight", "self_attn.o_proj.weight") + copy_from_ht_to_te("layernorm_mlp.layer_norm_weight", "post_attention_layernorm.weight") + copy_from_ht_to_te("layernorm_mlp.fc2_weight", "mlp.down_proj.weight") + copy_from_ht_to_te( + "layernorm_mlp.fc1_weight", "mlp.gate_proj.weight", end=config.intermediate_size + ) + copy_from_ht_to_te( + "layernorm_mlp.fc1_weight", "mlp.up_proj.weight", start=config.intermediate_size + ) + + if qkv_fused_and_interleaved: + """ + When qkv_fused_and_interleaved=True, key, query and value layers are on one tensor + in TE TransformerLayer. Moreover they are interleaved within each head. + Let q_i, k_i and v_i be query, key and value layers for i-th head respectively. + Then TE stores weight tensor in the form: + [q1 k1 v1 q2 k2 v2 ...] + This is done to maximally optimize performance time. + """ + te_qkv_layer = te_state_dict[layer_prefix + "self_attention.layernorm_qkv.weight"] + + def copy_interleave(hf_name, idx): + src = hf_state_dict[layer_prefix + hf_name] + for head_nr in range(config.num_attention_heads): + dst_offset = head_nr * config.head_dim * 3 + dst_slice = slice( + dst_offset + idx * config.head_dim, dst_offset + (idx + 1) * config.head_dim + ) + src_slice = slice( + head_nr * config.head_dim, head_nr * config.head_dim + config.head_dim + ) + te_qkv_layer[dst_slice, :] = src[src_slice, :] + + copy_interleave("self_attn.q_proj.weight", 0) + copy_interleave("self_attn.k_proj.weight", 1) + copy_interleave("self_attn.v_proj.weight", 2) + else: + copy_from_ht_to_te( + "self_attention.layernorm_qkv.query_weight", "self_attn.q_proj.weight" + ) + copy_from_ht_to_te("self_attention.layernorm_qkv.key_weight", "self_attn.k_proj.weight") + copy_from_ht_to_te( + "self_attention.layernorm_qkv.value_weight", "self_attn.v_proj.weight" + ) + + return all_layer_prefixes diff --git a/docs/examples/te_gemma/te_gemma_save.py b/docs/examples/te_gemma/te_gemma_save.py new file mode 100755 index 0000000000..c83378840c --- /dev/null +++ b/docs/examples/te_gemma/te_gemma_save.py @@ -0,0 +1,872 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from contextlib import contextmanager + +from typing import Optional +from functools import partial +from collections import OrderedDict + +import torch +import transformer_engine as te +from transformer_engine.pytorch.attention import InferenceParams, RotaryPositionEmbedding +from transformer_engine.common.recipe import Format, DelayedScaling +from torch.cuda.amp import autocast + +import transformers +from transformers.models.gemma.modeling_gemma import GemmaForCausalLM, GemmaConfig, GemmaModel + +import torch.nn.functional as F + + +class CacheParams: + def __init__( + self, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + ): + self.max_seqlen_q = max_seqlen_q + self.max_seqlen_kv = max_seqlen_kv + self.cu_seqlens_q = cu_seqlens_q + self.cu_seqlens_kv = cu_seqlens_kv + self.cu_seqlens_q_padded = cu_seqlens_q_padded + self.cu_seqlens_kv_padded = cu_seqlens_kv_padded + + +def setup_cache_params_from_infer_params(inference_params, lengths_tensor, max_input_length): + """ + Converts the `input_ids` to variables like `cu_seqlens_q/kv`, etc. which + will be used later. + + (Currently a hack, this should be reformatted to a better method) + """ + + assert ( + lengths_tensor is not None and max_input_length is not None + ), 'lengths_tensor and max_input_length should not be none for qkv_format = "thd"' + + inference_params.max_incoming_seq_len = max_input_length + + lengths_tensor = lengths_tensor.to(inference_params.cu_seqlens_q.device) + + # inference_params.step_dict = OrderedDict(zip(list(range(len(lengths_tensor))), lengths_tensor.tolist())) + inference_params.pre_step( + OrderedDict(zip(list(range(len(lengths_tensor))), lengths_tensor.tolist())) + ) + + # print(inference_params.step_dict) + + # def get_cache_params_in_infer_params(): + # return CacheParams(max_seqlen_q, max_seqlen_kv, inference_params.cu_seqlens_q, inference_params.cu_seqlens_kv, inference_params.cu_seqlens_q_padded, inference_params.cu_seqlens_kv_padded) + + # For the time being, create an ad-hoc field in `inference_params` to get the variables. + # @sudhakars: to create a better way later. + # inference_params.get_cache_params_from_infer_params = get_cache_params_in_infer_params + + +# This class has been modified from +# https://github.com/huggingface/transformers/blob/98adf24883b007c2a7fb17bab1c01b1614673433/src/transformers/models/gemma/modeling_gemma.py +class GemmaRotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim) + ) + self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + self.inv_freq.to(x.device) + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + ) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = ( + device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + return emb.unsqueeze(2) # should return in [b, s, 1, d] format + + +class StaticBufferAllocator(torch.nn.Module): + """ + This class is used when we use te.make_graphed_callable(). + CUDA Graphs require all tensors to be static. Neverthless, + torch API make_graphed_callable() takes care of output of torch modules, + and makes them static. Thus by wrapping allocation of memory into + torch.nn.Module, we can greatly simplify our code. + """ + + # pylint: disable=no-self-use + def forward(self, size, dtype, device): + """ + Return buffer of given size, dtype and device. + """ + return torch.zeros(size, dtype=dtype, device=device) + + +class TEGemmaDecoderLayer(te.pytorch.TransformerLayer): + """ + Wrapper class over TE's `TransformerLayer`. This makes the wrapper very + similar to HF's `GemmaDecoderLayer` and easier to replace it in the code. + + Args: + config: GemmaConfig + args: positional args (for compatibility with `GemmaDecoderLayer`) + kwargs: keyword args (for compatibility with `GemmaDecoderLayer`) + """ + + def __init__(self, config: GemmaConfig, layer_idx: int, *args, **kwargs): + + self.gemma_config = config + + super().__init__( + hidden_size=config.hidden_size, + ffn_hidden_size=config.intermediate_size, + num_attention_heads=config.num_attention_heads, + bias=False, + layernorm_epsilon=config.rms_norm_eps, + hidden_dropout=0, + attention_dropout=0, + fuse_qkv_params=config.fuse_qkv_params, + normalization="RMSNorm", + activation="geglu", + # attn_input_format=config.qkv_format, + attn_input_format="bshd", + num_gqa_groups=config.num_key_value_heads, + kv_channels=self.gemma_config.head_dim, + layer_number=( + layer_idx + 1 + ), # Layer numbers in TE starts from 1, not 0 like in the HF. + zero_centered_gamma=True, + ) + + def alloc(self, size, dtype, device): + """ + Allocated the buffer and works correctly with CUDA Graphs. + """ + return self._allocator(size, dtype, device) + + def forward(self, *args, **kwargs): # We need to additionally pass positional encoding. + + # if "self_attn_mask_type" in kwargs: + # attn_mask_type = kwargs['self_attn_mask_type'] + # else: + # attn_mask_type = "whatever_default_is" + + # if attn_mask_type == "arbitrary": + # # @sudhakars: following logic doesn't work for `thd` + # attn_mask = kwargs['attention_mask'] + # attention_mask_inv = ~attn_mask + # generation_case = torch.tensor(torch.tensor(attn_mask.shape).shape).item() > 2 + + # if generation_case: + # # @sudhakars: for some reason, `attention_mask` for generation is of the + # # form [b, 1, 1, s]. + # attention_mask_inv = attention_mask_inv.squeeze(1).squeeze(1) + # assert torch.tensor(torch.tensor(attention_mask_inv.shape).shape).item() == 2 + + # # Create `position_ids` on the fly using `attention_mask` since HF + # # does the same in generation logic. + # position_ids = attention_mask_inv.long().cumsum(-1) - 1 + # position_ids.masked_fill_(attention_mask_inv == 0, 1) + + # if "position_ids" in kwargs and kwargs['position_ids'] is not None: + # assert torch.all(torch.eq(position_ids, kwargs["position_ids"])), "position ids don't match match exactly!" + + # # convert [b, s] to [b, 1, s, s] since `arbitrary` is only set for + # # context phase and context phase gets [b, s] sized attn mask + # seq_len = 1 if torch.tensor(torch.tensor(attn_mask.shape).shape).item() > 2 else attention_mask_inv.shape[1] + # arbitrary_attn_mask = torch.zeros(attention_mask_inv.shape[0], 1, seq_len, attention_mask_inv.shape[1]).bool() + # for sample_idx in range(attn_mask.shape[0]): + # pad_len = attn_mask[sample_idx].sum().int().item() + # # set the columns to padded + # arbitrary_attn_mask[sample_idx, :, :, :pad_len] = True + # # set the rows to padded + # if not generation_case: + # arbitrary_attn_mask[sample_idx, :, :pad_len, :] = True + # arbitrary_attn_mask[sample_idx] = torch.tril(arbitrary_attn_mask[sample_idx].logical_not()).logical_not() + + # # Update the attention mask to arbitrary + # kwargs['attention_mask'] = arbitrary_attn_mask.cuda() + + # # @sudhakars: `max_position_embeddings` is not even used inside GemmaRotaryEmbedding + # # @sudhakars: change the hardcoded `dim` to something like config.head_dim + # te_rope_emb = GemmaRotaryEmbedding(dim=256, max_position_embeddings=self.gemma_config.max_position_embeddings).cuda() + # te_rope_emb = te_rope_emb(args[0], position_ids.cuda()) + # else: + # When the `attention_mask` is not `arbitrary`, then for the purpose + # of this tutorial, we're using `padding_causal` (for context) and + # `padding` (for generation) + # @sudhakars: find a better way to provide the `tensor_format` + te_rope_emb = RotaryPositionEmbedding(self.gemma_config.head_dim)( + max_seq_len=self.gemma_config.max_position_embeddings + ).cuda() + + inference_params = kwargs["inference_params"] + # @sudhakars: big assumption that the input is "sbhd" + # batch_size = args[0].shape[0] + + # if inference_params.qkv_format_legacy == "thd": + # cache_params = kwargs["cache_params"] + # max_seqlen_q = cache_params.max_seqlen_q + # max_seqlen_kv = cache_params.max_seqlen_kv + # cu_seqlens_q = cache_params.cu_seqlens_q + # cu_seqlens_kv = cache_params.cu_seqlens_kv + # cu_seqlens_q_padded = cache_params.cu_seqlens_q_padded + # cu_seqlens_kv_padded = cache_params.cu_seqlens_kv_padded + # print(f"input_sequence_lengths (in forward): \n{inference_params.input_sequence_lengths}") + + # this args cannot be passed to TransformerLayer + keys_to_remove = [ + "position_ids", + "past_key_value", + "output_attentions", + "use_cache", + "cache_position", + ] + for key in keys_to_remove: + kwargs.pop(key, None) + + # We need to return tuple to be compatible with HF. + return ( + super().forward( + *args, + rotary_pos_emb=te_rope_emb, + # cu_seqlens_q=cu_seqlens_q, + # cu_seqlens_kv=cu_seqlens_kv, + # max_seqlen_q=max_seqlen_q, + # max_seqlen_kv=max_seqlen_kv, + **kwargs, + ), + ) + + +class StaticGemmaModel(torch.nn.Module): + """ + StaticGemma is based of HF GemmaModel class. + It is adjusted to work properly with CUDA Graphs. + """ + + def __init__( + self, + model: GemmaModel, + dtype: torch.dtype, + mask: torch.Tensor, + lm_head: torch.nn.Module, + ): + super().__init__() + self.model = model + self.normalizer = torch.tensor(self.model.config.hidden_size**0.5, dtype=dtype) + self.mask = mask + self.lm_head = lm_head + + def set_inference_params(self, inference_params): + self.inference_params = inference_params + + # @sudhakars: is `arbitrary` fine being the default here? + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor = None, + attn_mask_type: str = "arbitrary", + ): + print(f"StaticGemmaModel forward start") + with torch.no_grad(): + # static operation - for CUDA graphs + hidden_states.data[:] = hidden_states.data[:] * self.normalizer + + for i, decoder_layer in enumerate(self.model.layers): + # print(f"layer {i}") + hidden_states.data[:] = decoder_layer( + hidden_states, + attention_mask=attention_mask, + self_attn_mask_type=self.mask if attn_mask_type is None else attn_mask_type, + inference_params=self.inference_params, + )[ + 0 + ] # static copy - for CUDA graphs + + hidden_states.copy_(self.model.norm(hidden_states)) # static copy - for CUDA graphs + logits = self.lm_head(hidden_states) + logits = logits.float() + return logits, hidden_states + + +class GemmaGenerator(torch.nn.Module): + """ + GemmaGenerator gets one layer of embeddins, + makes forward pass and returns next tokens. + """ + + def __init__( + self, model: GemmaModel, lm_head: torch.nn.Module, dtype: torch.dtype, qkv_format: str + ): + super().__init__() + self.model = model + self.gemma_layers = StaticGemmaModel(model, dtype, "arbitrary", lm_head) + self.qkv_format = qkv_format + + def set_inference_params(self, inference_params): + self.inference_params = inference_params + self.gemma_layers.set_inference_params(inference_params) + + # @sudhakars: is `arbitrary` a good default value here? + def forward( + self, + hidden_states: torch.Tensor, + mask: torch.Tensor = None, + attn_mask_type: str = "arbitrary", + ): + logits, _ = self.gemma_layers( + hidden_states, attention_mask=mask, attn_mask_type=attn_mask_type + ) + + assert logits.shape[0] == hidden_states.shape[0] # b + assert logits.shape[1] == hidden_states.shape[1] # seq_len + # logits.shape[2] = number of tokens + logits = logits[:, -1, :] + next_tokens = torch.argmax(logits, dim=1) + + # static copy for CUDA graphs + hidden_states.copy_(self.model.embed_tokens(next_tokens).unsqueeze(1)) + + return next_tokens + + +class PartialForwardWrapper(torch.nn.Module): + """ + This class wraps a `torch.nn.Module` while partially modifying its `forward` + + CUDAGraphs' `make_graphed_callables` method takes in a module but if only + `functools.partial` is used to wrap the module, it changes the modules' + type and that interferes with the `make_graphed_callables` intrinsics. + """ + + def __init__(self, module, **kwargs): + super().__init__() + self.module = module + self.partial_forward = partial(self.module.forward, **kwargs) + + def __call__(self, *args, **kwargs): + return self.partial_forward(*args, **kwargs) + + # @sudhakars: should we use better abstraction? + def set_inference_params(self, *args, **kwargs): + return self.module.set_inference_params(*args, **kwargs) + + +@contextmanager +def replace_decoder(te_decoder_cls): + """ + Replace `GemmaDecoderLayer` with custom `TEGemmaDecoderLayer`. + """ + original_gemma_decoder_cls = transformers.models.gemma.modeling_gemma.GemmaDecoderLayer + transformers.models.gemma.modeling_gemma.GemmaDecoderLayer = te_decoder_cls + try: + yield + finally: + transformers.models.gemma.modeling_gemma.GemmaDecoderLayer = original_gemma_decoder_cls + + +class TEGemmaForCausalLM(GemmaForCausalLM): + """ + Causal LM created with `GemmaModel`. The underlying `GemmaDecoderLayer` + class is monkey-patched with `TEGemmaDecoderLayer` class before + initializing the causal LM with `GemmaForCausalLM`. + + Args: + config: GemmaConfig + """ + + def __init__(self, config: GemmaConfig): + with replace_decoder(te_decoder_cls=TEGemmaDecoderLayer): + super().__init__(config) + self.config = config + self.to(torch.bfloat16).cuda() + self.hidden_size = config.hidden_size + self._model_generation_phase = GemmaGenerator( + lm_head=self.lm_head, + model=self.model, + dtype=torch.bfloat16, + qkv_format=config.qkv_format, + ) + self._model_context_phase = StaticGemmaModel( + self.model, torch.bfloat16, "arbitrary", self.lm_head + ) + + if self.config.fp8: + self.fp8_recipe = DelayedScaling( + fp8_format=Format.HYBRID, amax_history_len=16, amax_compute_algo="max" + ) + + @staticmethod + def _padding_to_end(inputs, lengths, max_seq_len=None): + """ + Gets the tensor with sequence padded from the beginning and + return tensor padded from its end. + + Parameters + ---------- + inputs : Tensor, tensor with shape [b, s] containing token numbers. + It's padded from the beggining. + lengths: Tensor, tensor with shape [s] with lengths of the sequences. + + """ + max_seq_len = torch.max(lengths) if max_seq_len is None else max_seq_len + batch_size, max_seq_len = inputs.shape + new_input_ids = inputs.clone() + for i in range(batch_size): + new_input_ids[i, : lengths[i]] = inputs[i, (max_seq_len - lengths[i]) : max_seq_len] + new_input_ids[i, lengths[i] :] = inputs[i, 0 : (max_seq_len - lengths[i])] + + # Disable the input preparation that involves extra padding + # inputs.copy_(new_input_ids) + + # Trim the inputs to no extra padding i.e. fix the max seq len to + # the longest sequence in the batch + actual_max_seq_len = max_seq_len + inputs.data = new_input_ids[:, :actual_max_seq_len] + print(f"actual_max_seq_len: {actual_max_seq_len}") + + # For Paged Attention, make the valid sequences, multiple of 64 + # inputs.data = new_input_ids[:, :4].repeat(1, 16) + + def _next_64_multiply(self, x): + return ((x + 63) // 64) * 64 + + # This function is overriden in TeGEmmaForCausalLMCudaGraphs. + def _create_hidden_states_buffer(self, input_ids: torch.Tensor): + tensor = torch.empty( + (input_ids.shape[0], input_ids.shape[1], self.hidden_size), + device="cuda", + dtype=torch.float32, + ) + # import pdb; pdb.set_trace() + return tensor + + # This function is overriden in TeGEmmaForCausalLMCudaGraphs. + def _create_inference_params(self, *args, **kwargs): + infer_params = InferenceParams(*args, **kwargs) + + # max_batch_size = kwargs["max_batch_size"] + + # Initialize some legacy params + # _allocator = StaticBufferAllocator() + # infer_params.cached_sequence_lengths = _allocator((max_batch_size,), dtype=torch.int32, device="cuda") + # infer_params.input_sequence_lengths = _allocator((max_batch_size,), dtype=torch.int32, device="cuda") + + # These are updated in setup_cache_params_from_infer_params and they should be static for + # the duration of the context as well as the generation phase. + # infer_params.cu_seqlens_q, infer_params.cu_seqlens_kv, infer_params.cu_seqlens_q_padded, infer_params.cu_seqlens_kv_padded = [ + # _allocator(max_batch_size + 1, dtype=torch.int32, device="cuda") + # for _ in range(4) + # ] + + return infer_params + + # This function is overriden in TeGEmmaForCausalLMCudaGraphs. + def _get_max_input_seq_len(self, input_ids): + return ( + input_ids.shape[1] + if not hasattr(self.config, "cuda_graphs_static_max_context_len") + else self.config.cuda_graphs_static_max_context_len + ) + + # The buffer for generation is some part (beginning) of hidden states buffer. + # This function returns pointer to it and also copies there data if provided. + def _get_generation_buffer(self, hidden_states_buffer, data_to_copy=None): + # hidden_states_buffer has shape [b, s, hd] + # generation_buffer will have shape [b, 1, hd] + # Notice that "generation_buffer = hidden_states_buffer[:, 0, :].unsqueeze(1)" + # will return uncontiguous buffer, which we want to avoid. + output = hidden_states_buffer.view(-1)[ + : hidden_states_buffer.shape[0] * hidden_states_buffer.shape[2] + ] + if data_to_copy is not None: + output.copy_(data_to_copy.reshape(-1)) + generation_buffer = output.view( + (hidden_states_buffer.shape[0], 1, hidden_states_buffer.shape[2]) + ) + return generation_buffer + + def _generate_context_phase(self, input_ids: torch.Tensor, inference_params: InferenceParams): + # import pdb; pdb.set_trace() + hidden_states = self._create_hidden_states_buffer(input_ids) + hidden_states.data[:] = self.model.embed_tokens(input_ids) + + # We need to update offsets before every forward pass to make cache work properly. + lengths = input_ids.ne(0).sum(dim=1) + + # import pdb; pdb.set_trace() + if self.config.qkv_format == "thd": + # inference_params.setup_before_new_input( + # lengths_tensor=lengths, max_input_length=input_ids.shape[1] + # ) + lengths = input_ids.ne(0).sum(dim=1) + inference_params.pre_step(OrderedDict(zip(list(range(len(lengths))), lengths.tolist()))) + else: + inference_params.setup_before_new_input(length=input_ids.shape[1]) + + logits, hs_buffer = self._model_context_phase( + hidden_states, + attention_mask=((input_ids == 0) if self.config.qkv_format != "thd" else None), + attn_mask_type="padding_causal" if self.config.qkv_format == "thd" else "arbitrary", + ) + + # We choose logits coresponding with last token in each sequence, + # which have various lengths - they are stored in (inference_params.incoming_seq_len - 1) + # Tensor when qkv_format == "thd" and + # they are the last token in the sequence when qkv_format != "thd". + # import pdb; pdb.set_trace() + if self.config.qkv_format == "thd": + logits = logits[torch.arange(logits.size(0)), lengths - 1, :] + else: + logits = logits[:, -1, :] + + next_tokens = torch.argmax(logits, dim=1) + + # self.hidden_states have shape [b, s, hd]. + # We return hidden state for the last token - output has shape [b, 1, hd] + hidden_states = self._get_generation_buffer( + hidden_states, self.model.embed_tokens(next_tokens) + ) + return hidden_states, next_tokens + + def _make_mask_one_token_longer(self, mask): + return torch.cat( + [mask, torch.zeros(mask.size(0), 1, 1, 1, dtype=torch.bool, device=mask.device)], dim=-1 + ) + + @torch.no_grad() + def generate( + self, + input_ids: Optional[torch.Tensor] = None, + pad_token_id: int = 0, + max_new_tokens: int = 0, + *args, + **kwargs, + ): + self.eval() + + # We need both autocasts: FP8 for operations that can run in lower precision + # and BF16 for those that cannot. + with autocast(dtype=torch.bfloat16, cache_enabled=False), te.pytorch.fp8_autocast( + enabled=self.config.fp8, fp8_recipe=self.fp8_recipe if self.config.fp8 else None + ): + + lengths = torch.sum(input_ids.ne(pad_token_id), dim=-1).squeeze() # [s] + + batch_size, max_input_sequence_len = input_ids.shape[0], self._get_max_input_seq_len( + input_ids + ) + + # This is not needed since the padding to the left is already done in utils.py + # # Pad input_ids with zeros on the left to match max_input_sequence_len + # # This adds padding tokens (0) to the left side of each sequence in the batch + # # Shape goes from [batch_size, seq_len] to [batch_size, max_input_sequence_len] + # input_ids = F.pad( + # input_ids, (max_input_sequence_len - input_ids.shape[1], 0), "constant", 0 + # ) + + if self.config.qkv_format == "thd": + # For thd layout padding is at the end, otherwise at the beginning. + TEGemmaForCausalLM._padding_to_end( + input_ids, + lengths, + max_seq_len=( + self.config.cuda_graphs_static_max_context_len + if self.config.generation_cuda_graphs + else None + ), + ) + + # import pdb; pdb.set_trace() + + # InferenceParams is a cache, where keys and values of previous tokens are stored. + # Moreover it stores length of both already generated and input sequences. + inference_params = self._create_inference_params( + max_batch_size=batch_size, + # num_layers=self.config.num_hidden_layers, + max_sequence_length=self._next_64_multiply(max_input_sequence_len + max_new_tokens), + num_heads_kv=self.config.num_key_value_heads, + # num_heads_q=self.config.num_attention_heads, + head_dim_v=self.config.head_dim, + head_dim_k=self.config.head_dim, + dtype=torch.bfloat16, + is_paged=self.config.is_paged, + page_size=64, + total_num_pages=64, # 64 * 64 (max_sequence_length) / 64 (page_size) + # is_cuda_graph=False + ) + + # def init_cache_params_in_infer_params(inference_params): + # _allocator = StaticBufferAllocator() + # inference_params.cached_sequence_lengths = _allocator( + # (batch_size,), dtype=torch.int32, device="cuda") + # inference_params.input_sequence_lengths = _allocator( + # (batch_size,), dtype=torch.int32, device="cuda") + + # init_cache_params_in_infer_params(inference_params) + + # inference_params.qkv_format_legacy = self.config.qkv_format + + self._model_context_phase.set_inference_params(inference_params) + self._model_generation_phase.set_inference_params(inference_params) + + print(f"context phase start") + # import pdb; pdb.set_trace() + hidden_states, next_tokens = self._generate_context_phase(input_ids, inference_params) + + print(f"context phase done") + # Generation phase. + if self.config.qkv_format == "thd": + # inference_params.setup_before_new_input( + # lengths_tensor=torch.ones((next_tokens.shape[0],), device="cuda"), + # max_input_length=1, + # ) + lengths_tensor = torch.ones((next_tokens.shape[0],), dtype=int) + inference_params.pre_step( + OrderedDict(zip(list(range(len(lengths_tensor))), lengths_tensor.tolist())) + ) + else: + inference_params.setup_before_new_input(length=1) + + output_tokens = [next_tokens] + + mask = None + if self.config.qkv_format != "thd": + mask = (input_ids == 0).unsqueeze(1).unsqueeze(1) + + for _ in range(max_new_tokens): + if self.config.qkv_format != "thd": + # It will not work with cuda graphs, but it is not used for thd qkv_format. + # Attention mask in bshd needs attn_mask increased by 1 to + # include the next token to be generated + mask = self._make_mask_one_token_longer(mask) + + # setup_cache_params_from_infer_params(inference_params, input_ids) + # @sudhakars: could create position_ids from mask here + next_tokens = self._model_generation_phase( + hidden_states, + mask, + attn_mask_type="padding" if self.config.qkv_format == "thd" else "arbitrary", + ) + + # self.inference_params contains for example kv_cache. + # This needs to be called before every pass, + # to update the information of sequence lengths. + # Here we increase sequence offsets by one, + # because we generated one token for every sequence. + if self.config.qkv_format == "thd": + # self.inference_params.setup_before_new_input( + # lengths_tensor=torch.ones((next_tokens.shape[0],), device="cuda"), + # max_input_length=1, + # ) + lengths_tensor = torch.ones((next_tokens.shape[0],), dtype=int) + inference_params.pre_step( + OrderedDict(zip(list(range(len(lengths_tensor))), lengths_tensor.tolist())) + ) + else: + inference_params.setup_before_new_input(length=1) + # next_tokens is static output tensor, so we need to clone it + # - it gets changed every iteration. + output_tokens.append(next_tokens.clone()) + + result = torch.cat((input_ids, torch.stack(output_tokens).permute([1, 0])), dim=1) + return result + + def forward(self, *args, **kwargs): + self._model_context_phase.set_inference_params(None) + hidden_states = self.model.embed_tokens(kwargs["input_ids"]) + logits = self._model_context_phase( + hidden_states, + attention_mask=( + (kwargs["input_ids"] == 0) if self.config.qkv_format != "thd" else None + ), + attn_mask_type="arbitrary", + ) + return logits + + +class TEGemmaForCausalLMCudaGraphs(TEGemmaForCausalLM): + """ + TEGemmaForCausalLMCudaGraphs is the version of the class TEGemmaForCausalLM + using CUDA Graphs to speed it up. We need to make one trade-off. + Namely, batch_size, max_seq_len and max_context_seq_len need to be static. + It is necessary to run generation with the same value of + these variables that we recorded graph on. + """ + + def __init__(self, config: GemmaConfig): + super().__init__(config) + assert ( + config.qkv_format == "thd" + ), "Generation with CUDA Graphs are implemented only for thd format." + + # Preparation of the static buffers. + self.config = config + self.hidden_states_buffer = torch.empty( + ( + self.config.cuda_graphs_static_batch_size, + self.config.cuda_graphs_static_max_context_len, + self.config.hidden_size, + ) + ).cuda() + # This is in fact part of the buffer for hidden_states. + self.generation_buffer = self._get_generation_buffer(self.hidden_states_buffer) + # self.inference_params = InferenceParams( + # max_batch_size=config.cuda_graphs_static_batch_size, + # max_sequence_length=config.cuda_graphs_static_max_seq_len, + # qkv_format="thd", + # ) + self.inference_params = InferenceParams( + max_batch_size=self.config.cuda_graphs_static_batch_size, + # num_layers=self.config.num_hidden_layers, + max_sequence_length=self.config.cuda_graphs_static_max_seq_len, + num_heads_kv=self.config.num_key_value_heads, + # num_heads_q=self.config.num_attention_heads, + head_dim_v=self.config.head_dim, + head_dim_k=self.config.head_dim, + dtype=torch.bfloat16, + is_paged=self.config.is_paged, + page_size=64, + total_num_pages=64, # 64 * 64 (max_sequence_length) / 64 (page_size) + # is_cuda_graph=False + ) + + ## Taken from TEGemmaForCausalLM above + # max_batch_size = self.config.cuda_graphs_static_batch_size + # # Initialize some legacy params + # _allocator = StaticBufferAllocator() + # self.inference_params.cached_sequence_lengths = _allocator((max_batch_size,), dtype=torch.int32, device="cuda") + # self.inference_params.input_sequence_lengths = _allocator((max_batch_size,), dtype=torch.int32, device="cuda") + + # self.inference_params.cu_seqlens_q, self.inference_params.cu_seqlens_kv, self.inference_params.cu_seqlens_q_padded, self.inference_params.cu_seqlens_kv_padded = [ + # _allocator(max_batch_size + 1, dtype=torch.int32, device="cuda") + # for _ in range(4) + # ] + + # def init_cache_params_in_infer_params(inference_params): + # inference_params.cached_sequence_lengths = torch.zeros( + # (batch_size,), device="cuda", dtype=torch.int32) + # inference_params.input_sequence_lengths = torch.zeros( + # (batch_size,), device="cuda", dtype=torch.int32) + # init_cache_params_in_infer_params(inference_params) + + # self.inference_params.qkv_format_legacy = self.config.qkv_format + + self._model_generation_phase.set_inference_params(self.inference_params) + self._model_context_phase.set_inference_params(self.inference_params) + + def record(self): + # We want to record model in training=False, because it will be used in generation. + self.eval() + + # Here "the trick" happens. We override methods from TEGemmaForCausalLM + # with their recorded version. After invocation of each of them, + # captured graph will be replayed with minimal usage of CPU, + # what will lead to huge speedup. + input_shape = ( + self.config.cuda_graphs_static_batch_size, + self.config.cuda_graphs_static_max_context_len, + ) + # self.inference_params.reset() + # self.inference_params.setup_before_new_input( + # lengths_tensor=torch.tensor(input_shape[0] * [input_shape[1]], device="cuda"), + # max_input_length=input_shape[1], + # ) + + # [1] Should be same as lengths_tensor from TEGemmaForCausalLM + lengths = torch.tensor(input_shape[0] * [input_shape[1]], device="cuda", dtype=torch.int32) + max_input_length = input_shape[1] + + self.inference_params.pre_step( + OrderedDict(zip(list(range(len(lengths))), lengths.tolist())) + ) + + print(f"context phase recording start") + # self._model_context_phase.model.layers = torch.nn.ModuleList([ + # self.record_graph( + # layer, + # self.hidden_states_buffer, + # self_attn_mask_type="padding_causal", + # inference_params=self.inference_params + # ) + # for layer in self._model_context_phase.model.layers + # ]) + self._model_context_phase = self.record_graph( + self._model_context_phase, self.hidden_states_buffer, attn_mask_type="padding_causal" + ) # CUDA Graphs recording + + print(f"context phase recording done") + input_shape = (self.config.cuda_graphs_static_batch_size, 1) + # self.inference_params.reset() + # self.inference_params.setup_before_new_input( + # lengths_tensor=torch.tensor(input_shape[0] * [input_shape[1]], device="cuda"), + # max_input_length=input_shape[1], + # ) + lengths = torch.tensor(input_shape[0] * [1], device="cuda", dtype=torch.int32) + max_input_length = input_shape[1] + + self.inference_params.pre_step( + OrderedDict(zip(list(range(len(lengths))), lengths.tolist())) + ) + + self._model_generation_phase = self.record_graph( + self._model_generation_phase, self.generation_buffer, attn_mask_type="padding" + ) # CUDA Graphs recording + + """ + Functions _create_hidden_states_buffer and _create_inference_params + from base class are overriden to make hidden_states and inference_params static + - not changing their position in memory between every invocation. + """ + + def _create_hidden_states_buffer(self, *args, **kwargs): + return self.hidden_states_buffer + + def _create_inference_params(self, *args, **kwargs): + self.inference_params.reset() + return self.inference_params + + def _get_max_input_seq_len(self, _): + return self.config.cuda_graphs_static_max_context_len + + @torch.no_grad() + def record_graph(self, function, input_tensor, **sample_kwargs): + # function is invoked on argument (self.hidden_states,) and all kernels are recorded. + # record_graph() returns captured function, which can be run later with lower of th CPU. + fp8_format = Format.HYBRID + fp8_recipe = DelayedScaling( + fp8_format=fp8_format, amax_history_len=1024, amax_compute_algo="max" + ) + + # We need both autocasts: FP8 for operations that can run in lower precision + # and BF16 for those that cannot. + with autocast(dtype=torch.bfloat16, cache_enabled=False): + graphed_function = te.pytorch.make_graphed_callables( + function, + (input_tensor,), + fp8_enabled=self.config.fp8, + fp8_recipe=fp8_recipe, + allow_unused_input=True, + num_warmup_iters=5, + sample_kwargs=sample_kwargs, + ) + return graphed_function diff --git a/docs/examples/te_gemma/te_llama.py b/docs/examples/te_gemma/te_llama.py new file mode 100755 index 0000000000..637f4f574c --- /dev/null +++ b/docs/examples/te_gemma/te_llama.py @@ -0,0 +1,826 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from contextlib import contextmanager + +from typing import Optional +from functools import partial +from collections import OrderedDict + +import torch +import transformer_engine as te +from transformer_engine.pytorch.attention import InferenceParams, RotaryPositionEmbedding +from transformer_engine.common.recipe import Format, DelayedScaling +from torch.cuda.amp import autocast + +import transformers +from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaConfig, LlamaModel + +import torch.nn.functional as F + + +def setup_cache_params_from_infer_params(inference_params, lengths_tensor, max_input_length): + """ + Converts the `input_ids` to variables like `cu_seqlens_q/kv`, etc. which + will be used later. + + (Currently a hack, this should be reformatted to a better method) + """ + + assert ( + lengths_tensor is not None and max_input_length is not None + ), 'lengths_tensor and max_input_length should not be none for qkv_format = "thd"' + torch.add( + inference_params.cached_sequence_lengths, + inference_params.input_sequence_lengths, + out=inference_params.cached_sequence_lengths, + ) + inference_params.input_sequence_lengths.copy_(lengths_tensor) + inference_params.max_incoming_seq_len = max_input_length + + max_seqlen_q, max_seqlen_kv = ( + inference_params.max_incoming_seq_len, + inference_params.max_sequence_length, + ) + + # # Allocation of buffers, it works correctly with CUDA Graphs. + _allocator = StaticBufferAllocator() + NR_BUFFERS = 4 + + cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded = [ + _allocator(inference_params.max_batch_size + 1, dtype=torch.int32, device="cuda") + for _ in range(NR_BUFFERS) + ] + + torch.cumsum(inference_params.input_sequence_lengths, dim=0, out=cu_seqlens_q[1:]) + torch.cumsum( + inference_params.cached_sequence_lengths + inference_params.input_sequence_lengths, + dim=0, + out=cu_seqlens_kv[1:], + ) + # If layer has shape [b * s_layer, h, d] + # offsets are of the form [k * s_layer * h * d for k = 0, ..., batch_size] + cu_seqlens_q_padded.copy_( + torch.arange(0, inference_params.max_batch_size + 1, device="cuda") * max_seqlen_q + ) + cu_seqlens_kv_padded.copy_( + torch.arange(0, inference_params.max_batch_size + 1, device="cuda") * max_seqlen_kv + ) + + # inference_params.step_dict = OrderedDict(zip(list(range(len(lengths_tensor))), lengths_tensor.tolist())) + inference_params.pre_step( + OrderedDict(zip(list(range(len(lengths_tensor))), lengths_tensor.tolist())) + ) + + # print(inference_params.step_dict) + + def get_cache_params_in_infer_params(): + return ( + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + ) + + # For the time being, create an ad-hoc field in `inference_params` to get the variables. + # @sudhakars: to create a better way later. + inference_params.get_cache_params_from_infer_params = get_cache_params_in_infer_params + + +# This class has been modified from +# https://github.com/huggingface/transformers/blob/98adf24883b007c2a7fb17bab1c01b1614673433/src/transformers/models/gemma/modeling_gemma.py +class LlamaRotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim) + ) + self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + self.inv_freq.to(x.device) + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + ) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = ( + device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + return emb.unsqueeze(2) # should return in [b, s, 1, d] format + + +class StaticBufferAllocator(torch.nn.Module): + """ + This class is used when we use te.make_graphed_callable(). + CUDA Graphs require all tensors to be static. Neverthlessly, + torch API make_graphed_callable() takes care of output of torch modules, + and makes them static. Thus by wrapping allocation of memory into + torch.nn.Module, we can greatly simplify our code. + """ + + # pylint: disable=no-self-use + def forward(self, size, dtype, device): + """ + Return buffer of given size, dtype and device. + """ + return torch.zeros(size, dtype=dtype, device=device) + + +class TELlamaDecoderLayer(te.pytorch.TransformerLayer): + """ + Wrapper class over TE's `TransformerLayer`. This makes the wrapper very + similar to HF's `LlamaDecoderLayer` and easier to replace it in the code. + + Args: + config: LlamaConfig + args: positional args (for compatibility with `LlamaDecoderLayer`) + kwargs: keyword args (for compatibility with `LlamaDecoderLayer`) + """ + + def __init__(self, config: LlamaConfig, layer_idx: int, *args, **kwargs): + + self.llama_config = config + self.head_dim = self.llama_config.hidden_size // self.llama_config.num_attention_heads + + super().__init__( + hidden_size=config.hidden_size, + ffn_hidden_size=config.intermediate_size, + num_attention_heads=config.num_attention_heads, + bias=False, # LLaMA specific + layernorm_epsilon=config.rms_norm_eps, + hidden_dropout=0, + attention_dropout=0, + fuse_qkv_params=config.fuse_qkv_params, + normalization="RMSNorm", + activation="swiglu", # LLaMA specific + # attn_input_format=config.qkv_format, + attn_input_format="bshd", + num_gqa_groups=config.num_key_value_heads, + kv_channels=self.head_dim, # LLaMA specific + layer_number=( + layer_idx + 1 + ), # Layer numbers in TE starts from 1, not 0 like in the HF. + zero_centered_gamma=True, # LLaMA specific + ) + + def alloc(self, size, dtype, device): + """ + Allocated the buffer and works correctly with CUDA Graphs. + """ + return self._allocator(size, dtype, device) + + def forward(self, *args, **kwargs): # We need to additionally pass positional encoding. + + if "self_attn_mask_type" in kwargs: + attn_mask_type = kwargs["self_attn_mask_type"] + else: + attn_mask_type = "whatever_default_is" + + if attn_mask_type == "arbitrary": + # @sudhakars: following logic doesn't work for `thd` + attn_mask = kwargs["attention_mask"] + attention_mask_inv = ~attn_mask + generation_case = torch.tensor(torch.tensor(attn_mask.shape).shape).item() > 2 + + if generation_case: + # @sudhakars: for some reason, `attention_mask` for generation is of the + # form [b, 1, 1, s]. + attention_mask_inv = attention_mask_inv.squeeze(1).squeeze(1) + assert torch.tensor(torch.tensor(attention_mask_inv.shape).shape).item() == 2 + + # Create `position_ids` on the fly using `attention_mask` since HF + # does the same in generation logic. + position_ids = attention_mask_inv.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask_inv == 0, 1) + + if "position_ids" in kwargs and kwargs["position_ids"] is not None: + assert torch.all( + torch.eq(position_ids, kwargs["position_ids"]) + ), "position ids don't match match exactly!" + + # convert [b, s] to [b, 1, s, s] since `arbitrary` is only set for + # context phase and context phase gets [b, s] sized attn mask + seq_len = ( + 1 + if torch.tensor(torch.tensor(attn_mask.shape).shape).item() > 2 + else attention_mask_inv.shape[1] + ) + arbitrary_attn_mask = torch.zeros( + attention_mask_inv.shape[0], 1, seq_len, attention_mask_inv.shape[1] + ).bool() + for sample_idx in range(attn_mask.shape[0]): + pad_len = attn_mask[sample_idx].sum().int().item() + # set the columns to padded + arbitrary_attn_mask[sample_idx, :, :, :pad_len] = True + # set the rows to padded + if not generation_case: + arbitrary_attn_mask[sample_idx, :, :pad_len, :] = True + arbitrary_attn_mask[sample_idx] = torch.tril( + arbitrary_attn_mask[sample_idx].logical_not() + ).logical_not() + + # Update the attention mask to arbitrary + kwargs["attention_mask"] = arbitrary_attn_mask.cuda() + + # @sudhakars: `max_position_embeddings` is not even used inside GemmaRotaryEmbedding + # @sudhakars: change the hardcoded `dim` to something like config.head_dim + te_rope_emb = LlamaRotaryEmbedding( + dim=self.head_dim, max_position_embeddings=self.llama_config.max_position_embeddings + ).cuda() + te_rope_emb = te_rope_emb(args[0], position_ids.cuda()) + else: + # When the `attention_mask` is not `arbitrary`, then for the purpose + # of this tutorial, we're using `padding_causal` (for context) and + # `padding` (for generation) + # @sudhakars: find a better way to provide the `tensor_format` + te_rope_emb = RotaryPositionEmbedding(self.head_dim)( # Use self.head_dim + max_seq_len=self.llama_config.max_position_embeddings + ).cuda() + + inference_params = kwargs["inference_params"] + # @sudhakars: big assumption that the input is "sbhd" + # batch_size = args[0].shape[0] + if inference_params.qkv_format_legacy == "thd": + ( + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + ) = inference_params.get_cache_params_from_infer_params() + + # this args cannot be passed to TransformerLayer + keys_to_remove = [ + "position_ids", + "past_key_value", + "output_attentions", + "use_cache", + "cache_position", + ] + for key in keys_to_remove: + kwargs.pop(key, None) + + # import pdb; pdb.set_trace() + # We need to return tuple to be compatible with HF. + return ( + super().forward( + *args, + rotary_pos_emb=te_rope_emb, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + **kwargs + ), + ) + + +class StaticLlamaModel(torch.nn.Module): + """ + StaticLlama is based of HF LlamaModel class. + It is adjusted to work properly with CUDA Graphs. + """ + + def __init__( + self, + model: LlamaModel, + dtype: torch.dtype, + mask: torch.Tensor, + lm_head: torch.nn.Module, + ): + super().__init__() + self.model = model + self.llama_config = model.config # Store LlamaConfig + self.normalizer = torch.tensor(self.llama_config.hidden_size**0.5, dtype=dtype) + self.mask = mask + self.lm_head = lm_head + + def set_inference_params(self, inference_params): + self.inference_params = inference_params + + # @sudhakars: is `arbitrary` fine being the default here? + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor = None, + attn_mask_type: str = "arbitrary", + ): + # import pdb; pdb.set_trace() + if hidden_states.shape[1] > 1: + torch.save(hidden_states, "input_ctxt.pth") + + with torch.no_grad(): + # static operation - for CUDA graphs + hidden_states.data[:] = hidden_states.data[:] * self.normalizer + + for i, decoder_layer in enumerate(self.model.layers): + hidden_states.data[:] = decoder_layer( + hidden_states, + attention_mask=attention_mask, + self_attn_mask_type=self.mask if attn_mask_type is None else attn_mask_type, + inference_params=self.inference_params, + )[ + 0 + ] # static copy - for CUDA graphs + + hidden_states.copy_(self.model.norm(hidden_states)) # static copy - for CUDA graphs + logits = self.lm_head(hidden_states) + logits = logits.float() + return logits + + +class LlamaGenerator(torch.nn.Module): + """ + LlamaGenerator gets one layer of embeddins, + makes forward pass and returns next tokens. + """ + + def __init__( + self, model: LlamaModel, lm_head: torch.nn.Module, dtype: torch.dtype, qkv_format: str + ): + super().__init__() + self.model = model + self.llama_layers = StaticLlamaModel(model, dtype, "arbitrary", lm_head) + self.qkv_format = qkv_format + + def set_inference_params(self, inference_params): + self.inference_params = inference_params + self.llama_layers.set_inference_params(inference_params) + + # @sudhakars: is `arbitrary` a good default value here? + def forward( + self, hidden_states: torch.Tensor, mask: torch.Tensor = None, mask_type: str = "arbitrary" + ): + logits = self.llama_layers(hidden_states, attention_mask=mask, attn_mask_type=mask_type) + + assert logits.shape[0] == hidden_states.shape[0] # b + assert logits.shape[1] == hidden_states.shape[1] # seq_len + # logits.shape[2] = number of tokens + logits = logits[:, -1, :] + next_tokens = torch.argmax(logits, dim=1) + + # static copy for CUDA graphs + hidden_states.copy_(self.model.embed_tokens(next_tokens).unsqueeze(1)) + + # self.inference_params contains for example kv_cache. + # This needs to be called before every pass, + # to update the information of sequence lengths. + # Here we increase sequence offsets by one, + # because we generated one token for every sequence. + if self.qkv_format == "thd": + # self.inference_params.setup_before_new_input( + # lengths_tensor=torch.ones((next_tokens.shape[0],), device="cuda"), + # max_input_length=1, + # ) + setup_cache_params_from_infer_params( + self.inference_params, + lengths_tensor=torch.ones((next_tokens.shape[0],), dtype=int), + max_input_length=1, + ) + else: + self.inference_params.setup_before_new_input(length=1) + + return next_tokens + + +class PartialForwardWrapper(torch.nn.Module): + """ + This class wraps a `torch.nn.Module` while partially modifying its `forward` + + CUDAGraphs' `make_graphed_callables` method takes in a module but if only + `functools.partial` is used to wrap the module, it changes the modules' + type and that interferes with the `make_graphed_callables` intrinsics. + """ + + def __init__(self, module, **kwargs): + super().__init__() + self.module = module + self.partial_forward = partial(self.module.forward, **kwargs) + + def __call__(self, *args, **kwargs): + return self.partial_forward(*args, **kwargs) + + # @sudhakars: should we use better abstraction? + def set_inference_params(self, *args, **kwargs): + return self.module.set_inference_params(*args, **kwargs) + + +@contextmanager +def replace_decoder(te_decoder_cls): + """ + Replace `LlamaDecoderLayer` with custom `TELlamaDecoderLayer`. + """ + original_llama_decoder_cls = transformers.models.llama.modeling_llama.LlamaDecoderLayer + transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decoder_cls + try: + yield + finally: + transformers.models.llama.modeling_llama.LlamaDecoderLayer = original_llama_decoder_cls + + +class TELlamaForCausalLM(LlamaForCausalLM): + """ + Causal LM created with `LlamaModel`. The underlying `LlamaDecoderLayer` + class is monkey-patched with `TELlamaDecoderLayer` class before + initializing the causal LM with `LlamaForCausalLM`. + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + with replace_decoder(te_decoder_cls=TELlamaDecoderLayer): + super().__init__(config) + self.config = config + self.to(torch.bfloat16).cuda() + self.hidden_size = config.hidden_size + self._model_generation_phase = LlamaGenerator( + lm_head=self.lm_head, + model=self.model, + dtype=torch.bfloat16, + qkv_format=config.qkv_format, + ) + self._model_context_phase = StaticLlamaModel( + self.model, torch.bfloat16, "arbitrary", self.lm_head + ) + + if self.config.fp8: + self.fp8_recipe = DelayedScaling( + fp8_format=Format.HYBRID, amax_history_len=16, amax_compute_algo="max" + ) + + @staticmethod + def _padding_to_end(inputs, lengths): + """ + Gets the tensor with sequence padded from the beginning and + return tensor padded from its end. + + Parameters + ---------- + inputs : Tensor, tensor with shape [b, s] containing token numbers. + It's padded from the beggining. + lengths: Tensor, tensor with shape [s] with lengths of the sequences. + + """ + max_seq_len = torch.max(lengths) + batch_size, max_seq_len = inputs.shape + new_input_ids = inputs.clone() + for i in range(batch_size): + new_input_ids[i, : lengths[i]] = inputs[i, (max_seq_len - lengths[i]) : max_seq_len] + new_input_ids[i, lengths[i] :] = inputs[i, 0 : (max_seq_len - lengths[i])] + + # Disable the input preparation that involves extra padding + # inputs.copy_(new_input_ids) + + # Trim the inputs to no extra padding i.e. fix the max seq len to + # the longest sequence in the batch + actual_max_seq_len = inputs.ne(0).sum(dim=1).max() + inputs.data = new_input_ids[:, :actual_max_seq_len] + + # For Paged Attention, make the valid sequences, multiple of 64 + # inputs.data = new_input_ids[:, :4].repeat(1, 16) + + def _next_64_multiply(self, x): + return ((x + 63) // 64) * 64 + + # This function is overriden in TeGEmmaForCausalLMCudaGraphs. + def _create_hidden_states_buffer(self, input_ids: torch.Tensor): + return torch.empty( + (input_ids.shape[0], input_ids.shape[1], self.hidden_size), + device="cuda", + dtype=torch.float32, + ) + + # This function is overriden in TeGEmmaForCausalLMCudaGraphs. + def _create_inference_params(self, *args, **kwargs): + infer_params = InferenceParams(*args, **kwargs) + + max_batch_size = kwargs["max_batch_size"] + + # Initialize some legacy params + infer_params.cached_sequence_lengths = torch.zeros( + (max_batch_size,), device="cuda", dtype=torch.int32 + ) + infer_params.input_sequence_lengths = torch.zeros( + (max_batch_size,), device="cuda", dtype=torch.int32 + ) + + return infer_params + + # This function is overriden in TeGEmmaForCausalLMCudaGraphs. + def _get_max_input_seq_len(self, input_ids): + return input_ids.shape[1] + + # The buffer for generation is some part (beginning) of hidden states buffer. + # This function returns pointer to it and also copies there data if provided. + def _get_generation_buffer(self, hidden_states_buffer, data_to_copy=None): + # hidden_states_buffer has shape [b, s, hd] + # generation_buffer will have shape [b, 1, hd] + # Notice that "generation_buffer = hidden_states_buffer[:, 0, :].unsqueeze(1)" + # will return uncontiguous buffer, which we want to avoid. + output = hidden_states_buffer.view(-1)[ + : hidden_states_buffer.shape[0] * hidden_states_buffer.shape[2] + ] + if data_to_copy is not None: + output.copy_(data_to_copy.reshape(-1)) + generation_buffer = output.view( + (hidden_states_buffer.shape[0], 1, hidden_states_buffer.shape[2]) + ) + return generation_buffer + + def _generate_context_phase(self, input_ids: torch.Tensor, inference_params: InferenceParams): + hidden_states = self._create_hidden_states_buffer(input_ids) + hidden_states.data[:] = self.model.embed_tokens(input_ids) + + # We need to update offsets before every forward pass to make cache work properly. + lengths = input_ids.ne(0).sum(dim=1) + # import pdb; pdb.set_trace() + if self.config.qkv_format == "thd": + # inference_params.setup_before_new_input( + # lengths_tensor=lengths, max_input_length=input_ids.shape[1] + # ) + lengths = input_ids.ne(0).sum(dim=1) + max_input_length = input_ids.shape[1] + setup_cache_params_from_infer_params(inference_params, lengths, max_input_length) + else: + inference_params.setup_before_new_input(length=input_ids.shape[1]) + + logits = self._model_context_phase( + hidden_states, + attention_mask=((input_ids == 0) if self.config.qkv_format != "thd" else None), + attn_mask_type="padding_causal" if self.config.qkv_format == "thd" else "arbitrary", + ) + + # We choose logits coresponding with last token in each sequence, + # which have various lengths - they are stored in (inference_params.incoming_seq_len - 1) + # Tensor when qkv_format == "thd" and + # they are the last token in the sequence when qkv_format != "thd". + if self.config.qkv_format == "thd": + logits = logits[ + torch.arange(logits.size(0)), inference_params.input_sequence_lengths - 1, : + ] + else: + logits = logits[:, -1, :] + torch.save(logits, "logits_ctxt.pth") + next_tokens = torch.argmax(logits, dim=1) + + # self.hidden_states have shape [b, s, hd]. + # We return hidden state for the last token - output has shape [b, 1, hd] + hidden_states = self._get_generation_buffer( + hidden_states, self.model.embed_tokens(next_tokens) + ) + return hidden_states, next_tokens + + def _make_mask_one_token_longer(self, mask): + return torch.cat( + [mask, torch.zeros(mask.size(0), 1, 1, 1, dtype=torch.bool, device=mask.device)], dim=-1 + ) + + @torch.no_grad() + def generate( + self, + input_ids: Optional[torch.Tensor] = None, + pad_token_id: int = 0, + max_new_tokens: int = 0, + *args, + **kwargs + ): + self.eval() + + # We need both autocasts: FP8 for operations that can run in lower precision + # and BF16 for those that cannot. + with autocast(dtype=torch.bfloat16, cache_enabled=False), te.pytorch.fp8_autocast( + enabled=self.config.fp8, fp8_recipe=self.fp8_recipe if self.config.fp8 else None + ): + + lengths = torch.sum(input_ids.ne(pad_token_id), dim=-1).squeeze() # [s] + # input_ids = F.pad( + # input_ids, (max_input_sequence_len - input_ids.shape[1], 0), "constant", 0 + # ) + + if self.config.qkv_format == "thd": + # For thd layout padding is at the end, otherwise at the beginning. + TELlamaForCausalLM._padding_to_end(input_ids, lengths) + + batch_size, max_input_sequence_len = input_ids.shape[0], self._get_max_input_seq_len( + input_ids + ) + # import pdb; pdb.set_trace() + + # InferenceParams is a cache, where keys and values of previous tokens are stored. + # Moreover it stores length of both already generated and input sequences. + head_dim = self.config.hidden_size // self.config.num_attention_heads + inference_params = self._create_inference_params( + max_batch_size=batch_size, + # num_layers=self.config.num_hidden_layers, + max_sequence_length=self._next_64_multiply(max_input_sequence_len + max_new_tokens), + num_heads_kv=self.config.num_key_value_heads, + # num_heads_q=self.config.num_attention_heads, + head_dim_v=head_dim, + head_dim_k=head_dim, + dtype=torch.bfloat16, + is_paged=True, + page_size=64, + total_num_pages=64 * 3, # 64 * 64 (max_sequence_length) / 64 (page_size) + # is_cuda_graph=False + ) + + def init_cache_params_in_infer_params(inference_params): + inference_params.cached_sequence_lengths = torch.zeros( + (batch_size,), device="cuda", dtype=torch.int32 + ) + inference_params.input_sequence_lengths = torch.zeros( + (batch_size,), device="cuda", dtype=torch.int32 + ) + + init_cache_params_in_infer_params(inference_params) + inference_params.qkv_format_legacy = self.config.qkv_format + + self._model_context_phase.set_inference_params(inference_params) + self._model_generation_phase.set_inference_params(inference_params) + + hidden_states, next_tokens = self._generate_context_phase(input_ids, inference_params) + + # Generation phase. + if self.config.qkv_format == "thd": + # inference_params.setup_before_new_input( + # lengths_tensor=torch.ones((next_tokens.shape[0],), device="cuda"), + # max_input_length=1, + # ) + setup_cache_params_from_infer_params( + inference_params, + lengths_tensor=torch.ones((next_tokens.shape[0],), dtype=int), + max_input_length=1, + ) + else: + inference_params.setup_before_new_input(length=1) + + output_tokens = [next_tokens] + + mask = None + if self.config.qkv_format != "thd": + mask = (input_ids == 0).unsqueeze(1).unsqueeze(1) + + for _ in range(max_new_tokens): + if self.config.qkv_format != "thd": + # It will not work with cuda graphs, but it is not used for thd qkv_format. + # Attention mask in bshd needs attn_mask increased by 1 to + # include the next token to be generated + mask = self._make_mask_one_token_longer(mask) + + # setup_cache_params_from_infer_params(inference_params, input_ids) + # @sudhakars: could create position_ids from mask here + next_tokens = self._model_generation_phase( + hidden_states, + mask, + mask_type="padding" if self.config.qkv_format == "thd" else "arbitrary", + ) + # next_tokens is static output tensor, so we need to clone it + # - it gets changed every iteration. + output_tokens.append(next_tokens.clone()) + + result = torch.cat((input_ids, torch.stack(output_tokens).permute([1, 0])), dim=1) + return result + + def forward(self, *args, **kwargs): + self._model_context_phase.set_inference_params(None) + hidden_states = self.model.embed_tokens(kwargs["input_ids"]) + logits = self._model_context_phase( + hidden_states, + attention_mask=( + (kwargs["input_ids"] == 0) if self.config.qkv_format != "thd" else None + ), + attn_mask_type="arbitrary", + ) + return logits + + +class TELlamaForCausalLMCudaGraphs(TELlamaForCausalLM): + """ + TELlamaForCausalLMCudaGraphs is the version of the class TELlamaForCausalLM + using CUDA Graphs to speed it up. We need to make one trade-off. + Namely, batch_size, max_seq_len and max_context_seq_len need to be static. + It is necessary to run generation with the same value of + these variables that we recorded graph on. + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + assert ( + config.qkv_format == "thd" + ), "Generation with CUDA Graphs are implemented only for thd format." + + # Preparation of the static buffers. + self.config = config + self.hidden_states_buffer = torch.empty( + ( + config.cuda_graphs_static_batch_size, + config.cuda_graphs_static_max_context_len, + config.hidden_size, + ) + ).cuda() + # This is in fact part of the buffer for hidden_states. + self.generation_buffer = self._get_generation_buffer(self.hidden_states_buffer) + self.inference_params = InferenceParams( + max_batch_size=config.cuda_graphs_static_batch_size, + max_sequence_length=config.cuda_graphs_static_max_seq_len, + qkv_format="thd", + ) + + self._model_generation_phase.set_inference_params(self.inference_params) + self._model_context_phase.set_inference_params(self.inference_params) + + def record(self): + # We want to record model in training=False, because it will be used in generation. + self.eval() + + # Here "the trick" happens. We override methods from TELlamaForCausalLM + # with their recorded version. After invocation of each of them, + # captured graph will be replayed with minimal usage of CPU, + # what will lead to huge speedup. + input_shape = ( + self.config.cuda_graphs_static_batch_size, + self.config.cuda_graphs_static_max_context_len, + ) + self.inference_params.reset() + self.inference_params.setup_before_new_input( + lengths_tensor=torch.tensor(input_shape[0] * [input_shape[1]], device="cuda"), + max_input_length=input_shape[1], + ) + self._model_context_phase = self.record_graph( + PartialForwardWrapper( + self._model_context_phase, + attn_mask_type=( + "padding_causal" if self.inference_params.qkv_format == "thd" else "arbitrary" + ), + ), + self.hidden_states_buffer, + ) # CUDA Graphs recording + + input_shape = (self.config.cuda_graphs_static_batch_size, 1) + self.inference_params.reset() + self.inference_params.setup_before_new_input( + lengths_tensor=torch.tensor(input_shape[0] * [input_shape[1]], device="cuda"), + max_input_length=input_shape[1], + ) + self._model_generation_phase = self.record_graph( + PartialForwardWrapper( + self._model_generation_phase, + mask_type="padding" if self.inference_params.qkv_format == "thd" else "arbitrary", + ), + self.generation_buffer, + ) # CUDA Graphs recording + + """ + Functions _create_hidden_states_buffer and _create_inference_params + from base class are overriden to make hidden_states and inference_params static + - not changing their position in memory between every invocation. + """ + + def _create_hidden_states_buffer(self, *args, **kwargs): + return self.hidden_states_buffer + + def _create_inference_params(self, *args, **kwargs): + self.inference_params.reset() + return self.inference_params + + def _get_max_input_seq_len(self, _): + return self.config.cuda_graphs_static_max_context_len + + @torch.no_grad() + def record_graph(self, function, input_tensor): + # function is invoked on argument (self.hidden_states,) and all kernels are recorded. + # record_graph() returns captured function, which can be run later with lower of th CPU. + fp8_format = Format.HYBRID + fp8_recipe = DelayedScaling( + fp8_format=fp8_format, amax_history_len=1024, amax_compute_algo="max" + ) + + # We need both autocasts: FP8 for operations that can run in lower precision + # and BF16 for those that cannot. + with autocast(dtype=torch.bfloat16, cache_enabled=False): + graphed_function = te.pytorch.make_graphed_callables( + function, + (input_tensor,), + fp8_enabled=self.config.fp8, + fp8_recipe=fp8_recipe, + allow_unused_input=True, + num_warmup_iters=3, + ) + return graphed_function diff --git a/docs/examples/te_gemma/te_llama_loading_weights.py b/docs/examples/te_gemma/te_llama_loading_weights.py new file mode 100755 index 0000000000..a5ab151f67 --- /dev/null +++ b/docs/examples/te_gemma/te_llama_loading_weights.py @@ -0,0 +1,224 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import re +import gc +import torch + +from typing import List + +from transformer_engine.pytorch.fp8 import fp8_model_init + +from transformers.modeling_utils import load_state_dict, _load_state_dict_into_model +from transformers.utils.hub import get_checkpoint_shard_files + +""" + This file contains logic of mapping the HuggingFace LlamaModel parameters + with TransformerEngine TransformerLayer. When we have initialized Transformer models + both with HF and with TE, we can copy parameters from the first to the second. +""" + + +def _load_weights_for_fp8_model(vanilla_model, hyperparams): + # The weights are loaded from the file with state_dict + # of model with weights which contains also fp8 parameters. + # The weights are in BF16 precision, but they contain fp8 metadata + # computed by the calibration procedure. + vanilla_model.load_state_dict( + torch.load(hyperparams.fp8_model_weights_filename), + strict=False, + # strict = false, because some parameters have + # multiple pointers to the same weight + # vanilla_model._model_context_phase.model + # and vanilla_model._model_generation_phase.model + ) + + +def _load_weights_for_standard_model(vanilla_model, config): + # The weights are loaded from the file with original weights. + archive_file = os.path.join(config.model_name, "model.safetensors.index.json") + resolved_archive_file, _ = get_checkpoint_shard_files(config.model_name, archive_file) + total_dict = {} + for shard_file in resolved_archive_file: + state_dict = load_state_dict(shard_file) + total_dict.update(state_dict) + + replace_params( + total_dict, + vanilla_model.state_dict(), + config, + qkv_fused_and_interleaved=config.fuse_qkv_params, + ) + # Copy parameters like embedding: + _load_state_dict_into_model(vanilla_model, total_dict, start_prefix="") + + # Force mem release. Taken from huggingface code. + del total_dict + gc.collect() + + +def load_te_model(cls, config): + """ + Custom method adapted from `from_pretrained` method in HuggingFace + Transformers repo: + https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579 + """ + + config.use_cache = False # To make TransformerLayer compatible with LlamaModel + with fp8_model_init(config.fp8_model_init): + # there we need only to create model + vanilla_model = cls(config).to(torch.bfloat16).cuda() + + # return vanilla_model + # and now we copy the weights into it + if config.fp8_model_weights_filename is not None: + _load_weights_for_fp8_model(vanilla_model, config) + else: + _load_weights_for_standard_model(vanilla_model, config) + + return vanilla_model + + +def _get_all_layer_prefixes_to_update(hf_state_dict): + """ + There are many parameters in hf_state_dict, whose name start with "model.layers.[number]." + This function extracts all strings like "model.layers.[number]." + that are starting strings of keys in hf_state_dict. + """ + all_layer_prefixes = set() + for param_key in hf_state_dict.keys(): + layer_prefix_pat = "model.layers.\d+." + m = re.match(layer_prefix_pat, param_key) + if m is not None: + all_layer_prefixes.add(m.group()) + return all_layer_prefixes + + +def replace_params(hf_state_dict, te_state_dict, config, qkv_fused_and_interleaved=False): + # collect all layer prefixes to update + all_layer_prefixes = set() + for param_key in hf_state_dict.keys(): + layer_prefix_pat = "model.layers.\d+." + m = re.match(layer_prefix_pat, param_key) + if m is not None: + all_layer_prefixes.add(m.group()) + + for layer_prefix in all_layer_prefixes: + # When loading weights into models with less number of layers, skip the + # copy if the corresponding layer doesn't exist in HF model + if layer_prefix + "input_layernorm.weight" in hf_state_dict: + te_state_dict[layer_prefix + "self_attention.layernorm_qkv.layer_norm_weight"].data[ + : + ] = hf_state_dict[layer_prefix + "input_layernorm.weight"].data[:] + + if layer_prefix + "self_attn.q_proj.weight" in hf_state_dict: + te_state_dict[layer_prefix + "self_attention.layernorm_qkv.query_weight"].data[:] = ( + hf_state_dict[layer_prefix + "self_attn.q_proj.weight"].data[:] + ) + + if layer_prefix + "self_attn.k_proj.weight" in hf_state_dict: + te_state_dict[layer_prefix + "self_attention.layernorm_qkv.key_weight"].data[:] = ( + hf_state_dict[layer_prefix + "self_attn.k_proj.weight"].data[:] + ) + + if layer_prefix + "self_attn.v_proj.weight" in hf_state_dict: + te_state_dict[layer_prefix + "self_attention.layernorm_qkv.value_weight"].data[:] = ( + hf_state_dict[layer_prefix + "self_attn.v_proj.weight"].data[:] + ) + + if layer_prefix + "self_attn.o_proj.weight" in hf_state_dict: + te_state_dict[layer_prefix + "self_attention.proj.weight"].data[:] = hf_state_dict[ + layer_prefix + "self_attn.o_proj.weight" + ].data[:] + + if layer_prefix + "post_attention_layernorm.weight" in hf_state_dict: + te_state_dict[layer_prefix + "layernorm_mlp.layer_norm_weight"].data[:] = hf_state_dict[ + layer_prefix + "post_attention_layernorm.weight" + ].data[:] + + # It may happen that gate_proj.weight and up_proj.weight will be in the different files, so we need to + # load them separately. + if layer_prefix + "mlp.gate_proj.weight" in hf_state_dict: + te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[ + : config.intermediate_size + ] = hf_state_dict[layer_prefix + "mlp.gate_proj.weight"].data + + if layer_prefix + "mlp.up_proj.weight" in hf_state_dict: + te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[ + config.intermediate_size : + ] = hf_state_dict[layer_prefix + "mlp.up_proj.weight"].data + + if layer_prefix + "mlp.down_proj.weight" in hf_state_dict: + te_state_dict[layer_prefix + "layernorm_mlp.fc2_weight"].data[:] = hf_state_dict[ + layer_prefix + "mlp.down_proj.weight" + ].data[:] + return all_layer_prefixes + + +# def replace_params(hf_state_dict, te_state_dict, config, qkv_fused_and_interleaved=False): +# """ +# Replaces params from TE TransformerLayer state_dict with corresponding parameters +# from HuggingFace LlamaModel state_dict. +# """ +# all_layer_prefixes: List[str] = _get_all_layer_prefixes_to_update(hf_state_dict) + +# head_dim = config.hidden_size // config.num_attention_heads + +# for layer_prefix in all_layer_prefixes: + +# def copy_from_ht_to_te(te_name, hf_name, start=None, end=None): +# te_state_dict[layer_prefix + te_name].data[start:end].copy_( +# hf_state_dict[layer_prefix + hf_name] +# ) + +# copy_from_ht_to_te( +# "self_attention.layernorm_qkv.layer_norm_weight", "input_layernorm.weight" +# ) +# copy_from_ht_to_te("self_attention.proj.weight", "self_attn.o_proj.weight") +# copy_from_ht_to_te("layernorm_mlp.layer_norm_weight", "post_attention_layernorm.weight") +# copy_from_ht_to_te("layernorm_mlp.fc2_weight", "mlp.down_proj.weight") +# copy_from_ht_to_te( +# "layernorm_mlp.fc1_weight", "mlp.gate_proj.weight", end=config.intermediate_size +# ) +# copy_from_ht_to_te( +# "layernorm_mlp.fc1_weight", "mlp.up_proj.weight", start=config.intermediate_size +# ) + +# if qkv_fused_and_interleaved: +# """ +# When qkv_fused_and_interleaved=True, key, query and value layers are on one tensor +# in TE TransformerLayer. Moreover they are interleaved within each head. +# Let q_i, k_i and v_i be query, key and value layers for i-th head respectively. +# Then TE stores weight tensor in the form: +# [q1 k1 v1 q2 k2 v2 ...] +# This is done to maximally optimize performance time. +# """ +# te_qkv_layer = te_state_dict[layer_prefix + "self_attention.layernorm_qkv.weight"] + +# def copy_interleave(hf_name, idx): +# src = hf_state_dict[layer_prefix + hf_name] +# for head_nr in range(config.num_attention_heads): +# dst_offset = head_nr * config.head_dim * 3 +# dst_slice = slice( +# dst_offset + idx * config.head_dim, dst_offset + (idx + 1) * config.head_dim +# ) +# src_slice = slice( +# head_nr * config.head_dim, head_nr * config.head_dim + config.head_dim +# ) +# te_qkv_layer[dst_slice, :] = src[src_slice, :] + +# copy_interleave("self_attn.q_proj.weight", 0) +# copy_interleave("self_attn.k_proj.weight", 1) +# copy_interleave("self_attn.v_proj.weight", 2) +# else: +# copy_from_ht_to_te( +# "self_attention.layernorm_qkv.query_weight", "self_attn.q_proj.weight" +# ) +# copy_from_ht_to_te("self_attention.layernorm_qkv.key_weight", "self_attn.k_proj.weight") +# copy_from_ht_to_te( +# "self_attention.layernorm_qkv.value_weight", "self_attn.v_proj.weight" +# ) + +# return all_layer_prefixes diff --git a/docs/examples/te_gemma/test_paged_attn.ipynb b/docs/examples/te_gemma/test_paged_attn.ipynb new file mode 100755 index 0000000000..543ebe9262 --- /dev/null +++ b/docs/examples/te_gemma/test_paged_attn.ipynb @@ -0,0 +1,33 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "ace403ac-c276-4378-a4e8-0155165f9934", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_finetuning_with_te.ipynb b/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_finetuning_with_te.ipynb new file mode 100755 index 0000000000..7875ffc9f3 --- /dev/null +++ b/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_finetuning_with_te.ipynb @@ -0,0 +1,314 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Accelerating a Hugging Face Gemma model finetuning with Transformer Engine" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the previous [tutorial](../te_llama/tutorial_accelerate_hf_llama_finetuning_with_te.ipynb), we demonstrated how to accelerate HF Llama models using the Transformer Engine library. We replaced `LlamaDecoderLayer` with `TransformerLayer` from the Transformer Engine, achieving a speedup. Furthermore, we conducted the finetuning in FP8 precision, which yielded an additional speedup.\n", + "\n", + "Now, we will undertake a similar enhancement for the Google's [Gemma](https://blog.google/technology/developers/gemma-open-models/) model." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Dependencies for this tutorial\n", + "\n", + "Following files and media are necessary to effectively run this tutorial:\n", + "\n", + "1. `te_gemma.py`\n", + " - This file contains the code to load a Hugging Face Gemma checkpoint in Transformer Engine's `TransformerLayer` instead of Hugging Face's `GemmaDecoderLayer`. This is used in the following two sections of the tutorial - \"Improvement 1\" and \"Improvement 2\".\n", + "2. `utils.py`\n", + " - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training and other miscellaneous tasks like restarting the jupyter notebook from within the cell. \n", + "3. `requirements.txt`\n", + " - This file contains necessary Python packages for this tutorial.\n", + "4. `media/`\n", + " - This directory contains the images used in the following tutorial." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install -r requirements.txt\n", + "\n", + "import torch\n", + "cudnn_version = torch.backends.cudnn.version()\n", + "assert cudnn_version >= 90100, \"cuDNN version >= 9.1.0 is needed to run this tutorial.\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Differences between Llama and Gemma" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Thr Llama and the Gemma are very similar models - both are based on Transformer Decoder architecture. The most important architectural differences between them are the following:\n", + "\n", + "\n", + "| Feature | Llama | Gemma |\n", + "|----------------------------------------------|------------------------------------|--------------------------------------------|\n", + "| **Norm Layer** | Standard RMSNorm
$y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\varepsilon}} * \\gamma + \\beta$ | RMSNorm with zero centered gamma parameter
$y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\varepsilon}} * (\\textcolor{red}{1 +} \\gamma) + \\beta$ |\n", + "| **Embedding Dimension/Head Dimension** | 4096/4096 | 3072/4096 |\n", + "| **Activation Function** | SwiGlu | GeGlu |\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## [Baseline] Running HF `GemmaModel` (Precision: `BF16`)\n", + "\n", + "Similarly to the Llama tutorial, we begin the experiments by running baseline Hugging Face Gemma model finetuning in BF16 precision.\n", + "\n", + "
\n", + "\n", + "Note\n", + " \n", + "This tutorial loads and trains a Gemma 7B model which takes up most of the GPU memory and therefore, we need to restart the jupyter notebook each time before running the following sections. A small utility method `restart_jupyter_notebook` is defined in the accompanying `utils.py` file. This function restarts the jupyter notebook so that the GPU memory is flushed before the model is loaded again from the checkpoint in order to avoid running into OOM (Out Of Memory) errors.\n", + "\n", + "If the utility doesn't work, comment this line `restart_jupyter_notebook()` in the following cell and manually restart the jupyter notebook before running the cell. Repeat the same for other sections in this tutorial.\n", + "\n", + "
\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10 finetuning steps complete!\n", + "\n", + "Average time taken per step: \n", + "298 \n", + "milliseconds\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "\n", + "# Import necessary packages and methods\n", + "from utils import *\n", + "\n", + "\n", + "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", + "## !!! `model_name` attr must point to the location of the model weights !!!\n", + "## Weights can be downloaded from: https://huggingface.co/google/gemma-7b\n", + "hyperparams.model_name = \"../../../../gemma-7b\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", + "hyperparams.mixed_precision = \"bf16\"\n", + "\n", + "\n", + "# Init the model and accelerator wrapper\n", + "model = init_baseline_model(hyperparams).cuda()\n", + "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", + "\n", + "\n", + "# Finetune the model\n", + "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's add this information in a table and keep comparing it with a few possible improvements in future sections:\n", + "\n", + "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", + "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", + "| HF (baseline) | BF16 | 298 | 1 |" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## [Improvement 1] Replace HF's `GemmaDecoderLayer` with TE's `TransformerLayer` (Precision: `BF16`)\n", + "\n", + "We replace *GemmaDecoderLayer* with the highly tuned *TransformerLayer*, similarly to our approach in the [Llama tutorial](../te_llama/tutorial_accelerate_hf_llama_finetuning_with_te.ipynb). Let's observe the impact this change has on the model's speed." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10 finetuning steps complete!\n", + "\n", + "Average time taken per step: \n", + "257 \n", + "milliseconds\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "\n", + "# Import necessary packages and methods\n", + "from utils import *\n", + "\n", + "\n", + "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", + "## !!! `model_name` attr must point to the location of the model weights !!!\n", + "## Weights can be downloaded from: https://huggingface.co/google/gemma-7b\n", + "hyperparams.model_name = \"../../../../gemma-7b\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", + "hyperparams.mixed_precision = \"bf16\"\n", + "\n", + "\n", + "# Init the model and accelerator wrapper\n", + "model = init_te_gemma_model(hyperparams).cuda()\n", + "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", + "\n", + "\n", + "# Finetune the model\n", + "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Compared to the \"baseline\" implementation, we see that using Transformer Engine's `TransformerLayer` in place of Huggging Face's `GemmaDecoderLayer` gives a speedup of **16%** even when using only BF16 precision!\n", + "\n", + "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", + "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", + "| HF (baseline) | BF16 | 298 | 1 |\n", + "| TE (replace `GemmaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 257 | 1.16 |" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## [Improvement 2] Replace HF's `GemmaDecoderLayer` with TE's `TransformerLayer` (Precision: `FP8`)\n", + "\n", + "The last improvement is about enabling FP8 precision. Let's see how it works." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10 finetuning steps complete!\n", + "\n", + "Average time taken per step: \n", + "214 \n", + "milliseconds\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "#restart_jupyter_notebook()\n", + "\n", + "\n", + "# Import necessary packages and methods\n", + "from utils import *\n", + "\n", + "\n", + "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", + "## !!! `model_name` attr must point to the location of the model weights !!!\n", + "## Weights can be downloaded from: https://huggingface.co/google/gemma-7b\n", + "hyperparams.model_name = \"../../../../gemma-7b\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", + "hyperparams.mixed_precision = \"fp8\"\n", + "\n", + "\n", + "# Init the model and accelerator wrapper\n", + "model = init_te_gemma_model(hyperparams).cuda()\n", + "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", + "\n", + "\n", + "# Finetune the model\n", + "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", + "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", + "| HF (baseline) | BF16 | 298 | 1 |\n", + "| TE (replace `GemmaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 257 | 1.16 |\n", + "| TE (replace `GemmaDecoderLayer` with `TE.TransformerLayer`) | FP8 | 214 | 1.39 |\n", + "\n", + "\n", + "After turning on FP8 precision, we get even more speedup of almost **39%**!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conclusion\n", + "\n", + "As shown in the [Llama tutorial](../te_llama/tutorial_accelerate_hf_llama_finetuning_with_te.ipynb), using the `TransformerLayer` module from Transformer Engine to replace Hugging Face's `GemmaDecoderLayer` results in a speedup compared to Hugging Face's native Gemma implementation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## See more\n", + "\n", + "We also prepared [tutorial](./tutorial_generation_gemma_with_te.ipynb) in which we will show how to speedup the Gemma model generation using Transformer Engine." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb new file mode 100755 index 0000000000..acb93b795e --- /dev/null +++ b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb @@ -0,0 +1,1277 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "40364db7", + "metadata": {}, + "source": [ + "# Accelerating token generation of the Hugging Face Gemma Model with Transformer Engine\n", + "\n", + "Generative AI has made remarkable strides in recent years, with Large Language Models (LLMs) like ChatGPT at the forefront. These models have revolutionized how we interact with machine-generated content, providing capabilities that range from writing assistance to complex decision support. The core functionality of these models is the generation process, which involves predicting the next token in a sequence based on the preceding text. This task is critical for applications such as automated content creation, translation, and more, emphasizing the importance of efficient implementation.\n", + "\n", + "\n", + "\n", + "
\n", + "\"\"\n", + "
\n", + "Animation 1: Hugging Face Gemma model token generation.\n", + "
\n", + "
\n", + "\n", + "For those seeking a deeper understanding of text generation mechanisms in Transformers, it is recommended to check out the [HuggingFace generation tutorial](https://huggingface.co/docs/transformers/llm_tutorial).\n", + "\n", + "In the previous tutorials on [Llama](../te_llama/tutorial_accelerate_hf_llama_finetuning_with_te.ipynb) and [Gemma](./tutorial_accelerate_hf_gemma_finetuning_with_te.ipynb), it was demonstrated how finetuning can be accelerated using the Transformer Engine's `TransformerLayer`. Building on this foundation, the current objective is to enhance the generation speed of the Gemma model.\n", + "\n", + "This tutorial will introduce and explain several advanced features of the Transformer Engine that contribute to this goal:\n", + "\n", + "###### **1. THD Attention Layout.**\n", + "\n", + "Addressing the challenge of computing attention for sequences with varying lengths, a common method is to pad these sequences and apply an attention mask. The Transformer Engine, however, offers a more optimized approach—by specifying the lengths and offsets of the sequences, attention can be computed directly. Instead of passing the tensor with shape `[b, s, h, d]` and the attention mask, one can pass a tensor of the shape `[t, h, d]` along with tensors detailing cumulative sequence lengths and offsets to run the attention optimized for this case. This specific attention layout is referred to as the **THD layout**. \n", + "\n", + "\n", + "The letter `t` in the standard `[t, h, d]` layout is equal to the total length of the sequences, namely `t = s_1 + s_2 + ... + s_b`, where `s_i` denotes the length of sequence `i`. TransformerEngine supports a THD layout that incorporates gaps between these sequences - the lengths of the offsets need to be passed in the additional parameter.\n", + "\n", + "
\n", + "\"\"\n", + "
\n", + "Figure 1: The difference between BSHD (default) and THD attention layouts is as follows: with BSHD, one needs to provide the attention mask, while with THD, one needs to provide cumulative sequence lengths and sequence offsets.\n", + "
\n", + "
\n", + "\n", + "###### **2. CUDA Graphs API.**\n", + "\n", + "The speed of GPUs is increasing at a rapid pace. It turns out that sometimes the runtime of kernels is shorter than the time it takes for the CPU to submit them, which can lead to significant overhead. CUDA Graphs can address this issue. When certain kernels are executed repeatedly, it allows us to record and replay them with less CPU involvement. This becomes particularly useful in applications like token generation, where a `TransformerLayer` is run for every token that needs to be generated.\n", + "\n", + "One can read more about CUDA Graphs [here](https://developer.nvidia.com/blog/cuda-graphs/).\n", + "\n", + "PyTorch exposes graphs via a raw `torch.cuda.CUDAGraph` class and two convenience wrappers: `torch.cuda.graph` and `torch.cuda.make_graphed_callables`. More information about the cuda graphs in Pytorch can be found [here](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/).\n", + "\n", + "
\n", + "\"\"\n", + "
\n", + "Figure 2: CUDA Graphs reduce the overhead generated by the long time it takes to launch a single kernel. It enables the recording and replaying of subsequent launches, thus reducing the total time used by the CPU.\n", + "
\n", + "
\n", + "\n", + "\n", + "###### **3. FP8 Weights Calibration.**\n", + "\n", + "Assuming that the model is trained in FP32/BF16 precision and the goal is to execute it in FP8 precision, the process isn't straightforward due to the absence of appropriate FP8 scaling factors. In this scenario, FP8 calibration becomes essential. By conducting several forward passes on sample data, the FP8 scaling parameters can be computed. This calibration allows the model to operate correctly in FP8 precision.\n", + "\n", + "It is highly recommended to familiarize oneself with the [tutorial](../../examples/fp8_primer.ipynb) on FP8 precision to understand the importance of proper scaling factors.\n", + "\n", + "\n", + "
\n", + "\"\"\n", + "
\n", + "Figure 3:\n", + "If the model is trained in BF16/FP32, it does not include the computed FP8 scaling factors. When it is run under fp8_autocast(), the value of these scaling factors will default to their initial values, which can cause numerical errors. Weight calibration involves calculating FP8 scaling factors from higher precision forward passes. Once these factors are computed, the model becomes numerically stable. \n", + "
\n", + "
\n", + "\n", + "###### **4. FP8 Model Weights.**\n", + "\n", + "The typical approach is to store weights in higher precision and then cast them to fp8 before operations. This may prevent accuraccy drops in training. However, for inference, this level of precision is not necessary.\n", + "\n", + "The TransformerEngine includes a wrapper `fp8_model_​init`, which allows for the creation of models that store only the FP8 copy of the weights. This eliminates the need to cast from higher precision to BF16, saving time in this casting process. \n", + "\n", + "
\n", + "\"\"\n", + "
\n", + "Figure 4: Model under fp8_autocast() stores weights in high precision by default, and casts them if needed. It can leads to slowdown and increased memory usage. Using fp8_model_init() results in storing weight in FP8.\n", + "
\n", + "
\n", + "\n", + "###### Benchmarking\n", + "\n", + "We'll evaluate the generation time across one benchmark: generation with context phase max sequence length = 128, batch size = 64 and number of generated tokens = 896 on random texts with random lengths.\n", + "\n", + "
\n", + "Note\n", + " \n", + "This tutorial focuses on showcasing the mentioned features of Transformer Engine in the context of token generation. It's important to note, however, that NVIDIA provides [TensorRT](https://developer.nvidia.com/tensorrt), which is optimized for inference tasks and should be considered for such use cases.\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "b18f91a9", + "metadata": {}, + "source": [ + "## Dependencies for this tutorial" + ] + }, + { + "cell_type": "markdown", + "id": "e5201d77", + "metadata": {}, + "source": [ + "Following files and media are necessary to effectively run this tutorial:\n", + "\n", + "1. `te_gemma.py`\n", + " - This file contains the code to load a Hugging Face Gemma checkpoint in Transformer Engine's `TransformerLayer` instead of Hugging Face's `GemmaDecoderLayer`. It does also contain code for generation with THD attention, CUDA Graphs and weight calibration.\n", + "2. `te_gemma_loading_weights.py`\n", + " - This file contains logic of mapping the parameters from `GemmaDecoderLayer` into the `TransformerLayer`.\n", + "3. `utils.py`\n", + " - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training and other miscellaneous tasks like restarting the jupyter notebook from within the cell. \n", + "4. `requirements.txt`\n", + " - This file contains necessary Python packages for this tutorial.\n", + "5. `media/`\n", + " - This directory contains the images used in the following tutorial." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "31390c76", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n", + "Collecting transformers==4.41.1 (from -r requirements.txt (line 1))\n", + " Downloading transformers-4.41.1-py3-none-any.whl.metadata (43 kB)\n", + "Collecting accelerate==0.30.1 (from -r requirements.txt (line 2))\n", + " Downloading accelerate-0.30.1-py3-none-any.whl.metadata (18 kB)\n", + "Collecting datasets==2.19.1 (from -r requirements.txt (line 3))\n", + " Downloading datasets-2.19.1-py3-none-any.whl.metadata (19 kB)\n", + "Collecting sentencepiece==0.2.0 (from -r requirements.txt (line 4))\n", + " Downloading sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.7 kB)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers==4.41.1->-r requirements.txt (line 1)) (3.16.1)\n", + "Collecting huggingface-hub<1.0,>=0.23.0 (from transformers==4.41.1->-r requirements.txt (line 1))\n", + " Downloading huggingface_hub-0.26.2-py3-none-any.whl.metadata (13 kB)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.41.1->-r requirements.txt (line 1)) (1.24.4)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers==4.41.1->-r requirements.txt (line 1)) (23.2)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.41.1->-r requirements.txt (line 1)) (6.0.2)\n", + "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.41.1->-r requirements.txt (line 1)) (2024.9.11)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers==4.41.1->-r requirements.txt (line 1)) (2.32.3)\n", + "Collecting tokenizers<0.20,>=0.19 (from transformers==4.41.1->-r requirements.txt (line 1))\n", + " Downloading tokenizers-0.19.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)\n", + "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.41.1->-r requirements.txt (line 1)) (0.4.5)\n", + "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers==4.41.1->-r requirements.txt (line 1)) (4.66.5)\n", + "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate==0.30.1->-r requirements.txt (line 2)) (6.0.0)\n", + "Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from accelerate==0.30.1->-r requirements.txt (line 2)) (2.5.0a0+e000cf0ad9.nv24.10)\n", + "Requirement already satisfied: pyarrow>=12.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.19.1->-r requirements.txt (line 3)) (16.1.0)\n", + "Collecting pyarrow-hotfix (from datasets==2.19.1->-r requirements.txt (line 3))\n", + " Downloading pyarrow_hotfix-0.6-py3-none-any.whl.metadata (3.6 kB)\n", + "Collecting dill<0.3.9,>=0.3.0 (from datasets==2.19.1->-r requirements.txt (line 3))\n", + " Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets==2.19.1->-r requirements.txt (line 3)) (2.2.2)\n", + "Collecting xxhash (from datasets==2.19.1->-r requirements.txt (line 3))\n", + " Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)\n", + "Collecting multiprocess (from datasets==2.19.1->-r requirements.txt (line 3))\n", + " Downloading multiprocess-0.70.17-py310-none-any.whl.metadata (7.2 kB)\n", + "Collecting fsspec<=2024.3.1,>=2023.1.0 (from fsspec[http]<=2024.3.1,>=2023.1.0->datasets==2.19.1->-r requirements.txt (line 3))\n", + " Downloading fsspec-2024.3.1-py3-none-any.whl.metadata (6.8 kB)\n", + "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets==2.19.1->-r requirements.txt (line 3)) (3.10.5)\n", + "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.19.1->-r requirements.txt (line 3)) (2.4.0)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.19.1->-r requirements.txt (line 3)) (1.3.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.19.1->-r requirements.txt (line 3)) (24.2.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.19.1->-r requirements.txt (line 3)) (1.4.1)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.19.1->-r requirements.txt (line 3)) (6.0.5)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.19.1->-r requirements.txt (line 3)) (1.9.4)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.19.1->-r requirements.txt (line 3)) (4.0.3)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.0->transformers==4.41.1->-r requirements.txt (line 1)) (4.12.2)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.41.1->-r requirements.txt (line 1)) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.41.1->-r requirements.txt (line 1)) (3.7)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.41.1->-r requirements.txt (line 1)) (2.0.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.41.1->-r requirements.txt (line 1)) (2024.8.30)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate==0.30.1->-r requirements.txt (line 2)) (3.3)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate==0.30.1->-r requirements.txt (line 2)) (3.1.4)\n", + "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate==0.30.1->-r requirements.txt (line 2)) (1.13.1)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch>=1.10.0->accelerate==0.30.1->-r requirements.txt (line 2)) (1.3.0)\n", + "INFO: pip is looking at multiple versions of multiprocess to determine which version is compatible with other requirements. This could take a while.\n", + "Collecting multiprocess (from datasets==2.19.1->-r requirements.txt (line 3))\n", + " Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.19.1->-r requirements.txt (line 3)) (2.9.0.post0)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.19.1->-r requirements.txt (line 3)) (2023.4)\n", + "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.19.1->-r requirements.txt (line 3)) (2024.1)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets==2.19.1->-r requirements.txt (line 3)) (1.16.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->accelerate==0.30.1->-r requirements.txt (line 2)) (2.1.5)\n", + "Downloading transformers-4.41.1-py3-none-any.whl (9.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m9.1/9.1 MB\u001b[0m \u001b[31m175.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading accelerate-0.30.1-py3-none-any.whl (302 kB)\n", + "Downloading datasets-2.19.1-py3-none-any.whl (542 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m542.0/542.0 kB\u001b[0m \u001b[31m334.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m628.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)\n", + "Downloading fsspec-2024.3.1-py3-none-any.whl (171 kB)\n", + "Downloading huggingface_hub-0.26.2-py3-none-any.whl (447 kB)\n", + "Downloading tokenizers-0.19.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.6/3.6 MB\u001b[0m \u001b[31m296.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading multiprocess-0.70.16-py310-none-any.whl (134 kB)\n", + "Downloading pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)\n", + "Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\n", + "Installing collected packages: sentencepiece, xxhash, pyarrow-hotfix, fsspec, dill, multiprocess, huggingface-hub, tokenizers, accelerate, transformers, datasets\n", + " Attempting uninstall: fsspec\n", + " Found existing installation: fsspec 2024.6.1\n", + " Uninstalling fsspec-2024.6.1:\n", + " Successfully uninstalled fsspec-2024.6.1\n", + " Attempting uninstall: dill\n", + " Found existing installation: dill 0.3.9\n", + " Uninstalling dill-0.3.9:\n", + " Successfully uninstalled dill-0.3.9\n", + "Successfully installed accelerate-0.30.1 datasets-2.19.1 dill-0.3.8 fsspec-2024.3.1 huggingface-hub-0.26.2 multiprocess-0.70.16 pyarrow-hotfix-0.6 sentencepiece-0.2.0 tokenizers-0.19.1 transformers-4.41.1 xxhash-3.5.0\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n", + "\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.3.1\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n", + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip install -r requirements.txt\n", + "\n", + "import torch\n", + "cudnn_version = torch.backends.cudnn.version()\n", + "assert cudnn_version >= 90100, \"cuDNN version >= 9.1.0 is needed to run this tutorial.\"" + ] + }, + { + "cell_type": "markdown", + "id": "e8dfabbf", + "metadata": {}, + "source": [ + "\n", + "|\n", + "## [Baseline] Running Hugging Face generation with Gemma model" + ] + }, + { + "cell_type": "markdown", + "id": "59560bff", + "metadata": {}, + "source": [ + "HuggingFace Transformers library offers generation API. \n", + "HuggingFace generation for the Gemma model will be used as a baseline." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "2803e0ec", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.\n", + "`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.\n", + "Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use\n", + "`config.hidden_activation` if you want to override this behaviour.\n", + "See https://github.com/huggingface/transformers/pull/29402 for more details.\n", + "Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:04<00:00, 1.02s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================== Generation example 1 ==============================\n", + "Prompt:\n", + "Here are the two facts about GPUs:\n", + "Generated text:\n", + "\n", + "\n", + "1. GPUs are very good at doing the same thing over and over again.\n", + "2. GPUs are very bad at doing different things at the same time.\n", + "\n", + "The first fact is why GPUs are so good at graphics. The second fact is\n", + "============================== Generation example 2 ==============================\n", + "Prompt:\n", + "Some facts about NVIDIA:\n", + "Generated text:\n", + "\n", + "\n", + "* NVIDIA is a global technology company that designs and develops high-performance computer graphics and computer processing units (CPUs) for the gaming and professional markets.\n", + "* The company was founded in 1993 and is headquartered in Santa Clara\n", + "============================== Generation example 3 ==============================\n", + "Prompt:\n", + "Here are the two facts about GPUs:\n", + "Generated text:\n", + "\n", + "\n", + "1. GPUs are very good at doing the same thing over and over again.\n", + "2. GPUs are very bad at doing different things at the same time.\n", + "\n", + "The first fact is why GPUs are so good at graphics. The second fact is\n", + "============================== Generation example 4 ==============================\n", + "Prompt:\n", + "Some facts about NVIDIA:\n", + "Generated text:\n", + "\n", + "\n", + "* NVIDIA is a global technology company that designs and develops high-performance computer graphics and computer processing units (CPUs) for the gaming and professional markets.\n", + "* The company was founded in 1993 and is headquartered in Santa Clara\n", + "============================== Generation example 5 ==============================\n", + "Prompt:\n", + "Here are the two facts about GPUs:\n", + "Generated text:\n", + "\n", + "\n", + "1. GPUs are very good at doing the same thing over and over again.\n", + "2. GPUs are very bad at doing different things at the same time.\n", + "\n", + "The first fact is why GPUs are so good at graphics. The second fact is\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "from utils import *\n", + "\n", + "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", + "# !!! `model_name` attr must point to the location of the model weights !!!\n", + "# Weights can be downloaded from: https://huggingface.co/google/gemma-7b.\n", + "# Weights should be in the *.safetensors HF format, not in the original format.\n", + "hyperparams.model_name = \"/tmp/gemma-7b-hf\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", + "\n", + "model = init_baseline_model(hyperparams)\n", + "\n", + "print_sample_of_generated_texts(model)\n", + "# benchmark_generation(model)" + ] + }, + { + "cell_type": "markdown", + "id": "b3698dc6", + "metadata": {}, + "source": [ + "Let's put this time into the table for later comparison.\n", + "\n", + "| Models | Time (s) | Speedup | \n", + "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n", + "| HF (baseline) | 87.68 | 1 |" + ] + }, + { + "cell_type": "markdown", + "id": "8bb40f45", + "metadata": {}, + "source": [ + "## [Improvement 1] Using TransformerLayer from Transformer Engine instead of GemmaDecoderLayer." + ] + }, + { + "cell_type": "markdown", + "id": "263b40f2", + "metadata": {}, + "source": [ + "As in the [Gemma](./tutorial_accelerate_hf_gemma_finetuning_with_te.ipynb) finetuning tutorial, a GemmaDecoderLayer is substituted by a tuned TransformerLayer from the Transformer Engine. Let's run it and compare the time with the baseline." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "9dceef93", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in TEGemmaForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n", + "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in GemmaModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "> \u001b[0;32m/perfhome/mnt/wkstn/work/repos/TransformerEngine/transformer_engine/pytorch/attention.py\u001b[0m(8223)\u001b[0;36mforward\u001b[0;34m()\u001b[0m\n", + "\u001b[0;32m 8221 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0m\u001b[0;32m 8222 \u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mpdb\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mpdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0m\u001b[0;32m-> 8223 \u001b[0;31m key_layer, value_layer = inference_params.save_to_kv_cache(\n", + "\u001b[0m\u001b[0;32m 8224 \u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayer_number\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey_layer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue_layer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0m\u001b[0;32m 8225 \u001b[0;31m )\n", + "\u001b[0m\n" + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + "ipdb> key_layer.shape\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([128, 64, 16, 256])\n" + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + "ipdb> value_layer.shape\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([128, 64, 16, 256])\n" + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + "ipdb> query_layer.shape\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8192, 16, 256])\n" + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + "ipdb> c\n" + ] + }, + { + "ename": "AssertionError", + "evalue": "Queries, keys and values must be 4D tensors when qkv_format = bshd!", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[2], line 11\u001b[0m\n\u001b[1;32m 7\u001b[0m hyperparams\u001b[38;5;241m.\u001b[39mmodel_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/tmp/gemma-7b-hf\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;66;03m# <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\u001b[39;00m\n\u001b[1;32m 9\u001b[0m model \u001b[38;5;241m=\u001b[39m init_te_gemma_model(hyperparams)\n\u001b[0;32m---> 11\u001b[0m \u001b[43mprint_sample_of_generated_texts\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;66;03m# benchmark_generation(model)\u001b[39;00m\n", + "File \u001b[0;32m/perfhome/mnt/wkstn/work/repos/TransformerEngine/docs/examples/te_gemma/utils.py:280\u001b[0m, in \u001b[0;36mprint_sample_of_generated_texts\u001b[0;34m(model)\u001b[0m\n\u001b[1;32m 277\u001b[0m inputs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minput_ids\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m inputs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minput_ids\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mcuda()\n\u001b[1;32m 278\u001b[0m inputs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattention_mask\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m inputs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattention_mask\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mcuda()\n\u001b[0;32m--> 280\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmax_new_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m50\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 281\u001b[0m generated_texts \u001b[38;5;241m=\u001b[39m tokenizer\u001b[38;5;241m.\u001b[39mbatch_decode(outputs, skip_special_tokens\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 283\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mprint_output\u001b[39m(prompts, generated_texts, idx):\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/perfhome/mnt/wkstn/work/repos/TransformerEngine/docs/examples/te_gemma/te_gemma.py:450\u001b[0m, in \u001b[0;36mTEGemmaForCausalLM.generate\u001b[0;34m(self, input_ids, pad_token_id, max_new_tokens, *args, **kwargs)\u001b[0m\n\u001b[1;32m 446\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mqkv_format \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthd\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 447\u001b[0m \u001b[38;5;66;03m# For thd layout padding is at the end, otherwise at the beginning.\u001b[39;00m\n\u001b[1;32m 448\u001b[0m TEGemmaForCausalLM\u001b[38;5;241m.\u001b[39m_padding_to_end(input_ids, lengths)\n\u001b[0;32m--> 450\u001b[0m hidden_states, next_tokens \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_generate_context_phase\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minference_params\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 452\u001b[0m \u001b[38;5;66;03m# Generation phase.\u001b[39;00m\n\u001b[1;32m 453\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mqkv_format \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthd\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n", + "File \u001b[0;32m/perfhome/mnt/wkstn/work/repos/TransformerEngine/docs/examples/te_gemma/te_gemma.py:381\u001b[0m, in \u001b[0;36mTEGemmaForCausalLM._generate_context_phase\u001b[0;34m(self, input_ids, inference_params)\u001b[0m\n\u001b[1;32m 378\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 379\u001b[0m inference_params\u001b[38;5;241m.\u001b[39msetup_before_new_input(length\u001b[38;5;241m=\u001b[39minput_ids\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m1\u001b[39m])\n\u001b[0;32m--> 381\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_model_context_phase\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 382\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 383\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconfig\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mqkv_format\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m!=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mthd\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 384\u001b[0m \u001b[43m \u001b[49m\u001b[43mattn_mask_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mpadding_causal\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43minference_params\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mqkv_format\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mthd\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43marbitrary\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\n\u001b[1;32m 385\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 387\u001b[0m \u001b[38;5;66;03m# We choose logits coresponding with last token in each sequence,\u001b[39;00m\n\u001b[1;32m 388\u001b[0m \u001b[38;5;66;03m# which have various lengths - they are stored in (inference_params.incoming_seq_len - 1)\u001b[39;00m\n\u001b[1;32m 389\u001b[0m \u001b[38;5;66;03m# Tensor when qkv_format == \"thd\" and\u001b[39;00m\n\u001b[1;32m 390\u001b[0m \u001b[38;5;66;03m# they are the last token in the sequence when qkv_format != \"thd\".\u001b[39;00m\n\u001b[1;32m 391\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mqkv_format \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthd\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n", + "File \u001b[0;32m/perfhome/mnt/wkstn/work/repos/TransformerEngine/docs/examples/te_gemma/te_gemma.py:183\u001b[0m, in \u001b[0;36mStaticGemmaModel.forward\u001b[0;34m(self, hidden_states, attention_mask, attn_mask_type)\u001b[0m\n\u001b[1;32m 180\u001b[0m hidden_states\u001b[38;5;241m.\u001b[39mdata[:] \u001b[38;5;241m=\u001b[39m hidden_states\u001b[38;5;241m.\u001b[39mdata[:] \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnormalizer\n\u001b[1;32m 182\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i, decoder_layer \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mlayers):\n\u001b[0;32m--> 183\u001b[0m hidden_states\u001b[38;5;241m.\u001b[39mdata[:] \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 184\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 185\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 186\u001b[0m \u001b[43m \u001b[49m\u001b[43mself_attn_mask_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmask\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mattn_mask_type\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mis\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mattn_mask_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 187\u001b[0m \u001b[43m \u001b[49m\u001b[43minference_params\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minference_params\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 188\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m[\n\u001b[1;32m 189\u001b[0m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 190\u001b[0m ] \u001b[38;5;66;03m# static copy - for CUDA graphs\u001b[39;00m\n\u001b[1;32m 192\u001b[0m hidden_states\u001b[38;5;241m.\u001b[39mcopy_(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mnorm(hidden_states)) \u001b[38;5;66;03m# static copy - for CUDA graphs\u001b[39;00m\n\u001b[1;32m 193\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlm_head(hidden_states)\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n", + "File \u001b[0;32m/perfhome/mnt/wkstn/work/repos/TransformerEngine/docs/examples/te_gemma/te_gemma.py:151\u001b[0m, in \u001b[0;36mTEGemmaDecoderLayer.forward\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 148\u001b[0m kwargs\u001b[38;5;241m.\u001b[39mpop(key, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[1;32m 150\u001b[0m \u001b[38;5;66;03m# We need to return tuple to be compatible with HF.\u001b[39;00m\n\u001b[0;32m--> 151\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrotary_pos_emb\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mte_rope_emb\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m,)\n", + "File \u001b[0;32m/perfhome/mnt/wkstn/work/repos/TransformerEngine/transformer_engine/pytorch/transformer.py:690\u001b[0m, in \u001b[0;36mTransformerLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, self_attn_mask_type, window_size, encoder_output, enc_dec_attn_mask, enc_dec_attn_mask_type, enc_dec_window_size, is_first_microbatch, checkpoint_core_attention, inference_params, rotary_pos_emb, core_attention_bias_type, core_attention_bias, alibi_slopes, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, fast_zero_fill)\u001b[0m\n\u001b[1;32m 687\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m cast_if_needed(hidden_states, torch\u001b[38;5;241m.\u001b[39mget_autocast_gpu_dtype())\n\u001b[1;32m 689\u001b[0m \u001b[38;5;66;03m# Self attention.\u001b[39;00m\n\u001b[0;32m--> 690\u001b[0m self_attention_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mself_attention\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 691\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 692\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 693\u001b[0m \u001b[43m \u001b[49m\u001b[43mattn_mask_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mself_attn_mask_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 694\u001b[0m \u001b[43m \u001b[49m\u001b[43mwindow_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mwindow_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 695\u001b[0m \u001b[43m \u001b[49m\u001b[43minference_params\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minference_params\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 696\u001b[0m \u001b[43m \u001b[49m\u001b[43mis_first_microbatch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_first_microbatch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 697\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheckpoint_core_attention\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcheckpoint_core_attention\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 698\u001b[0m \u001b[43m \u001b[49m\u001b[43mrotary_pos_emb\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrotary_pos_emb\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 699\u001b[0m \u001b[43m \u001b[49m\u001b[43mcore_attention_bias_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcore_attention_bias_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 700\u001b[0m \u001b[43m \u001b[49m\u001b[43mcore_attention_bias\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcore_attention_bias\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 701\u001b[0m \u001b[43m \u001b[49m\u001b[43malibi_slopes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43malibi_slopes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 702\u001b[0m \u001b[43m \u001b[49m\u001b[43mcu_seqlens_q\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcu_seqlens_q\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 703\u001b[0m \u001b[43m \u001b[49m\u001b[43mcu_seqlens_kv\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcu_seqlens_kv\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 704\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_seqlen_q\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmax_seqlen_q\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 705\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_seqlen_kv\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmax_seqlen_kv\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 706\u001b[0m \u001b[43m \u001b[49m\u001b[43mfast_zero_fill\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfast_zero_fill\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 707\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 709\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mapply_residual_connection_post_layernorm \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_layernorm:\n\u001b[1;32m 710\u001b[0m attention_output, attention_bias, residual \u001b[38;5;241m=\u001b[39m self_attention_outputs\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n", + "File \u001b[0;32m/perfhome/mnt/wkstn/work/repos/TransformerEngine/transformer_engine/pytorch/attention.py:9453\u001b[0m, in \u001b[0;36mMultiheadAttention.forward\u001b[0;34m(self, hidden_states, attention_mask, encoder_output, attn_mask_type, window_size, is_first_microbatch, checkpoint_core_attention, inference_params, rotary_pos_emb, core_attention_bias_type, core_attention_bias, alibi_slopes, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, fast_zero_fill)\u001b[0m\n\u001b[1;32m 9447\u001b[0m query_layer \u001b[38;5;241m=\u001b[39m query_layer\u001b[38;5;241m.\u001b[39mview(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m*\u001b[39mquery_layer\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m2\u001b[39m:])\u001b[38;5;241m.\u001b[39mcontiguous()\n\u001b[1;32m 9449\u001b[0m \u001b[38;5;66;03m# ===========================\u001b[39;00m\n\u001b[1;32m 9450\u001b[0m \u001b[38;5;66;03m# Core attention computation\u001b[39;00m\n\u001b[1;32m 9451\u001b[0m \u001b[38;5;66;03m# ===========================\u001b[39;00m\n\u001b[0;32m-> 9453\u001b[0m context_layer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcore_attention\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 9454\u001b[0m \u001b[43m \u001b[49m\u001b[43mquery_layer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 9455\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey_layer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 9456\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalue_layer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 9457\u001b[0m \u001b[43m \u001b[49m\u001b[43mqkv_format\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mqkv_format\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 9458\u001b[0m \u001b[43m \u001b[49m\u001b[43mcu_seqlens_q\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcu_seqlens_q\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 9459\u001b[0m \u001b[43m \u001b[49m\u001b[43mcu_seqlens_kv\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcu_seqlens_kv\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 9460\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_seqlen_q\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmax_seqlen_q\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 9461\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_seqlen_kv\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmax_seqlen_kv\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 9462\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 9463\u001b[0m \u001b[43m \u001b[49m\u001b[43mattn_mask_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattn_mask_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 9464\u001b[0m \u001b[43m \u001b[49m\u001b[43mwindow_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mwindow_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 9465\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheckpoint_core_attention\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcheckpoint_core_attention\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 9466\u001b[0m \u001b[43m \u001b[49m\u001b[43mcore_attention_bias_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcore_attention_bias_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 9467\u001b[0m \u001b[43m \u001b[49m\u001b[43mcore_attention_bias\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcore_attention_bias\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 9468\u001b[0m \u001b[43m \u001b[49m\u001b[43malibi_slopes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43malibi_slopes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 9469\u001b[0m \u001b[43m \u001b[49m\u001b[43mfast_zero_fill\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfast_zero_fill\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 9470\u001b[0m \u001b[43m \u001b[49m\u001b[43minference_params\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minference_params\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 9471\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 9473\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mqkv_format \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthd\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 9474\u001b[0m \u001b[38;5;66;03m# [b * sq, h] -> [qs, b, h]\u001b[39;00m\n\u001b[1;32m 9475\u001b[0m context_layer \u001b[38;5;241m=\u001b[39m context_layer\u001b[38;5;241m.\u001b[39mview(\n\u001b[1;32m 9476\u001b[0m (inference_params\u001b[38;5;241m.\u001b[39mmax_batch_size, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, context_layer\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m1\u001b[39m])\n\u001b[1;32m 9477\u001b[0m )\u001b[38;5;241m.\u001b[39mcontiguous()\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n", + "File \u001b[0;32m/perfhome/mnt/wkstn/work/repos/TransformerEngine/transformer_engine/pytorch/attention.py:8301\u001b[0m, in \u001b[0;36mDotProductAttention.forward\u001b[0;34m(self, query_layer, key_layer, value_layer, attention_mask, qkv_format, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded, max_seqlen_q, max_seqlen_kv, attn_mask_type, window_size, checkpoint_core_attention, core_attention_bias_type, core_attention_bias, alibi_slopes, fast_zero_fill, inference_params, is_first_microbatch)\u001b[0m\n\u001b[1;32m 8298\u001b[0m context_parallel \u001b[38;5;241m=\u001b[39m cp_size \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 8300\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m qkv_format \u001b[38;5;129;01min\u001b[39;00m [\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msbhd\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbshd\u001b[39m\u001b[38;5;124m\"\u001b[39m]:\n\u001b[0;32m-> 8301\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mall\u001b[39m(\n\u001b[1;32m 8302\u001b[0m \u001b[38;5;28mlen\u001b[39m(x\u001b[38;5;241m.\u001b[39mshape) \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m4\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m (query_layer, key_layer, value_layer)\n\u001b[1;32m 8303\u001b[0m ), \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mQueries, keys and values must be 4D tensors when qkv_format = \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mqkv_format\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m!\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 8304\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m qkv_format \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msbhd\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 8305\u001b[0m max_seqlen_q \u001b[38;5;241m=\u001b[39m query_layer\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m max_seqlen_q \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m max_seqlen_q\n", + "\u001b[0;31mAssertionError\u001b[0m: Queries, keys and values must be 4D tensors when qkv_format = bshd!" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "from utils import *\n", + "\n", + "hyperparams.model_name = \"/tmp/gemma-7b-hf\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", + "\n", + "model = init_te_gemma_model(hyperparams)\n", + "\n", + "print_sample_of_generated_texts(model)\n", + "# benchmark_generation(model)" + ] + }, + { + "cell_type": "markdown", + "id": "b5d40836", + "metadata": { + "jupyter": { + "source_hidden": true + } + }, + "source": [ + "The speedup of **62%** was obtained." + ] + }, + { + "cell_type": "markdown", + "id": "006d18e8", + "metadata": {}, + "source": [ + "| Models | Time (s) | Speedup | \n", + "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n", + "| HF (baseline) | 87.68 | 1 |\n", + "| TE (subsitution of GemmaDecoderLayer with te.TransformerLayer) | 54.11 | 1.62 | " + ] + }, + { + "cell_type": "markdown", + "id": "2bbf3d47", + "metadata": {}, + "source": [ + "## [Improvement 2] Use of THD attention layout.\n", + "\n", + "Input sequences can have various lengths. Hugging Face generation – as can be seen in Animation 1 – pads the sequences and then uses attention mask. In the THD attention layout cumulative sequence lengths and offsets need to be provided, instead of attention mask. The THD attention layout is much more optimized than BSHD layout.\n", + "\n", + "The class `transformer_engine.pytorch.DotProductAttention` supports this format. One need to pass the following things as the arguments to the forward:\n", + "- `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` – offsets of the beginnings of the next sequences,\n", + "- `cu_seqlens_q`, `cu_seqlens_kv` – cumulative sum of the lengths of the sequences of query and values,\n", + "- `max_seqlen_q` – maximum sequence length in query layer,\n", + "- `max_seqlen_kv` – maximum sequence length in key-value layer.\n", + "\n", + "
\n", + "Note\n", + "\n", + "Currently, the THD attention for `TransformerLayer` is supported only for token generation.\n", + "
\n", + "\n", + "Let's look how using TransformerEngine with THD attention impacts the speed of token generation:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "4fc5e1cd", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in TEGemmaForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n", + "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in GemmaModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================== Generation example 1 ==============================\n", + "Prompt:\n", + "Here are the two facts about GPUs:\n", + "Generated text:\n", + "\n", + "\n", + "1. They are very good at doing the same thing over and over again.\n", + "2. They are very bad at doing different things at the same time.\n", + "\n", + "This is why they are so good at rendering 3D graphics.\n", + "\n", + "The GPU\n", + "============================== Generation example 2 ==============================\n", + "Prompt:\n", + "Some facts about NVIDIA:\n", + "Generated text:\n", + "\n", + "\n", + "* NVIDIA is a global technology company that designs and develops graphics processing units (GPUs) for the gaming and professional markets.\n", + "* NVIDIA was founded in 1993 by Jensen Huang, Chris Malachowsky, and Curtis Priem.\n", + "============================== Generation example 3 ==============================\n", + "Prompt:\n", + "Here are the two facts about GPUs:\n", + "Generated text:\n", + "\n", + "\n", + "1. They are very good at doing the same thing over and over again.\n", + "2. They are very bad at doing different things at the same time.\n", + "\n", + "This is why they are so good at rendering 3D graphics.\n", + "\n", + "The GPU\n", + "============================== Generation example 4 ==============================\n", + "Prompt:\n", + "Some facts about NVIDIA:\n", + "Generated text:\n", + "\n", + "\n", + "* NVIDIA is a global technology company that designs and develops graphics processing units (GPUs) for the gaming and professional markets.\n", + "* NVIDIA was founded in 1993 by Jensen Huang, Chris Malachowsky, and Curtis Priem.\n", + "============================== Generation example 5 ==============================\n", + "Prompt:\n", + "Here are the two facts about GPUs:\n", + "Generated text:\n", + "\n", + "\n", + "1. They are very good at doing the same thing over and over again.\n", + "2. They are very bad at doing different things at the same time.\n", + "\n", + "This is why they are so good at rendering 3D graphics.\n", + "\n", + "The GPU\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "from utils import *\n", + "\n", + "hyperparams.model_name = \"/tmp/gemma-7b-hf/\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", + "hyperparams.qkv_format = \"thd\"\n", + "\n", + "model = init_te_gemma_model(hyperparams)\n", + "\n", + "print_sample_of_generated_texts(model)\n", + "# benchmark_generation(model)" + ] + }, + { + "cell_type": "markdown", + "id": "8e397a65", + "metadata": {}, + "source": [ + "By using THD attention, the following speedup was obtained:\n", + "\n", + "| Models | Time (s) | Speedup | \n", + "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n", + "| HF (baseline) | 87.68 | 1 |\n", + "| TE (subsitution of GemmaDecoderLayer with te.TransformerLayer) | 54.11 | 1.62 | \n", + "| TE + THD attention | 28.22 | 3.11 | " + ] + }, + { + "cell_type": "markdown", + "id": "21a89d9c", + "metadata": {}, + "source": [ + "## [Improvement 3] Speeding up generation with CUDA Graphs" + ] + }, + { + "cell_type": "markdown", + "id": "e2d53e7b", + "metadata": {}, + "source": [ + "TransformerEngine includes a function `transformer_engine.pytorch.make_graphed_callables`, which functions similarly to the corresponding feature in PyTorch. It is capable of recording any modules from the Transformer Engine. Below is a code excerpt from `te_gemma.py` from class `TEGemmaForCausalLMCudaGraphs`:\n", + "```\n", + " def __init__(self, config : GemmaConfig):\n", + " (...)\n", + " \n", + " # Here \"the trick\" happens. We override methods from TEGemmaForCausalLM\n", + " # with their recorded version. After invocation of each of them,\n", + " # captured graph will be replayed with minimal usage of CPU,\n", + " # what will lead to huge speedup.\n", + " (...)\n", + " self._model_context_phase = \n", + " self.record_graph(self._model_context_phase, self.hidden_states_buffer) # CUDA Graphs recording\n", + "\n", + " (...) \n", + " self._model_generation_phase = \n", + " self.record_graph(self._model_generation_phase, self.generation_buffer) # CUDA Graphs recording\n", + "\n", + " @torch.no_grad()\n", + " def record_graph(self, function, input_tensor):\n", + " (...)\n", + " # function is invoked on argument (self.hidden_states,) and all kernels are recorded.\n", + " # record_graph() returns captured function, which can be run later with minimal use of th CPU.\n", + " fp8_format = Format.HYBRID\n", + " fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo=\"max\")\n", + " with autocast(dtype=torch.bfloat16, cache_enabled=False):\n", + " graphed_function = te.pytorch.make_graphed_callables(\n", + " function, \n", + " (input_tensor,), \n", + " fp8_enabled=True, \n", + " fp8_recipe=fp8_recipe, \n", + " allow_unused_input=True,\n", + " num_warmup_iters=3\n", + " )\n", + " return graphed_function\n", + "```\n", + "\n", + "It is strongly reccomended to review the entire code of the class `TEGemmaForCausalLMCudaGraphs`. Let's now proceed to evaluate the performance improvement offered by CUDA Graphs." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "31a3a8a3", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in TEGemmaForCausalLMCudaGraphs is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n", + "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in GemmaModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================== Generation example 1 ==============================\n", + "Prompt:\n", + "Here are the two facts about GPUs:\n", + "Generated text:\n", + "\n", + "\n", + "1. They are very good at doing the same thing over and over again.\n", + "2. They are very bad at doing different things at the same time.\n", + "\n", + "This is why they are so good at rendering 3D graphics.\n", + "\n", + "The GPU\n", + "============================== Generation example 2 ==============================\n", + "Prompt:\n", + "Some facts about NVIDIA:\n", + "Generated text:\n", + "\n", + "\n", + "* NVIDIA is a global technology company that designs and develops graphics processing units (GPUs) for the gaming and professional markets.\n", + "* NVIDIA was founded in 1993 by Jensen Huang, Chris Malachowsky, and Curtis Priem.\n", + "============================== Generation example 3 ==============================\n", + "Prompt:\n", + "Here are the two facts about GPUs:\n", + "Generated text:\n", + "\n", + "\n", + "1. They are very good at doing the same thing over and over again.\n", + "2. They are very bad at doing different things at the same time.\n", + "\n", + "This is why they are so good at rendering 3D graphics.\n", + "\n", + "The GPU\n", + "============================== Generation example 4 ==============================\n", + "Prompt:\n", + "Some facts about NVIDIA:\n", + "Generated text:\n", + "\n", + "\n", + "* NVIDIA is a global technology company that designs and develops graphics processing units (GPUs) for the gaming and professional markets.\n", + "* NVIDIA was founded in 1993 by Jensen Huang, Chris Malachowsky, and Curtis Priem.\n", + "============================== Generation example 5 ==============================\n", + "Prompt:\n", + "Here are the two facts about GPUs:\n", + "Generated text:\n", + "\n", + "\n", + "1. They are very good at doing the same thing over and over again.\n", + "2. They are very bad at doing different things at the same time.\n", + "\n", + "This is why they are so good at rendering 3D graphics.\n", + "\n", + "The GPU\n" + ] + } + ], + "source": [ + "#Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "\n", + "from utils import *\n", + "\n", + "hyperparams.model_name = \"/tmp/gemma-7b-hf/\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", + "hyperparams.qkv_format = \"thd\"\n", + "\n", + "hyperparams.generation_cuda_graphs = True\n", + "\n", + "# It is necessary to preallocate a static buffer.\n", + "# CUDA graphs require static input tensors for every kernel.\n", + "# This approach may result in a slight increase in memory consumption;\n", + "# however, the substantial speedup achieved makes it worthwhile.\n", + "hyperparams.cuda_graphs_static_batch_size = 64\n", + "hyperparams.cuda_graphs_static_max_seq_len = 1024\n", + "hyperparams.cuda_graphs_static_max_context_len = 128\n", + "model = init_te_gemma_model(hyperparams)\n", + "\n", + "print_sample_of_generated_texts(model)\n", + "# benchmark_generation(model)" + ] + }, + { + "cell_type": "markdown", + "id": "53bb430f", + "metadata": {}, + "source": [ + "The **5.23x** speedup was obtained.\n", + "\n", + "| Models | Time (s) | Speedup | \n", + "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n", + "| HF (baseline) | 87.68 | 1 |\n", + "| TE (subsitution of GemmaDecoderLayer with te.TransformerLayer) | 54.11 | 1.62 | \n", + "| TE + THD attention | 28.22 | 3.11 | \n", + "| TE + THD attention + CUDA Graphs | 16.75 | 5.23 | \n" + ] + }, + { + "cell_type": "markdown", + "id": "0a11b75c", + "metadata": {}, + "source": [ + "Let's look at the screenshots from *NVIDIA Nsight System* profiler to see where this speedup comes from:\n", + "\n", + "
\n", + "\n", + "
\n", + "Figure 5: Without CUDA Graphs. One can see that GPU (blue) is idle for big portion of the time.\n", + "
\n", + "
\n", + "\n", + "
\n", + "\n", + "
\n", + "Figure 6: With CUDA Graphs. One can see that GPU (orange) is fully utilized.\n", + "
\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "e6b171a0", + "metadata": {}, + "source": [ + "## [Improvement 4] Running generation in FP8 of the model trained in higher precision " + ] + }, + { + "cell_type": "markdown", + "id": "1a80288b", + "metadata": {}, + "source": [ + "Implementing FP8 generation with the Gemma model is not straightforward, because this model was initially trained using BF16 precision, and the necessary FP8 scaling factors are missing. Running the model at this lower precision without proper scaling could lead to significant errors and incorrect results.\n", + "\n", + "It is highly recommended to familiarize oneself with the [tutorial](../../examples/fp8_primer.ipynb) on FP8 precision to understand the necessity of scaling.\n", + "\n", + "\n", + "
\n", + "\n", + "
\n", + " Figure 8: The FP8 scaling factors are incorrect and that leads to numerical errors. The weight calibration allows us to compute FP8 metadata during the forwards in higher precision.\n", + "
\n", + "
\n", + "\n", + "### Weight Calibration\n", + "\n", + "To address the issue outlined above, weight calibration will be used. This involves running several forward iterations at BF16 precision within the context `te.fp8_autocast(enabled=False, calibration=True)`. This setup allows the forward pass to operate at higher precision, while simultaneously collecting `amax_history` and other parameters related to the FP8 precision, which are essential for calculating the FP8 scaling well.\n", + "\n", + "The code below outlines the steps to initialize the BF16 model and conduct several forward iterations within the specified context. After these iterations, the model is saved, and these weights will be utilized in subsequent chapters." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "aecee0e1", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in TEGemmaForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n", + "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in GemmaModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n", + "Repo card metadata block was not found. Setting CardData to empty.\n", + "[WARNING | huggingface_hub.repocard]: Repo card metadata block was not found. Setting CardData to empty.\n", + "Repo card metadata block was not found. Setting CardData to empty.\n", + "[WARNING | huggingface_hub.repocard]: Repo card metadata block was not found. Setting CardData to empty.\n" + ] + } + ], + "source": [ + "#Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "from utils import *\n", + "import transformer_engine.pytorch as te\n", + "\n", + "hyperparams.model_name = \"/tmp/gemma-7b-hf/\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", + "hyperparams.fuse_qkv_params = True # This is needed by the last improvement.\n", + "\n", + "model = init_te_gemma_model(hyperparams)\n", + "\n", + "# Calibration\n", + "with te.fp8_autocast(enabled=False, calibrating=True), \\\n", + " torch.autocast(device_type='cuda', dtype=torch.bfloat16):\n", + " model.train()\n", + " run_forward_pass(model, hyperparams, num_iters=512)\n", + "\n", + "# Compute scale_fwd with enabled fp8 autocast\n", + "with te.fp8_autocast(enabled=True), \\\n", + " torch.autocast(device_type='cuda', dtype=torch.bfloat16):\n", + " run_forward_pass(model, hyperparams, 1)\n", + "\n", + "# Some parameters are in pointing to the same tensors, double save is avoided here.\n", + "dict_to_save = {k: v for k, v in model.state_dict().items() \\\n", + " if (\"_context_phase\" not in k and \"_generation_phase\" not in k)}\n", + "torch.save(dict_to_save, 'calibrated_weights.pth') # <== Add path to save calibrated weights." + ] + }, + { + "cell_type": "markdown", + "id": "b6dcd135", + "metadata": {}, + "source": [ + "|\n", + "### Generation in FP8\n", + "\n", + "
\n", + "\n", + "
\n", + " Figure 8: After the weight calibration FP8 scaling factors are correct and prevent numerical errors.\n", + "
\n", + "
\n", + "\n", + "Now FP8 inference is ready to be run." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38e005f5", + "metadata": {}, + "outputs": [], + "source": [ + "!ls -alh /perfhome/repos/data/gemma-7b-hf/" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "a913f54d", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in TEGemmaForCausalLMCudaGraphs is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n", + "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in GemmaModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================== Generation example 1 ==============================\n", + "Prompt:\n", + "Here are the two facts about GPUs:\n", + "Generated text:\n", + "\n", + "\n", + "1. GPUs are very good at doing the same thing over and over again.\n", + "2. GPUs are very bad at doing different things at the same time.\n", + "\n", + "This is a very important distinction to make.\n", + "\n", + "The first fact is a good thing\n", + "============================== Generation example 2 ==============================\n", + "Prompt:\n", + "Some facts about NVIDIA:\n", + "Generated text:\n", + "\n", + "\n", + "* NVIDIA is a global technology company that designs and develops graphics processing units (GPUs) for the gaming and professional markets.\n", + "* NVIDIA was founded in 1993 and is headquartered in Santa Clara, California.\n", + "* NVIDIA's\n", + "============================== Generation example 3 ==============================\n", + "Prompt:\n", + "Here are the two facts about GPUs:\n", + "Generated text:\n", + "\n", + "\n", + "1. GPUs are very good at doing the same thing over and over again.\n", + "2. GPUs are very bad at doing different things at the same time.\n", + "\n", + "This is a very important distinction to make.\n", + "\n", + "The first fact is a good thing\n", + "============================== Generation example 4 ==============================\n", + "Prompt:\n", + "Some facts about NVIDIA:\n", + "Generated text:\n", + "\n", + "\n", + "* NVIDIA is a global technology company that designs and develops graphics processing units (GPUs) for the gaming and professional markets.\n", + "* NVIDIA was founded in 1993 and is headquartered in Santa Clara, California.\n", + "* NVIDIA's\n", + "============================== Generation example 5 ==============================\n", + "Prompt:\n", + "Here are the two facts about GPUs:\n", + "Generated text:\n", + "\n", + "\n", + "1. GPUs are very good at doing the same thing over and over again.\n", + "2. GPUs are very bad at doing different things at the same time.\n", + "\n", + "This is a very important distinction to make.\n", + "\n", + "The first fact is a good thing\n" + ] + } + ], + "source": [ + "#Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "from utils import *\n", + "\n", + "hyperparams.model_name = \"/tmp/gemma-7b-hf/\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", + "hyperparams.qkv_format = \"thd\"\n", + "hyperparams.fuse_qkv_params = True # This is needed by the last improvement.\n", + "\n", + "hyperparams.fp8 = True\n", + "# Calibrated fp8 weights are loaded directly from the file.\n", + "\n", + "hyperparams.fp8_model_weights_filename = \"calibrated_weights.pth\" # <== Add calibrated weights location here.\n", + "\n", + "hyperparams.generation_cuda_graphs = True\n", + "hyperparams.cuda_graphs_static_batch_size = 64\n", + "hyperparams.cuda_graphs_static_max_seq_len = 1024\n", + "hyperparams.cuda_graphs_static_max_context_len = 128\n", + "model = init_te_gemma_model(hyperparams)\n", + "\n", + "print_sample_of_generated_texts(model)\n", + "# benchmark_generation(model)" + ] + }, + { + "cell_type": "markdown", + "id": "8cdbb56c", + "metadata": {}, + "source": [ + "One can observe that the outputs are coherent; however, the generation time has increased. Why is this the case?\n", + "\n", + "\n", + "
\n", + "\n", + "
\n", + " Figure 9: Running the model at higher precision involves only one GEMM operation. However, when the model operates in FP8, it requires not just the low-precision GEMM but also weight casting.\n", + "
\n", + "
\n", + "\n", + "Running the model in FP8 does not imply that all weights are stored in FP8. By default, they are stored in higher precision and are cast to FP8, using saved scaling factors, before operations such as GEMMs.\n", + "\n", + "This approach is beneficial during training: one can perform one cast for both backward and forward passes, leading to speedups. However, performing a single cast for each forward pass introduces too much overhead to achieve a speedup. This issue will be addressed in the next section of the tutorial." + ] + }, + { + "cell_type": "markdown", + "id": "8d3945e3", + "metadata": {}, + "source": [ + "### Use of only FP8 model weights" + ] + }, + { + "cell_type": "markdown", + "id": "2dd0cba9", + "metadata": {}, + "source": [ + "TransformerEngine stores parameters in higher precision and only casts them to FP8. It may be necessary to maintain accucacy during training. However, high precision is not needed when doing inference. \n", + "\n", + "Transformer Engine supports maintaining only FP8 weights with `fp8_model_init` decorator. Let's see an example\n", + "```\n", + "linear = te.Linear(1024, 1024) # this module is initialized with full precision weights\n", + "with te.fp8_model_init(enabled=True):\n", + " linear_fp8 = te.Linear(1024, 1024) # this module is initialized only with fp8 weights\n", + "\n", + "assert type(linear.weight.data) is torch.Tensor\n", + "assert type(linear_fp8.weight.data) is te.float8_tensor.Float8Tensor\n", + "```\n", + "\n", + "
\n", + "\n", + "
\n", + " Figure 10: Using fp8_model_init stores the weights directly in FP8 format, which reduces both time and memory usage.\n", + "
\n", + "
\n", + "\n", + "Let's run the code with `fp8_model_init`:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "96264b9c", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in TEGemmaForCausalLMCudaGraphs is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n", + "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in GemmaModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================== Generation example 1 ==============================\n", + "Prompt:\n", + "Here are the two facts about GPUs:\n", + "Generated text:\n", + "\n", + "\n", + "1. GPUs are very good at doing the same thing over and over again.\n", + "2. GPUs are very bad at doing different things at the same time.\n", + "\n", + "This is a very important distinction to make.\n", + "\n", + "The first fact is a good thing\n", + "============================== Generation example 2 ==============================\n", + "Prompt:\n", + "Some facts about NVIDIA:\n", + "Generated text:\n", + "\n", + "\n", + "* NVIDIA is a global technology company that designs and develops graphics processing units (GPUs) for the gaming and professional markets.\n", + "* NVIDIA was founded in 1993 and is headquartered in Santa Clara, California.\n", + "* NVIDIA's\n", + "============================== Generation example 3 ==============================\n", + "Prompt:\n", + "Here are the two facts about GPUs:\n", + "Generated text:\n", + "\n", + "\n", + "1. GPUs are very good at doing the same thing over and over again.\n", + "2. GPUs are very bad at doing different things at the same time.\n", + "\n", + "This is a very important distinction to make.\n", + "\n", + "The first fact is a good thing\n", + "============================== Generation example 4 ==============================\n", + "Prompt:\n", + "Some facts about NVIDIA:\n", + "Generated text:\n", + "\n", + "\n", + "* NVIDIA is a global technology company that designs and develops graphics processing units (GPUs) for the gaming and professional markets.\n", + "* NVIDIA was founded in 1993 and is headquartered in Santa Clara, California.\n", + "* NVIDIA's\n", + "============================== Generation example 5 ==============================\n", + "Prompt:\n", + "Here are the two facts about GPUs:\n", + "Generated text:\n", + "\n", + "\n", + "1. GPUs are very good at doing the same thing over and over again.\n", + "2. GPUs are very bad at doing different things at the same time.\n", + "\n", + "This is a very important distinction to make.\n", + "\n", + "The first fact is a good thing\n" + ] + } + ], + "source": [ + "#Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "# Import necessary packages and methods\n", + "from utils import *\n", + "\n", + "hyperparams.model_name = \"/tmp/gemma-7b-hf/\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", + "hyperparams.fuse_qkv_params = True # Needed for fp8_model_init().\n", + "hyperparams.qkv_format = \"thd\"\n", + "\n", + "hyperparams.fp8 = True\n", + "hyperparams.fp8_model_init = True # This will result in storing only fp8 weights.\n", + "hyperparams.fp8_model_weights_filename = \"calibrated_weights.pth\" # <== Add calibrated weights location here.\n", + "\n", + "hyperparams.generation_cuda_graphs = True\n", + "hyperparams.cuda_graphs_static_batch_size = 64\n", + "hyperparams.cuda_graphs_static_max_seq_len = 1024\n", + "hyperparams.cuda_graphs_static_max_context_len = 128\n", + "model = init_te_gemma_model(hyperparams)\n", + "\n", + "print_sample_of_generated_texts(model)\n", + "# benchmark_generation(model)" + ] + }, + { + "cell_type": "markdown", + "id": "3e30ca5a", + "metadata": {}, + "source": [ + "| Models | Time (s) | Speedup | \n", + "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n", + "| HF (baseline) | 87.68 | 1 |\n", + "| TE (subsitution of GemmaDecoderLayer with te.TransformerLayer) | 54.11 | 1.62 | \n", + "| TE + THD attention | 28.22 | 3.11 | \n", + "| TE + THD attention + CUDA Graphs | 16.75 | 5.23 | \n", + "| TE + THD attention + FP8 | 12.13 | 7.23 | \n", + "\n", + "The final speedup is **7.23x**." + ] + }, + { + "cell_type": "markdown", + "id": "c6e87275", + "metadata": {}, + "source": [ + "## Conclusions" + ] + }, + { + "cell_type": "markdown", + "id": "7bb2452d", + "metadata": {}, + "source": [ + "\n", + "
\n", + "\n", + "
\n", + " Figure 11: Times obtained with optimizations using TransformerEngine (seconds).\n", + "
\n", + "
\n", + "\n", + "In this tutorial, we've explored three features of the Transformer Engine:\n", + "1. Support for the THD attention layout,\n", + "2. Integration with CUDA Graphs,\n", + "3. FP8 weights calibration,\n", + "4. Models containing only FP8 version of their parameters.\n", + "\n", + "Each of these features can be applied in various contexts, such as fast token generation. It's important to note that the fastest possible inference speeds can be achieved using NVIDIA's inference-optimized [TensorRT](https://developer.nvidia.com/tensorrt) library." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/examples/te_gemma/utils.py b/docs/examples/te_gemma/utils.py new file mode 100755 index 0000000000..27e07ee15a --- /dev/null +++ b/docs/examples/te_gemma/utils.py @@ -0,0 +1,367 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import time +import sys +import IPython +import random +import string + +from te_gemma_loading_weights import load_te_model +from te_llama_loading_weights import load_te_model as load_te_model_llama +import torch +from torch.optim import AdamW +from torch.utils.data import DataLoader + +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + get_linear_schedule_with_warmup, + AutoConfig, +) +from transformers import DataCollatorForLanguageModeling +from datasets import load_dataset +from accelerate import Accelerator +from accelerate.utils.dataclasses import FP8RecipeKwargs + + +from te_gemma import TEGemmaForCausalLM, TEGemmaForCausalLMCudaGraphs +from te_llama import TELlamaForCausalLM, TELlamaForCausalLMCudaGraphs + + +class HyperParameters: + def __init__(self): + self.mixed_precision = "bf16" + self.model_name = None + + self.fp8 = False + + # Weights in fp8 + self.fp8_model_weights_filename = None + self.fp8_model_init = False + + # Cuda graphs + self.generation_cuda_graphs = False + self.cuda_graphs_static_batch_size = 16 + self.cuda_graphs_static_max_seq_len = 256 + self.cuda_graphs_static_max_context_len = 16 + + # Finetuning settings. + self.dataset_name = "timdettmers/openassistant-guanaco" + self.dataset_text_field = "text" + self.learning_rate = 1.41e-5 + self.batch_size = 8 + self.max_seq_length = 256 + self.gradient_accumulation_steps = 1 + self.num_warmup_steps = 5 + self.num_training_steps = 10 + + # QKV format. + self.fuse_qkv_params = False + self.qkv_format = "bshd" + + +hyperparams = HyperParameters() + +assert ( + torch.backends.cudnn.version() >= 90100 +), "cuDNN version >= 9.1.0 is needed to run this tutorial." + + +def get_dataloaders(accelerator: Accelerator, hyperparams): + dataset = load_dataset(hyperparams.dataset_name, split="train") + tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) + if getattr(tokenizer, "pad_token", None) is None: + tokenizer.pad_token = tokenizer.eos_token + + def tokenize(element): + outputs = tokenizer( + element["text"], + truncation=True, + padding=False, + max_length=hyperparams.max_seq_length, + return_overflowing_tokens=False, + return_length=False, + ) + return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]} + + with accelerator.main_process_first(): + dataset = dataset.map(tokenize, batched=True, remove_columns=dataset.column_names) + + # Simply pad to the multiple of 16 for both FP8 and BF16 precision + pad_to_multiple_of = 16 + data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False, + pad_to_multiple_of=pad_to_multiple_of, + ) + + dataloader_params = { + "batch_size": hyperparams.batch_size, + "collate_fn": data_collator, + "drop_last": True, + } + train_dataloader = DataLoader(dataset, **dataloader_params) + return train_dataloader + + +def init_baseline_model(hyperparams): + # Init the model + config = AutoConfig.from_pretrained(hyperparams.model_name) + # make sure to use flash_attention to do iso comparison with TEGemmaModel + config._attn_implementation = "flash_attention_2" + model = AutoModelForCausalLM.from_pretrained( + hyperparams.model_name, + config=config, + torch_dtype=torch.bfloat16, + ) + return model.cuda() + + +def init_te_llama_model(hyperparams): + cls = TELlamaForCausalLMCudaGraphs if hyperparams.generation_cuda_graphs else TELlamaForCausalLM + config = AutoConfig.from_pretrained(hyperparams.model_name) + config._attn_implementation = "flash_attention_2" + # config.hidden_size = 1024 + # config.head_dim = 128 + print(config) + # Adding all params from the hyperparams to the config to make the code simpler. + for key, value in hyperparams.__dict__.items(): + setattr(config, key, value) + model = load_te_model_llama(cls, config) + if hyperparams.generation_cuda_graphs: + model.record() + return model.cuda() + + +def init_te_gemma_model(hyperparams): + cls = TEGemmaForCausalLMCudaGraphs if hyperparams.generation_cuda_graphs else TEGemmaForCausalLM + config = AutoConfig.from_pretrained(hyperparams.model_name) + config._attn_implementation = "flash_attention_2" + # config.hidden_size = 1024 + # config.head_dim = 128 + print(config) + # Adding all params from the hyperparams to the config to make the code simpler. + for key, value in hyperparams.__dict__.items(): + setattr(config, key, value) + model = load_te_model(cls, config) + if hyperparams.generation_cuda_graphs: + model.record() + return model.cuda() + + +def wrap_with_accelerator(model, hyperparams): + # Create FP8 kwarg handler if required + fp8_kwarg_handler = ( + [FP8RecipeKwargs(backend="te")] if hyperparams.mixed_precision == "fp8" else None + ) + + # Init HF accelerator that's used for training + accelerator = Accelerator( + log_with="wandb", + gradient_accumulation_steps=hyperparams.gradient_accumulation_steps, + mixed_precision=hyperparams.mixed_precision, + kwargs_handlers=fp8_kwarg_handler, + ) + # accelerator.print(f'State: {accelerator.state}') + train_dataloader = get_dataloaders(accelerator, hyperparams) + + # Wrap model, optimizer/scheduler, dataloaders in accelerate + optimizer = AdamW(params=model.parameters(), lr=hyperparams.learning_rate, fused=True) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=100, + num_training_steps=hyperparams.num_training_steps, + ) + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + return accelerator, model, optimizer, train_dataloader, lr_scheduler + + +def finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler): + model.train() + optimizer.zero_grad() + train_dataloader = enumerate(train_dataloader) + + def run_iters(num_iters): + for _ in range(num_iters): + _, batch = next(train_dataloader) + with accelerator.accumulate(model): + outputs = model(**batch) + loss = outputs.loss + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + run_iters(hyperparams.num_warmup_steps) # Warmup iters + + # Get the timers ready + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + + start.record() + run_iters(hyperparams.num_training_steps) # Training iters + torch.cuda.synchronize() + end.record() + accelerator.end_training() + + print( + f"""{hyperparams.num_training_steps} finetuning steps complete!\n + Average time taken per step: + {(start.elapsed_time(end)/hyperparams.num_training_steps):.0f} + milliseconds""" + ) + + +def restart_jupyter_notebook(): + # Try restarting the Jupyter kernel + IPython.Application.instance().kernel.do_shutdown(True) + + # Check whether the device memory has been flushed + if torch.cuda.memory_allocated() != 0: + import warnings + + warnings.warn("The device memory hasn't been flushed, trying with a second method!") + + # Try restarting the Jupyter kernel another way + # Restart the kernel + from IPython.core.display import HTML + + HTML("") + + if torch.cuda.memory_allocated() != 0: + print( + "The device memory hasn't been flushed, try manually restarting the Jupyter kernel!" + ) + + # Suppress the warnings + if not sys.warnoptions: + import warnings + + warnings.simplefilter("ignore") + torch.set_warn_always(False) + + +@torch.no_grad() +def run_forward_pass(model, hyperparams, num_iters): + """ + It runs num_iters forward passes with sample data. + """ + accelerator = Accelerator( + log_with="wandb", + gradient_accumulation_steps=hyperparams.gradient_accumulation_steps, + mixed_precision="no", + ) + train_dataloader = get_dataloaders(accelerator, hyperparams) + + # @sudhakars: what's the point of calling `model.train` inside `no_grad` + # context? + model.train() + train_dataloader = enumerate(train_dataloader) + + for _ in range(num_iters): + _, batch = next(train_dataloader) + batch["input_ids"] = batch["input_ids"].cuda() + batch["attention_mask"] = batch["attention_mask"].cuda() + model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]) + + +""" + Benchmarking and example generation functions. +""" + + +def print_sample_of_generated_texts(model): + tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) + if getattr(tokenizer, "pad_token", None) is None: + tokenizer.pad_token = tokenizer.eos_token + prompts = ["Here are the two facts about GPUs:", "Some facts about NVIDIA:"] + prompts *= 32 + inputs = tokenizer(prompts, return_tensors="pt", padding=True) + + max_length = inputs["input_ids"].size(1) + new_length = ((max_length + 63) // 64) * 128 + + # Add padding to the left + inputs["input_ids"] = torch.nn.functional.pad( + inputs["input_ids"], (new_length - max_length, 0), value=tokenizer.pad_token_id + ) + + # Add padding to the left (only intended for baseline generation with HF + # which expects padding to the left) + inputs["attention_mask"] = torch.nn.functional.pad( + inputs["attention_mask"], (new_length - max_length, 0), value=0 + ) + + inputs["input_ids"] = inputs["input_ids"].cuda() + inputs["attention_mask"] = inputs["attention_mask"].cuda() + + outputs = model.generate(**inputs, max_new_tokens=50) + generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) + + def print_output(prompts, generated_texts, idx): + print("=" * 30 + f" Generation example {idx+1} " + "=" * 30) + print("Prompt:") + print(generated_texts[idx][: len(prompts[idx])]) + print("Generated text:") + print(generated_texts[idx][len(prompts[idx]) :]) + + for i in range(5): + print_output(prompts, generated_texts, i) + + +def _generate_random_words(num_words, max_word_length): + words = [] + for _ in range(num_words): + word_length = random.randint(1, max_word_length) + word = "".join(random.choices(string.ascii_lowercase, k=word_length)) + words.append(word) + return words + + +def benchmark_generation(model): + batch_size = 64 + context_length = 128 + max_new_tokens = 156 - 128 + print("=" * 30 + " Benchmarking " + "=" * 30) + print( + f"Benchmarking for batch_size = {batch_size} and max total tokens =" + f" {context_length + max_new_tokens}" + ) + + input_str = _generate_random_words(batch_size, context_length) + + tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) + inputs = tokenizer(input_str, return_tensors="pt", padding=True) + + max_length = inputs["input_ids"].size(1) + + # Add padding to the left + inputs["input_ids"] = torch.nn.functional.pad( + inputs["input_ids"], (context_length - max_length, 0), value=tokenizer.pad_token_id + ) + + # Add padding to the left (only intended for baseline generation with HF + # which expects padding to the left) + inputs["attention_mask"] = torch.nn.functional.pad( + inputs["attention_mask"], (context_length - max_length, 0), value=0 + ) + + inputs["input_ids"] = inputs["input_ids"].cuda() + inputs["attention_mask"] = inputs["attention_mask"].cuda() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start.record() + + model.generate(inputs["input_ids"].cuda(), max_new_tokens=max_new_tokens) + torch.cuda.synchronize() + end.record() + + print(f"Time: {start.elapsed_time(end)/1000:.2f} s.") diff --git a/transformer_engine/common/fused_attn/kv_cache.cu b/transformer_engine/common/fused_attn/kv_cache.cu index af69faaabe..ea468e435b 100644 --- a/transformer_engine/common/fused_attn/kv_cache.cu +++ b/transformer_engine/common/fused_attn/kv_cache.cu @@ -116,14 +116,14 @@ void copy_to_kv_cache_launcher(Tensor new_k, Tensor new_v, Tensor k_cache, Tenso bool is_non_paged, cudaStream_t stream) { if (new_k.has_data() && new_v.has_data() && k_cache.has_data() && v_cache.has_data()) { if (is_non_paged) { - reindex_kv_cache_kernel<<<16, 256, 0, stream>>>( + reindex_kv_cache_kernel<<<128, 1024, 0, stream>>>( reinterpret_cast(k_cache.data.dptr), reinterpret_cast(v_cache.data.dptr), reinterpret_cast(page_table.data.dptr), reinterpret_cast(cu_new_lens.data.dptr), reinterpret_cast(cu_cached_lens.data.dptr), h_kv, d_k, d_v, b, max_seq_len); } - copy_to_kv_cache_kernel<<<16, 256, 0, stream>>>( + copy_to_kv_cache_kernel<<<128, 1024, 0, stream>>>( reinterpret_cast(new_k.data.dptr), reinterpret_cast(new_v.data.dptr), reinterpret_cast(k_cache.data.dptr), reinterpret_cast(v_cache.data.dptr), reinterpret_cast(page_table.data.dptr), diff --git a/transformer_engine/pytorch/attention/inference.py b/transformer_engine/pytorch/attention/inference.py index 8267bf63c7..62a724ef79 100644 --- a/transformer_engine/pytorch/attention/inference.py +++ b/transformer_engine/pytorch/attention/inference.py @@ -214,6 +214,11 @@ def __init__( dtype=torch.int32, device=torch.cuda.current_device(), ) + self.pre_step_seqlens = torch.zeros( + self.max_batch_size, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) def reset(self): """Reset InferenceParams state""" @@ -266,6 +271,11 @@ def pre_step( for k, v in self.sequences.items(): self.sequences_pre_step[k] = v - step_dict[k] + pre_step_seqlens = torch.Tensor(list(self.sequences_pre_step.values())).to( + dtype=torch.int32, device="cpu" + ) + self.pre_step_seqlens[: len(pre_step_seqlens)].copy_(pre_step_seqlens, non_blocking=True) + seqlens_q = list(step_dict.values()) cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, self.batch_size + 1)] cu_seqlens_q = cu_seqlens_q + [cu_seqlens_q[-1]] * (self.max_batch_size - self.batch_size) @@ -280,9 +290,13 @@ def pre_step( def get_seqlens_pre_step(self): """Get cached sequence lengths before the stepping""" - return torch.Tensor(list(self.sequences_pre_step.values())).to( - dtype=torch.int32, device="cpu" - ) + # seqlens = torch.Tensor(list(self.sequences_pre_step.values())).to( + # dtype=torch.int32, device="cpu" + # ) + # # return seqlens.cuda() + # self.cu_pre_step_seqlens[:len(seqlens)].copy_(seqlens, non_blocking=True) + # return self.cu_pre_step_seqlens + return self.pre_step_seqlens def convert_paged_to_nonpaged(self, layer_number: int): """ @@ -455,14 +469,14 @@ def pre_step( finished_seqs = self.sequences.keys() - unfinished_seqs unfinished_indices = [i for i, j in enumerate(self.sequences) if j in unfinished_seqs] finished_indices = [i for i, j in enumerate(self.sequences) if j in finished_seqs] - self.batch_indices.copy_( + self.batch_indices.data[:].copy_( torch.Tensor( ( unfinished_indices + finished_indices + list(range(prev_batch_size, self.max_batch_size)) ) - ).to(dtype=torch.int32, device="cpu") + ) ) # Advance unfinished sequences diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index f018465dc1..cac16d9de6 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -768,8 +768,8 @@ def forward( # sequence_start = inference_params.seqlens[0] sequence_end = sequence_start + sequence_length - q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...] - k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...] + # q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...] + # k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...] query_layer = apply_rotary_pos_emb( query_layer, @@ -779,6 +779,7 @@ def forward( cu_seqlens=cu_seqlens_q, cp_size=self.cp_size, cp_rank=self.cp_rank, + start_positions=sequence_start, interleaved=self.rotary_pos_interleaved, ) key_layer = apply_rotary_pos_emb( @@ -789,6 +790,7 @@ def forward( cu_seqlens=cu_seqlens_kv, cp_size=self.cp_size, cp_rank=self.cp_rank, + start_positions=sequence_start, interleaved=self.rotary_pos_interleaved, ) diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp index fe640f67c0..781002f154 100644 --- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp +++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp @@ -28,9 +28,10 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, auto freqs_cu = makeTransformerEngineTensor(freqs); auto output_cu = makeTransformerEngineTensor(output); - auto start_positions_cu = TensorWrapper(); // empty cu_seqlens tensor + auto start_positions_cu = TensorWrapper(); // empty start_positions tensor if (start_positions) { start_positions_cu = makeTransformerEngineTensor(start_positions.value()); + TORCH_CHECK(start_positions_cu.ndim() == 1, "expected 1D tensor"); } if (qkv_format == NVTE_QKV_Format::NVTE_THD) {