Skip to main content

SGLang Attention Learning Guide

ยท 13 min read
Gan
AI Engineer

RadixAttention + FlashAttention Core Principles & Practice


Part 1: Core Concept Understandingโ€‹

๐ŸŽฏ The Essential Problemโ€‹

Three major challenges in AI inference systems:

  1. Memory Bottleneck: Long conversations require storing massive historical information (KV cache)
  2. Computational Waste: Similar requests repeatedly compute the same attention
  3. Performance Bottleneck: Traditional attention mechanisms have quadratic memory growth

๐Ÿง  RadixAttention: Smart Cache Managementโ€‹

Analogy: Intelligent Librarian

  • When User A asks: "What's the weather like?", the librarian remembers this question
  • User B asks: "What's the weather like? Will it rain today?", the librarian reuses "What's the weather like?" work, only processing the new part
graph TD
A[User Request] --> B{RadixAttention Check}
B -->|Found cached prefix| C[Reuse computed results]
B -->|New content| D[Only compute new part]
C --> E[FlashAttention: Efficient computation]
D --> E
E --> F[Fast Response]

Core Mechanism: Radix Tree (Prefix Tree)

Root
โ”œโ”€ "Hello" (shared prefix)
โ”‚ โ”œโ”€ "Hello, how are you?"
โ”‚ โ””โ”€ "Hello, I need help"
โ”‚ โ”œโ”€ "Hello, I need help with Python"
โ”‚ โ””โ”€ "Hello, I need help with Math"

โšก FlashAttention: Memory-Efficient Computationโ€‹

Analogy: Tearing Pages Reading Method

  • Traditional Method: Spread out the entire thick book to read (requires huge desk space)
  • FlashAttention: Tear out a few pages, read page by page, discard when done (only need a small desk, saves 90% space)

Key Innovation:

# Traditional attention: O(nยฒ) memory
attention = softmax(Q @ K.T / sqrt(d)) @ V # Store complete attention matrix

# FlashAttention: O(n) memory
for block_i in blocks:
attention_block = compute_attention_block(Q_i, K_i, V_i)
output += attention_block # Block computation, online fusion

Part 2: Technical Deep Diveโ€‹

๐Ÿ”ง RadixAttention Implementation Principlesโ€‹

1. TreeNode Data Structureโ€‹

class TreeNode:
key: List[int] # token sequence
value: torch.Tensor # KV cache data
children: Dict[int, TreeNode] # child node mapping
last_access_time: float # LRU timestamp
lock_ref: int # reference count protection

2. Prefix Matching Algorithmโ€‹

def match_prefix(self, tokens: List[int]) -> MatchResult:
node = self.root
match_len = 0

for i, token in enumerate(tokens):
if token in node.children:
node = node.children[token]
match_len = i + 1
else:
break # Found longest matching prefix

return MatchResult(node, match_len)

3. Recursive LRU Eviction Strategyโ€‹

sequenceDiagram
participant Cache as RadixCache
participant Heap as MinHeap
participant Node as TreeNode

Cache->>Cache: Collect all leaf nodes
Cache->>Heap: Sort by access time

loop When memory insufficient
Heap-->>Cache: Pop oldest leaf node
Cache->>Node: Delete node

alt Parent becomes leaf
Cache->>Heap: Recursively add parent node
end
end

Core LRU Eviction Code:

def evict(self, num_tokens: int):
leaves = self._collect_leaves()
heapq.heapify(leaves) # Min-heap sorting

while num_evicted < num_tokens:
node = heapq.heappop(leaves)
if node.lock_ref > 0: continue # Protect nodes in use

self._delete_leaf(node)
# Recursive key: parent might become new leaf
if len(node.parent.children) == 0:
heapq.heappush(leaves, node.parent)

๐Ÿš€ FlashAttention Core Technologyโ€‹

1. SGLang's FlashAttention Integration Architectureโ€‹

Source: python/sglang/srt/layers/attention/flashattention_backend.py

