Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import builtins
import ipaddress
import uuid
import warnings
import weakref
from collections.abc import Mapping, Sequence, Set
from datetime import date, datetime, time, timedelta
Expand Down Expand Up @@ -214,6 +215,7 @@ def Field(
exclude: Union[Set[Union[int, str]], Mapping[Union[int, str], Any], Any] = None,
include: Union[Set[Union[int, str]], Mapping[Union[int, str], Any], Any] = None,
const: Optional[bool] = None,
coerce_numbers_to_str: Optional[bool] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
Expand All @@ -226,9 +228,12 @@ def Field(
unique_items: Optional[bool] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
union_mode: Optional[Literal["smart", "left_to_right"]] = None,
fail_fast: Optional[bool] = None,
allow_mutation: bool = True,
regex: Optional[str] = None,
discriminator: Optional[str] = None,
validate_default: Optional[bool] = None,
repr: bool = True,
primary_key: Union[bool, UndefinedType] = Undefined,
foreign_key: Any = Undefined,
Expand Down Expand Up @@ -257,6 +262,7 @@ def Field(
exclude: Union[Set[Union[int, str]], Mapping[Union[int, str], Any], Any] = None,
include: Union[Set[Union[int, str]], Mapping[Union[int, str], Any], Any] = None,
const: Optional[bool] = None,
coerce_numbers_to_str: Optional[bool] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
Expand All @@ -269,9 +275,12 @@ def Field(
unique_items: Optional[bool] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
union_mode: Optional[Literal["smart", "left_to_right"]] = None,
fail_fast: Optional[bool] = None,
allow_mutation: bool = True,
regex: Optional[str] = None,
discriminator: Optional[str] = None,
validate_default: Optional[bool] = None,
repr: bool = True,
primary_key: Union[bool, UndefinedType] = Undefined,
foreign_key: str,
Expand Down Expand Up @@ -309,6 +318,7 @@ def Field(
exclude: Union[Set[Union[int, str]], Mapping[Union[int, str], Any], Any] = None,
include: Union[Set[Union[int, str]], Mapping[Union[int, str], Any], Any] = None,
const: Optional[bool] = None,
coerce_numbers_to_str: Optional[bool] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
Expand All @@ -321,9 +331,12 @@ def Field(
unique_items: Optional[bool] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
union_mode: Optional[Literal["smart", "left_to_right"]] = None,
fail_fast: Optional[bool] = None,
allow_mutation: bool = True,
regex: Optional[str] = None,
discriminator: Optional[str] = None,
validate_default: Optional[bool] = None,
repr: bool = True,
sa_column: Union[Column[Any], UndefinedType] = Undefined,
schema_extra: Optional[dict[str, Any]] = None,
Expand All @@ -342,6 +355,7 @@ def Field(
exclude: Union[Set[Union[int, str]], Mapping[Union[int, str], Any], Any] = None,
include: Union[Set[Union[int, str]], Mapping[Union[int, str], Any], Any] = None,
const: Optional[bool] = None,
coerce_numbers_to_str: Optional[bool] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
Expand All @@ -354,9 +368,12 @@ def Field(
unique_items: Optional[bool] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
union_mode: Optional[Literal["smart", "left_to_right"]] = None,
fail_fast: Optional[bool] = None,
allow_mutation: bool = True,
regex: Optional[str] = None,
discriminator: Optional[str] = None,
validate_default: Optional[bool] = None,
repr: bool = True,
primary_key: Union[bool, UndefinedType] = Undefined,
foreign_key: Any = Undefined,
Expand All @@ -371,16 +388,36 @@ def Field(
schema_extra: Optional[dict[str, Any]] = None,
) -> Any:
current_schema_extra = schema_extra or {}

for param_name in (
"coerce_numbers_to_str",
"validate_default",
"union_mode",
"fail_fast",
):
if param_name in current_schema_extra:
msg = f"Pass `{param_name}` parameter directly to Field instead of passing it via `schema_extra`"
warnings.warn(msg, UserWarning, stacklevel=2)

# Extract possible alias settings from schema_extra so we can control precedence
schema_validation_alias = current_schema_extra.pop("validation_alias", None)
schema_serialization_alias = current_schema_extra.pop("serialization_alias", None)
current_coerce_numbers_to_str = coerce_numbers_to_str or current_schema_extra.pop(
"coerce_numbers_to_str", None
)
current_validate_default = validate_default or current_schema_extra.pop(
"validate_default", None
)
current_fail_fast = fail_fast or current_schema_extra.pop("fail_fast", None)
field_info_kwargs = {
"alias": alias,
"title": title,
"description": description,
"exclude": exclude,
"include": include,
"const": const,
"coerce_numbers_to_str": current_coerce_numbers_to_str,
"validate_default": current_validate_default,
"gt": gt,
"ge": ge,
"lt": lt,
Expand All @@ -393,6 +430,7 @@ def Field(
"unique_items": unique_items,
"min_length": min_length,
"max_length": max_length,
"fail_fast": current_fail_fast,
"allow_mutation": allow_mutation,
"regex": regex,
"discriminator": discriminator,
Expand All @@ -418,6 +456,10 @@ def Field(
serialization_alias or schema_serialization_alias or alias
)

current_union_mode = union_mode or current_schema_extra.pop("union_mode", None)
if current_union_mode is not None:
field_info_kwargs["union_mode"] = current_union_mode

field_info = FieldInfo(
default,
default_factory=default_factory,
Expand Down
191 changes: 190 additions & 1 deletion tests/test_pydantic/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest
from pydantic import ValidationError
from sqlmodel import Field, SQLModel
from sqlmodel import Field, Session, SQLModel, create_engine


def test_decimal():
Expand Down Expand Up @@ -54,3 +54,192 @@ class Model(SQLModel):

instance = Model(id=123, foo="bar")
assert "foo=" not in repr(instance)


def test_coerce_numbers_to_str_true():
class Model(SQLModel):
val: str = Field(coerce_numbers_to_str=True)

assert Model.model_validate({"val": 123}).val == "123"
assert Model.model_validate({"val": 45.67}).val == "45.67"


@pytest.mark.parametrize("coerce_numbers_to_str", [None, False])
def test_coerce_numbers_to_str_false(coerce_numbers_to_str: Optional[bool]):
class Model2(SQLModel):
val: str = Field(coerce_numbers_to_str=coerce_numbers_to_str)

with pytest.raises(ValidationError):
Model2.model_validate({"val": 123})


def test_coerce_numbers_to_str_via_schema_extra(): # Current workaround. Remove after some time
with pytest.warns(
UserWarning,
match=(
"Pass `coerce_numbers_to_str` parameter directly to Field instead of passing "
"it via `schema_extra`"
),
):

class Model(SQLModel):
val: str = Field(schema_extra={"coerce_numbers_to_str": True})

assert Model.model_validate({"val": 123}).val == "123"
assert Model.model_validate({"val": 45.67}).val == "45.67"


def test_validate_default_true():
class Model(SQLModel):
val: int = Field(default="123", validate_default=True)

assert Model.model_validate({}).val == 123

class Model2(SQLModel):
val: int = Field(default=None, validate_default=True)

with pytest.raises(ValidationError):
Model2.model_validate({})


def test_validate_default_table_model():
class Model(SQLModel):
id: Optional[int] = Field(default=None, primary_key=True)
val: int = Field(default="123", validate_default=True)

class ModelDB(Model, table=True):
pass

engine = create_engine("sqlite://", echo=True)

SQLModel.metadata.create_all(engine)

model = ModelDB()
with Session(engine) as session:
session.add(model)
session.commit()
session.refresh(model)

assert model.val == 123


@pytest.mark.parametrize("validate_default", [None, False])
def test_validate_default_false(validate_default: Optional[bool]):
class Model3(SQLModel):
val: int = Field(default="123", validate_default=validate_default)

assert Model3().val == "123"


def test_validate_default_via_schema_extra(): # Current workaround. Remove after some time
with pytest.warns(
UserWarning,
match=(
"Pass `validate_default` parameter directly to Field instead of passing "
"it via `schema_extra`"
),
):

class Model(SQLModel):
val: int = Field(default="123", schema_extra={"validate_default": True})

assert Model.model_validate({}).val == 123


@pytest.mark.parametrize("union_mode", [None, "smart"])
def test_union_mode_smart(union_mode: Optional[Literal["smart"]]):
class Model(SQLModel):
val: Union[float, int] = Field(union_mode=union_mode)

a = Model.model_validate({"val": 123})
assert isinstance(a.val, int) # float is first, but int is more precise

b = Model.model_validate({"val": 123.0})
assert isinstance(b.val, float)

c = Model.model_validate({"val": 123.1})
assert isinstance(c.val, float)


def test_union_mode_left_to_right():
class Model(SQLModel):
val: Union[float, int] = Field(union_mode="left_to_right")

a = Model.model_validate({"val": 123})
assert isinstance(a.val, float)

b = Model.model_validate({"val": 123.0})
assert isinstance(b.val, float)

c = Model.model_validate({"val": 123.1})
assert isinstance(c.val, float)


def test_union_mode_via_schema_extra(): # Current workaround. Remove after some time
with pytest.warns(
UserWarning,
match=(
"Pass `union_mode` parameter directly to Field instead of passing "
"it via `schema_extra`"
),
):

class Model(SQLModel):
val: Union[float, int] = Field(schema_extra={"union_mode": "smart"})

a = Model.model_validate({"val": 123})
assert isinstance(a.val, int) # float is first, but int is more precise

b = Model.model_validate({"val": 123.0})
assert isinstance(b.val, float)

c = Model.model_validate({"val": 123.1})
assert isinstance(c.val, float)


def test_fail_fast_true():
class Model(SQLModel):
val: list[int] = Field(fail_fast=True)

with pytest.raises(ValidationError) as exc_info:
Model.model_validate({"val": [1.1, "not an int"]})

errors = exc_info.value.errors()
assert len(errors) == 1
assert errors[0]["type"] == "int_from_float"


@pytest.mark.parametrize("fail_fast", [None, False])
def test_fail_fast_false(fail_fast: Optional[bool]):
class Model(SQLModel):
val: list[int] = Field(fail_fast=fail_fast)

with pytest.raises(ValidationError) as exc_info:
Model.model_validate({"val": [1.1, "not an int"]})

errors = exc_info.value.errors()
assert len(errors) == 2
error_types = {error["type"] for error in errors}

assert "int_from_float" in error_types
assert "int_parsing" in error_types


def test_fail_fast_via_schema_extra(): # Current workaround. Remove after some time
with pytest.warns(
UserWarning,
match=(
"Pass `fail_fast` parameter directly to Field instead of passing "
"it via `schema_extra`"
),
):

class Model(SQLModel):
val: list[int] = Field(schema_extra={"fail_fast": True})

with pytest.raises(ValidationError) as exc_info:
Model.model_validate({"val": [1.1, "not an int"]})

errors = exc_info.value.errors()
assert len(errors) == 1
assert errors[0]["type"] == "int_from_float"