diff --git a/docs/examples/te_gemma/check_cuda_graphs.py b/docs/examples/te_gemma/check_cuda_graphs.py
new file mode 100644
index 0000000000..aee35f6911
--- /dev/null
+++ b/docs/examples/te_gemma/check_cuda_graphs.py
@@ -0,0 +1,67 @@
+import torch
+from transformer_engine.pytorch import Linear, LayerNorm
+
+
+# 1. Define model with static buffers
+class TE_Model(torch.nn.Module):
+ def __init__(self, max_seq_len=4096):
+ super().__init__()
+ self.max_seq_len = max_seq_len
+ self.ln = LayerNorm(1024)
+ self.attn_proj = Linear(1024, 1024)
+
+ # Pre-allocate static buffers
+ self.register_buffer("kv_cache", torch.zeros(max_seq_len, 1024, device="cuda"))
+ self.register_buffer(
+ "attn_mask", torch.tril(torch.ones(max_seq_len, max_seq_len, device="cuda"))
+ )
+
+ def forward(self, hidden_states, seq_start: int):
+ # Dynamic slicing of static buffers
+ seq_len = hidden_states.size(1)
+ current_mask = self.attn_mask[seq_start : seq_start + seq_len, :seq_len]
+
+ x = self.ln(hidden_states)
+ x = self.attn_proj(x)
+ # Update KV cache (in-place)
+ self.kv_cache[seq_start : seq_start + seq_len].copy_(x)
+ return x
+
+
+# 2. Create graphable callables
+model = TE_Model().cuda()
+static_input = torch.randn(8, 256, 1024, device="cuda") # (batch, seq, hidden)
+seq_start = torch.tensor(0, device="cuda")
+
+# Wrap with CUDA Graphs
+graph_model = torch.cuda.make_graphed_callables(
+ [model], # Module list
+ sample_args=[(static_input, seq_start)], # Must match actual input structure
+ # memory_pool=torch.cuda.graphs.graph_pool_handle(),
+ allow_unused_input=False,
+)
+
+
+# 3. Warmup and execution
+def run_inference(x, seq_start):
+ # Inputs must match sample_args' device/type/shape
+ x = x.to("cuda", non_blocking=True).requires_grad_(False)
+ seq_start = seq_start.to("cuda", non_blocking=True)
+
+ with torch.cuda.amp.autocast():
+ return graph_model(x, seq_start)
+
+
+# Warm-up (essential for TE's kernel auto-tuner)
+for _ in range(3):
+ _ = run_inference(static_input, seq_start)
+torch.cuda.synchronize()
+
+
+# 4. Usage with dynamic sequence lengths
+def process_batch(inputs, start_pos):
+ # inputs: (batch, seq) on CPU
+ inputs_gpu = inputs.to("cuda", non_blocking=True)
+
+ # Output shares memory with pre-allocated buffers
+ return run_inference(inputs_gpu, start_pos)
diff --git a/docs/examples/te_gemma/check_gemm.py b/docs/examples/te_gemma/check_gemm.py
new file mode 100755
index 0000000000..1ed6edd23a
--- /dev/null
+++ b/docs/examples/te_gemma/check_gemm.py
@@ -0,0 +1,137 @@
+import functools
+from typing import Optional, Tuple, Union, List
+import torch
+import transformer_engine as te
+import transformer_engine_torch as tex
+from transformer_engine.pytorch.constants import TE_DType
+from transformer_engine.pytorch.utils import assert_dim_for_fp8_exec
+from transformer_engine.pytorch.module.base import get_workspace
+import transformer_engine.pytorch.cpp_extensions as cpp_tex
+
+
+@functools.lru_cache(maxsize=None)
+def _empty_tensor() -> torch.Tensor:
+ """Get tensor with no entries and no data"""
+ return torch.Tensor()
+
+
+def gemm(
+ A: torch.Tensor,
+ B: torch.Tensor,
+ dtype: torch.dtype,
+ workspace: torch.Tensor,
+ gelu: bool = False,
+ gelu_input: Optional[torch.Tensor] = None,
+ grad: bool = False,
+ accumulate: bool = False,
+ layout: str = "TN",
+ out: Optional[torch.Tensor] = None,
+ bias: Optional[torch.Tensor] = None,
+ use_bias: bool = False,
+ ub_algo: tex.CommOverlapAlgo = None,
+ ub: Union[tex.CommOverlap, tex.CommOverlapP2P] = None,
+ extra_output_tensor: torch.Tensor = None,
+) -> Tuple[Union[torch.Tensor, None], ...]:
+ """Non FP8 GEMM."""
+
+ assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported."
+ transa = layout[0] == "T"
+ transb = layout[1] == "T"
+ empty_tensor = _empty_tensor()
+ fp8_index = -1 # dummy index
+
+ if out is None:
+ out = torch.empty(
+ B.shape[1] if transb else B.shape[0],
+ A.shape[0] if transa else A.shape[1],
+ dtype=dtype,
+ device="cuda",
+ )
+ else:
+ if not out.is_contiguous():
+ raise ValueError("Output tensor is not contiguous.")
+
+ if gelu and not grad:
+ gelu_input = torch.empty_like(out, dtype=dtype)
+ elif not gelu:
+ gelu_input = empty_tensor
+
+ if grad and use_bias:
+ grad_bias = torch.empty(B.shape[1], dtype=out.dtype, device="cuda")
+ else:
+ grad_bias = empty_tensor
+
+ bias = bias if use_bias else empty_tensor
+
+ assert (
+ A.dtype == dtype and B.dtype == dtype
+ ), f"Expected dtype={dtype}, but found A.dtype={A.dtype} and B.dtype={B.dtype}"
+ input_dtype = TE_DType[dtype]
+ output_dtype = TE_DType[out.dtype]
+ if use_bias:
+ bias_dtype = TE_DType[grad_bias.dtype] if grad else TE_DType[bias.dtype]
+ else:
+ bias_dtype = output_dtype
+
+ args = (
+ A,
+ empty_tensor,
+ fp8_index,
+ input_dtype,
+ transa,
+ B,
+ empty_tensor,
+ fp8_index,
+ input_dtype,
+ transb,
+ out,
+ empty_tensor, # out_scale
+ output_dtype,
+ empty_tensor, # out_amax
+ grad_bias if grad else bias,
+ bias_dtype,
+ gelu_input,
+ grad,
+ workspace,
+ workspace.shape[0],
+ accumulate,
+ False, # use_split_accumulator
+ )
+ fn = torch.ops.tex_ts.te_gemm_ts
+ if ub_algo is not None:
+ assert ub is not None, "ub object is None!"
+ _ = fn(*args)
+
+ import pdb
+
+ pdb.set_trace()
+ return out, grad_bias, gelu_input
+
+
+if __name__ == "__main__":
+ fc2_weight = torch.load("fc2_weight.pth").cuda()
+
+ base_repo = "/perfhome/mnt/wkstn/work/repos/te_gemma_gen_support/TransformerEngine/docs/examples/te_gemma/"
+ base_repo = ""
+ gelu_out = torch.load(base_repo + "gelu_out.pth").cuda()
+
+ activation_dtype = torch.bfloat16
+ fc2_bias = _empty_tensor()
+ use_fc2_bias = False
+
+ dim_size = list(gelu_out.size())
+ dim_size[1] = fc2_weight.size(0)
+ fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device)
+
+ _ = cpp_tex.gemm(
+ fc2_weight,
+ gelu_out,
+ activation_dtype,
+ get_workspace(),
+ bias=fc2_bias,
+ use_bias=use_fc2_bias,
+ out=fc2_out,
+ ub_algo=None,
+ ub=None,
+ extra_output_tensor=None,
+ )
diff --git a/docs/examples/te_gemma/check_rope.ipynb b/docs/examples/te_gemma/check_rope.ipynb
new file mode 100755
index 0000000000..26d5c9058f
--- /dev/null
+++ b/docs/examples/te_gemma/check_rope.ipynb
@@ -0,0 +1,716 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "72f61b51-b6fc-4463-9783-d42a25ca3a2f",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "before tex import\n"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "import torch.nn.functional as F\n",
+ "import math\n",
+ "print(\"before tex import\")\n",
+ "import transformer_engine as te\n",
+ "import transformer_engine_torch as tex"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "1f81be75-bf64-43b2-852a-7c482a1c3418",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from transformer_engine.pytorch.attention import apply_rotary_pos_emb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "8853f973-d834-41a9-929d-8687b947134f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def compare_rope_outputs(t, freqs_s11d, freqs_sb1d):\n",
+ " output1 = tex.fused_rope_forward(t, freqs_s11d, torch.Tensor(), False)\n",
+ " output2 = tex.fused_rope_forward(t, freqs_sb1d, torch.Tensor(), False)\n",
+ " print(output1, output2, sep=\"\\n\")\n",
+ " assert torch.allclose(output1, output2)\n",
+ " return output1, output2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "6b7bada1-6748-46f1-93a4-c2ac1a617063",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "torch.manual_seed(0)\n",
+ "b = 2\n",
+ "s = 3\n",
+ "h = 2\n",
+ "d = 4"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "54a8f6d6-28f8-4a9a-8ba0-0fdefff138e7",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "torch.Size([3, 1, 1, 4]) torch.Size([3, 2, 1, 4])\n"
+ ]
+ }
+ ],
+ "source": [
+ "freqs_s11d = torch.ones(s, 1, 1, d).cuda() * math.pi/4\n",
+ "freqs_sb1d = freqs_s11d.broadcast_to(s, b, 1, d).clone()\n",
+ "t = torch.ones(s, b, h, d).cuda()\n",
+ "\n",
+ "print(freqs_s11d.shape, freqs_sb1d.shape)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "5070307a-3104-401b-b84c-00f3bbf02ccc",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[[[0.7854, 0.7854, 0.7854, 0.7854]]],\n",
+ "\n",
+ "\n",
+ " [[[0.7854, 0.7854, 0.7854, 0.7854]]],\n",
+ "\n",
+ "\n",
+ " [[[0.7854, 0.7854, 0.7854, 0.7854]]]], device='cuda:0')"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "freqs_s11d"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "81e52785-e6ad-4180-9567-564af692375c",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(4, 4, 4, 1)"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "freqs_s11d.stride()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "0da9bc09-7e1e-4056-85eb-64b6122c7440",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "4, 0\n",
+ "4, 4, 4, 1, \n",
+ "nvt_fused_rope_fwd: 4, 0fused_rope_fwd: 4, 0fused_rope_fwd_launcher: 4, 0thread_id: 0, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n"
+ ]
+ }
+ ],
+ "source": [
+ "output = tex.fused_rope_forward(t, freqs_s11d, torch.Tensor(), False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "1b78017d-09b3-4b5f-93a8-75f6ba6f131c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "output_unfused=apply_rotary_pos_emb(\n",
+ " t,\n",
+ " freqs_s11d,\n",
+ " tensor_format=\"sbhd\",\n",
+ " fused=False,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "6f5d9350-deb1-48ef-a0a2-e18e01ed336f",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n",
+ " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]],\n",
+ "\n",
+ " [[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n",
+ " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]]],\n",
+ "\n",
+ "\n",
+ " [[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n",
+ " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]],\n",
+ "\n",
+ " [[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n",
+ " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]]],\n",
+ "\n",
+ "\n",
+ " [[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n",
+ " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]],\n",
+ "\n",
+ " [[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n",
+ " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]]]],\n",
+ " device='cuda:0')"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "output_unfused"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "b01e29b8-dfdf-41ac-81a5-d8edf6a8c168",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "4, 0\n",
+ "4, 4, 4, 1, \n",
+ "nvt_fused_rope_fwd: 4, 0fused_rope_fwd: 4, 0fused_rope_fwd_launcher: 4, 08, 4\n",
+ "8, 4, 4, 1, \n",
+ "nvt_fused_rope_fwd: 8, 4fused_rope_fwd: 8, 4fused_rope_fwd_launcher: 8, 4thread_id: 0, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 0, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 1, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 2, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 0, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 2, b_id: 0, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 1, b_id: 1, freq_stride_s: 4, freq_stride_b: 0, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 2, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 2, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 2, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 2, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 2, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 2, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 2, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 2, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 0, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 0, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 0, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 0, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 0, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 0, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 0, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 0, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 2, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 2, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 2, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 2, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 2, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 2, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 2, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 2, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 0, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 0, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 0, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 0, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 0, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 0, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 0, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 0, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 0, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 0, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 0, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 0, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 0, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 0, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 0, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 0, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 0, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 0, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 0, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 0, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 0, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 0, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 0, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 0, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 2, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 2, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 2, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 2, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 2, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 2, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 2, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 2, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 1, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 1, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 1, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 1, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 1, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 1, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 1, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 1, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 2, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 2, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 2, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 2, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 2, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 2, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 2, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 2, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 1, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 1, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 1, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 1, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 1, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 1, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 1, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 1, b_id: 1, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 1, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 1, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 1, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 1, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 1, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 1, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 1, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 1, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 1, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 1, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 1, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 1, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 0, s_id: 1, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 1, s_id: 1, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 2, s_id: 1, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "thread_id: 3, s_id: 1, b_id: 0, freq_stride_s: 8, freq_stride_b: 4, freq: 0.785398, v_sin: 0.707107, v_cos: 0.707107\n",
+ "tensor([[[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n",
+ " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]],\n",
+ "\n",
+ " [[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n",
+ " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]]],\n",
+ "\n",
+ "\n",
+ " [[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n",
+ " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]],\n",
+ "\n",
+ " [[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n",
+ " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]]],\n",
+ "\n",
+ "\n",
+ " [[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n",
+ " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]],\n",
+ "\n",
+ " [[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n",
+ " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]]]],\n",
+ " device='cuda:0')\n",
+ "tensor([[[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n",
+ " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]],\n",
+ "\n",
+ " [[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n",
+ " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]]],\n",
+ "\n",
+ "\n",
+ " [[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n",
+ " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]],\n",
+ "\n",
+ " [[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n",
+ " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]]],\n",
+ "\n",
+ "\n",
+ " [[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n",
+ " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]],\n",
+ "\n",
+ " [[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n",
+ " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]]]],\n",
+ " device='cuda:0')\n"
+ ]
+ }
+ ],
+ "source": [
+ "output1, output2 = compare_rope_outputs(t, freqs_s11d, freqs_sb1d)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "b168b178-1f63-4ccc-b084-2ac2c1ec016b",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "torch.Size([6, 1, 1, 4]) torch.Size([6, 2, 1, 4])\n"
+ ]
+ }
+ ],
+ "source": [
+ "freqs_s11d = torch.randn(s, 1, 1, d).cuda()\n",
+ "freqs_sb1d = freqs_s11d.broadcast_to(s, b, 1, d).clone()\n",
+ "\n",
+ "print(freqs_s11d.shape, freqs_sb1d.shape)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "33ec2e07-6e54-49f7-92f7-2f217a766456",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "tensor([[[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n",
+ " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]],\n",
+ "\n",
+ " [[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n",
+ " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]]],\n",
+ "\n",
+ "\n",
+ " [[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n",
+ " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]],\n",
+ "\n",
+ " [[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n",
+ " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]]],\n",
+ "\n",
+ "\n",
+ " [[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n",
+ " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]],\n",
+ "\n",
+ " [[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n",
+ " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]]],\n",
+ "\n",
+ "\n",
+ " [[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n",
+ " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]],\n",
+ "\n",
+ " [[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.0000e+00],\n",
+ " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.0000e+00]]]],\n",
+ " device='cuda:0')\n",
+ "tensor([[[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n",
+ " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]],\n",
+ "\n",
+ " [[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n",
+ " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]]],\n",
+ "\n",
+ "\n",
+ " [[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n",
+ " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]],\n",
+ "\n",
+ " [[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n",
+ " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]]],\n",
+ "\n",
+ "\n",
+ " [[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n",
+ " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]],\n",
+ "\n",
+ " [[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n",
+ " [-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00]]],\n",
+ "\n",
+ "\n",
+ " [[[-5.9605e-08, -5.9605e-08, 1.4142e+00, 1.4142e+00],\n",
+ " [ 7.0711e-01, 7.0711e-01, 7.0711e-01, 7.0711e-01]],\n",
+ "\n",
+ " [[ 7.0711e-01, 7.0711e-01, 7.0711e-01, 7.0711e-01],\n",
+ " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]]]],\n",
+ " device='cuda:0')\n"
+ ]
+ },
+ {
+ "ename": "AssertionError",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[9], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m output1, output2 \u001b[38;5;241m=\u001b[39m \u001b[43mcompare_rope_outputs\u001b[49m\u001b[43m(\u001b[49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfreqs_s11d\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfreqs_sb1d\u001b[49m\u001b[43m)\u001b[49m\n",
+ "Cell \u001b[0;32mIn[8], line 5\u001b[0m, in \u001b[0;36mcompare_rope_outputs\u001b[0;34m(t, freqs_s11d, freqs_sb1d)\u001b[0m\n\u001b[1;32m 3\u001b[0m output2 \u001b[38;5;241m=\u001b[39m tex\u001b[38;5;241m.\u001b[39mfused_rope_forward(t, freqs_sb1d, torch\u001b[38;5;241m.\u001b[39mTensor(), \u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28mprint\u001b[39m(output1, output2, sep\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mallclose(output1, output2)\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output1, output2\n",
+ "\u001b[0;31mAssertionError\u001b[0m: "
+ ]
+ }
+ ],
+ "source": [
+ "output1, output2 = compare_rope_outputs(t, freqs_s11d, freqs_sb1d)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3b58b818-7b31-4ecd-80bd-b5ba049b3c2e",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "before tex import\n"
+ ]
+ }
+ ],
+ "source": [
+ "freqs_s11d = torch.randn(s, 1, 1, d).cuda()\n",
+ "print(freqs_s11d)\n",
+ "freqs_sb1d = freqs_s11d.broadcast_to(s, b, 1, d).clone()\n",
+ "print(freqs_sb1d)\n",
+ "assert torch.all(torch.eq(freqs_sb1d[:, 0, ...], freqs_sb1d[:, 1, ...]))\n",
+ "\n",
+ "comp"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c04940b8-3056-466b-90f6-07a02ac47ace",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/docs/examples/te_gemma/media/calibration.svg b/docs/examples/te_gemma/media/calibration.svg
new file mode 100755
index 0000000000..b1e1b5ae4b
--- /dev/null
+++ b/docs/examples/te_gemma/media/calibration.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/examples/te_gemma/media/calibration_1_half.svg b/docs/examples/te_gemma/media/calibration_1_half.svg
new file mode 100755
index 0000000000..af2641387f
--- /dev/null
+++ b/docs/examples/te_gemma/media/calibration_1_half.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/examples/te_gemma/media/calibration_2_half.svg b/docs/examples/te_gemma/media/calibration_2_half.svg
new file mode 100755
index 0000000000..2d56f7d434
--- /dev/null
+++ b/docs/examples/te_gemma/media/calibration_2_half.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/examples/te_gemma/media/fp8_model_init.svg b/docs/examples/te_gemma/media/fp8_model_init.svg
new file mode 100755
index 0000000000..c7fce2120d
--- /dev/null
+++ b/docs/examples/te_gemma/media/fp8_model_init.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/examples/te_gemma/media/fp8_model_init_1_half.svg b/docs/examples/te_gemma/media/fp8_model_init_1_half.svg
new file mode 100755
index 0000000000..3b217a3eb2
--- /dev/null
+++ b/docs/examples/te_gemma/media/fp8_model_init_1_half.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/examples/te_gemma/media/fp8_model_init_2_half.svg b/docs/examples/te_gemma/media/fp8_model_init_2_half.svg
new file mode 100755
index 0000000000..46587664fe
--- /dev/null
+++ b/docs/examples/te_gemma/media/fp8_model_init_2_half.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/examples/te_gemma/media/generation_animation.gif b/docs/examples/te_gemma/media/generation_animation.gif
new file mode 100755
index 0000000000..25150cb9b6
Binary files /dev/null and b/docs/examples/te_gemma/media/generation_animation.gif differ
diff --git a/docs/examples/te_gemma/media/graphs.svg b/docs/examples/te_gemma/media/graphs.svg
new file mode 100755
index 0000000000..f734637e6d
--- /dev/null
+++ b/docs/examples/te_gemma/media/graphs.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/examples/te_gemma/media/graphs_1.png b/docs/examples/te_gemma/media/graphs_1.png
new file mode 100755
index 0000000000..f42b50fe0d
Binary files /dev/null and b/docs/examples/te_gemma/media/graphs_1.png differ
diff --git a/docs/examples/te_gemma/media/graphs_2.png b/docs/examples/te_gemma/media/graphs_2.png
new file mode 100755
index 0000000000..35c34ede55
Binary files /dev/null and b/docs/examples/te_gemma/media/graphs_2.png differ
diff --git a/docs/examples/te_gemma/media/plot.svg b/docs/examples/te_gemma/media/plot.svg
new file mode 100755
index 0000000000..481f156df6
--- /dev/null
+++ b/docs/examples/te_gemma/media/plot.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/examples/te_gemma/media/thd_bshd.svg b/docs/examples/te_gemma/media/thd_bshd.svg
new file mode 100755
index 0000000000..47eed69565
--- /dev/null
+++ b/docs/examples/te_gemma/media/thd_bshd.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/examples/te_gemma/requirements.txt b/docs/examples/te_gemma/requirements.txt
new file mode 100755
index 0000000000..c90fb6dad0
--- /dev/null
+++ b/docs/examples/te_gemma/requirements.txt
@@ -0,0 +1,4 @@
+transformers==4.41.1
+accelerate==0.30.1
+datasets==2.19.1
+sentencepiece==0.2.0
\ No newline at end of file
diff --git a/docs/examples/te_gemma/run_gemma_2b.py b/docs/examples/te_gemma/run_gemma_2b.py
new file mode 100644
index 0000000000..db2fb087c9
--- /dev/null
+++ b/docs/examples/te_gemma/run_gemma_2b.py
@@ -0,0 +1,15 @@
+from transformers import AutoTokenizer, AutoModelForCausalLM
+from huggingface_hub import login
+
+access_token = ""
+login(access_token)
+
+model_name = "google/gemma-3-4b-it"
+tokenizer = AutoTokenizer.from_pretrained(model_name)
+model = AutoModelForCausalLM.from_pretrained(model_name)
+print(model.config)
+input_text = "Write me a poem about Machine Learning."
+input_ids = tokenizer(input_text, return_tensors="pt")
+
+outputs = model.generate(**input_ids)
+print(tokenizer.decode(outputs[0]))
diff --git a/docs/examples/te_gemma/run_generation.py b/docs/examples/te_gemma/run_generation.py
new file mode 100755
index 0000000000..910fa325d0
--- /dev/null
+++ b/docs/examples/te_gemma/run_generation.py
@@ -0,0 +1,55 @@
+from utils import *
+import transformer_engine.pytorch as te
+
+hyperparams.model_name = ( # "/tmp/gemma-7b-hf/" # <== Add model weight location here e.g. "/path/to/downloaded/gemma/weights"
+ "/perfhome/repos/ckpt/models/gemma-7b-hf/"
+)
+hyperparams.qkv_format = "thd"
+
+run_generation = True
+run_calibration = False
+
+if run_calibration:
+ hyperparams.fuse_qkv_params = True # This is needed by the last improvement.
+
+ model = init_te_gemma_model(hyperparams)
+
+ # Calibration
+ with te.fp8_autocast(enabled=False, calibrating=True), torch.autocast(
+ device_type="cuda", dtype=torch.bfloat16
+ ):
+ model.train()
+ run_forward_pass(model, hyperparams, num_iters=512)
+
+ # Compute scale_fwd with enabled fp8 autocast
+ with te.fp8_autocast(enabled=True), torch.autocast(device_type="cuda", dtype=torch.bfloat16):
+ run_forward_pass(model, hyperparams, 1)
+
+ # Some parameters are in pointing to the same tensors, double save is avoided here.
+ dict_to_save = {
+ k: v
+ for k, v in model.state_dict().items()
+ if ("_context_phase" not in k and "_generation_phase" not in k)
+ }
+ torch.save(dict_to_save, "calibrated_weights.pth") # <== Add path to save calibrated weights.
+
+
+if run_generation:
+
+ # hyperparams.generation_cuda_graphs = False # 4.15s
+ hyperparams.generation_cuda_graphs = True # 4.38s
+
+ if hyperparams.generation_cuda_graphs:
+ # It is necessary to preallocate a static buffer.
+ # CUDA graphs require static input tensors for every kernel.
+ # This approach may result in a slight increase in memory consumption;
+ # however, the substantial speedup achieved makes it worthwhile.
+ hyperparams.cuda_graphs_static_batch_size = 64
+ hyperparams.cuda_graphs_static_max_seq_len = 128
+ hyperparams.cuda_graphs_static_max_context_len = 128
+
+ hyperparams.is_paged = False
+ model = init_te_gemma_model(hyperparams)
+
+ print_sample_of_generated_texts(model)
+ benchmark_generation(model)
diff --git a/docs/examples/te_gemma/run_generation_llama.py b/docs/examples/te_gemma/run_generation_llama.py
new file mode 100755
index 0000000000..1c3e6626ca
--- /dev/null
+++ b/docs/examples/te_gemma/run_generation_llama.py
@@ -0,0 +1,12 @@
+from utils import *
+
+hyperparams.model_name = ( # "/tmp/gemma-7b-hf/" # <== Add model weight location here e.g. "/path/to/downloaded/gemma/weights"
+ "/perfhome/repos/ckpt/models/llama2-7b-hf/"
+)
+hyperparams.qkv_format = "thd"
+
+# model = init_te_llama_model(hyperparams)
+model = init_baseline_model(hyperparams)
+
+print_sample_of_generated_texts(model)
+# benchmark_generation(model)
diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py
new file mode 100755
index 0000000000..706ea16bc4
--- /dev/null
+++ b/docs/examples/te_gemma/te_gemma.py
@@ -0,0 +1,594 @@
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+from contextlib import contextmanager
+
+from typing import Optional
+from functools import partial
+from collections import OrderedDict
+
+import torch
+import transformer_engine as te
+from transformer_engine.pytorch.attention import InferenceParams, RotaryPositionEmbedding
+from transformer_engine.common.recipe import Format, DelayedScaling
+from torch.cuda.amp import autocast
+
+import transformers
+from transformers.models.gemma.modeling_gemma import GemmaForCausalLM, GemmaConfig, GemmaModel
+
+import torch.nn.functional as F
+
+
+class TEGemmaDecoderLayer(te.pytorch.TransformerLayer):
+ """
+ Wrapper class over TE's `TransformerLayer`. This makes the wrapper very
+ similar to HF's `GemmaDecoderLayer` and easier to replace it in the code.
+
+ Args:
+ config: GemmaConfig
+ args: positional args (for compatibility with `GemmaDecoderLayer`)
+ kwargs: keyword args (for compatibility with `GemmaDecoderLayer`)
+ """
+
+ def __init__(self, config: GemmaConfig, layer_idx: int, *args, **kwargs):
+
+ self.gemma_config = config
+
+ super().__init__(
+ hidden_size=config.hidden_size,
+ ffn_hidden_size=config.intermediate_size,
+ num_attention_heads=config.num_attention_heads,
+ bias=False,
+ layernorm_epsilon=config.rms_norm_eps,
+ hidden_dropout=0,
+ attention_dropout=0,
+ fuse_qkv_params=config.fuse_qkv_params,
+ normalization="RMSNorm",
+ activation="geglu",
+ # attn_input_format=config.qkv_format,
+ attn_input_format="bshd",
+ num_gqa_groups=config.num_key_value_heads,
+ kv_channels=self.gemma_config.head_dim,
+ layer_number=(
+ layer_idx + 1
+ ), # Layer numbers in TE starts from 1, not 0 like in the HF.
+ zero_centered_gamma=True,
+ )
+
+ def forward(self, *args, **kwargs): # We need to additionally pass positional encoding.
+
+ # this args cannot be passed to TransformerLayer
+ keys_to_remove = [
+ "position_ids",
+ "past_key_value",
+ "output_attentions",
+ "use_cache",
+ "cache_position",
+ ]
+ for key in keys_to_remove:
+ kwargs.pop(key, None)
+
+ rope_emb = kwargs.pop("rope_emb", None)
+ # We need to return tuple to be compatible with HF.
+ return (super().forward(*args, rotary_pos_emb=rope_emb, **kwargs),)
+
+
+class StaticGemmaModel(torch.nn.Module):
+ """
+ StaticGemma is based of HF GemmaModel class.
+ It is adjusted to work properly with CUDA Graphs.
+ """
+
+ def __init__(
+ self,
+ model: GemmaModel,
+ dtype: torch.dtype,
+ mask: torch.Tensor,
+ lm_head: torch.nn.Module,
+ ):
+ super().__init__()
+ self.model = model
+ self.normalizer = torch.tensor(self.model.config.hidden_size**0.5, dtype=dtype)
+ self.mask = mask
+ self.lm_head = lm_head
+
+ def set_inference_params(self, inference_params):
+ self.inference_params = inference_params
+
+ # @sudhakars: is `arbitrary` fine being the default here?
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor = None,
+ attn_mask_type: str = "arbitrary",
+ rope_emb: torch.Tensor = None,
+ ):
+ # print(f"StaticGemmaModel forward start")
+ with torch.no_grad():
+ # static operation - for CUDA graphs
+ hidden_states.data[:] = hidden_states.data[:] * self.normalizer
+
+ for i, decoder_layer in enumerate(self.model.layers):
+ # print(f"layer {i}")
+ hidden_states.data[:] = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ self_attn_mask_type=self.mask if attn_mask_type is None else attn_mask_type,
+ inference_params=self.inference_params,
+ rope_emb=rope_emb,
+ )[
+ 0
+ ] # static copy - for CUDA graphs
+
+ hidden_states.copy_(self.model.norm(hidden_states)) # static copy - for CUDA graphs
+ logits = self.lm_head(hidden_states)
+ logits = logits.float()
+ return logits, hidden_states
+
+
+class GemmaGenerator(torch.nn.Module):
+ """
+ GemmaGenerator gets one layer of embeddins,
+ makes forward pass and returns next tokens.
+ """
+
+ def __init__(
+ self, model: GemmaModel, lm_head: torch.nn.Module, dtype: torch.dtype, qkv_format: str
+ ):
+ super().__init__()
+ self.model = model
+ self.gemma_layers = StaticGemmaModel(model, dtype, "arbitrary", lm_head)
+ self.qkv_format = qkv_format
+
+ def set_inference_params(self, inference_params):
+ self.inference_params = inference_params
+ self.gemma_layers.set_inference_params(inference_params)
+
+ # @sudhakars: is `arbitrary` a good default value here?
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ mask: torch.Tensor = None,
+ attn_mask_type: str = "arbitrary",
+ rope_emb: torch.Tensor = None,
+ ):
+ logits, _ = self.gemma_layers(
+ hidden_states, attention_mask=mask, attn_mask_type=attn_mask_type, rope_emb=rope_emb
+ )
+
+ assert logits.shape[0] == hidden_states.shape[0] # b
+ assert logits.shape[1] == hidden_states.shape[1] # seq_len
+ # logits.shape[2] = number of tokens
+ logits = logits[:, -1, :]
+ next_tokens = torch.argmax(logits, dim=1)
+
+ # static copy for CUDA graphs
+ hidden_states.copy_(self.model.embed_tokens(next_tokens).unsqueeze(1))
+
+ return next_tokens
+
+
+@contextmanager
+def replace_decoder(te_decoder_cls):
+ """
+ Replace `GemmaDecoderLayer` with custom `TEGemmaDecoderLayer`.
+ """
+ original_gemma_decoder_cls = transformers.models.gemma.modeling_gemma.GemmaDecoderLayer
+ transformers.models.gemma.modeling_gemma.GemmaDecoderLayer = te_decoder_cls
+ try:
+ yield
+ finally:
+ transformers.models.gemma.modeling_gemma.GemmaDecoderLayer = original_gemma_decoder_cls
+
+
+class TEGemmaForCausalLM(GemmaForCausalLM):
+ """
+ Causal LM created with `GemmaModel`. The underlying `GemmaDecoderLayer`
+ class is monkey-patched with `TEGemmaDecoderLayer` class before
+ initializing the causal LM with `GemmaForCausalLM`.
+
+ Args:
+ config: GemmaConfig
+ """
+
+ def __init__(self, config: GemmaConfig):
+ with replace_decoder(te_decoder_cls=TEGemmaDecoderLayer):
+ super().__init__(config)
+ self.config = config
+ self.to(torch.bfloat16).cuda()
+ self.hidden_size = config.hidden_size
+ self._model_generation_phase = GemmaGenerator(
+ lm_head=self.lm_head,
+ model=self.model,
+ dtype=torch.bfloat16,
+ qkv_format=config.qkv_format,
+ )
+ self._model_context_phase = StaticGemmaModel(
+ self.model, torch.bfloat16, "arbitrary", self.lm_head
+ )
+
+ if self.config.fp8:
+ self.fp8_recipe = DelayedScaling(
+ fp8_format=Format.HYBRID, amax_history_len=16, amax_compute_algo="max"
+ )
+
+ self.te_rope_emb = RotaryPositionEmbedding(self.config.head_dim)(
+ max_seq_len=self.config.max_position_embeddings
+ ).cuda()
+
+ @staticmethod
+ def _padding_to_end(inputs, lengths, max_seq_len=None):
+ """
+ Gets the tensor with sequence padded from the beginning and
+ return tensor padded from its end.
+
+ Parameters
+ ----------
+ inputs : Tensor, tensor with shape [b, s] containing token numbers.
+ It's padded from the beggining.
+ lengths: Tensor, tensor with shape [s] with lengths of the sequences.
+
+ """
+ max_seq_len = torch.max(lengths) if max_seq_len is None else max_seq_len
+ batch_size, max_seq_len = inputs.shape
+ new_input_ids = inputs.clone()
+ for i in range(batch_size):
+ new_input_ids[i, : lengths[i]] = inputs[i, (max_seq_len - lengths[i]) : max_seq_len]
+ new_input_ids[i, lengths[i] :] = inputs[i, 0 : (max_seq_len - lengths[i])]
+
+ # Disable the input preparation that involves extra padding
+ # inputs.copy_(new_input_ids)
+
+ # Trim the inputs to no extra padding i.e. fix the max seq len to
+ # the longest sequence in the batch
+ actual_max_seq_len = max_seq_len
+ inputs.data = new_input_ids[:, :actual_max_seq_len]
+ # print(f"actual_max_seq_len: {actual_max_seq_len}")
+
+ # For Paged Attention, make the valid sequences, multiple of 64
+ # inputs.data = new_input_ids[:, :4].repeat(1, 16)
+ # import pdb; pdb.set_trace()
+ # print(f"inputs.data.shape: {inputs.data.shape}")
+ # exit()
+
+ def _next_64_multiply(self, x):
+ return ((x + 63) // 64) * 64
+
+ # This function is overriden in TeGEmmaForCausalLMCudaGraphs.
+ def _create_hidden_states_buffer(self, input_ids: torch.Tensor):
+ tensor = torch.empty(
+ (input_ids.shape[0], input_ids.shape[1], self.hidden_size),
+ device="cuda",
+ dtype=torch.float32,
+ )
+ # import pdb; pdb.set_trace()
+ return tensor
+
+ # This function is overriden in TeGEmmaForCausalLMCudaGraphs.
+ def _create_inference_params(self, *args, **kwargs):
+ infer_params = InferenceParams(*args, **kwargs)
+ return infer_params
+
+ # This function is overriden in TeGEmmaForCausalLMCudaGraphs.
+ def _get_max_input_seq_len(self, input_ids):
+ return (
+ input_ids.shape[1]
+ if not hasattr(self.config, "cuda_graphs_static_max_context_len")
+ else self.config.cuda_graphs_static_max_context_len
+ )
+
+ # The buffer for generation is some part (beginning) of hidden states buffer.
+ # This function returns pointer to it and also copies there data if provided.
+ def _get_generation_buffer(self, hidden_states_buffer, data_to_copy=None):
+ # hidden_states_buffer has shape [b, s, hd]
+ # generation_buffer will have shape [b, 1, hd]
+ # Notice that "generation_buffer = hidden_states_buffer[:, 0, :].unsqueeze(1)"
+ # will return uncontiguous buffer, which we want to avoid.
+ output = hidden_states_buffer.view(-1)[
+ : hidden_states_buffer.shape[0] * hidden_states_buffer.shape[2]
+ ]
+ if data_to_copy is not None:
+ output.copy_(data_to_copy.reshape(-1))
+ generation_buffer = output.view(
+ (hidden_states_buffer.shape[0], 1, hidden_states_buffer.shape[2])
+ )
+ return generation_buffer
+
+ def _generate_context_phase(self, input_ids: torch.Tensor, inference_params: InferenceParams):
+ # import pdb; pdb.set_trace()
+ hidden_states = self._create_hidden_states_buffer(input_ids)
+ hidden_states.copy_(self.model.embed_tokens(input_ids))
+
+ # We need to update offsets before every forward pass to make cache work properly.
+ lengths = input_ids.ne(0).sum(dim=1)
+
+ # import pdb; pdb.set_trace()
+ if self.config.qkv_format == "thd":
+ # inference_params.setup_before_new_input(
+ # lengths_tensor=lengths, max_input_length=input_ids.shape[1]
+ # )
+ lengths = input_ids.ne(0).sum(dim=1)
+ inference_params.pre_step(OrderedDict(zip(list(range(len(lengths))), lengths.tolist())))
+ else:
+ inference_params.setup_before_new_input(length=input_ids.shape[1])
+
+ logits, hs_buffer = self._model_context_phase(
+ hidden_states,
+ attention_mask=((input_ids == 0) if self.config.qkv_format != "thd" else None),
+ attn_mask_type="padding_causal" if self.config.qkv_format == "thd" else "arbitrary",
+ rope_emb=self.te_rope_emb,
+ )
+
+ if self.config.qkv_format == "thd":
+ logits = logits[torch.arange(logits.size(0)), lengths - 1, :]
+ else:
+ logits = logits[:, -1, :]
+
+ next_tokens = torch.argmax(logits, dim=1)
+
+ # self.hidden_states have shape [b, s, hd].
+ # We return hidden state for the last token - output has shape [b, 1, hd]
+ hidden_states = self._get_generation_buffer(
+ hidden_states, self.model.embed_tokens(next_tokens)
+ )
+ return hidden_states, next_tokens
+
+ def _make_mask_one_token_longer(self, mask):
+ return torch.cat(
+ [mask, torch.zeros(mask.size(0), 1, 1, 1, dtype=torch.bool, device=mask.device)], dim=-1
+ )
+
+ @torch.no_grad()
+ def generate(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ pad_token_id: int = 0,
+ max_new_tokens: int = 0,
+ *args,
+ **kwargs
+ ):
+ self.eval()
+
+ # We need both autocasts: FP8 for operations that can run in lower precision
+ # and BF16 for those that cannot.
+ with autocast(dtype=torch.bfloat16, cache_enabled=False), te.pytorch.fp8_autocast(
+ enabled=self.config.fp8, fp8_recipe=self.fp8_recipe if self.config.fp8 else None
+ ):
+
+ lengths = torch.sum(input_ids.ne(pad_token_id), dim=-1).squeeze() # [s]
+
+ # print(f"max_input_sequence_len: {max_input_sequence_len}")
+ # exit()
+
+ if self.config.qkv_format == "thd":
+ # For thd layout padding is at the end, otherwise at the beginning.
+ TEGemmaForCausalLM._padding_to_end(
+ input_ids,
+ lengths,
+ max_seq_len=(
+ self.config.cuda_graphs_static_max_context_len
+ if self.config.generation_cuda_graphs
+ else None
+ ),
+ )
+
+ batch_size, max_input_sequence_len = input_ids.shape[0], self._get_max_input_seq_len(
+ input_ids
+ )
+
+ # InferenceParams is a cache, where keys and values of previous tokens are stored.
+ # Moreover it stores length of both already generated and input sequences.
+ inference_params = self._create_inference_params(
+ max_batch_size=batch_size,
+ # num_layers=self.config.num_hidden_layers,
+ max_sequence_length=128,
+ num_heads_kv=self.config.num_key_value_heads,
+ # num_heads_q=self.config.num_attention_heads,
+ head_dim_v=self.config.head_dim,
+ head_dim_k=self.config.head_dim,
+ dtype=torch.bfloat16,
+ is_paged=self.config.is_paged,
+ page_size=64,
+ total_num_pages=64 * 128 // 64, # 64 * 64 (max_sequence_length) / 64 (page_size)
+ )
+
+ self._model_context_phase.set_inference_params(inference_params)
+ self._model_generation_phase.set_inference_params(inference_params)
+
+ # print(f"context phase start")
+ # import pdb; pdb.set_trace()
+ hidden_states, next_tokens = self._generate_context_phase(input_ids, inference_params)
+
+ # print(f"context phase done")
+ # Generation phase.
+ if self.config.qkv_format == "thd":
+ lengths_tensor = torch.ones((next_tokens.shape[0],), dtype=int)
+ inference_params.pre_step(
+ OrderedDict(zip(list(range(len(lengths_tensor))), lengths_tensor.tolist()))
+ )
+ else:
+ inference_params.setup_before_new_input(length=1)
+
+ output_tokens = [next_tokens]
+
+ mask = None
+ if self.config.qkv_format != "thd":
+ mask = (input_ids == 0).unsqueeze(1).unsqueeze(1)
+
+ for _ in range(max_new_tokens):
+ if self.config.qkv_format != "thd":
+ # It will not work with cuda graphs, but it is not used for thd qkv_format.
+ # Attention mask in bshd needs attn_mask increased by 1 to
+ # include the next token to be generated
+ mask = self._make_mask_one_token_longer(mask)
+
+ next_tokens = self._model_generation_phase(
+ hidden_states,
+ mask=mask,
+ attn_mask_type="padding" if self.config.qkv_format == "thd" else "arbitrary",
+ rope_emb=self.te_rope_emb,
+ )
+
+ # self.inference_params contains for example kv_cache.
+ # This needs to be called before every pass,
+ # to update the information of sequence lengths.
+ # Here we increase sequence offsets by one,
+ # because we generated one token for every sequence.
+ if self.config.qkv_format == "thd":
+ lengths_tensor = torch.ones((next_tokens.shape[0],), dtype=int)
+ inference_params.pre_step(
+ OrderedDict(zip(list(range(len(lengths_tensor))), lengths_tensor.tolist()))
+ )
+ else:
+ inference_params.setup_before_new_input(length=1)
+ # next_tokens is static output tensor, so we need to clone it
+ # - it gets changed every iteration.
+ output_tokens.append(next_tokens.clone())
+
+ result = torch.cat((input_ids, torch.stack(output_tokens).permute([1, 0])), dim=1)
+ return result
+
+ def forward(self, *args, **kwargs):
+ self._model_context_phase.set_inference_params(None)
+ hidden_states = self.model.embed_tokens(kwargs["input_ids"])
+ logits = self._model_context_phase(
+ hidden_states,
+ attention_mask=(
+ (kwargs["input_ids"] == 0) if self.config.qkv_format != "thd" else None
+ ),
+ attn_mask_type="arbitrary",
+ )
+ return logits
+
+
+class TEGemmaForCausalLMCudaGraphs(TEGemmaForCausalLM):
+ """
+ TEGemmaForCausalLMCudaGraphs is the version of the class TEGemmaForCausalLM
+ using CUDA Graphs to speed it up. We need to make one trade-off.
+ Namely, batch_size, max_seq_len and max_context_seq_len need to be static.
+ It is necessary to run generation with the same value of
+ these variables that we recorded graph on.
+ """
+
+ def __init__(self, config: GemmaConfig):
+ super().__init__(config)
+ assert (
+ config.qkv_format == "thd"
+ ), "Generation with CUDA Graphs are implemented only for thd format."
+
+ # Preparation of the static buffers.
+ self.config = config
+ self.hidden_states_buffer = torch.empty(
+ (
+ self.config.cuda_graphs_static_batch_size,
+ self.config.cuda_graphs_static_max_context_len,
+ self.config.hidden_size,
+ )
+ ).cuda()
+
+ # This is in fact part of the buffer for hidden_states.
+ self.generation_buffer = self._get_generation_buffer(self.hidden_states_buffer)
+ self.inference_params = InferenceParams(
+ max_batch_size=self.config.cuda_graphs_static_batch_size,
+ # num_layers=self.config.num_hidden_layers,
+ max_sequence_length=self.config.cuda_graphs_static_max_seq_len,
+ num_heads_kv=self.config.num_key_value_heads,
+ # num_heads_q=self.config.num_attention_heads,
+ head_dim_v=self.config.head_dim,
+ head_dim_k=self.config.head_dim,
+ dtype=torch.bfloat16,
+ is_paged=self.config.is_paged,
+ page_size=64,
+ total_num_pages=64
+ * self.config.cuda_graphs_static_max_seq_len
+ // 64, # 64 * 64 (max_sequence_length) / 64 (page_size)
+ )
+
+ self._model_generation_phase.set_inference_params(self.inference_params)
+ self._model_context_phase.set_inference_params(self.inference_params)
+
+ def record(self):
+ # We want to record model in training=False, because it will be used in generation.
+ self.eval()
+
+ # Here "the trick" happens. We override methods from TEGemmaForCausalLM
+ # with their recorded version. After invocation of each of them,
+ # captured graph will be replayed with minimal usage of CPU,
+ # what will lead to huge speedup.
+ input_shape = (
+ self.config.cuda_graphs_static_batch_size,
+ self.config.cuda_graphs_static_max_context_len,
+ )
+
+ # [1] Should be same as lengths_tensor from TEGemmaForCausalLM
+ lengths = torch.tensor(input_shape[0] * [input_shape[1]], device="cuda", dtype=torch.int32)
+ max_input_length = input_shape[1]
+
+ self.inference_params.pre_step(
+ OrderedDict(zip(list(range(len(lengths))), lengths.tolist()))
+ )
+
+ # print(f"context phase recording start")
+
+ self._model_context_phase = self.record_graph(
+ self._model_context_phase,
+ self.hidden_states_buffer,
+ attn_mask_type="padding_causal",
+ rope_emb=self.te_rope_emb,
+ ) # CUDA Graphs recording
+
+ # print(f"context phase recording done")
+ input_shape = (self.config.cuda_graphs_static_batch_size, 1)
+
+ lengths = torch.tensor(input_shape[0] * [1], device="cuda", dtype=torch.int32)
+
+ self.inference_params.pre_step(
+ OrderedDict(zip(list(range(len(lengths))), lengths.tolist()))
+ )
+
+ self._model_generation_phase = self.record_graph(
+ self._model_generation_phase,
+ self.generation_buffer,
+ attn_mask_type="padding",
+ rope_emb=self.te_rope_emb,
+ ) # CUDA Graphs recording
+
+ """
+ Functions _create_hidden_states_buffer and _create_inference_params
+ from base class are overriden to make hidden_states and inference_params static
+ - not changing their position in memory between every invocation.
+ """
+
+ def _create_hidden_states_buffer(self, *args, **kwargs):
+ return self.hidden_states_buffer
+
+ def _create_inference_params(self, *args, **kwargs):
+ self.inference_params.reset()
+ return self.inference_params
+
+ def _get_max_input_seq_len(self, _):
+ return self.config.cuda_graphs_static_max_context_len
+
+ @torch.no_grad()
+ def record_graph(self, function, input_tensor, **sample_kwargs):
+ # function is invoked on argument (self.hidden_states,) and all kernels are recorded.
+ # record_graph() returns captured function, which can be run later with lower of th CPU.
+ fp8_format = Format.HYBRID
+ fp8_recipe = DelayedScaling(
+ fp8_format=fp8_format, amax_history_len=1024, amax_compute_algo="max"
+ )
+
+ # We need both autocasts: FP8 for operations that can run in lower precision
+ # and BF16 for those that cannot.
+ with autocast(dtype=torch.bfloat16, cache_enabled=False):
+ graphed_function = te.pytorch.make_graphed_callables(
+ function,
+ (input_tensor,),
+ fp8_enabled=self.config.fp8,
+ fp8_recipe=fp8_recipe,
+ allow_unused_input=True,
+ num_warmup_iters=5,
+ sample_kwargs=sample_kwargs,
+ )
+ return graphed_function
diff --git a/docs/examples/te_gemma/te_gemma_loading_weights.py b/docs/examples/te_gemma/te_gemma_loading_weights.py
new file mode 100755
index 0000000000..41f62ad7f3
--- /dev/null
+++ b/docs/examples/te_gemma/te_gemma_loading_weights.py
@@ -0,0 +1,160 @@
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+import os
+import re
+import gc
+import torch
+
+from typing import List
+
+from transformer_engine.pytorch.fp8 import fp8_model_init
+
+from transformers.modeling_utils import load_state_dict, _load_state_dict_into_model
+from transformers.utils.hub import get_checkpoint_shard_files
+
+"""
+ This file contains logic of mapping the HuggingFace GemmaModel parameters
+ with TransformerEngine TransformerLayer. When we have initialized Transformer models
+ both with HF and with TE, we can copy parameters from the first to the second.
+"""
+
+
+def _load_weights_for_fp8_model(vanilla_model, hyperparams):
+ # The weights are loaded from the file with state_dict
+ # of model with weights which contains also fp8 parameters.
+ # The weights are in BF16 precision, but they contain fp8 metadata
+ # computed by the calibration procedure.
+ vanilla_model.load_state_dict(
+ torch.load(hyperparams.fp8_model_weights_filename),
+ strict=False,
+ # strict = false, because some parameters have
+ # multiple pointers to the same weight
+ # vanilla_model._model_context_phase.model
+ # and vanilla_model._model_generation_phase.model
+ )
+
+
+def _load_weights_for_standard_model(vanilla_model, config):
+ # The weights are loaded from the file with original weights.
+ archive_file = os.path.join(config.model_name, "model.safetensors.index.json")
+ resolved_archive_file, _ = get_checkpoint_shard_files(config.model_name, archive_file)
+ total_dict = {}
+ for shard_file in resolved_archive_file:
+ state_dict = load_state_dict(shard_file)
+ total_dict.update(state_dict)
+
+ replace_params(
+ total_dict,
+ vanilla_model.state_dict(),
+ config,
+ qkv_fused_and_interleaved=config.fuse_qkv_params,
+ )
+ # Copy parameters like embedding:
+ _load_state_dict_into_model(vanilla_model, total_dict, start_prefix="")
+
+ # Force mem release. Taken from huggingface code.
+ del total_dict
+ gc.collect()
+
+
+def load_te_model(cls, config):
+ """
+ Custom method adapted from `from_pretrained` method in HuggingFace
+ Transformers repo:
+ https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579
+ """
+ config.use_cache = False # To make TransformerLayer compatible with GemmaModel
+ with fp8_model_init(config.fp8_model_init):
+ # there we need only to create model
+ vanilla_model = cls(config).to(torch.bfloat16).cuda()
+
+ # return vanilla_model
+ # and now we copy the weights into it
+ if config.fp8_model_weights_filename is not None:
+ _load_weights_for_fp8_model(vanilla_model, config)
+ else:
+ _load_weights_for_standard_model(vanilla_model, config)
+
+ return vanilla_model
+
+
+def _get_all_layer_prefixes_to_update(hf_state_dict):
+ """
+ There are many parameters in hf_state_dict, whose name start with "model.layers.[number]."
+ This function extracts all strings like "model.layers.[number]."
+ that are starting strings of keys in hf_state_dict.
+ """
+ all_layer_prefixes = set()
+ for param_key in hf_state_dict.keys():
+ layer_prefix_pat = "model.layers.\d+."
+ m = re.match(layer_prefix_pat, param_key)
+ if m is not None:
+ all_layer_prefixes.add(m.group())
+ return all_layer_prefixes
+
+
+def replace_params(hf_state_dict, te_state_dict, config, qkv_fused_and_interleaved=False):
+ """
+ Replaces params from TE TransformerLayer state_dict with corresponding parameters
+ from HuggingFace GemmaModel state_dict.
+ """
+ all_layer_prefixes: List[str] = _get_all_layer_prefixes_to_update(hf_state_dict)
+
+ for layer_prefix in all_layer_prefixes:
+
+ def copy_from_ht_to_te(te_name, hf_name, start=None, end=None):
+ te_state_dict[layer_prefix + te_name].data[start:end].copy_(
+ hf_state_dict[layer_prefix + hf_name]
+ )
+
+ copy_from_ht_to_te(
+ "self_attention.layernorm_qkv.layer_norm_weight", "input_layernorm.weight"
+ )
+ copy_from_ht_to_te("self_attention.proj.weight", "self_attn.o_proj.weight")
+ copy_from_ht_to_te("layernorm_mlp.layer_norm_weight", "post_attention_layernorm.weight")
+ copy_from_ht_to_te("layernorm_mlp.fc2_weight", "mlp.down_proj.weight")
+ copy_from_ht_to_te(
+ "layernorm_mlp.fc1_weight", "mlp.gate_proj.weight", end=config.intermediate_size
+ )
+ copy_from_ht_to_te(
+ "layernorm_mlp.fc1_weight", "mlp.up_proj.weight", start=config.intermediate_size
+ )
+
+ if qkv_fused_and_interleaved:
+ """
+ When qkv_fused_and_interleaved=True, key, query and value layers are on one tensor
+ in TE TransformerLayer. Moreover they are interleaved within each head.
+ Let q_i, k_i and v_i be query, key and value layers for i-th head respectively.
+ Then TE stores weight tensor in the form:
+ [q1 k1 v1 q2 k2 v2 ...]
+ This is done to maximally optimize performance time.
+ """
+ te_qkv_layer = te_state_dict[layer_prefix + "self_attention.layernorm_qkv.weight"]
+
+ def copy_interleave(hf_name, idx):
+ src = hf_state_dict[layer_prefix + hf_name]
+ for head_nr in range(config.num_attention_heads):
+ dst_offset = head_nr * config.head_dim * 3
+ dst_slice = slice(
+ dst_offset + idx * config.head_dim, dst_offset + (idx + 1) * config.head_dim
+ )
+ src_slice = slice(
+ head_nr * config.head_dim, head_nr * config.head_dim + config.head_dim
+ )
+ te_qkv_layer[dst_slice, :] = src[src_slice, :]
+
+ copy_interleave("self_attn.q_proj.weight", 0)
+ copy_interleave("self_attn.k_proj.weight", 1)
+ copy_interleave("self_attn.v_proj.weight", 2)
+ else:
+ copy_from_ht_to_te(
+ "self_attention.layernorm_qkv.query_weight", "self_attn.q_proj.weight"
+ )
+ copy_from_ht_to_te("self_attention.layernorm_qkv.key_weight", "self_attn.k_proj.weight")
+ copy_from_ht_to_te(
+ "self_attention.layernorm_qkv.value_weight", "self_attn.v_proj.weight"
+ )
+
+ return all_layer_prefixes
diff --git a/docs/examples/te_gemma/te_gemma_save.py b/docs/examples/te_gemma/te_gemma_save.py
new file mode 100755
index 0000000000..c83378840c
--- /dev/null
+++ b/docs/examples/te_gemma/te_gemma_save.py
@@ -0,0 +1,872 @@
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+from contextlib import contextmanager
+
+from typing import Optional
+from functools import partial
+from collections import OrderedDict
+
+import torch
+import transformer_engine as te
+from transformer_engine.pytorch.attention import InferenceParams, RotaryPositionEmbedding
+from transformer_engine.common.recipe import Format, DelayedScaling
+from torch.cuda.amp import autocast
+
+import transformers
+from transformers.models.gemma.modeling_gemma import GemmaForCausalLM, GemmaConfig, GemmaModel
+
+import torch.nn.functional as F
+
+
+class CacheParams:
+ def __init__(
+ self,
+ max_seqlen_q,
+ max_seqlen_kv,
+ cu_seqlens_q,
+ cu_seqlens_kv,
+ cu_seqlens_q_padded,
+ cu_seqlens_kv_padded,
+ ):
+ self.max_seqlen_q = max_seqlen_q
+ self.max_seqlen_kv = max_seqlen_kv
+ self.cu_seqlens_q = cu_seqlens_q
+ self.cu_seqlens_kv = cu_seqlens_kv
+ self.cu_seqlens_q_padded = cu_seqlens_q_padded
+ self.cu_seqlens_kv_padded = cu_seqlens_kv_padded
+
+
+def setup_cache_params_from_infer_params(inference_params, lengths_tensor, max_input_length):
+ """
+ Converts the `input_ids` to variables like `cu_seqlens_q/kv`, etc. which
+ will be used later.
+
+ (Currently a hack, this should be reformatted to a better method)
+ """
+
+ assert (
+ lengths_tensor is not None and max_input_length is not None
+ ), 'lengths_tensor and max_input_length should not be none for qkv_format = "thd"'
+
+ inference_params.max_incoming_seq_len = max_input_length
+
+ lengths_tensor = lengths_tensor.to(inference_params.cu_seqlens_q.device)
+
+ # inference_params.step_dict = OrderedDict(zip(list(range(len(lengths_tensor))), lengths_tensor.tolist()))
+ inference_params.pre_step(
+ OrderedDict(zip(list(range(len(lengths_tensor))), lengths_tensor.tolist()))
+ )
+
+ # print(inference_params.step_dict)
+
+ # def get_cache_params_in_infer_params():
+ # return CacheParams(max_seqlen_q, max_seqlen_kv, inference_params.cu_seqlens_q, inference_params.cu_seqlens_kv, inference_params.cu_seqlens_q_padded, inference_params.cu_seqlens_kv_padded)
+
+ # For the time being, create an ad-hoc field in `inference_params` to get the variables.
+ # @sudhakars: to create a better way later.
+ # inference_params.get_cache_params_from_infer_params = get_cache_params_in_infer_params
+
+
+# This class has been modified from
+# https://github.com/huggingface/transformers/blob/98adf24883b007c2a7fb17bab1c01b1614673433/src/transformers/models/gemma/modeling_gemma.py
+class GemmaRotaryEmbedding(torch.nn.Module):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
+ super().__init__()
+
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ inv_freq = 1.0 / (
+ self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)
+ )
+ self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
+
+ @torch.no_grad()
+ def forward(self, x, position_ids, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ self.inv_freq.to(x.device)
+ inv_freq_expanded = (
+ self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ )
+ position_ids_expanded = position_ids[:, None, :].float()
+ # Force float32 since bfloat16 loses precision on long contexts
+ # See https://github.com/huggingface/transformers/pull/29285
+ device_type = x.device.type
+ device_type = (
+ device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
+ )
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ return emb.unsqueeze(2) # should return in [b, s, 1, d] format
+
+
+class StaticBufferAllocator(torch.nn.Module):
+ """
+ This class is used when we use te.make_graphed_callable().
+ CUDA Graphs require all tensors to be static. Neverthless,
+ torch API make_graphed_callable() takes care of output of torch modules,
+ and makes them static. Thus by wrapping allocation of memory into
+ torch.nn.Module, we can greatly simplify our code.
+ """
+
+ # pylint: disable=no-self-use
+ def forward(self, size, dtype, device):
+ """
+ Return buffer of given size, dtype and device.
+ """
+ return torch.zeros(size, dtype=dtype, device=device)
+
+
+class TEGemmaDecoderLayer(te.pytorch.TransformerLayer):
+ """
+ Wrapper class over TE's `TransformerLayer`. This makes the wrapper very
+ similar to HF's `GemmaDecoderLayer` and easier to replace it in the code.
+
+ Args:
+ config: GemmaConfig
+ args: positional args (for compatibility with `GemmaDecoderLayer`)
+ kwargs: keyword args (for compatibility with `GemmaDecoderLayer`)
+ """
+
+ def __init__(self, config: GemmaConfig, layer_idx: int, *args, **kwargs):
+
+ self.gemma_config = config
+
+ super().__init__(
+ hidden_size=config.hidden_size,
+ ffn_hidden_size=config.intermediate_size,
+ num_attention_heads=config.num_attention_heads,
+ bias=False,
+ layernorm_epsilon=config.rms_norm_eps,
+ hidden_dropout=0,
+ attention_dropout=0,
+ fuse_qkv_params=config.fuse_qkv_params,
+ normalization="RMSNorm",
+ activation="geglu",
+ # attn_input_format=config.qkv_format,
+ attn_input_format="bshd",
+ num_gqa_groups=config.num_key_value_heads,
+ kv_channels=self.gemma_config.head_dim,
+ layer_number=(
+ layer_idx + 1
+ ), # Layer numbers in TE starts from 1, not 0 like in the HF.
+ zero_centered_gamma=True,
+ )
+
+ def alloc(self, size, dtype, device):
+ """
+ Allocated the buffer and works correctly with CUDA Graphs.
+ """
+ return self._allocator(size, dtype, device)
+
+ def forward(self, *args, **kwargs): # We need to additionally pass positional encoding.
+
+ # if "self_attn_mask_type" in kwargs:
+ # attn_mask_type = kwargs['self_attn_mask_type']
+ # else:
+ # attn_mask_type = "whatever_default_is"
+
+ # if attn_mask_type == "arbitrary":
+ # # @sudhakars: following logic doesn't work for `thd`
+ # attn_mask = kwargs['attention_mask']
+ # attention_mask_inv = ~attn_mask
+ # generation_case = torch.tensor(torch.tensor(attn_mask.shape).shape).item() > 2
+
+ # if generation_case:
+ # # @sudhakars: for some reason, `attention_mask` for generation is of the
+ # # form [b, 1, 1, s].
+ # attention_mask_inv = attention_mask_inv.squeeze(1).squeeze(1)
+ # assert torch.tensor(torch.tensor(attention_mask_inv.shape).shape).item() == 2
+
+ # # Create `position_ids` on the fly using `attention_mask` since HF
+ # # does the same in generation logic.
+ # position_ids = attention_mask_inv.long().cumsum(-1) - 1
+ # position_ids.masked_fill_(attention_mask_inv == 0, 1)
+
+ # if "position_ids" in kwargs and kwargs['position_ids'] is not None:
+ # assert torch.all(torch.eq(position_ids, kwargs["position_ids"])), "position ids don't match match exactly!"
+
+ # # convert [b, s] to [b, 1, s, s] since `arbitrary` is only set for
+ # # context phase and context phase gets [b, s] sized attn mask
+ # seq_len = 1 if torch.tensor(torch.tensor(attn_mask.shape).shape).item() > 2 else attention_mask_inv.shape[1]
+ # arbitrary_attn_mask = torch.zeros(attention_mask_inv.shape[0], 1, seq_len, attention_mask_inv.shape[1]).bool()
+ # for sample_idx in range(attn_mask.shape[0]):
+ # pad_len = attn_mask[sample_idx].sum().int().item()
+ # # set the columns to padded
+ # arbitrary_attn_mask[sample_idx, :, :, :pad_len] = True
+ # # set the rows to padded
+ # if not generation_case:
+ # arbitrary_attn_mask[sample_idx, :, :pad_len, :] = True
+ # arbitrary_attn_mask[sample_idx] = torch.tril(arbitrary_attn_mask[sample_idx].logical_not()).logical_not()
+
+ # # Update the attention mask to arbitrary
+ # kwargs['attention_mask'] = arbitrary_attn_mask.cuda()
+
+ # # @sudhakars: `max_position_embeddings` is not even used inside GemmaRotaryEmbedding
+ # # @sudhakars: change the hardcoded `dim` to something like config.head_dim
+ # te_rope_emb = GemmaRotaryEmbedding(dim=256, max_position_embeddings=self.gemma_config.max_position_embeddings).cuda()
+ # te_rope_emb = te_rope_emb(args[0], position_ids.cuda())
+ # else:
+ # When the `attention_mask` is not `arbitrary`, then for the purpose
+ # of this tutorial, we're using `padding_causal` (for context) and
+ # `padding` (for generation)
+ # @sudhakars: find a better way to provide the `tensor_format`
+ te_rope_emb = RotaryPositionEmbedding(self.gemma_config.head_dim)(
+ max_seq_len=self.gemma_config.max_position_embeddings
+ ).cuda()
+
+ inference_params = kwargs["inference_params"]
+ # @sudhakars: big assumption that the input is "sbhd"
+ # batch_size = args[0].shape[0]
+
+ # if inference_params.qkv_format_legacy == "thd":
+ # cache_params = kwargs["cache_params"]
+ # max_seqlen_q = cache_params.max_seqlen_q
+ # max_seqlen_kv = cache_params.max_seqlen_kv
+ # cu_seqlens_q = cache_params.cu_seqlens_q
+ # cu_seqlens_kv = cache_params.cu_seqlens_kv
+ # cu_seqlens_q_padded = cache_params.cu_seqlens_q_padded
+ # cu_seqlens_kv_padded = cache_params.cu_seqlens_kv_padded
+ # print(f"input_sequence_lengths (in forward): \n{inference_params.input_sequence_lengths}")
+
+ # this args cannot be passed to TransformerLayer
+ keys_to_remove = [
+ "position_ids",
+ "past_key_value",
+ "output_attentions",
+ "use_cache",
+ "cache_position",
+ ]
+ for key in keys_to_remove:
+ kwargs.pop(key, None)
+
+ # We need to return tuple to be compatible with HF.
+ return (
+ super().forward(
+ *args,
+ rotary_pos_emb=te_rope_emb,
+ # cu_seqlens_q=cu_seqlens_q,
+ # cu_seqlens_kv=cu_seqlens_kv,
+ # max_seqlen_q=max_seqlen_q,
+ # max_seqlen_kv=max_seqlen_kv,
+ **kwargs,
+ ),
+ )
+
+
+class StaticGemmaModel(torch.nn.Module):
+ """
+ StaticGemma is based of HF GemmaModel class.
+ It is adjusted to work properly with CUDA Graphs.
+ """
+
+ def __init__(
+ self,
+ model: GemmaModel,
+ dtype: torch.dtype,
+ mask: torch.Tensor,
+ lm_head: torch.nn.Module,
+ ):
+ super().__init__()
+ self.model = model
+ self.normalizer = torch.tensor(self.model.config.hidden_size**0.5, dtype=dtype)
+ self.mask = mask
+ self.lm_head = lm_head
+
+ def set_inference_params(self, inference_params):
+ self.inference_params = inference_params
+
+ # @sudhakars: is `arbitrary` fine being the default here?
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor = None,
+ attn_mask_type: str = "arbitrary",
+ ):
+ print(f"StaticGemmaModel forward start")
+ with torch.no_grad():
+ # static operation - for CUDA graphs
+ hidden_states.data[:] = hidden_states.data[:] * self.normalizer
+
+ for i, decoder_layer in enumerate(self.model.layers):
+ # print(f"layer {i}")
+ hidden_states.data[:] = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ self_attn_mask_type=self.mask if attn_mask_type is None else attn_mask_type,
+ inference_params=self.inference_params,
+ )[
+ 0
+ ] # static copy - for CUDA graphs
+
+ hidden_states.copy_(self.model.norm(hidden_states)) # static copy - for CUDA graphs
+ logits = self.lm_head(hidden_states)
+ logits = logits.float()
+ return logits, hidden_states
+
+
+class GemmaGenerator(torch.nn.Module):
+ """
+ GemmaGenerator gets one layer of embeddins,
+ makes forward pass and returns next tokens.
+ """
+
+ def __init__(
+ self, model: GemmaModel, lm_head: torch.nn.Module, dtype: torch.dtype, qkv_format: str
+ ):
+ super().__init__()
+ self.model = model
+ self.gemma_layers = StaticGemmaModel(model, dtype, "arbitrary", lm_head)
+ self.qkv_format = qkv_format
+
+ def set_inference_params(self, inference_params):
+ self.inference_params = inference_params
+ self.gemma_layers.set_inference_params(inference_params)
+
+ # @sudhakars: is `arbitrary` a good default value here?
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ mask: torch.Tensor = None,
+ attn_mask_type: str = "arbitrary",
+ ):
+ logits, _ = self.gemma_layers(
+ hidden_states, attention_mask=mask, attn_mask_type=attn_mask_type
+ )
+
+ assert logits.shape[0] == hidden_states.shape[0] # b
+ assert logits.shape[1] == hidden_states.shape[1] # seq_len
+ # logits.shape[2] = number of tokens
+ logits = logits[:, -1, :]
+ next_tokens = torch.argmax(logits, dim=1)
+
+ # static copy for CUDA graphs
+ hidden_states.copy_(self.model.embed_tokens(next_tokens).unsqueeze(1))
+
+ return next_tokens
+
+
+class PartialForwardWrapper(torch.nn.Module):
+ """
+ This class wraps a `torch.nn.Module` while partially modifying its `forward`
+
+ CUDAGraphs' `make_graphed_callables` method takes in a module but if only
+ `functools.partial` is used to wrap the module, it changes the modules'
+ type and that interferes with the `make_graphed_callables` intrinsics.
+ """
+
+ def __init__(self, module, **kwargs):
+ super().__init__()
+ self.module = module
+ self.partial_forward = partial(self.module.forward, **kwargs)
+
+ def __call__(self, *args, **kwargs):
+ return self.partial_forward(*args, **kwargs)
+
+ # @sudhakars: should we use better abstraction?
+ def set_inference_params(self, *args, **kwargs):
+ return self.module.set_inference_params(*args, **kwargs)
+
+
+@contextmanager
+def replace_decoder(te_decoder_cls):
+ """
+ Replace `GemmaDecoderLayer` with custom `TEGemmaDecoderLayer`.
+ """
+ original_gemma_decoder_cls = transformers.models.gemma.modeling_gemma.GemmaDecoderLayer
+ transformers.models.gemma.modeling_gemma.GemmaDecoderLayer = te_decoder_cls
+ try:
+ yield
+ finally:
+ transformers.models.gemma.modeling_gemma.GemmaDecoderLayer = original_gemma_decoder_cls
+
+
+class TEGemmaForCausalLM(GemmaForCausalLM):
+ """
+ Causal LM created with `GemmaModel`. The underlying `GemmaDecoderLayer`
+ class is monkey-patched with `TEGemmaDecoderLayer` class before
+ initializing the causal LM with `GemmaForCausalLM`.
+
+ Args:
+ config: GemmaConfig
+ """
+
+ def __init__(self, config: GemmaConfig):
+ with replace_decoder(te_decoder_cls=TEGemmaDecoderLayer):
+ super().__init__(config)
+ self.config = config
+ self.to(torch.bfloat16).cuda()
+ self.hidden_size = config.hidden_size
+ self._model_generation_phase = GemmaGenerator(
+ lm_head=self.lm_head,
+ model=self.model,
+ dtype=torch.bfloat16,
+ qkv_format=config.qkv_format,
+ )
+ self._model_context_phase = StaticGemmaModel(
+ self.model, torch.bfloat16, "arbitrary", self.lm_head
+ )
+
+ if self.config.fp8:
+ self.fp8_recipe = DelayedScaling(
+ fp8_format=Format.HYBRID, amax_history_len=16, amax_compute_algo="max"
+ )
+
+ @staticmethod
+ def _padding_to_end(inputs, lengths, max_seq_len=None):
+ """
+ Gets the tensor with sequence padded from the beginning and
+ return tensor padded from its end.
+
+ Parameters
+ ----------
+ inputs : Tensor, tensor with shape [b, s] containing token numbers.
+ It's padded from the beggining.
+ lengths: Tensor, tensor with shape [s] with lengths of the sequences.
+
+ """
+ max_seq_len = torch.max(lengths) if max_seq_len is None else max_seq_len
+ batch_size, max_seq_len = inputs.shape
+ new_input_ids = inputs.clone()
+ for i in range(batch_size):
+ new_input_ids[i, : lengths[i]] = inputs[i, (max_seq_len - lengths[i]) : max_seq_len]
+ new_input_ids[i, lengths[i] :] = inputs[i, 0 : (max_seq_len - lengths[i])]
+
+ # Disable the input preparation that involves extra padding
+ # inputs.copy_(new_input_ids)
+
+ # Trim the inputs to no extra padding i.e. fix the max seq len to
+ # the longest sequence in the batch
+ actual_max_seq_len = max_seq_len
+ inputs.data = new_input_ids[:, :actual_max_seq_len]
+ print(f"actual_max_seq_len: {actual_max_seq_len}")
+
+ # For Paged Attention, make the valid sequences, multiple of 64
+ # inputs.data = new_input_ids[:, :4].repeat(1, 16)
+
+ def _next_64_multiply(self, x):
+ return ((x + 63) // 64) * 64
+
+ # This function is overriden in TeGEmmaForCausalLMCudaGraphs.
+ def _create_hidden_states_buffer(self, input_ids: torch.Tensor):
+ tensor = torch.empty(
+ (input_ids.shape[0], input_ids.shape[1], self.hidden_size),
+ device="cuda",
+ dtype=torch.float32,
+ )
+ # import pdb; pdb.set_trace()
+ return tensor
+
+ # This function is overriden in TeGEmmaForCausalLMCudaGraphs.
+ def _create_inference_params(self, *args, **kwargs):
+ infer_params = InferenceParams(*args, **kwargs)
+
+ # max_batch_size = kwargs["max_batch_size"]
+
+ # Initialize some legacy params
+ # _allocator = StaticBufferAllocator()
+ # infer_params.cached_sequence_lengths = _allocator((max_batch_size,), dtype=torch.int32, device="cuda")
+ # infer_params.input_sequence_lengths = _allocator((max_batch_size,), dtype=torch.int32, device="cuda")
+
+ # These are updated in setup_cache_params_from_infer_params and they should be static for
+ # the duration of the context as well as the generation phase.
+ # infer_params.cu_seqlens_q, infer_params.cu_seqlens_kv, infer_params.cu_seqlens_q_padded, infer_params.cu_seqlens_kv_padded = [
+ # _allocator(max_batch_size + 1, dtype=torch.int32, device="cuda")
+ # for _ in range(4)
+ # ]
+
+ return infer_params
+
+ # This function is overriden in TeGEmmaForCausalLMCudaGraphs.
+ def _get_max_input_seq_len(self, input_ids):
+ return (
+ input_ids.shape[1]
+ if not hasattr(self.config, "cuda_graphs_static_max_context_len")
+ else self.config.cuda_graphs_static_max_context_len
+ )
+
+ # The buffer for generation is some part (beginning) of hidden states buffer.
+ # This function returns pointer to it and also copies there data if provided.
+ def _get_generation_buffer(self, hidden_states_buffer, data_to_copy=None):
+ # hidden_states_buffer has shape [b, s, hd]
+ # generation_buffer will have shape [b, 1, hd]
+ # Notice that "generation_buffer = hidden_states_buffer[:, 0, :].unsqueeze(1)"
+ # will return uncontiguous buffer, which we want to avoid.
+ output = hidden_states_buffer.view(-1)[
+ : hidden_states_buffer.shape[0] * hidden_states_buffer.shape[2]
+ ]
+ if data_to_copy is not None:
+ output.copy_(data_to_copy.reshape(-1))
+ generation_buffer = output.view(
+ (hidden_states_buffer.shape[0], 1, hidden_states_buffer.shape[2])
+ )
+ return generation_buffer
+
+ def _generate_context_phase(self, input_ids: torch.Tensor, inference_params: InferenceParams):
+ # import pdb; pdb.set_trace()
+ hidden_states = self._create_hidden_states_buffer(input_ids)
+ hidden_states.data[:] = self.model.embed_tokens(input_ids)
+
+ # We need to update offsets before every forward pass to make cache work properly.
+ lengths = input_ids.ne(0).sum(dim=1)
+
+ # import pdb; pdb.set_trace()
+ if self.config.qkv_format == "thd":
+ # inference_params.setup_before_new_input(
+ # lengths_tensor=lengths, max_input_length=input_ids.shape[1]
+ # )
+ lengths = input_ids.ne(0).sum(dim=1)
+ inference_params.pre_step(OrderedDict(zip(list(range(len(lengths))), lengths.tolist())))
+ else:
+ inference_params.setup_before_new_input(length=input_ids.shape[1])
+
+ logits, hs_buffer = self._model_context_phase(
+ hidden_states,
+ attention_mask=((input_ids == 0) if self.config.qkv_format != "thd" else None),
+ attn_mask_type="padding_causal" if self.config.qkv_format == "thd" else "arbitrary",
+ )
+
+ # We choose logits coresponding with last token in each sequence,
+ # which have various lengths - they are stored in (inference_params.incoming_seq_len - 1)
+ # Tensor when qkv_format == "thd" and
+ # they are the last token in the sequence when qkv_format != "thd".
+ # import pdb; pdb.set_trace()
+ if self.config.qkv_format == "thd":
+ logits = logits[torch.arange(logits.size(0)), lengths - 1, :]
+ else:
+ logits = logits[:, -1, :]
+
+ next_tokens = torch.argmax(logits, dim=1)
+
+ # self.hidden_states have shape [b, s, hd].
+ # We return hidden state for the last token - output has shape [b, 1, hd]
+ hidden_states = self._get_generation_buffer(
+ hidden_states, self.model.embed_tokens(next_tokens)
+ )
+ return hidden_states, next_tokens
+
+ def _make_mask_one_token_longer(self, mask):
+ return torch.cat(
+ [mask, torch.zeros(mask.size(0), 1, 1, 1, dtype=torch.bool, device=mask.device)], dim=-1
+ )
+
+ @torch.no_grad()
+ def generate(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ pad_token_id: int = 0,
+ max_new_tokens: int = 0,
+ *args,
+ **kwargs,
+ ):
+ self.eval()
+
+ # We need both autocasts: FP8 for operations that can run in lower precision
+ # and BF16 for those that cannot.
+ with autocast(dtype=torch.bfloat16, cache_enabled=False), te.pytorch.fp8_autocast(
+ enabled=self.config.fp8, fp8_recipe=self.fp8_recipe if self.config.fp8 else None
+ ):
+
+ lengths = torch.sum(input_ids.ne(pad_token_id), dim=-1).squeeze() # [s]
+
+ batch_size, max_input_sequence_len = input_ids.shape[0], self._get_max_input_seq_len(
+ input_ids
+ )
+
+ # This is not needed since the padding to the left is already done in utils.py
+ # # Pad input_ids with zeros on the left to match max_input_sequence_len
+ # # This adds padding tokens (0) to the left side of each sequence in the batch
+ # # Shape goes from [batch_size, seq_len] to [batch_size, max_input_sequence_len]
+ # input_ids = F.pad(
+ # input_ids, (max_input_sequence_len - input_ids.shape[1], 0), "constant", 0
+ # )
+
+ if self.config.qkv_format == "thd":
+ # For thd layout padding is at the end, otherwise at the beginning.
+ TEGemmaForCausalLM._padding_to_end(
+ input_ids,
+ lengths,
+ max_seq_len=(
+ self.config.cuda_graphs_static_max_context_len
+ if self.config.generation_cuda_graphs
+ else None
+ ),
+ )
+
+ # import pdb; pdb.set_trace()
+
+ # InferenceParams is a cache, where keys and values of previous tokens are stored.
+ # Moreover it stores length of both already generated and input sequences.
+ inference_params = self._create_inference_params(
+ max_batch_size=batch_size,
+ # num_layers=self.config.num_hidden_layers,
+ max_sequence_length=self._next_64_multiply(max_input_sequence_len + max_new_tokens),
+ num_heads_kv=self.config.num_key_value_heads,
+ # num_heads_q=self.config.num_attention_heads,
+ head_dim_v=self.config.head_dim,
+ head_dim_k=self.config.head_dim,
+ dtype=torch.bfloat16,
+ is_paged=self.config.is_paged,
+ page_size=64,
+ total_num_pages=64, # 64 * 64 (max_sequence_length) / 64 (page_size)
+ # is_cuda_graph=False
+ )
+
+ # def init_cache_params_in_infer_params(inference_params):
+ # _allocator = StaticBufferAllocator()
+ # inference_params.cached_sequence_lengths = _allocator(
+ # (batch_size,), dtype=torch.int32, device="cuda")
+ # inference_params.input_sequence_lengths = _allocator(
+ # (batch_size,), dtype=torch.int32, device="cuda")
+
+ # init_cache_params_in_infer_params(inference_params)
+
+ # inference_params.qkv_format_legacy = self.config.qkv_format
+
+ self._model_context_phase.set_inference_params(inference_params)
+ self._model_generation_phase.set_inference_params(inference_params)
+
+ print(f"context phase start")
+ # import pdb; pdb.set_trace()
+ hidden_states, next_tokens = self._generate_context_phase(input_ids, inference_params)
+
+ print(f"context phase done")
+ # Generation phase.
+ if self.config.qkv_format == "thd":
+ # inference_params.setup_before_new_input(
+ # lengths_tensor=torch.ones((next_tokens.shape[0],), device="cuda"),
+ # max_input_length=1,
+ # )
+ lengths_tensor = torch.ones((next_tokens.shape[0],), dtype=int)
+ inference_params.pre_step(
+ OrderedDict(zip(list(range(len(lengths_tensor))), lengths_tensor.tolist()))
+ )
+ else:
+ inference_params.setup_before_new_input(length=1)
+
+ output_tokens = [next_tokens]
+
+ mask = None
+ if self.config.qkv_format != "thd":
+ mask = (input_ids == 0).unsqueeze(1).unsqueeze(1)
+
+ for _ in range(max_new_tokens):
+ if self.config.qkv_format != "thd":
+ # It will not work with cuda graphs, but it is not used for thd qkv_format.
+ # Attention mask in bshd needs attn_mask increased by 1 to
+ # include the next token to be generated
+ mask = self._make_mask_one_token_longer(mask)
+
+ # setup_cache_params_from_infer_params(inference_params, input_ids)
+ # @sudhakars: could create position_ids from mask here
+ next_tokens = self._model_generation_phase(
+ hidden_states,
+ mask,
+ attn_mask_type="padding" if self.config.qkv_format == "thd" else "arbitrary",
+ )
+
+ # self.inference_params contains for example kv_cache.
+ # This needs to be called before every pass,
+ # to update the information of sequence lengths.
+ # Here we increase sequence offsets by one,
+ # because we generated one token for every sequence.
+ if self.config.qkv_format == "thd":
+ # self.inference_params.setup_before_new_input(
+ # lengths_tensor=torch.ones((next_tokens.shape[0],), device="cuda"),
+ # max_input_length=1,
+ # )
+ lengths_tensor = torch.ones((next_tokens.shape[0],), dtype=int)
+ inference_params.pre_step(
+ OrderedDict(zip(list(range(len(lengths_tensor))), lengths_tensor.tolist()))
+ )
+ else:
+ inference_params.setup_before_new_input(length=1)
+ # next_tokens is static output tensor, so we need to clone it
+ # - it gets changed every iteration.
+ output_tokens.append(next_tokens.clone())
+
+ result = torch.cat((input_ids, torch.stack(output_tokens).permute([1, 0])), dim=1)
+ return result
+
+ def forward(self, *args, **kwargs):
+ self._model_context_phase.set_inference_params(None)
+ hidden_states = self.model.embed_tokens(kwargs["input_ids"])
+ logits = self._model_context_phase(
+ hidden_states,
+ attention_mask=(
+ (kwargs["input_ids"] == 0) if self.config.qkv_format != "thd" else None
+ ),
+ attn_mask_type="arbitrary",
+ )
+ return logits
+
+
+class TEGemmaForCausalLMCudaGraphs(TEGemmaForCausalLM):
+ """
+ TEGemmaForCausalLMCudaGraphs is the version of the class TEGemmaForCausalLM
+ using CUDA Graphs to speed it up. We need to make one trade-off.
+ Namely, batch_size, max_seq_len and max_context_seq_len need to be static.
+ It is necessary to run generation with the same value of
+ these variables that we recorded graph on.
+ """
+
+ def __init__(self, config: GemmaConfig):
+ super().__init__(config)
+ assert (
+ config.qkv_format == "thd"
+ ), "Generation with CUDA Graphs are implemented only for thd format."
+
+ # Preparation of the static buffers.
+ self.config = config
+ self.hidden_states_buffer = torch.empty(
+ (
+ self.config.cuda_graphs_static_batch_size,
+ self.config.cuda_graphs_static_max_context_len,
+ self.config.hidden_size,
+ )
+ ).cuda()
+ # This is in fact part of the buffer for hidden_states.
+ self.generation_buffer = self._get_generation_buffer(self.hidden_states_buffer)
+ # self.inference_params = InferenceParams(
+ # max_batch_size=config.cuda_graphs_static_batch_size,
+ # max_sequence_length=config.cuda_graphs_static_max_seq_len,
+ # qkv_format="thd",
+ # )
+ self.inference_params = InferenceParams(
+ max_batch_size=self.config.cuda_graphs_static_batch_size,
+ # num_layers=self.config.num_hidden_layers,
+ max_sequence_length=self.config.cuda_graphs_static_max_seq_len,
+ num_heads_kv=self.config.num_key_value_heads,
+ # num_heads_q=self.config.num_attention_heads,
+ head_dim_v=self.config.head_dim,
+ head_dim_k=self.config.head_dim,
+ dtype=torch.bfloat16,
+ is_paged=self.config.is_paged,
+ page_size=64,
+ total_num_pages=64, # 64 * 64 (max_sequence_length) / 64 (page_size)
+ # is_cuda_graph=False
+ )
+
+ ## Taken from TEGemmaForCausalLM above
+ # max_batch_size = self.config.cuda_graphs_static_batch_size
+ # # Initialize some legacy params
+ # _allocator = StaticBufferAllocator()
+ # self.inference_params.cached_sequence_lengths = _allocator((max_batch_size,), dtype=torch.int32, device="cuda")
+ # self.inference_params.input_sequence_lengths = _allocator((max_batch_size,), dtype=torch.int32, device="cuda")
+
+ # self.inference_params.cu_seqlens_q, self.inference_params.cu_seqlens_kv, self.inference_params.cu_seqlens_q_padded, self.inference_params.cu_seqlens_kv_padded = [
+ # _allocator(max_batch_size + 1, dtype=torch.int32, device="cuda")
+ # for _ in range(4)
+ # ]
+
+ # def init_cache_params_in_infer_params(inference_params):
+ # inference_params.cached_sequence_lengths = torch.zeros(
+ # (batch_size,), device="cuda", dtype=torch.int32)
+ # inference_params.input_sequence_lengths = torch.zeros(
+ # (batch_size,), device="cuda", dtype=torch.int32)
+ # init_cache_params_in_infer_params(inference_params)
+
+ # self.inference_params.qkv_format_legacy = self.config.qkv_format
+
+ self._model_generation_phase.set_inference_params(self.inference_params)
+ self._model_context_phase.set_inference_params(self.inference_params)
+
+ def record(self):
+ # We want to record model in training=False, because it will be used in generation.
+ self.eval()
+
+ # Here "the trick" happens. We override methods from TEGemmaForCausalLM
+ # with their recorded version. After invocation of each of them,
+ # captured graph will be replayed with minimal usage of CPU,
+ # what will lead to huge speedup.
+ input_shape = (
+ self.config.cuda_graphs_static_batch_size,
+ self.config.cuda_graphs_static_max_context_len,
+ )
+ # self.inference_params.reset()
+ # self.inference_params.setup_before_new_input(
+ # lengths_tensor=torch.tensor(input_shape[0] * [input_shape[1]], device="cuda"),
+ # max_input_length=input_shape[1],
+ # )
+
+ # [1] Should be same as lengths_tensor from TEGemmaForCausalLM
+ lengths = torch.tensor(input_shape[0] * [input_shape[1]], device="cuda", dtype=torch.int32)
+ max_input_length = input_shape[1]
+
+ self.inference_params.pre_step(
+ OrderedDict(zip(list(range(len(lengths))), lengths.tolist()))
+ )
+
+ print(f"context phase recording start")
+ # self._model_context_phase.model.layers = torch.nn.ModuleList([
+ # self.record_graph(
+ # layer,
+ # self.hidden_states_buffer,
+ # self_attn_mask_type="padding_causal",
+ # inference_params=self.inference_params
+ # )
+ # for layer in self._model_context_phase.model.layers
+ # ])
+ self._model_context_phase = self.record_graph(
+ self._model_context_phase, self.hidden_states_buffer, attn_mask_type="padding_causal"
+ ) # CUDA Graphs recording
+
+ print(f"context phase recording done")
+ input_shape = (self.config.cuda_graphs_static_batch_size, 1)
+ # self.inference_params.reset()
+ # self.inference_params.setup_before_new_input(
+ # lengths_tensor=torch.tensor(input_shape[0] * [input_shape[1]], device="cuda"),
+ # max_input_length=input_shape[1],
+ # )
+ lengths = torch.tensor(input_shape[0] * [1], device="cuda", dtype=torch.int32)
+ max_input_length = input_shape[1]
+
+ self.inference_params.pre_step(
+ OrderedDict(zip(list(range(len(lengths))), lengths.tolist()))
+ )
+
+ self._model_generation_phase = self.record_graph(
+ self._model_generation_phase, self.generation_buffer, attn_mask_type="padding"
+ ) # CUDA Graphs recording
+
+ """
+ Functions _create_hidden_states_buffer and _create_inference_params
+ from base class are overriden to make hidden_states and inference_params static
+ - not changing their position in memory between every invocation.
+ """
+
+ def _create_hidden_states_buffer(self, *args, **kwargs):
+ return self.hidden_states_buffer
+
+ def _create_inference_params(self, *args, **kwargs):
+ self.inference_params.reset()
+ return self.inference_params
+
+ def _get_max_input_seq_len(self, _):
+ return self.config.cuda_graphs_static_max_context_len
+
+ @torch.no_grad()
+ def record_graph(self, function, input_tensor, **sample_kwargs):
+ # function is invoked on argument (self.hidden_states,) and all kernels are recorded.
+ # record_graph() returns captured function, which can be run later with lower of th CPU.
+ fp8_format = Format.HYBRID
+ fp8_recipe = DelayedScaling(
+ fp8_format=fp8_format, amax_history_len=1024, amax_compute_algo="max"
+ )
+
+ # We need both autocasts: FP8 for operations that can run in lower precision
+ # and BF16 for those that cannot.
+ with autocast(dtype=torch.bfloat16, cache_enabled=False):
+ graphed_function = te.pytorch.make_graphed_callables(
+ function,
+ (input_tensor,),
+ fp8_enabled=self.config.fp8,
+ fp8_recipe=fp8_recipe,
+ allow_unused_input=True,
+ num_warmup_iters=5,
+ sample_kwargs=sample_kwargs,
+ )
+ return graphed_function
diff --git a/docs/examples/te_gemma/te_llama.py b/docs/examples/te_gemma/te_llama.py
new file mode 100755
index 0000000000..637f4f574c
--- /dev/null
+++ b/docs/examples/te_gemma/te_llama.py
@@ -0,0 +1,826 @@
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+from contextlib import contextmanager
+
+from typing import Optional
+from functools import partial
+from collections import OrderedDict
+
+import torch
+import transformer_engine as te
+from transformer_engine.pytorch.attention import InferenceParams, RotaryPositionEmbedding
+from transformer_engine.common.recipe import Format, DelayedScaling
+from torch.cuda.amp import autocast
+
+import transformers
+from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaConfig, LlamaModel
+
+import torch.nn.functional as F
+
+
+def setup_cache_params_from_infer_params(inference_params, lengths_tensor, max_input_length):
+ """
+ Converts the `input_ids` to variables like `cu_seqlens_q/kv`, etc. which
+ will be used later.
+
+ (Currently a hack, this should be reformatted to a better method)
+ """
+
+ assert (
+ lengths_tensor is not None and max_input_length is not None
+ ), 'lengths_tensor and max_input_length should not be none for qkv_format = "thd"'
+ torch.add(
+ inference_params.cached_sequence_lengths,
+ inference_params.input_sequence_lengths,
+ out=inference_params.cached_sequence_lengths,
+ )
+ inference_params.input_sequence_lengths.copy_(lengths_tensor)
+ inference_params.max_incoming_seq_len = max_input_length
+
+ max_seqlen_q, max_seqlen_kv = (
+ inference_params.max_incoming_seq_len,
+ inference_params.max_sequence_length,
+ )
+
+ # # Allocation of buffers, it works correctly with CUDA Graphs.
+ _allocator = StaticBufferAllocator()
+ NR_BUFFERS = 4
+
+ cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded = [
+ _allocator(inference_params.max_batch_size + 1, dtype=torch.int32, device="cuda")
+ for _ in range(NR_BUFFERS)
+ ]
+
+ torch.cumsum(inference_params.input_sequence_lengths, dim=0, out=cu_seqlens_q[1:])
+ torch.cumsum(
+ inference_params.cached_sequence_lengths + inference_params.input_sequence_lengths,
+ dim=0,
+ out=cu_seqlens_kv[1:],
+ )
+ # If layer has shape [b * s_layer, h, d]
+ # offsets are of the form [k * s_layer * h * d for k = 0, ..., batch_size]
+ cu_seqlens_q_padded.copy_(
+ torch.arange(0, inference_params.max_batch_size + 1, device="cuda") * max_seqlen_q
+ )
+ cu_seqlens_kv_padded.copy_(
+ torch.arange(0, inference_params.max_batch_size + 1, device="cuda") * max_seqlen_kv
+ )
+
+ # inference_params.step_dict = OrderedDict(zip(list(range(len(lengths_tensor))), lengths_tensor.tolist()))
+ inference_params.pre_step(
+ OrderedDict(zip(list(range(len(lengths_tensor))), lengths_tensor.tolist()))
+ )
+
+ # print(inference_params.step_dict)
+
+ def get_cache_params_in_infer_params():
+ return (
+ max_seqlen_q,
+ max_seqlen_kv,
+ cu_seqlens_q,
+ cu_seqlens_kv,
+ cu_seqlens_q_padded,
+ cu_seqlens_kv_padded,
+ )
+
+ # For the time being, create an ad-hoc field in `inference_params` to get the variables.
+ # @sudhakars: to create a better way later.
+ inference_params.get_cache_params_from_infer_params = get_cache_params_in_infer_params
+
+
+# This class has been modified from
+# https://github.com/huggingface/transformers/blob/98adf24883b007c2a7fb17bab1c01b1614673433/src/transformers/models/gemma/modeling_gemma.py
+class LlamaRotaryEmbedding(torch.nn.Module):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
+ super().__init__()
+
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ inv_freq = 1.0 / (
+ self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)
+ )
+ self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
+
+ @torch.no_grad()
+ def forward(self, x, position_ids, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ self.inv_freq.to(x.device)
+ inv_freq_expanded = (
+ self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ )
+ position_ids_expanded = position_ids[:, None, :].float()
+ # Force float32 since bfloat16 loses precision on long contexts
+ # See https://github.com/huggingface/transformers/pull/29285
+ device_type = x.device.type
+ device_type = (
+ device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
+ )
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ return emb.unsqueeze(2) # should return in [b, s, 1, d] format
+
+
+class StaticBufferAllocator(torch.nn.Module):
+ """
+ This class is used when we use te.make_graphed_callable().
+ CUDA Graphs require all tensors to be static. Neverthlessly,
+ torch API make_graphed_callable() takes care of output of torch modules,
+ and makes them static. Thus by wrapping allocation of memory into
+ torch.nn.Module, we can greatly simplify our code.
+ """
+
+ # pylint: disable=no-self-use
+ def forward(self, size, dtype, device):
+ """
+ Return buffer of given size, dtype and device.
+ """
+ return torch.zeros(size, dtype=dtype, device=device)
+
+
+class TELlamaDecoderLayer(te.pytorch.TransformerLayer):
+ """
+ Wrapper class over TE's `TransformerLayer`. This makes the wrapper very
+ similar to HF's `LlamaDecoderLayer` and easier to replace it in the code.
+
+ Args:
+ config: LlamaConfig
+ args: positional args (for compatibility with `LlamaDecoderLayer`)
+ kwargs: keyword args (for compatibility with `LlamaDecoderLayer`)
+ """
+
+ def __init__(self, config: LlamaConfig, layer_idx: int, *args, **kwargs):
+
+ self.llama_config = config
+ self.head_dim = self.llama_config.hidden_size // self.llama_config.num_attention_heads
+
+ super().__init__(
+ hidden_size=config.hidden_size,
+ ffn_hidden_size=config.intermediate_size,
+ num_attention_heads=config.num_attention_heads,
+ bias=False, # LLaMA specific
+ layernorm_epsilon=config.rms_norm_eps,
+ hidden_dropout=0,
+ attention_dropout=0,
+ fuse_qkv_params=config.fuse_qkv_params,
+ normalization="RMSNorm",
+ activation="swiglu", # LLaMA specific
+ # attn_input_format=config.qkv_format,
+ attn_input_format="bshd",
+ num_gqa_groups=config.num_key_value_heads,
+ kv_channels=self.head_dim, # LLaMA specific
+ layer_number=(
+ layer_idx + 1
+ ), # Layer numbers in TE starts from 1, not 0 like in the HF.
+ zero_centered_gamma=True, # LLaMA specific
+ )
+
+ def alloc(self, size, dtype, device):
+ """
+ Allocated the buffer and works correctly with CUDA Graphs.
+ """
+ return self._allocator(size, dtype, device)
+
+ def forward(self, *args, **kwargs): # We need to additionally pass positional encoding.
+
+ if "self_attn_mask_type" in kwargs:
+ attn_mask_type = kwargs["self_attn_mask_type"]
+ else:
+ attn_mask_type = "whatever_default_is"
+
+ if attn_mask_type == "arbitrary":
+ # @sudhakars: following logic doesn't work for `thd`
+ attn_mask = kwargs["attention_mask"]
+ attention_mask_inv = ~attn_mask
+ generation_case = torch.tensor(torch.tensor(attn_mask.shape).shape).item() > 2
+
+ if generation_case:
+ # @sudhakars: for some reason, `attention_mask` for generation is of the
+ # form [b, 1, 1, s].
+ attention_mask_inv = attention_mask_inv.squeeze(1).squeeze(1)
+ assert torch.tensor(torch.tensor(attention_mask_inv.shape).shape).item() == 2
+
+ # Create `position_ids` on the fly using `attention_mask` since HF
+ # does the same in generation logic.
+ position_ids = attention_mask_inv.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask_inv == 0, 1)
+
+ if "position_ids" in kwargs and kwargs["position_ids"] is not None:
+ assert torch.all(
+ torch.eq(position_ids, kwargs["position_ids"])
+ ), "position ids don't match match exactly!"
+
+ # convert [b, s] to [b, 1, s, s] since `arbitrary` is only set for
+ # context phase and context phase gets [b, s] sized attn mask
+ seq_len = (
+ 1
+ if torch.tensor(torch.tensor(attn_mask.shape).shape).item() > 2
+ else attention_mask_inv.shape[1]
+ )
+ arbitrary_attn_mask = torch.zeros(
+ attention_mask_inv.shape[0], 1, seq_len, attention_mask_inv.shape[1]
+ ).bool()
+ for sample_idx in range(attn_mask.shape[0]):
+ pad_len = attn_mask[sample_idx].sum().int().item()
+ # set the columns to padded
+ arbitrary_attn_mask[sample_idx, :, :, :pad_len] = True
+ # set the rows to padded
+ if not generation_case:
+ arbitrary_attn_mask[sample_idx, :, :pad_len, :] = True
+ arbitrary_attn_mask[sample_idx] = torch.tril(
+ arbitrary_attn_mask[sample_idx].logical_not()
+ ).logical_not()
+
+ # Update the attention mask to arbitrary
+ kwargs["attention_mask"] = arbitrary_attn_mask.cuda()
+
+ # @sudhakars: `max_position_embeddings` is not even used inside GemmaRotaryEmbedding
+ # @sudhakars: change the hardcoded `dim` to something like config.head_dim
+ te_rope_emb = LlamaRotaryEmbedding(
+ dim=self.head_dim, max_position_embeddings=self.llama_config.max_position_embeddings
+ ).cuda()
+ te_rope_emb = te_rope_emb(args[0], position_ids.cuda())
+ else:
+ # When the `attention_mask` is not `arbitrary`, then for the purpose
+ # of this tutorial, we're using `padding_causal` (for context) and
+ # `padding` (for generation)
+ # @sudhakars: find a better way to provide the `tensor_format`
+ te_rope_emb = RotaryPositionEmbedding(self.head_dim)( # Use self.head_dim
+ max_seq_len=self.llama_config.max_position_embeddings
+ ).cuda()
+
+ inference_params = kwargs["inference_params"]
+ # @sudhakars: big assumption that the input is "sbhd"
+ # batch_size = args[0].shape[0]
+ if inference_params.qkv_format_legacy == "thd":
+ (
+ max_seqlen_q,
+ max_seqlen_kv,
+ cu_seqlens_q,
+ cu_seqlens_kv,
+ cu_seqlens_q_padded,
+ cu_seqlens_kv_padded,
+ ) = inference_params.get_cache_params_from_infer_params()
+
+ # this args cannot be passed to TransformerLayer
+ keys_to_remove = [
+ "position_ids",
+ "past_key_value",
+ "output_attentions",
+ "use_cache",
+ "cache_position",
+ ]
+ for key in keys_to_remove:
+ kwargs.pop(key, None)
+
+ # import pdb; pdb.set_trace()
+ # We need to return tuple to be compatible with HF.
+ return (
+ super().forward(
+ *args,
+ rotary_pos_emb=te_rope_emb,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_kv=cu_seqlens_kv,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_kv=max_seqlen_kv,
+ **kwargs
+ ),
+ )
+
+
+class StaticLlamaModel(torch.nn.Module):
+ """
+ StaticLlama is based of HF LlamaModel class.
+ It is adjusted to work properly with CUDA Graphs.
+ """
+
+ def __init__(
+ self,
+ model: LlamaModel,
+ dtype: torch.dtype,
+ mask: torch.Tensor,
+ lm_head: torch.nn.Module,
+ ):
+ super().__init__()
+ self.model = model
+ self.llama_config = model.config # Store LlamaConfig
+ self.normalizer = torch.tensor(self.llama_config.hidden_size**0.5, dtype=dtype)
+ self.mask = mask
+ self.lm_head = lm_head
+
+ def set_inference_params(self, inference_params):
+ self.inference_params = inference_params
+
+ # @sudhakars: is `arbitrary` fine being the default here?
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor = None,
+ attn_mask_type: str = "arbitrary",
+ ):
+ # import pdb; pdb.set_trace()
+ if hidden_states.shape[1] > 1:
+ torch.save(hidden_states, "input_ctxt.pth")
+
+ with torch.no_grad():
+ # static operation - for CUDA graphs
+ hidden_states.data[:] = hidden_states.data[:] * self.normalizer
+
+ for i, decoder_layer in enumerate(self.model.layers):
+ hidden_states.data[:] = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ self_attn_mask_type=self.mask if attn_mask_type is None else attn_mask_type,
+ inference_params=self.inference_params,
+ )[
+ 0
+ ] # static copy - for CUDA graphs
+
+ hidden_states.copy_(self.model.norm(hidden_states)) # static copy - for CUDA graphs
+ logits = self.lm_head(hidden_states)
+ logits = logits.float()
+ return logits
+
+
+class LlamaGenerator(torch.nn.Module):
+ """
+ LlamaGenerator gets one layer of embeddins,
+ makes forward pass and returns next tokens.
+ """
+
+ def __init__(
+ self, model: LlamaModel, lm_head: torch.nn.Module, dtype: torch.dtype, qkv_format: str
+ ):
+ super().__init__()
+ self.model = model
+ self.llama_layers = StaticLlamaModel(model, dtype, "arbitrary", lm_head)
+ self.qkv_format = qkv_format
+
+ def set_inference_params(self, inference_params):
+ self.inference_params = inference_params
+ self.llama_layers.set_inference_params(inference_params)
+
+ # @sudhakars: is `arbitrary` a good default value here?
+ def forward(
+ self, hidden_states: torch.Tensor, mask: torch.Tensor = None, mask_type: str = "arbitrary"
+ ):
+ logits = self.llama_layers(hidden_states, attention_mask=mask, attn_mask_type=mask_type)
+
+ assert logits.shape[0] == hidden_states.shape[0] # b
+ assert logits.shape[1] == hidden_states.shape[1] # seq_len
+ # logits.shape[2] = number of tokens
+ logits = logits[:, -1, :]
+ next_tokens = torch.argmax(logits, dim=1)
+
+ # static copy for CUDA graphs
+ hidden_states.copy_(self.model.embed_tokens(next_tokens).unsqueeze(1))
+
+ # self.inference_params contains for example kv_cache.
+ # This needs to be called before every pass,
+ # to update the information of sequence lengths.
+ # Here we increase sequence offsets by one,
+ # because we generated one token for every sequence.
+ if self.qkv_format == "thd":
+ # self.inference_params.setup_before_new_input(
+ # lengths_tensor=torch.ones((next_tokens.shape[0],), device="cuda"),
+ # max_input_length=1,
+ # )
+ setup_cache_params_from_infer_params(
+ self.inference_params,
+ lengths_tensor=torch.ones((next_tokens.shape[0],), dtype=int),
+ max_input_length=1,
+ )
+ else:
+ self.inference_params.setup_before_new_input(length=1)
+
+ return next_tokens
+
+
+class PartialForwardWrapper(torch.nn.Module):
+ """
+ This class wraps a `torch.nn.Module` while partially modifying its `forward`
+
+ CUDAGraphs' `make_graphed_callables` method takes in a module but if only
+ `functools.partial` is used to wrap the module, it changes the modules'
+ type and that interferes with the `make_graphed_callables` intrinsics.
+ """
+
+ def __init__(self, module, **kwargs):
+ super().__init__()
+ self.module = module
+ self.partial_forward = partial(self.module.forward, **kwargs)
+
+ def __call__(self, *args, **kwargs):
+ return self.partial_forward(*args, **kwargs)
+
+ # @sudhakars: should we use better abstraction?
+ def set_inference_params(self, *args, **kwargs):
+ return self.module.set_inference_params(*args, **kwargs)
+
+
+@contextmanager
+def replace_decoder(te_decoder_cls):
+ """
+ Replace `LlamaDecoderLayer` with custom `TELlamaDecoderLayer`.
+ """
+ original_llama_decoder_cls = transformers.models.llama.modeling_llama.LlamaDecoderLayer
+ transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decoder_cls
+ try:
+ yield
+ finally:
+ transformers.models.llama.modeling_llama.LlamaDecoderLayer = original_llama_decoder_cls
+
+
+class TELlamaForCausalLM(LlamaForCausalLM):
+ """
+ Causal LM created with `LlamaModel`. The underlying `LlamaDecoderLayer`
+ class is monkey-patched with `TELlamaDecoderLayer` class before
+ initializing the causal LM with `LlamaForCausalLM`.
+
+ Args:
+ config: LlamaConfig
+ """
+
+ def __init__(self, config: LlamaConfig):
+ with replace_decoder(te_decoder_cls=TELlamaDecoderLayer):
+ super().__init__(config)
+ self.config = config
+ self.to(torch.bfloat16).cuda()
+ self.hidden_size = config.hidden_size
+ self._model_generation_phase = LlamaGenerator(
+ lm_head=self.lm_head,
+ model=self.model,
+ dtype=torch.bfloat16,
+ qkv_format=config.qkv_format,
+ )
+ self._model_context_phase = StaticLlamaModel(
+ self.model, torch.bfloat16, "arbitrary", self.lm_head
+ )
+
+ if self.config.fp8:
+ self.fp8_recipe = DelayedScaling(
+ fp8_format=Format.HYBRID, amax_history_len=16, amax_compute_algo="max"
+ )
+
+ @staticmethod
+ def _padding_to_end(inputs, lengths):
+ """
+ Gets the tensor with sequence padded from the beginning and
+ return tensor padded from its end.
+
+ Parameters
+ ----------
+ inputs : Tensor, tensor with shape [b, s] containing token numbers.
+ It's padded from the beggining.
+ lengths: Tensor, tensor with shape [s] with lengths of the sequences.
+
+ """
+ max_seq_len = torch.max(lengths)
+ batch_size, max_seq_len = inputs.shape
+ new_input_ids = inputs.clone()
+ for i in range(batch_size):
+ new_input_ids[i, : lengths[i]] = inputs[i, (max_seq_len - lengths[i]) : max_seq_len]
+ new_input_ids[i, lengths[i] :] = inputs[i, 0 : (max_seq_len - lengths[i])]
+
+ # Disable the input preparation that involves extra padding
+ # inputs.copy_(new_input_ids)
+
+ # Trim the inputs to no extra padding i.e. fix the max seq len to
+ # the longest sequence in the batch
+ actual_max_seq_len = inputs.ne(0).sum(dim=1).max()
+ inputs.data = new_input_ids[:, :actual_max_seq_len]
+
+ # For Paged Attention, make the valid sequences, multiple of 64
+ # inputs.data = new_input_ids[:, :4].repeat(1, 16)
+
+ def _next_64_multiply(self, x):
+ return ((x + 63) // 64) * 64
+
+ # This function is overriden in TeGEmmaForCausalLMCudaGraphs.
+ def _create_hidden_states_buffer(self, input_ids: torch.Tensor):
+ return torch.empty(
+ (input_ids.shape[0], input_ids.shape[1], self.hidden_size),
+ device="cuda",
+ dtype=torch.float32,
+ )
+
+ # This function is overriden in TeGEmmaForCausalLMCudaGraphs.
+ def _create_inference_params(self, *args, **kwargs):
+ infer_params = InferenceParams(*args, **kwargs)
+
+ max_batch_size = kwargs["max_batch_size"]
+
+ # Initialize some legacy params
+ infer_params.cached_sequence_lengths = torch.zeros(
+ (max_batch_size,), device="cuda", dtype=torch.int32
+ )
+ infer_params.input_sequence_lengths = torch.zeros(
+ (max_batch_size,), device="cuda", dtype=torch.int32
+ )
+
+ return infer_params
+
+ # This function is overriden in TeGEmmaForCausalLMCudaGraphs.
+ def _get_max_input_seq_len(self, input_ids):
+ return input_ids.shape[1]
+
+ # The buffer for generation is some part (beginning) of hidden states buffer.
+ # This function returns pointer to it and also copies there data if provided.
+ def _get_generation_buffer(self, hidden_states_buffer, data_to_copy=None):
+ # hidden_states_buffer has shape [b, s, hd]
+ # generation_buffer will have shape [b, 1, hd]
+ # Notice that "generation_buffer = hidden_states_buffer[:, 0, :].unsqueeze(1)"
+ # will return uncontiguous buffer, which we want to avoid.
+ output = hidden_states_buffer.view(-1)[
+ : hidden_states_buffer.shape[0] * hidden_states_buffer.shape[2]
+ ]
+ if data_to_copy is not None:
+ output.copy_(data_to_copy.reshape(-1))
+ generation_buffer = output.view(
+ (hidden_states_buffer.shape[0], 1, hidden_states_buffer.shape[2])
+ )
+ return generation_buffer
+
+ def _generate_context_phase(self, input_ids: torch.Tensor, inference_params: InferenceParams):
+ hidden_states = self._create_hidden_states_buffer(input_ids)
+ hidden_states.data[:] = self.model.embed_tokens(input_ids)
+
+ # We need to update offsets before every forward pass to make cache work properly.
+ lengths = input_ids.ne(0).sum(dim=1)
+ # import pdb; pdb.set_trace()
+ if self.config.qkv_format == "thd":
+ # inference_params.setup_before_new_input(
+ # lengths_tensor=lengths, max_input_length=input_ids.shape[1]
+ # )
+ lengths = input_ids.ne(0).sum(dim=1)
+ max_input_length = input_ids.shape[1]
+ setup_cache_params_from_infer_params(inference_params, lengths, max_input_length)
+ else:
+ inference_params.setup_before_new_input(length=input_ids.shape[1])
+
+ logits = self._model_context_phase(
+ hidden_states,
+ attention_mask=((input_ids == 0) if self.config.qkv_format != "thd" else None),
+ attn_mask_type="padding_causal" if self.config.qkv_format == "thd" else "arbitrary",
+ )
+
+ # We choose logits coresponding with last token in each sequence,
+ # which have various lengths - they are stored in (inference_params.incoming_seq_len - 1)
+ # Tensor when qkv_format == "thd" and
+ # they are the last token in the sequence when qkv_format != "thd".
+ if self.config.qkv_format == "thd":
+ logits = logits[
+ torch.arange(logits.size(0)), inference_params.input_sequence_lengths - 1, :
+ ]
+ else:
+ logits = logits[:, -1, :]
+ torch.save(logits, "logits_ctxt.pth")
+ next_tokens = torch.argmax(logits, dim=1)
+
+ # self.hidden_states have shape [b, s, hd].
+ # We return hidden state for the last token - output has shape [b, 1, hd]
+ hidden_states = self._get_generation_buffer(
+ hidden_states, self.model.embed_tokens(next_tokens)
+ )
+ return hidden_states, next_tokens
+
+ def _make_mask_one_token_longer(self, mask):
+ return torch.cat(
+ [mask, torch.zeros(mask.size(0), 1, 1, 1, dtype=torch.bool, device=mask.device)], dim=-1
+ )
+
+ @torch.no_grad()
+ def generate(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ pad_token_id: int = 0,
+ max_new_tokens: int = 0,
+ *args,
+ **kwargs
+ ):
+ self.eval()
+
+ # We need both autocasts: FP8 for operations that can run in lower precision
+ # and BF16 for those that cannot.
+ with autocast(dtype=torch.bfloat16, cache_enabled=False), te.pytorch.fp8_autocast(
+ enabled=self.config.fp8, fp8_recipe=self.fp8_recipe if self.config.fp8 else None
+ ):
+
+ lengths = torch.sum(input_ids.ne(pad_token_id), dim=-1).squeeze() # [s]
+ # input_ids = F.pad(
+ # input_ids, (max_input_sequence_len - input_ids.shape[1], 0), "constant", 0
+ # )
+
+ if self.config.qkv_format == "thd":
+ # For thd layout padding is at the end, otherwise at the beginning.
+ TELlamaForCausalLM._padding_to_end(input_ids, lengths)
+
+ batch_size, max_input_sequence_len = input_ids.shape[0], self._get_max_input_seq_len(
+ input_ids
+ )
+ # import pdb; pdb.set_trace()
+
+ # InferenceParams is a cache, where keys and values of previous tokens are stored.
+ # Moreover it stores length of both already generated and input sequences.
+ head_dim = self.config.hidden_size // self.config.num_attention_heads
+ inference_params = self._create_inference_params(
+ max_batch_size=batch_size,
+ # num_layers=self.config.num_hidden_layers,
+ max_sequence_length=self._next_64_multiply(max_input_sequence_len + max_new_tokens),
+ num_heads_kv=self.config.num_key_value_heads,
+ # num_heads_q=self.config.num_attention_heads,
+ head_dim_v=head_dim,
+ head_dim_k=head_dim,
+ dtype=torch.bfloat16,
+ is_paged=True,
+ page_size=64,
+ total_num_pages=64 * 3, # 64 * 64 (max_sequence_length) / 64 (page_size)
+ # is_cuda_graph=False
+ )
+
+ def init_cache_params_in_infer_params(inference_params):
+ inference_params.cached_sequence_lengths = torch.zeros(
+ (batch_size,), device="cuda", dtype=torch.int32
+ )
+ inference_params.input_sequence_lengths = torch.zeros(
+ (batch_size,), device="cuda", dtype=torch.int32
+ )
+
+ init_cache_params_in_infer_params(inference_params)
+ inference_params.qkv_format_legacy = self.config.qkv_format
+
+ self._model_context_phase.set_inference_params(inference_params)
+ self._model_generation_phase.set_inference_params(inference_params)
+
+ hidden_states, next_tokens = self._generate_context_phase(input_ids, inference_params)
+
+ # Generation phase.
+ if self.config.qkv_format == "thd":
+ # inference_params.setup_before_new_input(
+ # lengths_tensor=torch.ones((next_tokens.shape[0],), device="cuda"),
+ # max_input_length=1,
+ # )
+ setup_cache_params_from_infer_params(
+ inference_params,
+ lengths_tensor=torch.ones((next_tokens.shape[0],), dtype=int),
+ max_input_length=1,
+ )
+ else:
+ inference_params.setup_before_new_input(length=1)
+
+ output_tokens = [next_tokens]
+
+ mask = None
+ if self.config.qkv_format != "thd":
+ mask = (input_ids == 0).unsqueeze(1).unsqueeze(1)
+
+ for _ in range(max_new_tokens):
+ if self.config.qkv_format != "thd":
+ # It will not work with cuda graphs, but it is not used for thd qkv_format.
+ # Attention mask in bshd needs attn_mask increased by 1 to
+ # include the next token to be generated
+ mask = self._make_mask_one_token_longer(mask)
+
+ # setup_cache_params_from_infer_params(inference_params, input_ids)
+ # @sudhakars: could create position_ids from mask here
+ next_tokens = self._model_generation_phase(
+ hidden_states,
+ mask,
+ mask_type="padding" if self.config.qkv_format == "thd" else "arbitrary",
+ )
+ # next_tokens is static output tensor, so we need to clone it
+ # - it gets changed every iteration.
+ output_tokens.append(next_tokens.clone())
+
+ result = torch.cat((input_ids, torch.stack(output_tokens).permute([1, 0])), dim=1)
+ return result
+
+ def forward(self, *args, **kwargs):
+ self._model_context_phase.set_inference_params(None)
+ hidden_states = self.model.embed_tokens(kwargs["input_ids"])
+ logits = self._model_context_phase(
+ hidden_states,
+ attention_mask=(
+ (kwargs["input_ids"] == 0) if self.config.qkv_format != "thd" else None
+ ),
+ attn_mask_type="arbitrary",
+ )
+ return logits
+
+
+class TELlamaForCausalLMCudaGraphs(TELlamaForCausalLM):
+ """
+ TELlamaForCausalLMCudaGraphs is the version of the class TELlamaForCausalLM
+ using CUDA Graphs to speed it up. We need to make one trade-off.
+ Namely, batch_size, max_seq_len and max_context_seq_len need to be static.
+ It is necessary to run generation with the same value of
+ these variables that we recorded graph on.
+ """
+
+ def __init__(self, config: LlamaConfig):
+ super().__init__(config)
+ assert (
+ config.qkv_format == "thd"
+ ), "Generation with CUDA Graphs are implemented only for thd format."
+
+ # Preparation of the static buffers.
+ self.config = config
+ self.hidden_states_buffer = torch.empty(
+ (
+ config.cuda_graphs_static_batch_size,
+ config.cuda_graphs_static_max_context_len,
+ config.hidden_size,
+ )
+ ).cuda()
+ # This is in fact part of the buffer for hidden_states.
+ self.generation_buffer = self._get_generation_buffer(self.hidden_states_buffer)
+ self.inference_params = InferenceParams(
+ max_batch_size=config.cuda_graphs_static_batch_size,
+ max_sequence_length=config.cuda_graphs_static_max_seq_len,
+ qkv_format="thd",
+ )
+
+ self._model_generation_phase.set_inference_params(self.inference_params)
+ self._model_context_phase.set_inference_params(self.inference_params)
+
+ def record(self):
+ # We want to record model in training=False, because it will be used in generation.
+ self.eval()
+
+ # Here "the trick" happens. We override methods from TELlamaForCausalLM
+ # with their recorded version. After invocation of each of them,
+ # captured graph will be replayed with minimal usage of CPU,
+ # what will lead to huge speedup.
+ input_shape = (
+ self.config.cuda_graphs_static_batch_size,
+ self.config.cuda_graphs_static_max_context_len,
+ )
+ self.inference_params.reset()
+ self.inference_params.setup_before_new_input(
+ lengths_tensor=torch.tensor(input_shape[0] * [input_shape[1]], device="cuda"),
+ max_input_length=input_shape[1],
+ )
+ self._model_context_phase = self.record_graph(
+ PartialForwardWrapper(
+ self._model_context_phase,
+ attn_mask_type=(
+ "padding_causal" if self.inference_params.qkv_format == "thd" else "arbitrary"
+ ),
+ ),
+ self.hidden_states_buffer,
+ ) # CUDA Graphs recording
+
+ input_shape = (self.config.cuda_graphs_static_batch_size, 1)
+ self.inference_params.reset()
+ self.inference_params.setup_before_new_input(
+ lengths_tensor=torch.tensor(input_shape[0] * [input_shape[1]], device="cuda"),
+ max_input_length=input_shape[1],
+ )
+ self._model_generation_phase = self.record_graph(
+ PartialForwardWrapper(
+ self._model_generation_phase,
+ mask_type="padding" if self.inference_params.qkv_format == "thd" else "arbitrary",
+ ),
+ self.generation_buffer,
+ ) # CUDA Graphs recording
+
+ """
+ Functions _create_hidden_states_buffer and _create_inference_params
+ from base class are overriden to make hidden_states and inference_params static
+ - not changing their position in memory between every invocation.
+ """
+
+ def _create_hidden_states_buffer(self, *args, **kwargs):
+ return self.hidden_states_buffer
+
+ def _create_inference_params(self, *args, **kwargs):
+ self.inference_params.reset()
+ return self.inference_params
+
+ def _get_max_input_seq_len(self, _):
+ return self.config.cuda_graphs_static_max_context_len
+
+ @torch.no_grad()
+ def record_graph(self, function, input_tensor):
+ # function is invoked on argument (self.hidden_states,) and all kernels are recorded.
+ # record_graph() returns captured function, which can be run later with lower of th CPU.
+ fp8_format = Format.HYBRID
+ fp8_recipe = DelayedScaling(
+ fp8_format=fp8_format, amax_history_len=1024, amax_compute_algo="max"
+ )
+
+ # We need both autocasts: FP8 for operations that can run in lower precision
+ # and BF16 for those that cannot.
+ with autocast(dtype=torch.bfloat16, cache_enabled=False):
+ graphed_function = te.pytorch.make_graphed_callables(
+ function,
+ (input_tensor,),
+ fp8_enabled=self.config.fp8,
+ fp8_recipe=fp8_recipe,
+ allow_unused_input=True,
+ num_warmup_iters=3,
+ )
+ return graphed_function
diff --git a/docs/examples/te_gemma/te_llama_loading_weights.py b/docs/examples/te_gemma/te_llama_loading_weights.py
new file mode 100755
index 0000000000..a5ab151f67
--- /dev/null
+++ b/docs/examples/te_gemma/te_llama_loading_weights.py
@@ -0,0 +1,224 @@
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+import os
+import re
+import gc
+import torch
+
+from typing import List
+
+from transformer_engine.pytorch.fp8 import fp8_model_init
+
+from transformers.modeling_utils import load_state_dict, _load_state_dict_into_model
+from transformers.utils.hub import get_checkpoint_shard_files
+
+"""
+ This file contains logic of mapping the HuggingFace LlamaModel parameters
+ with TransformerEngine TransformerLayer. When we have initialized Transformer models
+ both with HF and with TE, we can copy parameters from the first to the second.
+"""
+
+
+def _load_weights_for_fp8_model(vanilla_model, hyperparams):
+ # The weights are loaded from the file with state_dict
+ # of model with weights which contains also fp8 parameters.
+ # The weights are in BF16 precision, but they contain fp8 metadata
+ # computed by the calibration procedure.
+ vanilla_model.load_state_dict(
+ torch.load(hyperparams.fp8_model_weights_filename),
+ strict=False,
+ # strict = false, because some parameters have
+ # multiple pointers to the same weight
+ # vanilla_model._model_context_phase.model
+ # and vanilla_model._model_generation_phase.model
+ )
+
+
+def _load_weights_for_standard_model(vanilla_model, config):
+ # The weights are loaded from the file with original weights.
+ archive_file = os.path.join(config.model_name, "model.safetensors.index.json")
+ resolved_archive_file, _ = get_checkpoint_shard_files(config.model_name, archive_file)
+ total_dict = {}
+ for shard_file in resolved_archive_file:
+ state_dict = load_state_dict(shard_file)
+ total_dict.update(state_dict)
+
+ replace_params(
+ total_dict,
+ vanilla_model.state_dict(),
+ config,
+ qkv_fused_and_interleaved=config.fuse_qkv_params,
+ )
+ # Copy parameters like embedding:
+ _load_state_dict_into_model(vanilla_model, total_dict, start_prefix="")
+
+ # Force mem release. Taken from huggingface code.
+ del total_dict
+ gc.collect()
+
+
+def load_te_model(cls, config):
+ """
+ Custom method adapted from `from_pretrained` method in HuggingFace
+ Transformers repo:
+ https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579
+ """
+
+ config.use_cache = False # To make TransformerLayer compatible with LlamaModel
+ with fp8_model_init(config.fp8_model_init):
+ # there we need only to create model
+ vanilla_model = cls(config).to(torch.bfloat16).cuda()
+
+ # return vanilla_model
+ # and now we copy the weights into it
+ if config.fp8_model_weights_filename is not None:
+ _load_weights_for_fp8_model(vanilla_model, config)
+ else:
+ _load_weights_for_standard_model(vanilla_model, config)
+
+ return vanilla_model
+
+
+def _get_all_layer_prefixes_to_update(hf_state_dict):
+ """
+ There are many parameters in hf_state_dict, whose name start with "model.layers.[number]."
+ This function extracts all strings like "model.layers.[number]."
+ that are starting strings of keys in hf_state_dict.
+ """
+ all_layer_prefixes = set()
+ for param_key in hf_state_dict.keys():
+ layer_prefix_pat = "model.layers.\d+."
+ m = re.match(layer_prefix_pat, param_key)
+ if m is not None:
+ all_layer_prefixes.add(m.group())
+ return all_layer_prefixes
+
+
+def replace_params(hf_state_dict, te_state_dict, config, qkv_fused_and_interleaved=False):
+ # collect all layer prefixes to update
+ all_layer_prefixes = set()
+ for param_key in hf_state_dict.keys():
+ layer_prefix_pat = "model.layers.\d+."
+ m = re.match(layer_prefix_pat, param_key)
+ if m is not None:
+ all_layer_prefixes.add(m.group())
+
+ for layer_prefix in all_layer_prefixes:
+ # When loading weights into models with less number of layers, skip the
+ # copy if the corresponding layer doesn't exist in HF model
+ if layer_prefix + "input_layernorm.weight" in hf_state_dict:
+ te_state_dict[layer_prefix + "self_attention.layernorm_qkv.layer_norm_weight"].data[
+ :
+ ] = hf_state_dict[layer_prefix + "input_layernorm.weight"].data[:]
+
+ if layer_prefix + "self_attn.q_proj.weight" in hf_state_dict:
+ te_state_dict[layer_prefix + "self_attention.layernorm_qkv.query_weight"].data[:] = (
+ hf_state_dict[layer_prefix + "self_attn.q_proj.weight"].data[:]
+ )
+
+ if layer_prefix + "self_attn.k_proj.weight" in hf_state_dict:
+ te_state_dict[layer_prefix + "self_attention.layernorm_qkv.key_weight"].data[:] = (
+ hf_state_dict[layer_prefix + "self_attn.k_proj.weight"].data[:]
+ )
+
+ if layer_prefix + "self_attn.v_proj.weight" in hf_state_dict:
+ te_state_dict[layer_prefix + "self_attention.layernorm_qkv.value_weight"].data[:] = (
+ hf_state_dict[layer_prefix + "self_attn.v_proj.weight"].data[:]
+ )
+
+ if layer_prefix + "self_attn.o_proj.weight" in hf_state_dict:
+ te_state_dict[layer_prefix + "self_attention.proj.weight"].data[:] = hf_state_dict[
+ layer_prefix + "self_attn.o_proj.weight"
+ ].data[:]
+
+ if layer_prefix + "post_attention_layernorm.weight" in hf_state_dict:
+ te_state_dict[layer_prefix + "layernorm_mlp.layer_norm_weight"].data[:] = hf_state_dict[
+ layer_prefix + "post_attention_layernorm.weight"
+ ].data[:]
+
+ # It may happen that gate_proj.weight and up_proj.weight will be in the different files, so we need to
+ # load them separately.
+ if layer_prefix + "mlp.gate_proj.weight" in hf_state_dict:
+ te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[
+ : config.intermediate_size
+ ] = hf_state_dict[layer_prefix + "mlp.gate_proj.weight"].data
+
+ if layer_prefix + "mlp.up_proj.weight" in hf_state_dict:
+ te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[
+ config.intermediate_size :
+ ] = hf_state_dict[layer_prefix + "mlp.up_proj.weight"].data
+
+ if layer_prefix + "mlp.down_proj.weight" in hf_state_dict:
+ te_state_dict[layer_prefix + "layernorm_mlp.fc2_weight"].data[:] = hf_state_dict[
+ layer_prefix + "mlp.down_proj.weight"
+ ].data[:]
+ return all_layer_prefixes
+
+
+# def replace_params(hf_state_dict, te_state_dict, config, qkv_fused_and_interleaved=False):
+# """
+# Replaces params from TE TransformerLayer state_dict with corresponding parameters
+# from HuggingFace LlamaModel state_dict.
+# """
+# all_layer_prefixes: List[str] = _get_all_layer_prefixes_to_update(hf_state_dict)
+
+# head_dim = config.hidden_size // config.num_attention_heads
+
+# for layer_prefix in all_layer_prefixes:
+
+# def copy_from_ht_to_te(te_name, hf_name, start=None, end=None):
+# te_state_dict[layer_prefix + te_name].data[start:end].copy_(
+# hf_state_dict[layer_prefix + hf_name]
+# )
+
+# copy_from_ht_to_te(
+# "self_attention.layernorm_qkv.layer_norm_weight", "input_layernorm.weight"
+# )
+# copy_from_ht_to_te("self_attention.proj.weight", "self_attn.o_proj.weight")
+# copy_from_ht_to_te("layernorm_mlp.layer_norm_weight", "post_attention_layernorm.weight")
+# copy_from_ht_to_te("layernorm_mlp.fc2_weight", "mlp.down_proj.weight")
+# copy_from_ht_to_te(
+# "layernorm_mlp.fc1_weight", "mlp.gate_proj.weight", end=config.intermediate_size
+# )
+# copy_from_ht_to_te(
+# "layernorm_mlp.fc1_weight", "mlp.up_proj.weight", start=config.intermediate_size
+# )
+
+# if qkv_fused_and_interleaved:
+# """
+# When qkv_fused_and_interleaved=True, key, query and value layers are on one tensor
+# in TE TransformerLayer. Moreover they are interleaved within each head.
+# Let q_i, k_i and v_i be query, key and value layers for i-th head respectively.
+# Then TE stores weight tensor in the form:
+# [q1 k1 v1 q2 k2 v2 ...]
+# This is done to maximally optimize performance time.
+# """
+# te_qkv_layer = te_state_dict[layer_prefix + "self_attention.layernorm_qkv.weight"]
+
+# def copy_interleave(hf_name, idx):
+# src = hf_state_dict[layer_prefix + hf_name]
+# for head_nr in range(config.num_attention_heads):
+# dst_offset = head_nr * config.head_dim * 3
+# dst_slice = slice(
+# dst_offset + idx * config.head_dim, dst_offset + (idx + 1) * config.head_dim
+# )
+# src_slice = slice(
+# head_nr * config.head_dim, head_nr * config.head_dim + config.head_dim
+# )
+# te_qkv_layer[dst_slice, :] = src[src_slice, :]
+
+# copy_interleave("self_attn.q_proj.weight", 0)
+# copy_interleave("self_attn.k_proj.weight", 1)
+# copy_interleave("self_attn.v_proj.weight", 2)
+# else:
+# copy_from_ht_to_te(
+# "self_attention.layernorm_qkv.query_weight", "self_attn.q_proj.weight"
+# )
+# copy_from_ht_to_te("self_attention.layernorm_qkv.key_weight", "self_attn.k_proj.weight")
+# copy_from_ht_to_te(
+# "self_attention.layernorm_qkv.value_weight", "self_attn.v_proj.weight"
+# )
+
+# return all_layer_prefixes
diff --git a/docs/examples/te_gemma/test_paged_attn.ipynb b/docs/examples/te_gemma/test_paged_attn.ipynb
new file mode 100755
index 0000000000..543ebe9262
--- /dev/null
+++ b/docs/examples/te_gemma/test_paged_attn.ipynb
@@ -0,0 +1,33 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ace403ac-c276-4378-a4e8-0155165f9934",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_finetuning_with_te.ipynb b/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_finetuning_with_te.ipynb
new file mode 100755
index 0000000000..7875ffc9f3
--- /dev/null
+++ b/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_finetuning_with_te.ipynb
@@ -0,0 +1,314 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Accelerating a Hugging Face Gemma model finetuning with Transformer Engine"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In the previous [tutorial](../te_llama/tutorial_accelerate_hf_llama_finetuning_with_te.ipynb), we demonstrated how to accelerate HF Llama models using the Transformer Engine library. We replaced `LlamaDecoderLayer` with `TransformerLayer` from the Transformer Engine, achieving a speedup. Furthermore, we conducted the finetuning in FP8 precision, which yielded an additional speedup.\n",
+ "\n",
+ "Now, we will undertake a similar enhancement for the Google's [Gemma](https://blog.google/technology/developers/gemma-open-models/) model."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Dependencies for this tutorial\n",
+ "\n",
+ "Following files and media are necessary to effectively run this tutorial:\n",
+ "\n",
+ "1. `te_gemma.py`\n",
+ " - This file contains the code to load a Hugging Face Gemma checkpoint in Transformer Engine's `TransformerLayer` instead of Hugging Face's `GemmaDecoderLayer`. This is used in the following two sections of the tutorial - \"Improvement 1\" and \"Improvement 2\".\n",
+ "2. `utils.py`\n",
+ " - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training and other miscellaneous tasks like restarting the jupyter notebook from within the cell. \n",
+ "3. `requirements.txt`\n",
+ " - This file contains necessary Python packages for this tutorial.\n",
+ "4. `media/`\n",
+ " - This directory contains the images used in the following tutorial."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%pip install -r requirements.txt\n",
+ "\n",
+ "import torch\n",
+ "cudnn_version = torch.backends.cudnn.version()\n",
+ "assert cudnn_version >= 90100, \"cuDNN version >= 9.1.0 is needed to run this tutorial.\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Differences between Llama and Gemma"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Thr Llama and the Gemma are very similar models - both are based on Transformer Decoder architecture. The most important architectural differences between them are the following:\n",
+ "\n",
+ "\n",
+ "| Feature | Llama | Gemma |\n",
+ "|----------------------------------------------|------------------------------------|--------------------------------------------|\n",
+ "| **Norm Layer** | Standard RMSNorm
$y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\varepsilon}} * \\gamma + \\beta$ | RMSNorm with zero centered gamma parameter
$y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\varepsilon}} * (\\textcolor{red}{1 +} \\gamma) + \\beta$ |\n",
+ "| **Embedding Dimension/Head Dimension** | 4096/4096 | 3072/4096 |\n",
+ "| **Activation Function** | SwiGlu | GeGlu |\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## [Baseline] Running HF `GemmaModel` (Precision: `BF16`)\n",
+ "\n",
+ "Similarly to the Llama tutorial, we begin the experiments by running baseline Hugging Face Gemma model finetuning in BF16 precision.\n",
+ "\n",
+ "