Skip to main content
The script below converts a BF16 GPT-OSS checkpoint (Hugging Face layout with .safetensors) into an MXFP4-quantized checkpoint.
  • Downcasts selected weights to MXFP4 using Triton kernels
  • Keeps specified layers in full precision (router, norms, embeddings, etc.)
  • Splits large weights into multiple .safetensors files
  • Writes a model.safetensors.index.json and updates config.json with quantization_config
  • Copies tokenizer files from the original checkpoint
This is a one-way conversion helper: it assumes your source checkpoint is already BF16 and on disk.

Prerequisites

  • Environment with at least 1 GPU (CUDA)
  • Python 3.10+ recommended
  • The following Python packages:
# requirements.txt
triton==3.4.0
torch==2.7.1
safetensors==0.4.3
kernels==0.11.1


Install:

```bash
pip install -r requirements.txt

Usage

Assume you have the original BF16 checkpoint:
  • INPUT_CKPT_PATH=/path/to/gpt-oss-20b-BF16/ (directory containing .safetensors, config.json, tokenizer files)
Choose an output directory for the quantized checkpoint:
  • QUANT_CKPT_PATH=/path/to/gpt-oss-20b-BF16-mxfp4/
Run:
python3 to_mxfp4.py \
  -i "$INPUT_CKPT_PATH" \
  -o "$QUANT_CKPT_PATH"
After this finishes, $QUANT_CKPT_PATH will contain:
  • MXFP4-quantized .safetensors shards
  • model.safetensors.index.json
  • config.json with quantization_config
  • Tokenizer files copied from the original checkpoint
You can then point your runtime / serving stack at QUANT_CKPT_PATH.

Script: to_mxfp4.py

import argparse
import json
import re
from pathlib import Path

import torch
from kernels import get_kernel
from safetensors import safe_open
from safetensors.torch import save_file

# Device used for downcasting; single-GPU setup assumed
DEVICE = "cuda:0"

# Max size for a single .safetensors shard (in bytes)
ST_FILE_SIZE_LIMIT = 4.8 * 1024**3  # ~4.8 GB

# Files copied verbatim from the original checkpoint
COMMON_FILES = (
    "tokenizer_config.json",
    "tokenizer.json",
    "special_tokens_map.json",
)

# Quantization config injected into config.json of the converted checkpoint.
# These modules are kept in higher precision (not quantized).
QUANT_CONFIG = {
    "modules_to_not_convert": [
        "model.layers.*.self_attn",
        "model.layers.*.mlp.router",
        "model.layers.*.input_layernorm",
        "model.layers.*.post_attention_layernorm",
        "model.layers.*.mlp.experts.down_proj_bias",
        "model.layers.*.mlp.experts.gate_up_proj_bias",
        "model.embed_tokens",
        "model.norm.weight",
        "lm_head",
    ],
    "quant_method": "mxfp4",
}


def _compile_patterns(patterns: list[str]) -> list[re.Pattern]:
    """Compile glob-style module patterns (with *) into regexes."""
    compiled = []
    for p in patterns:
        esc = re.escape(p).replace(r"\*", ".*")
        regex = r"^" + esc + r"(?:\..*)?$"
        compiled.append(re.compile(regex))
    return compiled


EXCLUDE_PATTERNS = _compile_patterns(QUANT_CONFIG["modules_to_not_convert"])


def _needs_quantization(key: str) -> bool:
    """Return True if a weight name should be quantized to MXFP4."""
    return not any(p.match(key) for p in EXCLUDE_PATTERNS)


def _get_safetensors_out_name(idx: int) -> str:
    """Generate a deterministic safetensors filename for shard index."""
    return f"model-{str(idx).zfill(5)}.safetensors"


def convert_to_mxfp4(orig_ckpt_path: Path | str, converted_ckpt_path: Path | str) -> None:
    """
    Convert a BF16 checkpoint to MXFP4-quantized format.

    - Reads *.safetensors weights from `orig_ckpt_path`
    - Quantizes selected tensors to MXFP4
    - Writes sharded safetensors files + index + updated config
    """
    orig_ckpt_path = Path(orig_ckpt_path)
    converted_ckpt_path = Path(converted_ckpt_path)

    torch.manual_seed(42)
    torch.cuda.manual_seed(42)

    # Load MXFP kernels via kernels library (triton-based)
    triton_kernels = get_kernel("kernels-community/triton_kernels")
    mxfp_kernels = triton_kernels.numerics_details.mxfp

    def _downcast_to_mxfp4(name: str, tensor: torch.Tensor) -> dict[str, torch.Tensor]:
        """
        Downcast a BF16 tensor to MXFP4 blocks + scales.

        The conversion operates along the last dimension (`axis=-1`).
        """
        # Align layout for kernel: transpose last two dims to match expected shape
        tensor = tensor.swapaxes(-2, -1).contiguous()

        out_blocks, out_scales = mxfp_kernels.downcast_to_mxfp(
            tensor.to(DEVICE),
            out_quant_type=torch.uint8,
            axis=-1,
        )

        # Reshape blocks to match scales layout (implementation-specific detail)
        out_blocks = out_blocks.reshape(
            *out_blocks.shape[:-1],
            out_scales.shape[-1],
            -1,
        ).contiguous()

        return {
            f"{name}_blocks": out_blocks.to("cpu"),
            f"{name}_scales": out_scales.to("cpu").contiguous(),
        }

    converted_ckpt_path.mkdir(exist_ok=True, parents=True)

    metadata: dict[str, int | float] = {
        "total_parameters": 0,
        "total_size": 0,  # in bytes
    }
    weight_map: dict[str, str] = {}

    st_idx = 0
    cur_st_data_size = 0
    cur_st_data: dict[str, torch.Tensor] = {}

    # Iterate over all safetensors files in the original checkpoint
    for safetensors_path in orig_ckpt_path.glob("*.safetensors"):
        with safe_open(safetensors_path, framework="torch") as f:
            for tensor_name in f.keys():
                t = f.get_tensor(tensor_name)

                # Track original parameter stats
                metadata["total_parameters"] += t.numel()
                metadata["total_size"] += t.element_size() * t.numel()

                # Decide whether to quantize or keep as-is
                if _needs_quantization(tensor_name):
                    print(f"[MXFP4] Downcasting tensor: {tensor_name}")
                    params_upd = _downcast_to_mxfp4(tensor_name, t)
                else:
                    params_upd = {tensor_name: t}

                # Approximate size in bytes for the updated params
                params_upd_size_bytes = sum(
                    v.element_size() * v.numel() for v in params_upd.values()
                )

                # If adding these tensors exceeds the shard size, flush current shard
                if cur_st_data_size + params_upd_size_bytes < ST_FILE_SIZE_LIMIT:
                    cur_st_data_size += params_upd_size_bytes
                    cur_st_data.update(params_upd)
                else:
                    fn = converted_ckpt_path / _get_safetensors_out_name(st_idx)
                    save_file(cur_st_data, filename=fn)
                    print(
                        f"Saved {fn} with size {cur_st_data_size / (1024**3):.2f} GB"
                    )

                    # Start a new shard
                    st_idx += 1
                    cur_st_data = params_upd
                    cur_st_data_size = params_upd_size_bytes

                # Record weight -> shard mapping
                for k in params_upd:
                    weight_map[k] = _get_safetensors_out_name(st_idx)

    # Flush last shard
    if len(cur_st_data) > 0:
        fn = converted_ckpt_path / _get_safetensors_out_name(st_idx)
        save_file(cur_st_data, filename=fn)
        print(f"Saved {fn} with size {cur_st_data_size / (1024**3):.2f} GB")

    print("Weights saved to MXFP4 quantized safetensors format.")

    # Write safetensors index
    st_index = {
        "metadata": metadata,
        "weight_map": weight_map,
    }
    with open(converted_ckpt_path / "model.safetensors.index.json", "w") as f:
        json.dump(st_index, f, indent=2, sort_keys=True)
    print("model.safetensors.index.json saved.")

    # Load original config and add quantization_config
    with open(orig_ckpt_path / "config.json") as f:
        config = json.load(f)
    config["quantization_config"] = QUANT_CONFIG

    with open(converted_ckpt_path / "config.json", "w") as f:
        json.dump(config, f, indent=2, sort_keys=True)
    print("config.json with quantization_config saved.")

    # Copy tokenizer and related files
    for common_file in COMMON_FILES:
        src = orig_ckpt_path / common_file
        if src.exists():
            with open(src, "rb") as fin, open(
                converted_ckpt_path / common_file, "wb"
            ) as fout:
                fout.write(fin.read())
        else:
            print(f"Warning: {common_file} not found in original checkpoint.")

    print(f"Common files {COMMON_FILES} copied (where present).")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Convert a BF16 GPT-OSS checkpoint to an MXFP4-quantized checkpoint."
    )
    parser.add_argument(
        "-i",
        "--orig_ckpt_path",
        type=str,
        required=True,
        help="Path to the original checkpoint directory containing *.safetensors and config.json",
    )
    parser.add_argument(
        "-o",
        "--converted_ckpt_path",
        type=str,
        required=True,
        help="Path to the output directory where the MXFP4 checkpoint will be written",
    )
    args = parser.parse_args()

    convert_to_mxfp4(
        orig_ckpt_path=args.orig_ckpt_path,
        converted_ckpt_path=args.converted_ckpt_path,
    )