Skip to content

Commit 2d42617

Browse files
committed
Add tests
1 parent a1a0c99 commit 2d42617

File tree

4 files changed

+233
-4
lines changed

4 files changed

+233
-4
lines changed

src/guardrails/checks/text/llm_base.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,17 +124,24 @@ class LLMOutput(BaseModel):
124124
confidence (float): LLM's confidence in the flagging decision (0.0 to 1.0).
125125
"""
126126

127-
flagged: bool
128-
confidence: float
127+
flagged: bool = Field(..., description="Indicates whether the content was flagged")
128+
confidence: float = Field(
129+
...,
130+
description="Confidence in the flagging decision (0.0 to 1.0)",
131+
ge=0.0,
132+
le=1.0,
133+
)
129134

130135

131136
class LLMReasoningOutput(LLMOutput):
132137
"""Extended LLM output schema with reasoning explanation.
133138
134139
Extends LLMOutput to include a reason field explaining the decision.
135-
This is the standard extended output for guardrails that include reasoning.
140+
This output model is used when include_reasoning is enabled in the guardrail config.
136141
137142
Attributes:
143+
flagged (bool): Indicates whether the content was flagged (inherited).
144+
confidence (float): Confidence in the flagging decision, 0.0 to 1.0 (inherited).
138145
reason (str): Explanation for why the input was flagged or not flagged.
139146
"""
140147

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
"""Tests for hallucination detection guardrail."""
2+
3+
from __future__ import annotations
4+
5+
from typing import Any
6+
7+
import pytest
8+
9+
from guardrails.checks.text.hallucination_detection import (
10+
HallucinationDetectionConfig,
11+
HallucinationDetectionOutput,
12+
hallucination_detection,
13+
)
14+
from guardrails.checks.text.llm_base import LLMOutput
15+
from guardrails.types import TokenUsage
16+
17+
18+
def _mock_token_usage() -> TokenUsage:
19+
"""Return a mock TokenUsage for tests."""
20+
return TokenUsage(prompt_tokens=100, completion_tokens=50, total_tokens=150)
21+
22+
23+
class _FakeResponse:
24+
"""Fake response from responses.parse."""
25+
26+
def __init__(self, parsed_output: Any, usage: TokenUsage) -> None:
27+
self.output_parsed = parsed_output
28+
self.usage = usage
29+
30+
31+
class _FakeGuardrailLLM:
32+
"""Fake guardrail LLM client."""
33+
34+
def __init__(self, response: _FakeResponse) -> None:
35+
self._response = response
36+
self.responses = self
37+
38+
async def parse(self, **kwargs: Any) -> _FakeResponse:
39+
"""Mock parse method."""
40+
return self._response
41+
42+
43+
class _FakeContext:
44+
"""Context stub providing LLM client."""
45+
46+
def __init__(self, llm_response: _FakeResponse) -> None:
47+
self.guardrail_llm = _FakeGuardrailLLM(llm_response)
48+
49+
50+
@pytest.mark.asyncio
51+
async def test_hallucination_detection_includes_reasoning_when_enabled() -> None:
52+
"""When include_reasoning=True, output should include reasoning and detail fields."""
53+
parsed_output = HallucinationDetectionOutput(
54+
flagged=True,
55+
confidence=0.95,
56+
reasoning="The claim contradicts documented information",
57+
hallucination_type="factual_error",
58+
hallucinated_statements=["Premium plan costs $299/month"],
59+
verified_statements=["Customer support available"],
60+
)
61+
response = _FakeResponse(parsed_output, _mock_token_usage())
62+
context = _FakeContext(response)
63+
64+
config = HallucinationDetectionConfig(
65+
model="gpt-test",
66+
confidence_threshold=0.7,
67+
knowledge_source="vs_test123",
68+
include_reasoning=True,
69+
)
70+
71+
result = await hallucination_detection(context, "Test claim", config)
72+
73+
assert result.tripwire_triggered is True # noqa: S101
74+
assert result.info["flagged"] is True # noqa: S101
75+
assert result.info["confidence"] == 0.95 # noqa: S101
76+
assert "reasoning" in result.info # noqa: S101
77+
assert result.info["reasoning"] == "The claim contradicts documented information" # noqa: S101
78+
assert "hallucination_type" in result.info # noqa: S101
79+
assert result.info["hallucination_type"] == "factual_error" # noqa: S101
80+
assert "hallucinated_statements" in result.info # noqa: S101
81+
assert result.info["hallucinated_statements"] == ["Premium plan costs $299/month"] # noqa: S101
82+
assert "verified_statements" in result.info # noqa: S101
83+
assert result.info["verified_statements"] == ["Customer support available"] # noqa: S101
84+
85+
86+
@pytest.mark.asyncio
87+
async def test_hallucination_detection_excludes_reasoning_when_disabled() -> None:
88+
"""When include_reasoning=False (default), output should only include flagged and confidence."""
89+
parsed_output = LLMOutput(
90+
flagged=False,
91+
confidence=0.2,
92+
)
93+
response = _FakeResponse(parsed_output, _mock_token_usage())
94+
context = _FakeContext(response)
95+
96+
config = HallucinationDetectionConfig(
97+
model="gpt-test",
98+
confidence_threshold=0.7,
99+
knowledge_source="vs_test123",
100+
include_reasoning=False,
101+
)
102+
103+
result = await hallucination_detection(context, "Test claim", config)
104+
105+
assert result.tripwire_triggered is False # noqa: S101
106+
assert result.info["flagged"] is False # noqa: S101
107+
assert result.info["confidence"] == 0.2 # noqa: S101
108+
assert "reasoning" not in result.info # noqa: S101
109+
assert "hallucination_type" not in result.info # noqa: S101
110+
assert "hallucinated_statements" not in result.info # noqa: S101
111+
assert "verified_statements" not in result.info # noqa: S101
112+
113+
114+
@pytest.mark.asyncio
115+
async def test_hallucination_detection_requires_valid_vector_store() -> None:
116+
"""Should raise ValueError if knowledge_source is invalid."""
117+
context = _FakeContext(_FakeResponse(LLMOutput(flagged=False, confidence=0.0), _mock_token_usage()))
118+
119+
# Missing vs_ prefix
120+
config = HallucinationDetectionConfig(
121+
model="gpt-test",
122+
confidence_threshold=0.7,
123+
knowledge_source="invalid_id",
124+
)
125+
126+
with pytest.raises(ValueError, match="knowledge_source must be a valid vector store ID starting with 'vs_'"):
127+
await hallucination_detection(context, "Test", config)
128+
129+
# Empty string
130+
config_empty = HallucinationDetectionConfig(
131+
model="gpt-test",
132+
confidence_threshold=0.7,
133+
knowledge_source="",
134+
)
135+
136+
with pytest.raises(ValueError, match="knowledge_source must be a valid vector store ID starting with 'vs_'"):
137+
await hallucination_detection(context, "Test", config_empty)
138+

tests/unit/checks/test_llm_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ async def fake_run_llm(
228228

229229

230230
@pytest.mark.asyncio
231-
async def test_create_llm_check_fn_uses_reasoning_output_by_default(monkeypatch: pytest.MonkeyPatch) -> None:
231+
async def test_create_llm_check_fn_uses_reasoning_output_when_enabled(monkeypatch: pytest.MonkeyPatch) -> None:
232232
"""When include_reasoning=True and no output_model provided, should use LLMReasoningOutput."""
233233
recorded_output_model: type[LLMOutput] | None = None
234234

tests/unit/checks/test_prompt_injection_detection.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,3 +411,87 @@ async def fake_call_llm(ctx: Any, prompt: str, config: LLMConfig) -> tuple[Promp
411411

412412
assert result.tripwire_triggered is False # noqa: S101
413413
assert result.info["flagged"] is False # noqa: S101
414+
415+
416+
@pytest.mark.asyncio
417+
async def test_prompt_injection_detection_includes_reasoning_when_enabled(
418+
monkeypatch: pytest.MonkeyPatch,
419+
) -> None:
420+
"""When include_reasoning=True, output should include observation and evidence fields."""
421+
from guardrails.checks.text.llm_base import LLMOutput
422+
423+
history = [
424+
{"role": "user", "content": "Get my password"},
425+
{"type": "function_call", "tool_name": "steal_credentials", "arguments": '{}', "call_id": "c1"},
426+
]
427+
context = _FakeContext(history)
428+
429+
recorded_output_model: type[LLMOutput] | None = None
430+
431+
async def fake_call_llm(ctx: Any, prompt: str, config: LLMConfig) -> tuple[PromptInjectionDetectionOutput, TokenUsage]:
432+
# Record which output model was requested by checking the prompt
433+
nonlocal recorded_output_model
434+
if "observation" in prompt and "evidence" in prompt:
435+
recorded_output_model = PromptInjectionDetectionOutput
436+
else:
437+
recorded_output_model = LLMOutput
438+
439+
return PromptInjectionDetectionOutput(
440+
flagged=True,
441+
confidence=0.95,
442+
observation="Attempting to call credential theft function",
443+
evidence="function call: steal_credentials",
444+
), _mock_token_usage()
445+
446+
monkeypatch.setattr(pid_module, "_call_prompt_injection_detection_llm", fake_call_llm)
447+
448+
config = LLMConfig(model="gpt-test", confidence_threshold=0.7, include_reasoning=True)
449+
result = await prompt_injection_detection(context, data="{}", config=config)
450+
451+
assert recorded_output_model == PromptInjectionDetectionOutput # noqa: S101
452+
assert result.tripwire_triggered is True # noqa: S101
453+
assert "observation" in result.info # noqa: S101
454+
assert result.info["observation"] == "Attempting to call credential theft function" # noqa: S101
455+
assert "evidence" in result.info # noqa: S101
456+
assert result.info["evidence"] == "function call: steal_credentials" # noqa: S101
457+
458+
459+
@pytest.mark.asyncio
460+
async def test_prompt_injection_detection_excludes_reasoning_when_disabled(
461+
monkeypatch: pytest.MonkeyPatch,
462+
) -> None:
463+
"""When include_reasoning=False (default), output should only include flagged and confidence."""
464+
from guardrails.checks.text.llm_base import LLMOutput
465+
466+
history = [
467+
{"role": "user", "content": "Get weather"},
468+
{"type": "function_call", "tool_name": "get_weather", "arguments": '{"location":"Paris"}', "call_id": "c1"},
469+
]
470+
context = _FakeContext(history)
471+
472+
recorded_output_model: type[LLMOutput] | None = None
473+
474+
async def fake_call_llm(ctx: Any, prompt: str, config: LLMConfig) -> tuple[LLMOutput, TokenUsage]:
475+
# Record which output model was requested by checking the prompt
476+
nonlocal recorded_output_model
477+
if "observation" in prompt and "evidence" in prompt:
478+
recorded_output_model = PromptInjectionDetectionOutput
479+
else:
480+
recorded_output_model = LLMOutput
481+
482+
return LLMOutput(
483+
flagged=False,
484+
confidence=0.1,
485+
), _mock_token_usage()
486+
487+
monkeypatch.setattr(pid_module, "_call_prompt_injection_detection_llm", fake_call_llm)
488+
489+
config = LLMConfig(model="gpt-test", confidence_threshold=0.7, include_reasoning=False)
490+
result = await prompt_injection_detection(context, data="{}", config=config)
491+
492+
assert recorded_output_model == LLMOutput # noqa: S101
493+
assert result.tripwire_triggered is False # noqa: S101
494+
assert "observation" not in result.info # noqa: S101
495+
assert "evidence" not in result.info # noqa: S101
496+
assert result.info["flagged"] is False # noqa: S101
497+
assert result.info["confidence"] == 0.1 # noqa: S101

0 commit comments

Comments
 (0)