Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions src/modelinfo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ def __call__(self, parser, namespace, values, option_string=None):
parser.exit()


def _positive_int(value: str) -> int:
ivalue = int(value)
if ivalue < 1:
raise argparse.ArgumentTypeError("batch size must be at least 1")
return ivalue


def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(
prog="modelinfo",
Expand All @@ -52,6 +59,12 @@ def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace:
default=None,
help="Context length for dynamic KV cache footprint calculation.",
)
parser.add_argument(
"--batch-size",
type=_positive_int,
default=1,
help="Batch size for dynamic KV cache footprint calculation.",
)
parser.add_argument(
"--max-vram",
type=float,
Expand Down Expand Up @@ -106,7 +119,8 @@ def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace:
def analyze_model(
file_path: str,
context_override: int | None,
gpu_count: int = 1,
gpu_count: int = 1,
batch_size: int = 1,
fetch_tensors: bool = False,
topology: str = "pcie4",
strategy: str = "tp",
Expand Down Expand Up @@ -164,6 +178,7 @@ def analyze_model(
footprint = calculate_footprint(
tensors,
context_length=context_length,
batch_size=batch_size,
config=config,
gpu_count=gpu_count,
topology=topology,
Expand Down Expand Up @@ -222,7 +237,8 @@ def main(argv: Sequence[str] | None = None) -> int:
info = analyze_model(
model_path,
args.context,
gpu_count,
gpu_count=gpu_count,
batch_size=args.batch_size,
fetch_tensors=args.tensors,
topology=args.topology,
strategy=args.strategy,
Expand All @@ -240,7 +256,8 @@ def main(argv: Sequence[str] | None = None) -> int:
info = analyze_model(
file_path,
args.context,
gpu_count,
gpu_count=gpu_count,
batch_size=args.batch_size,
fetch_tensors=args.tensors,
topology=args.topology,
strategy=args.strategy,
Expand Down
67 changes: 67 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

import modelinfo.cli as cli
from modelinfo import __version__
from modelinfo.cli import parse_args

Expand All @@ -10,3 +11,69 @@ def test_version_flag_prints_installed_version(capsys):

assert exc_info.value.code == 0
assert f"modelinfo {__version__}" in capsys.readouterr().out


def test_batch_size_flag_defaults_to_one():
args = parse_args(["model.gguf"])

assert args.batch_size == 1


def test_batch_size_flag_accepts_integer():
args = parse_args(["--batch-size", "4", "model.gguf"])

assert args.batch_size == 4


def test_batch_size_flag_rejects_zero():
with pytest.raises(SystemExit) as exc_info:
parse_args(["--batch-size", "0", "model.gguf"])

assert exc_info.value.code == 2


def test_batch_size_flag_rejects_negative():
with pytest.raises(SystemExit) as exc_info:
parse_args(["--batch-size", "-1", "model.gguf"])

assert exc_info.value.code == 2


def test_analyze_model_passes_batch_size_to_footprint(monkeypatch, tmp_path):
model_path = tmp_path / "model.gguf"
model_path.write_bytes(b"mock")
captured = {}

def fake_parse_gguf_header(file_path):
assert file_path == str(model_path)
return {
"model.layers.0.self_attn.k_proj.weight": {"shape": [1, 1], "dtype": "F16"}
}

def fake_calculate_footprint(tensors, *, context_length, batch_size, **kwargs):
captured["batch_size"] = batch_size
captured["context_length"] = context_length
return {
"total_params": 1,
"base_memory_bytes": 2.0,
"kv_cache_bytes": float(batch_size),
"overhead_bytes": 0.0,
"total_memory_bytes": 2.0 + batch_size,
"num_layers": 1,
"kv_dim": 1,
"primary_dtype": "F16",
"kv_is_estimate": False,
"penalty_percentage": 0.0,
"vllm_metrics": {},
}

monkeypatch.setattr(cli, "parse_gguf_header", fake_parse_gguf_header)
monkeypatch.setattr(cli, "calculate_footprint", fake_calculate_footprint)
monkeypatch.setattr(
cli, "identify_architecture_name", lambda tensors, num_layers, config: "Mock"
)

info = cli.analyze_model(str(model_path), context_override=128, batch_size=4)

assert captured == {"batch_size": 4, "context_length": 128}
assert info["footprint"]["kv_cache_bytes"] == 4.0
Loading