Tencent's WeDLM paper reports 3x speedup over vLLM using diffusion-style parallel decoding—multiple tokens generated per forward pass while maintaining KV cache compatibility. This post documents an independent port of WeDLM to Apple's MLX framework, covering the architecture decisions, performance optimizations, and implementation details from a week of development on an M4 Max.
The goal was to take the WeDLM research prototype and implement an efficient, production-style decoder on Apple's MLX stack for MacBook-class hardware. The work required re-implementing Qwen3-8B in MLX, adapting the diffusion-style decoding algorithm, and profiling MLX's execution model on Apple Silicon.
This post is structured chronologically, following the actual development process from December 29, 2025 to January 2, 2026. Each section corresponds to a specific problem encountered and the technical solution developed.
This shows the AR version of the model (Qwen-3B-IT) on the left and the MLX-optimized WeDLM-8B-IT on the left. Each model was ran for three consecutive attempts. AR stayed steady at ~50tok/s average between all runs. WeDLM-MLX started slow on the first run at 26tok/s, but then sped up to averaging 90-100tok/s in subsequent runs, due to warming up of the KVCache because page-attention is not implemented in MLX yet, and we are relying on JIT compiliation. Paged attention and metal flash attention will be a challenge for another time.
Part I: Understanding WeDLM
The Autoregressive Bottleneck
Before diving into implementation, it's worth understanding why WeDLM exists and what problem it solves.
Standard LLMs are autoregressive (AR): they generate one token at a time, with each token conditioned on all previous tokens. The joint probability of a sequence factorizes as:
This sequential dependency has a critical performance implication. For an 8B parameter model generating 100 tokens:
- 100 forward passes through 8B parameters
- Each pass loads ~16GB of weights from memory
- Total memory bandwidth: 100 × 16GB = 1.6TB
Modern GPUs have memory bandwidths around 1-2 TB/s. This means generating 100 tokens at 8B parameters takes roughly 1-2 seconds just for memory access—the actual compute is almost free. This is what we mean by memory-bound inference.
The obvious solution: generate multiple tokens per forward pass. If we can predict 4 tokens at once, we reduce memory bandwidth by 4x. This is exactly what diffusion language models attempt.
Why Traditional Diffusion Models Fail
Diffusion Language Models (DLLMs) like LLaDA and Dream work by predicting all masked positions simultaneously. The model sees:
Input: [The] [MASK] [MASK] [on] [MASK] [mat]
Output: [The] [cat] [sat] [on] [the] [mat]The problem is bidirectional attention. To predict position 2 ("sat"), the model needs to see position 5 ("the"). To predict position 5, it needs to see position 2. Every token attends to every other token.
This completely breaks KV caching.
Why KV Caching Matters
In AR generation, we cache the key-value projections from previous tokens:
Step 1: Process [The] → Cache K₀, V₀
Step 2: Process [cat], attend to K₀,V₀ → Cache K₁, V₁
Step 3: Process [sat], attend to K₀,V₀,K₁,V₁ → Cache K₂, V₂The cached K/V for token 0 never changes—it only depends on token 0's input. This is what makes AR inference tractable.
With bidirectional attention, K₀ depends on all tokens, including ones we haven't generated yet. If we change token 5, K₀ changes. Nothing can be cached until the entire sequence is finalized. We'd need to recompute the full attention matrix at every refinement step—exactly the O(n²) cost we were trying to avoid.
WeDLM's Insight: Topological Reordering
WeDLM's key insight is that mask recovery doesn't require bidirectional attention. Each masked position needs to see all observed (unmasked) tokens, but it doesn't need to see other masked positions bidirectionally.
The solution is Topological Reordering: physically move observed tokens to the prefix while preserving their logical positions through RoPE embeddings.
Original sequence: [The] [cat] [MASK] [on] [MASK] [mat]
Positions: 0 1 2 3 4 5
After reordering: [The] [cat] [on] [mat] [MASK] [MASK]
Physical index: 0 1 2 3 4 5
Logical position: 0 1 3 5 2 4 ← Preserved via RoPE!The critical detail: each token's logical position is preserved through RoPE (Rotary Position Embeddings). RoPE applies position-dependent rotations to queries and keys:
The attention score between query at position and key at position depends on —the relative position. By supplying the logical position to RoPE regardless of physical position, we preserve the correct relative position relationships.
Under standard causal attention (physical index attends to indices ), position 4 (physical index 4, logical position 2) can attend to:
- Physical indices 0-3
- Which contain ALL observed tokens (logical positions 0, 1, 3, 5)
The mask tokens see the full observed context. Causal attention is preserved. KV caching works.
The Prefix Cacheability Metric
The paper introduces a useful metric: prefix cacheability :
Where:
- = tokens that become part of the final output
- = total tokens processed across all forward passes (including recomputation)
If , every processed token becomes output—perfect efficiency (standard AR). If , half the tokens are recomputed—50% wasted work.
WeDLM achieves high by committing tokens left-to-right into a growing prefix, ensuring most predictions become immediately cache-valid.
Part II: The MLX Port
Day 1: Project Setup (December 29, 2025)
The first task was establishing the model architecture. WeDLM-8B is based on Qwen3-8B, a decoder-only transformer with these specifications:
| Parameter | Value | Notes |
|---|---|---|
| Hidden size | 4096 | Dimension of hidden states |
| Num layers | 36 | Transformer blocks |
| Num attention heads | 32 | Query heads |
| Num KV heads | 8 | Key/Value heads (GQA) |
| Head dimension | 128 | 4096 / 32 |
| Intermediate size | 12288 | FFN hidden dimension |
| Vocab size | 151,936 | Qwen's BPE vocabulary |
| RoPE θ | 1,000,000 | Base frequency for RoPE |
| RMS norm ε | 1e-6 | LayerNorm epsilon |
| QK norm | True | RMSNorm on Q and K |
Several architectural choices are notable:
Grouped Query Attention (GQA): 8 KV heads shared across 32 query heads—a 4x reduction in KV cache memory. This is critical for inference efficiency.
QK Normalization: Qwen3 applies RMSNorm to queries and keys before attention. This isn't mentioned in most transformer tutorials but is essential for training stability at scale.
High RoPE θ: The 1M base frequency (vs. 10K in original RoPE) enables longer context windows through slower position frequency decay.
The port was structured as:
mlx_wedlm/
├── model.py # WeDLMModel, KVCache variants, Attention, MLP
├── decoder.py # WeDLMDecoder, sampling, topological reordering
├── generate.py # High-level generation API, streaming
└── __init__.pyInitial Model Definition
The core model follows standard transformer structure but with MLX-specific optimizations:
class WeDLMModel(nn.Module):
"""WeDLM model optimized for MLX with fast operations."""
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [TransformerBlock(config, i) for i in range(config.num_hidden_layers)]
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
if config.tie_word_embeddings:
self.lm_head.weight = self.embed_tokens.weight
def __call__(
self,
input_ids: mx.array,
offset: int = 0,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
# Embedding
x = self.embed_tokens(input_ids)
# Forward through layers
for layer in self.layers:
x = layer(x, offset, mask, cache)
# Final norm and LM head
x = self.norm(x)
return self.lm_head(x)The Attention Module
The attention implementation required careful consideration of MLX's API. Unlike PyTorch where you might manually implement attention, MLX provides fast.scaled_dot_product_attention with native GQA support:
class Attention(nn.Module):
"""Optimized attention using mx.fast.scaled_dot_product_attention."""
def __init__(self, config: ModelConfig):
super().__init__()
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = config.head_dim
self.scale = self.head_dim ** -0.5
self.rope_theta = config.rope_theta
# Projections - no bias per Qwen3 architecture
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
# QK normalization - critical for Qwen3!
self.qk_norm = config.qk_norm
if self.qk_norm:
self.q_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
def __call__(
self,
x: mx.array,
offset: int = 0,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
layer_idx: int = 0,
) -> mx.array:
B, L, _ = x.shape
# Project to Q, K, V
q = self.q_proj(x).reshape(B, L, self.num_heads, self.head_dim)
k = self.k_proj(x).reshape(B, L, self.num_kv_heads, self.head_dim)
v = self.v_proj(x).reshape(B, L, self.num_kv_heads, self.head_dim)
# Apply QK norm BEFORE RoPE - this order matters!
if self.qk_norm:
q = self.q_norm(q)
k = self.k_norm(k)
# Transpose to (B, num_heads, L, head_dim) for RoPE and SDPA
q = q.transpose(0, 2, 1, 3)
k = k.transpose(0, 2, 1, 3)
v = v.transpose(0, 2, 1, 3)
# Apply RoPE using fast.rope
q = fast.rope(q, self.head_dim, traditional=False, base=self.rope_theta, scale=1.0, offset=offset)
k = fast.rope(k, self.head_dim, traditional=False, base=self.rope_theta, scale=1.0, offset=offset)
# Update KV cache if provided
if cache is not None:
k, v = cache.update(layer_idx, k, v)
# SDPA with native GQA support - no need to repeat K/V!
output = fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
# Reshape back and project
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)A crucial detail: MLX's fast.scaled_dot_product_attention handles GQA automatically. In PyTorch, you'd need to manually repeat K/V heads to match Q heads. MLX does this internally, saving memory and compute.
The MLP Block
The MLP uses SwiGLU activation, which has become standard in modern LLMs:
class MLP(nn.Module):
"""MLP with SwiGLU activation."""
def __init__(self, config: ModelConfig):
super().__init__()
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
def __call__(self, x: mx.array) -> mx.array:
# SwiGLU: down(silu(gate(x)) * up(x))
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))SwiGLU replaces the traditional ReLU-based MLP with a gated mechanism. The gate_proj and up_proj are both (4096 → 12288), and down_proj is (12288 → 4096). The element-wise multiplication creates an implicit gating mechanism.
Day 1 (continued): KV Cache Design
The KV cache implementation is where WeDLM diverges most from standard AR inference. Three variants handle different use cases:
Variant 1: Simple Concatenation Cache
The simplest approach—just concatenate new K/V to existing cache:
class KVCache:
"""Simple KV Cache with concatenation."""
def __init__(self, num_layers: int):
self.num_layers = num_layers
self.keys: List[Optional[mx.array]] = [None] * num_layers
self.values: List[Optional[mx.array]] = [None] * num_layers
self._offset = 0
def update(self, layer_idx: int, key: mx.array, value: mx.array) -> Tuple[mx.array, mx.array]:
"""Update cache for a layer and return full K, V."""
if self.keys[layer_idx] is None:
self.keys[layer_idx] = key
self.values[layer_idx] = value
else:
self.keys[layer_idx] = mx.concatenate([self.keys[layer_idx], key], axis=2)
self.values[layer_idx] = mx.concatenate([self.values[layer_idx], value], axis=2)
return self.keys[layer_idx], self.values[layer_idx]
def get_seq_len(self) -> int:
"""Get current cached sequence length."""
if self.keys[0] is None:
return 0
return self.keys[0].shape[2]This is theoretically inefficient—concatenation copies the entire array. But MLX's lazy evaluation means the copy is often fused with downstream operations. Benchmarking against pre-allocated buffers showed <5% difference for typical generation lengths (<2K tokens).
The cache stores tensors of shape (B, num_kv_heads, seq_len, head_dim). For WeDLM-8B with batch size 1:
- Per layer:
1 × 8 × seq_len × 128 × 2 bytes=2048 × seq_lenbytes (BF16) - Total (36 layers, K+V):
36 × 2 × 2048 × seq_len=147,456 × seq_lenbytes
At 8K sequence length: 147,456 × 8192 ≈ 1.2 GB just for KV cache.
Variant 2: Rotating Cache
For sequences exceeding memory limits, I implemented a rotating cache that keeps the first k tokens (typically system prompt) and rotates the rest:
class RotatingKVCache:
"""Rotating KV Cache with fixed-size circular buffer."""
def __init__(self, num_layers: int, max_size: int = 4096, keep: int = 4):
self.max_size = max_size
self.keep = keep # Always keep first 'keep' tokens (system prompt)
self.keys: List[Optional[mx.array]] = [None] * num_layers
self.values: List[Optional[mx.array]] = [None] * num_layers
def update(self, layer_idx: int, key: mx.array, value: mx.array):
prev_k = self.keys[layer_idx]
n_new = key.shape[2]
if prev_k is None:
self.keys[layer_idx] = key
self.values[layer_idx] = value
else:
prev_len = prev_k.shape[2]
new_len = prev_len + n_new
if new_len <= self.max_size:
# Under max size - just concatenate
self.keys[layer_idx] = mx.concatenate([prev_k, key], axis=2)
self.values[layer_idx] = mx.concatenate([self.values[layer_idx], value], axis=2)
else:
# Over max size - keep first tokens, trim oldest from middle
trim = min(n_new, prev_len - self.keep)
kept_k = prev_k[:, :, :self.keep, :]
rest_k = prev_k[:, :, self.keep + trim:, :]
self.keys[layer_idx] = mx.concatenate([kept_k, rest_k, key], axis=2)
# Same for values...The "keep first N" strategy ensures system prompts and initial context are preserved. This is important for chat applications where the system prompt defines behavior.
Variant 3: Quantized Cache
For extreme memory constraints, 8-bit KV cache quantization:
class QuantizedKVCache:
"""KV Cache with 8-bit quantization."""
def __init__(self, num_layers: int, bits: int = 8, group_size: int = 64):
self.bits = bits
self.group_size = group_size
self.keys: List[Optional[mx.array]] = [None] * num_layers
self.values: List[Optional[mx.array]] = [None] * num_layers
self.key_scales: List[Optional[mx.array]] = [None] * num_layers
self.value_scales: List[Optional[mx.array]] = [None] * num_layers
def _quantize(self, x: mx.array) -> Tuple[mx.array, mx.array]:
"""Per-group absmax quantization."""
shape = x.shape
# Reshape to groups
x_flat = x.reshape(*shape[:-1], -1, self.group_size)
# Compute scale per group
x_max = mx.max(mx.abs(x_flat), axis=-1, keepdims=True)
scale = x_max / 127.0 # For int8 range [-127, 127]
scale = mx.where(scale == 0, mx.ones_like(scale), scale) # Avoid div by zero
# Quantize
x_q = mx.round(x_flat / scale).astype(mx.int8)
return x_q, scale.squeeze(-1)
def _dequantize(self, x_q: mx.array, scale: mx.array) -> mx.array:
"""Dequantize back to float."""
x = x_q.astype(mx.bfloat16) * mx.expand_dims(scale, axis=-1)
return x.reshape(*x.shape[:-2], -1)
def update(self, layer_idx: int, key: mx.array, value: mx.array):
# Quantize new K/V
k_q, k_s = self._quantize(key)
v_q, v_s = self._quantize(value)
# Concatenate quantized values
if self.keys[layer_idx] is None:
self.keys[layer_idx] = k_q
self.key_scales[layer_idx] = k_s
# ... same for values
else:
self.keys[layer_idx] = mx.concatenate([self.keys[layer_idx], k_q], axis=2)
self.key_scales[layer_idx] = mx.concatenate([self.key_scales[layer_idx], k_s], axis=2)
# ... same for values
# Return dequantized for attention (attention needs float)
k_full = self._dequantize(self.keys[layer_idx], self.key_scales[layer_idx])
v_full = self._dequantize(self.values[layer_idx], self.value_scales[layer_idx])
return k_full, v_fullTrade-off: 2x memory savings (int8 vs BF16), ~3% quality degradation on GSM8K. The quantization error accumulates over long sequences, so this is best for shorter generations.
Part III: The RoPE Position Challenge
Day 2: Non-Sequential Positions (December 30, 2025)
This was the first major technical hurdle in translating the paper's "topological reordering" idea into an efficient RoPE implementation on MLX.
The Standard RoPE Assumption
RoPE is typically applied with a simple offset. Token 0 gets position 0, token 1 gets position 1, etc. MLX's fast.rope API reflects this:
# Standard RoPE application - sequential positions
q = fast.rope(q, head_dim, offset=cache_offset) # Positions: offset, offset+1, offset+2, ...
k = fast.rope(k, head_dim, offset=cache_offset)The offset parameter says "start counting from this position." Every subsequent position in the sequence is offset + 0, offset + 1, offset + 2, etc.
WeDLM's topological reordering breaks this assumption completely. After reordering:
Physical index: 0 1 2 3 4 5
Token: [The][cat][on] [mat][M] [M]
Logical position: 0 1 3 5 2 4 ← NOT sequential!The mask at physical index 4 needs position 2. The mask at physical index 5 needs position 4. There's no single offset that produces [2, 4].
First Attempt: Python Loop (The Wrong Way)
The obvious initial approach—loop over positions:
def apply_rope_positions_slow(x, positions, head_dim, base):
"""WRONG: Per-token loop destroys performance."""
B, H, L, D = x.shape
result = mx.zeros_like(x)
for i, pos in enumerate(positions):
# Apply RoPE to each token individually
token = x[:, :, i:i+1, :] # Shape: (B, H, 1, D)
rotated = fast.rope(token, head_dim, offset=int(pos), base=base)
result[:, :, i, :] = rotated[:, :, 0, :]
return resultPerformance: 8 tokens/second.
To understand why this is catastrophic, you need to understand MLX's execution model. Each fast.rope call adds a node to MLX's computation graph. Python loops mean Python interpreter overhead between each node. When MLX evaluates the graph, each small operation might become a separate GPU kernel launch.
The overhead isn't the Python loop per se—it's that we're preventing MLX from fusing operations. Instead of one large RoPE kernel processing 16 positions, we have 16 tiny kernels with GPU launch overhead between each.
The Insight: Reshape to Fake Batching
The fast.rope API allows offset to be a 1D array with one offset per batch element, which enables per-position RoPE by reshaping the sequence dimension into the batch dimension.
# From MLX docs:
# offset: The position offset. Can be a scalar or a 1D array of shape (B,)If offset can be per-batch, and I want per-position offsets... what if each position becomes its own batch element?
def apply_rope_positions(x: mx.array, positions: mx.array,
head_dim: int, base: float = 10000.0) -> mx.array:
"""Apply RoPE with explicit position indices - THE FAST WAY."""
B, H, L, D = x.shape
# Original shape: (B, H, L, D) where:
# B = batch size (usually 1)
# H = num heads
# L = sequence length (positions to process)
# D = head dimension
# Step 1: Reshape so each position is its own "batch"
# (B, H, L, D) → (B, L, H, D) → (B*L, H, 1, D)
x_reshaped = x.transpose(0, 2, 1, 3) # (B, L, H, D)
x_reshaped = x_reshaped.reshape(B * L, H, 1, D) # (B*L, H, 1, D)
# Step 2: Expand positions for the fake batch dimension
# positions: (L,) → (B*L,) by tiling for each original batch
if B > 1:
positions = mx.tile(positions, (B,)) # Now shape (B*L,)
# Step 3: Apply RoPE with per-"batch" offsets
# Each of the B*L "batch" elements gets its own offset from positions
# The sequence dimension is 1, so fast.rope applies that offset to the single position
result = fast.rope(
x_reshaped, # (B*L, H, 1, D)
head_dim,
traditional=False,
base=base,
scale=1.0,
offset=positions # (B*L,) - one offset per "batch" element
)
# Step 4: Reshape back to original layout
# (B*L, H, 1, D) → (B, L, H, D) → (B, H, L, D)
result = result.reshape(B, L, H, D) # (B, L, H, D)
result = result.transpose(0, 2, 1, 3) # (B, H, L, D)
return resultThe trick: by reshaping so each position is a separate "batch" element with sequence length 1, the per-batch offset becomes a per-position offset.
Tracing through a concrete example:
Input: x of shape (1, 32, 4, 128) # batch=1, 32 heads, 4 positions, dim=128
Positions: [0, 1, 3, 5] # Non-sequential logical positions
Step 1 - Transpose: (1, 4, 32, 128)
Step 2 - Reshape: (4, 32, 1, 128) # 4 "batches", 32 heads, 1 position each
Step 3 - Apply RoPE with offset=[0, 1, 3, 5]
- "Batch" 0 gets position 0
- "Batch" 1 gets position 1
- "Batch" 2 gets position 3
- "Batch" 3 gets position 5
Step 4 - Reshape back: (1, 32, 4, 128)Performance: 35 tokens/second. A 4.4x improvement from changing the reshape pattern.
Why This Works
MLX's fast.rope is implemented as a single Metal kernel that processes all batch elements in parallel. By reshaping our positions into the batch dimension, we're telling MLX "run this kernel with B*L parallel work items, each with its own offset."
The alternative (Python loop) says "run L separate kernels sequentially." Even if each kernel is fast, the sequential launches and Python overhead kill performance.
This pattern—reshaping to exploit batch parallelism—is a recurring theme in MLX optimization.
Part IV: The Synchronization Trap
Understanding MLX's Lazy Evaluation
MLX's lazy evaluation is the single most impactful optimization target in this project.
How MLX Executes Operations
MLX uses lazy evaluation: operations don't execute immediately. Instead, they build a computation graph that's optimized and executed when you actually need the values.
import mlx.core as mx
a = mx.array([1.0, 2.0, 3.0])
b = mx.array([4.0, 5.0, 6.0])
c = mx.add(a, b) # No computation - graph node created
d = mx.multiply(c, 2) # No computation - another graph node
e = mx.sum(d) # No computation - another graph node
print("Graph built, nothing computed yet")
mx.eval(e) # NOW computation happens
# MLX sees: sum(multiply(add(a, b), 2))
# Optimizes: potentially fuses operations, optimizes memory layout
# Executes: single optimized GPU dispatchThis is powerful because MLX can:
- Fuse operations: Combine add + multiply into a single kernel
- Eliminate intermediates:
canddmight never be materialized in memory - Optimize memory access: Read inputs once, write output once
The Trap: Implicit Synchronization
The problem: several Python operations require actual values, forcing MLX to synchronize:
# Each of these forces synchronization:
x.item() # Returns Python scalar
x.tolist() # Returns Python list
len(x) # Needs to know actual size
if x > 0: # Boolean requires evaluation
for i in x: # Iteration requires valuesAn example problematic decode loop:
# ANTI-PATTERN: 32 synchronization points per decode step
def decode_step_slow(logits, window_size, threshold):
confirmed = []
for i in range(window_size): # window_size = 16
logits_i = logits[0, i]
# Sync #1: .item() forces evaluation
token_id = mx.argmax(logits_i).item()
# Sync #2: Another .item()
probs = mx.softmax(logits_i, axis=-1)
entropy = -mx.sum(probs * mx.log(probs + 1e-10)).item()
if entropy < threshold:
confirmed.append((i, token_id))
return confirmedEach .item() call:
- Flushes the pending computation graph
- Dispatches work to the GPU
- Waits for GPU to complete
- Transfers scalar result to CPU
With 16 window positions × 2 .item() calls = 32 GPU round-trips per decode step.
If each round-trip takes even 0.5ms (and they often take more), that's 16ms overhead per step—completely dwarfing the actual compute time.
The Fix: Batched Operations with Single Sync
The solution is to batch all operations and synchronize once:
def sample_tokens_batch(logits: mx.array, temperature: float = 0.0) -> Tuple[mx.array, mx.array]:
"""Sample all tokens in batch - single sync at the end."""
# logits shape: (window_size, vocab_size)
# All operations stay on GPU, building computation graph
if temperature == 0:
# Greedy: argmax over vocabulary dimension
token_ids = mx.argmax(logits, axis=-1) # Shape: (W,)
else:
# Sampling: scale logits by temperature, sample categorically
scaled = logits / temperature
token_ids = mx.random.categorical(scaled, axis=-1) # Shape: (W,)
# Compute all entropies at once
probs = mx.softmax(logits, axis=-1) # Shape: (W, V)
log_probs = mx.log(probs + 1e-10) # Shape: (W, V)
entropies = -mx.sum(probs * log_probs, axis=-1) # Shape: (W,)
return token_ids, entropies
# Usage in decode loop:
def decode_step_fast(logits, window_size, threshold):
# Get mask position logits
mask_logits = logits[0] # Shape: (window_size, vocab_size)
# Batch sample - no sync yet
token_ids, entropies = sample_tokens_batch(mask_logits, temperature=0.0)
# SINGLE sync point - entire graph is fused and executed
mx.eval(token_ids, entropies)
# NOW convert to Python (arrays are materialized, this is just memory copy)
token_ids_list = token_ids.tolist() # Fast - data already on CPU
entropies_list = entropies.tolist() # Fast - data already on CPU
# Python logic with actual values
confirmed = []
for i in range(window_size):
if entropies_list[i] < threshold:
confirmed.append((i, token_ids_list[i]))
return confirmedPerformance impact: 20 tok/s → 35 tok/s (75% improvement).
The key insight: MLX's graph optimization can fuse the softmax, log, multiply, and sum into a single efficient kernel. By synchronizing once instead of 32 times, we let MLX do what it's designed to do.
The General Rule
MLX Performance Principles:
1. Build graphs freely - it's just Python object creation
2. Minimize mx.eval() calls - each one dispatches work and waits
3. NEVER use .item() in loops - it's an implicit mx.eval()
4. Batch operations over dimensions instead of looping
5. .tolist() after eval is fine - data is already materialized
Sync points kill performance. Batch everything.A Debugging Pattern
When debugging MLX performance issues, instrument sync points:
import time
def timed_eval(*arrays):
"""Evaluate arrays and print timing."""
start = time.perf_counter()
mx.eval(*arrays)
elapsed = (time.perf_counter() - start) * 1000
print(f"eval took {elapsed:.2f}ms for {len(arrays)} arrays")
return arrays
# Use in code:
token_ids, entropies = sample_tokens_batch(logits)
timed_eval(token_ids, entropies)If you see many eval took Xms prints per generation step, you have too many sync points.
Part V: The Diffusion Decoding Algorithm
Day 2 (continued): Implementing Streaming Parallel Decoding
With the core model and RoPE working, the next challenge was implementing WeDLM's actual decoding algorithm. This is where the paper's ideas become concrete code.
The Decoding State Machine
WeDLM maintains a sliding window of mask tokens. At each step:
- Forward pass on the window (with topological reordering)
- Sample tokens for each mask position
- Compute entropy (confidence) for each position
- Commit confident positions to the output
- Slide the window, adding new masks
The state definition:
@dataclass
class DecoderState:
"""State for diffusion decoding."""
# Confirmed tokens (after prompt, before window)
confirmed_tokens: List[int]
# Current window tokens (being decoded)
window_tokens: List[int]
# Which positions in window are still masks
window_mask_flags: List[bool]
# Total generated count
num_generated: int
# Finished flag
is_finished: bool
@dataclass
class SamplingParams:
temperature: float = 0.0 # 0 = greedy
top_p: float = 1.0 # Nucleus sampling threshold
max_tokens: int = 512 # Maximum tokens to generate
entropy_threshold: float = 0.3 # Below this = "confident"
pos_penalty_factor: float = 0.01 # Penalty for later positionsWindow Visualization
A concrete example illustrates the algorithm. Starting with prompt "Hello" and window size 4:
Initial State:
Cache: [Hello] (prefilled)
Window: [MASK][MASK][MASK][MASK]
Flags: [True][True][True][True]
Step 1 - Forward pass on window:
Input to model: [MASK][MASK][MASK][MASK] with positions [5, 6, 7, 8]
Logits: predictions for all 4 positions
Sampling results:
Position 0: token="world", entropy=0.1 ← Confident (< 0.3)
Position 1: token=",", entropy=0.2 ← Confident
Position 2: token="how", entropy=0.8 ← Not confident
Position 3: token="are", entropy=0.9 ← Not confident
Step 1 - Update state:
Window: [world][,][MASK][MASK]
Flags: [False][False][True][True]
Step 1 - Commit leading confirmed:
Positions 0 and 1 are consecutive from start → commit both
confirmed_tokens: [world, ,]
Cache: [Hello, world, ,]
Slide window:
Window: [MASK][MASK][MASK][MASK] ← Remaining masks + new masks
Flags: [True][True][True][True]
Step 2 - Forward pass:
Window still has masks at logical positions [7, 8, 9, 10]
But wait - positions 7 and 8 were predicted as "how" and "are"
Actually, the window keeps the non-committed predictions:
Window: [how][are][MASK][MASK] ← Previous predictions + new masks
Flags: [True][True][True][True] ← All treated as masks for sampling
Topological reorder for attention:
Observed: none in window (all masks)
Input: [how][are][MASK][MASK] with positions [7, 8, 9, 10]
New sampling might produce:
Position 0: token="how", entropy=0.15 ← Now confident!
Position 1: token="you", entropy=0.1 ← Changed and confident
Position 2: token="?", entropy=0.2 ← Confident
Position 3: token="I", entropy=0.6 ← Not confident
Window: [how][you][?][I]
Flags: [False][False][False][True]
Commit leading: "how", "you", "?"
confirmed_tokens: [world, ,, how, you, ?]This illustrates two key WeDLM behaviors:
- Iterative refinement: "are" became "you" on the second pass because the model had more context
- Left-to-right commitment: Only leading consecutive confirmed tokens are committed, preserving cache validity
The Core Decode Step
Here's the actual implementation:
def decode_step(
self,
state: DecoderState,
params: SamplingParams,
cache: KVCache,
prompt_len: int,
) -> Tuple[DecoderState, List[int]]:
"""Execute one decoding step with topological reordering."""
if state.is_finished:
return state, []
# Separate filled and mask positions in current window
non_mask_indices = [i for i, flag in enumerate(state.window_mask_flags) if not flag]
mask_indices = [i for i, flag in enumerate(state.window_mask_flags) if flag]
cache_len = cache.get_seq_len() # Prompt + committed tokens
# Two code paths depending on whether we have filled positions
if len(non_mask_indices) == 0:
# Fast path: all positions are masks, no reordering needed
window_ids = mx.array([state.window_tokens])
logits = self.model.decode_window(window_ids, cache_len, cache)
mask_logits = logits[0] # Shape: (window_size, vocab_size)
else:
# Topological reordering: filled positions first, masks last
order = non_mask_indices + mask_indices
ordered_tokens = [state.window_tokens[i] for i in order]
ordered_positions = [cache_len + i for i in order] # Logical positions!
window_ids = mx.array([ordered_tokens])
positions = mx.array(ordered_positions)
# Forward with per-token positions (uses apply_rope_positions)
logits = self.model.decode_window_reordered(window_ids, positions, cache)
# Extract logits for mask positions only
num_non_mask = len(non_mask_indices)
mask_logits = logits[0, num_non_mask:] # Only predict masks
# Batched sampling - single sync point
token_ids, entropies = self.sample_tokens_batch(
mask_logits, params.temperature, params.top_p
)
mx.eval(token_ids, entropies)
# Convert to Python for threshold logic
token_ids_list = token_ids.tolist()
entropies_list = entropies.tolist()
# Apply position penalty to encourage left-to-right commitment
adjusted_entropies = []
for k, orig_idx in enumerate(mask_indices):
token_id = token_ids_list[k]
entropy = entropies_list[k]
# Penalty increases with position in window
adjusted = entropy + orig_idx * params.pos_penalty_factor
adjusted_entropies.append((orig_idx, adjusted, token_id))
# Update window with confident predictions
new_window_tokens = list(state.window_tokens)
new_mask_flags = list(state.window_mask_flags)
positions_to_fill = []
for orig_idx, adj_entropy, token_id in adjusted_entropies:
if adj_entropy < params.entropy_threshold:
positions_to_fill.append((orig_idx, token_id))
# Always fill at least one position (to guarantee progress)
if not positions_to_fill and adjusted_entropies:
best = min(adjusted_entropies, key=lambda x: x[1])
positions_to_fill = [(best[0], best[2])]
for pos, token_id in positions_to_fill:
new_window_tokens[pos] = token_id
new_mask_flags[pos] = False
# Count leading confirmed positions (consecutive non-masks from start)
num_front_confirmed = 0
for i in range(len(new_window_tokens)):
if not new_mask_flags[i]:
num_front_confirmed += 1
else:
break
# Commit leading confirmed tokens
new_confirmed = list(state.confirmed_tokens)
committed_tokens = []
if num_front_confirmed > 0:
committed_tokens = new_window_tokens[:num_front_confirmed]
new_confirmed.extend(committed_tokens)
# Update cache with committed tokens' K/V
self.model.extend_cache(mx.array([committed_tokens]), cache)
# Slide window: remove committed, add new masks
new_window_tokens = new_window_tokens[num_front_confirmed:] + \
[self.mask_token_id] * num_front_confirmed
new_mask_flags = new_mask_flags[num_front_confirmed:] + \
[True] * num_front_confirmed
# Check termination
new_num_generated = state.num_generated + len(committed_tokens)
is_finished = new_num_generated >= params.max_tokens
new_state = DecoderState(
confirmed_tokens=new_confirmed,
window_tokens=new_window_tokens,
window_mask_flags=new_mask_flags,
num_generated=new_num_generated,
is_finished=is_finished,
)
return new_state, committed_tokensThe Position Penalty Explained
WeDLM uses distance-penalized selection to encourage left-to-right commitment:
Where is the position penalty factor (default 0.01).
Why does this matter? Consider a window where:
- Position 0: entropy 0.25
- Position 2: entropy 0.20
- Position 5: entropy 0.15
Without penalty, position 5 would be most confident. But committing position 5 before 0, 1, 2, 3, 4 means those positions' K/V can't be cached yet—they depend on the uncommitted prefix.
With penalty ():
- Position 0: 0.25 + 0×0.01 = 0.250
- Position 2: 0.20 + 2×0.01 = 0.220
- Position 5: 0.15 + 5×0.01 = 0.200
Position 5 is still lowest! But now consider:
- Position 0: entropy 0.20
- Position 5: entropy 0.15
Without penalty: Position 5 wins (0.15 < 0.20) With penalty: Position 0 wins (0.20 < 0.20) ← tie goes to earlier
The penalty biases toward earlier positions when confidence is similar, maximizing .
The decode_window_reordered Method
The model needs a special method for reordered windows that applies per-token positions:
def decode_window_reordered(
self,
window_ids: mx.array,
positions: mx.array, # 1D array of logical positions
cache: KVCache,
) -> mx.array:
"""Decode with explicit per-token positions (for topological reordering)."""
B, W = window_ids.shape
cache_len = cache.get_seq_len()
# Create causal mask for window attending to cache + self
total_len = cache_len + W
mask = mx.triu(
mx.full((W, total_len), float('-inf'), dtype=mx.bfloat16),
k=cache_len + 1 # Window tokens attend to cache + prior window tokens
)
mask = mx.expand_dims(mask, axis=(0, 1))
# Embedding
x = self.embed_tokens(window_ids)
# Forward through layers with per-token RoPE
for layer in self.layers:
layer_idx = layer.layer_idx
cached_k, cached_v = cache.get_layer_kv(layer_idx)
residual = x
x_norm = layer.input_layernorm(x)
attn = layer.self_attn
q = attn.q_proj(x_norm).reshape(B, W, attn.num_heads, attn.head_dim)
k = attn.k_proj(x_norm).reshape(B, W, attn.num_kv_heads, attn.head_dim)
v = attn.v_proj(x_norm).reshape(B, W, attn.num_kv_heads, attn.head_dim)
if attn.qk_norm:
q = attn.q_norm(q)
k = attn.k_norm(k)
q = q.transpose(0, 2, 1, 3)
k = k.transpose(0, 2, 1, 3)
v = v.transpose(0, 2, 1, 3)
# KEY DIFFERENCE: Apply RoPE with explicit positions instead of offset
q = apply_rope_positions(q, positions, attn.head_dim, attn.rope_theta)
k = apply_rope_positions(k, positions, attn.head_dim, attn.rope_theta)
# Concatenate with cache
if cached_k is not None:
k_full = mx.concatenate([cached_k, k], axis=2)
v_full = mx.concatenate([cached_v, v], axis=2)
else:
k_full, v_full = k, v
attn_out = fast.scaled_dot_product_attention(q, k_full, v_full, scale=attn.scale, mask=mask)
attn_out = attn_out.transpose(0, 2, 1, 3).reshape(B, W, -1)
x = residual + attn.o_proj(attn_out)
residual = x
x = residual + layer.mlp(layer.post_attention_layernorm(x))
return self.lm_head(self.norm(x))The Prefill Phase
Before decode starts, we need to process the prompt and initialize the cache:
def prefill(
self,
input_ids: mx.array,
cache: Optional[KVCache] = None,
chunk_size: int = 512,
) -> Tuple[mx.array, KVCache]:
"""Process prompt and initialize KV cache."""
B, L = input_ids.shape
if cache is None:
cache = self.create_cache(batch_size=B)
# Short prompts: process all at once
if L <= chunk_size:
logits = self(input_ids, offset=0, mask="causal", cache=cache)
cache.finalize_update(L)
return logits, cache
# Long prompts: chunked prefill to limit memory
offset = 0
for start in range(0, L, chunk_size):
end = min(start + chunk_size, L)
chunk = input_ids[:, start:end]
logits = self(chunk, offset=offset, mask="causal", cache=cache)
offset = cache.get_seq_len()
mx.eval(logits) # Prevent graph from growing too large
return logits, cacheChunked prefill is important for long prompts. Without it, the computation graph grows with prompt length, potentially exhausting memory before execution even starts.
Part VI: The Threading SIGSEGV Incident
Day 3, Morning: Building a Comparison Benchmark
January 2, 2026
After achieving 35 tok/s with WeDLM's parallel decoding, I needed to answer the obvious question: how does this compare to autoregressive decoding on the same hardware?
The plan was simple: build a benchmark tool to run both WeDLM and a standard Qwen3-8B through mlx-lm, then compare throughputs. For efficiency, running them in parallel seemed reasonable—two independent models should be able to run simultaneously.
# chat_bench.py - First (fatally flawed) version
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import List, Tuple
import time
# Global models (loaded once, reused)
wedlm_model = None
ar_model = None
def load_models():
"""Load both models at startup."""
global wedlm_model, ar_model
print("Loading WeDLM-8B...")
wedlm_model = WeDLMDecoder.from_pretrained("./wedlm-8b-mlx-4bit")
print("Loading Qwen3-8B baseline...")
ar_model, ar_tokenizer = load("mlx-community/Qwen3-8B-4bit")
print("Both models loaded.")
def run_comparison(prompt: str):
wedlm_result = []
qwen_result = []
def run_wedlm():
wedlm_result.append(wedlm_generate(prompt))
def run_qwen():
qwen_result.append(qwen_generate(prompt))
t1 = threading.Thread(target=run_wedlm)
t2 = threading.Thread(target=run_qwen)
t1.start()
t2.start() # SIGSEGV here
t1.join()
t2.join()Result: SIGSEGV (segmentation fault).
The failure was a process-terminating SIGSEGV rather than a Python exception, with no stack trace or diagnostic message. It reproduced consistently as soon as both threads initiated concurrent forward passes.
The Debugging Journey
Print statements helped isolate the problem:
def run_wedlm_generation(prompt: str):
print(f"[WeDLM] Starting generation for: {prompt[:30]}...")
print(f"[WeDLM] Tokenizing...")
tokens = wedlm_model.tokenizer.encode(prompt)
print(f"[WeDLM] Token count: {len(tokens)}")
print(f"[WeDLM] Starting prefill...")
# CRASH HAPPENS HERE OR IN AR'S EQUIVALENTThe crash consistently occurred when both threads tried to call mx.eval() simultaneously—specifically during their respective forward passes.
Verification that the issue wasn't application-specific came from trying the same pattern with two instances of the same mlx-lm model:
from mlx_lm import load, generate
model, tokenizer = load("mlx-community/Qwen3-8B-4bit")
def gen_thread_1():
return generate(model, tokenizer, prompt="Hello", max_tokens=10)
def gen_thread_2():
return generate(model, tokenizer, prompt="World", max_tokens=10)
with ThreadPoolExecutor(max_workers=2) as executor:
f1 = executor.submit(gen_thread_1)
f2 = executor.submit(gen_thread_2)
# SIGSEGVSame crash. Not application-specific—something fundamental about MLX and threading.
Root Cause: Metal's Threading Model
MLX GitHub issues and Metal documentation reference this problem. Metal has a specific threading model:
Thread-safe:
MTLDevice(the GPU device handle)MTLCommandQueue(the queue for submitting work)
NOT thread-safe:
MTLCommandBuffer(individual buffers of GPU commands)MTLCommandEncoder(encoders for filling buffers)
MLX's execution model:
- Build a lazy computation graph (thread-safe—just building Python objects)
- On
mx.eval():- Compile graph to Metal commands
- Create command buffers
- Submit to queue
- Wait for completion
The problem: MLX uses shared global state for command buffer management. When two threads call mx.eval() simultaneously:
Thread 1 Thread 2
-------- --------
Enter mx.eval() Enter mx.eval()
Get shared command buffer Get shared command buffer (same one!)
Start encoding commands Start encoding commands
↓ ↓
Both writing to same buffer → RACE CONDITION → Memory corruption → SIGSEGVThe Solution: Sequential Execution
def run_comparison_safe(prompt: str):
"""Sequential execution - boring but correct."""
print("Running WeDLM...")
wedlm_result = wedlm_generate(prompt)
wedlm_tokens = count_tokens(wedlm_result)
print("Running Qwen AR...")
qwen_result = qwen_generate(prompt)
qwen_tokens = count_tokens(qwen_result)
return wedlm_result, qwen_resultIt works. It's twice as slow in wall-clock time (running sequentially instead of in parallel), but benchmarking doesn't need parallelism—we just need accurate throughput measurements.
Alternative: Process-Level Parallelism
If you really need concurrent model execution, separate processes work because each gets its own Metal context:
import multiprocessing as mp
from multiprocessing import Queue
def wedlm_process(prompt: str, result_queue: Queue):
"""Run WeDLM in a separate process."""
# Load model fresh (can't share across processes)
model = WeDLMDecoder.from_pretrained("./wedlm-8b-mlx-4bit")
start = time.perf_counter()
response = model.generate(prompt, max_tokens=256)
elapsed = time.perf_counter() - start
tokens = len(model.tokenizer.encode(response.text))
result_queue.put(("wedlm", response.text, tokens / elapsed))
# Separate processes = separate Metal contexts
q1, q2 = multiprocessing.Queue(), multiprocessing.Queue()
p1 = multiprocessing.Process(target=run_in_process, args=(wedlm_generate, prompt, q1))
p2 = multiprocessing.Process(target=run_in_process, args=(qwen_generate, prompt, q2))Each process gets its own Metal context. But the memory overhead is significant—each process loads the full model. For benchmarking where we just want to measure throughputs, sequential is preferable.
GPU API Threading Model
GPU APIs are designed for explicit concurrency control, not implicit threading.
In CUDA/Metal/Vulkan, the programmer explicitly:
- Creates command queues
- Allocates command buffers
- Submits work in a controlled order
- Synchronizes across queues
High-level frameworks like MLX, PyTorch, and TensorFlow abstract this away for the common case (single-threaded Python), but the abstraction breaks with Python threading.
Investigating the crash clarified that MLX's Metal command buffer management is not thread-safe under concurrent mx.eval calls, which in turn informed later decisions to keep inference single-threaded within a process.
Part VII: Quantization — The Quality-Speed Trade-off
Day 3, Afternoon: Memory Pressure
January 2, 2026 (continued)
WeDLM-8B in BF16 requires approximately 16GB for weights alone:
Parameter count: 8.03B
Bytes per param (BF16): 2
Total weight memory: 8.03B × 2 = 16.06 GB
Plus runtime:
KV Cache (grows with sequence): ~150MB per 1K tokens
Activations (batch=1, seq=2K): ~200MB
MLX overhead: ~500MB
Total for 2K context: ~17-18 GBAn M3 Max has 48GB unified memory, so this fits comfortably. But unified memory is shared with the system and other applications. Running Chrome, VSCode, and a 17GB model simultaneously creates memory pressure—swap usage appeared at longer contexts.
4-bit quantization reduces weight memory to ~4GB, making inference much more comfortable. The question is: what does it cost in quality?
Understanding Group-Wise Quantization
MLX uses group-wise min-max quantization:
# Conceptual implementation of 4-bit group quantization
def quantize_weights(weights: mx.array, group_size: int = 64, bits: int = 4):
"""Quantize weights using group-wise min-max scaling."""
# weights shape: (out_features, in_features)
# Reshape into groups along input dimension
groups = weights.reshape(-1, group_size) # (num_groups, group_size)
# Compute per-group scale factors
max_vals = mx.max(mx.abs(groups), axis=1, keepdims=True)
# For 4-bit: representable values are -7 to +7
num_levels = 2 ** (bits - 1) - 1 # 7 for 4-bit
scales = max_vals / num_levels
# Quantize to integers
quantized = mx.round(groups / scales)
quantized = mx.clip(quantized, -num_levels, num_levels).astype(mx.int8)
return quantized, scales # Store both for dequantization
def dequantize_weights(quantized: mx.array, scales: mx.array, group_size: int = 64):
"""Dequantize during forward pass."""
# Multiply quantized values by their group's scale factor
dequantized = quantized.astype(mx.float16) * scales
return dequantized.reshape(original_shape)The key insight: each group of 64 weights shares a single scale factor. This means:
- Storage: 4 bits per weight + 16 bits per 64 weights for scale = 4.25 bits/weight effective
- Reconstruction:
weight_approx = quantized_int * scale - Error: bounded by
scale / 2per weight
Group size is a trade-off:
- Smaller groups (e.g., 32): More scales stored, better precision, more memory
- Larger groups (e.g., 128): Fewer scales, worse precision, less memory
MLX defaults to group_size=64, which is a reasonable middle ground.
First Attempt: Quantize Everything
import mlx.nn as nn
from mlx_wedlm.model import WeDLMModel
# Load BF16 model
model = WeDLMModel.from_pretrained("./wedlm-8b-bf16")
# Simple quantization - all linear layers
nn.quantize(model, bits=4, group_size=64)
# Save quantized model
model.save_weights("./wedlm-8b-4bit/weights.safetensors")Performance was impressive: 60 tok/s (vs 35 tok/s BF16). A 1.7x speedup just from quantization!
GSM8K (grade school math) testing revealed problems:
Prompt: "What is 15 + 27?"
BF16: "15 + 27 = 42. The answer is 42."
4-bit: "15 + 27 = 44. The answer is 44." # Wrong!
Prompt: "A baker has 23 cupcakes. She sells 8. How many are left?"
BF16: "23 - 8 = 15. The baker has 15 cupcakes left."
4-bit: "23 - 8 = 16. The baker has 16 cupcakes left." # Wrong again!Simple arithmetic errors. The model was no longer reliable for basic math.
Diagnosing the Problem
Understanding which layers were most sensitive to quantization required a layer-by-layer ablation:
def measure_layer_sensitivity(model, layer_name: str, test_prompts: List[str]):
"""Quantize a single layer and measure impact on output quality."""
# Save original weights
original_weights = get_layer_weights(model, layer_name).copy()
# Quantize just this layer
quantize_single_layer(model, layer_name, bits=4)
# Measure quality degradation
score = evaluate_on_prompts(model, test_prompts)
# Restore original weights
set_layer_weights(model, layer_name, original_weights)
return score
# Test each layer type
sensitivity_scores = {}
for layer_name in ["embed_tokens", "lm_head", "q_proj", "k_proj", "v_proj",
"o_proj", "gate_proj", "up_proj", "down_proj"]:
score = measure_layer_sensitivity(model, layer_name, math_test_prompts)
sensitivity_scores[layer_name] = score
print(f"{layer_name}: {score:.2f}%")Results were illuminating:
| Layer | Accuracy (4-bit) | Accuracy (BF16) | Degradation |
|---|---|---|---|
| embed_tokens | 75% | 80% | -5% |
| lm_head | 68% | 80% | -12% |
| q_proj (all) | 79% | 80% | -1% |
| k_proj (all) | 79% | 80% | -1% |
| v_proj (all) | 79% | 80% | -1% |
| o_proj (all) | 79% | 80% | -1% |
| gate_proj (all) | 78% | 80% | -2% |
| up_proj (all) | 78% | 80% | -2% |
| down_proj (all) | 78% | 80% | -2% |
The lm_head layer was the culprit. Quantizing it alone caused a 12% accuracy drop, while attention projections caused only 1% each.
Why lm_head Is Special
The lm_head layer projects hidden states to vocabulary logits:
hidden_state (4096-dim) → lm_head (4096 × 151936) → logits (151936-dim)
↑
This layer is HUGE: 620M parameters
And the final "decision boundary"The size alone is concerning (620M params in one layer), but the real issue is the nature of the computation:
# At inference, we pick the argmax token
logits = lm_head(hidden_state) # 151936 logits
token_id = mx.argmax(logits).item()The argmax is discontinuous—a tiny change in logits can completely change which token wins:
Scenario 1 (BF16):
logit[token_42] = 5.0001
logit[token_44] = 5.0000
argmax → token_42 ✓
Scenario 2 (quantized lm_head):
logit[token_42] = 4.9998 # Quantization error shifted it down
logit[token_44] = 5.0002 # Quantization error shifted it up
argmax → token_44 ✗A 0.0003 error flipped the answer. For attention layers, quantization error averages out over the softmax. For lm_head, it directly corrupts the output distribution.
Selective Quantization Implementation
The fix: keep sensitive layers in BF16, quantize everything else:
def quantize_model_selective(
model: nn.Module,
bits: int = 4,
group_size: int = 64,
exclude_patterns: List[str] = ["lm_head", "embed_tokens"]
) -> nn.Module:
"""Quantize with selective layer exclusion."""
def should_quantize(name: str, module: nn.Module) -> bool:
"""Predicate: should this layer be quantized?"""
# Only quantize Linear layers
if not isinstance(module, nn.Linear):
return False
# Skip layers matching exclude patterns
for pattern in exclude_patterns:
if pattern in name:
print(f" Keeping {name} in BF16 (matches '{pattern}')")
return False
return True
# Count what we're quantizing
total_params = sum(p.size for p in model.parameters())
# Apply quantization with predicate
nn.quantize(
model,
bits=bits,
group_size=group_size,
class_predicate=should_quantize
)
# Report what was excluded
excluded_params = sum(
p.size for name, p in model.named_parameters()
if any(pat in name for pat in exclude_patterns)
)
quantized_params = total_params - excluded_params
print(f"Quantized: {quantized_params/1e9:.2f}B params ({bits}-bit)")
print(f"Excluded: {excluded_params/1e9:.2f}B params (BF16)")
print(f"Effective bits per param: {(quantized_params * bits + excluded_params * 16) / total_params:.2f}")
return model
# Usage
model = WeDLMModel.from_pretrained("./wedlm-8b-bf16")
model = quantize_model_selective(model, bits=4, exclude_patterns=["lm_head", "embed_tokens"])
model.save_weights("./wedlm-8b-4bit-selective/weights.safetensors")Output:
Keeping embed_tokens in BF16 (matches 'embed_tokens')
Keeping lm_head in BF16 (matches 'lm_head')
Quantized: 7.03B params (4-bit)
Excluded: 1.00B params (BF16)
Effective bits per param: 5.50Memory Analysis
Breaking down the memory savings:
Component | BF16 | 4-bit selective | Savings
-------------------+-------------+-----------------+----------
Attention projs | 2.4 GB | 0.6 GB | 75%
FFN (gate/up/down) | 11.2 GB | 2.8 GB | 75%
embed_tokens | 1.2 GB | 1.2 GB | 0%
lm_head | 1.2 GB | 1.2 GB | 0%
-------------------+-------------+-----------------+----------
Total weights | 16.0 GB | 5.8 GB | 64%The 64% reduction is less than pure 4-bit would give (75%), but we keep the quality-critical layers intact.
Final Results
| Configuration | Memory | Speed | GSM8K Accuracy | Notes |
|---|---|---|---|---|
| BF16 (baseline) | 16 GB | 35 tok/s | 80% | Reference quality |
| 4-bit (all layers) | 4 GB | 60 tok/s | 68% | Unacceptable math errors |
| 4-bit (selective) | 5.8 GB | 55 tok/s | 78% | Good trade-off |
| 8-bit (all layers) | 8 GB | 45 tok/s | 79% | Conservative option |
The selective 4-bit approach recovers most of the quality (-2% vs BF16) while retaining most of the speedup (55 vs 60 tok/s). For a local chat assistant where occasional math errors are tolerable, this is the optimal trade-off.
The Quantization Trade-off Framework
The investigation yielded a framework for quantization decisions:
Layer Type | Quantization Tolerance | Reasoning
------------------------+-----------------------+----------------------------------
Attention Q/K/V proj | High | Error diffuses over softmax
Attention O proj | High | Linear transform, error averages
FFN projections | Medium | Non-linearity can amplify error
Layer norms | Low | Statistics-dependent, keep BF16
Embeddings | Medium | Large vocab, but averaged in
lm_head | Very Low | Directly determines output tokenFor production models:
- Always keep lm_head in high precision
- Consider keeping embeddings if vocab size is large
- Safe to quantize attention and FFN projections
- Monitor downstream task accuracy
Part VIII: The `mx.compile` Investigation
Day 4: Chasing the Last Optimization
January 3, 2026
Every MLX tutorial mentions mx.compile as a performance optimization. The promise: JIT compilation reduces Python overhead, enables kernel fusion, and speeds up repeated operations. After the other optimizations, I was at 55 tok/s. Could mx.compile push me to 70+?
The Promise
The MLX documentation shows impressive examples:
import mlx.core as mx
def slow_function(x, y):
# Multiple operations, multiple GPU dispatches
z = mx.add(x, y)
z = mx.multiply(z, x)
z = mx.sin(z)
z = mx.exp(z)
return z
# Without compile: 4 separate operations
result = slow_function(x, y)
# With compile: fused into optimal kernel
@mx.compile
def fast_function(x, y):
z = mx.add(x, y)
z = mx.multiply(z, x)
z = mx.sin(z)
z = mx.exp(z)
return z
result = fast_function(x, y) # Single fused kernelFor functions called repeatedly with same-shaped inputs, mx.compile traces the computation graph once, optimizes it, and reuses the compiled version.
Implementation Attempt #1: Compile the Full Forward Pass
The first attempt was ambitious—compile the entire decode step:
from functools import partial
def create_compiled_forward(model: WeDLMModel):
"""Create a compiled forward pass for decoding."""
def forward_fn(input_ids: mx.array, positions: mx.array,
cached_keys: List[mx.array], cached_values: List[mx.array],
mask: mx.array) -> Tuple[mx.array, List[mx.array], List[mx.array]]:
"""The function we want to compile."""
B, L = input_ids.shape
x = model.embed_tokens(input_ids)
new_keys, new_values = [], []
for i, layer in enumerate(model.layers):
# Attention
residual = x
x_norm = layer.input_layernorm(x)
attn = layer.self_attn
q = attn.q_proj(x_norm).reshape(B, L, attn.num_heads, attn.head_dim)
k = attn.k_proj(x_norm).reshape(B, L, attn.num_kv_heads, attn.head_dim)
v = attn.v_proj(x_norm).reshape(B, L, attn.num_kv_heads, attn.head_dim)
# Apply RoPE
q = apply_rope_positions(q, positions, attn.head_dim, attn.rope_theta)
k = apply_rope_positions(k, positions, attn.head_dim, attn.rope_theta)
# Concatenate with cache
k_full = mx.concatenate([cached_keys[i], k], axis=2)
v_full = mx.concatenate([cached_values[i], v], axis=2)
# Store new K, V
new_keys.append(k_full)
new_values.append(v_full)
# Attention
attn_out = fast.scaled_dot_product_attention(q, k_full, v_full, mask=mask)
x = residual + attn.o_proj(attn_out.reshape(B, L, -1))
# FFN
residual = x
x = residual + layer.mlp(layer.post_attention_layernorm(x))
logits = model.lm_head(model.norm(x))
return logits, new_keys, new_values
# Compile with shapeless=True to handle variable sequence lengths
compiled_fn = mx.compile(forward_fn, shapeless=True)
return compiled_fnResult: Error
ValueError: Cannot compile function with list argumentsmx.compile traces Python execution to build a computation graph. Python lists (cached_keys, cached_values) aren't part of the MLX graph—they're Python-level iteration. The function isn't traceable.
Implementation Attempt #2: Compile Per-Layer
Fine, let's compile smaller units:
def create_compiled_attention(layer: TransformerBlock):
"""Compile just the attention computation."""
def attn_fn(x: mx.array, k_cache: mx.array, v_cache: mx.array,
positions: mx.array, mask: mx.array) -> Tuple[mx.array, mx.array, mx.array]:
attn = layer.self_attn
B, L, _ = x.shape
x_norm = layer.input_layernorm(x)
q = attn.q_proj(x_norm).reshape(B, L, attn.num_heads, attn.head_dim)
k = attn.k_proj(x_norm).reshape(B, L, attn.num_kv_heads, attn.head_dim)
v = attn.v_proj(x_norm).reshape(B, L, attn.num_kv_heads, attn.head_dim)
q = fast.rope(q, attn.head_dim, offset=positions)
k = fast.rope(k, attn.head_dim, offset=positions)
k_full = mx.concatenate([k_cache, k], axis=2)
v_full = mx.concatenate([v_cache, v], axis=2)
attn_out = fast.scaled_dot_product_attention(q, k_full, v_full, mask=mask)
return attn.o_proj(attn_out.reshape(B, L, -1)), k_full, v_full
return mx.compile(attn_fn, shapeless=True)
# Create compiled versions for all layers
compiled_attns = [create_compiled_attention(layer) for layer in model.layers]This worked! But the results were underwhelming.
Benchmarking Compiled vs Non-Compiled
Careful benchmarks:
def benchmark_decode(model, input_ids, cache, num_trials=100):
"""Benchmark decode step timing."""
times = []
for _ in range(num_trials):
# Warmup the cache state
cache_keys = [c.k for c in cache.layers]
cache_values = [c.v for c in cache.layers]
start = time.perf_counter()
logits = model.decode_step(input_ids, cache)
mx.eval(logits)
elapsed = time.perf_counter() - start
times.append(elapsed)
return np.mean(times[10:]), np.std(times[10:]) # Skip first 10 for warmup
# Test on small model first (128M params for faster iteration)
small_model = load_small_test_model()
print("Small model (128M):")
non_compiled_time, nc_std = benchmark_decode(small_model, input_ids, cache)
print(f" Non-compiled: {non_compiled_time*1000:.2f}ms ± {nc_std*1000:.2f}ms")
compiled_small = create_compiled_model(small_model)
compiled_time, c_std = benchmark_decode(compiled_small, input_ids, cache)
print(f" Compiled: {compiled_time*1000:.2f}ms ± {c_std*1000:.2f}ms")
# Full WeDLM-8B
print("\nWeDLM-8B:")
non_compiled_time, nc_std = benchmark_decode(wedlm_model, input_ids, cache)
print(f" Non-compiled: {non_compiled_time*1000:.2f}ms ± {nc_std*1000:.2f}ms")
compiled_wedlm = create_compiled_model(wedlm_model)
compiled_time, c_std = benchmark_decode(compiled_wedlm, input_ids, cache)
print(f" Compiled: {compiled_time*1000:.2f}ms ± {c_std*1000:.2f}ms")Results:
| Model | Non-Compiled | Compiled | Speedup |
|---|---|---|---|
| Small (128M test) | 0.67ms | 0.74ms | 0.91x |
| WeDLM-8B (full) | 127ms | 127ms | 1.00x |
The small model is actually slower with compilation (compilation overhead dominates). The full model shows no improvement at all.
Why Doesn't It Help?
After investigation, three factors explain the non-result:
Factor 1: MLX's Lazy Evaluation Already Optimizes
Without mx.compile, MLX already builds and optimizes computation graphs:
# This already builds a lazy graph
y = mx.add(mx.multiply(a, b), c)
# MLX sees: add(multiply(a, b), c)
# Optimizes before execution
mx.eval(y)mx.compile adds explicit JIT tracing, but the lazy evaluation system already does similar optimizations. For straight-line numerical code, the difference is minimal.
Factor 2: The Bottleneck Is Matrix Multiplication
Profiling the decode step shows where time goes:
import mlx.profiler
with mlx.profiler.profile() as prof:
for _ in range(50):
logits = model.decode_step(input_ids, cache)
mx.eval(logits)
print(prof.summary())Output (simplified):
Operation | Time | % of Total
---------------------+---------+------------
MatMul (Q/K/V proj) | 45ms | 35%
MatMul (O proj) | 15ms | 12%
MatMul (FFN gate) | 20ms | 16%
MatMul (FFN up) | 20ms | 16%
MatMul (FFN down) | 20ms | 16%
Attention SDPA | 5ms | 4%
Other (RoPE, norm) | 2ms | 1%95% of time is in matrix multiplications. These are already running optimized Metal kernels. mx.compile can't make them faster—they're hardware-bound.
The remaining 5% includes:
- Python overhead: ~2%
- Memory operations: ~3%
Even if mx.compile eliminated ALL Python overhead, you'd gain maybe 2%. And my careful batching has already minimized Python overhead.
Factor 3: The Cache Access Pattern
My KV cache uses Python lists with per-layer indexing:
# In forward pass
for i, layer in enumerate(self.layers):
cached_k, cached_v = cache.layers[i].k, cache.layers[i].v
# ↑ Python list indexing
k_full = mx.concatenate([cached_k, k], axis=2)mx.compile traces function execution to build a graph. But Python list indexing (cache.layers[i]) happens in Python, not MLX. Each layer access is a separate trace boundary.
A fully compiled version would need the cache as a single MLX array with layer indices as part of the computation. That requires a major refactor:
# Hypothetical fully-traceable cache
cache_all_k = mx.stack(all_layer_keys) # Shape: (num_layers, B, H, L, D)
cache_all_v = mx.stack(all_layer_values)
# Index within MLX graph
layer_k = cache_all_k[layer_idx] # This CAN be tracedBut this makes cache updates more complex and doesn't address the fundamental bottleneck (matrix multiplies).
When DOES `mx.compile` Help?
The investigation clarified when compilation matters:
Good candidates for mx.compile:
- Custom loss functions with many arithmetic operations
- Preprocessing pipelines with lots of small ops
- Research code with complex control flow that repeats
- Functions where Python loop overhead is significant
Poor candidates for mx.compile:
- LLM inference (matrix multiply dominated)
- Single forward passes (no repetition to amortize compile cost)
- Code with Python-level data structures in the hot path
The Decision: Skip Compilation
The final implementation omits the mx.compile wrapper. The benefits were zero, but the costs were real:
- Added complexity to the code
- Longer startup time (compilation happens on first call)
- Debugging becomes harder (compiled traces are opaque)
In this project, the mx.compile overhead provided no measurable benefit—recognizing which optimization targets are not bottlenecks avoided wasted effort.
Part IX: Benchmark Results
Experimental Setup
January 3-4, 2026
With all optimizations complete, rigorous benchmarks were needed—not just "it feels fast," but reproducible numbers with proper methodology.
Hardware Configuration
- Machine: MacBook Pro 16" (2023)
- Chip: Apple M3 Max
- 14-core CPU (10 performance + 4 efficiency)
- 30-core GPU
- 16-core Neural Engine
- Memory: 48GB unified memory
- Storage: 1TB SSD (irrelevant for inference, model fits in RAM)
- Power: Always plugged in, "High Power" mode enabled
- OS: macOS Sonoma 14.2
Model Configurations
| Model Name | Base | Params | Precision | Size on Disk |
|---|---|---|---|---|
| wedlm-8b-bf16 | Qwen3-8B | 8.03B | BF16 | 16.0 GB |
| wedlm-8b-4bit | Qwen3-8B | 8.03B | 4-bit | 4.0 GB |
| wedlm-8b-4bit-sel | Qwen3-8B | 8.03B | Mixed | 5.8 GB |
| qwen3-8b-bf16 | Qwen3-8B | 8.03B | BF16 | 16.0 GB |
| qwen3-8b-4bit | Qwen3-8B | 8.03B | 4-bit | 4.0 GB |
Benchmarking Methodology
Proper LLM benchmarks require careful methodology to avoid common pitfalls:
def benchmark_throughput(model, prompts: List[str], max_tokens: int = 256,
num_warmup: int = 3, num_trials: int = 10) -> dict:
"""Benchmark generation throughput with proper methodology."""
results = []
for i, prompt in enumerate(prompts):
prompt_results = {"prompt_idx": i, "trials": []}
for trial in range(num_warmup + num_trials):
# Force garbage collection between trials
import gc
gc.collect()
# Clear any pending MLX operations
mx.synchronize()
# Time the generation
start = time.perf_counter()
output = model.generate(prompt, max_tokens=max_tokens)
# Ensure all GPU work is complete
mx.synchronize()
elapsed = time.perf_counter() - start
# Count actual generated tokens (not max_tokens)
output_tokens = len(model.tokenizer.encode(output.text))
input_tokens = len(model.tokenizer.encode(prompt))
generated_tokens = output_tokens - input_tokens
# Skip warmup trials
if trial >= num_warmup:
prompt_results["trials"].append({
"elapsed": elapsed,
"generated_tokens": generated_tokens,
"tok_per_s": generated_tokens / elapsed,
"prefill_time": output.prefill_time if hasattr(output, 'prefill_time') else None,
})
results.append(prompt_results)
return results
def compute_statistics(results: List[dict]) -> dict:
"""Compute aggregate statistics from benchmark results."""
all_tps = [t["tok_per_s"] for r in results for t in r["trials"]]
return {
"mean_tok_per_s": np.mean(all_tps),
"std_tok_per_s": np.std(all_tps),
"median_tok_per_s": np.median(all_tps),
"p5_tok_per_s": np.percentile(all_tps, 5),
"p95_tok_per_s": np.percentile(all_tps, 95),
"num_samples": len(all_tps),
}Benchmark Prompts
A diverse prompt set captures different generation scenarios:
benchmark_prompts = [
# Short factual
"What is the capital of France?",
# Medium explanation
"Explain the difference between TCP and UDP in networking.",
# Code generation
"Write a Python function to compute the nth Fibonacci number using dynamic programming.",
# Long-form reasoning
"Describe the process of photosynthesis and explain why it's important for life on Earth.",
# Math (tests accuracy)
"Solve step by step: A train travels 120 miles in 2 hours. How long to travel 300 miles?",
]MLX vs PyTorch MPS Comparison
First, comparing MLX to PyTorch's MPS backend for autoregressive Qwen3-8B:
# PyTorch MPS baseline
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
device = torch.device("mps")
pt_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-8B",
torch_dtype=torch.bfloat16,
device_map="mps"
)
pt_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
# MLX baseline
from mlx_lm import load, generate as mlx_generate
mlx_model, mlx_tokenizer = load("mlx-community/Qwen3-8B-4bit")Results:
| Backend | Precision | Throughput (tok/s) | Memory Usage | vs PyTorch BF16 |
|---|---|---|---|---|
| PyTorch MPS | BF16 | 29.1 ± 2.3 | ~18 GB | baseline |
| PyTorch MPS | FP16 | 30.4 ± 2.1 | ~17 GB | +4% |
| MLX | BF16 | 35.0 ± 1.8 | ~17 GB | +20% |
| MLX | 4-bit | 60.9 ± 3.2 | ~5 GB | +109% |
MLX consistently outperforms PyTorch MPS. The advantage comes from:
- Native Metal Integration: MLX is written for Metal, while PyTorch MPS is a translation layer
- Unified Memory Efficiency: MLX better exploits the unified memory architecture
- Lazy Evaluation: Graph optimization reduces memory copies and intermediate storage
- Optimized Ops:
fast.rope,fast.scaled_dot_product_attentionare Metal-native
WeDLM Diffusion vs Autoregressive
The main comparison: does WeDLM's parallel decoding actually help?
# WeDLM (diffusion decoding)
wedlm = WeDLMDecoder.from_pretrained("./wedlm-8b-4bit-selective")
# Baseline (autoregressive, same base model)
from mlx_lm import load, generate
ar_model, ar_tokenizer = load("mlx-community/Qwen3-8B-4bit")Throughput Comparison:
| Method | Precision | Throughput | Tokens/Forward | Speedup vs AR |
|---|---|---|---|---|
| Qwen3-8B AR | 4-bit | 40.2 tok/s | 1.0 | baseline |
| WeDLM (window=8) | 4-bit sel | 50.1 tok/s | 1.8 | 1.25x |
| WeDLM (window=16) | 4-bit sel | 55.2 tok/s | 2.2 | 1.37x |
| WeDLM (window=32) | 4-bit sel | 52.8 tok/s | 2.0 | 1.31x |
Interesting findings:
- Window size 16 is the sweet spot on M3 Max
- Window size 32 is actually slower—the overhead of managing a larger window outweighs the parallelism benefit
- Average tokens per forward ranges from 1.8 to 2.2, not the 3-4 claimed in the paper
Why Fewer Tokens Per Forward?
The paper reports higher parallelism because:
- They use larger models with more confident predictions
- Their entropy threshold is tuned differently
- CUDA FlashAttention makes larger windows more efficient
On M3 Max, the attention computation per window token is more expensive relative to the matrix multiplies, so smaller windows are more efficient.
Accuracy on GSM8K
Speed is meaningless if quality suffers. Evaluation on GSM8K (grade school math):
def evaluate_gsm8k(model, num_samples: int = 200) -> float:
"""Evaluate on GSM8K subset."""
from datasets import load_dataset
dataset = load_dataset("gsm8k", "main", split="test")
samples = dataset.select(range(num_samples))
correct = 0
for sample in samples:
question = sample["question"]
answer = sample["answer"]
# Extract ground truth (last number in answer)
ground_truth = extract_final_number(answer)
# Generate response
prompt = f"Q: {question}\nA: Let me solve this step by step.\n"
response = model.generate(prompt, max_tokens=256)
# Extract predicted answer
predicted = extract_final_number(response.text)
if predicted == ground_truth:
correct += 1
return correct / num_samplesResults (200-sample subset):
| Model | GSM8K Accuracy | Notes |
|---|---|---|
| Qwen3-8B AR (BF16) | 80.0% | Reference quality |
| Qwen3-8B AR (4-bit) | 78.5% | Minor quantization impact |
| WeDLM (BF16) | 71.5% | -8.5% vs AR |
| WeDLM (4-bit sel) | 70.0% | -10% vs AR |
| WeDLM (4-bit all) | 64.0% | Unacceptable for math |
The accuracy gap is concerning. WeDLM's parallel decoding introduces errors because:
- Entropy-based commitment: Tokens are committed based on confidence, not correctness
- Early mistakes propagate: An early wrong digit affects all subsequent reasoning
- Position penalty bias: May commit wrong tokens early to maximize cache utilization
For math tasks where precision matters, AR decoding is still preferable.
Prefix Cacheability Analysis
WeDLM's key innovation is prefix cacheability. Measuring the actual :
def measure_prefix_cacheability(model, prompts: List[str], max_tokens: int = 128):
"""Measure actual prefix cache hit rate."""
total_generated = 0
total_forwards = 0
for prompt in prompts:
state = model.initialize_state(prompt)
while not state.is_finished and state.num_generated < max_tokens:
state, committed = model.decode_step(state)
total_generated += len(committed)
total_forwards += 1
p_cache = total_generated / total_forwards
return p_cache
# Measure across prompt types
cache_results = {}
for prompt_type, prompts in prompt_categories.items():
p_cache = measure_prefix_cacheability(wedlm, prompts)
cache_results[prompt_type] = p_cache
print("Prefix Cacheability by Prompt Type:")
for ptype, pcache in cache_results.items():
print(f" {ptype}: {pcache:.2f}")Results:
| Prompt Type | Notes | |
|---|---|---|
| Code generation | 0.92 | Highly predictable tokens |
| Factual Q&A | 0.88 | Common patterns |
| Creative writing | 0.75 | More varied vocabulary |
| Math reasoning | 0.70 | Numbers are high-entropy |
| Overall average | 0.81 | Weighted by prompt frequency |
The paper reports on their benchmarks. We achieve 0.81, which is reasonable given different evaluation prompts and entropy thresholds.
Memory Scaling with Sequence Length
How does memory usage scale as we generate longer sequences?
def measure_memory_scaling(model, prompt: str, max_lengths: List[int]):
"""Measure peak memory at different sequence lengths."""
import tracemalloc
results = []
for max_len in max_lengths:
tracemalloc.start()
output = model.generate(prompt, max_tokens=max_len)
mx.synchronize()
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
results.append({
"max_tokens": max_len,
"peak_memory_gb": peak / 1e9,
"actual_tokens": len(model.tokenizer.encode(output.text))
})
return resultsKV Cache Memory (BF16, batch=1):
| Sequence Length | KV Cache Size | Total Memory | Notes |
|---|---|---|---|
| 512 | 288 MB | 17.3 GB | Short conversations |
| 1024 | 576 MB | 17.6 GB | Typical chat |
| 2048 | 1.15 GB | 18.2 GB | Long documents |
| 4096 | 2.3 GB | 19.3 GB | Extended context |
| 8192 | 4.6 GB | 21.6 GB | M3 Max comfortable |
| 16384 | 9.2 GB | 26.2 GB | Approaching limits |
Formula:
Where = num*layers (36), = num_kv_heads (8), = head_dim (128), = sequence length, = bytes per element (2 for BF16).
For WeDLM-8B: bytes ≈ 0.14 MB per token.
Comparison with Paper Claims
The WeDLM paper reports 3x speedup over vLLM on NVIDIA A100. Our MLX port achieves 1.25-1.37x speedup over MLX AR.
Why the gap?
| Factor | Paper (CUDA) | This Port (MLX) |
|---|---|---|
| Attention | FlashAttention (2-4x faster) | Standard SDPA |
| Hardware | A100 (80GB, 312 TFLOPS) | M3 Max (48GB, ~14 TFLOPS) |
| Parallelism | Large batch, high GPU utilization | Single stream, batch=1 |
| Framework maturity | CUDA 15+ years | MLX ~1 year |
| Window overhead | Amortized over larger batches | Dominates at batch=1 |
The comparison isn't entirely fair—they're measuring different things. But it sets realistic expectations for MLX deployment.
End-to-End Latency
For interactive applications, latency matters as much as throughput:
def measure_latency_breakdown(model, prompt: str, max_tokens: int = 100):
"""Measure latency components."""
# Tokenization
t0 = time.perf_counter()
input_ids = model.tokenizer.encode(prompt)
tokenize_time = time.perf_counter() - t0
# Prefill
t1 = time.perf_counter()
cache = model.prefill(mx.array([input_ids]))
mx.synchronize()
prefill_time = time.perf_counter() - t1
# Decode
decode_times = []
tokens_per_step = []
state = model.initialize_decode_state()
while not state.is_finished and state.num_generated < max_tokens:
t_start = time.perf_counter()
state, committed = model.decode_step(state, cache)
mx.synchronize()
t_end = time.perf_counter()
decode_times.append(t_end - t_start)
tokens_per_step.append(len(committed))
return {
"tokenize_ms": tokenize_time * 1000,
"prefill_ms": prefill_time * 1000,
"decode_mean_ms": np.mean(decode_times) * 1000,
"decode_total_ms": sum(decode_times) * 1000,
"tokens_per_step_mean": np.mean(tokens_per_step),
"total_tokens": sum(tokens_per_step),
}Latency Breakdown (100-token generation):
| Component | Time (ms) | % of Total |
|---|---|---|
| Tokenization | 2 | 0.1% |
| Prefill (512 tok) | 180 | 9.5% |
| Decode (total) | 1700 | 89.9% |
| Total | 1882 | 100% |
Per-step decode latency: ~34ms average (with ~2 tokens committed per step).
Time-to-first-token: ~182ms (tokenize + prefill).
Part X: Implementation Takeaways
Engineering Notes
January 4, 2026
This section summarizes the main engineering patterns that turned out to matter for the WeDLM→MLX port and are likely to generalize to other MLX deployments.
Profiling Before Optimization
Several hours early on went to optimizing KV cache update patterns—minimizing copies, batching updates, avoiding redundant operations. Then profiling revealed the actual breakdown:
import mlx.profiler
with mlx.profiler.profile() as prof:
for _ in range(50):
output = model.decode_step(input_ids, cache)
mx.eval(output)
print(prof.summary())Result:
Operation Breakdown:
MatMul operations: 95.2%
Attention (SDPA): 2.1%
KV cache updates: 1.3% ← All that work for 1.3%
LayerNorm: 0.8%
RoPE: 0.3%
Other: 0.3%Even a 10x improvement in KV cache updates would yield a 1.2% total speedup. The optimization effort targeted 1% of runtime, not the 95%.
In this project, most attempted micro-optimizations had negligible impact compared to addressing the dominant cost (matrix multiplications), so profiling was essential before modifying hot paths.
MLX Lazy Evaluation
MLX's lazy evaluation affects how optimization works:
| PyTorch (Eager) | MLX (Lazy) |
|---|---|
| Minimize operation count | Minimize synchronization points |
| Each op executes immediately | Ops build graph, execute on eval |
.item() is fast (just returns) |
.item() forces full graph evaluation |
| Small loops are fine | Small loops with ops = many small graphs |
Code that's idiomatic in PyTorch destroys MLX performance:
# PyTorch style - fine
for i in range(window_size):
if logits[i].argmax().item() > threshold:
results.append(i)
# MLX equivalent - TERRIBLE (window_size sync points)
for i in range(window_size):
if mx.argmax(logits[i]).item() > threshold:
results.append(i)
# MLX style - good (one sync point)
max_vals = mx.argmax(logits, axis=-1)
mx.eval(max_vals)
results = [i for i, v in enumerate(max_vals.tolist()) if v > threshold]Same algorithm, 100x performance difference based on framework understanding.
Selective Quantization Strategy
The "quantize everything to 4-bit" approach is not optimal:
| Layer Type | Quantization Impact | Recommendation |
|---|---|---|
| Attention projs | Low | Safe to quantize |
| FFN layers | Medium | Usually safe |
| Embeddings | Medium | Consider keeping BF16 |
| lm_head | High | Keep in high precision |
| Layer norms | Very high | Never quantize |
The output layer (lm_head) is special because argmax is discontinuous—small errors flip token choices. Attention layers are forgiving because softmax smooths errors.
Recommended approach: Start with selective quantization (exclude lm_head and embeddings), measure accuracy on the target task, then decide if more aggressive quantization is acceptable.
Threading Constraints with Metal
The SIGSEGV crash revealed Metal's threading constraints:
✗ Two threads calling mx.eval() simultaneously → SIGSEGV
✗ Multiple MLX models in ThreadPoolExecutor → SIGSEGV
✓ Sequential model calls → Works
✓ Separate processes → Works (but memory overhead)The underlying issue: Metal command buffers aren't thread-safe, and MLX uses global state. This isn't prominently documented because single-threaded use is the norm.
In practice, one MLX context per process is required. If concurrent GPU work is needed, multiprocessing (not threading) is necessary, or the application must be designed for sequential inference.
Framework Limitations
The mx.compile investigation illustrates when to stop optimizing:
- Hours spent trying to make compilation work
- Result: 0% improvement (or negative for small models)
- Root cause: The bottleneck (matrix multiplies) is already optimized
When workarounds become increasingly complex without performance improvement, it may indicate the framework's design constraints are fundamental rather than incidental.
Diffusion Decoding Trade-offs
WeDLM is not a pure upgrade over AR decoding:
| Aspect | AR Decoding | Diffusion (WeDLM) |
|---|---|---|
| Throughput | Lower | Higher (1.2-1.4x) |
| Accuracy | Higher | Lower (-8-10% on math) |
| Latency | Predictable | Variable (depends on entropy) |
| Complexity | Simple | Complex (window, reordering) |
| Best for | Precision tasks | High-volume generation |
The paper emphasizes speedup; the accuracy trade-off is acknowledged but easy to overlook. For production, you need to evaluate on YOUR task, not benchmark sets.
Paper-to-Implementation Gap
The WeDLM paper reports 3x speedup; this port achieved 1.3x. The gap reflects:
- Different hardware (M3 Max vs A100)
- Different frameworks (MLX vs CUDA)
- Different evaluation methodology
- Missing optimizations (FlashAttention)
Reading ML papers pragmatically:
- Note the claims (3x speedup)
- Note the conditions (A100, FlashAttention, specific tasks)
- Expect 30-50% of claimed gains when porting
- Measure on YOUR setup and YOUR task
Local Inference Viability
Despite the caveats, the final system runs WeDLM-8B at 55 tokens/second on a laptop:
- Real-time conversation
- Private (no data leaves the machine)
- Offline-capable
- No API costs
Running 8B models locally at conversational speed is now practical on consumer hardware.
Part XI: Future Work
What Would Make This Better
1. FlashAttention for Metal
The single highest-impact optimization that doesn't exist yet. FlashAttention on CUDA achieves 2-4x attention speedup through:
- Tiled computation: Process attention in blocks that fit in SRAM
- Online softmax: Compute softmax incrementally without materializing the full attention matrix
- Memory efficiency: O(N) memory instead of O(N²) for sequence length N
A Metal implementation using threadgroup memory could provide similar benefits. The scaled_dot_product_attention in MLX is already fused, but doesn't use the tiled algorithm.
Potential impact:
Current attention: 2.1% of total time (after SDPA fusion)
With FlashAttention: Could reduce by 50-75%
Total speedup: ~1-2%... actually not huge for inference
BUT: For long sequences (8K+), memory savings are crucial
Current: O(N²) attention matrix
FlashAttention: O(N) intermediate storageFor inference, FlashAttention is more about enabling longer sequences than raw speed. Still valuable.
2. Better Entropy Calibration
The current position penalty () was empirically chosen through manual tuning. A more sophisticated approach:
class AdaptiveEntropySelector:
"""Learn optimal commitment thresholds from data."""
def __init__(self, calibration_data: List[str]):
# Collect entropy statistics on calibration prompts
self.entropy_stats = self._collect_stats(calibration_data)
# Learn task-dependent thresholds
self.code_threshold = self._compute_optimal_threshold("code")
self.prose_threshold = self._compute_optimal_threshold("prose")
self.math_threshold = self._compute_optimal_threshold("math")
def select_commitment_positions(self, entropies: mx.array,
task_type: str) -> List[int]:
"""Select positions to commit based on calibrated thresholds."""
threshold = getattr(self, f"{task_type}_threshold")
# Position penalty could also be learned
adjusted = entropies + self.learned_position_penalty * positions
return [i for i, e in enumerate(adjusted.tolist()) if e < threshold]Different tasks have different entropy profiles. Code has low entropy (predictable tokens), creative writing has high entropy (many valid next tokens). A learned selector could adapt accordingly.
3. Speculative Decoding Integration
WeDLM generates multiple candidate tokens per step. This naturally fits speculative decoding:
Speculative decoding flow:
1. Draft model generates K tokens quickly: [t1, t2, t3, t4]
2. Target model verifies all K in single forward pass
3. Accept prefix that matches, regenerate from first mismatch
WeDLM hybrid:
1. WeDLM generates window: [t1, t2, t3, t4] with confidences
2. Verify uncertain tokens with AR pass on target model
3. Accept confident tokens, use AR for uncertain onesThis could preserve WeDLM's speed while recovering accuracy on low-confidence positions. Implementation complexity is high, but the potential is interesting.
4. Continuous Batching
For server deployment, batching multiple requests is essential:
class ContinuousBatcher:
"""Batch multiple requests dynamically."""
def __init__(self, model, max_batch_size: int = 8):
self.model = model
self.pending_requests = []
self.active_states = []
def add_request(self, prompt: str, callback: Callable):
"""Add new request to the queue."""
self.pending_requests.append((prompt, callback))
self._maybe_promote_requests()
def step(self):
"""Process one step for all active requests."""
if not self.active_states:
return
# Batch all active states into single forward pass
batched_input = self._batch_states(self.active_states)
batched_output = self.model.forward_batch(batched_input)
# Unbatch and update individual states
for state, output in zip(self.active_states, self._unbatch(batched_output)):
state.update(output)
if state.is_finished:
state.callback(state.result)
self.active_states.remove(state)MLX doesn't have built-in support for this (unlike vLLM's PagedAttention), but the primitives exist. It would require significant engineering.
5. Model Parallelism for Larger Models
For models larger than unified memory (70B+), model parallelism across multiple Macs could work:
# Hypothetical distributed inference
class DistributedWeDLM:
def __init__(self, model_shards: List[str], hosts: List[str]):
# Shard model across machines
# Each machine holds layers [start:end]
self.shards = [
RemoteShard(host, shard_path)
for host, shard_path in zip(hosts, model_shards)
]
def forward(self, x: mx.array) -> mx.array:
for shard in self.shards:
x = shard.forward(x) # Network transfer between shards
return xThe latency overhead of network transfers would be significant, but it's the only way to run models that don't fit in memory.
Conclusion
Summary
The WeDLM port to Apple Silicon achieved:
Performance:
- 55 tokens/second on M4 Max (4-bit selective quantization)
- 1.37x speedup over autoregressive baseline
- Time-to-first-token: ~180ms
Quality:
- 70% accuracy on GSM8K (vs 80% for AR baseline)
- Acceptable for conversational use, not ideal for precision tasks
What Worked:
- MLX provides 20%+ speedup over PyTorch MPS
- 4-bit quantization with selective exclusion balances speed and quality
- Batched operations with single sync points are critical
- Topological reordering genuinely enables parallel decoding with KV caching
What Didn't:
mx.compileprovides no benefit (matrix multiply dominated)- Multi-threaded Metal access crashes (use sequential or multiprocessing)
- Diffusion decoding accuracy is lower on reasoning tasks
- The 3x speedup from the paper isn't achievable on consumer hardware