Zimeng Xiong's Weblog

About

Porting WeDLM to MLX

Tencent's WeDLM paper reports 3x speedup over vLLM using diffusion-style parallel decoding—multiple tokens generated per forward pass while maintaining KV cache compatibility. This post documents an independent port of WeDLM to Apple's MLX framework, covering the architecture decisions, performance optimizations, and implementation details from a week of development on an M4 Max.

The goal was to take the WeDLM research prototype and implement an efficient, production-style decoder on Apple's MLX stack for MacBook-class hardware. The work required re-implementing Qwen3-8B in MLX, adapting the diffusion-style decoding algorithm, and profiling MLX's execution model on Apple Silicon.

This post is structured chronologically, following the actual development process from December 29, 2025 to January 2, 2026. Each section corresponds to a specific problem encountered and the technical solution developed.

This shows the AR version of the model (Qwen-3B-IT) on the left and the MLX-optimized WeDLM-8B-IT on the left. Each model was ran for three consecutive attempts. AR stayed steady at ~50tok/s average between all runs. WeDLM-MLX started slow on the first run at 26tok/s, but then sped up to averaging 90-100tok/s in subsequent runs, due to warming up of the KVCache because page-attention is not implemented in MLX yet, and we are relying on JIT compiliation. Paged attention and metal flash attention will be a challenge for another time.


Part I: Understanding WeDLM

The Autoregressive Bottleneck

Before diving into implementation, it's worth understanding why WeDLM exists and what problem it solves.

Standard LLMs are autoregressive (AR): they generate one token at a time, with each token conditioned on all previous tokens. The joint probability of a sequence factorizes as:

P(x)=t=1TP(xtx<t;θ)P(\mathbf{x}) = \prod_{t=1}^{T} P(x_t \mid x_{<t}; \theta)

This sequential dependency has a critical performance implication. For an 8B parameter model generating 100 tokens:

  • 100 forward passes through 8B parameters
  • Each pass loads ~16GB of weights from memory
  • Total memory bandwidth: 100 × 16GB = 1.6TB

Modern GPUs have memory bandwidths around 1-2 TB/s. This means generating 100 tokens at 8B parameters takes roughly 1-2 seconds just for memory access—the actual compute is almost free. This is what we mean by memory-bound inference.

The obvious solution: generate multiple tokens per forward pass. If we can predict 4 tokens at once, we reduce memory bandwidth by 4x. This is exactly what diffusion language models attempt.

Why Traditional Diffusion Models Fail

Diffusion Language Models (DLLMs) like LLaDA and Dream work by predicting all masked positions simultaneously. The model sees:

Input:  [The] [MASK] [MASK] [on] [MASK] [mat]
Output: [The] [cat]  [sat]  [on] [the]  [mat]

The problem is bidirectional attention. To predict position 2 ("sat"), the model needs to see position 5 ("the"). To predict position 5, it needs to see position 2. Every token attends to every other token.

This completely breaks KV caching.

Why KV Caching Matters

In AR generation, we cache the key-value projections from previous tokens:

Step 1: Process [The] → Cache K₀, V₀
Step 2: Process [cat], attend to K₀,V₀ → Cache K₁, V₁
Step 3: Process [sat], attend to K₀,V₀,K₁,V₁ → Cache K₂, V₂

The cached K/V for token 0 never changes—it only depends on token 0's input. This is what makes AR inference tractable.

With bidirectional attention, K₀ depends on all tokens, including ones we haven't generated yet. If we change token 5, K₀ changes. Nothing can be cached until the entire sequence is finalized. We'd need to recompute the full attention matrix at every refinement step—exactly the O(n²) cost we were trying to avoid.

WeDLM's Insight: Topological Reordering

WeDLM's key insight is that mask recovery doesn't require bidirectional attention. Each masked position needs to see all observed (unmasked) tokens, but it doesn't need to see other masked positions bidirectionally.

The solution is Topological Reordering: physically move observed tokens to the prefix while preserving their logical positions through RoPE embeddings.

Original sequence:  [The] [cat] [MASK] [on] [MASK] [mat]
Positions:           0     1      2     3     4      5

After reordering:   [The] [cat] [on] [mat] [MASK] [MASK]
Physical index:       0     1    2    3      4      5
Logical position:     0     1    3    5      2      4    ← Preserved via RoPE!

The critical detail: each token's logical position is preserved through RoPE (Rotary Position Embeddings). RoPE applies position-dependent rotations to queries and keys:

qrotated=qeiposθq_{\text{rotated}} = q \cdot e^{i \cdot \text{pos} \cdot \theta} krotated=keiposθk_{\text{rotated}} = k \cdot e^{i \cdot \text{pos} \cdot \theta}

The attention score between query at position pqp_q and key at position pkp_k depends on (pqpk)(p_q - p_k)—the relative position. By supplying the logical position to RoPE regardless of physical position, we preserve the correct relative position relationships.

Under standard causal attention (physical index ii attends to indices 0..i10..i-1), position 4 (physical index 4, logical position 2) can attend to:

  • Physical indices 0-3
  • Which contain ALL observed tokens (logical positions 0, 1, 3, 5)

The mask tokens see the full observed context. Causal attention is preserved. KV caching works.

The Prefix Cacheability Metric

The paper introduces a useful metric: prefix cacheability pcachep_{\text{cache}}:

pcache=NgenNfwd(0,1]p_{\text{cache}} = \frac{N_{\text{gen}}}{N_{\text{fwd}}} \in (0, 1]

Where:

  • NgenN_{\text{gen}} = tokens that become part of the final output
  • NfwdN_{\text{fwd}} = total tokens processed across all forward passes (including recomputation)

If pcache=1.0p_{\text{cache}} = 1.0, every processed token becomes output—perfect efficiency (standard AR). If pcache=0.5p_{\text{cache}} = 0.5, half the tokens are recomputed—50% wasted work.

WeDLM achieves high pcachep_{\text{cache}} by committing tokens left-to-right into a growing prefix, ensuring most predictions become immediately cache-valid.


Part II: The MLX Port

Day 1: Project Setup (December 29, 2025)

The first task was establishing the model architecture. WeDLM-8B is based on Qwen3-8B, a decoder-only transformer with these specifications:

Parameter Value Notes
Hidden size 4096 Dimension of hidden states
Num layers 36 Transformer blocks
Num attention heads 32 Query heads
Num KV heads 8 Key/Value heads (GQA)
Head dimension 128 4096 / 32
Intermediate size 12288 FFN hidden dimension
Vocab size 151,936 Qwen's BPE vocabulary
RoPE θ 1,000,000 Base frequency for RoPE
RMS norm ε 1e-6 LayerNorm epsilon
QK norm True RMSNorm on Q and K

Several architectural choices are notable:

Grouped Query Attention (GQA): 8 KV heads shared across 32 query heads—a 4x reduction in KV cache memory. This is critical for inference efficiency.

QK Normalization: Qwen3 applies RMSNorm to queries and keys before attention. This isn't mentioned in most transformer tutorials but is essential for training stability at scale.

High RoPE θ: The 1M base frequency (vs. 10K in original RoPE) enables longer context windows through slower position frequency decay.

The port was structured as:

mlx_wedlm/
├── model.py      # WeDLMModel, KVCache variants, Attention, MLP
├── decoder.py    # WeDLMDecoder, sampling, topological reordering
├── generate.py   # High-level generation API, streaming
└── __init__.py

Initial Model Definition

The core model follows standard transformer structure but with MLX-specific optimizations:

class WeDLMModel(nn.Module):
    """WeDLM model optimized for MLX with fast operations."""

    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
        self.layers = [TransformerBlock(config, i) for i in range(config.num_hidden_layers)]
        self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        if config.tie_word_embeddings:
            self.lm_head.weight = self.embed_tokens.weight

    def __call__(
        self,
        input_ids: mx.array,
        offset: int = 0,
        mask: Optional[mx.array] = None,
        cache: Optional[KVCache] = None,
    ) -> mx.array:
        # Embedding
        x = self.embed_tokens(input_ids)

        # Forward through layers
        for layer in self.layers:
            x = layer(x, offset, mask, cache)

        # Final norm and LM head
        x = self.norm(x)
        return self.lm_head(x)

The Attention Module

The attention implementation required careful consideration of MLX's API. Unlike PyTorch where you might manually implement attention, MLX provides fast.scaled_dot_product_attention with native GQA support:

class Attention(nn.Module):
    """Optimized attention using mx.fast.scaled_dot_product_attention."""

    def __init__(self, config: ModelConfig):
        super().__init__()
        self.num_heads = config.num_attention_heads
        self.num_kv_heads = config.num_key_value_heads
        self.head_dim = config.head_dim
        self.scale = self.head_dim ** -0.5
        self.rope_theta = config.rope_theta

        # Projections - no bias per Qwen3 architecture
        self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)

        # QK normalization - critical for Qwen3!
        self.qk_norm = config.qk_norm
        if self.qk_norm:
            self.q_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
            self.k_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)

    def __call__(
        self,
        x: mx.array,
        offset: int = 0,
        mask: Optional[mx.array] = None,
        cache: Optional[KVCache] = None,
        layer_idx: int = 0,
    ) -> mx.array:
        B, L, _ = x.shape

        # Project to Q, K, V
        q = self.q_proj(x).reshape(B, L, self.num_heads, self.head_dim)
        k = self.k_proj(x).reshape(B, L, self.num_kv_heads, self.head_dim)
        v = self.v_proj(x).reshape(B, L, self.num_kv_heads, self.head_dim)

        # Apply QK norm BEFORE RoPE - this order matters!
        if self.qk_norm:
            q = self.q_norm(q)
            k = self.k_norm(k)

        # Transpose to (B, num_heads, L, head_dim) for RoPE and SDPA
        q = q.transpose(0, 2, 1, 3)
        k = k.transpose(0, 2, 1, 3)
        v = v.transpose(0, 2, 1, 3)

        # Apply RoPE using fast.rope
        q = fast.rope(q, self.head_dim, traditional=False, base=self.rope_theta, scale=1.0, offset=offset)
        k = fast.rope(k, self.head_dim, traditional=False, base=self.rope_theta, scale=1.0, offset=offset)

        # Update KV cache if provided
        if cache is not None:
            k, v = cache.update(layer_idx, k, v)

        # SDPA with native GQA support - no need to repeat K/V!
        output = fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)

        # Reshape back and project
        output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
        return self.o_proj(output)

A crucial detail: MLX's fast.scaled_dot_product_attention handles GQA automatically. In PyTorch, you'd need to manually repeat K/V heads to match Q heads. MLX does this internally, saving memory and compute.

The MLP Block

The MLP uses SwiGLU activation, which has become standard in modern LLMs:

class MLP(nn.Module):
    """MLP with SwiGLU activation."""

    def __init__(self, config: ModelConfig):
        super().__init__()
        self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)

    def __call__(self, x: mx.array) -> mx.array:
        # SwiGLU: down(silu(gate(x)) * up(x))
        return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))

SwiGLU replaces the traditional ReLU-based MLP with a gated mechanism. The gate_proj and up_proj are both (4096 → 12288), and down_proj is (12288 → 4096). The element-wise multiplication creates an implicit gating mechanism.


Day 1 (continued): KV Cache Design

The KV cache implementation is where WeDLM diverges most from standard AR inference. Three variants handle different use cases:

