.safetensors) into an MXFP4-quantized checkpoint.
- Downcasts selected weights to MXFP4 using Triton kernels
- Keeps specified layers in full precision (router, norms, embeddings, etc.)
- Splits large weights into multiple
.safetensorsfiles - Writes a
model.safetensors.index.jsonand updatesconfig.jsonwithquantization_config - Copies tokenizer files from the original checkpoint
This is a one-way conversion helper: it assumes your source checkpoint is already BF16 and on disk.
Prerequisites
- Environment with at least 1 GPU (CUDA)
- Python 3.10+ recommended
- The following Python packages:
Copy
Ask AI
# requirements.txt
triton==3.4.0
torch==2.7.1
safetensors==0.4.3
kernels==0.11.1
Install:
```bash
pip install -r requirements.txt
Usage
Assume you have the original BF16 checkpoint:INPUT_CKPT_PATH=/path/to/gpt-oss-20b-BF16/(directory containing.safetensors,config.json, tokenizer files)
QUANT_CKPT_PATH=/path/to/gpt-oss-20b-BF16-mxfp4/
Copy
Ask AI
python3 to_mxfp4.py \
-i "$INPUT_CKPT_PATH" \
-o "$QUANT_CKPT_PATH"
$QUANT_CKPT_PATH will contain:
- MXFP4-quantized
.safetensorsshards model.safetensors.index.jsonconfig.jsonwithquantization_config- Tokenizer files copied from the original checkpoint
QUANT_CKPT_PATH.
Script: to_mxfp4.py
Copy
Ask AI
import argparse
import json
import re
from pathlib import Path
import torch
from kernels import get_kernel
from safetensors import safe_open
from safetensors.torch import save_file
# Device used for downcasting; single-GPU setup assumed
DEVICE = "cuda:0"
# Max size for a single .safetensors shard (in bytes)
ST_FILE_SIZE_LIMIT = 4.8 * 1024**3 # ~4.8 GB
# Files copied verbatim from the original checkpoint
COMMON_FILES = (
"tokenizer_config.json",
"tokenizer.json",
"special_tokens_map.json",
)
# Quantization config injected into config.json of the converted checkpoint.
# These modules are kept in higher precision (not quantized).
QUANT_CONFIG = {
"modules_to_not_convert": [
"model.layers.*.self_attn",
"model.layers.*.mlp.router",
"model.layers.*.input_layernorm",
"model.layers.*.post_attention_layernorm",
"model.layers.*.mlp.experts.down_proj_bias",
"model.layers.*.mlp.experts.gate_up_proj_bias",
"model.embed_tokens",
"model.norm.weight",
"lm_head",
],
"quant_method": "mxfp4",
}
def _compile_patterns(patterns: list[str]) -> list[re.Pattern]:
"""Compile glob-style module patterns (with *) into regexes."""
compiled = []
for p in patterns:
esc = re.escape(p).replace(r"\*", ".*")
regex = r"^" + esc + r"(?:\..*)?$"
compiled.append(re.compile(regex))
return compiled
EXCLUDE_PATTERNS = _compile_patterns(QUANT_CONFIG["modules_to_not_convert"])
def _needs_quantization(key: str) -> bool:
"""Return True if a weight name should be quantized to MXFP4."""
return not any(p.match(key) for p in EXCLUDE_PATTERNS)
def _get_safetensors_out_name(idx: int) -> str:
"""Generate a deterministic safetensors filename for shard index."""
return f"model-{str(idx).zfill(5)}.safetensors"
def convert_to_mxfp4(orig_ckpt_path: Path | str, converted_ckpt_path: Path | str) -> None:
"""
Convert a BF16 checkpoint to MXFP4-quantized format.
- Reads *.safetensors weights from `orig_ckpt_path`
- Quantizes selected tensors to MXFP4
- Writes sharded safetensors files + index + updated config
"""
orig_ckpt_path = Path(orig_ckpt_path)
converted_ckpt_path = Path(converted_ckpt_path)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
# Load MXFP kernels via kernels library (triton-based)
triton_kernels = get_kernel("kernels-community/triton_kernels")
mxfp_kernels = triton_kernels.numerics_details.mxfp
def _downcast_to_mxfp4(name: str, tensor: torch.Tensor) -> dict[str, torch.Tensor]:
"""
Downcast a BF16 tensor to MXFP4 blocks + scales.
The conversion operates along the last dimension (`axis=-1`).
"""
# Align layout for kernel: transpose last two dims to match expected shape
tensor = tensor.swapaxes(-2, -1).contiguous()
out_blocks, out_scales = mxfp_kernels.downcast_to_mxfp(
tensor.to(DEVICE),
out_quant_type=torch.uint8,
axis=-1,
)
# Reshape blocks to match scales layout (implementation-specific detail)
out_blocks = out_blocks.reshape(
*out_blocks.shape[:-1],
out_scales.shape[-1],
-1,
).contiguous()
return {
f"{name}_blocks": out_blocks.to("cpu"),
f"{name}_scales": out_scales.to("cpu").contiguous(),
}
converted_ckpt_path.mkdir(exist_ok=True, parents=True)
metadata: dict[str, int | float] = {
"total_parameters": 0,
"total_size": 0, # in bytes
}
weight_map: dict[str, str] = {}
st_idx = 0
cur_st_data_size = 0
cur_st_data: dict[str, torch.Tensor] = {}
# Iterate over all safetensors files in the original checkpoint
for safetensors_path in orig_ckpt_path.glob("*.safetensors"):
with safe_open(safetensors_path, framework="torch") as f:
for tensor_name in f.keys():
t = f.get_tensor(tensor_name)
# Track original parameter stats
metadata["total_parameters"] += t.numel()
metadata["total_size"] += t.element_size() * t.numel()
# Decide whether to quantize or keep as-is
if _needs_quantization(tensor_name):
print(f"[MXFP4] Downcasting tensor: {tensor_name}")
params_upd = _downcast_to_mxfp4(tensor_name, t)
else:
params_upd = {tensor_name: t}
# Approximate size in bytes for the updated params
params_upd_size_bytes = sum(
v.element_size() * v.numel() for v in params_upd.values()
)
# If adding these tensors exceeds the shard size, flush current shard
if cur_st_data_size + params_upd_size_bytes < ST_FILE_SIZE_LIMIT:
cur_st_data_size += params_upd_size_bytes
cur_st_data.update(params_upd)
else:
fn = converted_ckpt_path / _get_safetensors_out_name(st_idx)
save_file(cur_st_data, filename=fn)
print(
f"Saved {fn} with size {cur_st_data_size / (1024**3):.2f} GB"
)
# Start a new shard
st_idx += 1
cur_st_data = params_upd
cur_st_data_size = params_upd_size_bytes
# Record weight -> shard mapping
for k in params_upd:
weight_map[k] = _get_safetensors_out_name(st_idx)
# Flush last shard
if len(cur_st_data) > 0:
fn = converted_ckpt_path / _get_safetensors_out_name(st_idx)
save_file(cur_st_data, filename=fn)
print(f"Saved {fn} with size {cur_st_data_size / (1024**3):.2f} GB")
print("Weights saved to MXFP4 quantized safetensors format.")
# Write safetensors index
st_index = {
"metadata": metadata,
"weight_map": weight_map,
}
with open(converted_ckpt_path / "model.safetensors.index.json", "w") as f:
json.dump(st_index, f, indent=2, sort_keys=True)
print("model.safetensors.index.json saved.")
# Load original config and add quantization_config
with open(orig_ckpt_path / "config.json") as f:
config = json.load(f)
config["quantization_config"] = QUANT_CONFIG
with open(converted_ckpt_path / "config.json", "w") as f:
json.dump(config, f, indent=2, sort_keys=True)
print("config.json with quantization_config saved.")
# Copy tokenizer and related files
for common_file in COMMON_FILES:
src = orig_ckpt_path / common_file
if src.exists():
with open(src, "rb") as fin, open(
converted_ckpt_path / common_file, "wb"
) as fout:
fout.write(fin.read())
else:
print(f"Warning: {common_file} not found in original checkpoint.")
print(f"Common files {COMMON_FILES} copied (where present).")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Convert a BF16 GPT-OSS checkpoint to an MXFP4-quantized checkpoint."
)
parser.add_argument(
"-i",
"--orig_ckpt_path",
type=str,
required=True,
help="Path to the original checkpoint directory containing *.safetensors and config.json",
)
parser.add_argument(
"-o",
"--converted_ckpt_path",
type=str,
required=True,
help="Path to the output directory where the MXFP4 checkpoint will be written",
)
args = parser.parse_args()
convert_to_mxfp4(
orig_ckpt_path=args.orig_ckpt_path,
converted_ckpt_path=args.converted_ckpt_path,
)