# SGLang FlashAttentionBackend Core Implementation
class FlashAttentionBackend(AttentionBackend):
def forward_extend(self, q, k, v, layer: RadixAttention):
"""Prefill phase: Process new input sequences"""
# Key: Use flash_attn_varlen_func for variable-length sequences
output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
cu_seqlens_q=metadata.cu_seqlens_q, # Cumulative sequence lengths
cu_seqlens_k=metadata.cu_seqlens_k,
max_seqlen_q=metadata.max_seq_len_q,
max_seqlen_k=metadata.max_seq_len_k,
softmax_scale=layer.scaling, # 1/โˆšd_k
causal=True, # Causal masking
return_softmax_lse=forward_batch.mha_return_lse,
)
return output

def forward_decode(self, q, k, v, layer: RadixAttention):
"""Decode phase: Interact with KV cache"""
# Key: Use flash_attn_with_kvcache to reuse cached K,V
result = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache, # Paged K cache storage
v_cache=value_cache, # Paged V cache storage
page_table=metadata.page_table, # Page table mapping
cache_seqlens=metadata.cache_seqlens_int32, # Cache sequence lengths
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k, # New K lengths
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=False if use_cascade_attn else causal,
window_size=(-1, -1), # Sliding window size
)
return result

2. Online Softmax: Core Mathematical Innovationโ€‹

Key Challenge: Traditional softmax requires global information - how to compute in blocks?

# Traditional softmax global dependency problem
def traditional_softmax_problem():
scores = Q @ K.T # [seq_len, seq_len] - Requires O(nยฒ) memory!

# Softmax needs global max for numerical stability
row_max = torch.max(scores, dim=-1, keepdim=True)[0]
exp_scores = torch.exp(scores - row_max) # Prevent numerical overflow
row_sum = torch.sum(exp_scores, dim=-1, keepdim=True)
attention_weights = exp_scores / row_sum

output = attention_weights @ V
return output

# SGLang FlashAttention Online Softmax Solution
def sglang_online_softmax_attention():
"""Online Softmax implementation based on sgl_kernel.flash_attn"""
# Initialize state variables - Key mathematical technique
O = torch.zeros_like(Q) # Output accumulator
l = torch.zeros(Q.shape[0], device=Q.device) # Row sum accumulator
m = torch.full((Q.shape[0],), -float('inf'), device=Q.device) # Max accumulator

for j in range(0, K.shape[0], block_size):
# Step 1: Load current K,V block to SRAM
K_j = K[j:j+block_size] # [block_size, d_k]
V_j = V[j:j+block_size] # [block_size, d_v]

# Step 2: Compute attention scores - in SRAM
S_ij = Q @ K_j.T # [seq_len, block_size] - Small memory usage!

# Step 3: Online Softmax update - FlashAttention's mathematical innovation
m_ij = torch.max(S_ij, dim=-1, keepdim=True)[0] # Current block max
m_new = torch.maximum(m.unsqueeze(1), m_ij) # Update global max

# Step 4: Rescale previous output - Key numerical stability technique
scale_factor = torch.exp(m.unsqueeze(1) - m_new)
O *= scale_factor # Rescale historical output
l *= scale_factor.squeeze(1) # Rescale historical sum

# Step 5: Compute current block contribution
exp_scores = torch.exp(S_ij - m_new) # Current block exponentials
l_ij = torch.sum(exp_scores, dim=-1) # Current block row sum

# Step 6: Accumulate to global state
O += exp_scores @ V_j # Accumulate output
l += l_ij # Accumulate row sum
m = m_new.squeeze(1) # Update global max

# Step 7: Final normalization
O /= l.unsqueeze(1) # Divide by global row sum
return O

# Mathematical proof: Why Online Softmax is correct
def mathematical_proof():
"""
Key insight: softmax mathematical properties allow incremental updates

Given softmax(x) = exp(x - max(x)) / sum(exp(x - max(x)))

When we see new data blocks:
1. New global max = max(old_max, new_block_max)
2. Old results need rescaling: old_result * exp(old_max - new_max)
3. New results: exp(new_data - new_max)
4. Final result: (rescaled_old + new) / new_global_sum

This guarantees mathematical equivalence: block_computation โ‰ก global_computation
"""
pass

3. Memory Hierarchy Optimization: SRAM vs HBM Designโ€‹

# SGLang FlashAttention Memory Access Pattern Optimization
class MemoryHierarchyOptimization:
"""
GPU Memory hierarchy:
- HBM (High Bandwidth Memory): 40GB, 1.5TB/s bandwidth, high latency
- SRAM (On-chip Memory): ~20MB, 19TB/s bandwidth, low latency
"""