Variant 1: Simple Concatenation Cache

The simplest approach—just concatenate new K/V to existing cache:

class KVCache:
    """Simple KV Cache with concatenation."""

    def __init__(self, num_layers: int):
        self.num_layers = num_layers
        self.keys: List[Optional[mx.array]] = [None] * num_layers
        self.values: List[Optional[mx.array]] = [None] * num_layers
        self._offset = 0

    def update(self, layer_idx: int, key: mx.array, value: mx.array) -> Tuple[mx.array, mx.array]:
        """Update cache for a layer and return full K, V."""
        if self.keys[layer_idx] is None:
            self.keys[layer_idx] = key
            self.values[layer_idx] = value
        else:
            self.keys[layer_idx] = mx.concatenate([self.keys[layer_idx], key], axis=2)
            self.values[layer_idx] = mx.concatenate([self.values[layer_idx], value], axis=2)
        return self.keys[layer_idx], self.values[layer_idx]

    def get_seq_len(self) -> int:
        """Get current cached sequence length."""
        if self.keys[0] is None:
            return 0
        return self.keys[0].shape[2]

This is theoretically inefficient—concatenation copies the entire array. But MLX's lazy evaluation means the copy is often fused with downstream operations. Benchmarking against pre-allocated buffers showed <5% difference for typical generation lengths (<2K tokens).

The cache stores tensors of shape (B, num_kv_heads, seq_len, head_dim). For WeDLM-8B with batch size 1:

  • Per layer: 1 × 8 × seq_len × 128 × 2 bytes = 2048 × seq_len bytes (BF16)
  • Total (36 layers, K+V): 36 × 2 × 2048 × seq_len = 147,456 × seq_len bytes

At 8K sequence length: 147,456 × 8192 ≈ 1.2 GB just for KV cache.

Variant 2: Rotating Cache

For sequences exceeding memory limits, I implemented a rotating cache that keeps the first k tokens (typically system prompt) and rotates the rest:

class RotatingKVCache:
    """Rotating KV Cache with fixed-size circular buffer."""

    def __init__(self, num_layers: int, max_size: int = 4096, keep: int = 4):
        self.max_size = max_size
        self.keep = keep  # Always keep first 'keep' tokens (system prompt)
        self.keys: List[Optional[mx.array]] = [None] * num_layers
        self.values: List[Optional[mx.array]] = [None] * num_layers

    def update(self, layer_idx: int, key: mx.array, value: mx.array):
        prev_k = self.keys[layer_idx]
        n_new = key.shape[2]

        if prev_k is None:
            self.keys[layer_idx] = key
            self.values[layer_idx] = value
        else:
            prev_len = prev_k.shape[2]
            new_len = prev_len + n_new

            if new_len <= self.max_size:
                # Under max size - just concatenate
                self.keys[layer_idx] = mx.concatenate([prev_k, key], axis=2)
                self.values[layer_idx] = mx.concatenate([self.values[layer_idx], value], axis=2)
            else:
                # Over max size - keep first tokens, trim oldest from middle
                trim = min(n_new, prev_len - self.keep)
                kept_k = prev_k[:, :, :self.keep, :]
                rest_k = prev_k[:, :, self.keep + trim:, :]
                self.keys[layer_idx] = mx.concatenate([kept_k, rest_k, key], axis=2)
                # Same for values...

The "keep first N" strategy ensures system prompts and initial context are preserved. This is important for chat applications where the system prompt defines behavior.

Variant 3: Quantized Cache

For extreme memory constraints, 8-bit KV cache quantization:

class QuantizedKVCache:
    """KV Cache with 8-bit quantization."""

    def __init__(self, num_layers: int, bits: int = 8, group_size: int = 64):
        self.bits = bits
        self.group_size = group_size
        self.keys: List[Optional[mx.array]] = [None] * num_layers
        self.values: List[Optional[mx.array]] = [None] * num_layers
        self.key_scales: List[Optional[mx.array]] = [None] * num_layers
        self.value_scales: List[Optional[mx.array]] = [None] * num_layers

    def _quantize(self, x: mx.array) -> Tuple[mx.array, mx.array]:
        """Per-group absmax quantization."""
        shape = x.shape
        # Reshape to groups
        x_flat = x.reshape(*shape[:-1], -1, self.group_size)
        # Compute scale per group
        x_max = mx.max(mx.abs(x_flat), axis=-1, keepdims=True)
        scale = x_max / 127.0  # For int8 range [-127, 127]
        scale = mx.where(scale == 0, mx.ones_like(scale), scale)  # Avoid div by zero
        # Quantize
        x_q = mx.round(x_flat / scale).astype(mx.int8)
        return x_q, scale.squeeze(-1)

    def _dequantize(self, x_q: mx.array, scale: mx.array) -> mx.array:
        """Dequantize back to float."""
        x = x_q.astype(mx.bfloat16) * mx.expand_dims(scale, axis=-1)
        return x.reshape(*x.shape[:-2], -1)

    def update(self, layer_idx: int, key: mx.array, value: mx.array):
        # Quantize new K/V
        k_q, k_s = self._quantize(key)
        v_q, v_s = self._quantize(value)

        # Concatenate quantized values
        if self.keys[layer_idx] is None:
            self.keys[layer_idx] = k_q
            self.key_scales[layer_idx] = k_s
            # ... same for values
        else:
            self.keys[layer_idx] = mx.concatenate([self.keys[layer_idx], k_q], axis=2)
            self.key_scales[layer_idx] = mx.concatenate([self.key_scales[layer_idx], k_s], axis=2)
            # ... same for values

        # Return dequantized for attention (attention needs float)
        k_full = self._dequantize(self.keys[layer_idx], self.key_scales[layer_idx])
        v_full = self._dequantize(self.values[layer_idx], self.value_scales[layer_idx])
        return k_full, v_full

Trade-off: 2x memory savings (int8 vs BF16), ~3% quality degradation on GSM8K. The quantization error accumulates over long sequences, so this is best for shorter generations.


Part III: The RoPE Position Challenge

Day 2: Non-Sequential Positions (December 30, 2025)

This was the first major technical hurdle in translating the paper's "topological reordering" idea into an efficient RoPE implementation on MLX.

The Standard RoPE Assumption

RoPE is typically applied with a simple offset. Token 0 gets position 0, token 1 gets position 1, etc. MLX's fast.rope API reflects this:

# Standard RoPE application - sequential positions
q = fast.rope(q, head_dim, offset=cache_offset)  # Positions: offset, offset+1, offset+2, ...
k = fast.rope(k, head_dim, offset=cache_offset)

The offset parameter says "start counting from this position." Every subsequent position in the sequence is offset + 0, offset + 1, offset + 2, etc.

WeDLM's topological reordering breaks this assumption completely. After reordering:

Physical index:    0    1    2    3    4    5
Token:            [The][cat][on] [mat][M]  [M]
Logical position:  0    1    3    5    2    4    ← NOT sequential!

The mask at physical index 4 needs position 2. The mask at physical index 5 needs position 4. There's no single offset that produces [2, 4].

First Attempt: Python Loop (The Wrong Way)

The obvious initial approach—loop over positions:

def apply_rope_positions_slow(x, positions, head_dim, base):
    """WRONG: Per-token loop destroys performance."""
    B, H, L, D = x.shape
    result = mx.zeros_like(x)
    for i, pos in enumerate(positions):
        # Apply RoPE to each token individually
        token = x[:, :, i:i+1, :]  # Shape: (B, H, 1, D)
        rotated = fast.rope(token, head_dim, offset=int(pos), base=base)
        result[:, :, i, :] = rotated[:, :, 0, :]
    return result

Performance: 8 tokens/second.

To understand why this is catastrophic, you need to understand MLX's execution model. Each fast.rope call adds a node to MLX's computation graph. Python loops mean Python interpreter overhead between each node. When MLX evaluates the graph, each small operation might become a separate GPU kernel launch.

The overhead isn't the Python loop per se—it's that we're preventing MLX from fusing operations. Instead of one large RoPE kernel processing 16 positions, we have 16 tiny kernels with GPU launch overhead between each.

The Insight: Reshape to Fake Batching

The fast.rope API allows offset to be a 1D array with one offset per batch element, which enables per-position RoPE by reshaping the sequence dimension into the batch dimension.

# From MLX docs:
# offset: The position offset. Can be a scalar or a 1D array of shape (B,)

If offset can be per-batch, and I want per-position offsets... what if each position becomes its own batch element?

def apply_rope_positions(x: mx.array, positions: mx.array,
                         head_dim: int, base: float = 10000.0) -> mx.array:
    """Apply RoPE with explicit position indices - THE FAST WAY."""
    B, H, L, D = x.shape

    # Original shape: (B, H, L, D) where:
    #   B = batch size (usually 1)
    #   H = num heads
    #   L = sequence length (positions to process)
    #   D = head dimension

    # Step 1: Reshape so each position is its own "batch"
    # (B, H, L, D) → (B, L, H, D) → (B*L, H, 1, D)
    x_reshaped = x.transpose(0, 2, 1, 3)  # (B, L, H, D)
    x_reshaped = x_reshaped.reshape(B * L, H, 1, D)  # (B*L, H, 1, D)

    # Step 2: Expand positions for the fake batch dimension
    # positions: (L,) → (B*L,) by tiling for each original batch
    if B > 1:
        positions = mx.tile(positions, (B,))  # Now shape (B*L,)

    # Step 3: Apply RoPE with per-"batch" offsets
    # Each of the B*L "batch" elements gets its own offset from positions
    # The sequence dimension is 1, so fast.rope applies that offset to the single position
    result = fast.rope(
        x_reshaped,  # (B*L, H, 1, D)
        head_dim,
        traditional=False,
        base=base,
        scale=1.0,
        offset=positions  # (B*L,) - one offset per "batch" element
    )

    # Step 4: Reshape back to original layout
    # (B*L, H, 1, D) → (B, L, H, D) → (B, H, L, D)
    result = result.reshape(B, L, H, D)  # (B, L, H, D)
    result = result.transpose(0, 2, 1, 3)  # (B, H, L, D)

    return result

The trick: by reshaping so each position is a separate "batch" element with sequence length 1, the per-batch offset becomes a per-position offset.

Tracing through a concrete example:

Input: x of shape (1, 32, 4, 128)  # batch=1, 32 heads, 4 positions, dim=128
Positions: [0, 1, 3, 5]  # Non-sequential logical positions

Step 1 - Transpose: (1, 4, 32, 128)
Step 2 - Reshape: (4, 32, 1, 128)  # 4 "batches", 32 heads, 1 position each
Step 3 - Apply RoPE with offset=[0, 1, 3, 5]
         - "Batch" 0 gets position 0
         - "Batch" 1 gets position 1
         - "Batch" 2 gets position 3
         - "Batch" 3 gets position 5
Step 4 - Reshape back: (1, 32, 4, 128)

Performance: 35 tokens/second. A 4.4x improvement from changing the reshape pattern.

Why This Works

MLX's fast.rope is implemented as a single Metal kernel that processes all batch elements in parallel. By reshaping our positions into the batch dimension, we're telling MLX "run this kernel with B*L parallel work items, each with its own offset."

The alternative (Python loop) says "run L separate kernels sequentially." Even if each kernel is fast, the sequential launches and Python overhead kill performance.

