diff --git a/optimizer/optimizer.go b/optimizer/optimizer.go index 9a9677c1..7ccbee54 100644 --- a/optimizer/optimizer.go +++ b/optimizer/optimizer.go @@ -40,6 +40,7 @@ func Optimize(node *Node, config *conf.Config) error { Walk(node, &filterLast{}) Walk(node, &filterFirst{}) Walk(node, &predicateCombination{}) + Walk(node, &sumRange{}) Walk(node, &sumArray{}) Walk(node, &sumMap{}) return nil diff --git a/optimizer/sum_range.go b/optimizer/sum_range.go new file mode 100644 index 00000000..92f6e1da --- /dev/null +++ b/optimizer/sum_range.go @@ -0,0 +1,172 @@ +package optimizer + +import ( + . "github.com/expr-lang/expr/ast" +) + +type sumRange struct{} + +func (*sumRange) Visit(node *Node) { + // Pattern 1: sum(m..n) or sum(m..n, predicate) where m and n are constant integers + if sumBuiltin, ok := (*node).(*BuiltinNode); ok && + sumBuiltin.Name == "sum" && + (len(sumBuiltin.Arguments) == 1 || len(sumBuiltin.Arguments) == 2) { + if rangeOp, ok := sumBuiltin.Arguments[0].(*BinaryNode); ok && rangeOp.Operator == ".." { + if from, ok := rangeOp.Left.(*IntegerNode); ok { + if to, ok := rangeOp.Right.(*IntegerNode); ok { + m := from.Value + n := to.Value + if n >= m { + count := n - m + 1 + // Use the arithmetic series formula: (n - m + 1) * (m + n) / 2 + sum := count * (m + n) / 2 + + if len(sumBuiltin.Arguments) == 1 { + // sum(m..n) + patchWithType(node, &IntegerNode{Value: sum}) + } else if len(sumBuiltin.Arguments) == 2 { + // sum(m..n, predicate) + if result, ok := applySumPredicate(sum, count, sumBuiltin.Arguments[1]); ok { + patchWithType(node, &IntegerNode{Value: result}) + } + } + } + } + } + } + } + + // Pattern 2: reduce(m..n, # + #acc) where m and n are constant integers + if reduceBuiltin, ok := (*node).(*BuiltinNode); ok && + reduceBuiltin.Name == "reduce" && + (len(reduceBuiltin.Arguments) == 2 || len(reduceBuiltin.Arguments) == 3) { + if rangeOp, ok := reduceBuiltin.Arguments[0].(*BinaryNode); ok && rangeOp.Operator == ".." { + if from, ok := rangeOp.Left.(*IntegerNode); ok { + if to, ok := rangeOp.Right.(*IntegerNode); ok { + if isPointerPlusAcc(reduceBuiltin.Arguments[1]) { + m := from.Value + n := to.Value + if n >= m { + // Use the arithmetic series formula: (n - m + 1) * (m + n) / 2 + sum := (n - m + 1) * (m + n) / 2 + + // Check for optional initialValue (3rd argument) + if len(reduceBuiltin.Arguments) == 3 { + if initialValue, ok := reduceBuiltin.Arguments[2].(*IntegerNode); ok { + result := initialValue.Value + sum + patchWithType(node, &IntegerNode{Value: result}) + } + } else { + patchWithType(node, &IntegerNode{Value: sum}) + } + } + } + } + } + } + } +} + +// isPointerPlusAcc checks if the node represents `# + #acc` pattern +func isPointerPlusAcc(node Node) bool { + predicate, ok := node.(*PredicateNode) + if !ok { + return false + } + + binary, ok := predicate.Node.(*BinaryNode) + if !ok { + return false + } + + if binary.Operator != "+" { + return false + } + + // Check for # + #acc (pointer + accumulator) + leftPointer, leftIsPointer := binary.Left.(*PointerNode) + rightPointer, rightIsPointer := binary.Right.(*PointerNode) + + if leftIsPointer && rightIsPointer { + // # + #acc: Left is pointer (Name=""), Right is acc (Name="acc") + if leftPointer.Name == "" && rightPointer.Name == "acc" { + return true + } + // #acc + #: Left is acc (Name="acc"), Right is pointer (Name="") + if leftPointer.Name == "acc" && rightPointer.Name == "" { + return true + } + } + + return false +} + +// applySumPredicate tries to compute the result of sum(m..n, predicate) at compile time. +// Returns (result, true) if optimization is possible, (0, false) otherwise. +// Supported predicates: +// - # (identity): result = sum +// - # * k (multiply by constant): result = k * sum +// - k * # (multiply by constant): result = k * sum +// - # + k (add constant): result = sum + count * k +// - k + # (add constant): result = sum + count * k +// - # - k (subtract constant): result = sum - count * k +func applySumPredicate(sum, count int, predicateArg Node) (int, bool) { + predicate, ok := predicateArg.(*PredicateNode) + if !ok { + return 0, false + } + + // Case 1: # (identity) - just return the sum + if pointer, ok := predicate.Node.(*PointerNode); ok && pointer.Name == "" { + return sum, true + } + + // Case 2: Binary operations with pointer and constant + binary, ok := predicate.Node.(*BinaryNode) + if !ok { + return 0, false + } + + pointer, constant, pointerOnLeft := extractPointerAndConstantWithPosition(binary) + if pointer == nil || constant == nil { + return 0, false + } + + switch binary.Operator { + case "*": + // # * k or k * # => k * sum + return constant.Value * sum, true + case "+": + // # + k or k + # => sum + count * k + return sum + count*constant.Value, true + case "-": + if pointerOnLeft { + // # - k => sum - count * k + return sum - count*constant.Value, true + } + // k - # => count * k - sum + return count*constant.Value - sum, true + } + + return 0, false +} + +// extractPointerAndConstantWithPosition extracts pointer (#) and integer constant from a binary node. +// Returns (pointer, constant, pointerOnLeft) or (nil, nil, false) if not matching the expected pattern. +func extractPointerAndConstantWithPosition(binary *BinaryNode) (*PointerNode, *IntegerNode, bool) { + // Try left=pointer, right=constant + if pointer, ok := binary.Left.(*PointerNode); ok && pointer.Name == "" { + if constant, ok := binary.Right.(*IntegerNode); ok { + return pointer, constant, true + } + } + + // Try left=constant, right=pointer + if constant, ok := binary.Left.(*IntegerNode); ok { + if pointer, ok := binary.Right.(*PointerNode); ok && pointer.Name == "" { + return pointer, constant, false + } + } + + return nil, nil, false +} diff --git a/optimizer/sum_range_test.go b/optimizer/sum_range_test.go new file mode 100644 index 00000000..87185b22 --- /dev/null +++ b/optimizer/sum_range_test.go @@ -0,0 +1,289 @@ +package optimizer_test + +import ( + "testing" + + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/ast" + "github.com/expr-lang/expr/internal/testify/assert" + "github.com/expr-lang/expr/internal/testify/require" + "github.com/expr-lang/expr/optimizer" + "github.com/expr-lang/expr/parser" +) + +func TestOptimize_sum_range(t *testing.T) { + tree, err := parser.Parse(`sum(1..100)`) + require.NoError(t, err) + + err = optimizer.Optimize(&tree.Node, nil) + require.NoError(t, err) + + expected := &ast.IntegerNode{Value: 5050} + + assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) +} + +func TestOptimize_sum_range_different_values(t *testing.T) { + tests := []struct { + expr string + want int + }{ + {`sum(1..10)`, 55}, + {`sum(1..100)`, 5050}, + {`sum(5..10)`, 45}, + {`sum(0..100)`, 5050}, + {`sum(1..1)`, 1}, + {`sum(0..0)`, 0}, + {`sum(10..20)`, 165}, + } + + for _, tt := range tests { + t.Run(tt.expr, func(t *testing.T) { + program, err := expr.Compile(tt.expr) + require.NoError(t, err) + + output, err := expr.Run(program, nil) + require.NoError(t, err) + assert.Equal(t, tt.want, output) + }) + } +} + +func TestOptimize_sum_range_with_predicate(t *testing.T) { + tests := []struct { + expr string + want int + }{ + // # (identity) - same as sum(m..n) + {`sum(1..10, #)`, 55}, + {`sum(1..100, #)`, 5050}, + + // # * k (multiply by constant) + {`sum(1..10, # * 2)`, 110}, // 2 * 55 + {`sum(1..100, # * 2)`, 10100}, // 2 * 5050 + {`sum(1..10, # * 0)`, 0}, + {`sum(1..10, # * 1)`, 55}, + + // k * # (multiply by constant, reversed) + {`sum(1..10, 2 * #)`, 110}, + {`sum(1..100, 3 * #)`, 15150}, // 3 * 5050 + + // # + k (add constant to each element) + {`sum(1..10, # + 1)`, 65}, // 55 + 10*1 + {`sum(1..100, # + 1)`, 5150}, // 5050 + 100*1 + {`sum(1..10, # + 0)`, 55}, + {`sum(1..10, # + 10)`, 155}, // 55 + 10*10 + + // k + # (add constant, reversed) + {`sum(1..10, 1 + #)`, 65}, + {`sum(1..100, 5 + #)`, 5550}, // 5050 + 100*5 + + // # - k (subtract constant from each element) + {`sum(1..10, # - 1)`, 45}, // 55 - 10*1 + {`sum(1..100, # - 1)`, 4950}, // 5050 - 100*1 + {`sum(1..10, # - 0)`, 55}, + + // k - # (constant minus each element) + {`sum(1..10, 10 - #)`, 45}, // 10*10 - 55 + {`sum(1..10, 0 - #)`, -55}, // 10*0 - 55 + } + + for _, tt := range tests { + t.Run(tt.expr, func(t *testing.T) { + program, err := expr.Compile(tt.expr) + require.NoError(t, err) + + output, err := expr.Run(program, nil) + require.NoError(t, err) + assert.Equal(t, tt.want, output) + }) + } +} + +func TestOptimize_sum_range_with_predicate_ast(t *testing.T) { + // Verify that sum(1..10, # * 2) is optimized to a constant + tree, err := parser.Parse(`sum(1..10, # * 2)`) + require.NoError(t, err) + + err = optimizer.Optimize(&tree.Node, nil) + require.NoError(t, err) + + expected := &ast.IntegerNode{Value: 110} + + assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) +} + +func TestOptimize_reduce_range_sum(t *testing.T) { + tree, err := parser.Parse(`reduce(1..100, # + #acc)`) + require.NoError(t, err) + + err = optimizer.Optimize(&tree.Node, nil) + require.NoError(t, err) + + expected := &ast.IntegerNode{Value: 5050} + + assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) +} + +func TestOptimize_reduce_range_sum_different_values(t *testing.T) { + tests := []struct { + expr string + want int + }{ + {`reduce(1..10, # + #acc)`, 55}, + {`reduce(1..100, # + #acc)`, 5050}, + {`reduce(5..10, # + #acc)`, 45}, + {`reduce(0..100, # + #acc)`, 5050}, + {`reduce(1..1, # + #acc)`, 1}, + {`reduce(10..20, # + #acc)`, 165}, + } + + for _, tt := range tests { + t.Run(tt.expr, func(t *testing.T) { + program, err := expr.Compile(tt.expr) + require.NoError(t, err) + + output, err := expr.Run(program, nil) + require.NoError(t, err) + assert.Equal(t, tt.want, output) + }) + } +} + +func TestOptimize_reduce_range_sum_reverse_order(t *testing.T) { + // Test #acc + # (reverse order) - should also be optimized + tree, err := parser.Parse(`reduce(1..100, #acc + #)`) + require.NoError(t, err) + + err = optimizer.Optimize(&tree.Node, nil) + require.NoError(t, err) + + expected := &ast.IntegerNode{Value: 5050} + + assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) +} + +func TestOptimize_reduce_range_sum_with_initial_value(t *testing.T) { + // Test reduce with initialValue: reduce(1..100, # + #acc, 10) => 5050 + 10 = 5060 + tree, err := parser.Parse(`reduce(1..100, # + #acc, 10)`) + require.NoError(t, err) + + err = optimizer.Optimize(&tree.Node, nil) + require.NoError(t, err) + + expected := &ast.IntegerNode{Value: 5060} + + assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) +} + +func TestOptimize_reduce_range_sum_with_initial_value_different_values(t *testing.T) { + tests := []struct { + expr string + want int + }{ + {`reduce(1..10, # + #acc, 0)`, 55}, + {`reduce(1..10, # + #acc, 10)`, 65}, + {`reduce(1..100, # + #acc, 0)`, 5050}, + {`reduce(1..100, # + #acc, 100)`, 5150}, + {`reduce(5..10, # + #acc, 5)`, 50}, + } + + for _, tt := range tests { + t.Run(tt.expr, func(t *testing.T) { + program, err := expr.Compile(tt.expr) + require.NoError(t, err) + + output, err := expr.Run(program, nil) + require.NoError(t, err) + assert.Equal(t, tt.want, output) + }) + } +} + +func TestOptimize_sum_range_reversed(t *testing.T) { + // When n < m (e.g., 10..1), the range is empty and sum should return 0. + // The optimization should NOT apply (n >= m check), so runtime handles it. + tests := []struct { + expr string + want int + }{ + {`sum(10..1)`, 0}, + {`sum(5..3)`, 0}, + {`sum(100..1)`, 0}, + } + + for _, tt := range tests { + t.Run(tt.expr, func(t *testing.T) { + program, err := expr.Compile(tt.expr) + require.NoError(t, err) + + output, err := expr.Run(program, nil) + require.NoError(t, err) + assert.Equal(t, tt.want, output) + }) + } +} + +func TestOptimize_sum_range_reversed_not_optimized(t *testing.T) { + // Verify that reversed ranges are NOT optimized (left as BuiltinNode) + tree, err := parser.Parse(`sum(10..1)`) + require.NoError(t, err) + + err = optimizer.Optimize(&tree.Node, nil) + require.NoError(t, err) + + // Should still be a BuiltinNode, not an IntegerNode + _, isBuiltin := tree.Node.(*ast.BuiltinNode) + assert.True(t, isBuiltin, "reversed range should not be optimized") +} + +func TestOptimize_reduce_range_reversed_errors(t *testing.T) { + // reduce on empty range (reversed) should error at runtime + program, err := expr.Compile(`reduce(10..1, # + #acc)`) + require.NoError(t, err) + + _, err = expr.Run(program, nil) + require.Error(t, err, "reduce on empty range should error") +} + +func BenchmarkSumRange_Optimized(b *testing.B) { + program, err := expr.Compile(`sum(1..100)`) + require.NoError(b, err) + + var out any + b.ResetTimer() + for n := 0; n < b.N; n++ { + out, _ = expr.Run(program, nil) + } + b.StopTimer() + + require.Equal(b, 5050, out) +} + +func BenchmarkReduceRangeSum_Optimized(b *testing.B) { + program, err := expr.Compile(`reduce(1..100, # + #acc)`) + require.NoError(b, err) + + var out any + b.ResetTimer() + for n := 0; n < b.N; n++ { + out, _ = expr.Run(program, nil) + } + b.StopTimer() + + require.Equal(b, 5050, out) +} + +func BenchmarkSumRange_Unoptimized(b *testing.B) { + program, err := expr.Compile(`sum(1..100)`, expr.Optimize(false)) + require.NoError(b, err) + + var out any + b.ResetTimer() + for n := 0; n < b.N; n++ { + out, _ = expr.Run(program, nil) + } + b.StopTimer() + + require.Equal(b, 5050, out) +}