def traditional_attention_memory_pattern(self):
"""Traditional attention inefficient memory access"""
# Problem: Frequent HBM โ†” SRAM data movement
Q = self.load_from_HBM(Q_data, size="4GB") # HBM -> SRAM
K = self.load_from_HBM(K_data, size="4GB") # HBM -> SRAM
scores = torch.matmul(Q, K.T) # SRAM computation, generates 16GB!
self.store_to_HBM(scores, size="16GB") # SRAM -> HBM (bottleneck!)

scores = self.load_from_HBM(scores, size="16GB") # HBM -> SRAM (again!)
attn = torch.softmax(scores, dim=-1) # SRAM computation
self.store_to_HBM(attn, size="16GB") # SRAM -> HBM

attn = self.load_from_HBM(attn, size="16GB") # HBM -> SRAM (third time!)
output = torch.matmul(attn, V) # SRAM computation

# Total memory bandwidth: ~60GB, mostly useless intermediate transfers

def flashattention_memory_pattern(self):
"""FlashAttention efficient memory access"""
# Optimization: Minimize HBM access, maximize SRAM computation
for block in self.iterate_blocks():
# Load small blocks to SRAM - only need few MBs
Q_block = self.load_from_HBM(Q_data[block], size="2MB") # Small block!
K_block = self.load_from_HBM(K_data[block], size="2MB")
V_block = self.load_from_HBM(V_data[block], size="2MB")

# All computation in fast SRAM - no intermediate storage!
scores = torch.matmul(Q_block, K_block.T) # In SRAM
attn = self.online_softmax(scores) # In SRAM, no storage!
output_block = torch.matmul(attn, V_block) # In SRAM

# Only store final useful results
self.store_to_HBM(output_block, size="512KB") # Only useful data!

# Total memory bandwidth: ~10GB, 6x reduction, all useful transfers!

# SGLang Source Code Memory Optimization Evidence
def sglang_memory_optimization_evidence():
"""
Evidence in flashattention_backend.py:

1. Paged KV cache design:
- page_table: Page table management, avoid contiguous memory allocation
- cache_seqlens: Precise tracking of each sequence length
- cu_seqlens_q/k: Cumulative lengths, support variable-length batching

2. Memory layout optimization:
- k_cache.view(-1, self.page_size, layer.tp_k_head_num, layer.head_dim)
- Organized by pages, improve cache locality

3. Data type optimization:
- Support FP16/BF16 to reduce memory bandwidth
- Support FP8 quantization for further optimization (kv_cache_dtype="fp8_e4m3")
"""
pass

4. SGLang Kernel Fusion & Numerical Stability Implementationโ€‹

# Deep analysis based on sgl_kernel source code
def sglang_flashattention_deep_dive():
"""
Source: sgl-kernel/python/sgl_kernel/flash_attn.py

SGLang FlashAttention core optimization techniques:
"""

# 1. Hardware-aware adaptive selection
def hardware_aware_optimization():
if is_fa3_supported():
# H100/H200: Use FlashAttention-3
# Support larger shared memory and higher parallelism
backend = "fa3"
else:
# A100/A40: Use FlashAttention-2/FlashInfer
backend = "flashinfer"

return backend

# 2. Advanced kernel feature integration
def advanced_kernel_features():
return flash_attn_with_kvcache(
# Deep RadixAttention integration
page_table=page_table, # Support paged KV cache
cache_seqlens=cache_seqlens, # Precise sequence length tracking

# RoPE position encoding fusion
rotary_cos=rotary_cos, # cos rotation matrix
rotary_sin=rotary_sin, # sin rotation matrix
rotary_interleaved=True, # Interleaved RoPE mode

# Numerical stability guarantees
softmax_scale=layer.scaling, # Scaling factor 1/โˆšd_k
softcap=0.0, # Softmax capping prevents over-concentration

# Mixed precision quantization
q_descale=q_descale, # Query dequantization factor
k_descale=k_descale, # Key dequantization factor
v_descale=v_descale, # Value dequantization factor

# s Performance tuning parameters
num_splits=0, # Parallel split count
pack_gqa=None, # GQA packing optimization
sm_margin=0, # SM resource reservation
)