This pattern—reshaping to exploit batch parallelism—is a recurring theme in MLX optimization.


Part IV: The Synchronization Trap

Understanding MLX's Lazy Evaluation

MLX's lazy evaluation is the single most impactful optimization target in this project.

How MLX Executes Operations

MLX uses lazy evaluation: operations don't execute immediately. Instead, they build a computation graph that's optimized and executed when you actually need the values.

import mlx.core as mx

a = mx.array([1.0, 2.0, 3.0])
b = mx.array([4.0, 5.0, 6.0])

c = mx.add(a, b)      # No computation - graph node created
d = mx.multiply(c, 2)  # No computation - another graph node
e = mx.sum(d)          # No computation - another graph node

print("Graph built, nothing computed yet")

mx.eval(e)  # NOW computation happens
# MLX sees: sum(multiply(add(a, b), 2))
# Optimizes: potentially fuses operations, optimizes memory layout
# Executes: single optimized GPU dispatch

This is powerful because MLX can:

  1. Fuse operations: Combine add + multiply into a single kernel
  2. Eliminate intermediates: c and d might never be materialized in memory
  3. Optimize memory access: Read inputs once, write output once

The Trap: Implicit Synchronization

The problem: several Python operations require actual values, forcing MLX to synchronize:

# Each of these forces synchronization:
x.item()      # Returns Python scalar
x.tolist()    # Returns Python list
len(x)        # Needs to know actual size
if x > 0:     # Boolean requires evaluation
for i in x:   # Iteration requires values

An example problematic decode loop:

# ANTI-PATTERN: 32 synchronization points per decode step
def decode_step_slow(logits, window_size, threshold):
    confirmed = []
    for i in range(window_size):  # window_size = 16
        logits_i = logits[0, i]

        # Sync #1: .item() forces evaluation
        token_id = mx.argmax(logits_i).item()

        # Sync #2: Another .item()
        probs = mx.softmax(logits_i, axis=-1)
        entropy = -mx.sum(probs * mx.log(probs + 1e-10)).item()

        if entropy < threshold:
            confirmed.append((i, token_id))

    return confirmed

Each .item() call:

  1. Flushes the pending computation graph
  2. Dispatches work to the GPU
  3. Waits for GPU to complete
  4. Transfers scalar result to CPU

With 16 window positions × 2 .item() calls = 32 GPU round-trips per decode step.

If each round-trip takes even 0.5ms (and they often take more), that's 16ms overhead per step—completely dwarfing the actual compute time.

The Fix: Batched Operations with Single Sync

The solution is to batch all operations and synchronize once:

def sample_tokens_batch(logits: mx.array, temperature: float = 0.0) -> Tuple[mx.array, mx.array]:
    """Sample all tokens in batch - single sync at the end."""
    # logits shape: (window_size, vocab_size)

    # All operations stay on GPU, building computation graph
    if temperature == 0:
        # Greedy: argmax over vocabulary dimension
        token_ids = mx.argmax(logits, axis=-1)  # Shape: (W,)
    else:
        # Sampling: scale logits by temperature, sample categorically
        scaled = logits / temperature
        token_ids = mx.random.categorical(scaled, axis=-1)  # Shape: (W,)

    # Compute all entropies at once
    probs = mx.softmax(logits, axis=-1)  # Shape: (W, V)
    log_probs = mx.log(probs + 1e-10)    # Shape: (W, V)
    entropies = -mx.sum(probs * log_probs, axis=-1)  # Shape: (W,)

    return token_ids, entropies

# Usage in decode loop:
def decode_step_fast(logits, window_size, threshold):
    # Get mask position logits
    mask_logits = logits[0]  # Shape: (window_size, vocab_size)

    # Batch sample - no sync yet
    token_ids, entropies = sample_tokens_batch(mask_logits, temperature=0.0)

    # SINGLE sync point - entire graph is fused and executed
    mx.eval(token_ids, entropies)

    # NOW convert to Python (arrays are materialized, this is just memory copy)
    token_ids_list = token_ids.tolist()  # Fast - data already on CPU
    entropies_list = entropies.tolist()  # Fast - data already on CPU

    # Python logic with actual values
    confirmed = []
    for i in range(window_size):
        if entropies_list[i] < threshold:
            confirmed.append((i, token_ids_list[i]))

    return confirmed

Performance impact: 20 tok/s → 35 tok/s (75% improvement).

The key insight: MLX's graph optimization can fuse the softmax, log, multiply, and sum into a single efficient kernel. By synchronizing once instead of 32 times, we let MLX do what it's designed to do.

The General Rule

MLX Performance Principles:

1. Build graphs freely - it's just Python object creation
2. Minimize mx.eval() calls - each one dispatches work and waits
3. NEVER use .item() in loops - it's an implicit mx.eval()
4. Batch operations over dimensions instead of looping
5. .tolist() after eval is fine - data is already materialized

Sync points kill performance. Batch everything.

A Debugging Pattern

When debugging MLX performance issues, instrument sync points:

import time

def timed_eval(*arrays):
    """Evaluate arrays and print timing."""
    start = time.perf_counter()
    mx.eval(*arrays)
    elapsed = (time.perf_counter() - start) * 1000
    print(f"eval took {elapsed:.2f}ms for {len(arrays)} arrays")
    return arrays

# Use in code:
token_ids, entropies = sample_tokens_batch(logits)
timed_eval(token_ids, entropies)

If you see many eval took Xms prints per generation step, you have too many sync points.


Part V: The Diffusion Decoding Algorithm

Day 2 (continued): Implementing Streaming Parallel Decoding

With the core model and RoPE working, the next challenge was implementing WeDLM's actual decoding algorithm. This is where the paper's ideas become concrete code.

The Decoding State Machine

WeDLM maintains a sliding window of mask tokens. At each step:

  1. Forward pass on the window (with topological reordering)
  2. Sample tokens for each mask position
  3. Compute entropy (confidence) for each position
  4. Commit confident positions to the output
  5. Slide the window, adding new masks

The state definition:

@dataclass
class DecoderState:
    """State for diffusion decoding."""
    # Confirmed tokens (after prompt, before window)
    confirmed_tokens: List[int]
    # Current window tokens (being decoded)
    window_tokens: List[int]
    # Which positions in window are still masks
    window_mask_flags: List[bool]
    # Total generated count
    num_generated: int
    # Finished flag
    is_finished: bool

@dataclass
class SamplingParams:
    temperature: float = 0.0        # 0 = greedy
    top_p: float = 1.0              # Nucleus sampling threshold
    max_tokens: int = 512           # Maximum tokens to generate
    entropy_threshold: float = 0.3  # Below this = "confident"
    pos_penalty_factor: float = 0.01  # Penalty for later positions

Window Visualization

A concrete example illustrates the algorithm. Starting with prompt "Hello" and window size 4:

Initial State:
  Cache: [Hello]  (prefilled)
  Window: [MASK][MASK][MASK][MASK]
  Flags:  [True][True][True][True]

Step 1 - Forward pass on window:
  Input to model: [MASK][MASK][MASK][MASK] with positions [5, 6, 7, 8]
  Logits: predictions for all 4 positions

  Sampling results:
    Position 0: token="world", entropy=0.1  ← Confident (< 0.3)
    Position 1: token=",",     entropy=0.2  ← Confident
    Position 2: token="how",   entropy=0.8  ← Not confident
    Position 3: token="are",   entropy=0.9  ← Not confident

Step 1 - Update state:
  Window: [world][,][MASK][MASK]
  Flags:  [False][False][True][True]

Step 1 - Commit leading confirmed:
  Positions 0 and 1 are consecutive from start → commit both
  confirmed_tokens: [world, ,]
  Cache: [Hello, world, ,]

  Slide window:
  Window: [MASK][MASK][MASK][MASK]  ← Remaining masks + new masks
  Flags:  [True][True][True][True]

Step 2 - Forward pass:
  Window still has masks at logical positions [7, 8, 9, 10]
  But wait - positions 7 and 8 were predicted as "how" and "are"

  Actually, the window keeps the non-committed predictions:
  Window: [how][are][MASK][MASK]  ← Previous predictions + new masks
  Flags:  [True][True][True][True]  ← All treated as masks for sampling

  Topological reorder for attention:
    Observed: none in window (all masks)
    Input: [how][are][MASK][MASK] with positions [7, 8, 9, 10]

  New sampling might produce:
    Position 0: token="how",  entropy=0.15 ← Now confident!
    Position 1: token="you",  entropy=0.1  ← Changed and confident
    Position 2: token="?",    entropy=0.2  ← Confident
    Position 3: token="I",    entropy=0.6  ← Not confident

  Window: [how][you][?][I]
  Flags:  [False][False][False][True]

  Commit leading: "how", "you", "?"
  confirmed_tokens: [world, ,, how, you, ?]

This illustrates two key WeDLM behaviors:

  1. Iterative refinement: "are" became "you" on the second pass because the model had more context
  2. Left-to-right commitment: Only leading consecutive confirmed tokens are committed, preserving cache validity

The Core Decode Step

Here's the actual implementation:

def decode_step(
    self,
    state: DecoderState,
    params: SamplingParams,
    cache: KVCache,
    prompt_len: int,
) -> Tuple[DecoderState, List[int]]:
    """Execute one decoding step with topological reordering."""

    if state.is_finished:
        return state, []

    # Separate filled and mask positions in current window
    non_mask_indices = [i for i, flag in enumerate(state.window_mask_flags) if not flag]
    mask_indices = [i for i, flag in enumerate(state.window_mask_flags) if flag]

    cache_len = cache.get_seq_len()  # Prompt + committed tokens

    # Two code paths depending on whether we have filled positions
    if len(non_mask_indices) == 0:
        # Fast path: all positions are masks, no reordering needed
        window_ids = mx.array([state.window_tokens])
        logits = self.model.decode_window(window_ids, cache_len, cache)
        mask_logits = logits[0]  # Shape: (window_size, vocab_size)
    else:
        # Topological reordering: filled positions first, masks last
        order = non_mask_indices + mask_indices
        ordered_tokens = [state.window_tokens[i] for i in order]
        ordered_positions = [cache_len + i for i in order]  # Logical positions!

        window_ids = mx.array([ordered_tokens])
        positions = mx.array(ordered_positions)

        # Forward with per-token positions (uses apply_rope_positions)
        logits = self.model.decode_window_reordered(window_ids, positions, cache)

        # Extract logits for mask positions only
        num_non_mask = len(non_mask_indices)
        mask_logits = logits[0, num_non_mask:]  # Only predict masks

    # Batched sampling - single sync point
    token_ids, entropies = self.sample_tokens_batch(
        mask_logits, params.temperature, params.top_p
    )
    mx.eval(token_ids, entropies)

    # Convert to Python for threshold logic
    token_ids_list = token_ids.tolist()
    entropies_list = entropies.tolist()

    # Apply position penalty to encourage left-to-right commitment
    adjusted_entropies = []
    for k, orig_idx in enumerate(mask_indices):
        token_id = token_ids_list[k]
        entropy = entropies_list[k]
        # Penalty increases with position in window
        adjusted = entropy + orig_idx * params.pos_penalty_factor
        adjusted_entropies.append((orig_idx, adjusted, token_id))

    # Update window with confident predictions
    new_window_tokens = list(state.window_tokens)
    new_mask_flags = list(state.window_mask_flags)

    positions_to_fill = []
    for orig_idx, adj_entropy, token_id in adjusted_entropies:
        if adj_entropy < params.entropy_threshold:
            positions_to_fill.append((orig_idx, token_id))

    # Always fill at least one position (to guarantee progress)
    if not positions_to_fill and adjusted_entropies:
        best = min(adjusted_entropies, key=lambda x: x[1])
        positions_to_fill = [(best[0], best[2])]

    for pos, token_id in positions_to_fill:
        new_window_tokens[pos] = token_id
        new_mask_flags[pos] = False

    # Count leading confirmed positions (consecutive non-masks from start)
    num_front_confirmed = 0
    for i in range(len(new_window_tokens)):
        if not new_mask_flags[i]:
            num_front_confirmed += 1
        else:
            break

    # Commit leading confirmed tokens
    new_confirmed = list(state.confirmed_tokens)
    committed_tokens = []

    if num_front_confirmed > 0:
        committed_tokens = new_window_tokens[:num_front_confirmed]
        new_confirmed.extend(committed_tokens)

        # Update cache with committed tokens' K/V
        self.model.extend_cache(mx.array([committed_tokens]), cache)

        # Slide window: remove committed, add new masks
        new_window_tokens = new_window_tokens[num_front_confirmed:] + \
                            [self.mask_token_id] * num_front_confirmed
        new_mask_flags = new_mask_flags[num_front_confirmed:] + \
                         [True] * num_front_confirmed

    # Check termination
    new_num_generated = state.num_generated + len(committed_tokens)
    is_finished = new_num_generated >= params.max_tokens

    new_state = DecoderState(
        confirmed_tokens=new_confirmed,
        window_tokens=new_window_tokens,
        window_mask_flags=new_mask_flags,
        num_generated=new_num_generated,
        is_finished=is_finished,
    )

    return new_state, committed_tokens

