Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions go/ai/action_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,15 @@ func TestGenerateAction(t *testing.T) {
t.Fatalf("action failed: %v", err)
}

if diff := cmp.Diff(tc.ExpectChunks, chunks); diff != "" {
if diff := cmp.Diff(tc.ExpectChunks, chunks, cmp.Options{
cmpopts.IgnoreFields(ModelResponseChunk{}, "formatHandler"),
}); diff != "" {
t.Errorf("chunks mismatch (-want +got):\n%s", diff)
}

if diff := cmp.Diff(tc.ExpectResponse, resp, cmp.Options{
cmpopts.EquateEmpty(),
cmpopts.IgnoreFields(ModelResponse{}, "LatencyMs"),
cmpopts.IgnoreFields(ModelResponse{}, "LatencyMs", "formatHandler"),
cmpopts.IgnoreFields(GenerationUsage{}, "InputCharacters", "OutputCharacters"),
cmpopts.IgnoreFields(ToolDefinition{}, "Metadata"),
}); diff != "" {
Expand All @@ -156,7 +158,7 @@ func TestGenerateAction(t *testing.T) {

if diff := cmp.Diff(tc.ExpectResponse, resp, cmp.Options{
cmpopts.EquateEmpty(),
cmpopts.IgnoreFields(ModelResponse{}, "LatencyMs"),
cmpopts.IgnoreFields(ModelResponse{}, "LatencyMs", "formatHandler"),
cmpopts.IgnoreFields(GenerationUsage{}, "InputCharacters", "OutputCharacters"),
cmpopts.IgnoreFields(ToolDefinition{}, "Metadata"),
}); diff != "" {
Expand Down
70 changes: 30 additions & 40 deletions go/ai/format_array.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ package ai

import (
"encoding/json"
"errors"
"fmt"
"strings"

"github.com/firebase/genkit/go/internal/base"
)
Expand All @@ -45,6 +43,7 @@ func (a arrayFormatter) Handler(schema map[string]any) (FormatHandler, error) {
handler := &arrayHandler{
instructions: instructions,
config: ModelOutputConfig{
Constrained: true,
Format: OutputFormatArray,
Schema: schema,
ContentType: "application/json",
Expand All @@ -55,58 +54,49 @@ func (a arrayFormatter) Handler(schema map[string]any) (FormatHandler, error) {
}

type arrayHandler struct {
instructions string
config ModelOutputConfig
instructions string
config ModelOutputConfig
accumulatedText string
currentIndex int
cursor int
}

// Instructions returns the instructions for the formatter.
func (a arrayHandler) Instructions() string {
func (a *arrayHandler) Instructions() string {
return a.instructions
}

// Config returns the output config for the formatter.
func (a arrayHandler) Config() ModelOutputConfig {
func (a *arrayHandler) Config() ModelOutputConfig {
return a.config
}

// ParseMessage parses the message and returns the formatted message.
func (a arrayHandler) ParseMessage(m *Message) (*Message, error) {
if a.config.Format == OutputFormatArray {
if m == nil {
return nil, errors.New("message is empty")
}
if len(m.Content) == 0 {
return nil, errors.New("message has no content")
}

var nonTextParts []*Part
accumulatedText := strings.Builder{}
// ParseOutput parses the final message and returns the parsed array.
func (a *arrayHandler) ParseOutput(m *Message) (any, error) {
result := base.ExtractItems(m.Text(), 0)
return result.Items, nil
}

for _, part := range m.Content {
if !part.IsText() {
nonTextParts = append(nonTextParts, part)
} else {
accumulatedText.WriteString(part.Text)
}
}
// ParseChunk processes a streaming chunk and returns parsed output.
func (a *arrayHandler) ParseChunk(chunk *ModelResponseChunk) (any, error) {
if chunk.Index != a.currentIndex {
a.accumulatedText = ""
a.currentIndex = chunk.Index
a.cursor = 0
}

var newParts []*Part
lines := base.GetJSONObjectLines(accumulatedText.String())
for _, line := range lines {
var schemaBytes []byte
schemaBytes, err := json.Marshal(a.config.Schema["items"])
if err != nil {
return nil, fmt.Errorf("expected schema is not valid: %w", err)
}
if err = base.ValidateRaw([]byte(line), schemaBytes); err != nil {
return nil, err
}

newParts = append(newParts, NewJSONPart(line))
for _, part := range chunk.Content {
if part.IsText() {
a.accumulatedText += part.Text
}

m.Content = append(newParts, nonTextParts...)
}

result := base.ExtractItems(a.accumulatedText, a.cursor)
a.cursor = result.Cursor
return result.Items, nil
}

// ParseMessage parses the message and returns the formatted message.
func (a *arrayHandler) ParseMessage(m *Message) (*Message, error) {
return m, nil
}
101 changes: 85 additions & 16 deletions go/ai/format_enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"regexp"
"slices"
"strings"

"github.com/firebase/genkit/go/core"
)

type enumFormatter struct{}
Expand All @@ -33,14 +35,15 @@ func (e enumFormatter) Name() string {
func (e enumFormatter) Handler(schema map[string]any) (FormatHandler, error) {
enums := objectEnums(schema)
if schema == nil || len(enums) == 0 {
return nil, fmt.Errorf("schema is not valid JSON enum")
return nil, core.NewError(core.INVALID_ARGUMENT, "schema must be an object with an 'enum' property for enum format")
}

instructions := fmt.Sprintf("Output should be ONLY one of the following enum values. Do not output any additional information or add quotes.\n\n```%s```", strings.Join(enums, "\n"))

handler := &enumHandler{
instructions: instructions,
config: ModelOutputConfig{
Constrained: true,
Format: OutputFormatEnum,
Schema: schema,
ContentType: "text/enum",
Expand All @@ -52,23 +55,49 @@ func (e enumFormatter) Handler(schema map[string]any) (FormatHandler, error) {
}

type enumHandler struct {
instructions string
config ModelOutputConfig
enums []string
instructions string
config ModelOutputConfig
enums []string
accumulatedText string
currentIndex int
}

// Instructions returns the instructions for the formatter.
func (e enumHandler) Instructions() string {
func (e *enumHandler) Instructions() string {
return e.instructions
}

// Config returns the output config for the formatter.
func (e enumHandler) Config() ModelOutputConfig {
func (e *enumHandler) Config() ModelOutputConfig {
return e.config
}

// ParseOutput parses the final message and returns the enum value.
func (e *enumHandler) ParseOutput(m *Message) (any, error) {
return e.parseEnum(m.Text())
}

// ParseChunk processes a streaming chunk and returns parsed output.
func (e *enumHandler) ParseChunk(chunk *ModelResponseChunk) (any, error) {
if chunk.Index != e.currentIndex {
e.accumulatedText = ""
e.currentIndex = chunk.Index
}

for _, part := range chunk.Content {
if part.IsText() {
e.accumulatedText += part.Text
}
}

// Ignore error since we are doing best effort parsing.
enum, _ := e.parseEnum(e.accumulatedText)

return enum, nil
}

// ParseMessage parses the message and returns the formatted message.
func (e enumHandler) ParseMessage(m *Message) (*Message, error) {
func (e *enumHandler) ParseMessage(m *Message) (*Message, error) {
if e.config.Format == OutputFormatEnum {
if m == nil {
return nil, errors.New("message is empty")
Expand Down Expand Up @@ -107,23 +136,63 @@ func (e enumHandler) ParseMessage(m *Message) (*Message, error) {
return m, nil
}

// Get enum strings from json schema
// Get enum strings from json schema.
// Supports both top-level enum (e.g. {"type": "string", "enum": ["a", "b"]})
// and nested property enum (e.g. {"properties": {"value": {"enum": ["a", "b"]}}}).
func objectEnums(schema map[string]any) []string {
var enums []string
if enums := extractEnumStrings(schema["enum"]); len(enums) > 0 {
return enums
}

if properties, ok := schema["properties"].(map[string]any); ok {
for _, propValue := range properties {
if propMap, ok := propValue.(map[string]any); ok {
if enumSlice, ok := propMap["enum"].([]any); ok {
for _, enumVal := range enumSlice {
if enumStr, ok := enumVal.(string); ok {
enums = append(enums, enumStr)
}
}
if enums := extractEnumStrings(propMap["enum"]); len(enums) > 0 {
return enums
}
}
}
}

return enums
return nil
}

// Extracts string values from an enum field, supporting both []any (from JSON) and []string (from Go code).
func extractEnumStrings(v any) []string {
if v == nil {
return nil
}

if strs, ok := v.([]string); ok {
return strs
}

if slice, ok := v.([]any); ok {
enums := make([]string, 0, len(slice))
for _, val := range slice {
if s, ok := val.(string); ok {
enums = append(enums, s)
}
}
return enums
}

return nil
}

// parseEnum is the shared parsing logic used by both ParseOutput and ParseChunk.
func (e *enumHandler) parseEnum(text string) (string, error) {
if text == "" {
return "", nil
}

re := regexp.MustCompile(`['"]`)
clean := re.ReplaceAllString(text, "")
trimmed := strings.TrimSpace(clean)

if !slices.Contains(e.enums, trimmed) {
return "", fmt.Errorf("message %s not in list of valid enums: %s", trimmed, strings.Join(e.enums, ", "))
}

return trimmed, nil
}
Loading
Loading