diff --git a/CHANGELOG.md b/CHANGELOG.md index aeb30623..440a6e91 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added +- **MariaDB dialect** (`--dialect mariadb`): New SQL dialect extending MySQL with support for SEQUENCE DDL (`CREATE/DROP/ALTER SEQUENCE` with full option set), temporal tables (`FOR SYSTEM_TIME`, `WITH SYSTEM VERSIONING`, `PERIOD FOR`), and `CONNECT BY` hierarchical queries with `PRIOR`, `START WITH`, and `NOCYCLE` + ## [1.13.0] - 2026-03-20 ### Added diff --git a/docs/SQL_COMPATIBILITY.md b/docs/SQL_COMPATIBILITY.md index 8a4ec57f..f62dcbf7 100644 --- a/docs/SQL_COMPATIBILITY.md +++ b/docs/SQL_COMPATIBILITY.md @@ -317,6 +317,27 @@ This matrix documents the comprehensive SQL feature support in GoSQLX across dif | **AUTO_INCREMENT** | ✅ Full | ✅ Full | 95% | Column property | | **Backtick identifiers** | ✅ Full | ✅ Full | 100% | `` `table`.`column` `` syntax | +### MariaDB-Specific Features (v1.14.0+) + +MariaDB inherits all MySQL features (SHOW, DESCRIBE, REPLACE INTO, ON DUPLICATE KEY UPDATE, GROUP_CONCAT, MATCH/AGAINST, REGEXP/RLIKE, backtick identifiers, etc.) and adds the following extensions: + +| Feature | Support Level | GoSQLX Parser | Test Coverage | Notes | +|---------|---------------|---------------|---------------|-------| +| **CREATE SEQUENCE** | ✅ Full | ✅ Full | 95% | Full DDL with all sequence options | +| **DROP SEQUENCE** | ✅ Full | ✅ Full | 95% | DROP SEQUENCE [IF EXISTS] | +| **ALTER SEQUENCE** | ✅ Full | ✅ Full | 90% | RESTART, RESTART WITH, and all options | +| **Sequence options** | ✅ Full | ✅ Full | 95% | START WITH, INCREMENT BY, MINVALUE, MAXVALUE, CACHE, CYCLE, NOCACHE, NOCYCLE, RESTART, RESTART WITH | +| **FOR SYSTEM_TIME AS OF** | ✅ Full | ✅ Full | 95% | Point-in-time query on system-versioned tables | +| **FOR SYSTEM_TIME BETWEEN** | ✅ Full | ✅ Full | 95% | Range query on system-versioned tables | +| **FOR SYSTEM_TIME FROM/TO** | ✅ Full | ✅ Full | 95% | Range query (inclusive/exclusive) | +| **FOR SYSTEM_TIME ALL** | ✅ Full | ✅ Full | 95% | All rows including historical | +| **WITH SYSTEM VERSIONING** | ✅ Full | ✅ Full | 90% | CREATE TABLE ... WITH SYSTEM VERSIONING | +| **PERIOD FOR** | ✅ Full | ✅ Full | 85% | Application-time period definitions | +| **CONNECT BY** | ✅ Full | ✅ Full | 90% | Hierarchical queries with PRIOR and NOCYCLE | +| **START WITH (CONNECT BY)** | ✅ Full | ✅ Full | 90% | Root condition for hierarchical traversal | +| **PRIOR operator** | ✅ Full | ✅ Full | 90% | Reference parent row in CONNECT BY | +| **NOCYCLE** | ✅ Full | ✅ Full | 85% | Prevent infinite loops in cyclic graphs | + ### SQL Server-Specific Features | Feature | Support Level | GoSQLX Parser | Test Coverage | Notes | @@ -549,6 +570,7 @@ GoSQLX v1.8.0 introduces a first-class dialect mode engine that threads the SQL | **SQLite** | `"sqlite"` | SQLite keywords | Flexible typing, simplified syntax | ⚠️ Keywords + basic parsing | | **Snowflake** | `"snowflake"` | Snowflake keywords | Stage operations, VARIANT type | ⚠️ Keyword detection only | | **ClickHouse** | `"clickhouse"` | ClickHouse keywords | PREWHERE, FINAL, GLOBAL IN/NOT IN, MergeTree keywords | ✅ v1.13.0 | +| **MariaDB** | `"mariadb"` | MariaDB keywords (superset of MySQL) | All MySQL features + SEQUENCE DDL, FOR SYSTEM_TIME, WITH SYSTEM VERSIONING, PERIOD FOR, CONNECT BY | ✅ v1.14.0 | ### Usage @@ -597,6 +619,12 @@ gosqlx format --dialect mysql query.sql - No Snowflake-specific parsing (stages, COPY INTO, VARIANT operations) - QUALIFY clause not supported +#### MariaDB +- Inherits all MySQL known gaps (stored procedures, HANDLER, XA transactions, CREATE EVENT) +- JSON_TABLE not supported +- Spider storage engine syntax not parsed +- ColumnStore-specific syntax not supported + #### ClickHouse - PREWHERE clause for pre-filter optimization before primary key scan - FINAL modifier on table references (forces MergeTree part merge) @@ -644,6 +672,7 @@ gosqlx format --dialect mysql query.sql | **SQL Server** | 85% | 65% | ⭐⭐⭐⭐ Very Good | Keywords + MERGE | | **Oracle** | 80% | 60% | ⭐⭐⭐⭐ Good | Keywords + basic features | | **SQLite** | 85% | 50% | ⭐⭐⭐⭐ Good | Keywords + basic features | +| **MariaDB** | 95% | 90% | ⭐⭐⭐⭐⭐ Excellent | MySQL superset + SEQUENCE DDL, temporal tables, CONNECT BY (v1.14.0) | | **Snowflake** | 80% | 30% | ⭐⭐⭐ Good | Keyword detection only | ## Performance Characteristics by Feature diff --git a/pkg/sql/ast/ast.go b/pkg/sql/ast/ast.go index 4284bdca..24b36a3c 100644 --- a/pkg/sql/ast/ast.go +++ b/pkg/sql/ast/ast.go @@ -228,6 +228,9 @@ type TableReference struct { Lateral bool // LATERAL keyword for correlated subqueries (PostgreSQL) TableHints []string // SQL Server table hints: WITH (NOLOCK), WITH (ROWLOCK, UPDLOCK), etc. Final bool // ClickHouse FINAL modifier: forces MergeTree part merge + // ForSystemTime is the MariaDB temporal table clause (10.3.4+). + // Example: SELECT * FROM t FOR SYSTEM_TIME AS OF '2024-01-01' + ForSystemTime *ForSystemTimeClause // MariaDB temporal query } func (t *TableReference) statementNode() {} @@ -397,13 +400,19 @@ type SelectStatement struct { Where Expression GroupBy []Expression Having Expression - Windows []WindowSpec - OrderBy []OrderByExpression - Limit *int - Offset *int - Fetch *FetchClause // SQL-99 FETCH FIRST/NEXT clause (F861, F862) - For *ForClause // Row-level locking clause (SQL:2003, PostgreSQL, MySQL) - Pos models.Location // Source position of the SELECT keyword (1-based line and column) + // StartWith is the optional seed condition for CONNECT BY (MariaDB 10.2+). + // Example: START WITH parent_id IS NULL + StartWith Expression // MariaDB hierarchical query seed + // ConnectBy holds the hierarchy traversal condition (MariaDB 10.2+). + // Example: CONNECT BY PRIOR id = parent_id + ConnectBy *ConnectByClause // MariaDB hierarchical query + Windows []WindowSpec + OrderBy []OrderByExpression + Limit *int + Offset *int + Fetch *FetchClause // SQL-99 FETCH FIRST/NEXT clause (F861, F862) + For *ForClause // Row-level locking clause (SQL:2003, PostgreSQL, MySQL) + Pos models.Location // Source position of the SELECT keyword (1-based line and column) } // TopClause represents SQL Server's TOP N [PERCENT] clause @@ -518,6 +527,12 @@ func (s SelectStatement) Children() []Node { if s.For != nil { children = append(children, s.For) } + if s.StartWith != nil { + children = append(children, s.StartWith) + } + if s.ConnectBy != nil { + children = append(children, s.ConnectBy) + } return children } @@ -1275,6 +1290,14 @@ type CreateTableStatement struct { Partitions []PartitionDefinition // Individual partition definitions Options []TableOption WithoutRowID bool // SQLite: CREATE TABLE ... WITHOUT ROWID + + // WithSystemVersioning enables system-versioned temporal history (MariaDB 10.3.4+). + // Example: CREATE TABLE t (...) WITH SYSTEM VERSIONING + WithSystemVersioning bool + + // PeriodDefinitions holds PERIOD FOR clauses for application-time or system-time periods. + // Example: PERIOD FOR app_time (start_col, end_col) + PeriodDefinitions []*PeriodDefinition } func (c *CreateTableStatement) statementNode() {} @@ -1815,3 +1838,187 @@ func (r ReplaceStatement) Children() []Node { } return children } + +// ── MariaDB SEQUENCE DDL (10.3+) ─────────────────────────────────────────── + +// CycleOption represents the CYCLE behavior for a sequence. +type CycleOption int + +const ( + // CycleUnspecified means no CYCLE or NOCYCLE clause was given (database default applies). + CycleUnspecified CycleOption = iota + // CycleBehavior means CYCLE — sequence wraps around when it reaches min/max. + CycleBehavior + // NoCycleBehavior means NOCYCLE / NO CYCLE — sequence errors on overflow. + NoCycleBehavior +) + +// SequenceOptions holds configuration for CREATE SEQUENCE and ALTER SEQUENCE. +// Fields are pointers so that unspecified options are distinguishable from zero values. +type SequenceOptions struct { + StartWith *LiteralValue // START WITH n + IncrementBy *LiteralValue // INCREMENT BY n (default 1) + MinValue *LiteralValue // MINVALUE n or nil when NO MINVALUE + MaxValue *LiteralValue // MAXVALUE n or nil when NO MAXVALUE + Cache *LiteralValue // CACHE n or nil when NO CACHE / NOCACHE + CycleMode CycleOption // CYCLE / NOCYCLE / NO CYCLE (CycleUnspecified if not specified) + NoCache bool // NOCACHE (explicit; Cache=nil alone is ambiguous) + Restart bool // bare RESTART (reset to start value) + RestartWith *LiteralValue // RESTART WITH n (explicit restart value) +} + +// CreateSequenceStatement represents: +// +// CREATE [OR REPLACE] SEQUENCE [IF NOT EXISTS] name [options...] +type CreateSequenceStatement struct { + Name *Identifier + OrReplace bool + IfNotExists bool + Options SequenceOptions + Pos models.Location // Source position of the CREATE keyword (1-based line and column) +} + +func (s *CreateSequenceStatement) statementNode() {} +func (s *CreateSequenceStatement) TokenLiteral() string { return "CREATE" } +func (s *CreateSequenceStatement) Children() []Node { + if s.Name != nil { + return []Node{s.Name} + } + return nil +} + +// DropSequenceStatement represents: +// +// DROP SEQUENCE [IF EXISTS | IF NOT EXISTS] name +type DropSequenceStatement struct { + Name *Identifier + IfExists bool + Pos models.Location // Source position of the DROP keyword (1-based line and column) +} + +func (s *DropSequenceStatement) statementNode() {} +func (s *DropSequenceStatement) TokenLiteral() string { return "DROP" } +func (s *DropSequenceStatement) Children() []Node { + if s.Name != nil { + return []Node{s.Name} + } + return nil +} + +// AlterSequenceStatement represents: +// +// ALTER SEQUENCE [IF EXISTS] name [options...] +type AlterSequenceStatement struct { + Name *Identifier + IfExists bool + Options SequenceOptions + Pos models.Location // Source position of the ALTER keyword (1-based line and column) +} + +func (s *AlterSequenceStatement) statementNode() {} +func (s *AlterSequenceStatement) TokenLiteral() string { return "ALTER" } +func (s *AlterSequenceStatement) Children() []Node { + if s.Name != nil { + return []Node{s.Name} + } + return nil +} + +// ── MariaDB Temporal Table Types (10.3.4+) ──────────────────────────────── + +// SystemTimeClauseType identifies the kind of FOR SYSTEM_TIME clause. +type SystemTimeClauseType int + +const ( + SystemTimeAsOf SystemTimeClauseType = iota // FOR SYSTEM_TIME AS OF + SystemTimeBetween // FOR SYSTEM_TIME BETWEEN AND + SystemTimeFromTo // FOR SYSTEM_TIME FROM TO + SystemTimeAll // FOR SYSTEM_TIME ALL +) + +// ForSystemTimeClause represents a temporal query on a system-versioned table. +// +// SELECT * FROM t FOR SYSTEM_TIME AS OF TIMESTAMP '2024-01-01'; +// SELECT * FROM t FOR SYSTEM_TIME BETWEEN '2020-01-01' AND '2024-01-01'; +// SELECT * FROM t FOR SYSTEM_TIME ALL; +type ForSystemTimeClause struct { + Type SystemTimeClauseType + Point Expression // used for AS OF + Start Expression // used for BETWEEN, FROM + End Expression // used for BETWEEN (AND), TO + Pos models.Location // Source position of the FOR keyword (1-based line and column) +} + +// expressionNode satisfies the Expression interface so ForSystemTimeClause can be +// stored in TableReference.ForSystemTime without a separate interface type. +// Semantically it is a table-level clause, not a scalar expression. +func (c *ForSystemTimeClause) expressionNode() {} +func (c ForSystemTimeClause) TokenLiteral() string { return "FOR SYSTEM_TIME" } +func (c ForSystemTimeClause) Children() []Node { + var nodes []Node + if c.Point != nil { + nodes = append(nodes, c.Point) + } + if c.Start != nil { + nodes = append(nodes, c.Start) + } + if c.End != nil { + nodes = append(nodes, c.End) + } + return nodes +} + +// PeriodDefinition represents a PERIOD FOR clause in CREATE TABLE. +// +// PERIOD FOR app_time (start_col, end_col) +// PERIOD FOR SYSTEM_TIME (row_start, row_end) +type PeriodDefinition struct { + Name *Identifier // period name (e.g., "app_time") or SYSTEM_TIME + StartCol *Identifier + EndCol *Identifier + Pos models.Location // Source position of the PERIOD FOR keyword (1-based line and column) +} + +// expressionNode satisfies the Expression interface so PeriodDefinition can be +// stored in CreateTableStatement.PeriodDefinitions without a separate interface type. +// Semantically it is a table column constraint, not a scalar expression. +func (p *PeriodDefinition) expressionNode() {} +func (p PeriodDefinition) TokenLiteral() string { return "PERIOD FOR" } +func (p PeriodDefinition) Children() []Node { + var nodes []Node + if p.Name != nil { + nodes = append(nodes, p.Name) + } + if p.StartCol != nil { + nodes = append(nodes, p.StartCol) + } + if p.EndCol != nil { + nodes = append(nodes, p.EndCol) + } + return nodes +} + +// ── MariaDB Hierarchical Query / CONNECT BY (10.2+) ─────────────────────── + +// ConnectByClause represents the CONNECT BY hierarchical query clause (MariaDB 10.2+). +// +// SELECT id, name FROM t +// START WITH parent_id IS NULL +// CONNECT BY NOCYCLE PRIOR id = parent_id; +type ConnectByClause struct { + NoCycle bool // NOCYCLE modifier — prevents loops in cyclic graphs + Condition Expression // the PRIOR expression (e.g., PRIOR id = parent_id) + Pos models.Location // Source position of the CONNECT BY keyword (1-based line and column) +} + +// expressionNode satisfies the Expression interface so ConnectByClause can be +// stored in SelectStatement.ConnectBy without a separate interface type. +// Semantically it is a query-level clause, not a scalar expression. +func (c *ConnectByClause) expressionNode() {} +func (c ConnectByClause) TokenLiteral() string { return "CONNECT BY" } +func (c ConnectByClause) Children() []Node { + if c.Condition != nil { + return []Node{c.Condition} + } + return nil +} diff --git a/pkg/sql/ast/ast_sequence_test.go b/pkg/sql/ast/ast_sequence_test.go new file mode 100644 index 00000000..9c46889f --- /dev/null +++ b/pkg/sql/ast/ast_sequence_test.go @@ -0,0 +1,239 @@ +package ast_test + +import ( + "strings" + "testing" + + "github.com/ajitpratap0/GoSQLX/pkg/sql/ast" +) + +func TestCreateSequenceStatement_ToSQL(t *testing.T) { + tests := []struct { + name string + stmt *ast.CreateSequenceStatement + want string + }{ + { + name: "minimal", + stmt: &ast.CreateSequenceStatement{ + Name: &ast.Identifier{Name: "seq_orders"}, + }, + want: "CREATE SEQUENCE seq_orders", + }, + { + name: "or replace", + stmt: &ast.CreateSequenceStatement{ + Name: &ast.Identifier{Name: "seq_orders"}, + OrReplace: true, + }, + want: "CREATE OR REPLACE SEQUENCE seq_orders", + }, + { + name: "if not exists", + stmt: &ast.CreateSequenceStatement{ + Name: &ast.Identifier{Name: "seq_orders"}, + IfNotExists: true, + }, + want: "CREATE SEQUENCE IF NOT EXISTS seq_orders", + }, + { + name: "with options", + stmt: &ast.CreateSequenceStatement{ + Name: &ast.Identifier{Name: "s"}, + Options: ast.SequenceOptions{ + StartWith: &ast.LiteralValue{Value: "1"}, + IncrementBy: &ast.LiteralValue{Value: "1"}, + MinValue: &ast.LiteralValue{Value: "1"}, + MaxValue: &ast.LiteralValue{Value: "9999"}, + Cache: &ast.LiteralValue{Value: "100"}, + CycleMode: ast.CycleBehavior, + }, + }, + want: "CREATE SEQUENCE s START WITH 1 INCREMENT BY 1 MINVALUE 1 MAXVALUE 9999 CACHE 100 CYCLE", + }, + { + name: "nocycle", + stmt: &ast.CreateSequenceStatement{ + Name: &ast.Identifier{Name: "s"}, + Options: ast.SequenceOptions{CycleMode: ast.NoCycleBehavior}, + }, + want: "CREATE SEQUENCE s NOCYCLE", + }, + { + name: "nocache", + stmt: &ast.CreateSequenceStatement{ + Name: &ast.Identifier{Name: "s"}, + Options: ast.SequenceOptions{NoCache: true}, + }, + want: "CREATE SEQUENCE s NOCACHE", + }, + { + name: "nil name does not panic", + stmt: &ast.CreateSequenceStatement{}, + want: "CREATE SEQUENCE ", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.stmt.ToSQL() + if got != tt.want { + t.Errorf("ToSQL() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestDropSequenceStatement_ToSQL(t *testing.T) { + tests := []struct { + name string + stmt *ast.DropSequenceStatement + want string + }{ + { + name: "basic", + stmt: &ast.DropSequenceStatement{Name: &ast.Identifier{Name: "seq_orders"}}, + want: "DROP SEQUENCE seq_orders", + }, + { + name: "if exists", + stmt: &ast.DropSequenceStatement{Name: &ast.Identifier{Name: "seq_orders"}, IfExists: true}, + want: "DROP SEQUENCE IF EXISTS seq_orders", + }, + { + name: "nil name does not panic", + stmt: &ast.DropSequenceStatement{}, + want: "DROP SEQUENCE ", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.stmt.ToSQL() + if got != tt.want { + t.Errorf("ToSQL() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestAlterSequenceStatement_ToSQL(t *testing.T) { + tests := []struct { + name string + stmt *ast.AlterSequenceStatement + want string + }{ + { + name: "restart bare", + stmt: &ast.AlterSequenceStatement{ + Name: &ast.Identifier{Name: "s"}, + Options: ast.SequenceOptions{Restart: true}, + }, + want: "ALTER SEQUENCE s RESTART", + }, + { + name: "restart with value", + stmt: &ast.AlterSequenceStatement{ + Name: &ast.Identifier{Name: "s"}, + Options: ast.SequenceOptions{ + RestartWith: &ast.LiteralValue{Value: "1"}, + }, + }, + want: "ALTER SEQUENCE s RESTART WITH 1", + }, + { + name: "nil name does not panic", + stmt: &ast.AlterSequenceStatement{}, + want: "ALTER SEQUENCE ", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.stmt.ToSQL() + if got != tt.want { + t.Errorf("ToSQL() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestSequencePool_RoundTrip(t *testing.T) { + s := ast.NewCreateSequenceStatement() + if s == nil { + t.Fatal("NewCreateSequenceStatement() returned nil") + } + s.Name = &ast.Identifier{Name: "test"} + ast.ReleaseCreateSequenceStatement(s) + + s2 := ast.NewCreateSequenceStatement() + if s2 == nil { + t.Fatal("second NewCreateSequenceStatement() returned nil") + } + if s2.Name != nil { + t.Error("expected Name to be nil after release (pool zero-reset)") + } + ast.ReleaseCreateSequenceStatement(s2) +} + +func TestSelectStatement_ConnectBy_SQLOrder(t *testing.T) { + limit := 10 + stmt := &ast.SelectStatement{ + Columns: []ast.Expression{&ast.Identifier{Name: "*"}}, + From: []ast.TableReference{ + {Name: "employees"}, + }, + StartWith: &ast.BinaryExpression{ + Left: &ast.Identifier{Name: "parent_id"}, + Operator: "IS", + Right: &ast.Identifier{Name: "NULL"}, + }, + ConnectBy: &ast.ConnectByClause{ + NoCycle: true, + Condition: &ast.BinaryExpression{ + Left: &ast.UnaryExpression{Operator: ast.Prior, Expr: &ast.Identifier{Name: "id"}}, + Operator: "=", + Right: &ast.Identifier{Name: "parent_id"}, + }, + }, + OrderBy: []ast.OrderByExpression{ + {Expression: &ast.Identifier{Name: "id"}}, + }, + Limit: &limit, + } + got := stmt.SQL() + startIdx := strings.Index(got, "START WITH") + orderIdx := strings.Index(got, "ORDER BY") + if startIdx == -1 { + t.Fatal("SQL() missing START WITH") + } + if orderIdx == -1 { + t.Fatal("SQL() missing ORDER BY") + } + if startIdx > orderIdx { + t.Errorf("START WITH appears after ORDER BY in SQL():\n %s", got) + } +} + +func TestPeriodDefinition_SQL(t *testing.T) { + pd := &ast.PeriodDefinition{ + Name: &ast.Identifier{Name: "app_time"}, + StartCol: &ast.Identifier{Name: "valid_from"}, + EndCol: &ast.Identifier{Name: "valid_to"}, + } + got := pd.SQL() + want := "PERIOD FOR app_time (valid_from, valid_to)" + if got != want { + t.Errorf("PeriodDefinition.SQL() = %q, want %q", got, want) + } +} + +func TestPeriodDefinition_SQL_SystemTime(t *testing.T) { + pd := &ast.PeriodDefinition{ + Name: &ast.Identifier{Name: "SYSTEM_TIME"}, + StartCol: &ast.Identifier{Name: "row_start"}, + EndCol: &ast.Identifier{Name: "row_end"}, + } + got := pd.SQL() + want := "PERIOD FOR SYSTEM_TIME (row_start, row_end)" + if got != want { + t.Errorf("PeriodDefinition.SQL() = %q, want %q", got, want) + } +} diff --git a/pkg/sql/ast/operator.go b/pkg/sql/ast/operator.go index f6f5bb98..7d9a5d08 100644 --- a/pkg/sql/ast/operator.go +++ b/pkg/sql/ast/operator.go @@ -81,6 +81,8 @@ const ( PGAbs // BangNot represents Hive-specific logical NOT operator, e.g. ! false BangNot + // Prior represents MariaDB CONNECT BY parent reference operator, e.g. PRIOR id + Prior ) // String returns the string representation of the unary operator @@ -106,6 +108,8 @@ func (op UnaryOperator) String() string { return "@" case BangNot: return "!" + case Prior: + return "PRIOR" default: return "UNKNOWN" } diff --git a/pkg/sql/ast/pool.go b/pkg/sql/ast/pool.go index 4ee21729..f1411e1d 100644 --- a/pkg/sql/ast/pool.go +++ b/pkg/sql/ast/pool.go @@ -351,6 +351,18 @@ var ( return &s }, } + + createSequencePool = sync.Pool{ + New: func() interface{} { return &CreateSequenceStatement{} }, + } + + dropSequencePool = sync.Pool{ + New: func() interface{} { return &DropSequenceStatement{} }, + } + + alterSequencePool = sync.Pool{ + New: func() interface{} { return &AlterSequenceStatement{} }, + } ) // NewAST retrieves a new AST container from the pool. @@ -1794,3 +1806,37 @@ func PutAlterStatement(stmt *AlterStatement) { alterStmtPool.Put(stmt) } + +// NewCreateSequenceStatement retrieves a CreateSequenceStatement from the pool. +func NewCreateSequenceStatement() *CreateSequenceStatement { + return createSequencePool.Get().(*CreateSequenceStatement) +} + +// ReleaseCreateSequenceStatement returns a CreateSequenceStatement to the pool. +func ReleaseCreateSequenceStatement(s *CreateSequenceStatement) { + *s = CreateSequenceStatement{} // zero all fields + createSequencePool.Put(s) +} + +// NewDropSequenceStatement retrieves a DropSequenceStatement from the pool. +func NewDropSequenceStatement() *DropSequenceStatement { + return dropSequencePool.Get().(*DropSequenceStatement) +} + +// ReleaseDropSequenceStatement returns a DropSequenceStatement to the pool. +// Always call this with defer after parsing is complete. +func ReleaseDropSequenceStatement(s *DropSequenceStatement) { + *s = DropSequenceStatement{} // zero all fields + dropSequencePool.Put(s) +} + +// NewAlterSequenceStatement retrieves an AlterSequenceStatement from the pool. +func NewAlterSequenceStatement() *AlterSequenceStatement { + return alterSequencePool.Get().(*AlterSequenceStatement) +} + +// ReleaseAlterSequenceStatement returns an AlterSequenceStatement to the pool. +func ReleaseAlterSequenceStatement(s *AlterSequenceStatement) { + *s = AlterSequenceStatement{} // zero all fields + alterSequencePool.Put(s) +} diff --git a/pkg/sql/ast/sql.go b/pkg/sql/ast/sql.go index 4236d106..eea64033 100644 --- a/pkg/sql/ast/sql.go +++ b/pkg/sql/ast/sql.go @@ -202,6 +202,8 @@ func (u *UnaryExpression) SQL() string { return "+" + inner case Minus: return "-" + inner + case Prior: + return "PRIOR " + inner default: return u.Operator.String() + inner } @@ -580,6 +582,17 @@ func (s *SelectStatement) SQL() string { sb.WriteString(exprSQL(s.Having)) } + // MariaDB hierarchical query clauses (10.2+): START WITH ... CONNECT BY ... + // These must appear after HAVING and before ORDER BY per MariaDB grammar. + if s.StartWith != nil { + sb.WriteString(" START WITH ") + sb.WriteString(exprSQL(s.StartWith)) + } + if s.ConnectBy != nil { + sb.WriteString(" ") + sb.WriteString(s.ConnectBy.ToSQL()) + } + if len(s.Windows) > 0 { sb.WriteString(" WINDOW ") wins := make([]string, len(s.Windows)) @@ -1306,6 +1319,10 @@ func tableRefSQL(t *TableReference) string { if t.Final { sb.WriteString(" FINAL") } + if t.ForSystemTime != nil { + sb.WriteString(" ") + sb.WriteString(t.ForSystemTime.ToSQL()) + } return sb.String() } @@ -1585,3 +1602,151 @@ func mergeActionSQL(a *MergeAction) string { return a.ActionType } } + +// ToSQL returns the SQL string for CREATE SEQUENCE. +func (s *CreateSequenceStatement) ToSQL() string { + var b strings.Builder + b.WriteString("CREATE ") + if s.OrReplace { + b.WriteString("OR REPLACE ") + } + b.WriteString("SEQUENCE ") + if s.IfNotExists { + b.WriteString("IF NOT EXISTS ") + } + if s.Name != nil { + b.WriteString(s.Name.Name) + } + writeSequenceOptions(&b, s.Options) + return b.String() +} + +// ToSQL returns the SQL string for DROP SEQUENCE. +func (s *DropSequenceStatement) ToSQL() string { + var b strings.Builder + b.WriteString("DROP SEQUENCE ") + if s.IfExists { + b.WriteString("IF EXISTS ") + } + if s.Name != nil { + b.WriteString(s.Name.Name) + } + return b.String() +} + +// ToSQL returns the SQL string for ALTER SEQUENCE. +func (s *AlterSequenceStatement) ToSQL() string { + var b strings.Builder + b.WriteString("ALTER SEQUENCE ") + if s.IfExists { + b.WriteString("IF EXISTS ") + } + if s.Name != nil { + b.WriteString(s.Name.Name) + } + writeSequenceOptions(&b, s.Options) + return b.String() +} + +// writeSequenceOptions is a shared helper for CREATE/ALTER SEQUENCE serialization. +func writeSequenceOptions(b *strings.Builder, opts SequenceOptions) { + if opts.StartWith != nil { + b.WriteString(" START WITH ") + b.WriteString(opts.StartWith.TokenLiteral()) + } + if opts.IncrementBy != nil { + b.WriteString(" INCREMENT BY ") + b.WriteString(opts.IncrementBy.TokenLiteral()) + } + if opts.MinValue != nil { + b.WriteString(" MINVALUE ") + b.WriteString(opts.MinValue.TokenLiteral()) + } + if opts.MaxValue != nil { + b.WriteString(" MAXVALUE ") + b.WriteString(opts.MaxValue.TokenLiteral()) + } + if opts.Cache != nil { + b.WriteString(" CACHE ") + b.WriteString(opts.Cache.TokenLiteral()) + } else if opts.NoCache { + b.WriteString(" NOCACHE") + } + switch opts.CycleMode { + case CycleBehavior: + b.WriteString(" CYCLE") + case NoCycleBehavior: + b.WriteString(" NOCYCLE") + } + if opts.RestartWith != nil { + b.WriteString(" RESTART WITH ") + b.WriteString(opts.RestartWith.TokenLiteral()) + } else if opts.Restart { + b.WriteString(" RESTART") + } +} + +// SQL implements the Expression interface for ForSystemTimeClause. +func (c *ForSystemTimeClause) SQL() string { return c.ToSQL() } + +// ToSQL returns the SQL string for a FOR SYSTEM_TIME clause (MariaDB 10.3.4+). +func (c *ForSystemTimeClause) ToSQL() string { + var b strings.Builder + b.WriteString("FOR SYSTEM_TIME ") + switch c.Type { + case SystemTimeAsOf: + b.WriteString("AS OF ") + b.WriteString(exprSQL(c.Point)) + case SystemTimeBetween: + b.WriteString("BETWEEN ") + b.WriteString(exprSQL(c.Start)) + b.WriteString(" AND ") + b.WriteString(exprSQL(c.End)) + case SystemTimeFromTo: + b.WriteString("FROM ") + b.WriteString(exprSQL(c.Start)) + b.WriteString(" TO ") + b.WriteString(exprSQL(c.End)) + case SystemTimeAll: + b.WriteString("ALL") + } + return b.String() +} + +// SQL implements the Expression interface for ConnectByClause. +func (c *ConnectByClause) SQL() string { return c.ToSQL() } + +// SQL implements the Expression interface for PeriodDefinition (stub; not used as a standalone expression). +// SQL returns the SQL string for a PERIOD FOR clause in CREATE TABLE. +// Example: PERIOD FOR app_time (valid_from, valid_to) +func (p *PeriodDefinition) SQL() string { + if p == nil { + return "" + } + var b strings.Builder + b.WriteString("PERIOD FOR ") + if p.Name != nil { + b.WriteString(p.Name.Name) + } + b.WriteString(" (") + if p.StartCol != nil { + b.WriteString(p.StartCol.Name) + } + b.WriteString(", ") + if p.EndCol != nil { + b.WriteString(p.EndCol.Name) + } + b.WriteString(")") + return b.String() +} + +// ToSQL returns the SQL string for a CONNECT BY clause (MariaDB 10.2+). +func (c *ConnectByClause) ToSQL() string { + var b strings.Builder + b.WriteString("CONNECT BY ") + if c.NoCycle { + b.WriteString("NOCYCLE ") + } + b.WriteString(exprSQL(c.Condition)) + return b.String() +} diff --git a/pkg/sql/keywords/detect.go b/pkg/sql/keywords/detect.go index 8e6b9506..5b48c283 100644 --- a/pkg/sql/keywords/detect.go +++ b/pkg/sql/keywords/detect.go @@ -73,10 +73,22 @@ var dialectHints = []dialectHint{ // Oracle-specific (high confidence) {pattern: "ROWNUM", dialect: DialectOracle, weight: 5}, - {pattern: "CONNECT BY", dialect: DialectOracle, weight: 5}, + {pattern: "CONNECT BY", dialect: DialectOracle, weight: 3}, {pattern: "SYSDATE", dialect: DialectOracle, weight: 5}, {pattern: "DECODE", dialect: DialectOracle, weight: 3}, + // MariaDB-specific (high confidence — these features don't appear in MySQL or Oracle) + {pattern: "NEXTVAL", dialect: DialectMariaDB, weight: 5}, + {pattern: "LASTVAL", dialect: DialectMariaDB, weight: 5}, + {pattern: "SETVAL", dialect: DialectMariaDB, weight: 5}, + {pattern: "NEXT VALUE FOR", dialect: DialectMariaDB, weight: 5}, + {pattern: "SYSTEM VERSIONING", dialect: DialectMariaDB, weight: 5}, + {pattern: "FOR SYSTEM_TIME", dialect: DialectMariaDB, weight: 5}, + {pattern: "VERSIONING", dialect: DialectMariaDB, weight: 4}, + {pattern: "CONNECT BY", dialect: DialectMariaDB, weight: 2}, + {pattern: "CREATE SEQUENCE", dialect: DialectMariaDB, weight: 5}, + {pattern: "DROP SEQUENCE", dialect: DialectMariaDB, weight: 5}, + // SQLite-specific (high confidence) {pattern: "AUTOINCREMENT", dialect: DialectSQLite, weight: 5}, {pattern: "GLOB", dialect: DialectSQLite, weight: 4}, @@ -98,6 +110,7 @@ var dialectHints = []dialectHint{ // - MySQL: ZEROFILL, UNSIGNED, AUTO_INCREMENT, FORCE INDEX // - SQL Server: NOLOCK, TOP, NVARCHAR, GETDATE // - Oracle: ROWNUM, CONNECT BY, SYSDATE, DECODE +// - MariaDB: NEXTVAL, LASTVAL, SETVAL, NEXT VALUE FOR, SYSTEM VERSIONING, FOR SYSTEM_TIME, VERSIONING, CONNECT BY, CREATE SEQUENCE, DROP SEQUENCE // - SQLite: AUTOINCREMENT, GLOB, VACUUM // // The function also performs syntactic checks for identifier quoting styles: @@ -113,6 +126,9 @@ var dialectHints = []dialectHint{ // dialect = keywords.DetectDialect("SELECT DISTINCT ON (dept) * FROM emp") // // dialect == DialectPostgreSQL // +// dialect = keywords.DetectDialect("SELECT NEXTVAL(seq_orders)") +// // dialect == DialectMariaDB +// // dialect = keywords.DetectDialect("SELECT * FROM users") // // dialect == DialectGeneric func DetectDialect(sql string) SQLDialect { diff --git a/pkg/sql/keywords/dialect.go b/pkg/sql/keywords/dialect.go index 934b40ae..6062f40b 100644 --- a/pkg/sql/keywords/dialect.go +++ b/pkg/sql/keywords/dialect.go @@ -64,6 +64,13 @@ const ( // definitions (ENGINE, CODEC, TTL), ClickHouse data types (FixedString, // LowCardinality, Nullable, DateTime64), and replication keywords (ON CLUSTER, GLOBAL). DialectClickHouse SQLDialect = "clickhouse" + + // DialectMariaDB represents MariaDB-specific keywords and extensions. + // MariaDB is a superset of MySQL; this dialect includes all MySQL keywords + // (UNSIGNED, ZEROFILL, ON DUPLICATE KEY UPDATE, etc.) plus MariaDB-specific + // features: SEQUENCE DDL (10.3+), system-versioned temporal tables (10.3.4+), + // CONNECT BY hierarchical queries (10.2+), and index visibility (10.6+). + DialectMariaDB SQLDialect = "mariadb" ) // DialectKeywords returns the additional keywords for a specific dialect. @@ -84,6 +91,11 @@ func DialectKeywords(dialect SQLDialect) []Keyword { return SNOWFLAKE_SPECIFIC case DialectMySQL: return MYSQL_SPECIFIC + case DialectMariaDB: + combined := make([]Keyword, 0, len(MYSQL_SPECIFIC)+len(MARIADB_SPECIFIC)) + combined = append(combined, MYSQL_SPECIFIC...) + combined = append(combined, MARIADB_SPECIFIC...) + return combined case DialectPostgreSQL: return POSTGRESQL_SPECIFIC case DialectSQLite: @@ -133,6 +145,7 @@ func AllDialects() []SQLDialect { DialectGeneric, DialectPostgreSQL, DialectMySQL, + DialectMariaDB, DialectSQLServer, DialectOracle, DialectSQLite, diff --git a/pkg/sql/keywords/keywords.go b/pkg/sql/keywords/keywords.go index 5e93dfe5..3f24d34e 100644 --- a/pkg/sql/keywords/keywords.go +++ b/pkg/sql/keywords/keywords.go @@ -265,6 +265,10 @@ func New(dialect SQLDialect, ignoreCase bool) *Keywords { switch dialect { case DialectMySQL: k.addKeywordsWithCategory(MYSQL_SPECIFIC) + case DialectMariaDB: + // MariaDB is a superset of MySQL — load MySQL base first, then MariaDB extras + k.addKeywordsWithCategory(MYSQL_SPECIFIC) + k.addKeywordsWithCategory(MARIADB_SPECIFIC) case DialectPostgreSQL: k.addKeywordsWithCategory(POSTGRESQL_SPECIFIC) case DialectSQLite: diff --git a/pkg/sql/keywords/mariadb.go b/pkg/sql/keywords/mariadb.go new file mode 100644 index 00000000..b2feea23 --- /dev/null +++ b/pkg/sql/keywords/mariadb.go @@ -0,0 +1,63 @@ +// Copyright 2026 GoSQLX Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package keywords + +import "github.com/ajitpratap0/GoSQLX/pkg/models" + +// MARIADB_SPECIFIC contains MariaDB-specific SQL keywords beyond the MySQL base. +// When DialectMariaDB is active, both MYSQL_SPECIFIC and MARIADB_SPECIFIC are loaded +// (MariaDB is a superset of MySQL). +// +// Features covered: +// - SEQUENCE DDL (MariaDB 10.3+): CREATE/DROP/ALTER SEQUENCE, NEXTVAL, LASTVAL, SETVAL +// - Temporal tables (MariaDB 10.3.4+): WITH SYSTEM VERSIONING, FOR SYSTEM_TIME, PERIOD FOR +// - Hierarchical queries (MariaDB 10.2+): CONNECT BY, START WITH, PRIOR, NOCYCLE +// - Index visibility (MariaDB 10.6+): INVISIBLE, VISIBLE modifiers +// +// Note: MAXVALUE is already in ADDITIONAL_KEYWORDS (base list, all dialects). +// Note: MINVALUE is already in ORACLE_SPECIFIC. Neither needs repeating here. +// Note: INCREMENT, RESTART, NOCACHE are already in ADDITIONAL_KEYWORDS. +var MARIADB_SPECIFIC = []Keyword{ + // ── SEQUENCE DDL (MariaDB 10.3+) ─────────────────────────────────────── + // CREATE SEQUENCE s START WITH 1 INCREMENT BY 1 MINVALUE 1 MAXVALUE 9999 CYCLE CACHE 100; + // SELECT NEXT VALUE FOR s; -- ANSI style + // SELECT NEXTVAL(s); -- MariaDB style + // MINVALUE/MAXVALUE/INCREMENT/RESTART/NOCACHE covered by base or Oracle lists. + {Word: "SEQUENCE", Type: models.TokenTypeKeyword, Reserved: false, ReservedForTableAlias: false}, + {Word: "NEXTVAL", Type: models.TokenTypeKeyword, Reserved: false, ReservedForTableAlias: false}, + {Word: "LASTVAL", Type: models.TokenTypeKeyword, Reserved: false, ReservedForTableAlias: false}, + {Word: "SETVAL", Type: models.TokenTypeKeyword, Reserved: false, ReservedForTableAlias: false}, + {Word: "NOCYCLE", Type: models.TokenTypeKeyword, Reserved: false, ReservedForTableAlias: false}, + + // ── Temporal tables / System versioning (MariaDB 10.3.4+) ───────────── + // CREATE TABLE t (...) WITH SYSTEM VERSIONING; + // SELECT * FROM t FOR SYSTEM_TIME AS OF TIMESTAMP '2024-01-01'; + // PERIOD FOR app_time (start_col, end_col) + {Word: "VERSIONING", Type: models.TokenTypeKeyword, Reserved: false, ReservedForTableAlias: false}, + {Word: "PERIOD", Type: models.TokenTypeKeyword, Reserved: false, ReservedForTableAlias: false}, + {Word: "OVERLAPS", Type: models.TokenTypeKeyword, Reserved: false, ReservedForTableAlias: false}, + // SYSTEM_TIME is reserved so it doesn't collide as a table alias + {Word: "SYSTEM_TIME", Type: models.TokenTypeKeyword, Reserved: true, ReservedForTableAlias: true}, + + // ── Hierarchical queries / CONNECT BY (MariaDB 10.2+) ────────────────── + // SELECT id FROM t START WITH parent_id IS NULL CONNECT BY PRIOR id = parent_id; + {Word: "PRIOR", Type: models.TokenTypeKeyword, Reserved: false, ReservedForTableAlias: false}, + + // ── Index visibility (MariaDB 10.6+) ──────────────────────────────────── + // CREATE INDEX idx ON t (col) INVISIBLE; + // ALTER TABLE t ALTER INDEX idx VISIBLE; + {Word: "INVISIBLE", Type: models.TokenTypeKeyword, Reserved: false, ReservedForTableAlias: false}, + {Word: "VISIBLE", Type: models.TokenTypeKeyword, Reserved: false, ReservedForTableAlias: false}, +} diff --git a/pkg/sql/keywords/mariadb_test.go b/pkg/sql/keywords/mariadb_test.go new file mode 100644 index 00000000..ece9df1f --- /dev/null +++ b/pkg/sql/keywords/mariadb_test.go @@ -0,0 +1,124 @@ +package keywords_test + +import ( + "testing" + + "github.com/ajitpratap0/GoSQLX/pkg/sql/keywords" +) + +func TestDialectMariaDB_Constant(t *testing.T) { + if string(keywords.DialectMariaDB) != "mariadb" { + t.Fatalf("expected DialectMariaDB = \"mariadb\", got %q", keywords.DialectMariaDB) + } +} + +func TestDialectMariaDB_InAllDialects(t *testing.T) { + found := false + for _, d := range keywords.AllDialects() { + if d == keywords.DialectMariaDB { + found = true + break + } + } + if !found { + t.Error("DialectMariaDB not found in AllDialects()") + } +} + +func TestDialectMariaDB_IsValidDialect(t *testing.T) { + if !keywords.IsValidDialect("mariadb") { + t.Error("IsValidDialect(\"mariadb\") returned false") + } +} + +func TestDialectMariaDB_InheritsMySQL(t *testing.T) { + kw := keywords.New(keywords.DialectMariaDB, true) + for _, word := range []string{"UNSIGNED", "ZEROFILL", "DATETIME"} { + if !kw.IsKeyword(word) { + t.Errorf("expected MariaDB to inherit MySQL keyword %q", word) + } + } +} + +func TestMariaDBKeywords_Recognized(t *testing.T) { + kw := keywords.New(keywords.DialectMariaDB, true) + + mariadbOnly := []string{ + // Sequence DDL + "SEQUENCE", "NEXTVAL", "LASTVAL", "SETVAL", + // Temporal tables + "VERSIONING", "PERIOD", "OVERLAPS", + // Hierarchical queries + "PRIOR", "NOCYCLE", + // Index visibility + "INVISIBLE", "VISIBLE", + } + for _, word := range mariadbOnly { + if !kw.IsKeyword(word) { + t.Errorf("expected %q to be a keyword in DialectMariaDB", word) + } + } +} + +func TestMariaDBKeywords_InheritsMySQLKeywords(t *testing.T) { + kw := keywords.New(keywords.DialectMariaDB, true) + + // These are MySQL-specific keywords that MariaDB must also recognize + mysqlKeywords := []string{"UNSIGNED", "ZEROFILL", "KILL", "PURGE", "STATUS", "VARIABLES"} + for _, word := range mysqlKeywords { + if !kw.IsKeyword(word) { + t.Errorf("MariaDB dialect must inherit MySQL keyword %q", word) + } + } +} + +func TestMariaDBKeywords_NotRecognizedInMySQLDialect(t *testing.T) { + kw := keywords.New(keywords.DialectMySQL, true) + + mariadbOnlyKeywords := []string{"VERSIONING", "PRIOR", "NOCYCLE", "INVISIBLE"} + for _, word := range mariadbOnlyKeywords { + if kw.IsKeyword(word) { + t.Errorf("keyword %q should NOT be recognized in pure MySQL dialect", word) + } + } +} + +func TestDetectDialect_MariaDB(t *testing.T) { + tests := []struct { + name string + sql string + }{ + { + name: "CREATE SEQUENCE", + sql: "CREATE SEQUENCE seq_orders START WITH 1 INCREMENT BY 1", + }, + { + name: "WITH SYSTEM VERSIONING", + sql: "CREATE TABLE orders (id INT) WITH SYSTEM VERSIONING", + }, + { + name: "FOR SYSTEM_TIME", + sql: "SELECT * FROM orders FOR SYSTEM_TIME AS OF TIMESTAMP '2024-01-01'", + }, + { + name: "DROP SEQUENCE", + sql: "DROP SEQUENCE seq_orders", + }, + { + name: "NEXTVAL", + sql: "SELECT NEXTVAL(seq_orders)", + }, + { + name: "CONNECT BY with NEXTVAL (MariaDB wins on accumulation)", + sql: "SELECT NEXTVAL(s) FROM t CONNECT BY PRIOR id = parent_id", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := keywords.DetectDialect(tt.sql) + if got != keywords.DialectMariaDB { + t.Errorf("DetectDialect(%q) = %q, want %q", tt.sql, got, keywords.DialectMariaDB) + } + }) + } +} diff --git a/pkg/sql/keywords/snowflake_test.go b/pkg/sql/keywords/snowflake_test.go index 05c779cb..b1d24935 100644 --- a/pkg/sql/keywords/snowflake_test.go +++ b/pkg/sql/keywords/snowflake_test.go @@ -466,6 +466,7 @@ func TestDialectRegistry(t *testing.T) { DialectGeneric: false, DialectPostgreSQL: false, DialectMySQL: false, + DialectMariaDB: false, DialectSQLServer: false, DialectOracle: false, DialectSQLite: false, diff --git a/pkg/sql/parser/ddl.go b/pkg/sql/parser/ddl.go index 06caa179..69698ebf 100644 --- a/pkg/sql/parser/ddl.go +++ b/pkg/sql/parser/ddl.go @@ -82,6 +82,17 @@ func (p *Parser) parseCreateStatement() (ast.Statement, error) { } p.advance() // Consume INDEX return p.parseCreateIndex(true) // Unique + } else if p.isMariaDB() && p.isTokenMatch("SEQUENCE") { + seqPos := p.currentLocation() // position of SEQUENCE token + p.advance() // Consume SEQUENCE + stmt, err := p.parseCreateSequenceStatement(orReplace) + if err != nil { + return nil, err + } + if stmt.Pos.IsZero() { + stmt.Pos = seqPos + } + return stmt, nil } return nil, p.expectedError("TABLE, VIEW, MATERIALIZED VIEW, or INDEX after CREATE") } @@ -121,9 +132,18 @@ func (p *Parser) parseCreateTable(temporary bool) (*ast.CreateTableStatement, er // Parse column definitions and constraints for { - // Check for table-level constraints - if p.isAnyType(models.TokenTypePrimary, models.TokenTypeForeign, + // MariaDB: PERIOD FOR name (start_col, end_col) — application-time or system-time period + if p.isMariaDB() && p.isTokenMatch("PERIOD") { + periodPos := p.currentLocation() // position of PERIOD keyword + pd, err := p.parsePeriodDefinition() + if err != nil { + return nil, err + } + pd.Pos = periodPos + stmt.PeriodDefinitions = append(stmt.PeriodDefinitions, pd) + } else if p.isAnyType(models.TokenTypePrimary, models.TokenTypeForeign, models.TokenTypeUnique, models.TokenTypeCheck, models.TokenTypeConstraint) { + // Check for table-level constraints constraint, err := p.parseTableConstraint() if err != nil { return nil, err @@ -152,6 +172,21 @@ func (p *Parser) parseCreateTable(temporary bool) (*ast.CreateTableStatement, er } p.advance() // Consume ) + // MariaDB: WITH SYSTEM VERSIONING — enables system-versioned temporal history + if p.isMariaDB() && p.isType(models.TokenTypeWith) { + // peek ahead to check for SYSTEM VERSIONING (not WITH TIES or WITH CHECK etc.) + next := p.peekToken() + if strings.EqualFold(next.Token.Value, "SYSTEM") { + p.advance() // Consume WITH + p.advance() // Consume SYSTEM + if !strings.EqualFold(p.currentToken.Token.Value, "VERSIONING") { + return nil, p.expectedError("VERSIONING after WITH SYSTEM") + } + p.advance() // Consume VERSIONING + stmt.WithSystemVersioning = true + } + } + // Parse optional PARTITION BY clause if p.isType(models.TokenTypePartition) { p.advance() // Consume PARTITION diff --git a/pkg/sql/parser/ddl_columns.go b/pkg/sql/parser/ddl_columns.go index f09c3dd8..5118d5bc 100644 --- a/pkg/sql/parser/ddl_columns.go +++ b/pkg/sql/parser/ddl_columns.go @@ -18,6 +18,8 @@ package parser import ( + "strings" + goerrors "github.com/ajitpratap0/GoSQLX/pkg/errors" "github.com/ajitpratap0/GoSQLX/pkg/models" "github.com/ajitpratap0/GoSQLX/pkg/sql/ast" @@ -247,6 +249,34 @@ func (p *Parser) parseColumnConstraint() (*ast.ColumnConstraint, bool, error) { return constraint, true, nil } + // GENERATED ALWAYS AS ROW START / ROW END (MariaDB system-versioned columns) + // Syntax: GENERATED ALWAYS AS ROW START | ROW END + if strings.EqualFold(p.currentToken.Token.Value, "GENERATED") { + p.advance() // Consume GENERATED + // Optional ALWAYS + if strings.EqualFold(p.currentToken.Token.Value, "ALWAYS") { + p.advance() // Consume ALWAYS + } + // Expect AS + if !p.isType(models.TokenTypeAs) { + return nil, false, p.expectedError("AS after GENERATED [ALWAYS]") + } + p.advance() // Consume AS + // Expect ROW + if !p.isType(models.TokenTypeRow) { + return nil, false, p.expectedError("ROW after GENERATED [ALWAYS] AS") + } + p.advance() // Consume ROW + // Expect START or END + rowRole := strings.ToUpper(p.currentToken.Token.Value) + if rowRole != "START" && rowRole != "END" { + return nil, false, p.expectedError("START or END after GENERATED [ALWAYS] AS ROW") + } + p.advance() // Consume START or END + constraint.Type = "GENERATED ALWAYS AS ROW " + rowRole + return constraint, true, nil + } + // No constraint found return nil, false, nil } diff --git a/pkg/sql/parser/dialect_test.go b/pkg/sql/parser/dialect_test.go index ed7036d6..f50f565e 100644 --- a/pkg/sql/parser/dialect_test.go +++ b/pkg/sql/parser/dialect_test.go @@ -170,14 +170,14 @@ func TestRejectUnknownDialect(t *testing.T) { func TestIsValidDialect(t *testing.T) { validDialects := []string{ "postgresql", "mysql", "sqlserver", "oracle", "sqlite", - "snowflake", "bigquery", "redshift", "generic", "", + "snowflake", "bigquery", "redshift", "generic", "mariadb", "", } for _, d := range validDialects { if !keywords.IsValidDialect(d) { t.Errorf("IsValidDialect(%q) should return true", d) } } - invalidDialects := []string{"fakesql", "postgres", "mssql", "pg", "mariadb", "db2"} + invalidDialects := []string{"fakesql", "postgres", "mssql", "pg", "db2"} for _, d := range invalidDialects { if keywords.IsValidDialect(d) { t.Errorf("IsValidDialect(%q) should return false", d) diff --git a/pkg/sql/parser/mariadb.go b/pkg/sql/parser/mariadb.go new file mode 100644 index 00000000..b7199a7d --- /dev/null +++ b/pkg/sql/parser/mariadb.go @@ -0,0 +1,480 @@ +// Copyright 2026 GoSQLX Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package parser + +import ( + "fmt" + "strings" + + "github.com/ajitpratap0/GoSQLX/pkg/models" + "github.com/ajitpratap0/GoSQLX/pkg/sql/ast" + "github.com/ajitpratap0/GoSQLX/pkg/sql/keywords" +) + +// isMariaDB is a convenience helper used throughout the parser. +func (p *Parser) isMariaDB() bool { + return p.dialect == string(keywords.DialectMariaDB) +} + +// isMariaDBClauseStart returns true when the current token is the start of a +// MariaDB hierarchical-query clause (CONNECT BY or START WITH) rather than a +// table alias. Used to guard alias parsing in FROM and JOIN table references. +func (p *Parser) isMariaDBClauseStart() bool { + if !p.isMariaDB() { + return false + } + val := strings.ToUpper(p.currentToken.Token.Value) + if val == "CONNECT" { + next := p.peekToken() + return strings.EqualFold(next.Token.Value, "BY") + } + if val == "START" { + next := p.peekToken() + return strings.EqualFold(next.Token.Value, "WITH") + } + return false +} + +// parseCreateSequenceStatement parses: +// +// CREATE [OR REPLACE] SEQUENCE [IF NOT EXISTS] name [options...] +// +// The caller has already consumed CREATE and SEQUENCE. +func (p *Parser) parseCreateSequenceStatement(orReplace bool) (*ast.CreateSequenceStatement, error) { + stmt := ast.NewCreateSequenceStatement() + stmt.OrReplace = orReplace + + // IF NOT EXISTS + if strings.EqualFold(p.currentToken.Token.Value, "IF") { + p.advance() + if !strings.EqualFold(p.currentToken.Token.Value, "NOT") { + return nil, p.expectedError("NOT") + } + p.advance() + if !strings.EqualFold(p.currentToken.Token.Value, "EXISTS") { + return nil, p.expectedError("EXISTS") + } + p.advance() + stmt.IfNotExists = true + } + + name := p.parseIdent() + if name == nil || name.Name == "" { + return nil, p.expectedError("sequence name") + } + stmt.Name = name + + opts, err := p.parseSequenceOptions() + if err != nil { + return nil, err + } + stmt.Options = opts + return stmt, nil +} + +// parseDropSequenceStatement parses: DROP SEQUENCE [IF EXISTS | IF NOT EXISTS] name +// The caller has already consumed DROP and SEQUENCE. +func (p *Parser) parseDropSequenceStatement() (*ast.DropSequenceStatement, error) { + stmt := ast.NewDropSequenceStatement() + + if strings.EqualFold(p.currentToken.Token.Value, "IF") { + p.advance() + if strings.EqualFold(p.currentToken.Token.Value, "NOT") { + // IF NOT EXISTS is a non-standard permissive extension (MariaDB only supports + // IF EXISTS natively). We accept it and reuse the IfExists flag since both + // forms mean "suppress the error if the sequence is absent". + p.advance() + if !strings.EqualFold(p.currentToken.Token.Value, "EXISTS") { + return nil, p.expectedError("EXISTS") + } + p.advance() + stmt.IfExists = true + } else if strings.EqualFold(p.currentToken.Token.Value, "EXISTS") { + p.advance() + stmt.IfExists = true + } else { + return nil, p.expectedError("EXISTS or NOT EXISTS") + } + } + + name := p.parseIdent() + if name == nil || name.Name == "" { + return nil, p.expectedError("sequence name") + } + stmt.Name = name + return stmt, nil +} + +// parseAlterSequenceStatement parses: ALTER SEQUENCE [IF EXISTS] name [options...] +// The caller has already consumed ALTER and SEQUENCE. +func (p *Parser) parseAlterSequenceStatement() (*ast.AlterSequenceStatement, error) { + stmt := ast.NewAlterSequenceStatement() + + if strings.EqualFold(p.currentToken.Token.Value, "IF") { + p.advance() + if !strings.EqualFold(p.currentToken.Token.Value, "EXISTS") { + return nil, p.expectedError("EXISTS") + } + p.advance() + stmt.IfExists = true + } + + name := p.parseIdent() + if name == nil || name.Name == "" { + return nil, p.expectedError("sequence name") + } + stmt.Name = name + + opts, err := p.parseSequenceOptions() + if err != nil { + return nil, err + } + stmt.Options = opts + return stmt, nil +} + +// parseSequenceOptions parses sequence option keywords until no more are found. +func (p *Parser) parseSequenceOptions() (ast.SequenceOptions, error) { + var opts ast.SequenceOptions + for { + if p.isType(models.TokenTypeSemicolon) || p.isType(models.TokenTypeEOF) { + break + } + + word := strings.ToUpper(p.currentToken.Token.Value) + switch word { + case "START": + p.advance() + if strings.EqualFold(p.currentToken.Token.Value, "WITH") { + p.advance() + } + lit, err := p.parseNumericLit() + if err != nil { + return opts, err + } + opts.StartWith = lit + case "INCREMENT": + p.advance() + if strings.EqualFold(p.currentToken.Token.Value, "BY") { + p.advance() + } + lit, err := p.parseNumericLit() + if err != nil { + return opts, err + } + opts.IncrementBy = lit + case "MINVALUE": + p.advance() + lit, err := p.parseNumericLit() + if err != nil { + return opts, err + } + opts.MinValue = lit + case "MAXVALUE": + p.advance() + lit, err := p.parseNumericLit() + if err != nil { + return opts, err + } + opts.MaxValue = lit + case "NO": + p.advance() + sub := strings.ToUpper(p.currentToken.Token.Value) + p.advance() + switch sub { + case "MINVALUE": + opts.MinValue = nil + case "MAXVALUE": + opts.MaxValue = nil + case "CYCLE": + opts.CycleMode = ast.NoCycleBehavior + case "CACHE": + opts.Cache = nil + opts.NoCache = true + default: + return opts, p.expectedError("MINVALUE, MAXVALUE, CYCLE, or CACHE after NO") + } + case "CYCLE": + p.advance() + opts.CycleMode = ast.CycleBehavior + case "NOCYCLE": + p.advance() + opts.CycleMode = ast.NoCycleBehavior + case "CACHE": + p.advance() + lit, err := p.parseNumericLit() + if err != nil { + return opts, err + } + opts.Cache = lit + case "NOCACHE": + p.advance() + opts.NoCache = true + case "RESTART": + p.advance() + if strings.EqualFold(p.currentToken.Token.Value, "WITH") { + p.advance() + lit, err := p.parseNumericLit() + if err != nil { + return opts, err + } + opts.RestartWith = lit + } else { + opts.Restart = true + } + default: + return opts, nil + } + } + // Validate: CACHE n and NOCACHE are mutually exclusive. + if opts.Cache != nil && opts.NoCache { + return opts, fmt.Errorf("contradictory sequence options: CACHE and NOCACHE cannot both be specified") + } + return opts, nil +} + +// parseNumericLit reads a numeric literal token and returns a LiteralValue. +func (p *Parser) parseNumericLit() (*ast.LiteralValue, error) { + if !p.isNumericLiteral() { + return nil, p.expectedError("numeric literal") + } + value := p.currentToken.Token.Value + litType := "int" + if strings.ContainsAny(value, ".eE") { + litType = "float" + } + p.advance() + return &ast.LiteralValue{Value: value, Type: litType}, nil +} + +// parseForSystemTimeClause parses the FOR SYSTEM_TIME clause that follows a table reference. +// The caller has already consumed FOR. +func (p *Parser) parseForSystemTimeClause() (*ast.ForSystemTimeClause, error) { + if !strings.EqualFold(p.currentToken.Token.Value, "SYSTEM_TIME") { + return nil, fmt.Errorf("expected SYSTEM_TIME after FOR, got %q", p.currentToken.Token.Value) + } + sysTimePos := p.currentLocation() // position of SYSTEM_TIME token + p.advance() + + clause := &ast.ForSystemTimeClause{} + clause.Pos = sysTimePos + word := strings.ToUpper(p.currentToken.Token.Value) + + switch word { + case "AS": + p.advance() + if !strings.EqualFold(p.currentToken.Token.Value, "OF") { + return nil, fmt.Errorf("expected OF after AS, got %q", p.currentToken.Token.Value) + } + p.advance() + expr, err := p.parseTemporalPointExpression() + if err != nil { + return nil, err + } + clause.Type = ast.SystemTimeAsOf + clause.Point = expr + case "BETWEEN": + p.advance() + // Use parsePrimaryExpression to avoid consuming AND as a binary logical operator. + start, err := p.parseTemporalPointExpression() + if err != nil { + return nil, err + } + if !strings.EqualFold(p.currentToken.Token.Value, "AND") { + return nil, fmt.Errorf("expected AND in FOR SYSTEM_TIME BETWEEN, got %q", p.currentToken.Token.Value) + } + p.advance() + end, err := p.parseTemporalPointExpression() + if err != nil { + return nil, err + } + clause.Type = ast.SystemTimeBetween + clause.Start = start + clause.End = end + case "FROM": + p.advance() + start, err := p.parseTemporalPointExpression() + if err != nil { + return nil, err + } + if !strings.EqualFold(p.currentToken.Token.Value, "TO") { + return nil, fmt.Errorf("expected TO in FOR SYSTEM_TIME FROM, got %q", p.currentToken.Token.Value) + } + p.advance() + end, err := p.parseTemporalPointExpression() + if err != nil { + return nil, err + } + clause.Type = ast.SystemTimeFromTo + clause.Start = start + clause.End = end + case "ALL": + p.advance() + clause.Type = ast.SystemTimeAll + default: + return nil, fmt.Errorf("expected AS OF, BETWEEN, FROM, or ALL after FOR SYSTEM_TIME, got %q", word) + } + return clause, nil +} + +// parseTemporalPointExpression parses a temporal point expression for FOR SYSTEM_TIME clauses. +// Handles typed string literals like TIMESTAMP '2024-01-01' and DATE '2024-01-01', +// as well as plain string literals and other primary expressions. +func (p *Parser) parseTemporalPointExpression() (ast.Expression, error) { + // Handle TIMESTAMP 'str', DATE 'str', TIME 'str' typed literals. + word := strings.ToUpper(p.currentToken.Token.Value) + if word == "TIMESTAMP" || word == "DATE" || word == "TIME" { + typeKeyword := p.currentToken.Token.Value + p.advance() + if !p.isStringLiteral() { + return nil, fmt.Errorf("expected string literal after %s, got %q", typeKeyword, p.currentToken.Token.Value) + } + // The tokenizer strips surrounding single quotes from string literal tokens, + // so p.currentToken.Token.Value is the raw string content (e.g. "2023-01-01 00:00:00"). + // We reconstruct the canonical form: TYPE 'value'. + value := typeKeyword + " '" + p.currentToken.Token.Value + "'" + p.advance() + return &ast.LiteralValue{Value: value, Type: "timestamp"}, nil + } + // Fall back to primary expression (handles plain string literals, numbers, identifiers). + return p.parsePrimaryExpression() +} + +// parseConnectByCondition parses the condition expression for CONNECT BY. +// It handles the PRIOR prefix operator in either position: +// +// CONNECT BY PRIOR id = parent_id (PRIOR on left) +// CONNECT BY id = PRIOR parent_id (PRIOR on right) +// CONNECT BY PRIOR id = parent_id AND active = 1 (complex with AND/OR) +// +// PRIOR references the value from the parent row in the hierarchy. +// It is modeled as UnaryExpression{Operator: ast.Prior, Expr: }. +func (p *Parser) parseConnectByCondition() (ast.Expression, error) { + var base ast.Expression + + // Case 1: PRIOR col op col + if strings.EqualFold(p.currentToken.Token.Value, "PRIOR") { + p.advance() + priorIdent := p.parseIdent() + if priorIdent == nil || priorIdent.Name == "" { + return nil, p.expectedError("column name after PRIOR") + } + priorExpr := &ast.UnaryExpression{Operator: ast.Prior, Expr: priorIdent} + + if p.isType(models.TokenTypeEq) || p.isType(models.TokenTypeNeq) || + p.isType(models.TokenTypeLt) || p.isType(models.TokenTypeGt) || + p.isType(models.TokenTypeLtEq) || p.isType(models.TokenTypeGtEq) { + op := p.currentToken.Token.Value + p.advance() + right, err := p.parsePrimaryExpression() + if err != nil { + return nil, err + } + base = &ast.BinaryExpression{Left: priorExpr, Operator: op, Right: right} + } else { + base = priorExpr + } + } else { + // Case 2: col op PRIOR col (PRIOR on the right-hand side) + // or plain expression (no PRIOR) + left, err := p.parsePrimaryExpression() + if err != nil { + return nil, err + } + if p.isType(models.TokenTypeEq) || p.isType(models.TokenTypeNeq) || + p.isType(models.TokenTypeLt) || p.isType(models.TokenTypeGt) || + p.isType(models.TokenTypeLtEq) || p.isType(models.TokenTypeGtEq) { + op := p.currentToken.Token.Value + p.advance() + // Check for PRIOR on the right side + if strings.EqualFold(p.currentToken.Token.Value, "PRIOR") { + p.advance() + priorIdent := p.parseIdent() + if priorIdent == nil || priorIdent.Name == "" { + return nil, p.expectedError("column name after PRIOR") + } + priorExpr := &ast.UnaryExpression{Operator: ast.Prior, Expr: priorIdent} + base = &ast.BinaryExpression{Left: left, Operator: op, Right: priorExpr} + } else { + right, err := p.parsePrimaryExpression() + if err != nil { + return nil, err + } + base = &ast.BinaryExpression{Left: left, Operator: op, Right: right} + } + } else { + base = left + } + } + + // Handle AND/OR chaining for complex conditions like: + // PRIOR id = parent_id AND active = 1 + for strings.EqualFold(p.currentToken.Token.Value, "AND") || + strings.EqualFold(p.currentToken.Token.Value, "OR") { + logicOp := p.currentToken.Token.Value + p.advance() + rest, err := p.parseConnectByCondition() + if err != nil { + return nil, err + } + base = &ast.BinaryExpression{Left: base, Operator: logicOp, Right: rest} + } + + return base, nil +} + +// parsePeriodDefinition parses: PERIOD FOR name (start_col, end_col) +// The caller positions the parser at the PERIOD keyword; this function advances past it. +func (p *Parser) parsePeriodDefinition() (*ast.PeriodDefinition, error) { + // current token is PERIOD; advance past it + p.advance() + if !strings.EqualFold(p.currentToken.Token.Value, "FOR") { + return nil, p.expectedError("FOR") + } + p.advance() + + // Use parseColumnName so that reserved-keyword period names like SYSTEM_TIME are accepted. + name := p.parseColumnName() + if name == nil || name.Name == "" { + return nil, p.expectedError("period name") + } + + if !p.isType(models.TokenTypeLParen) { + return nil, p.expectedError("(") + } + p.advance() + + startCol := p.parseIdent() + if startCol == nil || startCol.Name == "" { + return nil, p.expectedError("start column name") + } + + if !p.isType(models.TokenTypeComma) { + return nil, p.expectedError(",") + } + p.advance() + + endCol := p.parseIdent() + if endCol == nil || endCol.Name == "" { + return nil, p.expectedError("end column name") + } + + if !p.isType(models.TokenTypeRParen) { + return nil, p.expectedError(")") + } + p.advance() + + return &ast.PeriodDefinition{Name: name, StartCol: startCol, EndCol: endCol}, nil +} diff --git a/pkg/sql/parser/mariadb_bench_test.go b/pkg/sql/parser/mariadb_bench_test.go new file mode 100644 index 00000000..616657a7 --- /dev/null +++ b/pkg/sql/parser/mariadb_bench_test.go @@ -0,0 +1,260 @@ +// Copyright 2026 GoSQLX Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package parser_test + +import ( + "testing" + + "github.com/ajitpratap0/GoSQLX/pkg/sql/ast" + "github.com/ajitpratap0/GoSQLX/pkg/sql/keywords" + "github.com/ajitpratap0/GoSQLX/pkg/sql/parser" + "github.com/ajitpratap0/GoSQLX/pkg/sql/tokenizer" +) + +// BenchmarkMariaDB_Sequence benchmarks MariaDB SEQUENCE DDL parsing. +func BenchmarkMariaDB_Sequence(b *testing.B) { + benchmarks := []struct { + name string + sql string + }{ + { + name: "create_minimal", + sql: "CREATE SEQUENCE seq_orders", + }, + { + name: "create_all_options", + sql: "CREATE SEQUENCE s START WITH 1000 INCREMENT BY 5 MINVALUE 1 MAXVALUE 9999 CACHE 20 CYCLE", + }, + { + name: "create_or_replace_nocache", + sql: "CREATE OR REPLACE SEQUENCE s NOCACHE NOCYCLE", + }, + { + name: "alter_restart_with", + sql: "ALTER SEQUENCE s RESTART WITH 5000", + }, + { + name: "drop_if_exists", + sql: "DROP SEQUENCE IF EXISTS seq_orders", + }, + } + + for _, bm := range benchmarks { + bm := bm + b.Run(bm.name, func(b *testing.B) { + tkz := tokenizer.GetTokenizer() + defer tokenizer.PutTokenizer(tkz) + + tokens, err := tkz.Tokenize([]byte(bm.sql)) + if err != nil { + b.Fatalf("Tokenize error: %v", err) + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + p := parser.NewParser(parser.WithDialect(string(keywords.DialectMariaDB))) + result, err := p.ParseFromModelTokens(tokens) + if err != nil { + b.Fatalf("Parse error: %v", err) + } + ast.ReleaseAST(result) + p.Release() + } + }) + } +} + +// BenchmarkMariaDB_ForSystemTime benchmarks MariaDB temporal table query parsing. +func BenchmarkMariaDB_ForSystemTime(b *testing.B) { + benchmarks := []struct { + name string + sql string + }{ + { + name: "as_of_timestamp", + sql: "SELECT * FROM t FOR SYSTEM_TIME AS OF TIMESTAMP '2024-01-01 00:00:00'", + }, + { + name: "all", + sql: "SELECT id, name FROM orders FOR SYSTEM_TIME ALL", + }, + { + name: "between", + sql: "SELECT * FROM t FOR SYSTEM_TIME BETWEEN TIMESTAMP '2023-01-01' AND TIMESTAMP '2023-12-31'", + }, + { + name: "from_to", + sql: "SELECT * FROM t FOR SYSTEM_TIME FROM TIMESTAMP '2023-01-01' TO TIMESTAMP '2024-01-01'", + }, + { + name: "join_with_system_time", + sql: `SELECT o.id, h.status + FROM orders o + JOIN order_history h FOR SYSTEM_TIME AS OF TIMESTAMP '2024-01-01' + ON o.id = h.order_id`, + }, + } + + for _, bm := range benchmarks { + bm := bm + b.Run(bm.name, func(b *testing.B) { + tkz := tokenizer.GetTokenizer() + defer tokenizer.PutTokenizer(tkz) + + tokens, err := tkz.Tokenize([]byte(bm.sql)) + if err != nil { + b.Fatalf("Tokenize error: %v", err) + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + p := parser.NewParser(parser.WithDialect(string(keywords.DialectMariaDB))) + result, err := p.ParseFromModelTokens(tokens) + if err != nil { + b.Fatalf("Parse error: %v", err) + } + ast.ReleaseAST(result) + p.Release() + } + }) + } +} + +// BenchmarkMariaDB_ConnectBy benchmarks MariaDB CONNECT BY hierarchical query parsing. +func BenchmarkMariaDB_ConnectBy(b *testing.B) { + benchmarks := []struct { + name string + sql string + }{ + { + name: "simple_prior_left", + sql: `SELECT id, name FROM employees + START WITH parent_id IS NULL + CONNECT BY PRIOR id = parent_id`, + }, + { + name: "prior_right", + sql: `SELECT id, name FROM employees + START WITH id = 1 + CONNECT BY id = PRIOR parent_id`, + }, + { + name: "nocycle", + sql: `SELECT id, name, level FROM employees + START WITH parent_id IS NULL + CONNECT BY NOCYCLE PRIOR id = parent_id`, + }, + { + name: "with_where_and_order", + sql: `SELECT id, name FROM employees + WHERE active = 1 + START WITH parent_id IS NULL + CONNECT BY PRIOR id = parent_id + ORDER BY id`, + }, + } + + for _, bm := range benchmarks { + bm := bm + b.Run(bm.name, func(b *testing.B) { + tkz := tokenizer.GetTokenizer() + defer tokenizer.PutTokenizer(tkz) + + tokens, err := tkz.Tokenize([]byte(bm.sql)) + if err != nil { + b.Fatalf("Tokenize error: %v", err) + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + p := parser.NewParser(parser.WithDialect(string(keywords.DialectMariaDB))) + result, err := p.ParseFromModelTokens(tokens) + if err != nil { + b.Fatalf("Parse error: %v", err) + } + ast.ReleaseAST(result) + p.Release() + } + }) + } +} + +// BenchmarkMariaDB_Mixed benchmarks parsing of queries that combine multiple +// MariaDB-specific features in a single statement. +func BenchmarkMariaDB_Mixed(b *testing.B) { + benchmarks := []struct { + name string + sql string + }{ + { + name: "temporal_with_cte", + sql: `WITH history AS ( + SELECT * FROM orders FOR SYSTEM_TIME ALL + ) + SELECT id, status FROM history WHERE status = 'cancelled'`, + }, + { + name: "hierarchical_with_cte", + sql: `WITH RECURSIVE org AS ( + SELECT id, name, parent_id FROM employees + START WITH parent_id IS NULL + CONNECT BY PRIOR id = parent_id + ) + SELECT * FROM org ORDER BY id`, + }, + { + name: "create_table_versioned", + sql: `CREATE TABLE orders ( + id INT PRIMARY KEY, + status VARCHAR(50), + row_start DATETIME(6) GENERATED ALWAYS AS ROW START, + row_end DATETIME(6) GENERATED ALWAYS AS ROW END, + PERIOD FOR SYSTEM_TIME(row_start, row_end) + ) WITH SYSTEM VERSIONING`, + }, + } + + for _, bm := range benchmarks { + bm := bm + b.Run(bm.name, func(b *testing.B) { + tkz := tokenizer.GetTokenizer() + defer tokenizer.PutTokenizer(tkz) + + tokens, err := tkz.Tokenize([]byte(bm.sql)) + if err != nil { + b.Fatalf("Tokenize error: %v", err) + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + p := parser.NewParser(parser.WithDialect(string(keywords.DialectMariaDB))) + result, err := p.ParseFromModelTokens(tokens) + if err != nil { + b.Fatalf("Parse error: %v", err) + } + ast.ReleaseAST(result) + p.Release() + } + }) + } +} diff --git a/pkg/sql/parser/mariadb_test.go b/pkg/sql/parser/mariadb_test.go new file mode 100644 index 00000000..bd9f7a4d --- /dev/null +++ b/pkg/sql/parser/mariadb_test.go @@ -0,0 +1,356 @@ +package parser_test + +import ( + "os" + "strings" + "testing" + + "github.com/ajitpratap0/GoSQLX/pkg/sql/ast" + "github.com/ajitpratap0/GoSQLX/pkg/sql/keywords" + "github.com/ajitpratap0/GoSQLX/pkg/sql/parser" +) + +// ── Task 7: SEQUENCE Tests ──────────────────────────────────────────────────── + +func TestMariaDB_CreateSequence_Basic(t *testing.T) { + sql := "CREATE SEQUENCE seq_orders START WITH 1 INCREMENT BY 1" + tree, err := parser.ParseWithDialect(sql, keywords.DialectMariaDB) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(tree.Statements) != 1 { + t.Fatalf("expected 1 statement, got %d", len(tree.Statements)) + } + stmt, ok := tree.Statements[0].(*ast.CreateSequenceStatement) + if !ok { + t.Fatalf("expected CreateSequenceStatement, got %T", tree.Statements[0]) + } + if stmt.Name.Name != "seq_orders" { + t.Errorf("expected name %q, got %q", "seq_orders", stmt.Name.Name) + } + if stmt.Options.StartWith == nil { + t.Error("expected StartWith to be set") + } +} + +func TestMariaDB_CreateSequence_AllOptions(t *testing.T) { + sql := `CREATE SEQUENCE s START WITH 100 INCREMENT BY 5 MINVALUE 1 MAXVALUE 9999 CYCLE CACHE 20` + tree, err := parser.ParseWithDialect(sql, keywords.DialectMariaDB) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + stmt := tree.Statements[0].(*ast.CreateSequenceStatement) + if stmt.Options.CycleMode != ast.CycleBehavior { + t.Error("expected CycleMode = CycleBehavior") + } + if stmt.Options.Cache == nil { + t.Error("expected Cache to be set") + } +} + +func TestMariaDB_CreateSequence_IfNotExists(t *testing.T) { + sql := "CREATE SEQUENCE IF NOT EXISTS my_seq" + tree, err := parser.ParseWithDialect(sql, keywords.DialectMariaDB) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + stmt := tree.Statements[0].(*ast.CreateSequenceStatement) + if !stmt.IfNotExists { + t.Error("expected IfNotExists = true") + } +} + +func TestMariaDB_DropSequence(t *testing.T) { + sql := "DROP SEQUENCE IF EXISTS seq_orders" + tree, err := parser.ParseWithDialect(sql, keywords.DialectMariaDB) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + stmt, ok := tree.Statements[0].(*ast.DropSequenceStatement) + if !ok { + t.Fatalf("expected DropSequenceStatement, got %T", tree.Statements[0]) + } + if !stmt.IfExists { + t.Error("expected IfExists = true") + } +} + +func TestMariaDB_AlterSequence_Restart(t *testing.T) { + sql := "ALTER SEQUENCE seq_orders RESTART WITH 500" + tree, err := parser.ParseWithDialect(sql, keywords.DialectMariaDB) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + stmt, ok := tree.Statements[0].(*ast.AlterSequenceStatement) + if !ok { + t.Fatalf("expected AlterSequenceStatement, got %T", tree.Statements[0]) + } + if stmt.Options.RestartWith == nil { + t.Error("expected RestartWith to be set") + } +} + +func TestMariaDB_AlterSequence_RestartBare(t *testing.T) { + sql := "ALTER SEQUENCE seq_orders RESTART" + tree, err := parser.ParseWithDialect(sql, keywords.DialectMariaDB) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + stmt, ok := tree.Statements[0].(*ast.AlterSequenceStatement) + if !ok { + t.Fatalf("expected AlterSequenceStatement, got %T", tree.Statements[0]) + } + if !stmt.Options.Restart { + t.Error("expected Restart = true") + } + if stmt.Options.RestartWith != nil { + t.Error("expected RestartWith = nil for bare RESTART") + } +} + +func TestMariaDB_SequenceNotRecognizedInMySQL(t *testing.T) { + sql := "CREATE SEQUENCE seq1 START WITH 1" + _, err := parser.ParseWithDialect(sql, keywords.DialectMySQL) + if err == nil { + t.Error("expected error when parsing CREATE SEQUENCE in MySQL dialect") + } +} + +// ── Task 8: Temporal Table Tests ────────────────────────────────────────────── + +func TestMariaDB_CreateTable_WithSystemVersioning(t *testing.T) { + sql := "CREATE TABLE orders (id INT PRIMARY KEY, total DECIMAL(10,2)) WITH SYSTEM VERSIONING" + tree, err := parser.ParseWithDialect(sql, keywords.DialectMariaDB) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + stmt, ok := tree.Statements[0].(*ast.CreateTableStatement) + if !ok { + t.Fatalf("expected CreateTableStatement, got %T", tree.Statements[0]) + } + if !stmt.WithSystemVersioning { + t.Error("expected WithSystemVersioning = true") + } +} + +func TestMariaDB_SelectForSystemTime_AsOf(t *testing.T) { + sql := "SELECT id FROM orders FOR SYSTEM_TIME AS OF TIMESTAMP '2024-01-15 10:00:00'" + tree, err := parser.ParseWithDialect(sql, keywords.DialectMariaDB) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + sel := tree.Statements[0].(*ast.SelectStatement) + if len(sel.From) == 0 { + t.Fatal("expected FROM clause") + } + ref := &sel.From[0] + if ref.ForSystemTime == nil { + t.Error("expected ForSystemTime to be set") + } + if ref.ForSystemTime.Type != ast.SystemTimeAsOf { + t.Errorf("expected AS OF, got %v", ref.ForSystemTime.Type) + } +} + +func TestMariaDB_SelectForSystemTime_All(t *testing.T) { + sql := "SELECT * FROM orders FOR SYSTEM_TIME ALL" + tree, err := parser.ParseWithDialect(sql, keywords.DialectMariaDB) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + sel := tree.Statements[0].(*ast.SelectStatement) + ref := &sel.From[0] + if ref.ForSystemTime == nil || ref.ForSystemTime.Type != ast.SystemTimeAll { + t.Error("expected SystemTimeAll") + } +} + +func TestMariaDB_SelectForSystemTime_Between(t *testing.T) { + sql := "SELECT * FROM orders FOR SYSTEM_TIME BETWEEN '2020-01-01' AND '2024-01-01'" + tree, err := parser.ParseWithDialect(sql, keywords.DialectMariaDB) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + sel := tree.Statements[0].(*ast.SelectStatement) + ref := &sel.From[0] + if ref.ForSystemTime == nil || ref.ForSystemTime.Type != ast.SystemTimeBetween { + t.Error("expected SystemTimeBetween") + } +} + +// ── Task 9: CONNECT BY Tests ────────────────────────────────────────────────── + +func TestMariaDB_ConnectBy_Basic(t *testing.T) { + sql := `SELECT id, name FROM category START WITH parent_id IS NULL CONNECT BY PRIOR id = parent_id` + tree, err := parser.ParseWithDialect(sql, keywords.DialectMariaDB) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + sel, ok := tree.Statements[0].(*ast.SelectStatement) + if !ok { + t.Fatalf("expected SelectStatement, got %T", tree.Statements[0]) + } + if sel.StartWith == nil { + t.Error("expected StartWith to be set") + } + if sel.ConnectBy == nil { + t.Error("expected ConnectBy to be set") + } + if sel.ConnectBy.NoCycle { + t.Error("expected NoCycle = false") + } +} + +func TestMariaDB_ConnectBy_NoCycle(t *testing.T) { + sql := `SELECT id FROM t CONNECT BY NOCYCLE PRIOR id = parent_id` + tree, err := parser.ParseWithDialect(sql, keywords.DialectMariaDB) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + sel := tree.Statements[0].(*ast.SelectStatement) + if sel.ConnectBy == nil || !sel.ConnectBy.NoCycle { + t.Error("expected NoCycle = true") + } +} + +func TestMariaDB_ConnectBy_NoStartWith(t *testing.T) { + sql := `SELECT id FROM t CONNECT BY PRIOR id = parent_id` + tree, err := parser.ParseWithDialect(sql, keywords.DialectMariaDB) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + sel := tree.Statements[0].(*ast.SelectStatement) + if sel.ConnectBy == nil { + t.Error("expected ConnectBy to be set") + } +} + +// TestMariaDB_ConnectBy_PriorOnRight verifies PRIOR on the right-hand side of the condition. +func TestMariaDB_ConnectBy_PriorOnRight(t *testing.T) { + sql := "SELECT id FROM employees CONNECT BY id = PRIOR parent_id" + tree, err := parser.ParseWithDialect(sql, keywords.DialectMariaDB) + if err != nil { + t.Fatalf("unexpected parse error: %v", err) + } + sel, ok := tree.Statements[0].(*ast.SelectStatement) + if !ok { + t.Fatalf("expected SelectStatement") + } + if sel.ConnectBy == nil { + t.Fatal("expected ConnectBy clause") + } + bin, ok := sel.ConnectBy.Condition.(*ast.BinaryExpression) + if !ok { + t.Fatalf("expected BinaryExpression, got %T", sel.ConnectBy.Condition) + } + // Right side should be PRIOR parent_id + unary, ok := bin.Right.(*ast.UnaryExpression) + if !ok { + t.Fatalf("expected UnaryExpression on right, got %T", bin.Right) + } + if unary.Operator != ast.Prior { + t.Errorf("expected Prior operator, got %v", unary.Operator) + } +} + +// TestMariaDB_DropSequence_IfNotExists verifies DROP SEQUENCE IF NOT EXISTS is accepted. +func TestMariaDB_DropSequence_IfNotExists(t *testing.T) { + sql := "DROP SEQUENCE IF NOT EXISTS my_seq" + tree, err := parser.ParseWithDialect(sql, keywords.DialectMariaDB) + if err != nil { + t.Fatalf("unexpected parse error: %v", err) + } + stmt, ok := tree.Statements[0].(*ast.DropSequenceStatement) + if !ok { + t.Fatalf("expected DropSequenceStatement, got %T", tree.Statements[0]) + } + if !stmt.IfExists { + t.Error("expected IfExists=true") + } + if stmt.Name == nil || stmt.Name.Name != "my_seq" { + t.Errorf("expected name my_seq, got %v", stmt.Name) + } +} + +// TestMariaDB_Sequence_NoCache verifies NOCACHE sets the NoCache field. +func TestMariaDB_Sequence_NoCache(t *testing.T) { + sql := "CREATE SEQUENCE s NOCACHE" + tree, err := parser.ParseWithDialect(sql, keywords.DialectMariaDB) + if err != nil { + t.Fatalf("unexpected parse error: %v", err) + } + stmt, ok := tree.Statements[0].(*ast.CreateSequenceStatement) + if !ok { + t.Fatalf("expected CreateSequenceStatement") + } + if !stmt.Options.NoCache { + t.Error("expected NoCache=true") + } +} + +// ── Task 10: File-based Integration Tests ───────────────────────────────────── + +func TestMariaDB_SQLFiles(t *testing.T) { + files := []string{ + "testdata/mariadb/sequences.sql", + "testdata/mariadb/temporal.sql", + "testdata/mariadb/connect_by.sql", + "testdata/mariadb/mixed.sql", + } + for _, f := range files { + t.Run(f, func(t *testing.T) { + data, err := os.ReadFile(f) + if err != nil { + t.Fatalf("failed to read %s: %v", f, err) + } + // Split on semicolons to get individual statements + stmts := strings.Split(string(data), ";") + for _, raw := range stmts { + sql := strings.TrimSpace(raw) + if sql == "" { + continue + } + _, err := parser.ParseWithDialect(sql, keywords.DialectMariaDB) + if err != nil { + t.Errorf("failed to parse %q: %v", sql, err) + } + } + }) + } +} + +func TestMariaDB_CreateTable_PeriodForSystemTime(t *testing.T) { + sql := `CREATE TABLE t ( + id INT, + row_start DATETIME(6) GENERATED ALWAYS AS ROW START, + row_end DATETIME(6) GENERATED ALWAYS AS ROW END, + PERIOD FOR SYSTEM_TIME(row_start, row_end) + ) WITH SYSTEM VERSIONING` + tree, err := parser.ParseWithDialect(sql, keywords.DialectMariaDB) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(tree.Statements) != 1 { + t.Fatalf("expected 1 statement, got %d", len(tree.Statements)) + } + stmt, ok := tree.Statements[0].(*ast.CreateTableStatement) + if !ok { + t.Fatalf("expected CreateTableStatement, got %T", tree.Statements[0]) + } + if len(stmt.PeriodDefinitions) == 0 { + t.Fatal("expected at least one PeriodDefinition") + } + pd := stmt.PeriodDefinitions[0] + if pd.Name == nil || !strings.EqualFold(pd.Name.Name, "SYSTEM_TIME") { + t.Errorf("expected period name SYSTEM_TIME, got %v", pd.Name) + } + if pd.StartCol == nil || pd.StartCol.Name != "row_start" { + t.Errorf("expected StartCol=row_start, got %v", pd.StartCol) + } + if pd.EndCol == nil || pd.EndCol.Name != "row_end" { + t.Errorf("expected EndCol=row_end, got %v", pd.EndCol) + } + if !stmt.WithSystemVersioning { + t.Error("expected WithSystemVersioning = true") + } +} diff --git a/pkg/sql/parser/parser.go b/pkg/sql/parser/parser.go index 8c12f013..678b1714 100644 --- a/pkg/sql/parser/parser.go +++ b/pkg/sql/parser/parser.go @@ -629,7 +629,20 @@ func (p *Parser) parseStatement() (ast.Statement, error) { } return stmt, nil case models.TokenTypeAlter: + stmtPos := p.currentLocation() p.advance() + // MariaDB: ALTER SEQUENCE [IF EXISTS] name [options...] + if p.isMariaDB() && p.isTokenMatch("SEQUENCE") { + p.advance() // Consume SEQUENCE + stmt, err := p.parseAlterSequenceStatement() + if err != nil { + return nil, err + } + if stmt.Pos.IsZero() { + stmt.Pos = stmtPos + } + return stmt, nil + } return p.parseAlterTableStmt() case models.TokenTypeMerge: p.advance() @@ -638,7 +651,20 @@ func (p *Parser) parseStatement() (ast.Statement, error) { p.advance() return p.parseCreateStatement() case models.TokenTypeDrop: + stmtPos := p.currentLocation() p.advance() + // MariaDB: DROP SEQUENCE [IF EXISTS | IF NOT EXISTS] name + if p.isMariaDB() && p.isTokenMatch("SEQUENCE") { + p.advance() // Consume SEQUENCE + stmt, err := p.parseDropSequenceStatement() + if err != nil { + return nil, err + } + if stmt.Pos.IsZero() { + stmt.Pos = stmtPos + } + return stmt, nil + } return p.parseDropStatement() case models.TokenTypeRefresh: p.advance() diff --git a/pkg/sql/parser/select.go b/pkg/sql/parser/select.go index 0abb6825..63c8eeda 100644 --- a/pkg/sql/parser/select.go +++ b/pkg/sql/parser/select.go @@ -109,6 +109,45 @@ func (p *Parser) parseSelectStatement() (ast.Statement, error) { return nil, err } + // MariaDB: START WITH ... CONNECT BY hierarchical queries (10.2+) + if p.isMariaDB() { + if strings.EqualFold(p.currentToken.Token.Value, "START") { + p.advance() // Consume START + if !strings.EqualFold(p.currentToken.Token.Value, "WITH") { + return nil, fmt.Errorf("expected WITH after START, got %q", p.currentToken.Token.Value) + } + p.advance() // Consume WITH + startExpr, startErr := p.parseExpression() + if startErr != nil { + return nil, startErr + } + selectStmt.StartWith = startExpr + } + if strings.EqualFold(p.currentToken.Token.Value, "CONNECT") { + connectPos := p.currentLocation() // position of CONNECT keyword + p.advance() // Consume CONNECT + if !strings.EqualFold(p.currentToken.Token.Value, "BY") { + return nil, fmt.Errorf("expected BY after CONNECT, got %q", p.currentToken.Token.Value) + } + p.advance() // Consume BY + cb := &ast.ConnectByClause{} + cb.Pos = connectPos + if strings.EqualFold(p.currentToken.Token.Value, "NOCYCLE") { + cb.NoCycle = true + p.advance() // Consume NOCYCLE + } + cond, condErr := p.parseConnectByCondition() + if condErr != nil { + return nil, condErr + } + if cond == nil { + return nil, fmt.Errorf("expected condition after CONNECT BY") + } + cb.Condition = cond + selectStmt.ConnectBy = cb + } + } + // ORDER BY if selectStmt.OrderBy, err = p.parseOrderByClause(); err != nil { return nil, err diff --git a/pkg/sql/parser/select_subquery.go b/pkg/sql/parser/select_subquery.go index 61ee478b..f1eff0ae 100644 --- a/pkg/sql/parser/select_subquery.go +++ b/pkg/sql/parser/select_subquery.go @@ -84,8 +84,10 @@ func (p *Parser) parseFromTableReference() (ast.TableReference, error) { } } - // Check for table alias (required for derived tables, optional for regular tables) - if p.isIdentifier() || p.isType(models.TokenTypeAs) { + // Check for table alias (required for derived tables, optional for regular tables). + // Guard: in MariaDB, CONNECT followed by BY is a hierarchical query clause, not an alias. + // Similarly, START followed by WITH is a hierarchical query seed, not an alias. + if (p.isIdentifier() || p.isType(models.TokenTypeAs)) && !p.isMariaDBClauseStart() { if p.isType(models.TokenTypeAs) { p.advance() // Consume AS if !p.isIdentifier() { @@ -98,6 +100,20 @@ func (p *Parser) parseFromTableReference() (ast.TableReference, error) { } } + // MariaDB FOR SYSTEM_TIME temporal query (10.3.4+) + if p.isMariaDB() && p.isType(models.TokenTypeFor) { + // Only parse as FOR SYSTEM_TIME if next token is SYSTEM_TIME + next := p.peekToken() + if strings.EqualFold(next.Token.Value, "SYSTEM_TIME") { + p.advance() // Consume FOR + sysTime, err := p.parseForSystemTimeClause() + if err != nil { + return tableRef, err + } + tableRef.ForSystemTime = sysTime + } + } + // SQL Server table hints: WITH (NOLOCK), WITH (ROWLOCK, UPDLOCK), etc. if p.dialect == string(keywords.DialectSQLServer) && p.isType(models.TokenTypeWith) { if p.peekToken().Token.Type == models.TokenTypeLParen { @@ -160,8 +176,10 @@ func (p *Parser) parseJoinedTableRef(joinType string) (ast.TableReference, error ref = ast.TableReference{Name: joinedName, Lateral: isLateral} } - // Optional alias - if p.isIdentifier() || p.isType(models.TokenTypeAs) { + // Optional alias. + // Guard: in MariaDB, CONNECT followed by BY is a hierarchical query clause, not an alias. + // Similarly, START followed by WITH is a hierarchical query seed, not an alias. + if (p.isIdentifier() || p.isType(models.TokenTypeAs)) && !p.isMariaDBClauseStart() { if p.isType(models.TokenTypeAs) { p.advance() if !p.isIdentifier() { @@ -174,6 +192,20 @@ func (p *Parser) parseJoinedTableRef(joinType string) (ast.TableReference, error } } + // MariaDB FOR SYSTEM_TIME temporal query (10.3.4+) + if p.isMariaDB() && p.isType(models.TokenTypeFor) { + // Only parse as FOR SYSTEM_TIME if next token is SYSTEM_TIME + next := p.peekToken() + if strings.EqualFold(next.Token.Value, "SYSTEM_TIME") { + p.advance() // Consume FOR + sysTime, err := p.parseForSystemTimeClause() + if err != nil { + return ref, err + } + ref.ForSystemTime = sysTime + } + } + // SQL Server table hints if p.dialect == string(keywords.DialectSQLServer) && p.isType(models.TokenTypeWith) { if p.peekToken().Token.Type == models.TokenTypeLParen { diff --git a/pkg/sql/parser/testdata/mariadb/connect_by.sql b/pkg/sql/parser/testdata/mariadb/connect_by.sql new file mode 100644 index 00000000..df918111 --- /dev/null +++ b/pkg/sql/parser/testdata/mariadb/connect_by.sql @@ -0,0 +1,9 @@ +SELECT id, name, parent_id FROM categories START WITH parent_id IS NULL CONNECT BY PRIOR id = parent_id; +SELECT id, name FROM employees CONNECT BY NOCYCLE PRIOR manager_id = id; +SELECT id, name, parent_id +FROM employees +CONNECT BY id = PRIOR parent_id; +SELECT id, name, parent_id +FROM employees +START WITH id = 1 +CONNECT BY NOCYCLE id = PRIOR parent_id; diff --git a/pkg/sql/parser/testdata/mariadb/mixed.sql b/pkg/sql/parser/testdata/mariadb/mixed.sql new file mode 100644 index 00000000..f351b66c --- /dev/null +++ b/pkg/sql/parser/testdata/mariadb/mixed.sql @@ -0,0 +1,7 @@ +CREATE SEQUENCE IF NOT EXISTS order_seq START WITH 1 INCREMENT BY 1; +CREATE TABLE orders ( + id INT NOT NULL, + customer_id INT NOT NULL, + total DECIMAL(12,2) +) WITH SYSTEM VERSIONING; +DROP SEQUENCE IF NOT EXISTS order_seq; diff --git a/pkg/sql/parser/testdata/mariadb/sequences.sql b/pkg/sql/parser/testdata/mariadb/sequences.sql new file mode 100644 index 00000000..2273718a --- /dev/null +++ b/pkg/sql/parser/testdata/mariadb/sequences.sql @@ -0,0 +1,9 @@ +CREATE SEQUENCE seq_orders START WITH 1 INCREMENT BY 1; +CREATE SEQUENCE IF NOT EXISTS seq_invoices START WITH 1000 MAXVALUE 99999 CYCLE; +CREATE OR REPLACE SEQUENCE seq_users START WITH 1 INCREMENT BY 1 NOCACHE; +DROP SEQUENCE seq_orders; +DROP SEQUENCE IF EXISTS seq_invoices; +ALTER SEQUENCE seq_orders RESTART WITH 1; +ALTER SEQUENCE s2 MINVALUE 10 MAXVALUE 99999; +ALTER SEQUENCE s2 NO MINVALUE NO MAXVALUE; +CREATE SEQUENCE s6 NO MINVALUE NO MAXVALUE NOCACHE NOCYCLE; diff --git a/pkg/sql/parser/testdata/mariadb/temporal.sql b/pkg/sql/parser/testdata/mariadb/temporal.sql new file mode 100644 index 00000000..e0d8a2f4 --- /dev/null +++ b/pkg/sql/parser/testdata/mariadb/temporal.sql @@ -0,0 +1,12 @@ +CREATE TABLE prices ( + id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, + item VARCHAR(100), + price DECIMAL(10,2) +) WITH SYSTEM VERSIONING; +SELECT id, price FROM prices FOR SYSTEM_TIME AS OF TIMESTAMP '2023-06-15 12:00:00'; +SELECT id, price FROM prices FOR SYSTEM_TIME ALL; +SELECT id, price FROM prices FOR SYSTEM_TIME BETWEEN '2022-01-01' AND '2023-01-01'; +SELECT o.id, o.status, o.created_at +FROM orders AS o FOR SYSTEM_TIME AS OF TIMESTAMP '2023-06-01 00:00:00' +JOIN customers AS c ON o.customer_id = c.id +WHERE o.id = 1;