The Position Penalty Explained

WeDLM uses distance-penalized selection to encourage left-to-right commitment:

adjusted_entropyi=entropyi+iλ\text{adjusted\_entropy}_i = \text{entropy}_i + i \cdot \lambda

Where λ\lambda is the position penalty factor (default 0.01).

Why does this matter? Consider a window where:

  • Position 0: entropy 0.25
  • Position 2: entropy 0.20
  • Position 5: entropy 0.15

Without penalty, position 5 would be most confident. But committing position 5 before 0, 1, 2, 3, 4 means those positions' K/V can't be cached yet—they depend on the uncommitted prefix.

With penalty (λ=0.01\lambda = 0.01):

  • Position 0: 0.25 + 0×0.01 = 0.250
  • Position 2: 0.20 + 2×0.01 = 0.220
  • Position 5: 0.15 + 5×0.01 = 0.200

Position 5 is still lowest! But now consider:

  • Position 0: entropy 0.20
  • Position 5: entropy 0.15

Without penalty: Position 5 wins (0.15 < 0.20) With penalty: Position 0 wins (0.20 < 0.20) ← tie goes to earlier

The penalty biases toward earlier positions when confidence is similar, maximizing pcachep_{\text{cache}}.

The decode_window_reordered Method

The model needs a special method for reordered windows that applies per-token positions:

def decode_window_reordered(
    self,
    window_ids: mx.array,
    positions: mx.array,  # 1D array of logical positions
    cache: KVCache,
) -> mx.array:
    """Decode with explicit per-token positions (for topological reordering)."""
    B, W = window_ids.shape
    cache_len = cache.get_seq_len()

    # Create causal mask for window attending to cache + self
    total_len = cache_len + W
    mask = mx.triu(
        mx.full((W, total_len), float('-inf'), dtype=mx.bfloat16),
        k=cache_len + 1  # Window tokens attend to cache + prior window tokens
    )
    mask = mx.expand_dims(mask, axis=(0, 1))

    # Embedding
    x = self.embed_tokens(window_ids)

    # Forward through layers with per-token RoPE
    for layer in self.layers:
        layer_idx = layer.layer_idx
        cached_k, cached_v = cache.get_layer_kv(layer_idx)

        residual = x
        x_norm = layer.input_layernorm(x)

        attn = layer.self_attn
        q = attn.q_proj(x_norm).reshape(B, W, attn.num_heads, attn.head_dim)
        k = attn.k_proj(x_norm).reshape(B, W, attn.num_kv_heads, attn.head_dim)
        v = attn.v_proj(x_norm).reshape(B, W, attn.num_kv_heads, attn.head_dim)

        if attn.qk_norm:
            q = attn.q_norm(q)
            k = attn.k_norm(k)

        q = q.transpose(0, 2, 1, 3)
        k = k.transpose(0, 2, 1, 3)
        v = v.transpose(0, 2, 1, 3)

        # KEY DIFFERENCE: Apply RoPE with explicit positions instead of offset
        q = apply_rope_positions(q, positions, attn.head_dim, attn.rope_theta)
        k = apply_rope_positions(k, positions, attn.head_dim, attn.rope_theta)

        # Concatenate with cache
        if cached_k is not None:
            k_full = mx.concatenate([cached_k, k], axis=2)
            v_full = mx.concatenate([cached_v, v], axis=2)
        else:
            k_full, v_full = k, v

        attn_out = fast.scaled_dot_product_attention(q, k_full, v_full, scale=attn.scale, mask=mask)
        attn_out = attn_out.transpose(0, 2, 1, 3).reshape(B, W, -1)
        x = residual + attn.o_proj(attn_out)

        residual = x
        x = residual + layer.mlp(layer.post_attention_layernorm(x))

    return self.lm_head(self.norm(x))

The Prefill Phase

Before decode starts, we need to process the prompt and initialize the cache:

def prefill(
    self,
    input_ids: mx.array,
    cache: Optional[KVCache] = None,
    chunk_size: int = 512,
) -> Tuple[mx.array, KVCache]:
    """Process prompt and initialize KV cache."""
    B, L = input_ids.shape
    if cache is None:
        cache = self.create_cache(batch_size=B)

    # Short prompts: process all at once
    if L <= chunk_size:
        logits = self(input_ids, offset=0, mask="causal", cache=cache)
        cache.finalize_update(L)
        return logits, cache

    # Long prompts: chunked prefill to limit memory
    offset = 0
    for start in range(0, L, chunk_size):
        end = min(start + chunk_size, L)
        chunk = input_ids[:, start:end]
        logits = self(chunk, offset=offset, mask="causal", cache=cache)
        offset = cache.get_seq_len()
        mx.eval(logits)  # Prevent graph from growing too large

    return logits, cache

Chunked prefill is important for long prompts. Without it, the computation graph grows with prompt length, potentially exhausting memory before execution even starts.


Part VI: The Threading SIGSEGV Incident

Day 3, Morning: Building a Comparison Benchmark

January 2, 2026

After achieving 35 tok/s with WeDLM's parallel decoding, I needed to answer the obvious question: how does this compare to autoregressive decoding on the same hardware?

The plan was simple: build a benchmark tool to run both WeDLM and a standard Qwen3-8B through mlx-lm, then compare throughputs. For efficiency, running them in parallel seemed reasonable—two independent models should be able to run simultaneously.

# chat_bench.py - First (fatally flawed) version
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import List, Tuple
import time

# Global models (loaded once, reused)
wedlm_model = None
ar_model = None

def load_models():
    """Load both models at startup."""
    global wedlm_model, ar_model
    print("Loading WeDLM-8B...")
    wedlm_model = WeDLMDecoder.from_pretrained("./wedlm-8b-mlx-4bit")
    print("Loading Qwen3-8B baseline...")
    ar_model, ar_tokenizer = load("mlx-community/Qwen3-8B-4bit")
    print("Both models loaded.")

def run_comparison(prompt: str):
    wedlm_result = []
    qwen_result = []

    def run_wedlm():
        wedlm_result.append(wedlm_generate(prompt))

    def run_qwen():
        qwen_result.append(qwen_generate(prompt))

    t1 = threading.Thread(target=run_wedlm)
    t2 = threading.Thread(target=run_qwen)

    t1.start()
    t2.start()  # SIGSEGV here

    t1.join()
    t2.join()

Result: SIGSEGV (segmentation fault).

The failure was a process-terminating SIGSEGV rather than a Python exception, with no stack trace or diagnostic message. It reproduced consistently as soon as both threads initiated concurrent forward passes.

The Debugging Journey

Print statements helped isolate the problem:

def run_wedlm_generation(prompt: str):
    print(f"[WeDLM] Starting generation for: {prompt[:30]}...")
    print(f"[WeDLM] Tokenizing...")
    tokens = wedlm_model.tokenizer.encode(prompt)
    print(f"[WeDLM] Token count: {len(tokens)}")
    print(f"[WeDLM] Starting prefill...")
    # CRASH HAPPENS HERE OR IN AR'S EQUIVALENT

The crash consistently occurred when both threads tried to call mx.eval() simultaneously—specifically during their respective forward passes.

Verification that the issue wasn't application-specific came from trying the same pattern with two instances of the same mlx-lm model:

from mlx_lm import load, generate

model, tokenizer = load("mlx-community/Qwen3-8B-4bit")

def gen_thread_1():
    return generate(model, tokenizer, prompt="Hello", max_tokens=10)

def gen_thread_2():
    return generate(model, tokenizer, prompt="World", max_tokens=10)

with ThreadPoolExecutor(max_workers=2) as executor:
    f1 = executor.submit(gen_thread_1)
    f2 = executor.submit(gen_thread_2)
    # SIGSEGV

Same crash. Not application-specific—something fundamental about MLX and threading.

Root Cause: Metal's Threading Model

MLX GitHub issues and Metal documentation reference this problem. Metal has a specific threading model:

Thread-safe:

  • MTLDevice (the GPU device handle)
  • MTLCommandQueue (the queue for submitting work)

NOT thread-safe:

  • MTLCommandBuffer (individual buffers of GPU commands)
  • MTLCommandEncoder (encoders for filling buffers)

MLX's execution model:

  1. Build a lazy computation graph (thread-safe—just building Python objects)
  2. On mx.eval():
    • Compile graph to Metal commands
    • Create command buffers
    • Submit to queue
    • Wait for completion

The problem: MLX uses shared global state for command buffer management. When two threads call mx.eval() simultaneously:

Thread 1                          Thread 2
--------                          --------
Enter mx.eval()                   Enter mx.eval()
Get shared command buffer         Get shared command buffer (same one!)
Start encoding commands           Start encoding commands
  ↓                                 ↓
Both writing to same buffer → RACE CONDITION → Memory corruption → SIGSEGV

The Solution: Sequential Execution

def run_comparison_safe(prompt: str):
    """Sequential execution - boring but correct."""
    print("Running WeDLM...")
    wedlm_result = wedlm_generate(prompt)
    wedlm_tokens = count_tokens(wedlm_result)

    print("Running Qwen AR...")
    qwen_result = qwen_generate(prompt)
    qwen_tokens = count_tokens(qwen_result)

    return wedlm_result, qwen_result

It works. It's twice as slow in wall-clock time (running sequentially instead of in parallel), but benchmarking doesn't need parallelism—we just need accurate throughput measurements.

Alternative: Process-Level Parallelism

If you really need concurrent model execution, separate processes work because each gets its own Metal context:

import multiprocessing as mp
from multiprocessing import Queue

