Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions dataframely/columns/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ def __init__(
max: dt.timedelta | None = None,
max_exclusive: dt.timedelta | None = None,
resolution: str | None = None,
time_unit: TimeUnit = "us",
check: Check | None = None,
alias: str | None = None,
metadata: dict[str, Any] | None = None,
Expand All @@ -462,6 +463,7 @@ def __init__(
the formatting language used by :mod:`polars` datetime `truncate` method.
For example, a value `1h` expects all durations to be full hours. Note
that this setting does *not* affect the storage resolution.
time_unit: Unit of time. Defaults to `us` (microseconds).
check: A custom rule or multiple rules to run for this column. This can be:
- A single callable that returns a non-aggregated boolean expression.
The name of the rule is derived from the callable name, or defaults to
Expand Down Expand Up @@ -504,10 +506,11 @@ def __init__(
metadata=metadata,
)
self.resolution = resolution
self.time_unit = time_unit

@property
def dtype(self) -> pl.DataType:
return pl.Duration()
return pl.Duration(time_unit=self.time_unit)

def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]:
result = super().validation_rules(expr)
Expand All @@ -526,7 +529,7 @@ def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:

@property
def pyarrow_dtype(self) -> pa.DataType:
return pa.duration("us")
return pa.duration(self.time_unit)

def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
# NOTE: If no duration is specified, we default to 100 years
Expand All @@ -543,6 +546,7 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
default=dt.timedelta(days=365 * 100),
),
resolution=self.resolution,
time_unit=self.time_unit,
null_probability=self._null_probability,
)

Expand Down
4 changes: 3 additions & 1 deletion dataframely/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ def sample_duration(
min: dt.timedelta,
max: dt.timedelta,
resolution: str | None = None,
time_unit: TimeUnit = "us",
null_probability: float = 0.0,
) -> pl.Series:
"""Sample a list of durations in the provided range.
Expand All @@ -386,6 +387,7 @@ def sample_duration(
max: The maximum duration to sample (exclusive).
resolution: The resolution that durations in the column must have. This uses
the formatting language used by :mod:`polars` datetime `round` method.
time_unit: The time unit of the duration column. Defaults to `us` (microseconds).
null_probability: The probability of an element being `null`.

Returns:
Expand All @@ -410,7 +412,7 @@ def sample_duration(
max=max_microseconds,
null_probability=null_probability,
)
).cast(pl.Duration)
).cast(pl.Duration(time_unit=time_unit))

if resolution is not None:
ref_dt = pl.lit(EPOCH_DATETIME)
Expand Down
8 changes: 8 additions & 0 deletions tests/columns/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,11 @@ def test_datetime_time_unit(time_unit: TimeUnit) -> None:
"test", {"a": dy.Datetime(time_unit=time_unit, nullable=True)}
)
assert str(schema.to_pyarrow_schema()) == f"a: timestamp[{time_unit}]"


@pytest.mark.parametrize("time_unit", ["ns", "us", "ms"])
def test_duration_time_unit(time_unit: TimeUnit) -> None:
schema = create_schema(
"test", {"a": dy.Duration(time_unit=time_unit, nullable=True)}
)
assert str(schema.to_pyarrow_schema()) == f"a: duration[{time_unit}]"
Loading