Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
410 changes: 367 additions & 43 deletions CMakeLists.txt

Large diffs are not rendered by default.

52 changes: 40 additions & 12 deletions src/duckdb/extension/core_functions/aggregate/algebraic/avg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,27 +239,47 @@ struct TimeTZAverageOperation : public BaseSumOperation<AverageSetOperation, Add
}
};

LogicalType GetAvgStateType(const AggregateFunction &function) {
child_list_t<LogicalType> children;
children.emplace_back("count", LogicalType::UBIGINT);
children.emplace_back("value", function.arguments[0]);
return LogicalType::STRUCT(std::move(children));
}

LogicalType GetKahanAvgStateType(const AggregateFunction &function) {
child_list_t<LogicalType> children;
children.emplace_back("count", LogicalType::UBIGINT);
children.emplace_back("value", LogicalType::DOUBLE);
children.emplace_back("err", LogicalType::DOUBLE);
return LogicalType::STRUCT(std::move(children));
}

AggregateFunction GetAverageAggregate(PhysicalType type) {
switch (type) {
case PhysicalType::INT16: {
return AggregateFunction::UnaryAggregate<AvgState<int64_t>, int16_t, double, IntegerAverageOperation>(
LogicalType::SMALLINT, LogicalType::DOUBLE);
LogicalType::SMALLINT, LogicalType::DOUBLE)
.SetStructStateExport(GetAvgStateType);
}
case PhysicalType::INT32: {
return AggregateFunction::UnaryAggregate<AvgState<hugeint_t>, int32_t, double, IntegerAverageOperationHugeint>(
LogicalType::INTEGER, LogicalType::DOUBLE);
LogicalType::INTEGER, LogicalType::DOUBLE)
.SetStructStateExport(GetAvgStateType);
}
case PhysicalType::INT64: {
return AggregateFunction::UnaryAggregate<AvgState<hugeint_t>, int64_t, double, IntegerAverageOperationHugeint>(
LogicalType::BIGINT, LogicalType::DOUBLE);
LogicalType::BIGINT, LogicalType::DOUBLE)
.SetStructStateExport(GetAvgStateType);
}
case PhysicalType::INT128: {
return AggregateFunction::UnaryAggregate<AvgState<hugeint_t>, hugeint_t, double, HugeintAverageOperation>(
LogicalType::HUGEINT, LogicalType::DOUBLE);
LogicalType::HUGEINT, LogicalType::DOUBLE)
.SetStructStateExport(GetAvgStateType);
}
case PhysicalType::INTERVAL: {
return AggregateFunction::UnaryAggregate<IntervalAvgState, interval_t, interval_t, IntervalAverageOperation>(
LogicalType::INTERVAL, LogicalType::INTERVAL);
LogicalType::INTERVAL, LogicalType::INTERVAL)
.SetStructStateExport(GetAvgStateType);
}
default:
throw InternalException("Unimplemented average aggregate");
Expand All @@ -282,6 +302,7 @@ unique_ptr<FunctionData> BindDecimalAvg(ClientContext &context, AggregateFunctio
AggregateFunctionSet AvgFun::GetFunctions() {
AggregateFunctionSet avg;

// The first is already opted-in during `BindDecimalAvg`
avg.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr,
nullptr, nullptr, FunctionNullHandling::DEFAULT_NULL_HANDLING, nullptr,
BindDecimalAvg));
Expand All @@ -291,24 +312,31 @@ AggregateFunctionSet AvgFun::GetFunctions() {
avg.AddFunction(GetAverageAggregate(PhysicalType::INT128));
avg.AddFunction(GetAverageAggregate(PhysicalType::INTERVAL));
avg.AddFunction(AggregateFunction::UnaryAggregate<AvgState<double>, double, double, NumericAverageOperation>(
LogicalType::DOUBLE, LogicalType::DOUBLE));
LogicalType::DOUBLE, LogicalType::DOUBLE)
.SetStructStateExport(GetAvgStateType));