def wedlm_process(prompt: str, result_queue: Queue):
    """Run WeDLM in a separate process."""
    # Load model fresh (can't share across processes)
    model = WeDLMDecoder.from_pretrained("./wedlm-8b-mlx-4bit")
    start = time.perf_counter()
    response = model.generate(prompt, max_tokens=256)
    elapsed = time.perf_counter() - start
    tokens = len(model.tokenizer.encode(response.text))
    result_queue.put(("wedlm", response.text, tokens / elapsed))

# Separate processes = separate Metal contexts
q1, q2 = multiprocessing.Queue(), multiprocessing.Queue()
p1 = multiprocessing.Process(target=run_in_process, args=(wedlm_generate, prompt, q1))
p2 = multiprocessing.Process(target=run_in_process, args=(qwen_generate, prompt, q2))

Each process gets its own Metal context. But the memory overhead is significant—each process loads the full model. For benchmarking where we just want to measure throughputs, sequential is preferable.

GPU API Threading Model

GPU APIs are designed for explicit concurrency control, not implicit threading.

In CUDA/Metal/Vulkan, the programmer explicitly:

  • Creates command queues
  • Allocates command buffers
  • Submits work in a controlled order
  • Synchronizes across queues

High-level frameworks like MLX, PyTorch, and TensorFlow abstract this away for the common case (single-threaded Python), but the abstraction breaks with Python threading.

Investigating the crash clarified that MLX's Metal command buffer management is not thread-safe under concurrent mx.eval calls, which in turn informed later decisions to keep inference single-threaded within a process.


Part VII: Quantization — The Quality-Speed Trade-off

Day 3, Afternoon: Memory Pressure

January 2, 2026 (continued)

WeDLM-8B in BF16 requires approximately 16GB for weights alone:

Parameter count: 8.03B
Bytes per param (BF16): 2
Total weight memory: 8.03B × 2 = 16.06 GB

Plus runtime:
  KV Cache (grows with sequence): ~150MB per 1K tokens
  Activations (batch=1, seq=2K): ~200MB
  MLX overhead: ~500MB

Total for 2K context: ~17-18 GB

An M3 Max has 48GB unified memory, so this fits comfortably. But unified memory is shared with the system and other applications. Running Chrome, VSCode, and a 17GB model simultaneously creates memory pressure—swap usage appeared at longer contexts.

4-bit quantization reduces weight memory to ~4GB, making inference much more comfortable. The question is: what does it cost in quality?

Understanding Group-Wise Quantization

MLX uses group-wise min-max quantization:

# Conceptual implementation of 4-bit group quantization
def quantize_weights(weights: mx.array, group_size: int = 64, bits: int = 4):
    """Quantize weights using group-wise min-max scaling."""
    # weights shape: (out_features, in_features)

    # Reshape into groups along input dimension
    groups = weights.reshape(-1, group_size)  # (num_groups, group_size)

    # Compute per-group scale factors
    max_vals = mx.max(mx.abs(groups), axis=1, keepdims=True)

    # For 4-bit: representable values are -7 to +7
    num_levels = 2 ** (bits - 1) - 1  # 7 for 4-bit
    scales = max_vals / num_levels

    # Quantize to integers
    quantized = mx.round(groups / scales)
    quantized = mx.clip(quantized, -num_levels, num_levels).astype(mx.int8)

    return quantized, scales  # Store both for dequantization

def dequantize_weights(quantized: mx.array, scales: mx.array, group_size: int = 64):
    """Dequantize during forward pass."""
    # Multiply quantized values by their group's scale factor
    dequantized = quantized.astype(mx.float16) * scales
    return dequantized.reshape(original_shape)

The key insight: each group of 64 weights shares a single scale factor. This means:

  • Storage: 4 bits per weight + 16 bits per 64 weights for scale = 4.25 bits/weight effective
  • Reconstruction: weight_approx = quantized_int * scale
  • Error: bounded by scale / 2 per weight

Group size is a trade-off:

  • Smaller groups (e.g., 32): More scales stored, better precision, more memory
  • Larger groups (e.g., 128): Fewer scales, worse precision, less memory

MLX defaults to group_size=64, which is a reasonable middle ground.

First Attempt: Quantize Everything

import mlx.nn as nn
from mlx_wedlm.model import WeDLMModel

# Load BF16 model
model = WeDLMModel.from_pretrained("./wedlm-8b-bf16")

# Simple quantization - all linear layers
nn.quantize(model, bits=4, group_size=64)

# Save quantized model
model.save_weights("./wedlm-8b-4bit/weights.safetensors")

Performance was impressive: 60 tok/s (vs 35 tok/s BF16). A 1.7x speedup just from quantization!

GSM8K (grade school math) testing revealed problems:

Prompt: "What is 15 + 27?"
BF16:   "15 + 27 = 42. The answer is 42."
4-bit:  "15 + 27 = 44. The answer is 44."  # Wrong!

Prompt: "A baker has 23 cupcakes. She sells 8. How many are left?"
BF16:   "23 - 8 = 15. The baker has 15 cupcakes left."
4-bit:  "23 - 8 = 16. The baker has 16 cupcakes left."  # Wrong again!

Simple arithmetic errors. The model was no longer reliable for basic math.

Diagnosing the Problem

Understanding which layers were most sensitive to quantization required a layer-by-layer ablation:

def measure_layer_sensitivity(model, layer_name: str, test_prompts: List[str]):
    """Quantize a single layer and measure impact on output quality."""
    # Save original weights
    original_weights = get_layer_weights(model, layer_name).copy()

    # Quantize just this layer
    quantize_single_layer(model, layer_name, bits=4)

    # Measure quality degradation
    score = evaluate_on_prompts(model, test_prompts)

    # Restore original weights
    set_layer_weights(model, layer_name, original_weights)

    return score

# Test each layer type
sensitivity_scores = {}
for layer_name in ["embed_tokens", "lm_head", "q_proj", "k_proj", "v_proj",
                   "o_proj", "gate_proj", "up_proj", "down_proj"]:
    score = measure_layer_sensitivity(model, layer_name, math_test_prompts)
    sensitivity_scores[layer_name] = score
    print(f"{layer_name}: {score:.2f}%")

Results were illuminating:

Layer Accuracy (4-bit) Accuracy (BF16) Degradation
embed_tokens 75% 80% -5%
lm_head 68% 80% -12%
q_proj (all) 79% 80% -1%
k_proj (all) 79% 80% -1%
v_proj (all) 79% 80% -1%
o_proj (all) 79% 80% -1%
gate_proj (all) 78% 80% -2%
up_proj (all) 78% 80% -2%
down_proj (all) 78% 80% -2%

The lm_head layer was the culprit. Quantizing it alone caused a 12% accuracy drop, while attention projections caused only 1% each.

Why lm_head Is Special

The lm_head layer projects hidden states to vocabulary logits:

hidden_state (4096-dim) → lm_head (4096 × 151936) → logits (151936-dim)
                          ↑
                          This layer is HUGE: 620M parameters
                          And the final "decision boundary"

The size alone is concerning (620M params in one layer), but the real issue is the nature of the computation:

# At inference, we pick the argmax token
logits = lm_head(hidden_state)  # 151936 logits
token_id = mx.argmax(logits).item()

The argmax is discontinuous—a tiny change in logits can completely change which token wins:

Scenario 1 (BF16):
  logit[token_42] = 5.0001
  logit[token_44] = 5.0000
  argmax → token_42 ✓

Scenario 2 (quantized lm_head):
  logit[token_42] = 4.9998  # Quantization error shifted it down
  logit[token_44] = 5.0002  # Quantization error shifted it up
  argmax → token_44 ✗

A 0.0003 error flipped the answer. For attention layers, quantization error averages out over the softmax. For lm_head, it directly corrupts the output distribution.

Selective Quantization Implementation

The fix: keep sensitive layers in BF16, quantize everything else:

def quantize_model_selective(
    model: nn.Module,
    bits: int = 4,
    group_size: int = 64,
    exclude_patterns: List[str] = ["lm_head", "embed_tokens"]
) -> nn.Module:
    """Quantize with selective layer exclusion."""

    def should_quantize(name: str, module: nn.Module) -> bool:
        """Predicate: should this layer be quantized?"""
        # Only quantize Linear layers
        if not isinstance(module, nn.Linear):
            return False

        # Skip layers matching exclude patterns
        for pattern in exclude_patterns:
            if pattern in name:
                print(f"  Keeping {name} in BF16 (matches '{pattern}')")
                return False

        return True

    # Count what we're quantizing
    total_params = sum(p.size for p in model.parameters())

    # Apply quantization with predicate
    nn.quantize(
        model,
        bits=bits,
        group_size=group_size,
        class_predicate=should_quantize
    )

    # Report what was excluded
    excluded_params = sum(
        p.size for name, p in model.named_parameters()
        if any(pat in name for pat in exclude_patterns)
    )
    quantized_params = total_params - excluded_params

    print(f"Quantized: {quantized_params/1e9:.2f}B params ({bits}-bit)")
    print(f"Excluded: {excluded_params/1e9:.2f}B params (BF16)")
    print(f"Effective bits per param: {(quantized_params * bits + excluded_params * 16) / total_params:.2f}")

    return model

# Usage
model = WeDLMModel.from_pretrained("./wedlm-8b-bf16")
model = quantize_model_selective(model, bits=4, exclude_patterns=["lm_head", "embed_tokens"])
model.save_weights("./wedlm-8b-4bit-selective/weights.safetensors")

Output:

  Keeping embed_tokens in BF16 (matches 'embed_tokens')
  Keeping lm_head in BF16 (matches 'lm_head')
Quantized: 7.03B params (4-bit)
Excluded: 1.00B params (BF16)
Effective bits per param: 5.50

Memory Analysis

Breaking down the memory savings:

Component          | BF16        | 4-bit selective | Savings
-------------------+-------------+-----------------+----------
Attention projs    | 2.4 GB      | 0.6 GB          | 75%
FFN (gate/up/down) | 11.2 GB     | 2.8 GB          | 75%
embed_tokens       | 1.2 GB      | 1.2 GB          | 0%
lm_head            | 1.2 GB      | 1.2 GB          | 0%
-------------------+-------------+-----------------+----------
Total weights      | 16.0 GB     | 5.8 GB          | 64%

The 64% reduction is less than pure 4-bit would give (75%), but we keep the quality-critical layers intact.

Final Results

Configuration Memory Speed GSM8K Accuracy Notes
BF16 (baseline) 16 GB 35 tok/s 80% Reference quality
4-bit (all layers) 4 GB 60 tok/s 68% Unacceptable math errors
4-bit (selective) 5.8 GB 55 tok/s 78% Good trade-off
8-bit (all layers) 8 GB 45 tok/s 79% Conservative option

The selective 4-bit approach recovers most of the quality (-2% vs BF16) while retaining most of the speedup (55 vs 60 tok/s). For a local chat assistant where occasional math errors are tolerable, this is the optimal trade-off.

The Quantization Trade-off Framework

The investigation yielded a framework for quantization decisions:

Layer Type              | Quantization Tolerance | Reasoning
------------------------+-----------------------+----------------------------------
Attention Q/K/V proj    | High                   | Error diffuses over softmax
Attention O proj        | High                   | Linear transform, error averages
FFN projections         | Medium                 | Non-linearity can amplify error
Layer norms             | Low                    | Statistics-dependent, keep BF16
Embeddings              | Medium                 | Large vocab, but averaged in
lm_head                 | Very Low               | Directly determines output token

