Zimeng Xiong's Weblog

About

Porting Meta SAM-3D to Apple Silicon: Custom Metal Kernels and Memory Magic

Tl;DR This post documents the process of porting Meta's SAM-3D Objects (a 12GB foundation model for single-image 3D reconstruction) from CUDA/Linux to Apple Silicon macOS. The work involved rebuilding sparse convolution backends, implementing custom Metal compute shaders, and engineering a sequential model loading strategy that reduced peak memory from 61GB to 17GB.


Part 1: Understanding SAM 3D from the Paper

Overview

SAM 3D is a generative model for visually grounded 3D object reconstruction.

"We present SAM 3D, a generative model for visually grounded 3D object reconstruction, predicting geometry, texture, and layout from a single image. SAM 3D excels in natural images, where occlusion and scene clutter are common and visual recognition cues from context play a larger role."

The model differs from prior work (Hunyuan3D, TRELLIS) by conditioning on scene context rather than just isolated object crops. This enables reconstruction of occluded objects using recognition-based priors.

SAM 3D Teaser SAM 3D reconstructions from single images with occlusion and scene clutter.

The Problem SAM 3D Solves

The fundamental challenge of single-image 3D reconstruction is inverting a lossy process. As the paper states:

"The act of taking a photograph maps a 3D object to a set of 2D pixels, specified by a mask M in an image I. We seek to invert this map."

Mathematically, given an image I and object mask M, SAM 3D models the conditional distribution:

p(S, T, R, t, s | I, M)

Where:

  • S = 3D Shape
  • T = Texture
  • R = Rotation (6D representation)
  • t = Translation
  • s = Scale

Architecture Overview

SAM 3D uses a two-stage latent flow matching architecture:

SAM 3D Architecture Figure: The SAM 3D model architecture showing the two-stage pipeline.

"SAM 3D first jointly predicts object pose and coarse shape, then refines the shapes by integrating pictorial cues."

Stage 1: The Geometry Model (1.2B parameters)

The Geometry Model predicts:

  • Coarse shape O ∈ ℝ^(64³) — a 64x64x64 voxel grid
  • 6D rotation R
  • Translation t
  • Scale s

It uses a Mixture-of-Transformers (MoT) architecture, conditioning on DINOv2-encoded image features.

Stage 2: The Texture & Refinement Model (600M parameters)

"We first extract active voxels from the coarse shape O predicted by Geometry model. A 600M parameter sparse latent flow transformer refines geometric details and synthesizes object texture."

This stage operates on sparse tensors (only active voxels), which is why we need sparse convolution support—and why the CUDA dependency exists.

Input Encoding: Why Masks Matter

SAM 3D uses dual conditioning from two views:

"We use DINOv2 as an encoder to extract features from two pairs of images, resulting in 4 sets of conditioning tokens:

  • Cropped object: We encode the cropped image I by mask M and its corresponding cropped binary mask, providing a focused, high-resolution view of the object.
  • Full image: We encode the full image I and its full image binary mask, providing global scene context and recognition cues absent from the cropped view."

This dual encoding is critical—the crop provides detail while the full image provides context.


Mask Conditioning

SAM 3D's input consists of an image I and binary mask M specifying the target object:

Input image Input image: children's room scene.

Mask Binary mask (index 14): teddy bear isolated for reconstruction.

The model encodes both the masked crop and full image via DINOv2, providing local detail and global context simultaneously. This dual conditioning enables reconstruction of partially occluded objects.

"Mid-training builds up foundational skills: Mask-following, occlusion robustness, and layout estimation."

The model was trained on 61M render-paste samples with synthetic occlusions.


Part 2: The Port

The architecture consists of three main stages:

Input Image + Mask
        │
        ▼
   MoGe Depth Model (Stage 0)
   Monocular depth estimation → Pointmap (3D)
        │
        ▼
   Sparse Structure Generator (Stage 1)
   Diffusion: image → 25k voxel coordinates
        │
        ▼
   Structured Latent Generator (Stage 2)
   Diffusion: voxels → 8D latent features
        │
        ▼
   Mesh Decoder (Stage 3)
   SparseTensor → FlexiCubes → GLB Mesh

Model Components

