Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions src/duckdb_py/pyrelation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,10 +395,36 @@ string DuckDBPyRelation::GenerateExpressionList(const string &function_name, vec
function_name + "(" + function_parameter + ((ignore_nulls) ? " ignore nulls) " : ") ") + window_spec;
}
for (idx_t i = 0; i < input.size(); i++) {
// We parse the input as an expression to validate it.
auto trimmed_input = input[i];
StringUtil::Trim(trimmed_input);

unique_ptr<ParsedExpression> expression;
try {
auto expressions = Parser::ParseExpressionList(trimmed_input);
if (expressions.size() == 1) {
expression = std::move(expressions[0]);
}
} catch (const ParserException &) {
// First attempt at parsing failed, the input might be a column name that needs quoting.
auto quoted_input = KeywordHelper::WriteQuoted(trimmed_input, '"');
auto expressions = Parser::ParseExpressionList(quoted_input);
if (expressions.size() == 1 && expressions[0]->GetExpressionClass() == ExpressionClass::COLUMN_REF) {
expression = std::move(expressions[0]);
}
}

if (!expression) {
throw ParserException("Invalid column expression: %s", trimmed_input);
}

// ToString() handles escaping for all expression types
auto escaped_input = expression->ToString();

if (function_parameter.empty()) {
expr += function_name + "(" + input[i] + ((ignore_nulls) ? " ignore nulls) " : ") ") + window_spec;
expr += function_name + "(" + escaped_input + ((ignore_nulls) ? " ignore nulls) " : ") ") + window_spec;
} else {
expr += function_name + "(" + input[i] + "," + function_parameter +
expr += function_name + "(" + escaped_input + "," + function_parameter +
((ignore_nulls) ? " ignore nulls) " : ") ") + window_spec;
}

Expand Down Expand Up @@ -587,7 +613,7 @@ unique_ptr<DuckDBPyRelation> DuckDBPyRelation::Product(const std::string &column
unique_ptr<DuckDBPyRelation> DuckDBPyRelation::StringAgg(const std::string &column, const std::string &sep,
const std::string &groups, const std::string &window_spec,
const std::string &projected_columns) {
auto string_agg_params = "\'" + sep + "\'";
auto string_agg_params = KeywordHelper::WriteOptionallyQuoted(sep, '\'');
return ApplyAggOrWin("string_agg", column, string_agg_params, groups, window_spec, projected_columns);
}

Expand Down
143 changes: 143 additions & 0 deletions tests/fast/relational_api/test_rapi_aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,3 +416,146 @@ def test_var_samp(self, table, f):

def test_describe(self, table):
assert table.describe().fetchall() is not None


class TestRAPIAggregationsColumnEscaping:
"""Test that aggregate functions properly escape column names that need quoting."""

def test_reserved_keyword_column_name(self, duckdb_cursor):
# Column name "select" is a reserved SQL keyword
rel = duckdb_cursor.sql('select 1 as "select", 2 as "order"')
result = rel.sum("select").fetchall()
assert result == [(1,)]

result = rel.avg("order").fetchall()
assert result == [(2.0,)]

def test_column_name_with_space(self, duckdb_cursor):
rel = duckdb_cursor.sql('select 10 as "my column"')
result = rel.sum("my column").fetchall()
assert result == [(10,)]

def test_column_name_with_quotes(self, duckdb_cursor):
# Column name containing a double quote
rel = duckdb_cursor.sql('select 5 as "col""name"')
result = rel.sum('col"name').fetchall()
assert result == [(5,)]

def test_qualified_column_name(self, duckdb_cursor):
# Qualified column name like table.column
rel = duckdb_cursor.sql("select 42 as value")
# When using qualified names, they should be properly escaped
result = rel.sum("value").fetchall()
assert result == [(42,)]


class TestRAPIAggregationsExpressionPassthrough:
"""Test that aggregate functions correctly pass through SQL expressions without escaping."""

def test_cast_expression(self, duckdb_cursor):
# Cast expressions should pass through without being quoted
rel = duckdb_cursor.sql("select 1 as v, 0 as f")
result = rel.bool_and("v::BOOL").fetchall()
assert result == [(True,)]

result = rel.bool_or("f::BOOL").fetchall()
assert result == [(False,)]

def test_star_expression(self, duckdb_cursor):
# Star (*) should pass through for count
rel = duckdb_cursor.sql("select 1 as a union all select 2")
result = rel.count("*").fetchall()
assert result == [(2,)]

def test_arithmetic_expression(self, duckdb_cursor):
# Arithmetic expressions should pass through
rel = duckdb_cursor.sql("select 10 as a, 5 as b")
result = rel.sum("a + b").fetchall()
assert result == [(15,)]

def test_function_expression(self, duckdb_cursor):
# Function calls should pass through
rel = duckdb_cursor.sql("select -5 as v")
result = rel.sum("abs(v)").fetchall()
assert result == [(5,)]

def test_case_expression(self, duckdb_cursor):
# CASE expressions should pass through
rel = duckdb_cursor.sql("select 1 as v union all select 2 union all select 3")
result = rel.sum("case when v > 1 then v else 0 end").fetchall()
assert result == [(5,)]


class TestRAPIAggregationsWithInvalidInput:
"""Test that only expression can be used."""

def test_injection_with_semicolon_is_neutralized(self, duckdb_cursor):
# Semicolon injection fails to parse as expression, gets quoted as identifier
rel = duckdb_cursor.sql("select 1 as v")
with pytest.raises(duckdb.BinderException, match="not found in FROM clause"):
rel.sum("v; drop table agg; --").fetchall()

def test_injection_with_union_is_neutralized(self, duckdb_cursor):
# UNION fails to parse as single expression, gets quoted
rel = duckdb_cursor.sql("select 1 as v")
with pytest.raises(duckdb.BinderException, match="not found in FROM clause"):
rel.sum("v union select * from agg").fetchall()

def test_subquery_is_contained(self, duckdb_cursor):
# Subqueries are valid expressions - they're contained within the aggregate
# and cannot break out of the expression context
rel = duckdb_cursor.sql("select 1 as v")
# This executes sum((select 1)) = sum(1) = 1 - contained, not an injection
result = rel.sum("(select 1)").fetchall()
assert result == [(1,)]

def test_injection_closing_paren_is_neutralized(self, duckdb_cursor):
# Adding a closing paren fails to parse, gets quoted
rel = duckdb_cursor.sql("select 1 as v")
with pytest.raises(duckdb.BinderException, match="not found in FROM clause"):
rel.sum("v) from agg; drop table agg; --").fetchall()

def test_comment_is_harmless(self, duckdb_cursor):
# SQL comments are stripped during parsing, so "v -- comment" parses as just "v"
rel = duckdb_cursor.sql("select 1 as v")
result = rel.sum("v -- this is ignored").fetchall()
assert result == [(1,)]

def test_empty_expression_rejected(self, duckdb_cursor):
# Empty or whitespace-only expressions should be rejected
rel = duckdb_cursor.sql("select 1 as v")
with pytest.raises(duckdb.ParserException):
rel.sum("").fetchall()

def test_whitespace_only_expression_rejected(self, duckdb_cursor):
# Whitespace-only expressions should be rejected
rel = duckdb_cursor.sql("select 1 as v")
with pytest.raises(duckdb.ParserException):
rel.sum(" ").fetchall()


class TestRAPIStringAggSeparatorEscaping:
"""Test that string_agg separator is properly escaped as a string literal."""

def test_simple_separator(self, duckdb_cursor):
rel = duckdb_cursor.sql("select 'a' as s union all select 'b' union all select 'c'")
result = rel.string_agg("s", ",").fetchall()
assert result == [("a,b,c",)]

def test_separator_with_single_quote(self, duckdb_cursor):
# Separator containing a single quote should be properly escaped
rel = duckdb_cursor.sql("select 'a' as s union all select 'b'")
result = rel.string_agg("s", "','").fetchall()
assert result == [("a','b",)]

def test_separator_with_special_chars(self, duckdb_cursor):
rel = duckdb_cursor.sql("select 'x' as s union all select 'y'")
result = rel.string_agg("s", " | ").fetchall()
assert result == [("x | y",)]

def test_separator_injection_attempt(self, duckdb_cursor):
# Attempt to inject via separator - should be safely quoted as string literal
rel = duckdb_cursor.sql("select 'a' as s union all select 'b'")
# This should NOT execute the injection - separator becomes a literal string
result = rel.string_agg("s", "'); drop table agg; --").fetchall()
assert result == [("a'); drop table agg; --b",)]
Loading