For production models:

  • Always keep lm_head in high precision
  • Consider keeping embeddings if vocab size is large
  • Safe to quantize attention and FFN projections
  • Monitor downstream task accuracy

Part VIII: The `mx.compile` Investigation

Day 4: Chasing the Last Optimization

January 3, 2026

Every MLX tutorial mentions mx.compile as a performance optimization. The promise: JIT compilation reduces Python overhead, enables kernel fusion, and speeds up repeated operations. After the other optimizations, I was at 55 tok/s. Could mx.compile push me to 70+?

The Promise

The MLX documentation shows impressive examples:

import mlx.core as mx

def slow_function(x, y):
    # Multiple operations, multiple GPU dispatches
    z = mx.add(x, y)
    z = mx.multiply(z, x)
    z = mx.sin(z)
    z = mx.exp(z)
    return z

# Without compile: 4 separate operations
result = slow_function(x, y)

# With compile: fused into optimal kernel
@mx.compile
def fast_function(x, y):
    z = mx.add(x, y)
    z = mx.multiply(z, x)
    z = mx.sin(z)
    z = mx.exp(z)
    return z

result = fast_function(x, y)  # Single fused kernel

For functions called repeatedly with same-shaped inputs, mx.compile traces the computation graph once, optimizes it, and reuses the compiled version.

Implementation Attempt #1: Compile the Full Forward Pass

The first attempt was ambitious—compile the entire decode step:

from functools import partial

def create_compiled_forward(model: WeDLMModel):
    """Create a compiled forward pass for decoding."""

    def forward_fn(input_ids: mx.array, positions: mx.array,
                   cached_keys: List[mx.array], cached_values: List[mx.array],
                   mask: mx.array) -> Tuple[mx.array, List[mx.array], List[mx.array]]:
        """The function we want to compile."""
        B, L = input_ids.shape
        x = model.embed_tokens(input_ids)

        new_keys, new_values = [], []

        for i, layer in enumerate(model.layers):
            # Attention
            residual = x
            x_norm = layer.input_layernorm(x)

            attn = layer.self_attn
            q = attn.q_proj(x_norm).reshape(B, L, attn.num_heads, attn.head_dim)
            k = attn.k_proj(x_norm).reshape(B, L, attn.num_kv_heads, attn.head_dim)
            v = attn.v_proj(x_norm).reshape(B, L, attn.num_kv_heads, attn.head_dim)

            # Apply RoPE
            q = apply_rope_positions(q, positions, attn.head_dim, attn.rope_theta)
            k = apply_rope_positions(k, positions, attn.head_dim, attn.rope_theta)

            # Concatenate with cache
            k_full = mx.concatenate([cached_keys[i], k], axis=2)
            v_full = mx.concatenate([cached_values[i], v], axis=2)

            # Store new K, V
            new_keys.append(k_full)
            new_values.append(v_full)

            # Attention
            attn_out = fast.scaled_dot_product_attention(q, k_full, v_full, mask=mask)
            x = residual + attn.o_proj(attn_out.reshape(B, L, -1))

            # FFN
            residual = x
            x = residual + layer.mlp(layer.post_attention_layernorm(x))

        logits = model.lm_head(model.norm(x))
        return logits, new_keys, new_values

    # Compile with shapeless=True to handle variable sequence lengths
    compiled_fn = mx.compile(forward_fn, shapeless=True)
    return compiled_fn

Result: Error

ValueError: Cannot compile function with list arguments

mx.compile traces Python execution to build a computation graph. Python lists (cached_keys, cached_values) aren't part of the MLX graph—they're Python-level iteration. The function isn't traceable.

Implementation Attempt #2: Compile Per-Layer

Fine, let's compile smaller units:

def create_compiled_attention(layer: TransformerBlock):
    """Compile just the attention computation."""

    def attn_fn(x: mx.array, k_cache: mx.array, v_cache: mx.array,
                positions: mx.array, mask: mx.array) -> Tuple[mx.array, mx.array, mx.array]:
        attn = layer.self_attn
        B, L, _ = x.shape

        x_norm = layer.input_layernorm(x)

        q = attn.q_proj(x_norm).reshape(B, L, attn.num_heads, attn.head_dim)
        k = attn.k_proj(x_norm).reshape(B, L, attn.num_kv_heads, attn.head_dim)
        v = attn.v_proj(x_norm).reshape(B, L, attn.num_kv_heads, attn.head_dim)

        q = fast.rope(q, attn.head_dim, offset=positions)
        k = fast.rope(k, attn.head_dim, offset=positions)

        k_full = mx.concatenate([k_cache, k], axis=2)
        v_full = mx.concatenate([v_cache, v], axis=2)

        attn_out = fast.scaled_dot_product_attention(q, k_full, v_full, mask=mask)
        return attn.o_proj(attn_out.reshape(B, L, -1)), k_full, v_full

    return mx.compile(attn_fn, shapeless=True)

# Create compiled versions for all layers
compiled_attns = [create_compiled_attention(layer) for layer in model.layers]

This worked! But the results were underwhelming.

Benchmarking Compiled vs Non-Compiled

Careful benchmarks:

def benchmark_decode(model, input_ids, cache, num_trials=100):
    """Benchmark decode step timing."""
    times = []

    for _ in range(num_trials):
        # Warmup the cache state
        cache_keys = [c.k for c in cache.layers]
        cache_values = [c.v for c in cache.layers]

        start = time.perf_counter()
        logits = model.decode_step(input_ids, cache)
        mx.eval(logits)
        elapsed = time.perf_counter() - start
        times.append(elapsed)

    return np.mean(times[10:]), np.std(times[10:])  # Skip first 10 for warmup

# Test on small model first (128M params for faster iteration)
small_model = load_small_test_model()

print("Small model (128M):")
non_compiled_time, nc_std = benchmark_decode(small_model, input_ids, cache)
print(f"  Non-compiled: {non_compiled_time*1000:.2f}ms ± {nc_std*1000:.2f}ms")

compiled_small = create_compiled_model(small_model)
compiled_time, c_std = benchmark_decode(compiled_small, input_ids, cache)
print(f"  Compiled: {compiled_time*1000:.2f}ms ± {c_std*1000:.2f}ms")

# Full WeDLM-8B
print("\nWeDLM-8B:")
non_compiled_time, nc_std = benchmark_decode(wedlm_model, input_ids, cache)
print(f"  Non-compiled: {non_compiled_time*1000:.2f}ms ± {nc_std*1000:.2f}ms")

compiled_wedlm = create_compiled_model(wedlm_model)
compiled_time, c_std = benchmark_decode(compiled_wedlm, input_ids, cache)
print(f"  Compiled: {compiled_time*1000:.2f}ms ± {c_std*1000:.2f}ms")

Results:

Model Non-Compiled Compiled Speedup
Small (128M test) 0.67ms 0.74ms 0.91x
WeDLM-8B (full) 127ms 127ms 1.00x

The small model is actually slower with compilation (compilation overhead dominates). The full model shows no improvement at all.

Why Doesn't It Help?

After investigation, three factors explain the non-result:

Factor 1: MLX's Lazy Evaluation Already Optimizes

Without mx.compile, MLX already builds and optimizes computation graphs:

# This already builds a lazy graph
y = mx.add(mx.multiply(a, b), c)
# MLX sees: add(multiply(a, b), c)
# Optimizes before execution
mx.eval(y)

mx.compile adds explicit JIT tracing, but the lazy evaluation system already does similar optimizations. For straight-line numerical code, the difference is minimal.

Factor 2: The Bottleneck Is Matrix Multiplication

Profiling the decode step shows where time goes:

import mlx.profiler

with mlx.profiler.profile() as prof:
    for _ in range(50):
        logits = model.decode_step(input_ids, cache)
        mx.eval(logits)

print(prof.summary())

Output (simplified):

Operation            | Time    | % of Total
---------------------+---------+------------
MatMul (Q/K/V proj)  | 45ms    | 35%
MatMul (O proj)      | 15ms    | 12%
MatMul (FFN gate)    | 20ms    | 16%
MatMul (FFN up)      | 20ms    | 16%
MatMul (FFN down)    | 20ms    | 16%
Attention SDPA       | 5ms     | 4%
Other (RoPE, norm)   | 2ms     | 1%

95% of time is in matrix multiplications. These are already running optimized Metal kernels. mx.compile can't make them faster—they're hardware-bound.

The remaining 5% includes:

  • Python overhead: ~2%
  • Memory operations: ~3%

Even if mx.compile eliminated ALL Python overhead, you'd gain maybe 2%. And my careful batching has already minimized Python overhead.

Factor 3: The Cache Access Pattern

My KV cache uses Python lists with per-layer indexing:

# In forward pass
for i, layer in enumerate(self.layers):
    cached_k, cached_v = cache.layers[i].k, cache.layers[i].v
    #                     ↑ Python list indexing
    k_full = mx.concatenate([cached_k, k], axis=2)

mx.compile traces function execution to build a graph. But Python list indexing (cache.layers[i]) happens in Python, not MLX. Each layer access is a separate trace boundary.

A fully compiled version would need the cache as a single MLX array with layer indices as part of the computation. That requires a major refactor:

# Hypothetical fully-traceable cache
cache_all_k = mx.stack(all_layer_keys)  # Shape: (num_layers, B, H, L, D)
cache_all_v = mx.stack(all_layer_values)

# Index within MLX graph
layer_k = cache_all_k[layer_idx]  # This CAN be traced

But this makes cache updates more complex and doesn't address the fundamental bottleneck (matrix multiplies).

When DOES `mx.compile` Help?

The investigation clarified when compilation matters:

Good candidates for mx.compile:

  • Custom loss functions with many arithmetic operations
  • Preprocessing pipelines with lots of small ops
  • Research code with complex control flow that repeats
  • Functions where Python loop overhead is significant

Poor candidates for mx.compile:

  • LLM inference (matrix multiply dominated)
  • Single forward passes (no repetition to amortize compile cost)
  • Code with Python-level data structures in the hot path

The Decision: Skip Compilation

The final implementation omits the mx.compile wrapper. The benefits were zero, but the costs were real:

  • Added complexity to the code
  • Longer startup time (compilation happens on first call)
  • Debugging becomes harder (compiled traces are opaque)

In this project, the mx.compile overhead provided no measurable benefit—recognizing which optimization targets are not bottlenecks avoided wasted effort.


Part IX: Benchmark Results

Experimental Setup

January 3-4, 2026

With all optimizations complete, rigorous benchmarks were needed—not just "it feels fast," but reproducible numbers with proper methodology.

Hardware Configuration

  • Machine: MacBook Pro 16" (2023)
  • Chip: Apple M3 Max
    • 14-core CPU (10 performance + 4 efficiency)
    • 30-core GPU
    • 16-core Neural Engine
  • Memory: 48GB unified memory
  • Storage: 1TB SSD (irrelevant for inference, model fits in RAM)
  • Power: Always plugged in, "High Power" mode enabled
  • OS: macOS Sonoma 14.2

Model Configurations

