diff --git a/graph/src/schema/input/mod.rs b/graph/src/schema/input/mod.rs index a512c050965..ac7a4284175 100644 --- a/graph/src/schema/input/mod.rs +++ b/graph/src/schema/input/mod.rs @@ -788,7 +788,7 @@ pub struct Aggregate { } impl Aggregate { - fn new(_schema: &Schema, name: &str, field_type: &s::Type, dir: &s::Directive) -> Self { + fn new(schema: &Schema, name: &str, field_type: &s::Type, dir: &s::Directive) -> Self { let func = dir .argument("fn") .unwrap() @@ -818,7 +818,7 @@ impl Aggregate { arg, cumulative, field_type: field_type.clone(), - value_type: field_type.get_base_type().parse().unwrap(), + value_type: Field::scalar_value_type(schema, field_type), } } @@ -2366,27 +2366,63 @@ mod validations { } } - fn aggregate_fields_are_numbers(agg_type: &s::ObjectType, errors: &mut Vec) { + fn aggregate_field_types( + schema: &Schema, + agg_type: &s::ObjectType, + errors: &mut Vec, + ) { + fn is_first_last(agg_directive: &s::Directive) -> bool { + match agg_directive.argument(kw::FUNC) { + Some(s::Value::Enum(func) | s::Value::String(func)) => { + func == AggregateFn::First.as_str() + || func == AggregateFn::Last.as_str() + } + _ => false, + } + } + let errs = agg_type .fields .iter() - .filter(|field| field.find_directive(kw::AGGREGATE).is_some()) - .map(|field| match field.field_type.value_type() { - Ok(vt) => { - if vt.is_numeric() { - Ok(()) - } else { - Err(Err::NonNumericAggregate( + .filter_map(|field| { + field + .find_directive(kw::AGGREGATE) + .map(|agg_directive| (field, agg_directive)) + }) + .map(|(field, agg_directive)| { + let is_first_last = is_first_last(agg_directive); + + match field.field_type.value_type() { + Ok(value_type) if value_type.is_numeric() => Ok(()), + Ok(ValueType::Bytes | ValueType::String) if is_first_last => Ok(()), + Ok(_) if is_first_last => Err(Err::InvalidFirstLastAggregate( + agg_type.name.clone(), + field.name.clone(), + )), + Ok(_) => Err(Err::NonNumericAggregate( + agg_type.name.to_owned(), + field.name.to_owned(), + )), + Err(_) => { + if is_first_last + && schema + .entity_types + .iter() + .find(|entity_type| { + entity_type.name.eq(field.field_type.get_base_type()) + }) + .is_some() + { + return Ok(()); + } + + Err(Err::FieldTypeUnknown( agg_type.name.to_owned(), field.name.to_owned(), + field.field_type.get_base_type().to_owned(), )) } } - Err(_) => Err(Err::FieldTypeUnknown( - agg_type.name.to_owned(), - field.name.to_owned(), - field.field_type.get_base_type().to_owned(), - )), }) .filter_map(|err| err.err()); errors.extend(errs); @@ -2519,16 +2555,10 @@ mod validations { continue; } }; - let field_type = match field.field_type.value_type() { - Ok(field_type) => field_type, - Err(_) => { - errors.push(Err::NonNumericAggregate( - agg_type.name.to_owned(), - field.name.to_owned(), - )); - continue; - } - }; + + let is_first_last = + matches!(func, AggregateFn::First | AggregateFn::Last); + // It would be nicer to use a proper struct here // and have that implement // `sqlexpr::ExprVisitor` but we need access to @@ -2539,6 +2569,18 @@ mod validations { let arg_type = match source.field(ident) { Some(arg_field) => match arg_field.field_type.value_type() { Ok(arg_type) if arg_type.is_numeric() => arg_type, + Ok(ValueType::Bytes | ValueType::String) + if is_first_last => + { + return Ok(()); + } + Err(_) + if is_first_last + && arg_field.field_type.get_base_type() + == field.field_type.get_base_type() => + { + return Ok(()); + } Ok(_) | Err(_) => { return Err(Err::AggregationNonNumericArg( agg_type.name.to_owned(), @@ -2556,15 +2598,27 @@ mod validations { )); } }; - if arg_type > field_type { - return Err(Err::AggregationNonMatchingArg( - agg_type.name.to_owned(), - field.name.to_owned(), - arg.to_owned(), - arg_type.to_str().to_owned(), - field_type.to_str().to_owned(), - )); + + match field.field_type.value_type() { + Ok(field_type) if field_type.is_numeric() => { + if arg_type > field_type { + return Err(Err::AggregationNonMatchingArg( + agg_type.name.to_owned(), + field.name.to_owned(), + arg.to_owned(), + arg_type.to_str().to_owned(), + field_type.to_str().to_owned(), + )); + } + } + Ok(_) | Err(_) => { + return Err(Err::NonNumericAggregate( + agg_type.name.to_owned(), + field.name.to_owned(), + )); + } } + Ok(()) }; if let Err(mut errs) = sqlexpr::parse(arg, check_ident) { @@ -2661,7 +2715,7 @@ mod validations { errors.push(err); } no_derived_fields(agg_type, &mut errors); - aggregate_fields_are_numbers(agg_type, &mut errors); + aggregate_field_types(self, agg_type, &mut errors); aggregate_directive(self, agg_type, &mut errors); // check timeseries directive has intervals and args aggregation_intervals(agg_type, &mut errors); diff --git a/graph/src/schema/mod.rs b/graph/src/schema/mod.rs index 0b1a12cd338..f4e098a4b3e 100644 --- a/graph/src/schema/mod.rs +++ b/graph/src/schema/mod.rs @@ -123,6 +123,8 @@ pub enum SchemaValidationError { TimestampFieldMissing(String), #[error("Aggregation {0}, field{1}: aggregates must use a numeric type, one of Int, Int8, BigInt, and BigDecimal")] NonNumericAggregate(String, String), + #[error("Aggregation '{0}', field '{1}': first/last aggregates must use a numeric, byte array, string or a reference type")] + InvalidFirstLastAggregate(String, String), #[error("Aggregation {0} is missing the `source` argument")] AggregationMissingSource(String), #[error( diff --git a/graph/src/schema/test_schemas/ts_invalid_first_type_reference.graphql b/graph/src/schema/test_schemas/ts_invalid_first_type_reference.graphql new file mode 100644 index 00000000000..3f259c732cc --- /dev/null +++ b/graph/src/schema/test_schemas/ts_invalid_first_type_reference.graphql @@ -0,0 +1,14 @@ +# fail: FieldTypeUnknown("Stats", "firstBlockNumber", "BlockNumber") + +type Data @entity(timeseries: true) { + id: Int8! + timestamp: Timestamp! + blockNumber: Int8! +} + +type Stats @aggregation(intervals: ["hour", "day"], source: "Data") { + id: Int8! + timestamp: Timestamp! + + firstBlockNumber: BlockNumber! @aggregate(fn: "first", arg: "blockNumber") +} diff --git a/graph/src/schema/test_schemas/ts_invalid_last_type_reference.graphql b/graph/src/schema/test_schemas/ts_invalid_last_type_reference.graphql new file mode 100644 index 00000000000..405fde6fd3f --- /dev/null +++ b/graph/src/schema/test_schemas/ts_invalid_last_type_reference.graphql @@ -0,0 +1,14 @@ +# fail: FieldTypeUnknown("Stats", "lastBlockNumber", "BlockNumber") + +type Data @entity(timeseries: true) { + id: Int8! + timestamp: Timestamp! + blockNumber: Int8! +} + +type Stats @aggregation(intervals: ["hour", "day"], source: "Data") { + id: Int8! + timestamp: Timestamp! + + lastBlockNumber: BlockNumber! @aggregate(fn: "last", arg: "blockNumber") +} diff --git a/graph/src/schema/test_schemas/ts_valid_non_numeric_first_last.graphql b/graph/src/schema/test_schemas/ts_valid_non_numeric_first_last.graphql new file mode 100644 index 00000000000..fe78dc7463c --- /dev/null +++ b/graph/src/schema/test_schemas/ts_valid_non_numeric_first_last.graphql @@ -0,0 +1,43 @@ +# valid: Non-numeric first and last aggregations + +type EntityA @entity(immutable: true) { + id: ID! +} + +type EntityB @entity(immutable: true) { + id: Int8! +} + +type EntityC @entity(immutable: true) { + id: Bytes! +} + +type Data @entity(timeseries: true) { + id: Int8! + timestamp: Timestamp! + fieldA: EntityA! + fieldB: EntityB! + fieldC: EntityC! + fieldD: String! + fieldE: Bytes! +} + +type Stats @aggregation(intervals: ["hour", "day"], source: "Data") { + id: Int8! + timestamp: Timestamp! + + firstA: EntityA! @aggregate(fn: "first", arg: "fieldA") + lastA: EntityA! @aggregate(fn: "last", arg: "fieldA") + + firstB: EntityB! @aggregate(fn: "first", arg: "fieldB") + lastB: EntityB! @aggregate(fn: "last", arg: "fieldB") + + firstC: EntityC! @aggregate(fn: "first", arg: "fieldC") + lastC: EntityC! @aggregate(fn: "last", arg: "fieldC") + + firstD: String! @aggregate(fn: "first", arg: "fieldD") + lastD: String! @aggregate(fn: "last", arg: "fieldD") + + firstE: Bytes! @aggregate(fn: "first", arg: "fieldE") + lastE: Bytes! @aggregate(fn: "last", arg: "fieldE") +} diff --git a/store/postgres/migrations/2025-12-09-164812_extend_arg_min_max/down.sql b/store/postgres/migrations/2025-12-09-164812_extend_arg_min_max/down.sql new file mode 100644 index 00000000000..90de2623327 --- /dev/null +++ b/store/postgres/migrations/2025-12-09-164812_extend_arg_min_max/down.sql @@ -0,0 +1,14 @@ +-- This file was generated by generate.sh in this directory +set search_path = public; +drop aggregate arg_min_text(text_and_value); +drop aggregate arg_max_text(text_and_value); +drop function arg_from_text_and_value(text_and_value); +drop function arg_max_agg_text(text_and_value, text_and_value); +drop function arg_min_agg_text(text_and_value, text_and_value); +drop type text_and_value; +drop aggregate arg_min_bytea(bytea_and_value); +drop aggregate arg_max_bytea(bytea_and_value); +drop function arg_from_bytea_and_value(bytea_and_value); +drop function arg_max_agg_bytea(bytea_and_value, bytea_and_value); +drop function arg_min_agg_bytea(bytea_and_value, bytea_and_value); +drop type bytea_and_value; diff --git a/store/postgres/migrations/2025-12-09-164812_extend_arg_min_max/generate.sh b/store/postgres/migrations/2025-12-09-164812_extend_arg_min_max/generate.sh new file mode 100755 index 00000000000..44ceb2fadd3 --- /dev/null +++ b/store/postgres/migrations/2025-12-09-164812_extend_arg_min_max/generate.sh @@ -0,0 +1,98 @@ +#! /bin/bash + +# Generate up and down migrations to define arg_min and arg_max functions +# for the types listed in `types`. +# +# The functions can all be used like +# +# select first_int4((arg, value)) from t +# +# and return the `arg int4` for the smallest value `value int8`. If there +# are several rows with the smallest value, we try hard to return the first +# one, but that also depends on how Postgres calculates these +# aggregations. Note that the relation over which we are aggregating does +# not need to be ordered. +# +# Unfortunately, it is not possible to do this generically, so we have to +# monomorphize and define an aggregate for each data type that we want to +# use. The `value` is always an `int8` +# +# If changes to these functions are needed, copy this script to a new +# migration, change it and regenerate the up and down migrations + +types="text bytea" +dir=$(dirname $0) + +read -d '' -r prelude <<'EOF' +-- This file was generated by generate.sh in this directory +set search_path = public; +EOF + +read -d '' -r up_template <<'EOF' +create type public.@T@_and_value as ( + arg @T@, + value int8 +); + +create or replace function arg_min_agg_@T@ (a @T@_and_value, b @T@_and_value) + returns @T@_and_value + language sql immutable strict parallel safe as +'select case when a.arg is null then b + when b.arg is null then a + when a.value <= b.value then a + else b end'; + +create or replace function arg_max_agg_@T@ (a @T@_and_value, b @T@_and_value) + returns @T@_and_value + language sql immutable strict parallel safe as +'select case when a.arg is null then b + when b.arg is null then a + when a.value > b.value then a + else b end'; + +create function arg_from_@T@_and_value(a @T@_and_value) + returns @T@ + language sql immutable strict parallel safe as +'select a.arg'; + +create aggregate arg_min_@T@ (@T@_and_value) ( + sfunc = arg_min_agg_@T@, + stype = @T@_and_value, + finalfunc = arg_from_@T@_and_value, + parallel = safe +); + +comment on aggregate arg_min_@T@(@T@_and_value) is +'For ''select arg_min_@T@((arg, value)) from ..'' return the arg for the smallest value'; + +create aggregate arg_max_@T@ (@T@_and_value) ( + sfunc = arg_max_agg_@T@, + stype = @T@_and_value, + finalfunc = arg_from_@T@_and_value, + parallel = safe +); + +comment on aggregate arg_max_@T@(@T@_and_value) is +'For ''select arg_max_@T@((arg, value)) from ..'' return the arg for the largest value'; +EOF + +read -d '' -r down_template <<'EOF' +drop aggregate arg_min_@T@(@T@_and_value); +drop aggregate arg_max_@T@(@T@_and_value); +drop function arg_from_@T@_and_value(@T@_and_value); +drop function arg_max_agg_@T@(@T@_and_value, @T@_and_value); +drop function arg_min_agg_@T@(@T@_and_value, @T@_and_value); +drop type @T@_and_value; +EOF + +echo "$prelude" > $dir/up.sql +for typ in $types +do + echo "${up_template//@T@/$typ}" >> $dir/up.sql +done + +echo "$prelude" > $dir/down.sql +for typ in $types +do + echo "${down_template//@T@/$typ}" >> $dir/down.sql +done diff --git a/store/postgres/migrations/2025-12-09-164812_extend_arg_min_max/up.sql b/store/postgres/migrations/2025-12-09-164812_extend_arg_min_max/up.sql new file mode 100644 index 00000000000..fc9474011dd --- /dev/null +++ b/store/postgres/migrations/2025-12-09-164812_extend_arg_min_max/up.sql @@ -0,0 +1,92 @@ +-- This file was generated by generate.sh in this directory +set search_path = public; +create type public.text_and_value as ( + arg text, + value int8 +); + +create or replace function arg_min_agg_text (a text_and_value, b text_and_value) + returns text_and_value + language sql immutable strict parallel safe as +'select case when a.arg is null then b + when b.arg is null then a + when a.value <= b.value then a + else b end'; + +create or replace function arg_max_agg_text (a text_and_value, b text_and_value) + returns text_and_value + language sql immutable strict parallel safe as +'select case when a.arg is null then b + when b.arg is null then a + when a.value > b.value then a + else b end'; + +create function arg_from_text_and_value(a text_and_value) + returns text + language sql immutable strict parallel safe as +'select a.arg'; + +create aggregate arg_min_text (text_and_value) ( + sfunc = arg_min_agg_text, + stype = text_and_value, + finalfunc = arg_from_text_and_value, + parallel = safe +); + +comment on aggregate arg_min_text(text_and_value) is +'For ''select arg_min_text((arg, value)) from ..'' return the arg for the smallest value'; + +create aggregate arg_max_text (text_and_value) ( + sfunc = arg_max_agg_text, + stype = text_and_value, + finalfunc = arg_from_text_and_value, + parallel = safe +); + +comment on aggregate arg_max_text(text_and_value) is +'For ''select arg_max_text((arg, value)) from ..'' return the arg for the largest value'; +create type public.bytea_and_value as ( + arg bytea, + value int8 +); + +create or replace function arg_min_agg_bytea (a bytea_and_value, b bytea_and_value) + returns bytea_and_value + language sql immutable strict parallel safe as +'select case when a.arg is null then b + when b.arg is null then a + when a.value <= b.value then a + else b end'; + +create or replace function arg_max_agg_bytea (a bytea_and_value, b bytea_and_value) + returns bytea_and_value + language sql immutable strict parallel safe as +'select case when a.arg is null then b + when b.arg is null then a + when a.value > b.value then a + else b end'; + +create function arg_from_bytea_and_value(a bytea_and_value) + returns bytea + language sql immutable strict parallel safe as +'select a.arg'; + +create aggregate arg_min_bytea (bytea_and_value) ( + sfunc = arg_min_agg_bytea, + stype = bytea_and_value, + finalfunc = arg_from_bytea_and_value, + parallel = safe +); + +comment on aggregate arg_min_bytea(bytea_and_value) is +'For ''select arg_min_bytea((arg, value)) from ..'' return the arg for the smallest value'; + +create aggregate arg_max_bytea (bytea_and_value) ( + sfunc = arg_max_agg_bytea, + stype = bytea_and_value, + finalfunc = arg_from_bytea_and_value, + parallel = safe +); + +comment on aggregate arg_max_bytea(bytea_and_value) is +'For ''select arg_max_bytea((arg, value)) from ..'' return the arg for the largest value';