avg.AddFunction(AggregateFunction::UnaryAggregate<AvgState<hugeint_t>, int64_t, int64_t, DiscreteAverageOperation>(
LogicalType::TIMESTAMP, LogicalType::TIMESTAMP));
LogicalType::TIMESTAMP, LogicalType::TIMESTAMP)
.SetStructStateExport(GetAvgStateType));
avg.AddFunction(AggregateFunction::UnaryAggregate<AvgState<hugeint_t>, int64_t, int64_t, DiscreteAverageOperation>(
LogicalType::TIMESTAMP_TZ, LogicalType::TIMESTAMP_TZ));
LogicalType::TIMESTAMP_TZ, LogicalType::TIMESTAMP_TZ)
.SetStructStateExport(GetAvgStateType));
avg.AddFunction(AggregateFunction::UnaryAggregate<AvgState<hugeint_t>, int64_t, int64_t, DiscreteAverageOperation>(
LogicalType::TIME, LogicalType::TIME));
LogicalType::TIME, LogicalType::TIME)
.SetStructStateExport(GetAvgStateType));
avg.AddFunction(
AggregateFunction::UnaryAggregate<AvgState<hugeint_t>, dtime_tz_t, dtime_tz_t, TimeTZAverageOperation>(
LogicalType::TIME_TZ, LogicalType::TIME_TZ));
LogicalType::TIME_TZ, LogicalType::TIME_TZ)
.SetStructStateExport(GetAvgStateType));

return avg;
}

AggregateFunction FAvgFun::GetFunction() {
return AggregateFunction::UnaryAggregate<KahanAvgState, double, double, KahanAverageOperation>(LogicalType::DOUBLE,
LogicalType::DOUBLE);
auto function = AggregateFunction::UnaryAggregate<KahanAvgState, double, double, KahanAverageOperation>(
LogicalType::DOUBLE, LogicalType::DOUBLE)
.SetStructStateExport(GetKahanAvgStateType);
return function;
}

} // namespace duckdb
28 changes: 27 additions & 1 deletion src/duckdb/extension/core_functions/aggregate/algebraic/corr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,34 @@

namespace duckdb {

LogicalType GetCorrExportStateType(const AggregateFunction &function) {
auto state_children = child_list_t<LogicalType> {};

auto covar_pop_children = child_list_t<LogicalType> {};
covar_pop_children.emplace_back("count", LogicalType::UBIGINT);
covar_pop_children.emplace_back("mean_x", LogicalType::DOUBLE);
covar_pop_children.emplace_back("mean_y", LogicalType::DOUBLE);
covar_pop_children.emplace_back("co_moment", LogicalType::DOUBLE);
state_children.emplace_back("cov_pop", LogicalType::STRUCT(std::move(covar_pop_children)));

auto dev_pop_x_children = child_list_t<LogicalType> {};
dev_pop_x_children.emplace_back("count", LogicalType::UBIGINT);
dev_pop_x_children.emplace_back("mean", LogicalType::DOUBLE);
dev_pop_x_children.emplace_back("dsquared", LogicalType::DOUBLE);
state_children.emplace_back("dev_pop_x", LogicalType::STRUCT(std::move(dev_pop_x_children)));

auto dev_pop_y_children = child_list_t<LogicalType> {};
dev_pop_y_children.emplace_back("count", LogicalType::UBIGINT);
dev_pop_y_children.emplace_back("mean", LogicalType::DOUBLE);
dev_pop_y_children.emplace_back("dsquared", LogicalType::DOUBLE);
state_children.emplace_back("dev_pop_y", LogicalType::STRUCT(std::move(dev_pop_y_children)));

return LogicalType::STRUCT(std::move(state_children));
}

AggregateFunction CorrFun::GetFunction() {
return AggregateFunction::BinaryAggregate<CorrState, double, double, double, CorrOperation>(
LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE);
LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE)
.SetStructStateExport(GetCorrExportStateType);
}
} // namespace duckdb
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
#include "duckdb/common/exception.hpp"
#include "duckdb/common/types/hash.hpp"
#include "duckdb/common/types/hyperloglog.hpp"
#include "core_functions/aggregate/distributive_functions.hpp"
#include "duckdb/function/function_set.hpp"
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
#include "hyperloglog.hpp"