Model Name Base Params Precision Size on Disk
wedlm-8b-bf16 Qwen3-8B 8.03B BF16 16.0 GB
wedlm-8b-4bit Qwen3-8B 8.03B 4-bit 4.0 GB
wedlm-8b-4bit-sel Qwen3-8B 8.03B Mixed 5.8 GB
qwen3-8b-bf16 Qwen3-8B 8.03B BF16 16.0 GB
qwen3-8b-4bit Qwen3-8B 8.03B 4-bit 4.0 GB

Benchmarking Methodology

Proper LLM benchmarks require careful methodology to avoid common pitfalls:

def benchmark_throughput(model, prompts: List[str], max_tokens: int = 256,
                         num_warmup: int = 3, num_trials: int = 10) -> dict:
    """Benchmark generation throughput with proper methodology."""

    results = []

    for i, prompt in enumerate(prompts):
        prompt_results = {"prompt_idx": i, "trials": []}

        for trial in range(num_warmup + num_trials):
            # Force garbage collection between trials
            import gc
            gc.collect()

            # Clear any pending MLX operations
            mx.synchronize()

            # Time the generation
            start = time.perf_counter()

            output = model.generate(prompt, max_tokens=max_tokens)

            # Ensure all GPU work is complete
            mx.synchronize()

            elapsed = time.perf_counter() - start

            # Count actual generated tokens (not max_tokens)
            output_tokens = len(model.tokenizer.encode(output.text))
            input_tokens = len(model.tokenizer.encode(prompt))
            generated_tokens = output_tokens - input_tokens

            # Skip warmup trials
            if trial >= num_warmup:
                prompt_results["trials"].append({
                    "elapsed": elapsed,
                    "generated_tokens": generated_tokens,
                    "tok_per_s": generated_tokens / elapsed,
                    "prefill_time": output.prefill_time if hasattr(output, 'prefill_time') else None,
                })

        results.append(prompt_results)

    return results

def compute_statistics(results: List[dict]) -> dict:
    """Compute aggregate statistics from benchmark results."""
    all_tps = [t["tok_per_s"] for r in results for t in r["trials"]]

    return {
        "mean_tok_per_s": np.mean(all_tps),
        "std_tok_per_s": np.std(all_tps),
        "median_tok_per_s": np.median(all_tps),
        "p5_tok_per_s": np.percentile(all_tps, 5),
        "p95_tok_per_s": np.percentile(all_tps, 95),
        "num_samples": len(all_tps),
    }

Benchmark Prompts

A diverse prompt set captures different generation scenarios:

benchmark_prompts = [
    # Short factual
    "What is the capital of France?",

    # Medium explanation
    "Explain the difference between TCP and UDP in networking.",

    # Code generation
    "Write a Python function to compute the nth Fibonacci number using dynamic programming.",

    # Long-form reasoning
    "Describe the process of photosynthesis and explain why it's important for life on Earth.",

    # Math (tests accuracy)
    "Solve step by step: A train travels 120 miles in 2 hours. How long to travel 300 miles?",
]

MLX vs PyTorch MPS Comparison

First, comparing MLX to PyTorch's MPS backend for autoregressive Qwen3-8B:

# PyTorch MPS baseline
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

device = torch.device("mps")
pt_model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen3-8B",
    torch_dtype=torch.bfloat16,
    device_map="mps"
)
pt_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")

# MLX baseline
from mlx_lm import load, generate as mlx_generate
mlx_model, mlx_tokenizer = load("mlx-community/Qwen3-8B-4bit")

Results:

Backend Precision Throughput (tok/s) Memory Usage vs PyTorch BF16
PyTorch MPS BF16 29.1 ± 2.3 ~18 GB baseline
PyTorch MPS FP16 30.4 ± 2.1 ~17 GB +4%
MLX BF16 35.0 ± 1.8 ~17 GB +20%
MLX 4-bit 60.9 ± 3.2 ~5 GB +109%

MLX consistently outperforms PyTorch MPS. The advantage comes from:

  1. Native Metal Integration: MLX is written for Metal, while PyTorch MPS is a translation layer
  2. Unified Memory Efficiency: MLX better exploits the unified memory architecture
  3. Lazy Evaluation: Graph optimization reduces memory copies and intermediate storage
  4. Optimized Ops: fast.rope, fast.scaled_dot_product_attention are Metal-native

WeDLM Diffusion vs Autoregressive

The main comparison: does WeDLM's parallel decoding actually help?

# WeDLM (diffusion decoding)
wedlm = WeDLMDecoder.from_pretrained("./wedlm-8b-4bit-selective")

# Baseline (autoregressive, same base model)
from mlx_lm import load, generate
ar_model, ar_tokenizer = load("mlx-community/Qwen3-8B-4bit")

Throughput Comparison:

Method Precision Throughput Tokens/Forward Speedup vs AR
Qwen3-8B AR 4-bit 40.2 tok/s 1.0 baseline
WeDLM (window=8) 4-bit sel 50.1 tok/s 1.8 1.25x
WeDLM (window=16) 4-bit sel 55.2 tok/s 2.2 1.37x
WeDLM (window=32) 4-bit sel 52.8 tok/s 2.0 1.31x

Interesting findings:

  • Window size 16 is the sweet spot on M3 Max
  • Window size 32 is actually slower—the overhead of managing a larger window outweighs the parallelism benefit
  • Average tokens per forward ranges from 1.8 to 2.2, not the 3-4 claimed in the paper

Why Fewer Tokens Per Forward?

The paper reports higher parallelism because:

  1. They use larger models with more confident predictions
  2. Their entropy threshold is tuned differently
  3. CUDA FlashAttention makes larger windows more efficient

On M3 Max, the attention computation per window token is more expensive relative to the matrix multiplies, so smaller windows are more efficient.

Accuracy on GSM8K

Speed is meaningless if quality suffers. Evaluation on GSM8K (grade school math):

def evaluate_gsm8k(model, num_samples: int = 200) -> float:
    """Evaluate on GSM8K subset."""
    from datasets import load_dataset

    dataset = load_dataset("gsm8k", "main", split="test")
    samples = dataset.select(range(num_samples))

    correct = 0
    for sample in samples:
        question = sample["question"]
        answer = sample["answer"]

        # Extract ground truth (last number in answer)
        ground_truth = extract_final_number(answer)

        # Generate response
        prompt = f"Q: {question}\nA: Let me solve this step by step.\n"
        response = model.generate(prompt, max_tokens=256)

        # Extract predicted answer
        predicted = extract_final_number(response.text)

        if predicted == ground_truth:
            correct += 1

    return correct / num_samples

Results (200-sample subset):

Model GSM8K Accuracy Notes
Qwen3-8B AR (BF16) 80.0% Reference quality
Qwen3-8B AR (4-bit) 78.5% Minor quantization impact
WeDLM (BF16) 71.5% -8.5% vs AR
WeDLM (4-bit sel) 70.0% -10% vs AR
WeDLM (4-bit all) 64.0% Unacceptable for math

The accuracy gap is concerning. WeDLM's parallel decoding introduces errors because:

  1. Entropy-based commitment: Tokens are committed based on confidence, not correctness
  2. Early mistakes propagate: An early wrong digit affects all subsequent reasoning
  3. Position penalty bias: May commit wrong tokens early to maximize cache utilization

For math tasks where precision matters, AR decoding is still preferable.

Prefix Cacheability Analysis

WeDLM's key innovation is prefix cacheability. Measuring the actual pcachep_{\text{cache}}:

def measure_prefix_cacheability(model, prompts: List[str], max_tokens: int = 128):
    """Measure actual prefix cache hit rate."""

    total_generated = 0
    total_forwards = 0

    for prompt in prompts:
        state = model.initialize_state(prompt)

        while not state.is_finished and state.num_generated < max_tokens:
            state, committed = model.decode_step(state)

            total_generated += len(committed)
            total_forwards += 1

    p_cache = total_generated / total_forwards
    return p_cache

# Measure across prompt types
cache_results = {}
for prompt_type, prompts in prompt_categories.items():
    p_cache = measure_prefix_cacheability(wedlm, prompts)
    cache_results[prompt_type] = p_cache

print("Prefix Cacheability by Prompt Type:")
for ptype, pcache in cache_results.items():
    print(f"  {ptype}: {pcache:.2f}")

Results:

Prompt Type pcachep_{\text{cache}} Notes
Code generation 0.92 Highly predictable tokens
Factual Q&A 0.88 Common patterns
Creative writing 0.75 More varied vocabulary
Math reasoning 0.70 Numbers are high-entropy
Overall average 0.81 Weighted by prompt frequency

The paper reports pcache0.85p_{\text{cache}} \approx 0.85 on their benchmarks. We achieve 0.81, which is reasonable given different evaluation prompts and entropy thresholds.

Memory Scaling with Sequence Length

How does memory usage scale as we generate longer sequences?

def measure_memory_scaling(model, prompt: str, max_lengths: List[int]):
    """Measure peak memory at different sequence lengths."""
    import tracemalloc

    results = []
    for max_len in max_lengths:
        tracemalloc.start()

        output = model.generate(prompt, max_tokens=max_len)
        mx.synchronize()

        current, peak = tracemalloc.get_traced_memory()
        tracemalloc.stop()

        results.append({
            "max_tokens": max_len,
            "peak_memory_gb": peak / 1e9,
            "actual_tokens": len(model.tokenizer.encode(output.text))
        })

    return results

KV Cache Memory (BF16, batch=1):

Sequence Length KV Cache Size Total Memory Notes
512 288 MB 17.3 GB Short conversations
1024 576 MB 17.6 GB Typical chat
2048 1.15 GB 18.2 GB Long documents
4096 2.3 GB 19.3 GB Extended context
8192 4.6 GB 21.6 GB M3 Max comfortable
16384 9.2 GB 26.2 GB Approaching limits

Formula: KV cache=2×L×Hkv×D×N×B\text{KV cache} = 2 \times L \times H_{kv} \times D \times N \times B

Where LL = num*layers (36), HkvH*{kv} = num_kv_heads (8), DD = head_dim (128), NN = sequence length, BB = bytes per element (2 for BF16).

For WeDLM-8B: 2×36×8×128×N×2=147,456×N2 \times 36 \times 8 \times 128 \times N \times 2 = 147,456 \times N bytes ≈ 0.14 MB per token.

Comparison with Paper Claims

The WeDLM paper reports 3x speedup over vLLM on NVIDIA A100. Our MLX port achieves 1.25-1.37x speedup over MLX AR.

Why the gap?

Factor Paper (CUDA) This Port (MLX)
Attention FlashAttention (2-4x faster) Standard SDPA
Hardware A100 (80GB, 312 TFLOPS) M3 Max (48GB, ~14 TFLOPS)
Parallelism Large batch, high GPU utilization Single stream, batch=1
Framework maturity CUDA 15+ years MLX ~1 year
Window overhead Amortized over larger batches Dominates at batch=1

The comparison isn't entirely fair—they're measuring different things. But it sets realistic expectations for MLX deployment.

End-to-End Latency

For interactive applications, latency matters as much as throughput:

