Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 88 additions & 34 deletions graph/src/schema/input/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ pub struct Aggregate {
}

impl Aggregate {
fn new(_schema: &Schema, name: &str, field_type: &s::Type, dir: &s::Directive) -> Self {
fn new(schema: &Schema, name: &str, field_type: &s::Type, dir: &s::Directive) -> Self {
let func = dir
.argument("fn")
.unwrap()
Expand Down Expand Up @@ -818,7 +818,7 @@ impl Aggregate {
arg,
cumulative,
field_type: field_type.clone(),
value_type: field_type.get_base_type().parse().unwrap(),
value_type: Field::scalar_value_type(schema, field_type),
}
}

Expand Down Expand Up @@ -2366,27 +2366,63 @@ mod validations {
}
}

fn aggregate_fields_are_numbers(agg_type: &s::ObjectType, errors: &mut Vec<Err>) {
fn aggregate_field_types(
schema: &Schema,
agg_type: &s::ObjectType,
errors: &mut Vec<Err>,
) {
fn is_first_last(agg_directive: &s::Directive) -> bool {
match agg_directive.argument(kw::FUNC) {
Some(s::Value::Enum(func) | s::Value::String(func)) => {
func == AggregateFn::First.as_str()
|| func == AggregateFn::Last.as_str()
}
_ => false,
}
}

let errs = agg_type
.fields
.iter()
.filter(|field| field.find_directive(kw::AGGREGATE).is_some())
.map(|field| match field.field_type.value_type() {
Ok(vt) => {
if vt.is_numeric() {
Ok(())
} else {
Err(Err::NonNumericAggregate(
.filter_map(|field| {
field
.find_directive(kw::AGGREGATE)
.map(|agg_directive| (field, agg_directive))
})
.map(|(field, agg_directive)| {
let is_first_last = is_first_last(agg_directive);

match field.field_type.value_type() {
Ok(value_type) if value_type.is_numeric() => Ok(()),
Ok(ValueType::Bytes | ValueType::String) if is_first_last => Ok(()),
Ok(_) if is_first_last => Err(Err::InvalidFirstLastAggregate(
agg_type.name.clone(),
field.name.clone(),
)),
Ok(_) => Err(Err::NonNumericAggregate(
agg_type.name.to_owned(),
field.name.to_owned(),
)),
Err(_) => {
if is_first_last
&& schema
.entity_types
.iter()
.find(|entity_type| {
entity_type.name.eq(field.field_type.get_base_type())
})
.is_some()
{
return Ok(());
}

Err(Err::FieldTypeUnknown(
agg_type.name.to_owned(),
field.name.to_owned(),
field.field_type.get_base_type().to_owned(),
))
}
}
Err(_) => Err(Err::FieldTypeUnknown(
agg_type.name.to_owned(),
field.name.to_owned(),
field.field_type.get_base_type().to_owned(),
)),
})
.filter_map(|err| err.err());
errors.extend(errs);
Expand Down Expand Up @@ -2519,16 +2555,10 @@ mod validations {
continue;
}
};
let field_type = match field.field_type.value_type() {
Ok(field_type) => field_type,
Err(_) => {
errors.push(Err::NonNumericAggregate(
agg_type.name.to_owned(),
field.name.to_owned(),
));
continue;
}
};

let is_first_last =
matches!(func, AggregateFn::First | AggregateFn::Last);

// It would be nicer to use a proper struct here
// and have that implement
// `sqlexpr::ExprVisitor` but we need access to
Expand All @@ -2539,6 +2569,18 @@ mod validations {
let arg_type = match source.field(ident) {
Some(arg_field) => match arg_field.field_type.value_type() {
Ok(arg_type) if arg_type.is_numeric() => arg_type,
Ok(ValueType::Bytes | ValueType::String)
if is_first_last =>
{
return Ok(());
}
Err(_)
if is_first_last
&& arg_field.field_type.get_base_type()
== field.field_type.get_base_type() =>
{
return Ok(());
}
Ok(_) | Err(_) => {
return Err(Err::AggregationNonNumericArg(
agg_type.name.to_owned(),
Expand All @@ -2556,15 +2598,27 @@ mod validations {
));
}
};
if arg_type > field_type {
return Err(Err::AggregationNonMatchingArg(
agg_type.name.to_owned(),
field.name.to_owned(),
arg.to_owned(),
arg_type.to_str().to_owned(),
field_type.to_str().to_owned(),
));

match field.field_type.value_type() {
Ok(field_type) if field_type.is_numeric() => {
if arg_type > field_type {
return Err(Err::AggregationNonMatchingArg(
agg_type.name.to_owned(),
field.name.to_owned(),
arg.to_owned(),
arg_type.to_str().to_owned(),
field_type.to_str().to_owned(),
));
}
}
Ok(_) | Err(_) => {
return Err(Err::NonNumericAggregate(
agg_type.name.to_owned(),
field.name.to_owned(),
));
}
}

Ok(())
};
if let Err(mut errs) = sqlexpr::parse(arg, check_ident) {
Expand Down Expand Up @@ -2661,7 +2715,7 @@ mod validations {
errors.push(err);
}
no_derived_fields(agg_type, &mut errors);
aggregate_fields_are_numbers(agg_type, &mut errors);
aggregate_field_types(self, agg_type, &mut errors);
aggregate_directive(self, agg_type, &mut errors);
// check timeseries directive has intervals and args
aggregation_intervals(agg_type, &mut errors);
Expand Down
2 changes: 2 additions & 0 deletions graph/src/schema/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ pub enum SchemaValidationError {
TimestampFieldMissing(String),
#[error("Aggregation {0}, field{1}: aggregates must use a numeric type, one of Int, Int8, BigInt, and BigDecimal")]
NonNumericAggregate(String, String),
#[error("Aggregation '{0}', field '{1}': first/last aggregates must use a numeric, byte array, string or a reference type")]
InvalidFirstLastAggregate(String, String),
#[error("Aggregation {0} is missing the `source` argument")]
AggregationMissingSource(String),
#[error(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# fail: FieldTypeUnknown("Stats", "firstBlockNumber", "BlockNumber")

type Data @entity(timeseries: true) {
id: Int8!
timestamp: Timestamp!
blockNumber: Int8!
}

type Stats @aggregation(intervals: ["hour", "day"], source: "Data") {
id: Int8!
timestamp: Timestamp!

firstBlockNumber: BlockNumber! @aggregate(fn: "first", arg: "blockNumber")
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# fail: FieldTypeUnknown("Stats", "lastBlockNumber", "BlockNumber")

type Data @entity(timeseries: true) {
id: Int8!
timestamp: Timestamp!
blockNumber: Int8!
}

type Stats @aggregation(intervals: ["hour", "day"], source: "Data") {
id: Int8!
timestamp: Timestamp!

lastBlockNumber: BlockNumber! @aggregate(fn: "last", arg: "blockNumber")
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# valid: Non-numeric first and last aggregations

type EntityA @entity(immutable: true) {
id: ID!
}

type EntityB @entity(immutable: true) {
id: Int8!
}

type EntityC @entity(immutable: true) {
id: Bytes!
}

type Data @entity(timeseries: true) {
id: Int8!
timestamp: Timestamp!
fieldA: EntityA!
fieldB: EntityB!
fieldC: EntityC!
fieldD: String!
fieldE: Bytes!
}

type Stats @aggregation(intervals: ["hour", "day"], source: "Data") {
id: Int8!
timestamp: Timestamp!

firstA: EntityA! @aggregate(fn: "first", arg: "fieldA")
lastA: EntityA! @aggregate(fn: "last", arg: "fieldA")

firstB: EntityB! @aggregate(fn: "first", arg: "fieldB")
lastB: EntityB! @aggregate(fn: "last", arg: "fieldB")

firstC: EntityC! @aggregate(fn: "first", arg: "fieldC")
lastC: EntityC! @aggregate(fn: "last", arg: "fieldC")

firstD: String! @aggregate(fn: "first", arg: "fieldD")
lastD: String! @aggregate(fn: "last", arg: "fieldD")

firstE: Bytes! @aggregate(fn: "first", arg: "fieldE")
lastE: Bytes! @aggregate(fn: "last", arg: "fieldE")
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
-- This file was generated by generate.sh in this directory
set search_path = public;
drop aggregate arg_min_text(text_and_value);
drop aggregate arg_max_text(text_and_value);
drop function arg_from_text_and_value(text_and_value);
drop function arg_max_agg_text(text_and_value, text_and_value);
drop function arg_min_agg_text(text_and_value, text_and_value);
drop type text_and_value;
drop aggregate arg_min_bytea(bytea_and_value);
drop aggregate arg_max_bytea(bytea_and_value);
drop function arg_from_bytea_and_value(bytea_and_value);
drop function arg_max_agg_bytea(bytea_and_value, bytea_and_value);
drop function arg_min_agg_bytea(bytea_and_value, bytea_and_value);
drop type bytea_and_value;
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#! /bin/bash

# Generate up and down migrations to define arg_min and arg_max functions
# for the types listed in `types`.
#
# The functions can all be used like
#
# select first_int4((arg, value)) from t
#
# and return the `arg int4` for the smallest value `value int8`. If there
# are several rows with the smallest value, we try hard to return the first
# one, but that also depends on how Postgres calculates these
# aggregations. Note that the relation over which we are aggregating does
# not need to be ordered.
#
# Unfortunately, it is not possible to do this generically, so we have to
# monomorphize and define an aggregate for each data type that we want to
# use. The `value` is always an `int8`
#
# If changes to these functions are needed, copy this script to a new
# migration, change it and regenerate the up and down migrations

types="text bytea"
dir=$(dirname $0)

read -d '' -r prelude <<'EOF'
-- This file was generated by generate.sh in this directory
set search_path = public;
EOF

read -d '' -r up_template <<'EOF'
create type public.@T@_and_value as (
arg @T@,
value int8
);

create or replace function arg_min_agg_@T@ (a @T@_and_value, b @T@_and_value)
returns @T@_and_value
language sql immutable strict parallel safe as
'select case when a.arg is null then b
when b.arg is null then a
when a.value <= b.value then a
else b end';

create or replace function arg_max_agg_@T@ (a @T@_and_value, b @T@_and_value)
returns @T@_and_value
language sql immutable strict parallel safe as
'select case when a.arg is null then b
when b.arg is null then a
when a.value > b.value then a
else b end';

create function arg_from_@T@_and_value(a @T@_and_value)
returns @T@
language sql immutable strict parallel safe as
'select a.arg';

create aggregate arg_min_@T@ (@T@_and_value) (
sfunc = arg_min_agg_@T@,
stype = @T@_and_value,
finalfunc = arg_from_@T@_and_value,
parallel = safe
);

comment on aggregate arg_min_@T@(@T@_and_value) is
'For ''select arg_min_@T@((arg, value)) from ..'' return the arg for the smallest value';

create aggregate arg_max_@T@ (@T@_and_value) (
sfunc = arg_max_agg_@T@,
stype = @T@_and_value,
finalfunc = arg_from_@T@_and_value,
parallel = safe
);

comment on aggregate arg_max_@T@(@T@_and_value) is
'For ''select arg_max_@T@((arg, value)) from ..'' return the arg for the largest value';
EOF

read -d '' -r down_template <<'EOF'
drop aggregate arg_min_@T@(@T@_and_value);
drop aggregate arg_max_@T@(@T@_and_value);
drop function arg_from_@T@_and_value(@T@_and_value);
drop function arg_max_agg_@T@(@T@_and_value, @T@_and_value);
drop function arg_min_agg_@T@(@T@_and_value, @T@_and_value);
drop type @T@_and_value;
EOF

echo "$prelude" > $dir/up.sql
for typ in $types
do
echo "${up_template//@T@/$typ}" >> $dir/up.sql
done

echo "$prelude" > $dir/down.sql
for typ in $types
do
echo "${down_template//@T@/$typ}" >> $dir/down.sql
done
Loading