namespace duckdb {

Expand All @@ -14,7 +10,7 @@ namespace duckdb {
namespace {

struct ApproxDistinctCountState {
HyperLogLog hll;
HyperLogLogP<10> hll;
};

struct ApproxCountDistinctFunction {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,57 @@ struct BitState {
T value;
};

template <class T>
LogicalType GetBitStateType(const AggregateFunction &function) {
child_list_t<LogicalType> child_types;
child_types.emplace_back("is_set", LogicalType::BOOLEAN);

LogicalType value_type = function.return_type;
child_types.emplace_back("value", value_type);

return LogicalType::STRUCT(std::move(child_types));
}

LogicalType GetBitStringStateType(const AggregateFunction &function) {
child_list_t<LogicalType> child_types;
child_types.emplace_back("is_set", LogicalType::BOOLEAN);
child_types.emplace_back("value", function.return_type);
return LogicalType::STRUCT(std::move(child_types));
}

template <class OP>
AggregateFunction GetBitfieldUnaryAggregate(LogicalType type) {
switch (type.id()) {
case LogicalTypeId::TINYINT:
return AggregateFunction::UnaryAggregate<BitState<uint8_t>, int8_t, int8_t, OP>(type, type);
return AggregateFunction::UnaryAggregate<BitState<uint8_t>, int8_t, int8_t, OP>(type, type)
.SetStructStateExport(GetBitStateType<uint8_t>);
case LogicalTypeId::SMALLINT:
return AggregateFunction::UnaryAggregate<BitState<uint16_t>, int16_t, int16_t, OP>(type, type);
return AggregateFunction::UnaryAggregate<BitState<uint16_t>, int16_t, int16_t, OP>(type, type)
.SetStructStateExport(GetBitStateType<uint16_t>);
case LogicalTypeId::INTEGER:
return AggregateFunction::UnaryAggregate<BitState<uint32_t>, int32_t, int32_t, OP>(type, type);
return AggregateFunction::UnaryAggregate<BitState<uint32_t>, int32_t, int32_t, OP>(type, type)
.SetStructStateExport(GetBitStateType<uint32_t>);
case LogicalTypeId::BIGINT:
return AggregateFunction::UnaryAggregate<BitState<uint64_t>, int64_t, int64_t, OP>(type, type);
return AggregateFunction::UnaryAggregate<BitState<uint64_t>, int64_t, int64_t, OP>(type, type)
.SetStructStateExport(GetBitStateType<uint64_t>);
case LogicalTypeId::HUGEINT:
return AggregateFunction::UnaryAggregate<BitState<hugeint_t>, hugeint_t, hugeint_t, OP>(type, type);
return AggregateFunction::UnaryAggregate<BitState<hugeint_t>, hugeint_t, hugeint_t, OP>(type, type)
.SetStructStateExport(GetBitStateType<hugeint_t>);
case LogicalTypeId::UTINYINT:
return AggregateFunction::UnaryAggregate<BitState<uint8_t>, uint8_t, uint8_t, OP>(type, type);
return AggregateFunction::UnaryAggregate<BitState<uint8_t>, uint8_t, uint8_t, OP>(type, type)
.SetStructStateExport(GetBitStateType<uint8_t>);
case LogicalTypeId::USMALLINT:
return AggregateFunction::UnaryAggregate<BitState<uint16_t>, uint16_t, uint16_t, OP>(type, type);
return AggregateFunction::UnaryAggregate<BitState<uint16_t>, uint16_t, uint16_t, OP>(type, type)
.SetStructStateExport(GetBitStateType<uint16_t>);
case LogicalTypeId::UINTEGER:
return AggregateFunction::UnaryAggregate<BitState<uint32_t>, uint32_t, uint32_t, OP>(type, type);
return AggregateFunction::UnaryAggregate<BitState<uint32_t>, uint32_t, uint32_t, OP>(type, type)
.SetStructStateExport(GetBitStateType<uint32_t>);
case LogicalTypeId::UBIGINT:
return AggregateFunction::UnaryAggregate<BitState<uint64_t>, uint64_t, uint64_t, OP>(type, type);
return AggregateFunction::UnaryAggregate<BitState<uint64_t>, uint64_t, uint64_t, OP>(type, type)
.SetStructStateExport(GetBitStateType<uint64_t>);
case LogicalTypeId::UHUGEINT:
return AggregateFunction::UnaryAggregate<BitState<uhugeint_t>, uhugeint_t, uhugeint_t, OP>(type, type);
return AggregateFunction::UnaryAggregate<BitState<uhugeint_t>, uhugeint_t, uhugeint_t, OP>(type, type)
.SetStructStateExport(GetBitStateType<uhugeint_t>);
default:
throw InternalException("Unimplemented bitfield type for unary aggregate");
}
Expand Down Expand Up @@ -202,9 +230,11 @@ AggregateFunctionSet BitAndFun::GetFunctions() {
bit_and.AddFunction(GetBitfieldUnaryAggregate<BitAndOperation>(type));
}

bit_and.AddFunction(
auto bit_string_fun =
AggregateFunction::UnaryAggregateDestructor<BitState<string_t>, string_t, string_t, BitStringAndOperation>(
LogicalType::BIT, LogicalType::BIT));
LogicalType::BIT, LogicalType::BIT);
bit_string_fun.SetStructStateExport(GetBitStringStateType);
bit_and.AddFunction(bit_string_fun);
return bit_and;
}

Expand All @@ -213,9 +243,11 @@ AggregateFunctionSet BitOrFun::GetFunctions() {
for (auto &type : LogicalType::Integral()) {
bit_or.AddFunction(GetBitfieldUnaryAggregate<BitOrOperation>(type));
}
bit_or.AddFunction(
auto bit_string_fun =
AggregateFunction::UnaryAggregateDestructor<BitState<string_t>, string_t, string_t, BitStringOrOperation>(
LogicalType::BIT, LogicalType::BIT));
LogicalType::BIT, LogicalType::BIT);
bit_string_fun.SetStructStateExport(GetBitStringStateType);
bit_or.AddFunction(bit_string_fun);
return bit_or;
}

Expand All @@ -224,9 +256,11 @@ AggregateFunctionSet BitXorFun::GetFunctions() {
for (auto &type : LogicalType::Integral()) {
bit_xor.AddFunction(GetBitfieldUnaryAggregate<BitXorOperation>(type));
}
bit_xor.AddFunction(
auto bit_string_fun =
AggregateFunction::UnaryAggregateDestructor<BitState<string_t>, string_t, string_t, BitStringXorOperation>(
LogicalType::BIT, LogicalType::BIT));
LogicalType::BIT, LogicalType::BIT);
bit_string_fun.SetStructStateExport(GetBitStringStateType);
bit_xor.AddFunction(bit_string_fun);
return bit_xor;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,22 +93,29 @@ struct BoolOrFunFunction {
}
};

LogicalType GetBoolAndStateType(const AggregateFunction &function) {
child_list_t<LogicalType> child_types;
child_types.emplace_back("empty", LogicalType::BOOLEAN);
child_types.emplace_back("val", LogicalType::BOOLEAN);
return LogicalType::STRUCT(std::move(child_types));
}

} // namespace