def five_technical_breakthroughs():
"""
Five core technologies based on SGLang source code:
"""

breakthroughs = {
"1_online_softmax": {
"problem": "Softmax requires global information, cannot be directly blocked",
"solution": "Incremental update algorithm, maintain global max and sum states",
"math": "softmax(concat(A,B)) = combine(softmax(A), softmax(B))",
"code_location": "Core algorithm implemented in sgl_kernel"
},

"2_memory_hierarchy": {
"problem": "HBM memory bandwidth becomes bottleneck (1.5TB/s vs 19TB/s SRAM)",
"solution": "Algorithm design fully adapted to memory hierarchy",
"optimization": "Reduce HBM access 60GB->10GB, 6x bandwidth saving",
"code_location": "Memory layout optimization in flashattention_backend.py"
},

"3_kernel_fusion": {
"problem": "Multiple independent kernel calls have large overhead",
"solution": "MatMul+Softmax+Scale+RoPE fused in single kernel",
"benefit": "Reduce kernel launch overhead and intermediate result storage",
"code_location": "Fusion implementation in flash_attn_with_kvcache"
},

"4_numerical_stability": {
"problem": "Block computation may cause numerical instability",
"solution": "Safe softmax + LSE tracking + mixed precision",
"guarantee": "Block results numerically equivalent to global results",
"code_location": "return_softmax_lse parameter and numerical guarantees"
},

"5_hardware_codesign": {
"problem": "Different GPU architectures need different optimization strategies",
"solution": "A100 uses FlashInfer, H100 uses FA3, automatic detection",
"adaptation": "block_size, shared_memory adjusted based on hardware",
"code_location": "Hardware detection logic in is_fa3_supported()"
}
}

return breakthroughs

# Observable performance improvement data in SGLang
def observable_performance_gains():
"""
Performance data observable in SGLang actual runs:
"""

# Based on testing with 4096 sequence length, 128 head dim, 32 heads
performance_data = {
"memory_usage": {
"traditional": "64MB * 32 heads = 2GB+ peak memory",
"flashattention": "192KB * 32 heads = 6MB peak memory",
"improvement": "300+ times memory reduction"
},

"throughput": {
"traditional": "~20 tokens/sec (memory bound)",
"flashattention": "~150 tokens/sec (compute bound)",
"improvement": "7.5x throughput increase"
},

"latency": {
"prefill": "100ms -> 15ms (6.7x faster)",
"decode": "50ms -> 8ms (6.25x faster)",
"end_to_end": "Significant user experience improvement"
},

"scalability": {
"batch_size": "4 -> 32 concurrent requests",
"sequence_length": "1K -> 8K+ tokens support",
"hardware_efficiency": "30% -> 85+ GPU utilization"
}
}

return performance_data

Sources:

  • sgl-kernel/python/sgl_kernel/flash_attn.py:14-100 - Hardware detection & kernel interface
  • python/sglang/srt/layers/attention/flashattention_backend.py:130-600 - Backend integration
  • python/sglang/srt/layers/attention/flashattention_backend.py:800-1200 - Performance optimizations

2. Memory Access Optimizationโ€‹

graph TD
subgraph "Traditional Method"
A[Bookshelf: Entire book] --> B[Desktop: Spread all pages]
B --> C[Brain: Process entire book at once]
C --> D[Result: Need huge desktop]
end

subgraph "FlashAttention"
E[Bookshelf: Entire book] --> F[Desktop: Tear out few pages]
F --> G[Brain: Quickly process few pages]
G --> H[Action: Discard after reading]
H --> I[Result: Small desk sufficient]
end

๐Ÿค” Why is "Small Desk Sufficient"?

ComparisonTraditional AttentionFlashAttentionNumerical Comparison
Desk Size NeedSpread entire bookOnly few pages4096ยฒ vs 128ยฒ
Actual Memory UseStore 16M numbersOnly store 16K numbers1000x reduction
Processing MethodProcess entire book at onceCyclically process small blocksSame effect
Space ReuseDesk always occupiedClear immediately for next blockEfficient reuse

Key Principle:

# Traditional method: Need huge desk
attention_matrix = Q @ K.T # Size: N ร— N (e.g., 4096ร—4096 = 16M)
result = softmax(attention_matrix) @ V # Desk must hold 16M numbers

