Swizzle and Shuffle Weights before calling FlashInfer's Fused MoE Kernel

Dec 23, 2025·
Zhiyao Ma
· 5 min read

I recently worked on supporting GPT OSS models in our Pie LLM serving system. It was extremely challenging to get FlashInfer’s fused MoE kernel to work correctly. The main reason is that FlashInfer’s documentation doesn’t mention the required physical layout of the expert weights in memory. If the weights are in the wrong layout, the kernel produces incorrect numerical results but won’t crash or emit any error messages.

I’m writing this post to share the solution I found, in the hope that it will save others some time. Here is a complete end-to-end example (click to download).

Challenges in Getting FlashInfer’s Fused MoE to Work

There are several requirements for FlashInfer’s fused MoE kernel to produce correct results:

  • (*) The activation is bfloat16, the MoE weight matrix is mxfp4, but the MoE bias vector must be float32.
  • (**) mxfp4 quantization must use a 128x4 swizzled layout.
  • (***) Matrices must be padded to a multiple of 256 in each dimension.
  • (**) Matrix rows must be shuffled in specific ways before being passed into the kernel.
  • (***) Bias weights must be shuffled using the same pattern as the matrix rows.

The stars represent the difficulty level:

  • 1 star: Documented, but counterintuitive, so AI agents hallucinated and didn’t follow it.
  • 2 stars: Not documented, but can be inferred or reverse-engineered from existing unit tests.
  • 3 stars: Not documented, and existing unit tests don’t cover it. Also, AI agents failed to suggest the correct fix.

You might be wondering what swizzling and shuffling are. They permute elements in the matrix to improve memory-access performance. You may find this blog post by Yifan Yang (a FlashInfer developer) helpful.

Given that FlashInfer is a low-level building block designed for maximum performance, it’s understandable that the kernel requires the caller to prepare weights in a specific in-memory layout. If the kernel had to convert layouts internally on every call, performance would suffer. Still, it’s unfortunate that the documentation doesn’t mention the required layout. Next, I’ll walk through the solution I found and how I discovered it.

Solution

The kernel function we use is trtllm_fp4_block_scale_moe. For GPT OSS, the input activation tensor is bfloat16 and the MoE weight matrix is mxfp4. In the absence of documentation, the best way to determine the correct usage is to read the unit test code. At the time of writing, I used FlashInfer 0.5.3. The corresponding unit test is here.

The method FP4Moe.prepare_static_weights_for_kernel is the key. It shows that weights must be swizzled and shuffled before being passed into the kernel. If you find the unit test hard to read, the end-to-end example linked above contains a simplified version of the same logic.

Here are the key snippets from the working example. First, after the weights are padded to the alignment requirements, we quantize the matrix to mxfp4. We need to swizzle the fp4 part of the matrix, but not the scale part.

# Quantize the weights
gemm1_weights_quant_padded, _    = quant_mxfp4_batches(gemm1_weights_padded, True)  # swizzling
_, gemm1_scales_linear_fp4_bytes = quant_mxfp4_batches(gemm1_weights_padded, False) # not swizzling

# Convert quantized weights to proper shapes
gemm1_weights_fp4 = gemm1_weights_quant_padded.view(torch.float8_e4m3fn).reshape(
    num_experts, 2 * padded_intermediate_size, padded_hidden_size // 2
)  # packed fp4
gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
    torch.float8_e4m3fn
).reshape(
    num_experts, 2 * padded_intermediate_size, padded_hidden_size // 32
)  # fp4 block scaling factors

Next, we shuffle the matrix rows. If biases are used, we must shuffle the biases using the same row permutation. The official unit test doesn’t cover the bias case, so the bias handling only appears in my working example.

gemm1_weights_fp4_shuffled_list = []
gemm1_scales_fp4_shuffled_list = []
gemm1_bias_shuffled_list = []

# Calculate the permute indices for the following:
# 1. Reorder rows of W1 and scales for fused gated activation
# 2. Shuffle weights and scaling factors for transposed mma output
#    for both w3_w1 and w2 weights and scale factors
for i in range(num_experts):
    permute_indices = _maybe_get_cached_w3_w1_permute_indices(
        CACHE_PERMUTE_INDICES,
        gemm1_weights_fp4[i].view(torch.uint8),
        epilogue_tile_m,
    )
    gemm1_weights_fp4_shuffled_list.append(
        gemm1_weights_fp4[i]
        .view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)]
        .contiguous()
    )

    # Shuffle gemm1 bias using row permutation derived from weight permutation
    if gemm1_bias_padded is not None:
        gemm1_bias_shuffled_list.append(
            gemm1_bias_padded[i][
                permute_indices.to(gemm1_bias_padded.device)
            ].contiguous()
        )

    permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
        CACHE_PERMUTE_INDICES,
        gemm1_scales_linear_fp4[i].view(torch.uint8),
        epilogue_tile_m,
        num_elts_per_sf=16,
    )
    gemm1_scales_fp4_shuffled_list.append(
        block_scale_interleave(
            gemm1_scales_linear_fp4[i]
            .view(torch.uint8)[
                permute_sf_indices.to(gemm1_scales_linear_fp4.device)
            ]
            .contiguous()
        )
    )

# Stack weights for all experts
gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled_list)
gemm1_scales_fp4_shuffled = (
    torch.stack(gemm1_scales_fp4_shuffled_list)
    .view(torch.float8_e4m3fn)
    .reshape(num_experts, 2 * padded_intermediate_size, padded_hidden_size // 32)
)

gemm1_bias_shuffled = None
if gemm1_bias_padded is not None:
    gemm1_bias_shuffled = torch.stack(gemm1_bias_shuffled_list)

Finally, we pass the *_shuffled tensors into the kernel. At this point, they have the expected physical layout in memory.

More Caveats

Beyond the above, there are additional pitfalls specific to the GPT OSS model weights we’re using. In particular, the official weights hosted on Hugging Face store the MoE up-projection weights and gate weights in a single matrix. Those hosted weights are interleaved: Odd columns are up-projection weights, and even columns are gate weights. However, FlashInfer’s fused MoE kernel expects them in the same matrix but not interleaved: The first half of the columns are up-projection weights, and the second half are gate weights. Therefore, we also need to deinterleave the weights before passing them into the kernel. You can follow the links to find the model definition and the weight preparation code in Pie.

Conclusion

The lack of documentation makes it difficult to get FlashInfer’s fused MoE kernel working correctly. However, the unit tests are a good source of truth. If you’re facing the same issue, I hope this post saves you time.

If I were designing the API, I’d use type information to automatically validate whether the weights are in the expected layout. An interpretable error message when they’re not would save developers a lot of time debugging silent numerical errors.