AggregateFunction BoolOrFun::GetFunction() {
auto fun = AggregateFunction::UnaryAggregate<BoolState, bool, bool, BoolOrFunFunction>(
LogicalType(LogicalTypeId::BOOLEAN), LogicalType::BOOLEAN);
fun.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT);
fun.SetDistinctDependent(AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT);
return fun;
return fun.SetStructStateExport(GetBoolAndStateType);
}

AggregateFunction BoolAndFun::GetFunction() {
auto fun = AggregateFunction::UnaryAggregate<BoolState, bool, bool, BoolAndFunFunction>(
LogicalType(LogicalTypeId::BOOLEAN), LogicalType::BOOLEAN);
fun.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT);
fun.SetDistinctDependent(AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT);
return fun;
return fun.SetStructStateExport(GetBoolAndStateType);
}

} // namespace duckdb
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,19 @@ struct ProductFunction {
}
};

LogicalType GetProductStateType(const AggregateFunction &function) {
child_list_t<LogicalType> children;
children.emplace_back("empty", LogicalType::BOOLEAN);
children.emplace_back("val", LogicalType::DOUBLE);
return LogicalType::STRUCT(std::move(children));
}

} // namespace

AggregateFunction ProductFun::GetFunction() {
return AggregateFunction::UnaryAggregate<ProductState, double, double, ProductFunction>(
LogicalType(LogicalTypeId::DOUBLE), LogicalType::DOUBLE);
LogicalType(LogicalTypeId::DOUBLE), LogicalType::DOUBLE)
.SetStructStateExport(GetProductStateType);
}

} // namespace duckdb
Loading
Loading