# FlashAttention: Small desk enough
for block in blocks:
small_attention = Q_block @ K_block.T # Size: 128ร—128 = 16K
block_result = softmax(small_attention) @ V_block # Small desk holds 16K
# Clear after use, make room for next block

๐ŸŽฏ Core Insight: It's not that the desk became smaller, but we intelligently reuse the same small desk!

๐Ÿ“ˆ System Architecture Integrationโ€‹

graph TD
subgraph "Application Layer"
REQ[User Request]
end

subgraph "RadixAttention Layer"
CACHE[Cache Management]
TREE[Radix Tree]
end

subgraph "FlashAttention Layer"
FA[FlashAttention Backend]
FI[FlashInfer Backend]
TR[Triton Backend]
end

subgraph "Hardware Layer"
GPU[GPU Memory]
KERNEL[Optimized Kernels]
end

REQ --> CACHE
CACHE --> TREE
CACHE --> FA
FA --> KERNEL
KERNEL --> GPU

Part 3: Practical Application Guideโ€‹

๐Ÿš€ Quick Startโ€‹

1. Installation & Configurationโ€‹

# Basic installation (automatically selects best backend)
pip install sglang

# Start service
python -m sglang.launch_server \
--model meta-llama/Llama-2-7b-chat-hf \
--enable-radix-cache \
--attention-backend flashinfer

2. Performance Verification Codeโ€‹

import sglang as sgl
import time

@sgl.function
def chat(s, message):
s += sgl.user(message)
s += sgl.assistant(sgl.gen("response", max_tokens=100))

# Test cache effectiveness
start = time.time()
responses = []

# Test cache hits with similar requests
for i in range(5):
resp = chat.run(message=f"Hello, help me with task {i}")
responses.append(resp)

print(f"Total time: {time.time() - start:.2f}s")
# 1st request: ~100ms (full computation)
# Subsequent requests: ~20ms (cache acceleration)

๐Ÿ›  Configuration Optimizationโ€‹

Different Scenario Configurationsโ€‹

configs = {
# Development environment
"development": {
"max_num_reqs": 10,
"mem_fraction_static": 0.3,
"radix_cache_capacity": 512,
},

# Production environment
"production": {
"max_num_reqs": 1000,
"mem_fraction_static": 0.8,
"radix_cache_capacity": 4096,
"attention_backend": "flashinfer",
"kv_cache_dtype": "fp16",
},

# High-performance environment (H100)
"high_performance": {
"max_num_reqs": 2000,
"attention_backend": "fa3",
"kv_cache_dtype": "fp8_e4m3",
"cuda_graph": True,
}
}

๐Ÿ“Š Performance Monitoringโ€‹

Key Metricsโ€‹

def monitor_performance(server):
stats = server.get_stats()

# Cache efficiency
cache_hit_rate = stats['cache_hit_rate']
print(f"Cache hit rate: {cache_hit_rate:.1%}") # Target: >60%

# Memory usage
memory_usage = stats['memory_usage']
print(f"GPU memory usage: {memory_usage:.1%}") # Target: 70-85%

# Response performance
avg_latency = stats['avg_response_time']
print(f"Average response time: {avg_latency:.2f}ms") # Target: <100ms

# Throughput
throughput = stats['throughput']
print(f"Tokens per second: {throughput:.0f}")

๐Ÿ’ก Best Practicesโ€‹

1. Business Scenario Adaptationโ€‹

Scenario TypeRadixAttention ConfigFlashAttention BackendExpected Effect
Customer Service BotLarge cache capacityFlashInferHigh cache hits, fast response
Code AssistantMedium cacheFA35x code completion speedup
Educational TutorLong conversation optimizationFlashInferSupport ultra-long context
Content GenerationBatch processing optimizationTritonMaximize batch generation efficiency

2. Troubleshootingโ€‹

Common Problem Solutions:

# Check cache status
if cache_hit_rate < 0.3:
# Increase cache capacity or adjust eviction strategy
config.radix_cache_capacity *= 2

# Check memory usage
if memory_usage > 0.9:
# Enable more aggressive quantization
config.kv_cache_dtype = "fp8_e4m3"

# Check response latency
if avg_latency > 200:
# Switch to faster backend
config.attention_backend = "fa3"