def measure_latency_breakdown(model, prompt: str, max_tokens: int = 100):
    """Measure latency components."""

    # Tokenization
    t0 = time.perf_counter()
    input_ids = model.tokenizer.encode(prompt)
    tokenize_time = time.perf_counter() - t0

    # Prefill
    t1 = time.perf_counter()
    cache = model.prefill(mx.array([input_ids]))
    mx.synchronize()
    prefill_time = time.perf_counter() - t1

    # Decode
    decode_times = []
    tokens_per_step = []
    state = model.initialize_decode_state()

    while not state.is_finished and state.num_generated < max_tokens:
        t_start = time.perf_counter()
        state, committed = model.decode_step(state, cache)
        mx.synchronize()
        t_end = time.perf_counter()

        decode_times.append(t_end - t_start)
        tokens_per_step.append(len(committed))

    return {
        "tokenize_ms": tokenize_time * 1000,
        "prefill_ms": prefill_time * 1000,
        "decode_mean_ms": np.mean(decode_times) * 1000,
        "decode_total_ms": sum(decode_times) * 1000,
        "tokens_per_step_mean": np.mean(tokens_per_step),
        "total_tokens": sum(tokens_per_step),
    }

Latency Breakdown (100-token generation):

Component Time (ms) % of Total
Tokenization 2 0.1%
Prefill (512 tok) 180 9.5%
Decode (total) 1700 89.9%
Total 1882 100%

Per-step decode latency: ~34ms average (with ~2 tokens committed per step).

Time-to-first-token: ~182ms (tokenize + prefill).


Part X: Implementation Takeaways

Engineering Notes

January 4, 2026

This section summarizes the main engineering patterns that turned out to matter for the WeDLM→MLX port and are likely to generalize to other MLX deployments.

Profiling Before Optimization

Several hours early on went to optimizing KV cache update patterns—minimizing copies, batching updates, avoiding redundant operations. Then profiling revealed the actual breakdown:

import mlx.profiler

with mlx.profiler.profile() as prof:
    for _ in range(50):
        output = model.decode_step(input_ids, cache)
        mx.eval(output)

print(prof.summary())

Result:

Operation Breakdown:
  MatMul operations:  95.2%
  Attention (SDPA):    2.1%
  KV cache updates:    1.3%  ← All that work for 1.3%
  LayerNorm:           0.8%
  RoPE:                0.3%
  Other:               0.3%

Even a 10x improvement in KV cache updates would yield a 1.2% total speedup. The optimization effort targeted 1% of runtime, not the 95%.

In this project, most attempted micro-optimizations had negligible impact compared to addressing the dominant cost (matrix multiplications), so profiling was essential before modifying hot paths.

MLX Lazy Evaluation

MLX's lazy evaluation affects how optimization works:

PyTorch (Eager) MLX (Lazy)
Minimize operation count Minimize synchronization points
Each op executes immediately Ops build graph, execute on eval
.item() is fast (just returns) .item() forces full graph evaluation
Small loops are fine Small loops with ops = many small graphs

Code that's idiomatic in PyTorch destroys MLX performance:

# PyTorch style - fine
for i in range(window_size):
    if logits[i].argmax().item() > threshold:
        results.append(i)

# MLX equivalent - TERRIBLE (window_size sync points)
for i in range(window_size):
    if mx.argmax(logits[i]).item() > threshold:
        results.append(i)

# MLX style - good (one sync point)
max_vals = mx.argmax(logits, axis=-1)
mx.eval(max_vals)
results = [i for i, v in enumerate(max_vals.tolist()) if v > threshold]

Same algorithm, 100x performance difference based on framework understanding.

Selective Quantization Strategy

The "quantize everything to 4-bit" approach is not optimal:

Layer Type Quantization Impact Recommendation
Attention projs Low Safe to quantize
FFN layers Medium Usually safe
Embeddings Medium Consider keeping BF16
lm_head High Keep in high precision
Layer norms Very high Never quantize

The output layer (lm_head) is special because argmax is discontinuous—small errors flip token choices. Attention layers are forgiving because softmax smooths errors.

Recommended approach: Start with selective quantization (exclude lm_head and embeddings), measure accuracy on the target task, then decide if more aggressive quantization is acceptable.

Threading Constraints with Metal

The SIGSEGV crash revealed Metal's threading constraints:

✗ Two threads calling mx.eval() simultaneously → SIGSEGV
✗ Multiple MLX models in ThreadPoolExecutor → SIGSEGV
✓ Sequential model calls → Works
✓ Separate processes → Works (but memory overhead)

The underlying issue: Metal command buffers aren't thread-safe, and MLX uses global state. This isn't prominently documented because single-threaded use is the norm.

In practice, one MLX context per process is required. If concurrent GPU work is needed, multiprocessing (not threading) is necessary, or the application must be designed for sequential inference.

Framework Limitations

The mx.compile investigation illustrates when to stop optimizing:

  • Hours spent trying to make compilation work
  • Result: 0% improvement (or negative for small models)
  • Root cause: The bottleneck (matrix multiplies) is already optimized

When workarounds become increasingly complex without performance improvement, it may indicate the framework's design constraints are fundamental rather than incidental.

Diffusion Decoding Trade-offs

WeDLM is not a pure upgrade over AR decoding:

Aspect AR Decoding Diffusion (WeDLM)
Throughput Lower Higher (1.2-1.4x)
Accuracy Higher Lower (-8-10% on math)
Latency Predictable Variable (depends on entropy)
Complexity Simple Complex (window, reordering)
Best for Precision tasks High-volume generation

The paper emphasizes speedup; the accuracy trade-off is acknowledged but easy to overlook. For production, you need to evaluate on YOUR task, not benchmark sets.

Paper-to-Implementation Gap

The WeDLM paper reports 3x speedup; this port achieved 1.3x. The gap reflects:

  • Different hardware (M3 Max vs A100)
  • Different frameworks (MLX vs CUDA)
  • Different evaluation methodology
  • Missing optimizations (FlashAttention)

Reading ML papers pragmatically:

  1. Note the claims (3x speedup)
  2. Note the conditions (A100, FlashAttention, specific tasks)
  3. Expect 30-50% of claimed gains when porting
  4. Measure on YOUR setup and YOUR task

Local Inference Viability

Despite the caveats, the final system runs WeDLM-8B at 55 tokens/second on a laptop:

  • Real-time conversation
  • Private (no data leaves the machine)
  • Offline-capable
  • No API costs

Running 8B models locally at conversational speed is now practical on consumer hardware.


Part XI: Future Work

What Would Make This Better

1. FlashAttention for Metal

The single highest-impact optimization that doesn't exist yet. FlashAttention on CUDA achieves 2-4x attention speedup through:

  • Tiled computation: Process attention in blocks that fit in SRAM
  • Online softmax: Compute softmax incrementally without materializing the full attention matrix
  • Memory efficiency: O(N) memory instead of O(N²) for sequence length N

A Metal implementation using threadgroup memory could provide similar benefits. The scaled_dot_product_attention in MLX is already fused, but doesn't use the tiled algorithm.

Potential impact:
  Current attention: 2.1% of total time (after SDPA fusion)
  With FlashAttention: Could reduce by 50-75%
  Total speedup: ~1-2%... actually not huge for inference

  BUT: For long sequences (8K+), memory savings are crucial
  Current: O(N²) attention matrix
  FlashAttention: O(N) intermediate storage

For inference, FlashAttention is more about enabling longer sequences than raw speed. Still valuable.

2. Better Entropy Calibration

The current position penalty (λ=0.01\lambda = 0.01) was empirically chosen through manual tuning. A more sophisticated approach:

class AdaptiveEntropySelector:
    """Learn optimal commitment thresholds from data."""

    def __init__(self, calibration_data: List[str]):
        # Collect entropy statistics on calibration prompts
        self.entropy_stats = self._collect_stats(calibration_data)

        # Learn task-dependent thresholds
        self.code_threshold = self._compute_optimal_threshold("code")
        self.prose_threshold = self._compute_optimal_threshold("prose")
        self.math_threshold = self._compute_optimal_threshold("math")

    def select_commitment_positions(self, entropies: mx.array,
                                     task_type: str) -> List[int]:
        """Select positions to commit based on calibrated thresholds."""
        threshold = getattr(self, f"{task_type}_threshold")

        # Position penalty could also be learned
        adjusted = entropies + self.learned_position_penalty * positions

        return [i for i, e in enumerate(adjusted.tolist()) if e < threshold]

Different tasks have different entropy profiles. Code has low entropy (predictable tokens), creative writing has high entropy (many valid next tokens). A learned selector could adapt accordingly.

3. Speculative Decoding Integration

WeDLM generates multiple candidate tokens per step. This naturally fits speculative decoding:

Speculative decoding flow:
1. Draft model generates K tokens quickly: [t1, t2, t3, t4]
2. Target model verifies all K in single forward pass
3. Accept prefix that matches, regenerate from first mismatch

WeDLM hybrid:
1. WeDLM generates window: [t1, t2, t3, t4] with confidences
2. Verify uncertain tokens with AR pass on target model
3. Accept confident tokens, use AR for uncertain ones

This could preserve WeDLM's speed while recovering accuracy on low-confidence positions. Implementation complexity is high, but the potential is interesting.

4. Continuous Batching

For server deployment, batching multiple requests is essential:

class ContinuousBatcher:
    """Batch multiple requests dynamically."""

    def __init__(self, model, max_batch_size: int = 8):
        self.model = model
        self.pending_requests = []
        self.active_states = []

    def add_request(self, prompt: str, callback: Callable):
        """Add new request to the queue."""
        self.pending_requests.append((prompt, callback))
        self._maybe_promote_requests()

    def step(self):
        """Process one step for all active requests."""
        if not self.active_states:
            return

        # Batch all active states into single forward pass
        batched_input = self._batch_states(self.active_states)
        batched_output = self.model.forward_batch(batched_input)

        # Unbatch and update individual states
        for state, output in zip(self.active_states, self._unbatch(batched_output)):
            state.update(output)
            if state.is_finished:
                state.callback(state.result)
                self.active_states.remove(state)

MLX doesn't have built-in support for this (unlike vLLM's PagedAttention), but the primitives exist. It would require significant engineering.

5. Model Parallelism for Larger Models

For models larger than unified memory (70B+), model parallelism across multiple Macs could work:

# Hypothetical distributed inference
class DistributedWeDLM:
    def __init__(self, model_shards: List[str], hosts: List[str]):
        # Shard model across machines
        # Each machine holds layers [start:end]
        self.shards = [
            RemoteShard(host, shard_path)
            for host, shard_path in zip(hosts, model_shards)
        ]

    def forward(self, x: mx.array) -> mx.array:
        for shard in self.shards:
            x = shard.forward(x)  # Network transfer between shards
        return x

The latency overhead of network transfers would be significant, but it's the only way to run models that don't fit in memory.


Conclusion

Summary

The WeDLM port to Apple Silicon achieved:

Performance:

  • 55 tokens/second on M4 Max (4-bit selective quantization)
  • 1.37x speedup over autoregressive baseline
  • Time-to-first-token: ~180ms

Quality:

  • 70% accuracy on GSM8K (vs 80% for AR baseline)
  • Acceptable for conversational use, not ideal for precision tasks

What Worked:

  • MLX provides 20%+ speedup over PyTorch MPS
  • 4-bit quantization with selective exclusion balances speed and quality
  • Batched operations with single sync points are critical
  • Topological reordering genuinely enables parallel decoding with KV caching

What Didn't:

  • mx.compile provides no benefit (matrix multiply dominated)
  • Multi-threaded Metal access crashes (use sequential or multiprocessing)
  • Diffusion decoding accuracy is lower on reasoning tasks
  • The 3x speedup from the paper isn't achievable on consumer hardware
002352 visitors