diff --git a/examples/09_MultiTurn/README.md b/examples/09_MultiTurn/README.md index cc0a5101..fec37767 100644 --- a/examples/09_MultiTurn/README.md +++ b/examples/09_MultiTurn/README.md @@ -51,36 +51,125 @@ clients; set `model_params.name` in the YAML to the same value. The runnable config is `examples/09_MultiTurn/kimi_agentic_benchmark.yaml`. -Key fields: - -- `type: online`: runs the benchmark through the online scheduler. -- `model_params.name`: model name sent in each OpenAI request. Keep it aligned - with the served model name. -- `model_params.temperature`, `top_p`, `max_new_tokens`: sampling settings sent - to the server. `max_new_tokens` is large because agent turns can be long. -- `model_params.chat_template_kwargs`: Kimi-specific template options for - reasoning preservation. -- First `datasets` entry `name`: label used in benchmark outputs. -- First `datasets` entry `type: performance`: multi-turn datasets are replayed as - performance datasets. -- First `datasets` entry `path`: JSONL dataset path to run. -- First `datasets` entry `multi_turn.turn_timeout_s`: per-turn deadline. A - timeout aborts the remaining turns in that conversation. -- First `datasets` entry `multi_turn.enable_salt`: appends a deterministic cache - salt to each conversation system prompt. -- First `datasets` entry `multi_turn.inject_tool_delay`: honors positive - `delay_seconds` values from client turns before issuing those turns. -- `settings.runtime.min_duration_ms`: minimum run duration. With no max duration - override, the run finishes when the dataset is exhausted. -- `settings.load_pattern.type: multi_turn`: enables conversation-aware issuing. -- `settings.load_pattern.target_concurrency`: maximum active conversations. - Each active conversation has at most one in-flight request. -- `settings.client.warmup_connections: 0`: avoids stale pre-warmed sockets with - servers that close idle connections quickly. -- `settings.client.max_idle_time`: connection idle lifetime. -- `endpoint_config.endpoints`: server URL list. -- `endpoint_config.api_type: openai`: use `/v1/chat/completions`. -- `report_dir`: output directory for events, snapshots, and reports. +### Fields + +- `name`: human-readable run name written to reports and logs. Change this when + creating a distinct benchmark config. +- `version`: config version label for this example. +- `type`: scheduler mode for the run. +- `model_params.name`: model name sent in each OpenAI request. Set this to the + model name served by the endpoint. +- `model_params.temperature`: sampling temperature sent to the server. +- `model_params.top_p`: nucleus sampling value sent to the server. +- `model_params.max_new_tokens`: per-turn generation cap. +- `model_params.chat_template_kwargs.thinking`: Kimi chat-template option. +- `model_params.chat_template_kwargs.preserve_thinking`: preserves + reasoning content in the rendered prompt. +- First dataset `name`: label used in benchmark outputs. Change this to match + the dataset variant being run. +- First dataset `type`: dataset role for this entry. +- First dataset `path`: JSONL dataset path to run. Set this to a real local or + mounted dataset path, for example `/path/to/agentic_combined.jsonl`. +- First dataset `accuracy_config.eval_method`: scorer used during finalization. + `multi_turn_inline` scores the performance replay outputs without issuing a + separate accuracy phase. +- First dataset `multi_turn.enable_salt`: applies deterministic salt + markers when issuing conversation instances so repeats do not reuse KV cache + by accident. +- First dataset `multi_turn.inject_tool_delay`: honors positive + `delay_seconds` values from the dataset before issuing user/tool turns. +- First dataset `multi_turn.num_trajectories_to_issue`: total number of + trajectories to start. Change this to scale runtime. +- First dataset `multi_turn.stop_issuing_on_first_user_complete`: controls only + whether the client keeps issuing after the measurement window ends. Performance + tracking always stops when the first concurrency slot finishes a trajectory and + there is no next trajectory left to assign. If this field is `true`, the client + stops issuing future turns at that point and drains already in-flight turns. If + this field is `false`, the client keeps replaying already-started active + trajectories to completion for accuracy/log coverage, but those later-issued + turns are outside the performance measurement window. +- `settings.runtime.min_duration_ms`: minimum run duration. Multi-turn replay + completion is controlled by trajectory budget and active conversation drain. +- `settings.load_pattern.type`: enables conversation-aware issuing. +- `settings.load_pattern.target_concurrency`: maximum active conversations. Each + active conversation has at most one in-flight request. Change this for the + target concurrency of the run. +- `settings.client.warmup_connections`: disables pre-warmed HTTP sockets. +- `settings.client.max_idle_time`: connection idle lifetime in seconds. +- `endpoint_config.endpoints`: server URL list. Replace with the endpoint URLs + for the run. +- `endpoint_config.api_type`: selects the endpoint protocol and route. +- `report_dir`: output directory for events, snapshots, scores, and reports. + Change this per run so outputs are not overwritten. + +### Benchmark Invariants + +For official Kimi agentic benchmark runs, keep these values fixed: + +- `version: "1.0"` +- `type: "online"` +- `model_params.temperature: 1.0` +- `model_params.top_p: 0.95` +- `model_params.max_new_tokens: 8192` +- `model_params.chat_template_kwargs.thinking: true` +- `model_params.chat_template_kwargs.preserve_thinking: true` +- First dataset `type: performance` +- First dataset `accuracy_config.eval_method: multi_turn_inline` +- `settings.runtime.min_duration_ms: 0` +- `settings.load_pattern.type: multi_turn` +- `settings.client.warmup_connections: 0` +- `settings.client.max_idle_time: 0.5` +- `endpoint_config.api_type: openai` + +The multi-turn dataset required defaults are: + +- First dataset `multi_turn.enable_salt: true` +- First dataset `multi_turn.inject_tool_delay: true` +- First dataset `multi_turn.stop_issuing_on_first_user_complete: false` + +Set `multi_turn.num_trajectories_to_issue` to an integer multiple of the +dataset trajectory count so each repeat has the same representation. Use +`multi_turn.stop_issuing_on_first_user_complete: true` only for faster +optimization/debug runs, not official benchmark runs. + +### Salting Mechanism + +When `multi_turn.enable_salt: true`, the strategy adds a short deterministic +`[salt: ...]` marker before the system prompt for the trajectory repeat and +another after the system prompt for the conversation. Each salt is four hex characters. +This restricts kv-cache reuse to: + +1. Fully allowed within a trajectory. +2. System prompt allowed within same iteration of the dataset. +3. Disallowed across multiple iterations of dataset. + +### Inline Accuracy + +When `accuracy_config.eval_method: multi_turn_inline` is set on the performance +dataset, the benchmark scores the generated `events.jsonl` during finalization +and writes `scores.json` under `report_dir`. The scorer uses the loaded +multi-turn dataset as ground truth, matches completed assistant responses back +to their conversation/turn ids, and compares them with the expected assistant +turns embedded in the dataset. It does not issue a separate accuracy phase. + +### Tail Management + +Multi-turn benchmarks can have a long tail because different users receive +trajectories with very different turn counts, delays, and generated lengths. In +large runs this tail can last up to an hour after steady-state work has already +ended, so the benchmark separates the performance window from the remaining +accuracy/logging drain. + +The benchmark stops performance tracking when the first active user finishes its +final assigned trajectory. It emits `STOP_PERFORMANCE_TRACKING` at that point to +avoid measuring the tail. Turns issued before this event remain in the +performance window even if they finish later; turns issued after it are excluded +from performance metrics. + +For final submissions, keep +`multi_turn.stop_issuing_on_first_user_complete: false` so the client finishes +already-started trajectories for accuracy. During optimization, set it to `true` +to stop issuing future turns at the performance boundary and shorten the tail. ## Run The Client diff --git a/examples/09_MultiTurn/kimi_agentic_benchmark.yaml b/examples/09_MultiTurn/kimi_agentic_benchmark.yaml index 273a46a2..a6570496 100644 --- a/examples/09_MultiTurn/kimi_agentic_benchmark.yaml +++ b/examples/09_MultiTurn/kimi_agentic_benchmark.yaml @@ -1,26 +1,28 @@ name: "kimi-agentic-benchmark" version: "1.0" -type: "online" +type: "online" # do not change. model_params: name: "/model" - temperature: 1.0 - top_p: 0.95 - max_new_tokens: 20000 # covers longest observed assistant turn (~18k tokens) + temperature: 1.0 # do not change. + top_p: 0.95 # do not change. + max_new_tokens: 8192 # do not change. chat_template_kwargs: - thinking: true - preserve_thinking: true + thinking: true # do not change. + preserve_thinking: true # do not change. datasets: - # Select the dataset to run by updating both `name` and `path`. - # Use agentic_coding for coding traces or agentic_workflow for workflow traces. - - name: agentic_coding - type: performance - path: /path/to/agentic_dataset.jsonl + - name: agentic_combined + type: performance # do not change. + path: /path/to/agentic_combined.jsonl + accuracy_config: + eval_method: multi_turn_inline # required benchmark default. multi_turn: - turn_timeout_s: 600.0 - enable_salt: true # add salt after system prompt to prevent cache reuse across trajectories - inject_tool_delay: true # add delay before user/tool turns + enable_salt: true # required benchmark default. + inject_tool_delay: true # required benchmark default. + num_trajectories_to_issue: 990 # Should be integer multiple of 990. + # Required benchmark default; set to true only for faster optimization/debug runs. + stop_issuing_on_first_user_complete: false settings: runtime: @@ -28,10 +30,8 @@ settings: load_pattern: type: multi_turn - target_concurrency: 8 + target_concurrency: 8 # Submission-specific concurrency. - # Mandatory: with the default warmup behaviour, every request fails with - # ConnectionResetError because uvicorn closes pre-warmed idle sockets after 5s. client: warmup_connections: 0 max_idle_time: 0.5 @@ -39,6 +39,6 @@ settings: endpoint_config: endpoints: - "http://localhost:8000" - api_type: openai + api_type: openai # do not change. report_dir: logs/kimi_agentic diff --git a/src/inference_endpoint/async_utils/services/metrics_aggregator/token_metrics.py b/src/inference_endpoint/async_utils/services/metrics_aggregator/token_metrics.py index 78ee09e2..3411d506 100644 --- a/src/inference_endpoint/async_utils/services/metrics_aggregator/token_metrics.py +++ b/src/inference_endpoint/async_utils/services/metrics_aggregator/token_metrics.py @@ -121,7 +121,7 @@ def _get_thread_tokenizer(self) -> PreTrainedTokenizerBase: """Return the tokenizer for the current thread, loading it if needed.""" if getattr(self._thread_local, "tokenizer", None) is None: self._thread_local.tokenizer = AutoTokenizer.from_pretrained( - self._tokenizer_name + self._tokenizer_name, trust_remote_code=True ) # Baseline = tokens contributed by a [user, empty-assistant] pair minus # the [user] prefix alone. Some templates (Qwen3-Coder, etc.) reject diff --git a/src/inference_endpoint/commands/benchmark/execute.py b/src/inference_endpoint/commands/benchmark/execute.py index 7019ef44..2dfabeb3 100644 --- a/src/inference_endpoint/commands/benchmark/execute.py +++ b/src/inference_endpoint/commands/benchmark/execute.py @@ -228,6 +228,35 @@ def _check_tokenizer_exists(model_name: str) -> bool: return False +def _resolve_accuracy_components( + dataset_name: str, accuracy_config: Any | None +) -> tuple[type[Scorer], type[Extractor] | None]: + """Validate scorer/extractor config and return resolved classes.""" + if accuracy_config is None or accuracy_config.eval_method is None: + raise InputValidationError( + f"Dataset '{dataset_name}' requires accuracy_config with eval_method" + ) + + try: + scorer_cls = Scorer.get(accuracy_config.eval_method) + except KeyError as exc: + raise InputValidationError(str(exc)) from exc + extractor_name = accuracy_config.extractor + if extractor_name is None: + if scorer_cls.REQUIRES_EXTRACTOR: + raise InputValidationError( + f"Dataset '{dataset_name}' uses scorer " + f"'{accuracy_config.eval_method}' which requires an extractor" + ) + extractor_cls: type[Extractor] | None = None + else: + try: + extractor_cls = Extractor.get(extractor_name) + except KeyError as exc: + raise InputValidationError(str(exc)) from exc + return scorer_cls, extractor_cls + + def _load_datasets( config: BenchmarkConfig, report_dir: Path ) -> tuple[Dataset, list[Dataset], list[AccuracyConfiguration]]: @@ -247,25 +276,10 @@ def _load_datasets( # Pack the evaluation parameters for each accuracy dataset for acc_cfg in accuracy_cfgs: - if ( - acc_cfg.accuracy_config is None - or acc_cfg.accuracy_config.eval_method is None - ): - raise InputValidationError( - f"Dataset '{acc_cfg.name}' requires accuracy_config with eval_method" - ) - - scorer_cls = Scorer.get(acc_cfg.accuracy_config.eval_method) - extractor_name = acc_cfg.accuracy_config.extractor - if extractor_name is None: - if scorer_cls.REQUIRES_EXTRACTOR: - raise InputValidationError( - f"Dataset '{acc_cfg.name}' uses scorer " - f"'{acc_cfg.accuracy_config.eval_method}' which requires an extractor" - ) - extractor_cls: type[Extractor] | None = None - else: - extractor_cls = Extractor.get(extractor_name) + scorer_cls, extractor_cls = _resolve_accuracy_components( + acc_cfg.name, acc_cfg.accuracy_config + ) + assert acc_cfg.accuracy_config is not None ds = DataLoaderFactory.create_loader( acc_cfg, num_repeats=acc_cfg.accuracy_config.num_repeats @@ -290,12 +304,13 @@ def _load_datasets( logger.info(f"Loaded {ds} - {ds.num_samples()} samples") if not accuracy_cfgs: - logger.info("No accuracy datasets provided") + logger.info("No separate accuracy datasets provided") if len(performance_cfgs) > 1: raise InputValidationError("Multiple performance datasets not supported") + perf_cfg = performance_cfgs[0] try: - dataloader = DataLoaderFactory.create_loader(performance_cfgs[0]) + dataloader = DataLoaderFactory.create_loader(perf_cfg) dataloader.load( api_type=config.endpoint_config.api_type, model_params=config.model_params ) @@ -307,6 +322,31 @@ def _load_datasets( except Exception as e: raise SetupError(f"Failed to load dataset: {e}") from e + if perf_cfg.accuracy_config is not None: + accuracy_config = perf_cfg.accuracy_config + if accuracy_config.num_repeats != 1: + raise InputValidationError( + f"Dataset '{perf_cfg.name}' is a performance dataset; " + "accuracy_config.num_repeats must be 1 because scoring runs on " + "already-issued performance outputs" + ) + scorer_cls, extractor_cls = _resolve_accuracy_components( + perf_cfg.name, accuracy_config + ) + + eval_configs.append( + AccuracyConfiguration( + scorer_cls, + extractor_cls, + "performance", + dataloader, + report_dir, + accuracy_config.ground_truth, + accuracy_config.num_repeats, + accuracy_config.extras or {}, + ) + ) + return dataloader, accuracy_datasets, eval_configs @@ -433,6 +473,8 @@ def _build_phases( # Accuracy phases — use eval_cfg.dataset_name as phase name so it matches # what Scorer._load_sample_index_map() looks up in sample_idx_map.json for eval_cfg in ctx.eval_configs: + if eval_cfg.dataset_name == "performance": + continue acc_ds = eval_cfg.dataset if isinstance(acc_ds, MultiTurnDataset): raise InputValidationError( @@ -859,15 +901,17 @@ def finalize_benchmark(ctx: BenchmarkContext, bench: BenchmarkResult) -> None: ) score, n_repeats = scorer_instance.score() assert eval_cfg.dataset.data is not None + num_samples = len(eval_cfg.dataset.data) + if eval_cfg.dataset_name == "performance": + num_samples = sum(phase.issued_count for phase in result.perf_results) accuracy_scores[eval_cfg.dataset_name] = { "dataset_name": eval_cfg.dataset_name, - "num_samples": len(eval_cfg.dataset.data), + "num_samples": num_samples, "extractor": ( eval_cfg.extractor.__name__ if eval_cfg.extractor is not None else None ), "ground_truth_column": eval_cfg.ground_truth_column, "score": score, - "n_repeats": n_repeats, } logger.info(f"Score for {eval_cfg.dataset_name}: {score} ({n_repeats} repeats)") diff --git a/src/inference_endpoint/config/schema.py b/src/inference_endpoint/config/schema.py index 2cfa35d7..2b3790eb 100644 --- a/src/inference_endpoint/config/schema.py +++ b/src/inference_endpoint/config/schema.py @@ -97,6 +97,7 @@ class ScorerMethod(str, Enum): ROUGE = "rouge" CODE_BENCH = "code_bench_scorer" SHOPIFY_CATEGORY_F1 = "shopify_category_f1" + MULTI_TURN_INLINE = "multi_turn_inline" VBENCH = "vbench" @@ -257,18 +258,23 @@ class MultiTurnConfig(BaseModel): response. A timeout aborts that turn and all remaining client turns of the same conversation because subsequent turns depend on the timed-out response. - use_dataset_history: If True, use pre-built message history from dataset. """ model_config = ConfigDict(extra="forbid", frozen=True) - turn_timeout_s: float = Field(default=300.0, gt=0) - use_dataset_history: bool = True + turn_timeout_s: float = Field( + default=86400.0, + gt=0, + description=( + "Per-turn timeout in seconds. A timeout aborts that turn and all " + "remaining turns in the same conversation." + ), + ) enable_salt: bool = Field( False, description=( - "Enable salt addition after system prompt to prevent KV cache reuse " - "across trajectories in multi-turn setting." + "Add deterministic salt markers before and after the system prompt " + "to prevent KV cache reuse across trajectories in multi-turn setting." ), ) inject_tool_delay: bool = Field( @@ -278,6 +284,24 @@ class MultiTurnConfig(BaseModel): "in dataset." ), ) + num_trajectories_to_issue: int | None = Field( + default=None, + gt=0, + description=( + "Number of conversation trajectories to start. Defaults to one pass " + "over the dataset; values above the dataset size repeat trajectories " + "with unique logical conversation ids." + ), + ) + stop_issuing_on_first_user_complete: bool = Field( + False, + description=( + "When performance tracking stops because the first concurrency slot " + "has no next trajectory left to assign, also stop issuing future " + "turns. If false, replay continues outside the performance window " + "for accuracy/log coverage." + ), + ) class Dataset(BaseModel): @@ -811,6 +835,14 @@ def _resolve_and_validate(self) -> Self: raise ValueError( "load_pattern.type=multi_turn requires the performance dataset to have multi_turn config" ) + if ( + lp.type == LoadPatternType.MULTI_TURN + and self.settings.runtime.n_samples_to_issue is not None + ): + raise ValueError( + "runtime.n_samples_to_issue is not supported for multi-turn runs; " + "use datasets[].multi_turn.num_trajectories_to_issue instead" + ) if has_multi_turn_perf_dataset and lp.type != LoadPatternType.MULTI_TURN: raise ValueError( f"Performance dataset with multi_turn config requires load_pattern.type=multi_turn, " diff --git a/src/inference_endpoint/config/templates/concurrency_template_full.yaml b/src/inference_endpoint/config/templates/concurrency_template_full.yaml index 4fef4afc..450aacd2 100644 --- a/src/inference_endpoint/config/templates/concurrency_template_full.yaml +++ b/src/inference_endpoint/config/templates/concurrency_template_full.yaml @@ -37,7 +37,7 @@ datasets: # Dataset configs prompt: question system: system_prompt accuracy_config: # Accuracy evaluation settings - eval_method: pass_at_1 # Scorer method | options: pass_at_1, string_match, rouge, code_bench_scorer, shopify_category_f1, vbench + eval_method: pass_at_1 # Scorer method | options: pass_at_1, string_match, rouge, code_bench_scorer, shopify_category_f1, multi_turn_inline, vbench ground_truth: ground_truth # Ground truth column name extractor: boxed_math_extractor # Answer extractor (abcd_extractor, boxed_math_extractor, identity_extractor, python_code_extractor) num_repeats: 1 # Repeat dataset N times for evaluation diff --git a/src/inference_endpoint/config/templates/offline_template_full.yaml b/src/inference_endpoint/config/templates/offline_template_full.yaml index 1f61837f..059bd0aa 100644 --- a/src/inference_endpoint/config/templates/offline_template_full.yaml +++ b/src/inference_endpoint/config/templates/offline_template_full.yaml @@ -37,7 +37,7 @@ datasets: # Dataset configs prompt: question system: system_prompt accuracy_config: # Accuracy evaluation settings - eval_method: pass_at_1 # Scorer method | options: pass_at_1, string_match, rouge, code_bench_scorer, shopify_category_f1, vbench + eval_method: pass_at_1 # Scorer method | options: pass_at_1, string_match, rouge, code_bench_scorer, shopify_category_f1, multi_turn_inline, vbench ground_truth: ground_truth # Ground truth column name extractor: boxed_math_extractor # Answer extractor (abcd_extractor, boxed_math_extractor, identity_extractor, python_code_extractor) num_repeats: 1 # Repeat dataset N times for evaluation diff --git a/src/inference_endpoint/config/templates/online_template_full.yaml b/src/inference_endpoint/config/templates/online_template_full.yaml index a212fa95..4d6101fa 100644 --- a/src/inference_endpoint/config/templates/online_template_full.yaml +++ b/src/inference_endpoint/config/templates/online_template_full.yaml @@ -37,7 +37,7 @@ datasets: # Dataset configs prompt: question system: system_prompt accuracy_config: # Accuracy evaluation settings - eval_method: pass_at_1 # Scorer method | options: pass_at_1, string_match, rouge, code_bench_scorer, shopify_category_f1, vbench + eval_method: pass_at_1 # Scorer method | options: pass_at_1, string_match, rouge, code_bench_scorer, shopify_category_f1, multi_turn_inline, vbench ground_truth: ground_truth # Ground truth column name extractor: boxed_math_extractor # Answer extractor (abcd_extractor, boxed_math_extractor, identity_extractor, python_code_extractor) num_repeats: 1 # Repeat dataset N times for evaluation diff --git a/src/inference_endpoint/dataset_manager/factory.py b/src/inference_endpoint/dataset_manager/factory.py index d41f51f4..80f875ec 100644 --- a/src/inference_endpoint/dataset_manager/factory.py +++ b/src/inference_endpoint/dataset_manager/factory.py @@ -23,7 +23,6 @@ from inference_endpoint.config.schema import Dataset as DatasetConfig from inference_endpoint.dataset_manager.dataset import Dataset, DatasetFormat -from inference_endpoint.exceptions import InputValidationError from .multi_turn_dataset import MultiTurnDataset from .transforms import ColumnRemap, MakeAdapterCompatible, Transform @@ -118,10 +117,4 @@ def create_loader(config: DatasetConfig, num_repeats: int = 1, **kwargs) -> Data dataset_id=dataset_id, num_repeats=num_repeats, ) - if config.multi_turn is not None and config.multi_turn.enable_salt: - if not isinstance(dataloader, MultiTurnDataset): - raise InputValidationError( - "enable_salt requires a multi-turn dataset loader" - ) - dataloader.enable_salt() return dataloader diff --git a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py index bb20153a..b082978a 100644 --- a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py +++ b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py @@ -15,7 +15,6 @@ """Multi-turn conversation dataset for conversational AI benchmarking.""" -import hashlib import logging from dataclasses import dataclass, field, replace from typing import Any @@ -63,10 +62,6 @@ class ConversationMetadata: pre_built_messages_by_key: dict[tuple[str, int], list[dict]] = field( default_factory=dict ) - current_turn_messages_by_key: dict[tuple[str, int], list[dict]] = field( - default_factory=dict - ) - system_prompts_by_conv: dict[str, str | None] = field(default_factory=dict) delay_seconds_by_key: dict[tuple[str, int], float] = field(default_factory=dict) @@ -116,20 +111,17 @@ def _expand_tool_results(row: dict | pd.Series) -> list[dict]: def _build_conversation_metadata( conv_id: Any, group: Any, - enable_salt: bool, ) -> tuple[ str, dict[tuple[str, int], list[dict]], - dict[tuple[str, int], list[dict]], - str | None, dict[tuple[str, int], float], list[ConversationSampleEntry], int, ]: """Build message history for all client turns in a single conversation. - Returns a tuple of (str_conv_id, pre_built_messages, current_turn_messages, - system_prompt, delay_seconds, samples, client_turns_count). + Returns a tuple of (str_conv_id, pre_built_messages, delay_seconds, samples, + client_turns_count). """ str_conv_id = str(conv_id) sorted_group = group.sort_values("turn") @@ -141,20 +133,8 @@ def _build_conversation_metadata( if val and isinstance(val, str): system_content = val break - if enable_salt and system_content: - salt_hex = hashlib.blake2b( - str_conv_id.encode("utf-8"), digest_size=8 - ).hexdigest() - system_content = f"{system_content}\n\n[cache_salt: {salt_hex}]" - elif enable_salt: - logger.warning( - "multi_turn.enable_salt requested but conversation %s has no " - "system prompt; cache salt not applied", - conv_id, - ) pre_built_messages_by_key: dict[tuple[str, int], list[dict]] = {} - current_turn_messages_by_key: dict[tuple[str, int], list[dict]] = {} delay_seconds_by_key: dict[tuple[str, int], float] = {} samples: list[ConversationSampleEntry] = [] @@ -201,7 +181,6 @@ def _build_conversation_metadata( pre_built_messages_by_key[(str_conv_id, t_n)] = ( list(history) + current_turn_msgs ) - current_turn_messages_by_key[(str_conv_id, t_n)] = current_turn_msgs history.extend(current_turn_msgs) delay_val = row.get("delay_seconds") @@ -226,8 +205,6 @@ def _build_conversation_metadata( return ( str_conv_id, pre_built_messages_by_key, - current_turn_messages_by_key, - system_content, delay_seconds_by_key, samples, client_turns_count, @@ -283,7 +260,6 @@ def __init__(self, dataframe: pd.DataFrame, **kwargs): self.dataframe = self.dataframe.loc[~metadata_rows].reset_index( drop=True ) - self._enable_salt = False self._conv_groups = dict( list(self.dataframe.groupby("conversation_id", sort=False, dropna=False)) ) @@ -293,9 +269,6 @@ def __init__(self, dataframe: pd.DataFrame, **kwargs): # Populated by load() after transforms; None until then. self.conversation_metadata: ConversationMetadata | None = None - def enable_salt(self) -> None: - self._enable_salt = True - def _validate_conversation_grouping(self) -> None: """Validate that all rows for each conversation_id appear consecutively in file order. @@ -510,23 +483,17 @@ def _build_metadata(self) -> ConversationMetadata: samples: list[ConversationSampleEntry] = [] client_turns_per_conv: dict[str, int] = {} pre_built_messages_by_key: dict[tuple[str, int], list[dict]] = {} - current_turn_messages_by_key: dict[tuple[str, int], list[dict]] = {} - system_prompts_by_conv: dict[str, str | None] = {} delay_seconds_by_key: dict[tuple[str, int], float] = {} for conv_id, group in self._conv_groups.items(): ( str_conv_id, partial_pre_built, - partial_current_turn, - system_prompt, partial_delay, conv_samples, client_turns_count, - ) = _build_conversation_metadata(conv_id, group, self._enable_salt) + ) = _build_conversation_metadata(conv_id, group) pre_built_messages_by_key.update(partial_pre_built) - current_turn_messages_by_key.update(partial_current_turn) - system_prompts_by_conv[str_conv_id] = system_prompt delay_seconds_by_key.update(partial_delay) samples.extend(conv_samples) client_turns_per_conv[str_conv_id] = client_turns_count @@ -537,8 +504,6 @@ def _build_metadata(self) -> ConversationMetadata: max_turns_per_conv=max(g["turn"].max() for g in self._conv_groups.values()), client_turns_per_conversation=client_turns_per_conv, pre_built_messages_by_key=pre_built_messages_by_key, - current_turn_messages_by_key=current_turn_messages_by_key, - system_prompts_by_conv=system_prompts_by_conv, delay_seconds_by_key=delay_seconds_by_key, ) diff --git a/src/inference_endpoint/evaluation/scoring.py b/src/inference_endpoint/evaluation/scoring.py index 911ec629..f58150e9 100644 --- a/src/inference_endpoint/evaluation/scoring.py +++ b/src/inference_endpoint/evaluation/scoring.py @@ -15,6 +15,7 @@ import inspect +import json import logging import os import re @@ -24,10 +25,11 @@ import tempfile import uuid from abc import ABC, abstractmethod -from collections import defaultdict +from collections import Counter, defaultdict from pathlib import Path from typing import Any, ClassVar +import msgspec import msgspec.json import numpy as np import pandas as pd @@ -47,7 +49,9 @@ _nltk = None from ..core.record import EventRecord, EventType, SampleEventType +from ..core.types import TextModelOutput from ..dataset_manager.dataset import Dataset +from ..dataset_manager.multi_turn_dataset import MultiTurnDataset from ..dataset_manager.predefined.shopify_product_catalogue import ProductMetadata from .extractor import Extractor, PythonCodeExtractor @@ -314,6 +318,501 @@ def score(self) -> tuple[float, int]: return result, 1 +class MultiTurnInlineScorer(Scorer, scorer_id="multi_turn_inline"): + """Score multi-turn performance replay outputs without issuing another phase.""" + + REQUIRES_EXTRACTOR = False + _EXECUTABLE_ALIASES: ClassVar[dict[str, str]] = { + "python": "python", + "python2": "python", + "python3": "python", + "py": "python", + "pip": "pip", + "pip3": "pip", + "pytest": "pytest", + "pylint": "pylint", + "sphinx-build": "sphinx", + "sphinx-quickstart": "sphinx", + "cython": "cython", + "make": "make", + "conda": "conda", + "cat": "cat", + "head": "head", + "tail": "tail", + "less": "cat", + "more": "cat", + "wc": "wc", + "diff": "diff", + "grep": "grep", + "egrep": "grep", + "fgrep": "grep", + "rg": "grep", + "ag": "grep", + "sed": "sed", + "awk": "awk", + "gawk": "awk", + "tr": "tr", + "sort": "sort", + "uniq": "uniq", + "cut": "cut", + "find": "find", + "ls": "ls", + "locate": "find", + "xargs": "xargs", + "cp": "cp", + "mv": "mv", + "rm": "rm", + "mkdir": "mkdir", + "touch": "touch", + "tee": "tee", + "source": "source", + ".": "source", + "which": "which", + "alias": "alias", + "unset": "unset", + "export": "export", + "git": "git", + "curl": "curl", + "wget": "curl", + "true": "true", + "false": "false", + "timeout": "timeout", + "date": "date", + "apt-get": "apt", + "apt": "apt", + "yum": "yum", + } + _SHELL_WRAPPERS: ClassVar[set[str]] = { + "env", + "time", + "nice", + "sudo", + "exec", + "command", + } + _REPEAT_SUFFIX_RE: ClassVar[re.Pattern[str]] = re.compile(r"__repeat_(\d+)$") + _WORKFLOW_CONVERSATION_RE: ClassVar[re.Pattern[str]] = re.compile(r"^sim_\d+$") + _INTENT_RE: ClassVar[re.Pattern[str]] = re.compile( + r"\bintent:\s*(I\d{3})\b", re.IGNORECASE + ) + _BARE_INTENT_RE: ClassVar[re.Pattern[str]] = re.compile(r"\bI(\d{3})\b") + _COMMAND_SEPARATOR_RE: ClassVar[re.Pattern[str]] = re.compile(r"\|\||\||&&|;|\n") + _QUOTED_RE: ClassVar[re.Pattern[str]] = re.compile( + r"'[^']*'|\"(?:[^\"\\]|\\.)*\"|`[^`]*`" + ) + _ENV_ASSIGNMENT_RE: ClassVar[re.Pattern[str]] = re.compile( + r"^[A-Za-z_][A-Za-z0-9_]*=" + ) + _PY_VERSION_SUFFIX_RE: ClassVar[re.Pattern[str]] = re.compile(r"\.\d+(\.\d+)?$") + + def __init__( + self, + dataset_name: str, + dataset: Dataset, + report_dir: os.PathLike, + extractor: type[Extractor] | None = None, + ground_truth_column: str | None = None, + scores_filename: str = "scores.json", + ): + """Initialize a scorer for already-issued multi-turn performance events. + + The scorer intentionally does not use an extractor or a single + ``ground_truth`` column. Ground truth is derived from expected assistant + turns in the loaded ``MultiTurnDataset`` dataframe. + + Example: + A performance dataset config such as + ``accuracy_config.eval_method: multi_turn_inline`` instantiates this + scorer with ``dataset_name="performance"`` so it reads the + performance phase's entries from ``sample_idx_map.json``. + """ + if extractor is not None: + raise ValueError("MultiTurnInlineScorer does not use an extractor") + super().__init__( + dataset_name=dataset_name, + dataset=dataset, + report_dir=report_dir, + extractor=None, + ground_truth_column=ground_truth_column, + ) + self.scores_filename = scores_filename + + def score_single_sample(self, value: str, ground_truth: str) -> float: + """Reject single-sample scoring for the conversation-level scorer. + + Multi-turn accuracy depends on neighboring turns and conversation ids, + so a single output string cannot be scored in isolation. + + Example: + ``score_single_sample("answer", "expected")`` raises + ``RuntimeError``; callers should use ``score()``. + """ + raise RuntimeError( + "MultiTurnInlineScorer scores whole conversations; call score()." + ) + + def score(self) -> tuple[float | None, int]: + """Score completed multi-turn performance outputs. + + The method builds expected assistant turns from the loaded dataset, + reads issued turns and model assistant completions from ``events.jsonl``, + identifies each conversation as workflow or coding, and averages issued + turns with scorable ground truth. Issued turns without a model output + contribute score ``0``. + + Examples: + A workflow turn with ``intent_codes=["I042"]`` scores ``1.0`` when + the model text contains ``intent: I042``. + + A coding turn with expected bash command ``{"cmd": "python test.py"}`` + is scored by comparing normalized executables such as ``["python"]`` + against the model's bash tool calls. + """ + if not isinstance(self.dataset, MultiTurnDataset): + raise TypeError("MultiTurnInlineScorer requires a MultiTurnDataset") + assert ( + self.dataset.dataframe is not None + ), f"Dataset {self.dataset} has no dataframe loaded" + + expected = self._expected_assistant_turns() + scorable_expected: dict[tuple[str, int], dict[str, Any]] = {} + excluded_turns: list[dict[str, Any]] = [] + for (conversation_id, client_turn), ground_truth in sorted(expected.items()): + domain = ( + "workflow" + if self._WORKFLOW_CONVERSATION_RE.match(conversation_id) + else "coding" + ) + has_ground_truth = ( + bool(self._ground_truth_intents(ground_truth)) + if domain == "workflow" + else bool(self._bash_actions(ground_truth)) + ) + if has_ground_truth: + scorable_expected[(conversation_id, client_turn)] = ground_truth + else: + excluded_turns.append( + { + "conversation_id": conversation_id, + "turn": ground_truth["_assistant_turn"], + "domain": domain, + "exclude_reason": "no ground truth", + } + ) + + issued_turns, model_turns = self._issued_and_completed_model_turns( + set(expected) + ) + issued_repeats = sorted({key[1] for key in issued_turns}) + scorable_issued_turns = sorted( + key for key in issued_turns if (key[0], key[2]) in scorable_expected + ) + + total_score = 0.0 + n_scored = 0 + domain_totals = {"coding": 0.0, "workflow": 0.0} + domain_counts = {"coding": 0, "workflow": 0} + per_turn: list[dict[str, Any]] = [] + + for conversation_id, repeat_id, client_turn in scorable_issued_turns: + ground_truth = scorable_expected[(conversation_id, client_turn)] + key = (conversation_id, repeat_id, client_turn) + model = model_turns.get(key) + domain = ( + "workflow" + if self._WORKFLOW_CONVERSATION_RE.match(conversation_id) + else "coding" + ) + row: dict[str, Any] = { + "conversation_id": conversation_id, + "repeat": repeat_id, + "turn": ground_truth["_assistant_turn"], + "domain": domain, + } + + if model is None: + row["missing"] = True + model = {"role": "assistant"} + + score: float + if domain == "workflow": + gt_intents = self._ground_truth_intents(ground_truth) + model_intent = self._model_intent(model) + row["gt_intents"] = sorted(gt_intents) + row["model_intent"] = model_intent + score = 1.0 if model_intent in gt_intents else 0.0 + else: + gt_actions = self._bash_actions(ground_truth) + model_actions = self._bash_actions(model) + row["gt_actions"] = gt_actions + row["model_actions"] = model_actions + gt_counts = Counter(gt_actions) + model_counts = Counter(model_actions) + union = sum((gt_counts | model_counts).values()) + score = sum((gt_counts & model_counts).values()) / union + + row["score"] = round(score, 4) + per_turn.append(row) + total_score += score + n_scored += 1 + domain_totals[domain] += score + domain_counts[domain] += 1 + + expected_outputs = set(scorable_issued_turns) + observed_outputs = { + key + for key, model in model_turns.items() + if model and key in expected_outputs + } + missing_outputs = len(expected_outputs - observed_outputs) + final_score = round(total_score / n_scored, 4) if n_scored else None + result: dict[str, Any] = { + "score": final_score, + "turns": { + "issued": len(issued_turns), + "expected": len(expected_outputs), + "observed": len(observed_outputs), + "missing": missing_outputs, + "scored": n_scored, + }, + "domains": { + domain: { + "score": round(domain_totals[domain] / domain_counts[domain], 4), + "scored": domain_counts[domain], + } + for domain in ("coding", "workflow") + if domain_counts[domain] + }, + "per_turn": per_turn, + } + if excluded_turns: + result["excluded_turns"] = excluded_turns + + out_path = self.report_dir / self.scores_filename + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_text(json.dumps(result, indent=2)) + return final_score, len(issued_repeats) + + def _expected_assistant_turns(self) -> dict[tuple[str, int], dict[str, Any]]: + """Return expected assistant turns keyed by source conversation and turn. + + The dataset stores alternating client-side rows and expected assistant + rows. This method pairs each ``user`` or ``tool`` row with the following + ``assistant`` row and uses the client row's turn as the event-log turn + to match. + + Example: + Rows ``conv1/user/turn=1`` followed by ``conv1/assistant/turn=2`` + produce ``expected[("conv1", 1)]`` with ``"_assistant_turn": 2``. + """ + assert ( + self.dataset.dataframe is not None + ), f"Dataset {self.dataset} has no dataframe loaded" + + rows_by_conversation: dict[str, list[dict[str, Any]]] = defaultdict(list) + for raw_row in self.dataset.dataframe.to_dict("records"): + row: dict[str, Any] = {} + for field, value in raw_row.items(): + try: + row[field] = None if value != value else value + except (TypeError, ValueError): + row[field] = value + conversation_id = row.get("conversation_id") + if conversation_id is not None: + rows_by_conversation[str(conversation_id)].append(row) + + expected: dict[tuple[str, int], dict[str, Any]] = {} + for conversation_id, rows in rows_by_conversation.items(): + rows.sort(key=lambda row: int(row.get("turn") or 0)) + for row, next_row in zip(rows, rows[1:], strict=False): + if row.get("role") not in ("user", "tool"): + continue + if next_row.get("role") != "assistant": + continue + try: + client_turn = int(row.get("turn") or 0) + assistant_turn = int(next_row.get("turn") or 0) + except (TypeError, ValueError): + continue + expected[(conversation_id, client_turn)] = { + **next_row, + "_assistant_turn": assistant_turn, + } + return expected + + def _issued_and_completed_model_turns( + self, expected_keys: set[tuple[str, int]] + ) -> tuple[ + set[tuple[str, int, int]], + dict[tuple[str, int, int], dict[str, Any] | None], + ]: + """Read issued turns and completed model outputs from ``events.jsonl``. + + ISSUED records define the scoring denominator. COMPLETE records are + joined by ``sample_uuid`` and may carry ``None`` data for failed turns, + which keeps those turns in the denominator with score ``0``. + + Example: + ISSUED conversation id ``"conv1__repeat_3"`` and turn ``1`` becomes + issued key ``("conv1", 3, 1)``. A matching COMPLETE record with + ``data=None`` is returned as ``model_turns[("conv1", 3, 1)] = None``. + """ + events_path = self.report_dir / "events.jsonl" + if not events_path.exists(): + raise FileNotFoundError(f"Events log file not found at {events_path}") + + decoder = msgspec.json.Decoder(type=EventRecord, dec_hook=EventType.decode_hook) + uuid_to_key: dict[str, tuple[str, int, int]] = {} + completed_by_uuid: dict[str, dict[str, Any] | None] = {} + issued_turns: set[tuple[str, int, int]] = set() + model_turns: dict[tuple[str, int, int], dict[str, Any] | None] = {} + with events_path.open() as f: + for line_no, line in enumerate(f, 1): + stripped = line.strip() + if not stripped: + continue + try: + record = decoder.decode(stripped) + except msgspec.DecodeError as exc: + logger.warning( + "Skipping malformed event log line %d in %s: %s", + line_no, + events_path, + exc, + ) + continue + if record.event_type not in ( + SampleEventType.ISSUED, + SampleEventType.COMPLETE, + ): + continue + if record.turn is None or not record.conversation_id: + continue + + conversation_id = record.conversation_id + repeat_id = 1 + repeat_match = self._REPEAT_SUFFIX_RE.search(conversation_id) + if repeat_match is not None: + conversation_id = conversation_id[: repeat_match.start()] + repeat_id = int(repeat_match.group(1)) + turn = int(record.turn) + if (conversation_id, turn) not in expected_keys: + continue + + key = (conversation_id, repeat_id, turn) + if record.event_type == SampleEventType.ISSUED: + uuid_to_key[record.sample_uuid] = key + issued_turns.add(key) + if record.sample_uuid in completed_by_uuid: + model_turns[key] = completed_by_uuid[record.sample_uuid] + continue + + model: dict[str, Any] | None = None + if isinstance(record.data, TextModelOutput): + content, reasoning, tool_calls = record.data.as_message_parts() + model = { + "role": "assistant", + "content": content, + "reasoning_content": reasoning, + "tool_calls": list(tool_calls) if tool_calls else None, + } + if model is not None or record.sample_uuid not in completed_by_uuid: + completed_by_uuid[record.sample_uuid] = model + if record.sample_uuid in uuid_to_key: + key = uuid_to_key[record.sample_uuid] + if model is not None or key not in model_turns: + model_turns[key] = model + return issued_turns, model_turns + + def _ground_truth_intents(self, turn: dict[str, Any]) -> set[str]: + """Extract valid workflow intent codes from a ground-truth turn. + + Example: + ``{"intent_codes": ["i001", "I002", None]}`` returns + ``{"I001", "I002"}``. + """ + codes = turn.get("intent_codes") + if not isinstance(codes, list | tuple): + return set() + return {code.upper() for code in codes if isinstance(code, str) and code} + + def _model_intent(self, turn: dict[str, Any]) -> str | None: + """Extract the model's workflow intent code from text fields. + + The explicit ``intent: I123`` form is preferred. If absent, the last bare + ``I123`` token in ``reasoning_content`` or ``content`` is used. + + Example: + ``{"content": "final intent: I042"}`` returns ``"I042"``. + """ + for field in ("reasoning_content", "content"): + text = turn.get(field) or "" + if not isinstance(text, str): + continue + match = self._INTENT_RE.search(text) + if match is not None: + return match.group(1).upper() + for field in ("reasoning_content", "content"): + text = turn.get(field) or "" + if not isinstance(text, str): + continue + matches = list(self._BARE_INTENT_RE.finditer(text)) + if matches: + return f"I{matches[-1].group(1)}" + return None + + def _bash_actions(self, turn: dict[str, Any]) -> list[str]: + """Extract normalized bash executable names from assistant tool calls. + + Only ``bash`` function tool calls are considered. Shell wrappers, + leading environment assignments, command paths, and common aliases are + normalized before scoring. + + Example: + A tool call with ``{"cmd": "CUDA_VISIBLE_DEVICES=0 /usr/bin/python3 -m pytest"}`` + returns ``["python"]``. + """ + tool_calls = turn.get("tool_calls") + if not isinstance(tool_calls, list | tuple): + return [] + + actions: list[str] = [] + for tool_call in tool_calls: + if not isinstance(tool_call, dict): + continue + fn = tool_call.get("function") or {} + if not isinstance(fn, dict) or fn.get("name") != "bash": + continue + args = fn.get("arguments") + if isinstance(args, str): + try: + args = json.loads(args) + except json.JSONDecodeError: + continue + if not isinstance(args, dict): + continue + command = args.get("command") or args.get("cmd") + if not isinstance(command, str): + continue + + command = self._QUOTED_RE.sub(" ", command) + for stage in self._COMMAND_SEPARATOR_RE.split(command): + tokens = stage.split() + while tokens and ( + self._ENV_ASSIGNMENT_RE.match(tokens[0]) + or tokens[0] in self._SHELL_WRAPPERS + ): + tokens = tokens[1:] + if not tokens: + continue + executable = tokens[0].rsplit("/", 1)[-1].lower() + executable = self._PY_VERSION_SUFFIX_RE.sub("", executable) + action = self._EXECUTABLE_ALIASES.get(executable) + if action: + actions.append(action) + return actions + + class LiveCodeBenchScorer(Scorer, scorer_id="code_bench_scorer"): """Scorer for LiveCodeBench code generation tasks. diff --git a/src/inference_endpoint/load_generator/conversation_manager.py b/src/inference_endpoint/load_generator/conversation_manager.py index 1b0834bb..8ba9a3bf 100644 --- a/src/inference_endpoint/load_generator/conversation_manager.py +++ b/src/inference_endpoint/load_generator/conversation_manager.py @@ -16,8 +16,7 @@ """Conversation state management for multi-turn benchmarking.""" import logging -from dataclasses import dataclass, field -from typing import Any +from dataclasses import dataclass logger = logging.getLogger(__name__) @@ -28,15 +27,12 @@ class ConversationState: Attributes: conversation_id: Unique identifier for this conversation. - message_history: Accumulated message list (populated only when - use_dataset_history=False; empty otherwise). completed_turns: Turns with responses (success or failure) — observability only. failed_turns: Turns that failed — observability only. expected_client_turns: Expected total client turns (for completion detection). """ conversation_id: str - message_history: list[dict[str, Any]] = field(default_factory=list) completed_turns: int = 0 failed_turns: int = 0 expected_client_turns: int | None = None @@ -70,27 +66,20 @@ def get_or_create( self, conversation_id: str, expected_client_turns: int | None = None, - system_message: dict[str, Any] | None = None, ) -> ConversationState: """Return existing state or create a new one. Args: conversation_id: Unique identifier for conversation. expected_client_turns: Expected number of client turns. - system_message: System message to prepend to message_history - (only used when use_dataset_history=False and state is new). Returns: ConversationState for this conversation. """ if conversation_id not in self._conversations: - initial_history: list[dict[str, Any]] = ( - [system_message] if system_message is not None else [] - ) self._conversations[conversation_id] = ConversationState( conversation_id=conversation_id, expected_client_turns=expected_client_turns, - message_history=initial_history, ) return self._conversations[conversation_id] @@ -114,18 +103,11 @@ def _log_if_complete(self, state: ConversationState, conversation_id: str) -> No def mark_turn_complete( self, conversation_id: str, - response: str, - store_in_history: bool = False, - metadata: dict[str, Any] | None = None, ) -> None: """Record a successful response. Args: conversation_id: Conversation ID. - response: Model output (appended to history when store_in_history=True). - store_in_history: When True, append response to message_history. - metadata: Optional response metadata; tool_calls are preserved in history - when present (only used when store_in_history=True). Raises: KeyError: If conversation_id not found. @@ -133,20 +115,12 @@ def mark_turn_complete( state = self._conversations.get(conversation_id) if state is None: raise KeyError(f"Conversation {conversation_id} not initialized") - if store_in_history: - tool_calls = metadata.get("tool_calls") if metadata else None - if response or tool_calls: - msg: dict[str, Any] = {"role": "assistant", "content": response or None} - if tool_calls: - msg["tool_calls"] = tool_calls - state.message_history.append(msg) state.completed_turns += 1 self._log_if_complete(state, conversation_id) def mark_turn_failed( self, conversation_id: str, - store_in_history: bool = False, ) -> None: """Record a failed response. @@ -154,7 +128,6 @@ def mark_turn_failed( Args: conversation_id: Conversation ID. - store_in_history: When True, append error placeholder to message_history. Raises: KeyError: If conversation_id not found. @@ -162,10 +135,6 @@ def mark_turn_failed( state = self._conversations.get(conversation_id) if state is None: raise KeyError(f"Conversation {conversation_id} not initialized") - if store_in_history: - state.message_history.append( - {"role": "assistant", "content": "[ERROR: Turn failed or timed out]"} - ) state.completed_turns += 1 state.failed_turns += 1 logger.warning(f"Turn failed for conversation {conversation_id}") diff --git a/src/inference_endpoint/load_generator/multi_turn_strategy.py b/src/inference_endpoint/load_generator/multi_turn_strategy.py index 4cecccd1..7092a188 100644 --- a/src/inference_endpoint/load_generator/multi_turn_strategy.py +++ b/src/inference_endpoint/load_generator/multi_turn_strategy.py @@ -16,28 +16,29 @@ """Async multi-turn load strategy implementing the LoadStrategy protocol.""" import asyncio +import hashlib import logging import time -from collections import defaultdict, deque +from collections import defaultdict from typing import Any from ..config.schema import MultiTurnConfig from ..core.record import ErrorEventType, EventRecord, SampleEventType -from ..core.types import ErrorData, QueryResult, TextModelOutput +from ..core.types import ErrorData, QueryResult from ..dataset_manager.multi_turn_dataset import ConversationMetadata from ..exceptions import InputValidationError -from .conversation_manager import ConversationManager, ConversationState +from .conversation_manager import ConversationManager from .strategy import PhaseIssuerProtocol logger = logging.getLogger(__name__) # Default turn timeout when no MultiTurnConfig is provided. -_DEFAULT_TURN_TIMEOUT_S = 300.0 +_DEFAULT_TURN_TIMEOUT_S = 86400.0 ConversationTurn = tuple[int, int] ConversationTurns = list[ConversationTurn] -PendingConversation = tuple[str, ConversationTurns] -ActiveConversationState = tuple[ConversationTurns, int] +ActiveConversationState = tuple[str, ConversationTurns, int, int] +ConversationInstance = tuple[str, str, ConversationTurns, int] class MultiTurnStrategy: @@ -88,42 +89,26 @@ def __init__( self._conv_manager = conversation_manager self._dataset_metadata = dataset_metadata self._multi_turn_config = multi_turn_config + self._num_trajectories_to_issue = ( + multi_turn_config.num_trajectories_to_issue + if multi_turn_config is not None + else None + ) + self._stop_issuing_on_first_user_complete = ( + multi_turn_config.stop_issuing_on_first_user_complete + if multi_turn_config is not None + else False + ) self._turn_timeout_s = ( multi_turn_config.turn_timeout_s if multi_turn_config is not None else _DEFAULT_TURN_TIMEOUT_S ) self._target_concurrency = target_concurrency - self._store_in_history = ( - not multi_turn_config.use_dataset_history - if multi_turn_config is not None - else False + self._enable_salt = ( + multi_turn_config.enable_salt if multi_turn_config is not None else False ) - # Dataset-supplied `role: tool` turns carry baked tool_call_ids that - # cannot reference the live model's freshly generated ids — reject them. - # Datasets that only declare `tools` on user turns or have `tool_calls` - # in scripted assistant rows are fine: live history never replays - # scripted assistant rows, so model-generated ids stay self-consistent. - if self._store_in_history: - tool_turn_keys = [ - key - for key, msgs in dataset_metadata.current_turn_messages_by_key.items() - if any(m.get("role") == "tool" for m in msgs) - ] - if tool_turn_keys: - raise InputValidationError( - "Multi-turn with tool result rows requires use_dataset_history=True. " - "Live-history mode cannot replay dataset tool_call_ids against " - "freshly generated model responses. " - f"Offending turn(s): {tool_turn_keys[:5]}" - + ( - f" (+{len(tool_turn_keys) - 5} more)" - if len(tool_turn_keys) > 5 - else "" - ) - ) - # Composite on_sample_complete callback set by execute.py; used by # _handle_timeout to route synthetic failure results. self._session_on_sample_complete: Any | None = None @@ -131,11 +116,9 @@ def __init__( # Maps query_id -> conversation_id for routing completions. self._inflight: dict[str, str] = {} - # Cached ConversationState refs for O(1) lookup in on_sample_complete. - self._conv_states: dict[str, ConversationState] = {} # Event-driven state — populated in execute(). - self._pending_convs: deque[PendingConversation] = deque() + self._base_convs: list[tuple[str, ConversationTurns]] = [] self._active_iters: dict[str, ActiveConversationState] = {} self._timeout_handles: dict[str, asyncio.TimerHandle] = {} self._delay_handles: dict[str, asyncio.TimerHandle] = {} @@ -143,6 +126,9 @@ def __init__( self._all_done: asyncio.Event | None = None self._loop: asyncio.AbstractEventLoop | None = None self._phase_issuer: PhaseIssuerProtocol | None = None + self._stopping = False + self._started_trajectory_count = 0 + self._performance_tracking_stopped = False async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: """Drive multi-turn sample issuance. @@ -157,6 +143,11 @@ async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: self._loop = asyncio.get_running_loop() self._all_done = asyncio.Event() self._error = None + self._stopping = False + self._active_iters.clear() + self._inflight.clear() + self._started_trajectory_count = 0 + self._performance_tracking_stopped = False conv_samples: dict[str, ConversationTurns] = defaultdict(list) for sample_meta in self._dataset_metadata.samples: @@ -164,31 +155,12 @@ async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: assert sample_meta.sample_index is not None conv_samples[conv_id].append((sample_meta.sample_index, sample_meta.turn)) - # Pre-create all conversation states before issuing any turns (no locking needed). - sys_prompts = self._dataset_metadata.system_prompts_by_conv - for conv_id, turns in conv_samples.items(): - sys_content = sys_prompts.get(conv_id) if self._store_in_history else None - system_message = ( - {"role": "system", "content": sys_content} - if sys_content is not None - else None - ) - state = self._conv_manager.get_or_create( - conv_id, - expected_client_turns=len(turns), - system_message=system_message, - ) - self._conv_states[conv_id] = state - - # Build pending queue (sorted turns per conversation). - for conv_id, turns in conv_samples.items(): - self._pending_convs.append((conv_id, sorted(turns, key=lambda x: x[1]))) - - n_to_start = ( - min(self._target_concurrency, len(self._pending_convs)) - if self._target_concurrency is not None and self._target_concurrency > 0 - else len(self._pending_convs) - ) + self._base_convs = [ + (conv_id, sorted(turns, key=lambda x: x[1])) + for conv_id, turns in conv_samples.items() + ] + self._validate_salt_system_prompts() + n_to_start = self._initial_conversations_to_start() try: for _ in range(n_to_start): self._start_conversation() @@ -215,30 +187,84 @@ async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: ) self._inflight.clear() + def _initial_conversations_to_start(self) -> int: + if not self._base_convs or not self._has_trajectory_budget(): + return 0 + trajectory_budget = self._trajectory_budget() + if self._target_concurrency is not None and self._target_concurrency > 0: + n_to_start = self._target_concurrency + else: + n_to_start = len(self._base_convs) + return min(n_to_start, trajectory_budget) + + def _trajectory_budget(self) -> int: + if self._num_trajectories_to_issue is not None: + return self._num_trajectories_to_issue + return len(self._base_convs) + + def _has_trajectory_budget(self) -> bool: + return self._started_trajectory_count < self._trajectory_budget() + + def _has_more_conversation_instances(self) -> bool: + return bool(self._base_convs and self._has_trajectory_budget()) + + def _next_conversation_instance(self) -> ConversationInstance | None: + if not self._has_more_conversation_instances(): + return None + + source_index = self._started_trajectory_count % len(self._base_convs) + source_id, turns = self._base_convs[source_index] + instance_id = self._started_trajectory_count // len(self._base_convs) + 1 + logical_id = ( + source_id if instance_id == 1 else f"{source_id}__repeat_{instance_id}" + ) + self._started_trajectory_count += 1 + return logical_id, source_id, turns, instance_id + def _has_work_remaining(self) -> bool: + if self._stopping: + return bool(self._inflight or self._delay_handles) return bool( - self._pending_convs - or self._active_iters + self._active_iters or self._inflight or self._delay_handles + or self._has_more_conversation_instances() ) def _start_conversation(self) -> None: """Pop the next conversation from the pending queue and issue its first turn.""" - conv_id, turns = self._pending_convs.popleft() - self._active_iters[conv_id] = (turns, 0) - self._issue_next_turn(conv_id) + if self._stopping: + return + instance = self._next_conversation_instance() + if instance is None: + self._fill_slot() + return + logical_id, source_id, turns, repeat_id = instance + self._create_conversation_state(logical_id, turns) + self._active_iters[logical_id] = (source_id, turns, 0, repeat_id) + self._issue_next_turn(logical_id) + + def _create_conversation_state( + self, + logical_id: str, + turns: ConversationTurns, + ) -> None: + self._conv_manager.get_or_create( + logical_id, + expected_client_turns=len(turns), + ) def _issue_next_turn(self, conv_id: str) -> None: """Schedule the next turn for conv_id, applying inter-turn delay if set.""" + if self._stopping: + return state = self._active_iters.get(conv_id) if state is None: return - turns, cursor = state + source_id, turns, cursor, repeat_id = state if cursor >= len(turns): - del self._active_iters[conv_id] - self._fill_slot() + self._finish_conversation(conv_id) return idx, turn = turns[cursor] @@ -249,25 +275,37 @@ def _issue_next_turn(self, conv_id: str) -> None: and self._multi_turn_config.inject_tool_delay ): delay_map = self._dataset_metadata.delay_seconds_by_key - delay = float(delay_map.get((conv_id, turn), 0.0)) + delay = float(delay_map.get((source_id, turn), 0.0)) if delay > 0.0: assert self._loop is not None handle = self._loop.call_later( - delay, self._issue_turn_now, conv_id, idx, turn + delay, self._issue_turn_now_safely, conv_id, idx, turn ) self._delay_handles[conv_id] = handle else: self._issue_turn_now(conv_id, idx, turn) + def _issue_turn_now_safely(self, conv_id: str, idx: int, turn: int) -> None: + """Issue a delayed turn and surface callback failures to execute().""" + try: + self._issue_turn_now(conv_id, idx, turn) + except Exception as exc: + logger.exception("Error issuing delayed turn for %s", conv_id) + self._error = exc + if self._all_done is not None: + self._all_done.set() + def _issue_turn_now(self, conv_id: str, idx: int, turn: int) -> None: """Issue a single turn to the phase issuer.""" self._delay_handles.pop(conv_id, None) + if self._stopping: + return active_iter = self._active_iters.get(conv_id) if active_iter is None: return - turns, cursor = active_iter + source_id, turns, cursor, repeat_id = active_iter if cursor >= len(turns): return expected_idx, expected_turn = turns[cursor] @@ -279,23 +317,12 @@ def _issue_turn_now(self, conv_id: str, idx: int, turn: int) -> None: turn, ) return - self._active_iters[conv_id] = (turns, cursor + 1) - state = self._conv_states[conv_id] - - data_override: dict[str, Any] | None = None - current_turn_messages: list[dict[str, Any]] | None = None - if self._store_in_history: - current_turn_messages = ( - self._dataset_metadata.current_turn_messages_by_key.get((conv_id, turn)) - ) - if current_turn_messages: - live_messages = state.message_history.copy() + current_turn_messages - data_override = { - "messages": live_messages, - "input_tokens": None, - "token_ids": None, - } + data_override = self._build_data_override( + source_id=source_id, + turn=turn, + repeat_id=repeat_id, + ) assert self._phase_issuer is not None query_id = self._phase_issuer.issue( @@ -305,27 +332,102 @@ def _issue_turn_now(self, conv_id: str, idx: int, turn: int) -> None: turn=turn, ) if query_id is None: - # Session stopping — signal done. - assert self._all_done is not None - self._all_done.set() + # Session stopping due to wall-clock limit, signal, or a generic + # stop check. Do not synthesize failures for unissued turns. + self._request_stop_issuing() return + self._active_iters[conv_id] = (source_id, turns, cursor + 1, repeat_id) self._inflight[query_id] = conv_id - if self._store_in_history and current_turn_messages: - state.message_history.extend(current_turn_messages) - assert self._loop is not None handle = self._loop.call_later( self._turn_timeout_s, self._handle_timeout, query_id, conv_id ) self._timeout_handles[query_id] = handle + def _build_data_override( + self, + source_id: str, + turn: int, + repeat_id: int, + ) -> dict[str, Any] | None: + if not self._enable_salt: + return None + + messages = self._dataset_metadata.pre_built_messages_by_key.get( + (source_id, turn) + ) + if not messages: + raise InputValidationError( + "multi_turn.enable_salt requires pre-built messages for every " + f"client turn; conversation {source_id!r} turn {turn} has none" + ) + + salted_messages = self._messages_with_trajectory_salt( + messages, + repeat_id=repeat_id, + conversation_id=source_id, + ) + return { + "messages": salted_messages, + "input_tokens": None, + "token_ids": None, + } + + def _messages_with_trajectory_salt( + self, + messages: list[dict], + repeat_id: int, + conversation_id: str, + ) -> list[dict]: + salted_messages = [dict(message) for message in messages] + for message in salted_messages: + if message.get("role") != "system": + continue + content = message.get("content") + if isinstance(content, str): + repeat_salt = hashlib.blake2b( + str(repeat_id).encode("utf-8"), digest_size=2 + ).hexdigest() + conv_salt = hashlib.blake2b( + conversation_id.encode("utf-8"), digest_size=2 + ).hexdigest() + message["content"] = ( + f"[salt: {repeat_salt}]\n\n" f"{content}\n\n" f"[salt: {conv_salt}]" + ) + return salted_messages + raise InputValidationError( + "multi_turn.enable_salt requires a system prompt for every " + f"conversation; conversation {conversation_id!r} has no system prompt" + ) + + def _validate_salt_system_prompts(self) -> None: + """Fail before issuing if salting cannot be applied to every turn.""" + if not self._enable_salt: + return + for source_id, turns in self._base_convs: + for _idx, turn in turns: + messages = self._dataset_metadata.pre_built_messages_by_key.get( + (source_id, turn) + ) + if not messages or not any( + message.get("role") == "system" for message in messages + ): + raise InputValidationError( + "multi_turn.enable_salt requires a system prompt for every " + f"conversation; conversation {source_id!r} turn {turn} " + "has no system prompt" + ) + def _fill_slot(self) -> None: """Start a new conversation from the pending queue, or signal all done.""" # Errors here must not leave _all_done unset — that would hang execute(). try: - if self._pending_convs: + if self._stopping: + self._signal_done_if_no_inflight() + return + if self._has_more_conversation_instances(): self._start_conversation() elif not self._has_work_remaining(): assert self._all_done is not None @@ -336,6 +438,40 @@ def _fill_slot(self) -> None: if self._all_done is not None: self._all_done.set() + def _finish_conversation(self, conv_id: str) -> None: + """Mark one trajectory done and optionally stop tracking/issuing.""" + self._active_iters.pop(conv_id, None) + if not self._has_more_conversation_instances(): + self._stop_performance_tracking_once() + if self._stop_issuing_on_first_user_complete: + self._request_stop_issuing() + return + self._fill_slot() + + def _stop_performance_tracking_once(self) -> None: + if self._performance_tracking_stopped: + return + self._performance_tracking_stopped = True + if self._phase_issuer is None: + return + self._phase_issuer.stop_performance_tracking() + + def _request_stop_issuing(self) -> None: + """Stop issuing future turns and wait only for already in-flight work.""" + if self._stopping: + self._signal_done_if_no_inflight() + return + self._stopping = True + for handle in self._delay_handles.values(): + handle.cancel() + self._delay_handles.clear() + self._active_iters.clear() + self._signal_done_if_no_inflight() + + def _signal_done_if_no_inflight(self) -> None: + if not self._inflight and self._all_done is not None: + self._all_done.set() + def _handle_timeout(self, query_id: str, conv_id: str) -> None: """Called by the event loop when a turn response does not arrive in time.""" if self._inflight.pop(query_id, None) is None: @@ -361,9 +497,7 @@ def _handle_timeout(self, query_id: str, conv_id: str) -> None: "Turn timed out for conversation %s (query=%s)", conv_id, query_id ) - self._conv_manager.mark_turn_failed( - conv_id, store_in_history=self._store_in_history - ) + self._conv_manager.mark_turn_failed(conv_id) self._publish_synthetic_failure( query_id, @@ -383,7 +517,7 @@ def _handle_timeout(self, query_id: str, conv_id: str) -> None: dropped, ) - self._fill_slot() + self._finish_conversation(conv_id) def _publish_synthetic_failure( self, @@ -440,13 +574,11 @@ def _abort_remaining_turns(self, conv_id: str, reason: str) -> int: state = self._active_iters.pop(conv_id, None) if state is None: return 0 - turns, cursor = state + _source_id, turns, cursor, _repeat_id = state assert self._phase_issuer is not None dropped = 0 for idx, turn in turns[cursor:]: - self._conv_manager.mark_turn_failed( - conv_id, store_in_history=self._store_in_history - ) + self._conv_manager.mark_turn_failed(conv_id) skipped_id = self._phase_issuer.register_skipped( idx, conversation_id=conv_id, turn=turn ) @@ -488,28 +620,11 @@ def on_sample_complete(self, result: QueryResult) -> None: if handle is not None: handle.cancel() - output = result.response_output - if isinstance(output, TextModelOutput): - response_text: str | None = ( - "".join(output.output) - if isinstance(output.output, tuple) - else output.output - ) or None - else: - response_text = output if isinstance(output, str) else None - try: if result.error is not None: - self._conv_manager.mark_turn_failed( - conv_id, store_in_history=self._store_in_history - ) + self._conv_manager.mark_turn_failed(conv_id) else: - self._conv_manager.mark_turn_complete( - conv_id, - response_text or "", - store_in_history=self._store_in_history, - metadata=result.metadata, - ) + self._conv_manager.mark_turn_complete(conv_id) except KeyError: self._active_iters.pop(conv_id, None) self._fill_slot() @@ -520,9 +635,9 @@ def on_sample_complete(self, result: QueryResult) -> None: ) return - # If this turn failed, abandon the rest of the conversation: replaying - # later turns against a corrupt history (assistant placeholder / - # missing tool result) is meaningless and matches the timeout path. + # If this turn failed, abandon the rest of the conversation: later + # client turns depend on the failed prior response, matching timeout + # handling. if result.error is not None: err_type = ( result.error.error_type if result.error is not None else "unknown" @@ -536,7 +651,11 @@ def on_sample_complete(self, result: QueryResult) -> None: conv_id, dropped, ) - self._fill_slot() + self._finish_conversation(conv_id) + return + + if self._stopping: + self._signal_done_if_no_inflight() return try: diff --git a/src/inference_endpoint/load_generator/session.py b/src/inference_endpoint/load_generator/session.py index 3b80420b..f4df856e 100644 --- a/src/inference_endpoint/load_generator/session.py +++ b/src/inference_endpoint/load_generator/session.py @@ -31,6 +31,7 @@ from typing import Any, Protocol from ..config.runtime_settings import RuntimeSettings +from ..config.schema import LoadPatternType from ..core.record import ( ErrorEventType, EventRecord, @@ -177,6 +178,7 @@ class PhaseIssuer: "_dataset", "_issuer", "_on_inflight_drained", + "_performance_tracking_stopped", "_publisher", "_stop_check", "uuid_to_index", @@ -204,12 +206,26 @@ def __init__( self.completed_uuids: set[str] = set() self.inflight: int = 0 self.issued_count: int = 0 + self._performance_tracking_stopped = False def mark_inflight_complete(self) -> None: self.inflight -= 1 if self.inflight <= 0: self._on_inflight_drained() + def stop_performance_tracking(self) -> None: + """Publish STOP_PERFORMANCE_TRACKING once for this phase.""" + if self._performance_tracking_stopped: + return + self._performance_tracking_stopped = True + self._publisher.publish( + EventRecord( + event_type=SessionEventType.STOP_PERFORMANCE_TRACKING, + timestamp_ns=time.monotonic_ns(), + ) + ) + self._publisher.flush() + def issue( self, sample_index: int, @@ -225,7 +241,7 @@ def issue( sample_index: Index into the dataset. data_override: If provided, merged over the loaded sample data. Keys in data_override take precedence. Used by MultiTurnStrategy - to substitute live-accumulated message history. + to override pre-baked messages when trajectory salting is enabled. Note: load_sample() runs synchronously before the ISSUED timestamp. For accurate timing, datasets MUST be pre-loaded into memory. @@ -433,7 +449,7 @@ async def _run_phase(self, phase: PhaseConfig) -> PhaseResult | None: await self._drain_inflight(phase_issuer, phase.drain_timeout) if phase.phase_type == PhaseType.PERFORMANCE: - self._publish_session_event(SessionEventType.STOP_PERFORMANCE_TRACKING) + phase_issuer.stop_performance_tracking() phase_end = time.monotonic_ns() logger.info( @@ -609,12 +625,17 @@ def _make_stop_check( else 0 ) total_samples = settings.total_samples_to_issue() + stop_on_sample_count = not ( + settings.load_pattern is not None + and settings.load_pattern.type == LoadPatternType.MULTI_TURN + ) def check() -> bool: if self._stop_requested: return True if ( - self._current_phase_issuer + stop_on_sample_count + and self._current_phase_issuer and self._current_phase_issuer.issued_count >= total_samples ): return True diff --git a/src/inference_endpoint/load_generator/strategy.py b/src/inference_endpoint/load_generator/strategy.py index 5ab6ddfa..08c4f836 100644 --- a/src/inference_endpoint/load_generator/strategy.py +++ b/src/inference_endpoint/load_generator/strategy.py @@ -60,7 +60,7 @@ def issue( sample_index: Index into the dataset. data_override: If provided, merged over the loaded sample — keys in data_override take precedence. Used by MultiTurnStrategy to inject - a runtime-assembled `messages` array while still inheriting + a salted `messages` array while still inheriting `model`/`max_completion_tokens`/`tools`/`stream` from the dataset row. conversation_id: Conversation identifier (multi-turn). Empty string for single-turn issues; propagated onto the published EventRecords @@ -84,6 +84,10 @@ def mark_inflight_complete(self) -> None: """Record completion of one HTTP-issued sample.""" ... + def stop_performance_tracking(self) -> None: + """Stop counting subsequently issued samples in tracked metrics.""" + ... + issued_count: int diff --git a/tests/integration/test_multi_turn.py b/tests/integration/test_multi_turn.py index 8834a58d..4b7bfd3b 100644 --- a/tests/integration/test_multi_turn.py +++ b/tests/integration/test_multi_turn.py @@ -19,13 +19,10 @@ correctly together against a real HTTP echo server. Tests cover: - 1. Dataset-history mode (use_dataset_history=True): pre-built messages are - issued as-is; each turn is issued sequentially per conversation. - 2. Live-history mode (use_dataset_history=False): messages are built at - runtime from ConversationManager.message_history; the injected messages - grow with each turn. - 3. Multiple concurrent conversations complete successfully. - 4. Turn ordering: turn N+1 is never issued before turn N completes. + 1. Pre-built messages are issued as-is; each turn is issued sequentially per + conversation. + 2. Multiple concurrent conversations complete successfully. + 3. Turn ordering: turn N+1 is never issued before turn N completes. """ import asyncio @@ -48,7 +45,6 @@ from inference_endpoint.endpoint_client.config import HTTPClientConfig from inference_endpoint.endpoint_client.http_client import HTTPEndpointClient from inference_endpoint.endpoint_client.http_sample_issuer import HttpClientSampleIssuer -from inference_endpoint.exceptions import InputValidationError from inference_endpoint.load_generator.conversation_manager import ConversationManager from inference_endpoint.load_generator.multi_turn_strategy import MultiTurnStrategy from inference_endpoint.load_generator.session import ( @@ -89,13 +85,11 @@ def _make_dataset(rows: list[dict]) -> MultiTurnDataset: def _make_strategy( ds: MultiTurnDataset, - use_dataset_history: bool = True, target_concurrency: int | None = None, inject_tool_delay: bool = False, ) -> MultiTurnStrategy: mt_cfg = MultiTurnConfig( turn_timeout_s=10.0, - use_dataset_history=use_dataset_history, inject_tool_delay=inject_tool_delay, ) assert ds.conversation_metadata is not None @@ -299,7 +293,7 @@ async def _handle_echo_chat_completions_request(self, request): }, ] ds = _make_dataset(rows) - strategy = _make_strategy(ds, use_dataset_history=True) + strategy = _make_strategy(ds) responses: dict = {} count = await _run_session(server.url, ds, strategy, responses) @@ -321,8 +315,8 @@ async def _handle_echo_chat_completions_request(self, request): @pytest.mark.integration @pytest.mark.asyncio -async def test_live_history_messages_grow_each_turn(echo_server): - """Live-history mode: messages array grows with each completed turn.""" +async def test_pre_built_messages_grow_each_turn(echo_server): + """Pre-built messages array grows with each dataset turn.""" received_payloads: list[dict] = [] class CapturingEchoServer(EchoServer): @@ -348,7 +342,7 @@ async def _handle_echo_chat_completions_request(self, request): {"conversation_id": "c1", "turn": 3, "role": "user", "content": "Turn two"}, ] ds = _make_dataset(rows) - strategy = _make_strategy(ds, use_dataset_history=False) + strategy = _make_strategy(ds) responses: dict = {} count = await _run_session(server.url, ds, strategy, responses) @@ -379,7 +373,7 @@ async def test_turn_ordering_enforced_end_to_end(echo_server): {"conversation_id": "c1", "turn": 3, "role": "user", "content": "Second"}, ] ds = _make_dataset(rows) - mt_cfg = MultiTurnConfig(turn_timeout_s=10.0, use_dataset_history=True) + mt_cfg = MultiTurnConfig(turn_timeout_s=10.0) conv_manager = ConversationManager() strategy = MultiTurnStrategy( conversation_manager=conv_manager, @@ -742,11 +736,7 @@ def failing_issue_next_turn(*args, **kwargs): @pytest.mark.integration @pytest.mark.asyncio async def test_tools_field_forwarded_to_endpoint(echo_server): - """The 'tools' array from the dataset reaches the endpoint in every request payload. - - TODO: Add a tool-call-aware server that returns dynamic tool_call_ids to - validate live-history mode with real tool_call_id round-tripping. - """ + """The 'tools' array from the dataset reaches the endpoint in every request payload.""" received_payloads: list[dict] = [] class CapturingEchoServer(EchoServer): @@ -808,7 +798,7 @@ async def _handle_echo_chat_completions_request(self, request): }, ] ds = _make_dataset(rows) - strategy = _make_strategy(ds, use_dataset_history=True) + strategy = _make_strategy(ds) responses: dict = {} count = await _run_session(server.url, ds, strategy, responses) @@ -823,45 +813,6 @@ async def _handle_echo_chat_completions_request(self, request): server.stop() -@pytest.mark.integration -def test_live_history_rejects_tool_turns(): - """MultiTurnStrategy raises InputValidationError at __init__ when use_dataset_history=False - and the dataset contains tool-role turns. - """ - tool_calls = [ - { - "id": "call_1", - "type": "function", - "function": {"name": "search", "arguments": '{"q": "hello"}'}, - } - ] - tool_results = [{"tool_call_id": "call_1", "content": "result"}] - rows = [ - {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Search"}, - { - "conversation_id": "c1", - "turn": 2, - "role": "assistant", - "content": None, - "tool_calls": tool_calls, - }, - { - "conversation_id": "c1", - "turn": 3, - "role": "tool", - "tool_results": tool_results, - }, - ] - ds = _make_dataset(rows) - mt_cfg = MultiTurnConfig(turn_timeout_s=10.0, use_dataset_history=False) - with pytest.raises(InputValidationError, match="use_dataset_history=True"): - MultiTurnStrategy( - conversation_manager=ConversationManager(), - dataset_metadata=ds.conversation_metadata, - multi_turn_config=mt_cfg, - ) - - def _tool_use_rows_with_delays( tool_row_delays: dict[int, float | None], ) -> list[dict]: diff --git a/tests/unit/async_utils/services/metrics_aggregator/test_token_metrics.py b/tests/unit/async_utils/services/metrics_aggregator/test_token_metrics.py index 51c6e80d..e25bf002 100644 --- a/tests/unit/async_utils/services/metrics_aggregator/test_token_metrics.py +++ b/tests/unit/async_utils/services/metrics_aggregator/test_token_metrics.py @@ -40,7 +40,8 @@ def tokenize(self, text: str) -> list[str]: return text.split() @classmethod - def from_pretrained(cls, name: str) -> "_FakeTokenizer": + def from_pretrained(cls, name: str, **kwargs: object) -> "_FakeTokenizer": + assert kwargs == {"trust_remote_code": True} return cls() diff --git a/tests/unit/config/test_schema.py b/tests/unit/config/test_schema.py index a6077012..6a02879d 100644 --- a/tests/unit/config/test_schema.py +++ b/tests/unit/config/test_schema.py @@ -501,6 +501,17 @@ def test_multi_turn_valid_config(self): assert config.settings.load_pattern.type == LoadPatternType.MULTI_TURN assert config.settings.load_pattern.target_concurrency == 16 + @pytest.mark.unit + def test_multi_turn_rejects_removed_stop_on_first_empty_slot_as_extra(self): + # Legacy multi-turn knobs should remain rejected by extra="forbid". + with pytest.raises(ValueError, match="stop_on_first_empty_slot"): + BenchmarkConfig( + **self._make_online_multi_turn( + concurrency=16, + multi_turn={"stop_on_first_empty_slot": True}, + ) + ) + @pytest.mark.unit def test_multi_turn_requires_target_concurrency(self): with pytest.raises(ValueError, match="Multi-turn requires --concurrency"): @@ -533,6 +544,23 @@ def test_multi_turn_dataset_without_multi_turn_load_pattern_rejected(self): settings={"load_pattern": {"type": "poisson", "target_qps": 10}}, ) + @pytest.mark.unit + def test_multi_turn_rejects_runtime_num_samples_override(self): + with pytest.raises(ValueError, match="num_trajectories_to_issue"): + BenchmarkConfig( + type=TestType.ONLINE, + model_params={"name": "M"}, + endpoint_config={"endpoints": ["http://x"]}, + datasets=[{"path": "D", "multi_turn": {}}], + settings={ + "load_pattern": { + "type": "multi_turn", + "target_concurrency": 4, + }, + "runtime": {"n_samples_to_issue": 200}, + }, + ) + class TestMultiTurnTotalSamples: """Tests for total_samples_to_issue() with multi_turn load pattern.""" diff --git a/tests/unit/dataset_manager/test_factory.py b/tests/unit/dataset_manager/test_factory.py index 262f9f28..c4d0ee46 100644 --- a/tests/unit/dataset_manager/test_factory.py +++ b/tests/unit/dataset_manager/test_factory.py @@ -18,12 +18,15 @@ from inference_endpoint.config.schema import Dataset as DatasetConfig from inference_endpoint.config.schema import MultiTurnConfig from inference_endpoint.dataset_manager.dataset import Dataset -from inference_endpoint.exceptions import InputValidationError +from inference_endpoint.dataset_manager.multi_turn_dataset import MultiTurnDataset @pytest.mark.unit -def test_enable_salt_requires_multi_turn_loader(monkeypatch): +def test_multi_turn_config_selects_multi_turn_loader(monkeypatch): + captured: dict[str, object] = {} + def fake_load_from_file(*args, **kwargs): + captured["dataset_id"] = kwargs.get("dataset_id") return Dataset() monkeypatch.setattr( @@ -38,5 +41,6 @@ def fake_load_from_file(*args, **kwargs): multi_turn=MultiTurnConfig(enable_salt=True), ) - with pytest.raises(InputValidationError, match="enable_salt"): - factory_module.DataLoaderFactory.create_loader(config) + factory_module.DataLoaderFactory.create_loader(config) + + assert captured["dataset_id"] == MultiTurnDataset.DATASET_ID diff --git a/tests/unit/dataset_manager/test_multi_turn_dataset.py b/tests/unit/dataset_manager/test_multi_turn_dataset.py index 6615036d..7b4ff5bf 100644 --- a/tests/unit/dataset_manager/test_multi_turn_dataset.py +++ b/tests/unit/dataset_manager/test_multi_turn_dataset.py @@ -1242,8 +1242,8 @@ def test_jsonl_round_trip_with_tools_field(): @pytest.mark.unit -def test_current_turn_messages_by_key_parallel_tools(): - """current_turn_messages_by_key stores all expanded messages for a tool turn.""" +def test_pre_built_messages_parallel_tools(): + """pre_built_messages_by_key stores all expanded messages for a tool turn.""" df = pd.DataFrame( [ {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Go"}, @@ -1284,91 +1284,18 @@ def test_current_turn_messages_by_key_parallel_tools(): ) ds = MultiTurnDataset(df) ds.load() - ctm = ds.conversation_metadata.current_turn_messages_by_key + pbm = ds.conversation_metadata.pre_built_messages_by_key # user(1) current turn is 1 message - assert len(ctm[("c1", 1)]) == 1 - assert ctm[("c1", 1)][0] == {"role": "user", "content": "Go"} - - # tool(3) current turn has 2 expanded messages (parallel tool_results) - assert len(ctm[("c1", 3)]) == 2 - assert ctm[("c1", 3)][0]["tool_call_id"] == "c_0" - assert ctm[("c1", 3)][1]["tool_call_id"] == "c_1" - - -# ============================================================================ -# Fix 1: system_prompts_by_conv in metadata (live-history mode) -# ============================================================================ - - -@pytest.mark.unit -def test_metadata_contains_system_prompts_by_conv(): - """_build_metadata exposes system_prompts_by_conv keyed by conversation_id.""" - data = [ - { - "conversation_id": "c1", - "turn": 1, - "role": "user", - "content": "Hi", - "system": "Be concise", - }, - {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Ok"}, - # c2 has no system prompt - {"conversation_id": "c2", "turn": 1, "role": "user", "content": "Hello"}, - ] - df = pd.DataFrame(data) - ds = MultiTurnDataset(df) - ds.load() - - spc = ds.conversation_metadata.system_prompts_by_conv - assert spc["c1"] == "Be concise" - assert spc["c2"] is None - - -@pytest.mark.unit -def test_metadata_system_prompts_multiple_convs(): - """Each conversation gets its own system prompt entry.""" - data = [ - { - "conversation_id": "c1", - "turn": 1, - "role": "user", - "content": "A", - "system": "Sys1", - }, - {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "B"}, - { - "conversation_id": "c2", - "turn": 1, - "role": "user", - "content": "C", - "system": "Sys2", - }, - {"conversation_id": "c2", "turn": 2, "role": "assistant", "content": "D"}, - ] - df = pd.DataFrame(data) - ds = MultiTurnDataset(df) - ds.load() - - spc = ds.conversation_metadata.system_prompts_by_conv - assert spc["c1"] == "Sys1" - assert spc["c2"] == "Sys2" - - -@pytest.mark.unit -def test_enable_salt_warns_when_conversation_has_no_system_prompt(caplog): - rows = [ - {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Hi"}, - {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Ok"}, - ] - df = pd.DataFrame(rows) - ds = MultiTurnDataset(df) - ds.enable_salt() - - ds.load() - - assert "cache salt not applied" in caplog.text - assert ds.conversation_metadata.system_prompts_by_conv["c1"] is None + assert len(pbm[("c1", 1)]) == 1 + assert pbm[("c1", 1)][0] == {"role": "user", "content": "Go"} + + # tool(3) pre-built history includes user, assistant tool call, and both + # expanded tool result messages. + tool_messages = pbm[("c1", 3)] + assert len(tool_messages) == 4 + assert tool_messages[2]["tool_call_id"] == "c_0" + assert tool_messages[3]["tool_call_id"] == "c_1" # ============================================================================ diff --git a/tests/unit/evaluation/test_scoring.py b/tests/unit/evaluation/test_scoring.py index 3b7a5a90..5d558cef 100644 --- a/tests/unit/evaluation/test_scoring.py +++ b/tests/unit/evaluation/test_scoring.py @@ -24,12 +24,14 @@ import pytest from inference_endpoint.core.record import EventRecord, EventType, SampleEventType from inference_endpoint.core.types import TextModelOutput +from inference_endpoint.dataset_manager.multi_turn_dataset import MultiTurnDataset from inference_endpoint.dataset_manager.predefined.shopify_product_catalogue import ( ProductMetadata, ) from inference_endpoint.evaluation import scoring as scoring_mod from inference_endpoint.evaluation.scoring import ( _PRED_CATEGORY_PAD, + MultiTurnInlineScorer, Scorer, ShopifyCategoryF1Scorer, VBenchScorer, @@ -223,6 +225,310 @@ def test_score_requires_sample_index_map_and_events(self, mock_dataset, report_d ) +@pytest.mark.unit +class TestMultiTurnInlineScorer: + """MultiTurnInlineScorer unit tests.""" + + @staticmethod + def _bash_tool_call(call_id: str, command: str) -> dict: + return { + "id": call_id, + "type": "function", + "function": { + "name": "bash", + "arguments": json.dumps({"cmd": command}), + }, + } + + @staticmethod + def _write_report(report_dir: Path, records: list[EventRecord]) -> None: + report_dir.mkdir(parents=True, exist_ok=True) + (report_dir / "sample_idx_map.json").write_bytes( + msgspec.json.encode({"performance": {}}) + ) + encoder = msgspec.json.Encoder(enc_hook=EventType.encode_hook) + with (report_dir / "events.jsonl").open("wb") as f: + for record in records: + f.write(encoder.encode(record) + b"\n") + + def test_scores_coding_and_workflow_turns(self, tmp_path): + dataset = MultiTurnDataset( + pd.DataFrame( + [ + { + "conversation_id": "code1", + "turn": 1, + "role": "user", + "content": "run the tests", + }, + { + "conversation_id": "code1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": [ + self._bash_tool_call("expected-code", "python -m pytest") + ], + }, + { + "conversation_id": "sim_1", + "turn": 1, + "role": "user", + "content": "choose the next workflow action", + }, + { + "conversation_id": "sim_1", + "turn": 2, + "role": "assistant", + "content": "expected workflow action", + "intent_codes": ["I042"], + }, + ] + ) + ) + report_dir = tmp_path / "report" + self._write_report( + report_dir, + [ + EventRecord( + event_type=SampleEventType.ISSUED, + sample_uuid="code-response", + conversation_id="code1", + turn=1, + ), + EventRecord( + event_type=SampleEventType.COMPLETE, + sample_uuid="code-response", + conversation_id="code1", + turn=1, + data=TextModelOutput( + tool_calls=[ + self._bash_tool_call("model-code", "python test.py") + ] + ), + ), + EventRecord( + event_type=SampleEventType.ISSUED, + sample_uuid="workflow-response", + conversation_id="sim_1", + turn=1, + ), + EventRecord( + event_type=SampleEventType.COMPLETE, + sample_uuid="workflow-response", + conversation_id="sim_1", + turn=1, + data=TextModelOutput(output="intent: I042"), + ), + ], + ) + + score, repeats = MultiTurnInlineScorer( + "performance", dataset, report_dir + ).score() + + assert score == 1.0 + assert repeats == 1 + scores = json.loads((report_dir / "scores.json").read_text()) + assert "valid" not in scores + assert scores["turns"]["scored"] == 2 + assert scores["domains"] == { + "coding": {"score": 1.0, "scored": 1}, + "workflow": {"score": 1.0, "scored": 1}, + } + + def test_turns_without_ground_truth_are_excluded(self, tmp_path): + dataset = MultiTurnDataset( + pd.DataFrame( + [ + { + "conversation_id": "code1", + "turn": 1, + "role": "user", + "content": "run the tests", + }, + { + "conversation_id": "code1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": [ + self._bash_tool_call("expected-code", "python -m pytest") + ], + }, + { + "conversation_id": "code2", + "turn": 1, + "role": "user", + "content": "summarize the repository", + }, + { + "conversation_id": "code2", + "turn": 2, + "role": "assistant", + "content": "This turn has no bash action ground truth.", + }, + ] + ) + ) + report_dir = tmp_path / "report" + self._write_report( + report_dir, + [ + EventRecord( + event_type=SampleEventType.ISSUED, + sample_uuid="code-response", + conversation_id="code1", + turn=1, + ), + EventRecord( + event_type=SampleEventType.COMPLETE, + sample_uuid="code-response", + conversation_id="code1", + turn=1, + data=TextModelOutput( + tool_calls=[ + self._bash_tool_call("model-code", "python test.py") + ] + ), + ), + EventRecord( + event_type=SampleEventType.ISSUED, + sample_uuid="unscored-code-response", + conversation_id="code2", + turn=1, + ), + ], + ) + + score, repeats = MultiTurnInlineScorer( + "performance", dataset, report_dir + ).score() + + assert score == 1.0 + assert repeats == 1 + scores = json.loads((report_dir / "scores.json").read_text()) + assert "valid" not in scores + assert scores["turns"] == { + "issued": 2, + "expected": 1, + "observed": 1, + "missing": 0, + "scored": 1, + } + assert scores["excluded_turns"] == [ + { + "conversation_id": "code2", + "turn": 2, + "domain": "coding", + "exclude_reason": "no ground truth", + } + ] + + def test_scores_issued_turns_without_rounding_to_full_repeats(self, tmp_path): + dataset = MultiTurnDataset( + pd.DataFrame( + [ + { + "conversation_id": "code1", + "turn": 1, + "role": "user", + "content": "run the tests", + }, + { + "conversation_id": "code1", + "turn": 2, + "role": "assistant", + "tool_calls": [ + self._bash_tool_call("expected-code", "python -m pytest") + ], + }, + { + "conversation_id": "sim_1", + "turn": 1, + "role": "user", + "content": "choose the next workflow action", + }, + { + "conversation_id": "sim_1", + "turn": 2, + "role": "assistant", + "content": "expected workflow action", + "intent_codes": ["I042"], + }, + ] + ) + ) + report_dir = tmp_path / "report" + self._write_report( + report_dir, + [ + EventRecord( + event_type=SampleEventType.ISSUED, + sample_uuid="code-r1", + conversation_id="code1", + turn=1, + ), + EventRecord( + event_type=SampleEventType.COMPLETE, + sample_uuid="code-r1", + conversation_id="code1", + turn=1, + data=TextModelOutput( + tool_calls=[ + self._bash_tool_call("model-code-r1", "python test.py") + ] + ), + ), + EventRecord( + event_type=SampleEventType.ISSUED, + sample_uuid="workflow-r1", + conversation_id="sim_1", + turn=1, + ), + EventRecord( + event_type=SampleEventType.COMPLETE, + sample_uuid="workflow-r1", + conversation_id="sim_1", + turn=1, + data=None, + ), + EventRecord( + event_type=SampleEventType.ISSUED, + sample_uuid="code-r2", + conversation_id="code1__repeat_2", + turn=1, + ), + EventRecord( + event_type=SampleEventType.COMPLETE, + sample_uuid="code-r2", + conversation_id="code1__repeat_2", + turn=1, + data=TextModelOutput( + tool_calls=[ + self._bash_tool_call("model-code-r2", "python test.py") + ] + ), + ), + ], + ) + + score, repeats = MultiTurnInlineScorer( + "performance", dataset, report_dir + ).score() + + assert score == 0.6667 + assert repeats == 2 + scores = json.loads((report_dir / "scores.json").read_text()) + assert scores["turns"] == { + "issued": 3, + "expected": 3, + "observed": 2, + "missing": 1, + "scored": 3, + } + + @pytest.mark.unit class TestVBenchScorerRegistration: def test_scorer_registered(self): diff --git a/tests/unit/load_generator/test_multi_turn_conversation_manager.py b/tests/unit/load_generator/test_multi_turn_conversation_manager.py index c389fb5f..c764b5cc 100644 --- a/tests/unit/load_generator/test_multi_turn_conversation_manager.py +++ b/tests/unit/load_generator/test_multi_turn_conversation_manager.py @@ -29,7 +29,6 @@ def test_conversation_state_initialization(): state = ConversationState(conversation_id="conv_001") assert state.conversation_id == "conv_001" - assert state.message_history == [] assert state.completed_turns == 0 assert state.failed_turns == 0 assert state.expected_client_turns is None @@ -86,7 +85,7 @@ def test_conversation_manager_multiple_conversations(): assert state1 is not state2 - manager.mark_turn_complete("conv_001", "Response to conv_001") + manager.mark_turn_complete("conv_001") assert state1.completed_turns == 1 assert state2.completed_turns == 0 @@ -94,26 +93,14 @@ def test_conversation_manager_multiple_conversations(): @pytest.mark.unit def test_conversation_manager_mark_turn_complete(): - """mark_turn_complete increments counter and appends history.""" + """mark_turn_complete increments the completion counter.""" manager = ConversationManager() state = manager.get_or_create("conv_001") - manager.mark_turn_complete("conv_001", "Assistant response") + manager.mark_turn_complete("conv_001") assert state.completed_turns == 1 assert state.failed_turns == 0 - assert state.message_history == [] # store_in_history=False by default - - -@pytest.mark.unit -def test_conversation_manager_mark_turn_complete_stores_history(): - """mark_turn_complete appends to history when store_in_history=True.""" - manager = ConversationManager() - state = manager.get_or_create("conv_001") - - manager.mark_turn_complete("conv_001", "Hello", store_in_history=True) - - assert state.message_history == [{"role": "assistant", "content": "Hello"}] @pytest.mark.unit @@ -135,9 +122,9 @@ def test_conversation_completion_tracking(): state = manager.get_or_create("conv_001", expected_client_turns=2) assert not state.is_complete() - manager.mark_turn_complete("conv_001", "r1") + manager.mark_turn_complete("conv_001") assert not state.is_complete() - manager.mark_turn_complete("conv_001", "r2") + manager.mark_turn_complete("conv_001") assert state.is_complete() @@ -147,7 +134,7 @@ def test_conversation_completion_without_expected_turns(): manager = ConversationManager() state = manager.get_or_create("conv_001", expected_client_turns=None) - manager.mark_turn_complete("conv_001", "r1") + manager.mark_turn_complete("conv_001") assert not state.is_complete() @@ -158,13 +145,13 @@ def test_conversation_completion_with_failures(): manager = ConversationManager() state = manager.get_or_create("conv1", expected_client_turns=3) - manager.mark_turn_complete("conv1", "Hi") + manager.mark_turn_complete("conv1") assert not state.is_complete() manager.mark_turn_failed("conv1") assert not state.is_complete() - manager.mark_turn_complete("conv1", "Bye") + manager.mark_turn_complete("conv1") assert state.is_complete() assert state.failed_turns == 1 assert state.completed_turns == 3 @@ -202,7 +189,7 @@ async def process_conversation(conv_id: str): state = manager.get_state(conv_id) assert state is not None for _ in range(turns_per_conv): - manager.mark_turn_complete(conv_id, "response") + manager.mark_turn_complete(conv_id) await asyncio.sleep(0.001) except Exception as e: errors.append(f"{conv_id} error: {e}") diff --git a/tests/unit/load_generator/test_multi_turn_strategy.py b/tests/unit/load_generator/test_multi_turn_strategy.py index 81485dc9..d1112edb 100644 --- a/tests/unit/load_generator/test_multi_turn_strategy.py +++ b/tests/unit/load_generator/test_multi_turn_strategy.py @@ -16,15 +16,18 @@ """Unit tests for MultiTurnStrategy.""" import asyncio +import hashlib from unittest.mock import MagicMock import pytest +from inference_endpoint.config.schema import MultiTurnConfig from inference_endpoint.core.record import ErrorEventType, SampleEventType from inference_endpoint.core.types import ErrorData, QueryResult, TextModelOutput from inference_endpoint.dataset_manager.multi_turn_dataset import ( ConversationMetadata, ConversationSampleEntry, ) +from inference_endpoint.exceptions import InputValidationError from inference_endpoint.load_generator.conversation_manager import ConversationManager from inference_endpoint.load_generator.multi_turn_strategy import MultiTurnStrategy @@ -42,6 +45,7 @@ def __init__(self, stop_after: int | None = None): self.uuid_to_conv_info: dict[str, tuple[str, int | None]] = {} self.completed_uuids: set[str] = set() self.drained = False + self.stop_tracking_count = 0 def issue( self, @@ -77,6 +81,53 @@ def mark_inflight_complete(self) -> None: if self.inflight <= 0: self.drained = True + def stop_performance_tracking(self) -> None: + self.stop_tracking_count += 1 + + +class RecordingPhaseIssuer: + """Phase issuer with unique query IDs for repeated sample indices.""" + + def __init__(self): + self.issued_count = 0 + self.issued: list[int] = [] + self.records: list[tuple[str, int, str, int | None, dict | None]] = [] + self.uuid_to_conv_info: dict[str, tuple[str, int | None]] = {} + self.uuid_to_index: dict[str, int] = {} + self.completed_uuids: set[str] = set() + self.stop_tracking_count = 0 + + def issue( + self, + sample_index: int, + data_override: dict | None = None, + conversation_id: str = "", + turn: int | None = None, + ) -> str | None: + query_id = f"q{self.issued_count:04d}" + self.issued_count += 1 + self.issued.append(sample_index) + self.records.append( + (query_id, sample_index, conversation_id, turn, data_override) + ) + self.uuid_to_index[query_id] = sample_index + self.uuid_to_conv_info[query_id] = (conversation_id, turn) + return query_id + + def register_skipped( + self, + sample_index: int, + conversation_id: str = "", + turn: int | None = None, + ) -> str | None: + raise AssertionError("budget stops must not register skipped turns") + + def mark_inflight_complete(self) -> None: + pass + + def stop_performance_tracking(self) -> None: + self.stop_tracking_count += 1 + def _make_dataset_metadata(conversations: dict[str, list[int]]) -> ConversationMetadata: """Build ConversationMetadata from {conv_id: [turn_numbers]} mapping.""" @@ -100,6 +151,205 @@ def _make_dataset_metadata(conversations: dict[str, list[int]]) -> ConversationM ) +@pytest.mark.unit +@pytest.mark.asyncio +async def test_first_user_complete_stops_tracking_but_can_continue_for_accuracy(): + conv_manager = ConversationManager() + metadata = _make_dataset_metadata({"conv1": [1], "conv2": [1, 2]}) + cfg = MultiTurnConfig(num_trajectories_to_issue=2) + strategy = MultiTurnStrategy( + conv_manager, + metadata, + multi_turn_config=cfg, + target_concurrency=2, + ) + issuer = RecordingPhaseIssuer() + + execute_task = asyncio.create_task(strategy.execute(issuer)) + await asyncio.sleep(0.01) + + assert [(idx, conv, turn) for _, idx, conv, turn, _ in issuer.records] == [ + (0, "conv1", 1), + (1, "conv2", 1), + ] + + strategy.on_sample_complete( + QueryResult(id="q0000", response_output=TextModelOutput(output="conv1")) + ) + await asyncio.sleep(0.01) + + assert issuer.stop_tracking_count == 1 + assert not execute_task.done() + + strategy.on_sample_complete( + QueryResult(id="q0001", response_output=TextModelOutput(output="conv2-turn1")) + ) + await asyncio.sleep(0.01) + + assert [(idx, conv, turn) for _, idx, conv, turn, _ in issuer.records] == [ + (0, "conv1", 1), + (1, "conv2", 1), + (2, "conv2", 2), + ] + + strategy.on_sample_complete( + QueryResult(id="q0002", response_output=TextModelOutput(output="conv2-turn2")) + ) + count = await asyncio.wait_for(execute_task, timeout=1.0) + + assert count == 3 + assert issuer.stop_tracking_count == 1 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_stop_on_first_user_complete_refills_until_budget_exhausted(): + conv_manager = ConversationManager() + metadata = _make_dataset_metadata({"conv1": [1], "conv2": [1], "conv3": [1]}) + cfg = MultiTurnConfig( + stop_issuing_on_first_user_complete=True, + num_trajectories_to_issue=3, + ) + strategy = MultiTurnStrategy( + conv_manager, + metadata, + multi_turn_config=cfg, + target_concurrency=2, + ) + issuer = RecordingPhaseIssuer() + + execute_task = asyncio.create_task(strategy.execute(issuer)) + await asyncio.sleep(0.01) + + assert [(idx, conv, turn) for _, idx, conv, turn, _ in issuer.records] == [ + (0, "conv1", 1), + (1, "conv2", 1), + ] + + strategy.on_sample_complete( + QueryResult(id="q0000", response_output=TextModelOutput(output="conv1")) + ) + await asyncio.sleep(0.01) + + assert issuer.stop_tracking_count == 0 + assert [(idx, conv, turn) for _, idx, conv, turn, _ in issuer.records] == [ + (0, "conv1", 1), + (1, "conv2", 1), + (2, "conv3", 1), + ] + + strategy.on_sample_complete( + QueryResult(id="q0001", response_output=TextModelOutput(output="conv2")) + ) + await asyncio.sleep(0.01) + + assert issuer.stop_tracking_count == 1 + assert not execute_task.done() + + strategy.on_sample_complete( + QueryResult(id="q0002", response_output=TextModelOutput(output="conv3")) + ) + count = await asyncio.wait_for(execute_task, timeout=1.0) + + assert count == 3 + assert issuer.stop_tracking_count == 1 + assert [(idx, conv, turn) for _, idx, conv, turn, _ in issuer.records] == [ + (0, "conv1", 1), + (1, "conv2", 1), + (2, "conv3", 1), + ] + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_salted_turns_use_repeat_and_conversation_salts(): + """Pre-baked prompts get repeat and conversation salt markers when enabled.""" + conv_manager = ConversationManager() + metadata = _make_dataset_metadata({"conv1": [1]}) + base_messages = [ + {"role": "system", "content": "Be helpful"}, + {"role": "user", "content": "hello"}, + ] + metadata.pre_built_messages_by_key = {("conv1", 1): base_messages} + cfg = MultiTurnConfig(enable_salt=True, num_trajectories_to_issue=2) + strategy = MultiTurnStrategy( + conv_manager, + metadata, + target_concurrency=1, + multi_turn_config=cfg, + ) + issuer = RecordingPhaseIssuer() + + execute_task = asyncio.create_task(strategy.execute(issuer)) + await asyncio.sleep(0.01) + strategy.on_sample_complete( + QueryResult(id="q0000", response_output=TextModelOutput(output="first")) + ) + await asyncio.sleep(0.01) + strategy.on_sample_complete( + QueryResult(id="q0001", response_output=TextModelOutput(output="repeat")) + ) + await asyncio.wait_for(execute_task, timeout=1.0) + + first_override = issuer.records[0][4] + repeat_override = issuer.records[1][4] + assert first_override is not None + assert repeat_override is not None + first_messages = first_override["messages"] + repeat_messages = repeat_override["messages"] + first_system = first_messages[0]["content"] + repeat_system = repeat_messages[0]["content"] + + repeat1_salt = hashlib.blake2b(b"1", digest_size=2).hexdigest() + repeat2_salt = hashlib.blake2b(b"2", digest_size=2).hexdigest() + conversation_salt = hashlib.blake2b(b"conv1", digest_size=2).hexdigest() + assert first_system == ( + f"[salt: {repeat1_salt}]\n\n" f"Be helpful\n\n" f"[salt: {conversation_salt}]" + ) + assert repeat_system == ( + f"[salt: {repeat2_salt}]\n\n" f"Be helpful\n\n" f"[salt: {conversation_salt}]" + ) + assert repeat_system != base_messages[0]["content"] + assert base_messages[0]["content"] == "Be helpful" + + +@pytest.mark.unit +def test_enable_salt_requires_system_prompt(): + """Salting is invalid when the conversation has no system prompt.""" + conv_manager = ConversationManager() + metadata = _make_dataset_metadata({"conv1": [1]}) + cfg = MultiTurnConfig(enable_salt=True) + strategy = MultiTurnStrategy( + conv_manager, + metadata, + multi_turn_config=cfg, + ) + + with pytest.raises(InputValidationError, match="no system prompt"): + strategy._messages_with_trajectory_salt( + [{"role": "user", "content": "hello"}], + repeat_id=1, + conversation_id="conv1", + ) + + +@pytest.mark.unit +def test_enable_salt_requires_pre_built_messages(): + """Salting is invalid when metadata has no pre-built messages for a turn.""" + conv_manager = ConversationManager() + metadata = _make_dataset_metadata({"conv1": [1]}) + metadata.pre_built_messages_by_key = {} + cfg = MultiTurnConfig(enable_salt=True) + strategy = MultiTurnStrategy( + conv_manager, + metadata, + multi_turn_config=cfg, + ) + + with pytest.raises(InputValidationError, match="pre-built messages"): + strategy._build_data_override("conv1", 1, repeat_id=1) + + @pytest.mark.unit @pytest.mark.asyncio async def test_single_conversation_single_turn(): @@ -216,6 +466,9 @@ def issue( self.issued_count += 1 return f"q{idx:04d}" + def stop_performance_tracking(self) -> None: + pass + issuer = TimedIssuer() async def simulate_responses(): @@ -302,125 +555,6 @@ async def test_error_response_marks_turn_failed(): assert state.failed_turns == 1 -def _make_metadata_with_system( - conversations: dict[str, list[int]], - system_prompts: dict[str, str | None] | None = None, -) -> ConversationMetadata: - """Build ConversationMetadata including system_prompts_by_conv.""" - samples = [] - sample_index = 0 - for conv_id, turns in conversations.items(): - for turn in turns: - samples.append( - ConversationSampleEntry( - conversation_id=conv_id, - turn=turn, - sample_index=sample_index, - ) - ) - sample_index += 1 - return ConversationMetadata( - samples=samples, - num_conversations=len(conversations), - max_turns_per_conv=max((max(t) for t in conversations.values()), default=0), - client_turns_per_conversation={c: len(t) for c, t in conversations.items()}, - system_prompts_by_conv=system_prompts or {}, - ) - - -@pytest.mark.unit -@pytest.mark.asyncio -async def test_live_history_initializes_system_prompt(): - """In live-history mode, ConversationManager.message_history starts with system message.""" - from inference_endpoint.config.schema import MultiTurnConfig - - conv_manager = ConversationManager() - metadata = _make_metadata_with_system( - {"conv1": [1]}, - system_prompts={"conv1": "Be helpful"}, - ) - mt_cfg = MultiTurnConfig(use_dataset_history=False, turn_timeout_s=10.0) - strategy = MultiTurnStrategy(conv_manager, metadata, multi_turn_config=mt_cfg) - issuer = FakePhaseIssuer() - - async def complete_turn(): - await asyncio.sleep(0.01) - result = QueryResult( - id="q0000", response_output=TextModelOutput(output="response") - ) - strategy.on_sample_complete(result) - - asyncio.create_task(complete_turn()) - await strategy.execute(issuer) - - state = conv_manager.get_state("conv1") - assert state is not None - # message_history[0] must be the system message - assert len(state.message_history) >= 1 - assert state.message_history[0] == {"role": "system", "content": "Be helpful"} - - -@pytest.mark.unit -@pytest.mark.asyncio -async def test_live_history_no_system_prompt_when_none(): - """In live-history mode, no system message is prepended when system_prompt is None.""" - from inference_endpoint.config.schema import MultiTurnConfig - - conv_manager = ConversationManager() - metadata = _make_metadata_with_system( - {"conv1": [1]}, - system_prompts={"conv1": None}, - ) - mt_cfg = MultiTurnConfig(use_dataset_history=False, turn_timeout_s=10.0) - strategy = MultiTurnStrategy(conv_manager, metadata, multi_turn_config=mt_cfg) - issuer = FakePhaseIssuer() - - async def complete_turn(): - await asyncio.sleep(0.01) - result = QueryResult( - id="q0000", response_output=TextModelOutput(output="response") - ) - strategy.on_sample_complete(result) - - asyncio.create_task(complete_turn()) - await strategy.execute(issuer) - - state = conv_manager.get_state("conv1") - assert state is not None - # No system message should be in history - system_msgs = [m for m in state.message_history if m.get("role") == "system"] - assert len(system_msgs) == 0 - - -@pytest.mark.unit -@pytest.mark.asyncio -async def test_dataset_history_mode_does_not_inject_system_prompt(): - """In dataset-history mode (use_dataset_history=True), system_message is not passed.""" - conv_manager = ConversationManager() - metadata = _make_metadata_with_system( - {"conv1": [1]}, - system_prompts={"conv1": "Some system"}, - ) - # Default: use_dataset_history=True → _store_in_history=False - strategy = MultiTurnStrategy(conv_manager, metadata) - issuer = FakePhaseIssuer() - - async def complete_turn(): - await asyncio.sleep(0.01) - result = QueryResult( - id="q0000", response_output=TextModelOutput(output="response") - ) - strategy.on_sample_complete(result) - - asyncio.create_task(complete_turn()) - await strategy.execute(issuer) - - state = conv_manager.get_state("conv1") - assert state is not None - # message_history should be empty (dataset-history mode doesn't accumulate) - assert len(state.message_history) == 0 - - @pytest.mark.unit @pytest.mark.asyncio async def test_pipeline_error_propagated(): @@ -442,115 +576,13 @@ def issue( ) -> str | None: raise RuntimeError("simulated pipeline error") + def stop_performance_tracking(self) -> None: + pass + with pytest.raises(RuntimeError, match="simulated pipeline error"): await strategy.execute(ErrorIssuer()) -@pytest.mark.unit -def test_mark_turn_complete_preserves_tool_calls(): - """mark_turn_complete stores tool_calls in history when metadata contains them.""" - conv_manager = ConversationManager() - conv_manager.get_or_create("conv1", expected_client_turns=1) - - tool_calls = [ - { - "id": "call_1", - "type": "function", - "function": {"name": "bash", "arguments": '{"cmd": "ls"}'}, - } - ] - conv_manager.mark_turn_complete( - "conv1", - response="", - store_in_history=True, - metadata={"tool_calls": tool_calls}, - ) - - state = conv_manager.get_state("conv1") - assert state is not None - assert len(state.message_history) == 1 - msg = state.message_history[0] - assert msg["role"] == "assistant" - assert msg["content"] is None - assert msg["tool_calls"] == tool_calls - - -@pytest.mark.unit -def test_mark_turn_complete_with_response_and_tool_calls(): - """mark_turn_complete stores both content and tool_calls when both are present.""" - conv_manager = ConversationManager() - conv_manager.get_or_create("conv1", expected_client_turns=1) - - tool_calls = [ - { - "id": "call_1", - "type": "function", - "function": {"name": "search", "arguments": "{}"}, - } - ] - conv_manager.mark_turn_complete( - "conv1", - response="Calling search...", - store_in_history=True, - metadata={"tool_calls": tool_calls}, - ) - - state = conv_manager.get_state("conv1") - assert state is not None - msg = state.message_history[0] - assert msg["content"] == "Calling search..." - assert msg["tool_calls"] == tool_calls - - -@pytest.mark.unit -def test_mark_turn_complete_no_history_when_empty(): - """mark_turn_complete does not append when response is empty and no tool_calls.""" - conv_manager = ConversationManager() - conv_manager.get_or_create("conv1", expected_client_turns=1) - - conv_manager.mark_turn_complete("conv1", response="", store_in_history=True) - - state = conv_manager.get_state("conv1") - assert state is not None - assert len(state.message_history) == 0 - - -@pytest.mark.unit -@pytest.mark.asyncio -async def test_on_sample_complete_passes_metadata(): - """on_sample_complete forwards result.metadata (including tool_calls) to ConversationManager.""" - from inference_endpoint.config.schema import MultiTurnConfig - - conv_manager = ConversationManager() - metadata_dict = _make_metadata_with_system({"conv1": [1]}) - mt_cfg = MultiTurnConfig(use_dataset_history=False, turn_timeout_s=10.0) - strategy = MultiTurnStrategy(conv_manager, metadata_dict, multi_turn_config=mt_cfg) - - conv_manager.get_or_create("conv1", expected_client_turns=1) - strategy._inflight["q0001"] = "conv1" - - tool_calls = [ - { - "id": "call_1", - "type": "function", - "function": {"name": "bash", "arguments": "{}"}, - } - ] - result = QueryResult( - id="q0001", - response_output=TextModelOutput(output=""), - metadata={"tool_calls": tool_calls}, - ) - strategy.on_sample_complete(result) - - state = conv_manager.get_state("conv1") - assert state is not None - assert state.completed_turns == 1 - assert len(state.message_history) == 1 - assert state.message_history[0]["tool_calls"] == tool_calls - assert state.message_history[0]["content"] is None - - @pytest.mark.unit @pytest.mark.asyncio async def test_concurrency_limits_active_conversations(): @@ -645,7 +677,7 @@ async def test_timeout_publishes_error_and_complete_events(): # Seed: turn 1 in-flight, turns 2+3 still pending strategy._inflight["q-x"] = "conv-x" - strategy._active_iters["conv-x"] = ([(1, 2), (2, 3)], 0) + strategy._active_iters["conv-x"] = ("conv-x", [(1, 2), (2, 3)], 0, 1) issuer = FakePhaseIssuer() issuer.uuid_to_index["q-x"] = 0 @@ -742,7 +774,7 @@ async def test_abort_remaining_turns_includes_pending_delayed_turn(): strategy._phase_issuer = issuer strategy._session_publisher = publisher strategy._session_on_sample_complete = on_sample_complete - strategy._active_iters["c1"] = ([(0, 1), (1, 2), (2, 3)], 1) + strategy._active_iters["c1"] = ("c1", [(0, 1), (1, 2), (2, 3)], 1, 1) strategy._issue_next_turn("c1") @@ -886,6 +918,9 @@ def issue(self, idx, data_override=None, conversation_id="", turn=None): # Raises on the second call, which is triggered by _fill_slot after conv1 completes. raise RuntimeError("simulated slot-refill failure") + def stop_performance_tracking(self) -> None: + pass + issuer = RaisingIssuer() async def complete_conv1(): @@ -918,10 +953,9 @@ async def test_error_turn_aborts_remaining_turns(): strategy._phase_issuer = issuer # Seed: conv1 is active with turns 2 and 3 still pending - remaining_turns = ([(1, 2), (2, 3)], 0) + remaining_turns = ("conv1", [(1, 2), (2, 3)], 0, 1) strategy._active_iters["conv1"] = remaining_turns strategy._inflight["q0001"] = "conv1" - strategy._conv_states["conv1"] = conv_manager.get_state("conv1") result = QueryResult( id="q0001",