Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 84 additions & 14 deletions src/winml/modelkit/commands/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,9 @@ def _validate_loader_tasks_for_model(
help="Maximum autoconf re-optimization rounds (default: 3). --no-analyze sets this to 0.",
)
@cli_utils.allow_unsupported_nodes_option()
@cli_utils.precision_option(
optional_message="When fp16, applies FP16 conversion during optimization."
)
@cli_utils.trust_remote_code_option(
optional_message="Trust remote code for custom model architectures (e.g., Mu2)."
)
Expand All @@ -514,6 +517,7 @@ def build(
analyze: bool,
max_optim_iterations: int | None,
allow_unsupported_nodes: bool,
precision: str | None,
trust_remote_code: bool,
verbose: int,
quiet: bool,
Expand Down Expand Up @@ -674,6 +678,8 @@ def _patch_device(cfg: WinMLBuildConfig) -> None:
# on the key being present, matching the module-mode path which passes
# allow_unsupported_nodes explicitly regardless of its value.
extra_kwargs["allow_unsupported_nodes"] = allow_unsupported_nodes
if precision == "fp16":
extra_kwargs["precision"] = "fp16"

if isinstance(config_or_configs, list):
# ---- MODULE MODE: array config, one build per submodule ----
Expand Down Expand Up @@ -1119,6 +1125,45 @@ def _on_reoptimize(autoconf_dict: dict) -> None:
return current_path, opt_elapsed


def _run_fp16_stage(
*,
model_path: Path,
stage_timings: list[tuple[str, float | None]],
) -> Path:
"""Run FP16 conversion stage on an ONNX model file.

Loads the model, applies FP16 conversion with keep_io_types=True,
and overwrites the file in-place.

Args:
model_path: Path to the ONNX model to convert.
stage_timings: List to append (stage_name, elapsed) tuple to.

Returns:
The same model_path (overwritten with FP16 model).
"""
from ..onnx import load_onnx, save_onnx
from ..optim.fp16 import convert_to_fp16
from ..utils.console import StageLive

with StageLive("fp16", console) as sl:
sl.set_status("Converting to FP16...")
t0 = time.monotonic()

model = load_onnx(model_path)
model = convert_to_fp16(model, keep_io_types=True)
save_onnx(model, model_path)

elapsed = time.monotonic() - t0
sl.set_done(elapsed)
sl.detail("[dim]I/O types preserved as FP32[/dim]")
sl.artifact(str(model_path), _safe_size(model_path))
sl.blank()

stage_timings.append(("FP16", elapsed))
return model_path


def _run_quantize_stage(
*,
config: WinMLBuildConfig,
Expand Down Expand Up @@ -1378,6 +1423,8 @@ def _name(base: str) -> str:

stage_timings.append(("Export", _export_elapsed))

_precision = extra_kwargs.pop("precision", None)

# ── Optimize stage ───────────────────────────────────────────
current_path, _ = _run_optimize_stage(
config=config,
Expand All @@ -1395,13 +1442,24 @@ def _name(base: str) -> str:
# Persist config after autoconf
config_path.write_text(json.dumps(config.to_dict(), indent=2))

# ── Quantize stage ───────────────────────────────────────────
current_path = _run_quantize_stage(
config=config,
current_path=current_path,
quantized_path=quantized_path,
stage_timings=stage_timings,
)
# ── FP16 conversion (when --precision fp16) ──────────────────
if _precision == "fp16":
current_path = _run_fp16_stage(
model_path=current_path,
stage_timings=stage_timings,
)

# ── Quantize stage (skipped when FP16 — incompatible) ────────
if _precision == "fp16" and config.quant is not None:
print_stage_skip(console, "quantize", "(incompatible with --precision fp16)")
stage_timings.append(("Quantize", None))
else:
current_path = _run_quantize_stage(
config=config,
current_path=current_path,
quantized_path=quantized_path,
stage_timings=stage_timings,
)

# ── Compile stage ────────────────────────────────────────────
current_path = _run_compile_stage(
Expand Down Expand Up @@ -1437,6 +1495,7 @@ def _build_onnx_pipeline(

max_iters: int = extra_kwargs.pop("hack_max_optim_iterations", 3)
allow_unsupported_nodes: bool = extra_kwargs.pop("allow_unsupported_nodes", False)
_precision: str | None = extra_kwargs.pop("precision", None)

# ── Validate + setup ─────────────────────────────────────────
if not onnx_path.exists():
Expand Down Expand Up @@ -1490,13 +1549,24 @@ def _build_onnx_pipeline(

config_path.write_text(json.dumps(config.to_dict(), indent=2))

# ── Quantize stage ───────────────────────────────────────────
current_path = _run_quantize_stage(
config=config,
current_path=current_path,
quantized_path=quantized_path,
stage_timings=stage_timings,
)
# ── FP16 conversion (when --precision fp16) ──────────────────
if _precision == "fp16":
current_path = _run_fp16_stage(
model_path=current_path,
stage_timings=stage_timings,
)

# ── Quantize stage (skipped when FP16 — incompatible) ────────
if _precision == "fp16" and config.quant is not None:
print_stage_skip(console, "quantize", "(incompatible with --precision fp16)")
stage_timings.append(("Quantize", None))
else:
current_path = _run_quantize_stage(
config=config,
current_path=current_path,
quantized_path=quantized_path,
stage_timings=stage_timings,
)

# ── Compile stage ────────────────────────────────────────────
current_path = _run_compile_stage(
Expand Down
13 changes: 13 additions & 0 deletions src/winml/modelkit/commands/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def _delete_onnx_with_external_data(onnx_path: Path) -> None:
help='JSON with shape overrides (e.g., {"sequence_length": 2048, "height": 640}).',
)
@cli_utils.build_config_option()
@cli_utils.precision_option(optional_message="When fp16, applies FP16 conversion after export.")
@cli_utils.verbosity_options()
@click.pass_context
def export(
Expand All @@ -148,6 +149,7 @@ def export(
export_config: Path | None,
shape_config: Path | None,
config_file: Path | None,
precision: str | None,
) -> None:
r"""Export HuggingFace model to ONNX format with HTP.

Expand Down Expand Up @@ -420,6 +422,17 @@ def export(
)
logger.debug("Export stats: %s", export_stats)

# Post-export FP16 conversion when --precision fp16 is specified
if precision == "fp16":
console.print("[bold]Converting to FP16...[/bold]")
from ..onnx import load_onnx, save_onnx
from ..optim.fp16 import convert_to_fp16

fp16_model = load_onnx(output_path)
fp16_model = convert_to_fp16(fp16_model, keep_io_types=True)
save_onnx(fp16_model, output_path)
console.print("[dim]FP16 conversion applied (I/O kept as FP32)[/dim]")

# TODO: re-enable post-export optimization (shape inference, constant folding)
# Disabled: needs validation that optimize_onnx preserves HTP hierarchy tags.
# from ..optim.api import optimize_onnx
Expand Down
45 changes: 45 additions & 0 deletions src/winml/modelkit/commands/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,21 @@ def capability_options(func: F) -> F:
default=None,
help="Configuration file (YAML/JSON)",
)
@cli_utils.precision_option(optional_message="Applies FP16 conversion after graph optimization.")
@click.option(
"--fp16-keep-io-types/--no-fp16-keep-io-types",
"fp16_keep_io_types",
default=True,
show_default=True,
help="Keep model I/O as FP32 when --precision fp16 (insert Cast at boundary)",
)
@click.option(
"--fp16-op-block-list",
"fp16_op_block_list",
type=str,
default=None,
help="Comma-separated list of op types to keep in FP32 (e.g., LayerNorm,Softmax)",
)
@cli_utils.verbosity_options()
@capability_options
@click.pass_context # type: ignore[arg-type] # capability_options widens the signature; click stubs want positional-only ctx but we keep it keyword-callable for back-compat
Expand All @@ -190,6 +205,9 @@ def optimize(
model: Path | None,
output: Path | None,
config: Path | None,
precision: str | None,
fp16_keep_io_types: bool,
fp16_op_block_list: str | None,
verbose: int,
quiet: bool,
**kwargs: Any,
Expand Down Expand Up @@ -224,6 +242,17 @@ def optimize(
# Basic optimization with GELU fusion
winml optimize -m model.onnx -o model_opt.onnx --enable-gelu-fusion

# Convert model to FP16 (after graph optimization)
winml optimize -m model.onnx -o fp16.onnx --precision fp16

# FP16 without preserving I/O types
winml optimize -m model.onnx -o fp16.onnx --precision fp16 \
--no-fp16-keep-io-types

# FP16 with specific ops kept in FP32
winml optimize -m model.onnx -o fp16.onnx --precision fp16 \
--fp16-op-block-list LayerNorm,Softmax

# Use config file
winml optimize -m model.onnx -c config.toml
"""
Expand Down Expand Up @@ -406,6 +435,22 @@ def optimize(
optimizer = Optimizer()
optimized_model = optimizer.optimize(onnx_model, **optimizer_kwargs)

# Post-optimization FP16 conversion (command-layer, not a pipe)
if precision == "fp16":
from ..optim.fp16 import convert_to_fp16

console.print("[bold]Converting to FP16...[/bold]")
op_block = (
[s.strip() for s in fp16_op_block_list.split(",") if s.strip()]
if fp16_op_block_list
else None
)
optimized_model = convert_to_fp16(
optimized_model,
keep_io_types=fp16_keep_io_types,
op_block_list=op_block,
)

console.print("[bold]Saving optimized model...[/bold]")
save_onnx(optimized_model, output)

Expand Down
2 changes: 2 additions & 0 deletions src/winml/modelkit/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .api import optimize_onnx
from .config import WinMLOptimizationConfig
from .errors import ConfigurationError, ModelValidationError, OptimizationError
from .fp16 import convert_to_fp16
from .optimizer import Optimizer
from .registry import (
BoolCapability,
Expand All @@ -48,6 +49,7 @@
"Optimizer",
"WinMLOptimizationConfig",
"auto_enable_dependencies",
"convert_to_fp16",
"optimize_onnx",
"validate",
"validate_dependencies",
Expand Down
91 changes: 91 additions & 0 deletions src/winml/modelkit/optim/fp16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""FP16 conversion utility for ONNX models.

Provides a single entry point for FP32→FP16 model conversion, used by
all CLI commands (optimize, build, export) at the command layer.

This is NOT an optimizer pipe — FP16 is a precision transformation (like
quantization), not a graph optimization. It runs after optimization and
before quantization in the build pipeline.
"""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING


if TYPE_CHECKING:
import onnx

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note

Module 'onnx' is imported with both 'import' and 'import from'.
Module 'tests.unit.onnx' is imported with both 'import' and 'import from'.

logger = logging.getLogger(__name__)


def convert_to_fp16(
model: onnx.ModelProto,
*,
keep_io_types: bool = True,
op_block_list: list[str] | None = None,
) -> onnx.ModelProto:
"""Convert an ONNX model from FP32 to FP16 precision.

Uses onnxruntime.transformers.float16.convert_float_to_float16 internally.
No new dependencies — ORT is already a project dependency.

Note: ORT's converter mutates the model in-place and returns the same object.

Args:
model: Input ONNX ModelProto (will be mutated in-place by ORT).
keep_io_types: If True, preserve FP32 model inputs/outputs by inserting
Cast nodes at boundaries. Recommended for CPU-safe inference.
op_block_list: Op types to keep in FP32 (e.g., ["LayerNorm", "Softmax"]).
When None, ORT uses its DEFAULT_OP_BLOCK_LIST which includes ops
known to be numerically unsafe in FP16 (e.g., TopK, CumSum, etc.).

Returns:
The converted model (same object as input due to ORT in-place mutation).
"""
from onnx import TensorProto
from onnxruntime.transformers.float16 import convert_float_to_float16

# Skip if model is already FP16 (check floating-point initializer dtypes)
fp32_types = {TensorProto.FLOAT, TensorProto.DOUBLE, TensorProto.BFLOAT16}
initializers = model.graph.initializer
if initializers:
float_inits = [t for t in initializers if t.data_type in fp32_types | {TensorProto.FLOAT16}]
if float_inits and all(t.data_type == TensorProto.FLOAT16 for t in float_inits):
logger.info("Model is already FP16 — skipping conversion.")
return model

original_nodes = len(model.graph.node)

logger.info("Converting model to FP16...")
if keep_io_types:
logger.info(" Keeping I/O types as FP32")
if op_block_list:
logger.info(" Keeping ops in FP32: %s", op_block_list)

converted = convert_float_to_float16(
model,
keep_io_types=keep_io_types,
op_block_list=op_block_list,
)

# ORT's converter appends Cast nodes at the end of the node list (for
# keep_io_types), which breaks topological ordering. Re-sort the graph
# using ORT's own topological sort utility.
if keep_io_types:
from onnxruntime.transformers.onnx_model import OnnxModel

OnnxModel.graph_topological_sort(converted.graph)

converted_nodes = len(converted.graph.node)
if converted_nodes != original_nodes:
logger.info("FP16 conversion complete: %d -> %d nodes", original_nodes, converted_nodes)
else:
logger.info("FP16 conversion complete: %d nodes", converted_nodes)

return converted
Loading
Loading