diff --git a/chatcompletion.go b/chatcompletion.go index 962b69d0..e6cc1bf0 100644 --- a/chatcompletion.go +++ b/chatcompletion.go @@ -2742,6 +2742,15 @@ func (u ChatCompletionToolChoiceOptionUnionParam) GetType() *string { return nil } +func init() { + apijson.RegisterUnion[ChatCompletionToolChoiceOptionUnionParam]( + "type", + apijson.Discriminator[ChatCompletionAllowedToolChoiceParam]("allowed_tools"), + apijson.Discriminator[ChatCompletionNamedToolChoiceParam]("function"), + apijson.Discriminator[ChatCompletionNamedToolChoiceCustomParam]("custom"), + ) +} + // `none` means the model will not call any tool and instead generates a message. // `auto` means the model can pick between generating a message or calling one or // more tools. `required` means the model must call one or more tools. diff --git a/internal/apijson/union.go b/internal/apijson/union.go index ce0b6f25..075f87d5 100644 --- a/internal/apijson/union.go +++ b/internal/apijson/union.go @@ -2,9 +2,10 @@ package apijson import ( "errors" - "github.com/openai/openai-go/v3/packages/param" "reflect" + "github.com/openai/openai-go/v3/packages/param" + "github.com/tidwall/gjson" ) @@ -66,7 +67,11 @@ func (d *decoderBuilder) newStructUnionDecoder(t reflect.Type) decoderFunc { for _, variant := range unionEntry.variants { // For each union variant, find a matching decoder and save it for _, decoder := range decoders { - if decoder.field.Type.Elem() == variant.Type { + fieldType := decoder.field.Type + if fieldType.Kind() == reflect.Ptr { + fieldType = fieldType.Elem() + } + if fieldType == variant.Type { discriminatedDecoders = append(discriminatedDecoders, discriminatedDecoder{ decoder, variant.DiscriminatorValue,