← Back to Blog
CUDA

CUDA Kernel Optimization for Transformer Training: A Practical Guide

By Ryan Moore — September 22, 2025 — 14 min read

CUDA kernel optimization for deep learning

When most people think about improving training throughput, they think about scaling up — more GPUs, bigger clusters, better networking. Fewer people look inside the GPU itself, at the kernel-level operations that determine how efficiently each individual compute unit is being used. This is where the largest untapped performance gains usually live.

PyTorch and JAX are excellent frameworks, but their general-purpose kernels are not always optimally tuned for specific operations at specific batch sizes and sequence lengths. Custom CUDA kernels that target your exact workload can deliver 2–3x speedups on the operations that dominate transformer training time: attention, layer normalization, and embedding lookups.

The Profiling-First Approach

Before writing a single line of CUDA, you need a detailed breakdown of where training time is actually spent. This sounds obvious, but many teams go straight to implementing FlashAttention because they have heard it is fast, without measuring whether attention is actually their bottleneck.

Use NVIDIA Nsight Systems and Nsight Compute to profile a representative training step. Export the timeline and look for the top 5 kernel categories by total GPU time. For a standard transformer model training at batch size 32, sequence length 2048, you will typically see something like:

With this breakdown, it is clear that attention and GEMM are the right targets. Layer norm optimization will yield smaller absolute gains. Optimizing data loading before fixing attention would be premature optimization in the wrong direction.

FlashAttention: Why It Works and When to Use It

The standard attention computation — QK^T softmax V — requires materializing the full attention matrix of size (sequence_length, sequence_length) in GPU SRAM. For sequence length 2048, this is a 4M element matrix per head. For 96 attention heads at batch size 32, this creates significant memory pressure and forces expensive HBM reads and writes.

FlashAttention, and its successors FlashAttention-2 and FlashAttention-3, avoid materializing the full attention matrix by fusing the softmax and matmul operations and tiling the computation so that all data remains in the GPU's L2 cache. The key insight is that you do not need to compute the entire softmax at once — you can compute it incrementally using a numerically stable online algorithm.

On an A100-80G with sequence length 2048 and 96 attention heads, FlashAttention delivers 2.4–2.8x throughput versus the standard attention implementation. On H100 with FP8 support, the gains are even more pronounced because the fused kernel can leverage the Transformer Engine's FP8 matrix units throughout the attention computation.

Fused Layer Normalization

Layer normalization requires computing the mean and variance of the input tensor and then normalizing each element. In a naive implementation, this requires three passes over the data: one for mean, one for variance, one for the normalization itself. Each pass incurs a full HBM read and write.

A fused layer norm kernel computes all three operations in a single pass using CUDA warp-level reductions to accumulate mean and variance in registers. The result is a 2.1–2.6x throughput improvement over PyTorch's default layer_norm kernel, with identical numerical output.

The Deepiix kernel library includes fused implementations of layer norm, RMSNorm (used in Llama and Mistral architectures), and GroupNorm for convolutional models. All variants are available for both FP16 and BF16 precision, with FP8 forward pass support for H100 deployments.

Embedding Kernel Optimization

Vocabulary embedding lookups are an often-overlooked bottleneck in LLM training. For vocabularies of 50K–128K tokens, the embedding table is large (up to 1GB in FP16 for a 128K vocab at model dimension 4096), and lookup patterns are irregular — each training sample accesses a different, non-contiguous subset of rows.

The key optimization here is coalesced memory access. CUDA executes memory transactions in warps of 32 threads. For maximum bandwidth, those 32 threads should be reading or writing contiguous memory addresses. An unoptimized embedding lookup has random access patterns that result in 32 separate, non-coalesced memory transactions per warp — effectively serializing what could be a single vectorized load.

Deepiix's embedding kernels sort lookup indices within each warp before accessing the embedding table, enabling partial coalescing even for random access patterns. For vocabularies over 64K tokens, this typically improves embedding lookup throughput by 1.6–1.9x on A100 hardware.

Integration Without Code Changes

One of the design principles of the Deepiix kernel library is that it should not require model code changes. When a training job runs on Deepiix infrastructure, the platform intercepts PyTorch's kernel dispatch path and substitutes optimized kernels transparently, based on the operation type, tensor shapes, and hardware capabilities.

This is achieved through PyTorch's custom operator registration mechanism and CUDA's runtime kernel selection. The result is that your existing model code, training scripts, and framework integrations continue to work unchanged — but the underlying CUDA operations are replaced with Deepiix's optimized implementations.

This approach has been validated on GPT-2, GPT-3, BERT, T5, LLaMA-1, LLaMA-2, Mistral, and Mixtral architectures, as well as custom transformer variants across our production user base.


← Back to Blog