Component Size Function
SS Generator 6.2 GB Generates sparse voxel structure
SLAT Generator 4.6 GB Creates structured latents
Mesh Decoder 346 MB Decodes to 3D mesh
GS Decoder 163 MB Decodes to Gaussian splat
SS Decoder 140 MB Decodes voxel coordinates
MoGe ~2 GB Monocular depth estimation
Total ~12 GB All model weights

The core problem: SAM-3D only runs on Linux with NVIDIA GPUs, and requires 61GB of RAM at peak. My MacBook Pro has 48GB of unified memory.

First Attempt: The spconv CPU Fork

My first approach was to use an experimental CPU-only spconv fork for macOS. A GitHub PR (#616) by @tilmantroester showed that CPU-only builds on macOS ARM64 were possible.

Building spconv CPU

export CPU_ONLY=1

# Clone forked dependencies
git clone https://github.com/tilmantroester/cumm --branch macos_cpu_only
uv pip install ./cumm

git clone https://github.com/tilmantroester/ccimport --branch macos_support
uv pip install ./ccimport

git clone https://github.com/haixuanTao/spconv --branch master
uv pip install ./spconv

The build succeeded, but there was a problem:

>>> from spconv.pytorch import SubMConv3d
Traceback (most recent call last):
  File "...cumm/tensorview_bind.py", line 51, in __init__
    cuda_ver = get_cuda_version_by_nvcc().split(".")
FileNotFoundError: [Errno 2] No such file or directory: 'nvcc'

The cumm library still tried to call nvcc during initialization, even with CPU_ONLY=1. The forked dependencies had drifted from each other since the PR was created.

Lessons from spconv

After two days of debugging build issues, I realized:

  1. The spconv fork is fragile — dependencies (cumm, ccimport, spconv) must all be from compatible forks
  2. CPU performance is poor — even if it worked, CPU-only sparse convolution would be 10-100x slower
  3. A better approach exists — write custom Metal kernels for the specific operations SAM-3D needs

The Metal Kernel Approach

I decided to implement custom Metal compute shaders for the two main bottlenecks:

  1. Sparse 3D Convolution — the core operation in SAM-3D's diffusion backbone
  2. Flash Attention — memory-efficient attention for the transformer blocks

Sparse Convolution Architecture

The key insight is that sparse 3D convolution can be decomposed into:

  1. Build Hash Table — O(1) spatial hashing for neighbor lookup
  2. Gather Features — collect neighbor voxel features using hash table
  3. Apply Weights — matrix multiply with convolution kernel
  4. Scatter Output — write results back to sparse tensor
// 3x3x3 kernel offsets (pre-computed)
constant int3 KERNEL_OFFSETS[27] = {
    int3(-1, -1, -1), int3(-1, -1, 0), int3(-1, -1, 1),
    int3(-1, 0, -1),  int3(-1, 0, 0),  int3(-1, 0, 1),
    int3(-1, 1, -1),  int3(-1, 1, 0),  int3(-1, 1, 1),
    int3(0, -1, -1),  int3(0, -1, 0),  int3(0, -1, 1),
    int3(0, 0, -1),   int3(0, 0, 0),   int3(0, 0, 1),
    int3(0, 1, -1),   int3(0, 1, 0),   int3(0, 1, 1),
    int3(1, -1, -1),  int3(1, -1, 0),  int3(1, -1, 1),
    int3(1, 0, -1),   int3(1, 0, 0),   int3(1, 0, 1),
    int3(1, 1, -1),   int3(1, 1, 0),   int3(1, 1, 1)
};

inline uint coord_to_hash(int4 coord, int3 spatial_shape) {
    int D = spatial_shape.x, H = spatial_shape.y, W = spatial_shape.z;
    return uint(coord.x * D * H * W + coord.y * H * W + coord.z * W + coord.w);
}

Hash Table Kernel

kernel void build_hash_table(
    device const int4* coords [[buffer(0)]],
    device int* hash_table [[buffer(1)]],
    constant int& N [[buffer(2)]],
    constant int3& spatial_shape [[buffer(3)]],
    uint gid [[thread_position_in_grid]]
) {
    if (gid >= uint(N)) return;
    
    int4 coord = coords[gid];
    uint hash = coord_to_hash(coord, spatial_shape);
    
    hash_table[hash] = int(gid);  // Store voxel index
}

Sparse Convolution Kernel

The main kernel iterates over 27 neighbors per voxel:

kernel void sparse_conv3x3x3_subm(
    device const float* features [[buffer(0)]],
    device const int4* coords [[buffer(1)]],
    device const float* weights [[buffer(2)]],
    device const int* hash_table [[buffer(4)]],
    device float* output [[buffer(5)]],
    constant int& N [[buffer(6)]],
    constant int& C_in [[buffer(7)]],
    constant int& C_out [[buffer(8)]],
    constant int3& spatial_shape [[buffer(9)]],
    uint2 gid [[thread_position_in_grid]]
) {
    uint voxel_idx = gid.x;
    uint out_ch_start = gid.y * 8;  // Process 8 channels per thread
    
    if (voxel_idx >= uint(N)) return;
    
    int4 center_coord = coords[voxel_idx];
    float accum[8] = {0.0f};
    
    for (int k = 0; k < 27; k++) {
        int3 offset = KERNEL_OFFSETS[k];
        int4 neighbor_coord = center_coord + int4(0, offset.x, offset.y, offset.z);
        
        // Boundary check
        if (neighbor_coord.y < 0 || neighbor_coord.y >= spatial_shape.x ||
            neighbor_coord.z < 0 || neighbor_coord.z >= spatial_shape.y ||
            neighbor_coord.w < 0 || neighbor_coord.w >= spatial_shape.z)
            continue;
        
        // O(1) hash lookup
        uint hash = coord_to_hash(neighbor_coord, spatial_shape);
        int neighbor_idx = hash_table[hash];
        if (neighbor_idx == -1) continue;
        
        // Gather features and apply weights
        for (uint c_out = 0; c_out < 8; c_out++) {
            for (int c_in = 0; c_in < C_in; c_in++) {
                float feat = features[neighbor_idx * C_in + c_in];
                float w = weights[k * C_in * C_out + c_in * C_out + out_ch_start + c_out];
                accum[c_out] += feat * w;
            }
        }
    }
    
    // Write output
    for (uint c_out = 0; c_out < 8; c_out++) {
        output[voxel_idx * C_out + out_ch_start + c_out] = accum[c_out];
    }
}

Metal-Python Integration

The Python wrapper uses PyObjC to interface with Metal:

import Metal
import objc

def _ensure_metal_initialized():
    global _metal_device, _metal_library, _metal_functions
    
    if _metal_device is not None:
        return
    
    # Get GPU device
    _metal_device = Metal.MTLCreateSystemDefaultDevice()
    
    # Compile shader
    source = Path("sparse_conv.metal").read_text()
    options = Metal.MTLCompileOptions.alloc().init()
    library, error = _metal_device.newLibraryWithSource_options_error_(
        source, options, None
    )
    if error:
        raise RuntimeError(f"Shader compile error: {error}")
    _metal_library = library
    
    # Create pipeline states
    for kernel_name in ["build_hash_table", "sparse_conv3x3x3_subm"]:
        fn = _metal_library.newFunctionWithName_(kernel_name)
        pipeline, _ = _metal_device.newComputePipelineStateWithFunction_error_(
            fn, None
        )
        _metal_functions[kernel_name] = pipeline

Test Results

=== Testing Metal Backend Integration ===
✓ Metal framework: Apple M4 Max
✓ Metal Sparse Conv: AVAILABLE
✓ Metal Flash Attn: AVAILABLE

Metal sparse conv test (1000 voxels, 64 channels): 9.19ms

The Metal kernel completed in 9.19ms for 1000 voxels — a significant improvement over CPU.

Demo output 1 Teddy bear 3D reconstruction from the children's room test image

The Memory Crisis

With Metal acceleration working, I ran the full pipeline—and it OOM'd. The pipeline needed 61GB of RAM, but I only had 48GB.

Original Memory Profile

┌─────────────────────────────────────────────┐
│  All 12GB models in memory                  │
│  Plus activations: ~45GB more               │
│  Total peak: 61GB                           │
└─────────────────────────────────────────────┘

The original pipeline loaded ALL models simultaneously during initialization:

# Original: Everything in memory at once
self.models = {
    "ss_generator": load(...),      # 6.2 GB
    "slat_generator": load(...),    # 4.6 GB
    "slat_decoder_mesh": load(...), # 350 MB
    ...
}
# Peak: ~14 GB just for weights, 61 GB+ with activations

The Solution: Sequential Model Loading

I created InferencePipelineLowMemory that loads models on-demand and deletes them immediately after use:

class InferencePipelineLowMemory:
    """Load models on-demand, delete after use."""
    
    def __init__(self, config_path, cache_dir=None):
        # DON'T load models here!
        self.config = OmegaConf.load(config_path)
        self.models = {}  # Empty
    
    def run(self, image, mask, **kwargs):
        # Stage 0: Depth
        depth_model = self._load_model("depth")  # ~2GB
        pointmap = depth_model(image)
        self._delete_model(depth_model)  # FREE 2GB
        
        # Stage 1: Sparse Structure
        ss_generator = self._load_model("ss_generator")  # ~6GB
        coords = ss_generator(image)
        self._delete_model(ss_generator)  # FREE 6GB
        
        # Stage 2: Structured Latent
        slat_generator = self._load_model("slat_generator")  # ~5GB
        slat = slat_generator(coords)
        self._delete_model(slat_generator)  # FREE 5GB
        
        # Stage 3: Decode
        decoder = self._load_model("mesh_decoder")  # ~350MB
        mesh = decoder(slat)
        self._delete_model(decoder)  # FREE 350MB
        
        return mesh

Aggressive Garbage Collection

Python's garbage collector doesn't immediately free PyTorch tensors. I implemented aggressive deletion:

def force_gc():
    """Triple garbage collection + MPS cache clear."""
    gc.collect()
    gc.collect()
    gc.collect()
    
    if torch.backends.mps.is_available():
        torch.mps.synchronize()
        torch.mps.empty_cache()

def delete_model_completely(model, name="model"):
    """Fully delete model from memory."""
    model.cpu()  # Move to CPU first
    
    # Delete all tensor data
    for param in model.parameters():
        param.data = torch.empty(0)
        param.grad = None
    
    for buffer in model.buffers():
        buffer.data = torch.empty(0)
    
    del model
    gc.collect()

Memory Profile After Sequential Loading

Stage Original Low-Memory Reduction
Init 45 GB+ 0.7 GB 98%
After Depth 48 GB 5.5 GB 89%
After Stage 1 52 GB 17.3 GB 67%
After Stage 2 58 GB 17.3 GB 70%
Peak 61 GB 17.3 GB 72%

40 step voxel output Voxel output at 40 diffusion steps with sequential loading

The Mesh Decoder Problem

Sequential loading fixed the diffusion model memory issue. But the mesh decoder itself still OOM'd—the SparseSubdivide operation was exploding memory:

Input: 20K voxels
       │
       ▼ SparseSubdivide (8x)
161K voxels (~4GB)
       │
       ▼ SparseSubdivide (8x)
1.3M voxels (~48GB) ← OOM!

The problem: both the input tensor (161K) and output tensor (1.3M) exist simultaneously during the second subdivision.

First Attempt: Chunked Mesh Decoding

My first instinct was to process the mesh in chunks rather than all at once:

# BROKEN: Chunked mesh decoding (first attempt)
def chunked_mesh_decode(slat, chunk_size=64):
    chunks = split_spatial(slat.coords, chunk_size)
    meshes = []
    
    for chunk in chunks:
        mesh = decoder(chunk)
        meshes.append(mesh)
    
    return merge_meshes(meshes)

Attempt 1: Initial Chunking — Complete Failure

The first attempt produced garbage. Chunks were scattered in 3D space:

Broken chunking output Broken chunking output v2 Chunks scattered randomly — coordinate system mismatch

The issue was a coordinate transformation bug — I was splitting by voxel index rather than spatial coordinates.

Attempt 2: Fixed Coordinates — Visible Seams

After fixing the coordinate transformation, chunks aligned properly:

Chunked output with seams Chunks aligned, but visible seam artifacts at boundaries

FlexiCubes uses marching cubes which requires continuous voxel grids. Splitting creates discontinuities at chunk boundaries.

Attempt 3: Overlapping Chunks — Made It Worse

I tried increasing chunk overlap to smooth the transitions:

More overlap attempt Increased overlap created competing geometry artifacts

Overlapping regions had competing geometry, creating jagged artifacts. Chunking was fundamentally incompatible with marching cubes.

The Real Fix: Intermediate Tensor Deletion

Instead of chunking, I modified SparseSubdivideBlock3d to delete its input before returning:

class SparseSubdivideBlock3d(nn.Module):
    def forward(self, x, delete_input=False):
        """Forward with optional input deletion for memory."""
        h = self.sub(x)  # Input: 161K, Output: 1.3M
        
        if delete_input:
            # Delete input BEFORE returning to free memory
            x.feats.data = torch.empty(0)
            x.coords.data = torch.empty(0)
            del x
            _force_gc()
        
        h = self.norm(h)
        h = self.activation(h)
        h = self.out(h)
        
        return h

And in the mesh decoder:

def forward(self, x):
    """Forward pass with aggressive memory management."""
    h = super().forward(x)
    
    # Delete input after transformer forward
    x.feats.data = torch.empty(0)
    del x
    _force_gc()
    
    # Upsample with intermediate deletion
    for i, block in enumerate(self.upsample):
        h = block(h, delete_input=(i < len(self.upsample) - 1))
        _force_gc()
    
    # Extract mesh (already uses FlexiCubes)
    return self.mesh_extractor(h)

After Optimization

2025-12-19 21:45:08 | [LOW-MEM] After mesh decoding: 4.1 GB
2025-12-19 21:45:17 | [LOW-MEM] End of run(): 4.1 GB
2025-12-19 21:45:17 | [LOW-MEM] Pipeline complete!

Peak memory during mesh decoding dropped from ~50GB to 4.1GB.

25 step mesh output Figure: Mesh output at 25 diffusion steps

Performance Summary

Configuration Time Memory
CPU (float32) ~15 min 61 GB (OOM)
MPS + Low-Memory ~4 min 17.3 GB peak
MPS + Metal Kernels ~3 min 17.3 GB peak

The final pipeline runs in approximately 3-4 minutes on M4 Max with peak memory of 17.3GB.

Key Findings

1. spconv CPU Fork is Unreliable

The experimental macOS spconv fork has dependency version drift issues. Building custom Metal kernels is more reliable for production use.

2. Sequential Loading is Essential

Loading 12GB of models simultaneously is not feasible on consumer hardware. Sequential load/run/delete cuts peak memory by 72%.

3. Intermediate Tensor Deletion Matters

During SparseSubdivide, both input and output tensors exist simultaneously. Explicit deletion before returning frees memory for the next stage.

4. FlexiCubes Needs Continuous Grids

Chunked mesh extraction breaks marching cubes. Use single-pass with memory optimization instead.

5. Diffusion Steps Affect Mesh Quality

Low step counts produce binary-like SDF values. Use 25+ steps for smooth FlexiCubes output.

Files Structure

Sam3Dv11/sam-3d-objects-github/
├── MPS_Pipeline.py              # Main entry point
├── sam3d_objects/
│   ├── model/backbone/tdfy_dit/
│   │   └── modules/sparse/
│   │       ├── conv/
│   │       │   ├── sparse_conv.metal    # Metal kernels
│   │       │   └── conv_metal.py        # Python wrapper
│   │       └── attention/
│   │           ├── flash_attn.metal     # Metal attention
│   │           └── metal_flash_attn.py  # Python wrapper
│   └── pipeline/
│       └── inference_pipeline_low_memory.py  # Sequential loading
└── archived_pipelines/           # Old demos

Usage

# Set environment
export PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0
export SPARSE_BACKEND=metal
export SPARSE_ATTN_BACKEND=metal_fa

# Run pipeline
python MPS_Pipeline.py \
    --image input.png \
    --mask-dir masks/ \
    --mask-index 0 \
    --mesh \
    --output output.glb \
    --steps 25

The Code

Sam3D-MLX is Open Source :)

Future Work

  1. MLX Conversion — Native MLX implementation could further improve performance
  2. Quantization — INT8/INT4 weights would reduce memory by 4-8x
  3. CoreML Export — Deploy on iOS devices

References

More recent articles

002352 visitors