diff --git a/src/modelinfo/cli.py b/src/modelinfo/cli.py index 18aa340..1d9da7a 100644 --- a/src/modelinfo/cli.py +++ b/src/modelinfo/cli.py @@ -34,6 +34,13 @@ def __call__(self, parser, namespace, values, option_string=None): parser.exit() +def _positive_int(value: str) -> int: + ivalue = int(value) + if ivalue < 1: + raise argparse.ArgumentTypeError("batch size must be at least 1") + return ivalue + + def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: parser = argparse.ArgumentParser( prog="modelinfo", @@ -52,6 +59,12 @@ def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: default=None, help="Context length for dynamic KV cache footprint calculation.", ) + parser.add_argument( + "--batch-size", + type=_positive_int, + default=1, + help="Batch size for dynamic KV cache footprint calculation.", + ) parser.add_argument( "--max-vram", type=float, @@ -106,7 +119,8 @@ def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: def analyze_model( file_path: str, context_override: int | None, - gpu_count: int = 1, + gpu_count: int = 1, + batch_size: int = 1, fetch_tensors: bool = False, topology: str = "pcie4", strategy: str = "tp", @@ -164,6 +178,7 @@ def analyze_model( footprint = calculate_footprint( tensors, context_length=context_length, + batch_size=batch_size, config=config, gpu_count=gpu_count, topology=topology, @@ -222,7 +237,8 @@ def main(argv: Sequence[str] | None = None) -> int: info = analyze_model( model_path, args.context, - gpu_count, + gpu_count=gpu_count, + batch_size=args.batch_size, fetch_tensors=args.tensors, topology=args.topology, strategy=args.strategy, @@ -240,7 +256,8 @@ def main(argv: Sequence[str] | None = None) -> int: info = analyze_model( file_path, args.context, - gpu_count, + gpu_count=gpu_count, + batch_size=args.batch_size, fetch_tensors=args.tensors, topology=args.topology, strategy=args.strategy, diff --git a/tests/test_cli.py b/tests/test_cli.py index 1a29100..e5f9c02 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,5 +1,6 @@ import pytest +import modelinfo.cli as cli from modelinfo import __version__ from modelinfo.cli import parse_args @@ -10,3 +11,69 @@ def test_version_flag_prints_installed_version(capsys): assert exc_info.value.code == 0 assert f"modelinfo {__version__}" in capsys.readouterr().out + + +def test_batch_size_flag_defaults_to_one(): + args = parse_args(["model.gguf"]) + + assert args.batch_size == 1 + + +def test_batch_size_flag_accepts_integer(): + args = parse_args(["--batch-size", "4", "model.gguf"]) + + assert args.batch_size == 4 + + +def test_batch_size_flag_rejects_zero(): + with pytest.raises(SystemExit) as exc_info: + parse_args(["--batch-size", "0", "model.gguf"]) + + assert exc_info.value.code == 2 + + +def test_batch_size_flag_rejects_negative(): + with pytest.raises(SystemExit) as exc_info: + parse_args(["--batch-size", "-1", "model.gguf"]) + + assert exc_info.value.code == 2 + + +def test_analyze_model_passes_batch_size_to_footprint(monkeypatch, tmp_path): + model_path = tmp_path / "model.gguf" + model_path.write_bytes(b"mock") + captured = {} + + def fake_parse_gguf_header(file_path): + assert file_path == str(model_path) + return { + "model.layers.0.self_attn.k_proj.weight": {"shape": [1, 1], "dtype": "F16"} + } + + def fake_calculate_footprint(tensors, *, context_length, batch_size, **kwargs): + captured["batch_size"] = batch_size + captured["context_length"] = context_length + return { + "total_params": 1, + "base_memory_bytes": 2.0, + "kv_cache_bytes": float(batch_size), + "overhead_bytes": 0.0, + "total_memory_bytes": 2.0 + batch_size, + "num_layers": 1, + "kv_dim": 1, + "primary_dtype": "F16", + "kv_is_estimate": False, + "penalty_percentage": 0.0, + "vllm_metrics": {}, + } + + monkeypatch.setattr(cli, "parse_gguf_header", fake_parse_gguf_header) + monkeypatch.setattr(cli, "calculate_footprint", fake_calculate_footprint) + monkeypatch.setattr( + cli, "identify_architecture_name", lambda tensors, num_layers, config: "Mock" + ) + + info = cli.analyze_model(str(model_path), context_override=128, batch_size=4) + + assert captured == {"batch_size": 4, "context_length": 128} + assert info["footprint"]["kv_cache_bytes"] == 4.0