Skip to content

TE Gemma tutorial attempt#2 #1839

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions docs/examples/te_gemma/check_cuda_graphs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import torch
from transformer_engine.pytorch import Linear, LayerNorm


# 1. Define model with static buffers
class TE_Model(torch.nn.Module):
def __init__(self, max_seq_len=4096):
super().__init__()
self.max_seq_len = max_seq_len
self.ln = LayerNorm(1024)
self.attn_proj = Linear(1024, 1024)

# Pre-allocate static buffers
self.register_buffer("kv_cache", torch.zeros(max_seq_len, 1024, device="cuda"))
self.register_buffer(
"attn_mask", torch.tril(torch.ones(max_seq_len, max_seq_len, device="cuda"))
)

def forward(self, hidden_states, seq_start: int):
# Dynamic slicing of static buffers
seq_len = hidden_states.size(1)
current_mask = self.attn_mask[seq_start : seq_start + seq_len, :seq_len]

x = self.ln(hidden_states)
x = self.attn_proj(x)
# Update KV cache (in-place)
self.kv_cache[seq_start : seq_start + seq_len].copy_(x)
return x


# 2. Create graphable callables
model = TE_Model().cuda()
static_input = torch.randn(8, 256, 1024, device="cuda") # (batch, seq, hidden)
seq_start = torch.tensor(0, device="cuda")

# Wrap with CUDA Graphs
graph_model = torch.cuda.make_graphed_callables(
[model], # Module list
sample_args=[(static_input, seq_start)], # Must match actual input structure
# memory_pool=torch.cuda.graphs.graph_pool_handle(),
allow_unused_input=False,
)


# 3. Warmup and execution
def run_inference(x, seq_start):
# Inputs must match sample_args' device/type/shape
x = x.to("cuda", non_blocking=True).requires_grad_(False)
seq_start = seq_start.to("cuda", non_blocking=True)

with torch.cuda.amp.autocast():
return graph_model(x, seq_start)


# Warm-up (essential for TE's kernel auto-tuner)
for _ in range(3):
_ = run_inference(static_input, seq_start)
torch.cuda.synchronize()


# 4. Usage with dynamic sequence lengths
def process_batch(inputs, start_pos):
# inputs: (batch, seq) on CPU
inputs_gpu = inputs.to("cuda", non_blocking=True)

# Output shares memory with pre-allocated buffers
return run_inference(inputs_gpu, start_pos)
137 changes: 137 additions & 0 deletions docs/examples/te_gemma/check_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import functools
from typing import Optional, Tuple, Union, List
import torch
import transformer_engine as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.utils import assert_dim_for_fp8_exec
from transformer_engine.pytorch.module.base import get_workspace
import transformer_engine.pytorch.cpp_extensions as cpp_tex


@functools.lru_cache(maxsize=None)
def _empty_tensor() -> torch.Tensor:
"""Get tensor with no entries and no data"""
return torch.Tensor()


def gemm(
A: torch.Tensor,
B: torch.Tensor,
dtype: torch.dtype,
workspace: torch.Tensor,
gelu: bool = False,
gelu_input: Optional[torch.Tensor] = None,
grad: bool = False,
accumulate: bool = False,
layout: str = "TN",
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
use_bias: bool = False,
ub_algo: tex.CommOverlapAlgo = None,
ub: Union[tex.CommOverlap, tex.CommOverlapP2P] = None,
extra_output_tensor: torch.Tensor = None,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Non FP8 GEMM."""

assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported."
transa = layout[0] == "T"
transb = layout[1] == "T"
empty_tensor = _empty_tensor()
fp8_index = -1 # dummy index

if out is None:
out = torch.empty(
B.shape[1] if transb else B.shape[0],
A.shape[0] if transa else A.shape[1],
dtype=dtype,
device="cuda",
)
else:
if not out.is_contiguous():
raise ValueError("Output tensor is not contiguous.")

if gelu and not grad:
gelu_input = torch.empty_like(out, dtype=dtype)
elif not gelu:
gelu_input = empty_tensor

if grad and use_bias:
grad_bias = torch.empty(B.shape[1], dtype=out.dtype, device="cuda")
else:
grad_bias = empty_tensor

bias = bias if use_bias else empty_tensor

assert (
A.dtype == dtype and B.dtype == dtype
), f"Expected dtype={dtype}, but found A.dtype={A.dtype} and B.dtype={B.dtype}"
input_dtype = TE_DType[dtype]
output_dtype = TE_DType[out.dtype]
if use_bias:
bias_dtype = TE_DType[grad_bias.dtype] if grad else TE_DType[bias.dtype]
else:
bias_dtype = output_dtype

args = (
A,
empty_tensor,
fp8_index,
input_dtype,
transa,
B,
empty_tensor,
fp8_index,
input_dtype,
transb,
out,
empty_tensor, # out_scale
output_dtype,
empty_tensor, # out_amax
grad_bias if grad else bias,
bias_dtype,
gelu_input,
grad,
workspace,
workspace.shape[0],
accumulate,
False, # use_split_accumulator
)
fn = torch.ops.tex_ts.te_gemm_ts
if ub_algo is not None:
assert ub is not None, "ub object is None!"
_ = fn(*args)

import pdb

pdb.set_trace()
return out, grad_bias, gelu_input


if __name__ == "__main__":
fc2_weight = torch.load("fc2_weight.pth").cuda()

base_repo = "/perfhome/mnt/wkstn/work/repos/te_gemma_gen_support/TransformerEngine/docs/examples/te_gemma/"
base_repo = ""
gelu_out = torch.load(base_repo + "gelu_out.pth").cuda()

activation_dtype = torch.bfloat16
fc2_bias = _empty_tensor()
use_fc2_bias = False

dim_size = list(gelu_out.size())
dim_size[1] = fc2_weight.size(0)
fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device)

_ = cpp_tex.gemm(
fc2_weight,
gelu_out,
activation_dtype,
get_workspace(),
bias=fc2_bias,
use_bias=use_fc2_bias,
out=fc2_out,
ub_algo=None,
ub=None,
extra_output_tensor=None,
)
Loading