diff --git a/cassandra/ast/node.go b/cassandra/ast/node.go new file mode 100644 index 00000000..27911edd --- /dev/null +++ b/cassandra/ast/node.go @@ -0,0 +1,1042 @@ +// Package ast defines parse-tree node types for the omni Cassandra CQL parser. +package ast + +// Loc represents a source location range (byte offsets). +type Loc struct { + Start int // inclusive byte offset, -1 = unknown + End int // exclusive byte offset, -1 = unknown +} + +func NoLoc() Loc { return Loc{Start: -1, End: -1} } + +// Node is the interface implemented by all parse-tree nodes. +type Node interface { + nodeTag() + GetLoc() Loc +} + +// StmtNode marks top-level statement nodes. +type StmtNode interface { + Node + stmtNode() +} + +// ExprNode marks expression nodes. +type ExprNode interface { + Node + exprNode() +} + +// List is a generic ordered collection of nodes. +type List struct { + Items []Node + Loc Loc +} + +func (*List) nodeTag() {} +func (l *List) GetLoc() Loc { return l.Loc } +func (l *List) Len() int { + if l == nil { + return 0 + } + return len(l.Items) +} + +// RawStmt wraps a top-level statement with its location in the input. +type RawStmt struct { + Stmt Node + StmtLocation int + StmtLen int +} + +func (*RawStmt) nodeTag() {} +func (r *RawStmt) GetLoc() Loc { return Loc{Start: r.StmtLocation, End: r.StmtLocation + r.StmtLen} } + +// --------------------------------------------------------------------------- +// Identifier / Name nodes +// --------------------------------------------------------------------------- + +// Identifier represents a simple name (column, table, keyspace, etc.). +type Identifier struct { + Name string + Quoted bool + Loc Loc +} + +func (*Identifier) nodeTag() {} +func (n *Identifier) GetLoc() Loc { return n.Loc } +func (*Identifier) exprNode() {} + +// QualifiedName represents a dotted name like keyspace.table. +type QualifiedName struct { + Parts []*Identifier + Loc Loc +} + +func (*QualifiedName) nodeTag() {} +func (n *QualifiedName) GetLoc() Loc { return n.Loc } +func (*QualifiedName) exprNode() {} + +// --------------------------------------------------------------------------- +// Literal nodes +// --------------------------------------------------------------------------- + +type StringLit struct { + Val string + Loc Loc +} + +func (*StringLit) nodeTag() {} +func (n *StringLit) GetLoc() Loc { return n.Loc } +func (*StringLit) exprNode() {} + +type IntegerLit struct { + Val string + Loc Loc +} + +func (*IntegerLit) nodeTag() {} +func (n *IntegerLit) GetLoc() Loc { return n.Loc } +func (*IntegerLit) exprNode() {} + +type FloatLit struct { + Val string + Loc Loc +} + +func (*FloatLit) nodeTag() {} +func (n *FloatLit) GetLoc() Loc { return n.Loc } +func (*FloatLit) exprNode() {} + +type BoolLit struct { + Val bool + Loc Loc +} + +func (*BoolLit) nodeTag() {} +func (n *BoolLit) GetLoc() Loc { return n.Loc } +func (*BoolLit) exprNode() {} + +type NullLit struct { + Loc Loc +} + +func (*NullLit) nodeTag() {} +func (n *NullLit) GetLoc() Loc { return n.Loc } +func (*NullLit) exprNode() {} + +type UUIDLit struct { + Val string + Loc Loc +} + +func (*UUIDLit) nodeTag() {} +func (n *UUIDLit) GetLoc() Loc { return n.Loc } +func (*UUIDLit) exprNode() {} + +type HexLit struct { + Val string + Loc Loc +} + +func (*HexLit) nodeTag() {} +func (n *HexLit) GetLoc() Loc { return n.Loc } +func (*HexLit) exprNode() {} + +type CodeBlock struct { + Val string + Loc Loc +} + +func (*CodeBlock) nodeTag() {} +func (n *CodeBlock) GetLoc() Loc { return n.Loc } +func (*CodeBlock) exprNode() {} + +// --------------------------------------------------------------------------- +// Collection literals +// --------------------------------------------------------------------------- + +type MapLit struct { + Keys []ExprNode + Values []ExprNode + Loc Loc +} + +func (*MapLit) nodeTag() {} +func (n *MapLit) GetLoc() Loc { return n.Loc } +func (*MapLit) exprNode() {} + +type SetLit struct { + Elements []ExprNode + Loc Loc +} + +func (*SetLit) nodeTag() {} +func (n *SetLit) GetLoc() Loc { return n.Loc } +func (*SetLit) exprNode() {} + +type ListLit struct { + Elements []ExprNode + Loc Loc +} + +func (*ListLit) nodeTag() {} +func (n *ListLit) GetLoc() Loc { return n.Loc } +func (*ListLit) exprNode() {} + +type TupleLit struct { + Elements []ExprNode + Loc Loc +} + +func (*TupleLit) nodeTag() {} +func (n *TupleLit) GetLoc() Loc { return n.Loc } +func (*TupleLit) exprNode() {} + +type VectorLit struct { + Elements []ExprNode + Loc Loc +} + +func (*VectorLit) nodeTag() {} +func (n *VectorLit) GetLoc() Loc { return n.Loc } +func (*VectorLit) exprNode() {} + +// --------------------------------------------------------------------------- +// Expression nodes +// --------------------------------------------------------------------------- + +// FunctionCall represents a function invocation like token(col) or now(). +type FunctionCall struct { + Name *Identifier + Args []ExprNode + Star bool // e.g. count(*) + Loc Loc +} + +func (*FunctionCall) nodeTag() {} +func (n *FunctionCall) GetLoc() Loc { return n.Loc } +func (*FunctionCall) exprNode() {} + +// BinaryExpr represents a binary operation (used in WHERE conditions). +type BinaryExpr struct { + Left ExprNode + Op string // =, <, >, <=, >=, +, - + Right ExprNode + Loc Loc +} + +func (*BinaryExpr) nodeTag() {} +func (n *BinaryExpr) GetLoc() Loc { return n.Loc } +func (*BinaryExpr) exprNode() {} + +// InExpr represents col IN (val1, val2, ...). +type InExpr struct { + Column ExprNode + Values []ExprNode + Loc Loc +} + +func (*InExpr) nodeTag() {} +func (n *InExpr) GetLoc() Loc { return n.Loc } +func (*InExpr) exprNode() {} + +// ContainsExpr represents col CONTAINS value or col CONTAINS KEY value. +type ContainsExpr struct { + Column ExprNode + Value ExprNode + IsKey bool + Loc Loc +} + +func (*ContainsExpr) nodeTag() {} +func (n *ContainsExpr) GetLoc() Loc { return n.Loc } +func (*ContainsExpr) exprNode() {} + +// TupleCompareExpr represents (col1, col2) op (val1, val2). +type TupleCompareExpr struct { + Columns []ExprNode + Op string + Values []ExprNode + Loc Loc +} + +func (*TupleCompareExpr) nodeTag() {} +func (n *TupleCompareExpr) GetLoc() Loc { return n.Loc } +func (*TupleCompareExpr) exprNode() {} + +// TupleInExpr represents (col1, col2) IN ((v1, v2), (v3, v4)). +type TupleInExpr struct { + Columns []ExprNode + Tuples []*TupleLit + Loc Loc +} + +func (*TupleInExpr) nodeTag() {} +func (n *TupleInExpr) GetLoc() Loc { return n.Loc } +func (*TupleInExpr) exprNode() {} + +// IndexAccess represents col[index] (map/list element access). +type IndexAccess struct { + Collection ExprNode + Index ExprNode + Loc Loc +} + +func (*IndexAccess) nodeTag() {} +func (n *IndexAccess) GetLoc() Loc { return n.Loc } +func (*IndexAccess) exprNode() {} + +// DotAccess represents keyspace.table or similar dotted access in FROM. +type DotAccess struct { + Object ExprNode + Field *Identifier + Loc Loc +} + +func (*DotAccess) nodeTag() {} +func (n *DotAccess) GetLoc() Loc { return n.Loc } +func (*DotAccess) exprNode() {} + +// StarExpr represents * in SELECT *. +type StarExpr struct { + Loc Loc +} + +func (*StarExpr) nodeTag() {} +func (n *StarExpr) GetLoc() Loc { return n.Loc } +func (*StarExpr) exprNode() {} + +// CastExpr represents CAST(expr AS type). +type CastExpr struct { + Expr ExprNode + Type *DataType + Loc Loc +} + +func (*CastExpr) nodeTag() {} +func (n *CastExpr) GetLoc() Loc { return n.Loc } +func (*CastExpr) exprNode() {} + +// BindMarker represents ? (positional) or :name (named) bind markers. +type BindMarker struct { + Name string // empty for positional ? + Loc Loc +} + +func (*BindMarker) nodeTag() {} +func (n *BindMarker) GetLoc() Loc { return n.Loc } +func (*BindMarker) exprNode() {} + +// --------------------------------------------------------------------------- +// Type nodes +// --------------------------------------------------------------------------- + +// DataType represents a CQL data type like MAP, FROZEN, VECTOR. +type DataType struct { + Name *Identifier + TypeParams []*DataType + Dimension *IntegerLit // for VECTOR + Loc Loc +} + +func (*DataType) nodeTag() {} +func (n *DataType) GetLoc() Loc { return n.Loc } + +// --------------------------------------------------------------------------- +// Clause / helper nodes +// --------------------------------------------------------------------------- + +// ColumnDef represents a column definition in CREATE TABLE. +type ColumnDef struct { + Name *Identifier + Type *DataType + PrimaryKey bool + Static bool + Loc Loc +} + +func (*ColumnDef) nodeTag() {} +func (n *ColumnDef) GetLoc() Loc { return n.Loc } + +// PrimaryKeyDef represents a PRIMARY KEY definition. +type PrimaryKeyDef struct { + PartitionKeys []*Identifier + ClusteringKeys []*Identifier + Loc Loc +} + +func (*PrimaryKeyDef) nodeTag() {} +func (n *PrimaryKeyDef) GetLoc() Loc { return n.Loc } + +// ClusteringOrder represents CLUSTERING ORDER BY (col ASC/DESC). +type ClusteringOrder struct { + Column *Identifier + Direction string // "ASC" or "DESC" + Loc Loc +} + +func (*ClusteringOrder) nodeTag() {} +func (n *ClusteringOrder) GetLoc() Loc { return n.Loc } + +// TableOption represents a table option like key = value. +type TableOption struct { + Name *Identifier + Value ExprNode // StringLit, FloatLit, or OptionHash + Loc Loc +} + +func (*TableOption) nodeTag() {} +func (n *TableOption) GetLoc() Loc { return n.Loc } + +// OptionHash represents { 'key': 'value', ... }. +type OptionHash struct { + Items []*OptionHashItem + Loc Loc +} + +func (*OptionHash) nodeTag() {} +func (n *OptionHash) GetLoc() Loc { return n.Loc } +func (*OptionHash) exprNode() {} + +type OptionHashItem struct { + Key ExprNode + Value ExprNode + Loc Loc +} + +func (*OptionHashItem) nodeTag() {} +func (n *OptionHashItem) GetLoc() Loc { return n.Loc } + +// SelectElement represents a single item in a SELECT clause. +type SelectElement struct { + Expr ExprNode + Alias *Identifier + Loc Loc +} + +func (*SelectElement) nodeTag() {} +func (n *SelectElement) GetLoc() Loc { return n.Loc } + +// AssignmentElement represents col = expr in UPDATE SET. +type AssignmentElement struct { + Target ExprNode + Value ExprNode + Operator string // "=", "+=", "-=" (desugared from col = col + val) + Loc Loc +} + +func (*AssignmentElement) nodeTag() {} +func (n *AssignmentElement) GetLoc() Loc { return n.Loc } + +// IfCondition represents a LWT condition: col op value, col IN (...), col CONTAINS [KEY] value. +type IfCondition struct { + Column *Identifier + Op string // comparison operator, "IN", "CONTAINS", "CONTAINS KEY" + Value ExprNode + InValues []ExprNode // for IN conditions + Loc Loc +} + +func (*IfCondition) nodeTag() {} +func (n *IfCondition) GetLoc() Loc { return n.Loc } + +// UsingClause represents USING TTL n AND TIMESTAMP m. +type UsingClause struct { + TTL ExprNode + Timestamp ExprNode + Loc Loc +} + +func (*UsingClause) nodeTag() {} +func (n *UsingClause) GetLoc() Loc { return n.Loc } + +// OrderByElement represents a single ORDER BY column direction. +type OrderByElement struct { + Column ExprNode + Direction string // "ASC", "DESC", or "" + // ANN OF support + IsANN bool + AnnVector ExprNode + AnnLimit ExprNode + Loc Loc +} + +func (*OrderByElement) nodeTag() {} +func (n *OrderByElement) GetLoc() Loc { return n.Loc } + +// --------------------------------------------------------------------------- +// DML statement nodes +// --------------------------------------------------------------------------- + +type SelectStmt struct { + Distinct bool + JSON bool + Elements []*SelectElement + From *QualifiedName + Where []ExprNode // relation elements connected by AND + GroupBy []*Identifier + OrderBy []*OrderByElement + PerPartitionLimit ExprNode + Limit ExprNode + AllowFiltering bool + Loc Loc +} + +func (*SelectStmt) nodeTag() {} +func (n *SelectStmt) GetLoc() Loc { return n.Loc } +func (*SelectStmt) stmtNode() {} + +type InsertStmt struct { + Table *QualifiedName + Columns []*Identifier + Values []ExprNode + IsJSON bool + JSONValue ExprNode + DefaultUnset bool + DefaultNull bool + IfNotExists bool + Using *UsingClause + Loc Loc +} + +func (*InsertStmt) nodeTag() {} +func (n *InsertStmt) GetLoc() Loc { return n.Loc } +func (*InsertStmt) stmtNode() {} + +type UpdateStmt struct { + Table *QualifiedName + Using *UsingClause + Assignments []*AssignmentElement + Where []ExprNode + IfExists bool + IfConditions []*IfCondition + Loc Loc +} + +func (*UpdateStmt) nodeTag() {} +func (n *UpdateStmt) GetLoc() Loc { return n.Loc } +func (*UpdateStmt) stmtNode() {} + +type DeleteStmt struct { + Columns []ExprNode + From *QualifiedName + Using *UsingClause + Where []ExprNode + IfExists bool + IfConditions []*IfCondition + Loc Loc +} + +func (*DeleteStmt) nodeTag() {} +func (n *DeleteStmt) GetLoc() Loc { return n.Loc } +func (*DeleteStmt) stmtNode() {} + +// BatchType enumerates BATCH types. +type BatchType int + +const ( + BatchDefault BatchType = iota + BatchLogged + BatchUnlogged + BatchCounter +) + +type BatchStmt struct { + Type BatchType + Using *UsingClause + Statements []StmtNode + Loc Loc +} + +func (*BatchStmt) nodeTag() {} +func (n *BatchStmt) GetLoc() Loc { return n.Loc } +func (*BatchStmt) stmtNode() {} + +type TruncateStmt struct { + Table *QualifiedName + Loc Loc +} + +func (*TruncateStmt) nodeTag() {} +func (n *TruncateStmt) GetLoc() Loc { return n.Loc } +func (*TruncateStmt) stmtNode() {} + +type UseStmt struct { + Keyspace *Identifier + Loc Loc +} + +func (*UseStmt) nodeTag() {} +func (n *UseStmt) GetLoc() Loc { return n.Loc } +func (*UseStmt) stmtNode() {} + +// --------------------------------------------------------------------------- +// DDL statement nodes +// --------------------------------------------------------------------------- + +type CreateKeyspaceStmt struct { + IfNotExists bool + Name *Identifier + Replication *OptionHash + DurableWrites *BoolLit + Loc Loc +} + +func (*CreateKeyspaceStmt) nodeTag() {} +func (n *CreateKeyspaceStmt) GetLoc() Loc { return n.Loc } +func (*CreateKeyspaceStmt) stmtNode() {} + +type AlterKeyspaceStmt struct { + IfExists bool + Name *Identifier + Replication *OptionHash + DurableWrites *BoolLit + Loc Loc +} + +func (*AlterKeyspaceStmt) nodeTag() {} +func (n *AlterKeyspaceStmt) GetLoc() Loc { return n.Loc } +func (*AlterKeyspaceStmt) stmtNode() {} + +type DropKeyspaceStmt struct { + IfExists bool + Name *Identifier + Loc Loc +} + +func (*DropKeyspaceStmt) nodeTag() {} +func (n *DropKeyspaceStmt) GetLoc() Loc { return n.Loc } +func (*DropKeyspaceStmt) stmtNode() {} + +type CreateTableStmt struct { + IfNotExists bool + Name *QualifiedName + Columns []*ColumnDef + PrimaryKey *PrimaryKeyDef + Options []*TableOption + ClusteringOrders []*ClusteringOrder + CompactStorage bool + Loc Loc +} + +func (*CreateTableStmt) nodeTag() {} +func (n *CreateTableStmt) GetLoc() Loc { return n.Loc } +func (*CreateTableStmt) stmtNode() {} + +// AlterTableOp enumerates ALTER TABLE operations. +type AlterTableOp int + +const ( + AlterTableAdd AlterTableOp = iota + AlterTableDrop + AlterTableRename + AlterTableWith + AlterTableDropCompactStorage +) + +type AlterTableStmt struct { + IfExists bool + Name *QualifiedName + Op AlterTableOp + AddIfNotExists bool + AddColumns []*ColumnDef + DropIfExists bool + DropColumns []*Identifier + RenameIfExists bool + RenameItems []*AlterTableRenameItem + Options []*TableOption + Loc Loc +} + +func (*AlterTableStmt) nodeTag() {} +func (n *AlterTableStmt) GetLoc() Loc { return n.Loc } +func (*AlterTableStmt) stmtNode() {} + +type AlterTableRenameItem struct { + From *Identifier + To *Identifier + Loc Loc +} + +func (*AlterTableRenameItem) nodeTag() {} +func (n *AlterTableRenameItem) GetLoc() Loc { return n.Loc } + +type DropTableStmt struct { + IfExists bool + Name *QualifiedName + Loc Loc +} + +func (*DropTableStmt) nodeTag() {} +func (n *DropTableStmt) GetLoc() Loc { return n.Loc } +func (*DropTableStmt) stmtNode() {} + +type CreateIndexStmt struct { + IsCustom bool + IfNotExists bool + IndexName *Identifier + Table *QualifiedName + Column ExprNode // Identifier or FunctionCall (e.g., FULL(col)) + UsingClass ExprNode + Options *OptionHash + Loc Loc +} + +func (*CreateIndexStmt) nodeTag() {} +func (n *CreateIndexStmt) GetLoc() Loc { return n.Loc } +func (*CreateIndexStmt) stmtNode() {} + +type DropIndexStmt struct { + IfExists bool + Name *QualifiedName + Loc Loc +} + +func (*DropIndexStmt) nodeTag() {} +func (n *DropIndexStmt) GetLoc() Loc { return n.Loc } +func (*DropIndexStmt) stmtNode() {} + +type CreateTypeStmt struct { + IfNotExists bool + Name *QualifiedName + Fields []*ColumnDef + Loc Loc +} + +func (*CreateTypeStmt) nodeTag() {} +func (n *CreateTypeStmt) GetLoc() Loc { return n.Loc } +func (*CreateTypeStmt) stmtNode() {} + +// AlterTypeOp enumerates ALTER TYPE operations. +type AlterTypeOp int + +const ( + AlterTypeAlter AlterTypeOp = iota + AlterTypeAdd + AlterTypeRename +) + +type AlterTypeStmt struct { + IfExists bool + Name *QualifiedName + Op AlterTypeOp + AlterColumn *Identifier + AlterType *DataType + AddIfNotExists bool + AddFields []*ColumnDef + RenameIfExists bool + Renames []*AlterTypeRenameItem + Loc Loc +} + +func (*AlterTypeStmt) nodeTag() {} +func (n *AlterTypeStmt) GetLoc() Loc { return n.Loc } +func (*AlterTypeStmt) stmtNode() {} + +type AlterTypeRenameItem struct { + From *Identifier + To *Identifier + Loc Loc +} + +func (*AlterTypeRenameItem) nodeTag() {} +func (n *AlterTypeRenameItem) GetLoc() Loc { return n.Loc } + +type DropTypeStmt struct { + IfExists bool + Name *QualifiedName + Loc Loc +} + +func (*DropTypeStmt) nodeTag() {} +func (n *DropTypeStmt) GetLoc() Loc { return n.Loc } +func (*DropTypeStmt) stmtNode() {} + +type CreateMVStmt struct { + IfNotExists bool + Name *QualifiedName + SelectAll bool + SelectColumns []*Identifier + FromTable *QualifiedName + WhereNotNull []*Identifier + WhereRelations []ExprNode + PrimaryKey *PrimaryKeyDef + Options []*TableOption + ClusteringOrders []*ClusteringOrder + Loc Loc +} + +func (*CreateMVStmt) nodeTag() {} +func (n *CreateMVStmt) GetLoc() Loc { return n.Loc } +func (*CreateMVStmt) stmtNode() {} + +type AlterMVStmt struct { + IfExists bool + Name *QualifiedName + Options []*TableOption + Loc Loc +} + +func (*AlterMVStmt) nodeTag() {} +func (n *AlterMVStmt) GetLoc() Loc { return n.Loc } +func (*AlterMVStmt) stmtNode() {} + +type DropMVStmt struct { + IfExists bool + Name *QualifiedName + Loc Loc +} + +func (*DropMVStmt) nodeTag() {} +func (n *DropMVStmt) GetLoc() Loc { return n.Loc } +func (*DropMVStmt) stmtNode() {} + +// ReturnMode for CREATE FUNCTION. +type ReturnMode int + +const ( + ReturnCalledOnNull ReturnMode = iota + ReturnNullOnNull +) + +type CreateFunctionStmt struct { + OrReplace bool + IfNotExists bool + Name *QualifiedName + Params []*FunctionParam + ReturnMode ReturnMode + ReturnType *DataType + Language *Identifier + Body ExprNode // CodeBlock or StringLit + Loc Loc +} + +func (*CreateFunctionStmt) nodeTag() {} +func (n *CreateFunctionStmt) GetLoc() Loc { return n.Loc } +func (*CreateFunctionStmt) stmtNode() {} + +type FunctionParam struct { + Name *Identifier + Type *DataType + Loc Loc +} + +func (*FunctionParam) nodeTag() {} +func (n *FunctionParam) GetLoc() Loc { return n.Loc } + +type DropFunctionStmt struct { + IfExists bool + Name *QualifiedName + ArgTypes []*DataType + Loc Loc +} + +func (*DropFunctionStmt) nodeTag() {} +func (n *DropFunctionStmt) GetLoc() Loc { return n.Loc } +func (*DropFunctionStmt) stmtNode() {} + +type CreateAggregateStmt struct { + OrReplace bool + IfNotExists bool + Name *QualifiedName + ParamType *DataType + SFunc *Identifier + SType *DataType + FinalFunc *Identifier + InitCond ExprNode + Loc Loc +} + +func (*CreateAggregateStmt) nodeTag() {} +func (n *CreateAggregateStmt) GetLoc() Loc { return n.Loc } +func (*CreateAggregateStmt) stmtNode() {} + +type DropAggregateStmt struct { + IfExists bool + Name *QualifiedName + ArgTypes []*DataType + Loc Loc +} + +func (*DropAggregateStmt) nodeTag() {} +func (n *DropAggregateStmt) GetLoc() Loc { return n.Loc } +func (*DropAggregateStmt) stmtNode() {} + +type CreateTriggerStmt struct { + IfNotExists bool + Name *Identifier + Table *QualifiedName + UsingClass ExprNode + Loc Loc +} + +func (*CreateTriggerStmt) nodeTag() {} +func (n *CreateTriggerStmt) GetLoc() Loc { return n.Loc } +func (*CreateTriggerStmt) stmtNode() {} + +type DropTriggerStmt struct { + IfExists bool + Name *Identifier + Table *QualifiedName + Loc Loc +} + +func (*DropTriggerStmt) nodeTag() {} +func (n *DropTriggerStmt) GetLoc() Loc { return n.Loc } +func (*DropTriggerStmt) stmtNode() {} + +// --------------------------------------------------------------------------- +// Auth / Role / User statement nodes +// --------------------------------------------------------------------------- + +type CreateRoleStmt struct { + IfNotExists bool + Name *Identifier + Options []*RoleOption + Loc Loc +} + +func (*CreateRoleStmt) nodeTag() {} +func (n *CreateRoleStmt) GetLoc() Loc { return n.Loc } +func (*CreateRoleStmt) stmtNode() {} + +type RoleOption struct { + Key string // "PASSWORD", "HASHED PASSWORD", "LOGIN", "SUPERUSER", "OPTIONS", "ACCESS" + Value ExprNode + Loc Loc +} + +func (*RoleOption) nodeTag() {} +func (n *RoleOption) GetLoc() Loc { return n.Loc } + +type AlterRoleStmt struct { + IfExists bool + Name *Identifier + Options []*RoleOption + Loc Loc +} + +func (*AlterRoleStmt) nodeTag() {} +func (n *AlterRoleStmt) GetLoc() Loc { return n.Loc } +func (*AlterRoleStmt) stmtNode() {} + +type DropRoleStmt struct { + IfExists bool + Name *Identifier + Loc Loc +} + +func (*DropRoleStmt) nodeTag() {} +func (n *DropRoleStmt) GetLoc() Loc { return n.Loc } +func (*DropRoleStmt) stmtNode() {} + +type CreateUserStmt struct { + IfNotExists bool + Name *Identifier + Password ExprNode + Hashed bool + Superuser *bool + Loc Loc +} + +func (*CreateUserStmt) nodeTag() {} +func (n *CreateUserStmt) GetLoc() Loc { return n.Loc } +func (*CreateUserStmt) stmtNode() {} + +type AlterUserStmt struct { + IfExists bool + Name *Identifier + Password ExprNode + Hashed bool + Superuser *bool + Loc Loc +} + +func (*AlterUserStmt) nodeTag() {} +func (n *AlterUserStmt) GetLoc() Loc { return n.Loc } +func (*AlterUserStmt) stmtNode() {} + +type DropUserStmt struct { + IfExists bool + Name *Identifier + Loc Loc +} + +func (*DropUserStmt) nodeTag() {} +func (n *DropUserStmt) GetLoc() Loc { return n.Loc } +func (*DropUserStmt) stmtNode() {} + +type GrantStmt struct { + Privilege string + Resource *Resource + Role *Identifier + Loc Loc +} + +func (*GrantStmt) nodeTag() {} +func (n *GrantStmt) GetLoc() Loc { return n.Loc } +func (*GrantStmt) stmtNode() {} + +type RevokeStmt struct { + Privilege string + Resource *Resource + Role *Identifier + Loc Loc +} + +func (*RevokeStmt) nodeTag() {} +func (n *RevokeStmt) GetLoc() Loc { return n.Loc } +func (*RevokeStmt) stmtNode() {} + +type GrantRoleStmt struct { + RoleName *Identifier + Grantee *Identifier + Loc Loc +} + +func (*GrantRoleStmt) nodeTag() {} +func (n *GrantRoleStmt) GetLoc() Loc { return n.Loc } +func (*GrantRoleStmt) stmtNode() {} + +type RevokeRoleStmt struct { + RoleName *Identifier + Revokee *Identifier + Loc Loc +} + +func (*RevokeRoleStmt) nodeTag() {} +func (n *RevokeRoleStmt) GetLoc() Loc { return n.Loc } +func (*RevokeRoleStmt) stmtNode() {} + +// Resource represents a CQL resource (ALL KEYSPACES, KEYSPACE ks, TABLE ks.t, etc.) +type Resource struct { + Type string // "ALL KEYSPACES", "KEYSPACE", "TABLE", "ALL FUNCTIONS", "FUNCTION", "ALL ROLES", "ROLE", "ALL MBEANS", "MBEAN", "MBEANS" + Name *QualifiedName + ArgTypes []*DataType // for FUNCTION resource with type signature + Loc Loc +} + +func (*Resource) nodeTag() {} +func (n *Resource) GetLoc() Loc { return n.Loc } + +type ListPermissionsStmt struct { + Privilege string + Resource *Resource + Role *Identifier + Loc Loc +} + +func (*ListPermissionsStmt) nodeTag() {} +func (n *ListPermissionsStmt) GetLoc() Loc { return n.Loc } +func (*ListPermissionsStmt) stmtNode() {} + +type ListRolesStmt struct { + Of *Identifier + NoRecursive bool + Loc Loc +} + +func (*ListRolesStmt) nodeTag() {} +func (n *ListRolesStmt) GetLoc() Loc { return n.Loc } +func (*ListRolesStmt) stmtNode() {} diff --git a/cassandra/ast/walk.go b/cassandra/ast/walk.go new file mode 100644 index 00000000..d35f8c8c --- /dev/null +++ b/cassandra/ast/walk.go @@ -0,0 +1,35 @@ +package ast + +// Visitor defines the interface for AST traversal. +type Visitor interface { + Visit(node Node) Visitor +} + +// Walk traverses an AST in depth-first order. It calls v.Visit(node); if that +// returns a non-nil Visitor w, Walk recurses into the children of node with w, +// then calls w.Visit(nil). +func Walk(v Visitor, node Node) { + if node == nil { + return + } + if v = v.Visit(node); v == nil { + return + } + walkChildren(v, node) + v.Visit(nil) +} + +// Inspect traverses an AST in depth-first order, calling f for each node. +// If f returns true, Inspect recurses into the children of node. +func Inspect(node Node, f func(Node) bool) { + Walk(inspector(f), node) +} + +type inspector func(Node) bool + +func (f inspector) Visit(node Node) Visitor { + if node == nil || !f(node) { + return nil + } + return f +} diff --git a/cassandra/ast/walk_children.go b/cassandra/ast/walk_children.go new file mode 100644 index 00000000..cc57f1a4 --- /dev/null +++ b/cassandra/ast/walk_children.go @@ -0,0 +1,481 @@ +package ast + +func walkChildren(v Visitor, node Node) { + switch n := node.(type) { + case *List: + for _, item := range n.Items { + Walk(v, item) + } + case *RawStmt: + Walk(v, n.Stmt) + + // Identifiers / Names + case *Identifier: + case *QualifiedName: + for _, p := range n.Parts { + Walk(v, p) + } + + // Literals + case *StringLit: + case *IntegerLit: + case *FloatLit: + case *BoolLit: + case *NullLit: + case *UUIDLit: + case *HexLit: + case *CodeBlock: + case *StarExpr: + case *CastExpr: + Walk(v, n.Expr) + Walk(v, n.Type) + case *BindMarker: + + // Collections + case *MapLit: + for _, k := range n.Keys { + Walk(v, k) + } + for _, val := range n.Values { + Walk(v, val) + } + case *SetLit: + for _, e := range n.Elements { + Walk(v, e) + } + case *ListLit: + for _, e := range n.Elements { + Walk(v, e) + } + case *TupleLit: + for _, e := range n.Elements { + Walk(v, e) + } + case *VectorLit: + for _, e := range n.Elements { + Walk(v, e) + } + + // Expressions + case *FunctionCall: + Walk(v, n.Name) + for _, a := range n.Args { + Walk(v, a) + } + case *BinaryExpr: + Walk(v, n.Left) + Walk(v, n.Right) + case *InExpr: + Walk(v, n.Column) + for _, val := range n.Values { + Walk(v, val) + } + case *ContainsExpr: + Walk(v, n.Column) + Walk(v, n.Value) + case *TupleCompareExpr: + for _, c := range n.Columns { + Walk(v, c) + } + for _, val := range n.Values { + Walk(v, val) + } + case *TupleInExpr: + for _, c := range n.Columns { + Walk(v, c) + } + for _, t := range n.Tuples { + Walk(v, t) + } + case *IndexAccess: + Walk(v, n.Collection) + Walk(v, n.Index) + case *DotAccess: + Walk(v, n.Object) + Walk(v, n.Field) + + // Types + case *DataType: + Walk(v, n.Name) + for _, p := range n.TypeParams { + Walk(v, p) + } + if n.Dimension != nil { + Walk(v, n.Dimension) + } + + // Clauses / helpers + case *ColumnDef: + Walk(v, n.Name) + Walk(v, n.Type) + case *PrimaryKeyDef: + for _, k := range n.PartitionKeys { + Walk(v, k) + } + for _, k := range n.ClusteringKeys { + Walk(v, k) + } + case *ClusteringOrder: + Walk(v, n.Column) + case *TableOption: + Walk(v, n.Name) + Walk(v, n.Value) + case *OptionHash: + for _, item := range n.Items { + Walk(v, item) + } + case *OptionHashItem: + Walk(v, n.Key) + Walk(v, n.Value) + case *SelectElement: + Walk(v, n.Expr) + if n.Alias != nil { + Walk(v, n.Alias) + } + case *AssignmentElement: + Walk(v, n.Target) + Walk(v, n.Value) + case *IfCondition: + Walk(v, n.Column) + if n.Value != nil { + Walk(v, n.Value) + } + for _, val := range n.InValues { + Walk(v, val) + } + case *UsingClause: + if n.TTL != nil { + Walk(v, n.TTL) + } + if n.Timestamp != nil { + Walk(v, n.Timestamp) + } + case *OrderByElement: + Walk(v, n.Column) + if n.AnnVector != nil { + Walk(v, n.AnnVector) + } + if n.AnnLimit != nil { + Walk(v, n.AnnLimit) + } + + // DML + case *SelectStmt: + for _, e := range n.Elements { + Walk(v, e) + } + if n.From != nil { + Walk(v, n.From) + } + for _, w := range n.Where { + Walk(v, w) + } + for _, g := range n.GroupBy { + Walk(v, g) + } + for _, o := range n.OrderBy { + Walk(v, o) + } + if n.PerPartitionLimit != nil { + Walk(v, n.PerPartitionLimit) + } + if n.Limit != nil { + Walk(v, n.Limit) + } + case *InsertStmt: + Walk(v, n.Table) + for _, c := range n.Columns { + Walk(v, c) + } + for _, val := range n.Values { + Walk(v, val) + } + if n.JSONValue != nil { + Walk(v, n.JSONValue) + } + if n.Using != nil { + Walk(v, n.Using) + } + case *UpdateStmt: + Walk(v, n.Table) + if n.Using != nil { + Walk(v, n.Using) + } + for _, a := range n.Assignments { + Walk(v, a) + } + for _, w := range n.Where { + Walk(v, w) + } + for _, c := range n.IfConditions { + Walk(v, c) + } + case *DeleteStmt: + for _, c := range n.Columns { + Walk(v, c) + } + Walk(v, n.From) + if n.Using != nil { + Walk(v, n.Using) + } + for _, w := range n.Where { + Walk(v, w) + } + for _, c := range n.IfConditions { + Walk(v, c) + } + case *BatchStmt: + if n.Using != nil { + Walk(v, n.Using) + } + for _, s := range n.Statements { + Walk(v, s) + } + case *TruncateStmt: + Walk(v, n.Table) + case *UseStmt: + Walk(v, n.Keyspace) + + // DDL + case *CreateKeyspaceStmt: + Walk(v, n.Name) + if n.Replication != nil { + Walk(v, n.Replication) + } + if n.DurableWrites != nil { + Walk(v, n.DurableWrites) + } + case *AlterKeyspaceStmt: + Walk(v, n.Name) + if n.Replication != nil { + Walk(v, n.Replication) + } + if n.DurableWrites != nil { + Walk(v, n.DurableWrites) + } + case *DropKeyspaceStmt: + Walk(v, n.Name) + case *CreateTableStmt: + Walk(v, n.Name) + for _, c := range n.Columns { + Walk(v, c) + } + if n.PrimaryKey != nil { + Walk(v, n.PrimaryKey) + } + for _, o := range n.Options { + Walk(v, o) + } + for _, co := range n.ClusteringOrders { + Walk(v, co) + } + case *AlterTableStmt: + Walk(v, n.Name) + for _, c := range n.AddColumns { + Walk(v, c) + } + for _, c := range n.DropColumns { + Walk(v, c) + } + for _, r := range n.RenameItems { + Walk(v, r) + } + for _, o := range n.Options { + Walk(v, o) + } + case *AlterTableRenameItem: + Walk(v, n.From) + Walk(v, n.To) + case *DropTableStmt: + Walk(v, n.Name) + case *CreateIndexStmt: + if n.IndexName != nil { + Walk(v, n.IndexName) + } + Walk(v, n.Table) + Walk(v, n.Column) + if n.UsingClass != nil { + Walk(v, n.UsingClass) + } + if n.Options != nil { + Walk(v, n.Options) + } + case *DropIndexStmt: + Walk(v, n.Name) + case *CreateTypeStmt: + Walk(v, n.Name) + for _, f := range n.Fields { + Walk(v, f) + } + case *AlterTypeStmt: + Walk(v, n.Name) + if n.AlterColumn != nil { + Walk(v, n.AlterColumn) + } + if n.AlterType != nil { + Walk(v, n.AlterType) + } + for _, f := range n.AddFields { + Walk(v, f) + } + for _, r := range n.Renames { + Walk(v, r) + } + case *AlterTypeRenameItem: + Walk(v, n.From) + Walk(v, n.To) + case *DropTypeStmt: + Walk(v, n.Name) + case *CreateMVStmt: + Walk(v, n.Name) + for _, c := range n.SelectColumns { + Walk(v, c) + } + Walk(v, n.FromTable) + for _, c := range n.WhereNotNull { + Walk(v, c) + } + for _, r := range n.WhereRelations { + Walk(v, r) + } + if n.PrimaryKey != nil { + Walk(v, n.PrimaryKey) + } + for _, o := range n.Options { + Walk(v, o) + } + for _, co := range n.ClusteringOrders { + Walk(v, co) + } + case *AlterMVStmt: + Walk(v, n.Name) + for _, o := range n.Options { + Walk(v, o) + } + case *DropMVStmt: + Walk(v, n.Name) + case *CreateFunctionStmt: + Walk(v, n.Name) + for _, p := range n.Params { + Walk(v, p) + } + if n.ReturnType != nil { + Walk(v, n.ReturnType) + } + if n.Language != nil { + Walk(v, n.Language) + } + if n.Body != nil { + Walk(v, n.Body) + } + case *FunctionParam: + Walk(v, n.Name) + Walk(v, n.Type) + case *DropFunctionStmt: + Walk(v, n.Name) + for _, at := range n.ArgTypes { + Walk(v, at) + } + case *CreateAggregateStmt: + Walk(v, n.Name) + if n.ParamType != nil { + Walk(v, n.ParamType) + } + if n.SFunc != nil { + Walk(v, n.SFunc) + } + if n.SType != nil { + Walk(v, n.SType) + } + if n.FinalFunc != nil { + Walk(v, n.FinalFunc) + } + if n.InitCond != nil { + Walk(v, n.InitCond) + } + case *DropAggregateStmt: + Walk(v, n.Name) + for _, at := range n.ArgTypes { + Walk(v, at) + } + case *CreateTriggerStmt: + Walk(v, n.Name) + if n.Table != nil { + Walk(v, n.Table) + } + if n.UsingClass != nil { + Walk(v, n.UsingClass) + } + case *DropTriggerStmt: + if n.Name != nil { + Walk(v, n.Name) + } + if n.Table != nil { + Walk(v, n.Table) + } + + // Auth + case *CreateRoleStmt: + Walk(v, n.Name) + for _, o := range n.Options { + Walk(v, o) + } + case *RoleOption: + if n.Value != nil { + Walk(v, n.Value) + } + case *AlterRoleStmt: + Walk(v, n.Name) + for _, o := range n.Options { + Walk(v, o) + } + case *DropRoleStmt: + Walk(v, n.Name) + case *CreateUserStmt: + Walk(v, n.Name) + if n.Password != nil { + Walk(v, n.Password) + } + case *AlterUserStmt: + Walk(v, n.Name) + if n.Password != nil { + Walk(v, n.Password) + } + case *DropUserStmt: + Walk(v, n.Name) + case *GrantStmt: + if n.Resource != nil { + Walk(v, n.Resource) + } + Walk(v, n.Role) + case *RevokeStmt: + if n.Resource != nil { + Walk(v, n.Resource) + } + Walk(v, n.Role) + case *GrantRoleStmt: + Walk(v, n.RoleName) + Walk(v, n.Grantee) + case *RevokeRoleStmt: + Walk(v, n.RoleName) + Walk(v, n.Revokee) + case *Resource: + if n.Name != nil { + Walk(v, n.Name) + } + for _, at := range n.ArgTypes { + Walk(v, at) + } + case *ListPermissionsStmt: + if n.Resource != nil { + Walk(v, n.Resource) + } + if n.Role != nil { + Walk(v, n.Role) + } + case *ListRolesStmt: + if n.Of != nil { + Walk(v, n.Of) + } + } +} diff --git a/cassandra/ast/walk_test.go b/cassandra/ast/walk_test.go new file mode 100644 index 00000000..868f0fb7 --- /dev/null +++ b/cassandra/ast/walk_test.go @@ -0,0 +1,102 @@ +package ast + +import ( + "reflect" + "testing" +) + +func TestWalkNil(t *testing.T) { + var count int + Inspect(nil, func(n Node) bool { + count++ + return true + }) + if count != 0 { + t.Errorf("expected 0 visits for nil, got %d", count) + } +} + +func TestInspectSelectStmt(t *testing.T) { + sel := &SelectStmt{ + Elements: []*SelectElement{ + { + Expr: &Identifier{Name: "name", Loc: Loc{Start: 7, End: 11}}, + Alias: &Identifier{Name: "n", Loc: Loc{Start: 15, End: 16}}, + Loc: Loc{Start: 7, End: 16}, + }, + }, + From: &QualifiedName{ + Parts: []*Identifier{ + {Name: "users", Loc: Loc{Start: 22, End: 27}}, + }, + Loc: Loc{Start: 22, End: 27}, + }, + Where: []ExprNode{ + &BinaryExpr{ + Left: &Identifier{Name: "id", Loc: Loc{Start: 34, End: 36}}, + Op: "=", + Right: &IntegerLit{Val: "1", Loc: Loc{Start: 39, End: 40}}, + Loc: Loc{Start: 34, End: 40}, + }, + }, + Loc: Loc{Start: 0, End: 40}, + } + + var visited []string + Inspect(sel, func(n Node) bool { + if n == nil { + return false + } + visited = append(visited, reflect.TypeOf(n).Elem().Name()) + return true + }) + + expected := []string{ + "SelectStmt", + "SelectElement", + "Identifier", // name + "Identifier", // alias n + "QualifiedName", + "Identifier", // users + "BinaryExpr", + "Identifier", // id + "IntegerLit", // 1 + } + if len(visited) != len(expected) { + t.Fatalf("expected %d visits, got %d: %v", len(expected), len(visited), visited) + } + for i, e := range expected { + if visited[i] != e { + t.Errorf("visit[%d] = %q, want %q", i, visited[i], e) + } + } +} + +func TestInspectPruning(t *testing.T) { + sel := &SelectStmt{ + From: &QualifiedName{ + Parts: []*Identifier{ + {Name: "users", Loc: Loc{Start: 14, End: 19}}, + }, + Loc: Loc{Start: 14, End: 19}, + }, + Loc: Loc{Start: 0, End: 19}, + } + + var visited []string + Inspect(sel, func(n Node) bool { + if n == nil { + return false + } + name := reflect.TypeOf(n).Elem().Name() + visited = append(visited, name) + if name == "QualifiedName" { + return false + } + return true + }) + + if len(visited) != 2 { + t.Fatalf("expected 2 visits (SelectStmt, QualifiedName), got %d: %v", len(visited), visited) + } +} diff --git a/cassandra/compatibility_test.go b/cassandra/compatibility_test.go new file mode 100644 index 00000000..ad323da1 --- /dev/null +++ b/cassandra/compatibility_test.go @@ -0,0 +1,85 @@ +package cassandra + +import ( + "os" + "path/filepath" + "runtime" + "strings" + "testing" +) + +func testdataDir() string { + _, file, _, _ := runtime.Caller(0) + return filepath.Join(filepath.Dir(file), "testdata", "cql", "examples") +} + +var expectedFailures = map[string]string{ + "applyBatch.cql": "standalone APPLY BATCH is not valid CQL without BEGIN BATCH", +} + +func TestCompatibilityHarness(t *testing.T) { + dir := testdataDir() + entries, err := os.ReadDir(dir) + if err != nil { + t.Fatalf("CQL examples corpus missing at %s: %v", dir, err) + } + + var cqlFiles []os.DirEntry + for _, e := range entries { + if strings.HasSuffix(e.Name(), ".cql") { + cqlFiles = append(cqlFiles, e) + } + } + if len(cqlFiles) == 0 { + t.Fatal("CQL examples corpus is empty") + } + + var ( + totalFiles = len(cqlFiles) + passedFiles int + expectedFailureFiles int + totalStmts int + ) + + for _, entry := range cqlFiles { + t.Run(entry.Name(), func(t *testing.T) { + data, err := os.ReadFile(filepath.Join(dir, entry.Name())) + if err != nil { + t.Fatal(err) + } + + content := string(data) + stmts, err := Parse(content) + if err != nil { + if reason, ok := expectedFailures[entry.Name()]; ok { + expectedFailureFiles++ + t.Skipf("expected failure: %s (%v)", reason, err) + return + } + t.Errorf("Parse failed: %v", err) + return + } + + passedFiles++ + totalStmts += len(stmts) + + for i, s := range stmts { + if s.AST == nil { + t.Errorf("statement %d has nil AST", i) + } + loc := s.AST.GetLoc() + if loc.Start < 0 || loc.End <= loc.Start { + t.Errorf("statement %d has invalid Loc: %+v", i, loc) + } + } + + violations := CheckLocations(t, content) + for _, v := range violations { + t.Errorf("Loc violation: %s", v) + } + }) + } + + t.Logf("Compatibility: %d/%d files passed, %d expected failures, %d total statements", + passedFiles, totalFiles, expectedFailureFiles, totalStmts) +} diff --git a/cassandra/loc_test.go b/cassandra/loc_test.go new file mode 100644 index 00000000..2a8e414f --- /dev/null +++ b/cassandra/loc_test.go @@ -0,0 +1,185 @@ +package cassandra + +import ( + "fmt" + "reflect" + "testing" + + "github.com/bytebase/omni/cassandra/ast" +) + +type LocViolation struct { + Path string + Start int + End int + Reason string +} + +func (v LocViolation) String() string { + return fmt.Sprintf("%s: %s (Start=%d, End=%d)", v.Path, v.Reason, v.Start, v.End) +} + +type locContext struct { + sqlLen int + violations *[]LocViolation +} + +func CheckLocations(t *testing.T, sql string) []LocViolation { + t.Helper() + stmts, err := Parse(sql) + if err != nil { + t.Fatalf("Parse(%q): %v", sql, err) + } + var violations []LocViolation + ctx := &locContext{sqlLen: len(sql), violations: &violations} + for i, s := range stmts { + path := fmt.Sprintf("stmts[%d]", i) + parentLoc := ast.Loc{Start: s.ByteStart, End: s.ByteEnd} + walkNodeLocs(reflect.ValueOf(s.AST), path, parentLoc, ctx) + } + return violations +} + +var locType = reflect.TypeOf(ast.Loc{}) + +func walkNodeLocs(v reflect.Value, path string, parentLoc ast.Loc, ctx *locContext) { + switch v.Kind() { + case reflect.Ptr: + if v.IsNil() { + return + } + walkNodeLocs(v.Elem(), path, parentLoc, ctx) + case reflect.Interface: + if v.IsNil() { + return + } + elem := v.Elem() + typeName := elem.Type().Name() + if elem.Kind() == reflect.Ptr { + typeName = elem.Type().Elem().Name() + } + walkNodeLocs(elem, path+"("+typeName+")", parentLoc, ctx) + case reflect.Struct: + t := v.Type() + locField := v.FieldByName("Loc") + currentLoc := parentLoc + if locField.IsValid() && locField.Type() == locType { + loc := locField.Interface().(ast.Loc) + if loc.Start >= 0 && loc.End >= 0 { + if loc.End <= loc.Start { + *ctx.violations = append(*ctx.violations, LocViolation{ + Path: path, Start: loc.Start, End: loc.End, + Reason: "End <= Start", + }) + } + if loc.End > ctx.sqlLen { + *ctx.violations = append(*ctx.violations, LocViolation{ + Path: path, Start: loc.Start, End: loc.End, + Reason: fmt.Sprintf("End > len(sql) (%d)", ctx.sqlLen), + }) + } + if loc.Start > ctx.sqlLen { + *ctx.violations = append(*ctx.violations, LocViolation{ + Path: path, Start: loc.Start, End: loc.End, + Reason: fmt.Sprintf("Start > len(sql) (%d)", ctx.sqlLen), + }) + } + if parentLoc.Start >= 0 && parentLoc.End >= 0 { + if loc.Start < parentLoc.Start { + *ctx.violations = append(*ctx.violations, LocViolation{ + Path: path, Start: loc.Start, End: loc.End, + Reason: fmt.Sprintf("Start < parent Start (%d)", parentLoc.Start), + }) + } + if loc.End > parentLoc.End { + *ctx.violations = append(*ctx.violations, LocViolation{ + Path: path, Start: loc.Start, End: loc.End, + Reason: fmt.Sprintf("End > parent End (%d)", parentLoc.End), + }) + } + } + currentLoc = loc + } else if (loc.Start < 0) != (loc.End < 0) { + *ctx.violations = append(*ctx.violations, LocViolation{ + Path: path, Start: loc.Start, End: loc.End, + Reason: "mixed unknown sentinel", + }) + } + } + for i := range t.NumField() { + field := t.Field(i) + if !field.IsExported() || field.Name == "Loc" { + continue + } + walkNodeLocs(v.Field(i), path+"."+field.Name, currentLoc, ctx) + } + case reflect.Slice: + for i := range v.Len() { + elem := v.Index(i) + walkNodeLocs(elem, fmt.Sprintf("%s[%d]", path, i), parentLoc, ctx) + } + } +} + +func TestCheckLocations(t *testing.T) { + tests := []string{ + // DML + "SELECT * FROM users", + "SELECT DISTINCT name, age FROM ks.users WHERE id = 1 LIMIT 10 ALLOW FILTERING", + "SELECT JSON name AS n FROM users", + "INSERT INTO users (id, name) VALUES (1, 'Alice')", + "INSERT INTO ks.users (id, name) VALUES (1, 'Bob') IF NOT EXISTS", + "INSERT INTO users (id, name) VALUES (1, 'Charlie') USING TTL 86400", + "INSERT INTO users JSON '{\"id\": 1, \"name\": \"Dave\"}'", + "UPDATE users SET name = 'Alice' WHERE id = 1", + "UPDATE ks.users USING TTL 3600 SET name = 'Bob' WHERE id = 2", + "UPDATE users SET name = 'Charlie' WHERE id = 3 IF EXISTS", + "UPDATE users SET name = 'Dave' WHERE id = 4 IF name = 'old'", + "DELETE FROM users WHERE id = 1", + "DELETE name FROM ks.users WHERE id = 2", + `BEGIN BATCH + INSERT INTO users (id, name) VALUES (1, 'Alice'); + UPDATE users SET name = 'Bob' WHERE id = 2; + DELETE FROM users WHERE id = 3; + APPLY BATCH`, + "TRUNCATE users", + "USE my_keyspace", + + // DDL + "CREATE KEYSPACE ks WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': '1'}", + "ALTER KEYSPACE ks WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': '3'}", + "DROP KEYSPACE IF EXISTS ks", + "CREATE TABLE users (id int, name text, PRIMARY KEY (id))", + "CREATE TABLE t (id int, name text, age int, PRIMARY KEY ((id, name), age)) WITH CLUSTERING ORDER BY (age DESC)", + "ALTER TABLE users ADD email text", + "DROP TABLE IF EXISTS users", + "CREATE INDEX ON users (name)", + "DROP INDEX IF EXISTS users_name_idx", + "CREATE TYPE address (street text, city text)", + "ALTER TYPE address ADD zip text", + "DROP TYPE IF EXISTS address", + "CREATE MATERIALIZED VIEW mv AS SELECT * FROM users WHERE id IS NOT NULL PRIMARY KEY (id)", + "CREATE MATERIALIZED VIEW mv AS SELECT col1, col2 FROM users WHERE col1 IS NOT NULL AND col2 IS NOT NULL PRIMARY KEY (col1, col2)", + + // Auth + "GRANT SELECT ON TABLE users TO reader", + "REVOKE ALL ON ALL KEYSPACES FROM admin", + "LIST ALL PERMISSIONS OF admin", + "LIST ROLES", + "LIST ROLES OF admin NORECURSIVE", + "CREATE ROLE myrole WITH PASSWORD = 'secret' AND LOGIN = true", + "ALTER ROLE myrole WITH PASSWORD = 'newsecret'", + "DROP ROLE IF EXISTS myrole", + + // Multi-statement + "SELECT * FROM users; INSERT INTO users (id) VALUES (1); USE ks", + } + for _, sql := range tests { + t.Run(sql, func(t *testing.T) { + violations := CheckLocations(t, sql) + for _, v := range violations { + t.Errorf("Loc violation: %s", v) + } + }) + } +} diff --git a/cassandra/parse.go b/cassandra/parse.go new file mode 100644 index 00000000..ef50a772 --- /dev/null +++ b/cassandra/parse.go @@ -0,0 +1,92 @@ +// Package cassandra provides a parser for Apache Cassandra CQL (Cassandra Query Language). +package cassandra + +import ( + "sort" + "strings" + + "github.com/bytebase/omni/cassandra/ast" + "github.com/bytebase/omni/cassandra/parser" +) + +// Statement represents a single parsed CQL statement with position information. +type Statement struct { + Text string + AST ast.Node + ByteStart int + ByteEnd int + Start Position + End Position +} + +// Position represents a line/column location in source text. +type Position struct { + Line int // 1-based + Column int // 1-based, bytes +} + +// Empty reports whether the statement is empty (no AST). +func (s Statement) Empty() bool { + return s.AST == nil +} + +// Parse parses a CQL input containing one or more statements. +func Parse(sql string) ([]Statement, error) { + if strings.TrimSpace(sql) == "" { + return nil, nil + } + + list, err := parser.Parse(sql) + if err != nil { + return nil, err + } + if list.Len() == 0 { + return nil, nil + } + + idx := buildLineIndex(sql) + var stmts []Statement + for _, item := range list.Items { + raw, ok := item.(*ast.RawStmt) + if !ok { + continue + } + byteStart := raw.StmtLocation + byteEnd := byteStart + raw.StmtLen + // Trim trailing whitespace from statement text. + for byteEnd > byteStart && isSpace(sql[byteEnd-1]) { + byteEnd-- + } + stmts = append(stmts, Statement{ + Text: sql[byteStart:byteEnd], + AST: raw.Stmt, + ByteStart: byteStart, + ByteEnd: byteEnd, + Start: offsetToPosition(idx, byteStart), + End: offsetToPosition(idx, byteEnd), + }) + } + return stmts, nil +} + +func isSpace(c byte) bool { + return c == ' ' || c == '\t' || c == '\n' || c == '\r' || c == '\f' +} + +type lineIndex []int + +func buildLineIndex(s string) lineIndex { + idx := lineIndex{0} + for i := 0; i < len(s); i++ { + if s[i] == '\n' { + idx = append(idx, i+1) + } + } + return idx +} + +func offsetToPosition(idx lineIndex, offset int) Position { + line := sort.SearchInts(idx, offset+1) + col := offset - idx[line-1] + 1 + return Position{Line: line, Column: col} +} diff --git a/cassandra/parse_test.go b/cassandra/parse_test.go new file mode 100644 index 00000000..30b9ed3a --- /dev/null +++ b/cassandra/parse_test.go @@ -0,0 +1,992 @@ +package cassandra + +import ( + "errors" + "strings" + "testing" + + "github.com/bytebase/omni/cassandra/ast" + "github.com/bytebase/omni/cassandra/parser" +) + +func TestParseEmpty(t *testing.T) { + stmts, err := Parse("") + if err != nil { + t.Fatal(err) + } + if len(stmts) != 0 { + t.Fatalf("expected 0 statements, got %d", len(stmts)) + } +} + +func TestParseBlank(t *testing.T) { + stmts, err := Parse(" \n\t ") + if err != nil { + t.Fatal(err) + } + if len(stmts) != 0 { + t.Fatalf("expected 0 statements, got %d", len(stmts)) + } +} + +func TestParseSelect(t *testing.T) { + tests := []struct { + input string + check func(t *testing.T, s Statement) + }{ + { + input: "SELECT * FROM users", + check: func(t *testing.T, s Statement) { + sel, ok := s.AST.(*ast.SelectStmt) + if !ok { + t.Fatalf("expected *ast.SelectStmt, got %T", s.AST) + } + if sel.From == nil || len(sel.From.Parts) != 1 || sel.From.Parts[0].Name != "users" { + t.Fatal("expected FROM users") + } + }, + }, + { + input: "SELECT DISTINCT name, age FROM ks.users WHERE id = 1 LIMIT 10 ALLOW FILTERING", + check: func(t *testing.T, s Statement) { + sel := s.AST.(*ast.SelectStmt) + if !sel.Distinct { + t.Fatal("expected DISTINCT") + } + if len(sel.Elements) != 2 { + t.Fatalf("expected 2 select elements, got %d", len(sel.Elements)) + } + if sel.From == nil || len(sel.From.Parts) != 2 { + t.Fatal("expected qualified table name ks.users") + } + if len(sel.Where) != 1 { + t.Fatalf("expected 1 WHERE condition, got %d", len(sel.Where)) + } + if sel.Limit == nil { + t.Fatal("expected LIMIT") + } + if !sel.AllowFiltering { + t.Fatal("expected ALLOW FILTERING") + } + }, + }, + { + input: "SELECT JSON name AS n FROM users", + check: func(t *testing.T, s Statement) { + sel := s.AST.(*ast.SelectStmt) + if !sel.JSON { + t.Fatal("expected JSON") + } + if len(sel.Elements) != 1 { + t.Fatalf("expected 1 element, got %d", len(sel.Elements)) + } + if sel.Elements[0].Alias == nil || sel.Elements[0].Alias.Name != "n" { + t.Fatal("expected alias 'n'") + } + }, + }, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + stmts, err := Parse(tt.input) + if err != nil { + t.Fatal(err) + } + if len(stmts) != 1 { + t.Fatalf("expected 1 statement, got %d", len(stmts)) + } + tt.check(t, stmts[0]) + }) + } +} + +func TestParseInsert(t *testing.T) { + tests := []string{ + "INSERT INTO users (id, name) VALUES (1, 'Alice')", + "INSERT INTO ks.users (id, name) VALUES (1, 'Bob') IF NOT EXISTS", + "INSERT INTO users (id, name) VALUES (1, 'Charlie') USING TTL 86400", + "INSERT INTO users JSON '{\"id\": 1, \"name\": \"Dave\"}'", + "INSERT INTO users JSON '{\"id\": 1}' DEFAULT UNSET", + } + for _, input := range tests { + t.Run(input, func(t *testing.T) { + stmts, err := Parse(input) + if err != nil { + t.Fatal(err) + } + if len(stmts) != 1 { + t.Fatalf("expected 1 statement, got %d", len(stmts)) + } + if _, ok := stmts[0].AST.(*ast.InsertStmt); !ok { + t.Fatalf("expected *ast.InsertStmt, got %T", stmts[0].AST) + } + }) + } +} + +func TestParseUpdate(t *testing.T) { + tests := []string{ + "UPDATE users SET name = 'Alice' WHERE id = 1", + "UPDATE ks.users USING TTL 3600 SET name = 'Bob' WHERE id = 2", + "UPDATE users SET name = 'Charlie' WHERE id = 3 IF EXISTS", + "UPDATE users SET name = 'Dave' WHERE id = 4 IF name = 'old'", + } + for _, input := range tests { + t.Run(input, func(t *testing.T) { + stmts, err := Parse(input) + if err != nil { + t.Fatal(err) + } + if len(stmts) != 1 { + t.Fatalf("expected 1 statement, got %d", len(stmts)) + } + if _, ok := stmts[0].AST.(*ast.UpdateStmt); !ok { + t.Fatalf("expected *ast.UpdateStmt, got %T", stmts[0].AST) + } + }) + } +} + +func TestParseDelete(t *testing.T) { + tests := []string{ + "DELETE FROM users WHERE id = 1", + "DELETE name FROM ks.users WHERE id = 2", + "DELETE FROM users WHERE id = 3 IF EXISTS", + } + for _, input := range tests { + t.Run(input, func(t *testing.T) { + stmts, err := Parse(input) + if err != nil { + t.Fatal(err) + } + if len(stmts) != 1 { + t.Fatalf("expected 1 statement, got %d", len(stmts)) + } + if _, ok := stmts[0].AST.(*ast.DeleteStmt); !ok { + t.Fatalf("expected *ast.DeleteStmt, got %T", stmts[0].AST) + } + }) + } +} + +func TestParseBatch(t *testing.T) { + input := `BEGIN BATCH + INSERT INTO users (id, name) VALUES (1, 'Alice'); + UPDATE users SET name = 'Bob' WHERE id = 2; + DELETE FROM users WHERE id = 3; + APPLY BATCH` + stmts, err := Parse(input) + if err != nil { + t.Fatal(err) + } + if len(stmts) != 1 { + t.Fatalf("expected 1 statement, got %d", len(stmts)) + } + batch, ok := stmts[0].AST.(*ast.BatchStmt) + if !ok { + t.Fatalf("expected *ast.BatchStmt, got %T", stmts[0].AST) + } + if len(batch.Statements) != 3 { + t.Fatalf("expected 3 inner statements, got %d", len(batch.Statements)) + } +} + +func TestParseDDL(t *testing.T) { + tests := []struct { + input string + nodeType string + }{ + {"CREATE KEYSPACE ks WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': '1'}", "CreateKeyspaceStmt"}, + {"ALTER KEYSPACE ks WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': '3'}", "AlterKeyspaceStmt"}, + {"DROP KEYSPACE IF EXISTS ks", "DropKeyspaceStmt"}, + {"CREATE TABLE users (id int, name text, PRIMARY KEY (id))", "CreateTableStmt"}, + {"ALTER TABLE users ADD email text", "AlterTableStmt"}, + {"DROP TABLE IF EXISTS users", "DropTableStmt"}, + {"CREATE INDEX ON users (name)", "CreateIndexStmt"}, + {"DROP INDEX IF EXISTS users_name_idx", "DropIndexStmt"}, + {"CREATE TYPE address (street text, city text)", "CreateTypeStmt"}, + {"ALTER TYPE address ADD zip text", "AlterTypeStmt"}, + {"DROP TYPE IF EXISTS address", "DropTypeStmt"}, + {"TRUNCATE users", "TruncateStmt"}, + {"TRUNCATE TABLE ks.users", "TruncateStmt"}, + {"USE my_keyspace", "UseStmt"}, + {"CREATE TABLE t (id int PRIMARY KEY, v vector)", "CreateTableStmt"}, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + stmts, err := Parse(tt.input) + if err != nil { + t.Fatal(err) + } + if len(stmts) != 1 { + t.Fatalf("expected 1 statement, got %d", len(stmts)) + } + }) + } +} + +func TestParseAuth(t *testing.T) { + tests := []string{ + "GRANT SELECT ON TABLE users TO reader", + "REVOKE ALL ON ALL KEYSPACES FROM admin", + "LIST ALL PERMISSIONS OF admin", + "LIST ROLES", + "LIST ROLES OF admin NORECURSIVE", + "CREATE ROLE myrole WITH PASSWORD = 'secret' AND LOGIN = true", + "ALTER ROLE myrole WITH PASSWORD = 'newsecret'", + "DROP ROLE IF EXISTS myrole", + "CREATE USER myuser WITH PASSWORD 'secret' SUPERUSER", + "DROP USER IF EXISTS myuser", + } + for _, input := range tests { + t.Run(input, func(t *testing.T) { + stmts, err := Parse(input) + if err != nil { + t.Fatal(err) + } + if len(stmts) != 1 { + t.Fatalf("expected 1 statement, got %d", len(stmts)) + } + }) + } +} + +func TestParseMultipleStatements(t *testing.T) { + input := "SELECT * FROM users; INSERT INTO users (id) VALUES (1); USE ks" + stmts, err := Parse(input) + if err != nil { + t.Fatal(err) + } + if len(stmts) != 3 { + t.Fatalf("expected 3 statements, got %d", len(stmts)) + } + if _, ok := stmts[0].AST.(*ast.SelectStmt); !ok { + t.Fatalf("stmt 0: expected SelectStmt, got %T", stmts[0].AST) + } + if _, ok := stmts[1].AST.(*ast.InsertStmt); !ok { + t.Fatalf("stmt 1: expected InsertStmt, got %T", stmts[1].AST) + } + if _, ok := stmts[2].AST.(*ast.UseStmt); !ok { + t.Fatalf("stmt 2: expected UseStmt, got %T", stmts[2].AST) + } +} + +func TestParseMV(t *testing.T) { + tests := []struct { + input string + check func(t *testing.T, s Statement) + }{ + { + input: "CREATE MATERIALIZED VIEW mv AS SELECT * FROM users WHERE id IS NOT NULL PRIMARY KEY (id)", + check: func(t *testing.T, s Statement) { + mv, ok := s.AST.(*ast.CreateMVStmt) + if !ok { + t.Fatalf("expected *ast.CreateMVStmt, got %T", s.AST) + } + if !mv.SelectAll { + t.Fatal("expected SelectAll = true") + } + if len(mv.SelectColumns) != 0 { + t.Fatalf("expected 0 SelectColumns, got %d", len(mv.SelectColumns)) + } + if len(mv.WhereNotNull) != 1 || mv.WhereNotNull[0].Name != "id" { + t.Fatal("expected WHERE id IS NOT NULL") + } + }, + }, + { + input: "CREATE MATERIALIZED VIEW mv AS SELECT col1, col2 FROM users WHERE col1 IS NOT NULL AND col2 IS NOT NULL PRIMARY KEY (col1, col2)", + check: func(t *testing.T, s Statement) { + mv := s.AST.(*ast.CreateMVStmt) + if mv.SelectAll { + t.Fatal("expected SelectAll = false") + } + if len(mv.SelectColumns) != 2 { + t.Fatalf("expected 2 SelectColumns, got %d", len(mv.SelectColumns)) + } + if len(mv.WhereNotNull) != 2 { + t.Fatalf("expected 2 WhereNotNull, got %d", len(mv.WhereNotNull)) + } + }, + }, + { + input: "CREATE MATERIALIZED VIEW IF NOT EXISTS ks.mv AS SELECT * FROM users WHERE id IS NOT NULL PRIMARY KEY (id) WITH comment = 'test'", + check: func(t *testing.T, s Statement) { + mv := s.AST.(*ast.CreateMVStmt) + if !mv.IfNotExists { + t.Fatal("expected IfNotExists") + } + if len(mv.Name.Parts) != 2 { + t.Fatal("expected qualified name ks.mv") + } + if len(mv.Options) != 1 { + t.Fatalf("expected 1 option, got %d", len(mv.Options)) + } + }, + }, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + stmts, err := Parse(tt.input) + if err != nil { + t.Fatal(err) + } + if len(stmts) != 1 { + t.Fatalf("expected 1 statement, got %d", len(stmts)) + } + tt.check(t, stmts[0]) + }) + } +} + +func TestParseErrors(t *testing.T) { + tests := []struct { + name string + input string + }{ + // IF NOT EXISTS strict validation + {"IF NOT GARBAGE", "INSERT INTO users (id) VALUES (1) IF NOT GARBAGE"}, + {"IF NOT without EXISTS in CREATE", "CREATE TABLE IF NOT GARBAGE users (id int PRIMARY KEY)"}, + + // Truncated/malformed DML + {"truncated SELECT", "SELECT"}, + {"truncated SELECT FROM", "SELECT * FROM"}, + {"truncated INSERT", "INSERT INTO"}, + {"truncated INSERT no VALUES", "INSERT INTO users (id)"}, + {"truncated UPDATE", "UPDATE"}, + {"truncated UPDATE no SET", "UPDATE users"}, + {"truncated DELETE", "DELETE FROM"}, + {"truncated BATCH", "BEGIN BATCH"}, + + // Truncated/malformed DDL + {"truncated CREATE TABLE", "CREATE TABLE"}, + {"truncated CREATE KEYSPACE", "CREATE KEYSPACE"}, + {"truncated DROP", "DROP"}, + {"truncated ALTER", "ALTER"}, + {"CREATE without object", "CREATE"}, + + // Invalid tokens + {"bare operator", "< >"}, + {"invalid statement start", "123"}, + + // MV IS NOT NULL with wrong tokens + {"MV IS LOL NOPE", "CREATE MATERIALIZED VIEW mv AS SELECT * FROM t WHERE id IS LOL NOPE PRIMARY KEY (id)"}, + + // Type generic validation + {"vector missing element type", "CREATE TABLE t (id int PRIMARY KEY, v vector<3>)"}, + {"vector extra dimension", "CREATE TABLE t (id int PRIMARY KEY, v vector)"}, + {"map with integer param", "CREATE TABLE t (id int PRIMARY KEY, m map)"}, + + // LIMIT only accepts integer or bind marker + {"LIMIT string", "SELECT * FROM t LIMIT 'abc'"}, + {"LIMIT bool", "SELECT * FROM t LIMIT true"}, + {"LIMIT null", "SELECT * FROM t LIMIT null"}, + {"LIMIT float", "SELECT * FROM t LIMIT 3.14"}, + {"PER PARTITION LIMIT string", "SELECT * FROM t PER PARTITION LIMIT 'abc'"}, + {"PER PARTITION LIMIT bool", "SELECT * FROM t PER PARTITION LIMIT false"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := Parse(tt.input) + if err == nil { + t.Fatalf("expected error for %q, got nil", tt.input) + } + }) + } +} + +func TestParseNoPanic(t *testing.T) { + inputs := []string{ + "", + " ", + "\t\n", + ";", + ";;;", + "SELECT", + "INSERT", + "CREATE", + "DROP TABLE", + "ALTER TABLE users", + "BEGIN BATCH APPLY BATCH", + "SELECT * FROM users WHERE", + "CREATE TABLE t (", + "UPDATE users SET name =", + "'unterminated string", + `"unterminated ident`, + "$$unterminated code block", + "/* unterminated block comment", + "SELECT 1e", + "SELECT 1e+", + "GRANT SELECT ON MBEAN 123 TO r", + "GRANT SELECT ON MBEANS 456 TO r", + } + for _, input := range inputs { + t.Run(input, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Fatalf("Parse(%q) panicked: %v", input, r) + } + }() + Parse(input) // error is fine, panic is not + }) + } +} + +func TestParseLocWalker(t *testing.T) { + tests := []string{ + "SELECT * FROM users", + "SELECT name, age FROM ks.users WHERE id = 1", + "INSERT INTO users (id, name) VALUES (1, 'Alice')", + "UPDATE users SET name = 'Bob' WHERE id = 2", + "DELETE FROM users WHERE id = 3", + "CREATE TABLE t (id int, name text, PRIMARY KEY (id))", + "CREATE KEYSPACE ks WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': '1'}", + "DROP TABLE IF EXISTS users", + "USE my_keyspace", + "TRUNCATE TABLE users", + "CREATE MATERIALIZED VIEW mv AS SELECT * FROM users WHERE id IS NOT NULL PRIMARY KEY (id)", + } + for _, input := range tests { + t.Run(input, func(t *testing.T) { + stmts, err := Parse(input) + if err != nil { + t.Fatal(err) + } + for _, s := range stmts { + if s.ByteStart < 0 { + t.Errorf("ByteStart = %d, want >= 0", s.ByteStart) + } + if s.ByteEnd <= s.ByteStart { + t.Errorf("ByteEnd = %d <= ByteStart = %d", s.ByteEnd, s.ByteStart) + } + if s.ByteEnd > len(input) { + t.Errorf("ByteEnd = %d > len(input) = %d", s.ByteEnd, len(input)) + } + text := input[s.ByteStart:s.ByteEnd] + if text != s.Text { + t.Errorf("input[%d:%d] = %q, s.Text = %q", s.ByteStart, s.ByteEnd, text, s.Text) + } + if s.Start.Line < 1 || s.Start.Column < 1 { + t.Errorf("Start = %+v, want line >= 1 and column >= 1", s.Start) + } + loc := s.AST.GetLoc() + if loc.Start < 0 { + t.Errorf("AST Loc.Start = %d, want >= 0", loc.Start) + } + if loc.End <= loc.Start { + t.Errorf("AST Loc.End = %d <= Loc.Start = %d", loc.End, loc.Start) + } + } + }) + } +} + +func TestParsePositions(t *testing.T) { + input := "SELECT * FROM users" + stmts, err := Parse(input) + if err != nil { + t.Fatal(err) + } + if len(stmts) != 1 { + t.Fatalf("expected 1 statement, got %d", len(stmts)) + } + s := stmts[0] + if s.ByteStart != 0 { + t.Errorf("ByteStart = %d, want 0", s.ByteStart) + } + if s.ByteEnd != 19 { + t.Errorf("ByteEnd = %d, want 19", s.ByteEnd) + } + if s.Start.Line != 1 || s.Start.Column != 1 { + t.Errorf("Start = %+v, want {1 1}", s.Start) + } + if s.Text != "SELECT * FROM users" { + t.Errorf("Text = %q", s.Text) + } +} + +// --------------------------------------------------------------------------- +// Phase 4: L3 Error Quality +// --------------------------------------------------------------------------- + +func TestErrorLineColumn(t *testing.T) { + tests := []struct { + name string + sql string + wantLine int + wantCol int + wantNear string + }{ + { + name: "single line error", + sql: "SELECT * FORM users", + wantLine: 1, + wantCol: 10, + wantNear: "FORM", + }, + { + name: "second line error", + sql: "SELECT *\nFORM users", + wantLine: 2, + wantCol: 1, + wantNear: "FORM", + }, + { + name: "deep in statement", + sql: "INSERT INTO users (id) VALUE (1)", + wantLine: 1, + wantCol: 24, + wantNear: "VALUE", + }, + { + name: "unterminated string", + sql: "SELECT 'abc", + wantLine: 1, + wantCol: 8, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := Parse(tt.sql) + if err == nil { + t.Fatal("expected error") + } + var pe *parser.ParseError + if !errors.As(err, &pe) { + t.Fatalf("expected *parser.ParseError, got %T: %v", err, err) + } + if pe.Line != tt.wantLine { + t.Errorf("Line = %d, want %d (error: %s)", pe.Line, tt.wantLine, pe.Error()) + } + if pe.Column != tt.wantCol { + t.Errorf("Column = %d, want %d (error: %s)", pe.Column, tt.wantCol, pe.Error()) + } + if tt.wantNear != "" && pe.Near != tt.wantNear { + t.Errorf("Near = %q, want %q (error: %s)", pe.Near, tt.wantNear, pe.Error()) + } + if !strings.Contains(pe.Error(), "line") { + t.Errorf("error message missing 'line': %s", pe.Error()) + } + if !strings.Contains(pe.Error(), "column") { + t.Errorf("error message missing 'column': %s", pe.Error()) + } + }) + } +} + +func TestErrorAtOrNear(t *testing.T) { + _, err := Parse("SELECT * FORM users") + if err == nil { + t.Fatal("expected error") + } + msg := err.Error() + if !strings.Contains(msg, "at or near") { + t.Errorf("error message missing 'at or near': %s", msg) + } + if !strings.Contains(msg, "FORM") { + t.Errorf("error message missing token text 'FORM': %s", msg) + } +} + +func TestTruncationFuzz(t *testing.T) { + validSQL := []string{ + "SELECT * FROM users WHERE id = 1 ORDER BY name ASC LIMIT 10", + "INSERT INTO users (id, name) VALUES (1, 'Alice') IF NOT EXISTS USING TTL 86400", + "UPDATE users USING TTL 3600 SET name = 'Bob' WHERE id = 2 IF name = 'old'", + "DELETE name FROM ks.users WHERE id = 2 IF EXISTS", + "BEGIN UNLOGGED BATCH USING TIMESTAMP 12345 INSERT INTO t (id) VALUES (1); DELETE FROM t WHERE id = 2; APPLY BATCH", + "CREATE TABLE t (id int, name text, age int, PRIMARY KEY ((id, name), age)) WITH CLUSTERING ORDER BY (age DESC) AND comment = 'test'", + "CREATE KEYSPACE ks WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': '1'} AND DURABLE_WRITES = true", + "CREATE MATERIALIZED VIEW mv AS SELECT col1, col2 FROM t WHERE col1 IS NOT NULL AND col2 IS NOT NULL PRIMARY KEY (col1, col2)", + "CREATE FUNCTION ks.f(input text) CALLED ON NULL INPUT RETURNS text LANGUAGE java AS $$return input;$$", + "CREATE AGGREGATE ks.agg(int) SFUNC plus STYPE int FINALFUNC fin INITCOND 0", + "GRANT SELECT ON TABLE users TO reader", + "CREATE ROLE myrole WITH PASSWORD = 'secret' AND LOGIN = true AND SUPERUSER = false", + } + for _, sql := range validSQL { + for i := 0; i <= len(sql); i++ { + truncated := sql[:i] + func() { + defer func() { + if r := recover(); r != nil { + t.Fatalf("Parse(%q) panicked (truncated from %q at byte %d): %v", truncated, sql, i, r) + } + }() + Parse(truncated) + }() + } + } +} + +func TestBinaryInputNoPanic(t *testing.T) { + inputs := []string{ + "\x00\x01\x02\x03", + string([]byte{0xFF, 0xFE, 0xFD}), + "\x00SELECT * FROM users", + "SELECT\x00FROM\x00users", + string(make([]byte, 1024)), + } + for _, input := range inputs { + func() { + defer func() { + if r := recover(); r != nil { + t.Fatalf("Parse(binary) panicked: %v", r) + } + }() + Parse(input) + }() + } +} + +func TestParseAlterIfExists(t *testing.T) { + tests := []struct { + name string + sql string + check func(t *testing.T, s Statement) + }{ + { + name: "ALTER KEYSPACE IF EXISTS", + sql: "ALTER KEYSPACE IF EXISTS ks WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': '1'}", + check: func(t *testing.T, s Statement) { + stmt := s.AST.(*ast.AlterKeyspaceStmt) + if !stmt.IfExists { + t.Fatal("expected IfExists=true") + } + }, + }, + { + name: "ALTER TABLE IF EXISTS", + sql: "ALTER TABLE IF EXISTS t ADD col text", + check: func(t *testing.T, s Statement) { + stmt := s.AST.(*ast.AlterTableStmt) + if !stmt.IfExists { + t.Fatal("expected IfExists=true") + } + }, + }, + { + name: "ALTER TABLE ADD IF NOT EXISTS", + sql: "ALTER TABLE t ADD IF NOT EXISTS col text", + check: func(t *testing.T, s Statement) { + stmt := s.AST.(*ast.AlterTableStmt) + if !stmt.AddIfNotExists { + t.Fatal("expected AddIfNotExists=true") + } + }, + }, + { + name: "ALTER TABLE DROP IF EXISTS", + sql: "ALTER TABLE t DROP IF EXISTS col", + check: func(t *testing.T, s Statement) { + stmt := s.AST.(*ast.AlterTableStmt) + if !stmt.DropIfExists { + t.Fatal("expected DropIfExists=true") + } + }, + }, + { + name: "ALTER TABLE RENAME IF EXISTS", + sql: "ALTER TABLE t RENAME IF EXISTS a TO b", + check: func(t *testing.T, s Statement) { + stmt := s.AST.(*ast.AlterTableStmt) + if !stmt.RenameIfExists { + t.Fatal("expected RenameIfExists=true") + } + if len(stmt.RenameItems) != 1 { + t.Fatalf("expected 1 rename item, got %d", len(stmt.RenameItems)) + } + }, + }, + { + name: "ALTER TABLE RENAME multiple pairs", + sql: "ALTER TABLE t RENAME a TO b AND c TO d", + check: func(t *testing.T, s Statement) { + stmt := s.AST.(*ast.AlterTableStmt) + if len(stmt.RenameItems) != 2 { + t.Fatalf("expected 2 rename items, got %d", len(stmt.RenameItems)) + } + if stmt.RenameItems[0].From.Name != "a" || stmt.RenameItems[0].To.Name != "b" { + t.Fatal("first rename pair wrong") + } + if stmt.RenameItems[1].From.Name != "c" || stmt.RenameItems[1].To.Name != "d" { + t.Fatal("second rename pair wrong") + } + }, + }, + { + name: "ALTER TABLE IF EXISTS + ADD IF NOT EXISTS combined", + sql: "ALTER TABLE IF EXISTS t ADD IF NOT EXISTS col text", + check: func(t *testing.T, s Statement) { + stmt := s.AST.(*ast.AlterTableStmt) + if !stmt.IfExists { + t.Fatal("expected IfExists=true") + } + if !stmt.AddIfNotExists { + t.Fatal("expected AddIfNotExists=true") + } + }, + }, + { + name: "ALTER TYPE IF EXISTS", + sql: "ALTER TYPE IF EXISTS mytype ADD f2 int", + check: func(t *testing.T, s Statement) { + stmt := s.AST.(*ast.AlterTypeStmt) + if !stmt.IfExists { + t.Fatal("expected IfExists=true") + } + }, + }, + { + name: "ALTER TYPE ADD IF NOT EXISTS", + sql: "ALTER TYPE mytype ADD IF NOT EXISTS f2 int", + check: func(t *testing.T, s Statement) { + stmt := s.AST.(*ast.AlterTypeStmt) + if !stmt.AddIfNotExists { + t.Fatal("expected AddIfNotExists=true") + } + }, + }, + { + name: "ALTER TYPE RENAME IF EXISTS", + sql: "ALTER TYPE mytype RENAME IF EXISTS f1 TO field1", + check: func(t *testing.T, s Statement) { + stmt := s.AST.(*ast.AlterTypeStmt) + if !stmt.RenameIfExists { + t.Fatal("expected RenameIfExists=true") + } + }, + }, + { + name: "ALTER MATERIALIZED VIEW IF EXISTS", + sql: "ALTER MATERIALIZED VIEW IF EXISTS mv WITH comment = 'test'", + check: func(t *testing.T, s Statement) { + stmt := s.AST.(*ast.AlterMVStmt) + if !stmt.IfExists { + t.Fatal("expected IfExists=true") + } + }, + }, + { + name: "ALTER ROLE IF EXISTS", + sql: "ALTER ROLE IF EXISTS r WITH PASSWORD = 'x'", + check: func(t *testing.T, s Statement) { + stmt := s.AST.(*ast.AlterRoleStmt) + if !stmt.IfExists { + t.Fatal("expected IfExists=true") + } + }, + }, + { + name: "ALTER USER IF EXISTS", + sql: "ALTER USER IF EXISTS u WITH PASSWORD 'y'", + check: func(t *testing.T, s Statement) { + stmt := s.AST.(*ast.AlterUserStmt) + if !stmt.IfExists { + t.Fatal("expected IfExists=true") + } + }, + }, + { + name: "ALTER TABLE without IF EXISTS", + sql: "ALTER TABLE t ADD col text", + check: func(t *testing.T, s Statement) { + stmt := s.AST.(*ast.AlterTableStmt) + if stmt.IfExists { + t.Fatal("expected IfExists=false") + } + if stmt.AddIfNotExists { + t.Fatal("expected AddIfNotExists=false") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stmts, err := Parse(tt.sql) + if err != nil { + t.Fatal(err) + } + if len(stmts) != 1 { + t.Fatalf("expected 1 statement, got %d", len(stmts)) + } + tt.check(t, stmts[0]) + }) + } +} + +func TestParseP2Features(t *testing.T) { + cases := []struct { + name string + sql string + check func(t *testing.T, s Statement) + }{ + {"CAST", "SELECT CAST(col AS int) FROM t", func(t *testing.T, s Statement) { + sel := s.AST.(*ast.SelectStmt) + if _, ok := sel.Elements[0].Expr.(*ast.CastExpr); !ok { + t.Fatalf("expected CastExpr, got %T", sel.Elements[0].Expr) + } + }}, + {"bind ?", "INSERT INTO t (id) VALUES (?)", func(t *testing.T, s Statement) { + ins := s.AST.(*ast.InsertStmt) + bm := ins.Values[0].(*ast.BindMarker) + if bm.Name != "" { + t.Fatal("expected positional") + } + }}, + {"bind :name", "SELECT * FROM t WHERE id = :myid", func(t *testing.T, s Statement) { + sel := s.AST.(*ast.SelectStmt) + bm := sel.Where[0].(*ast.BinaryExpr).Right.(*ast.BindMarker) + if bm.Name != "myid" { + t.Fatalf("got %s", bm.Name) + } + }}, + {"NaN", "INSERT INTO t (v) VALUES (NaN)", func(t *testing.T, s Statement) { + fl := s.AST.(*ast.InsertStmt).Values[0].(*ast.FloatLit) + if fl.Val != "NaN" { + t.Fatalf("got %s", fl.Val) + } + }}, + {"Infinity", "INSERT INTO t (v) VALUES (Infinity)", func(t *testing.T, s Statement) { + fl := s.AST.(*ast.InsertStmt).Values[0].(*ast.FloatLit) + if fl.Val != "Infinity" { + t.Fatalf("got %s", fl.Val) + } + }}, + {"DROP FUNCTION with types", "DROP FUNCTION IF EXISTS ks.f(int, text)", func(t *testing.T, s Statement) { + if len(s.AST.(*ast.DropFunctionStmt).ArgTypes) != 2 { + t.Fatal("expected 2 arg types") + } + }}, + {"IF IN", "UPDATE t SET x = 1 WHERE id = 1 IF x IN (1, 2)", func(t *testing.T, s Statement) { + c := s.AST.(*ast.UpdateStmt).IfConditions[0] + if c.Op != "IN" || len(c.InValues) != 2 { + t.Fatal("wrong IF IN") + } + }}, + {"IF CONTAINS", "UPDATE t SET x = 1 WHERE id = 1 IF tags CONTAINS 'a'", func(t *testing.T, s Statement) { + if s.AST.(*ast.UpdateStmt).IfConditions[0].Op != "CONTAINS" { + t.Fatal("wrong op") + } + }}, + {"IF CONTAINS KEY", "UPDATE t SET x = 1 WHERE id = 1 IF m CONTAINS KEY 'k'", func(t *testing.T, s Statement) { + if s.AST.(*ast.UpdateStmt).IfConditions[0].Op != "CONTAINS KEY" { + t.Fatal("wrong op") + } + }}, + {"JSON DEFAULT NULL", "INSERT INTO t JSON '{\"id\":1}' DEFAULT NULL", func(t *testing.T, s Statement) { + ins := s.AST.(*ast.InsertStmt) + if !ins.DefaultNull || ins.DefaultUnset { + t.Fatal("wrong defaults") + } + }}, + {"HASHED PASSWORD", "CREATE ROLE r WITH HASHED PASSWORD = 'h'", func(t *testing.T, s Statement) { + if s.AST.(*ast.CreateRoleStmt).Options[0].Key != "HASHED PASSWORD" { + t.Fatal("wrong key") + } + }}, + {"UDT field UPDATE", "UPDATE t SET a.b = 1 WHERE id = 1", func(t *testing.T, s Statement) { + if _, ok := s.AST.(*ast.UpdateStmt).Assignments[0].Target.(*ast.DotAccess); !ok { + t.Fatal("expected DotAccess") + } + }}, + {"UDT field DELETE", "DELETE a.b FROM t WHERE id = 1", func(t *testing.T, s Statement) { + if _, ok := s.AST.(*ast.DeleteStmt).Columns[0].(*ast.DotAccess); !ok { + t.Fatal("expected DotAccess") + } + }}, + {"MBEAN resource", "GRANT SELECT ON MBEAN 'x' TO r", func(t *testing.T, s Statement) { + if s.AST.(*ast.GrantStmt).Resource.Type != "MBEAN" { + t.Fatal("wrong type") + } + }}, + {"FUNCTION resource with types", "GRANT EXECUTE ON FUNCTION f(int, text) TO r", func(t *testing.T, s Statement) { + r := s.AST.(*ast.GrantStmt).Resource + if r.Type != "FUNCTION" || len(r.ArgTypes) != 2 { + t.Fatal("wrong func resource") + } + }}, + {"singular PERMISSION", "GRANT SELECT PERMISSION ON TABLE t TO r", func(t *testing.T, s Statement) { + g := s.AST.(*ast.GrantStmt) + if g.Privilege != "SELECT" { + t.Fatalf("expected SELECT, got %s", g.Privilege) + } + }}, + {"ALL PERMISSION singular", "GRANT ALL PERMISSION ON TABLE t TO r", func(t *testing.T, s Statement) { + g := s.AST.(*ast.GrantStmt) + if g.Privilege != "ALL PERMISSIONS" { + t.Fatalf("expected ALL PERMISSIONS, got %s", g.Privilege) + } + }}, + {"REVOKE singular PERMISSION", "REVOKE MODIFY PERMISSION ON TABLE t FROM r", func(t *testing.T, s Statement) { + r := s.AST.(*ast.RevokeStmt) + if r.Privilege != "MODIFY" { + t.Fatalf("expected MODIFY, got %s", r.Privilege) + } + }}, + {"LIMIT ?", "SELECT * FROM t LIMIT ?", func(t *testing.T, s Statement) { + sel := s.AST.(*ast.SelectStmt) + if _, ok := sel.Limit.(*ast.BindMarker); !ok { + t.Fatalf("expected BindMarker, got %T", sel.Limit) + } + }}, + {"LIMIT :name", "SELECT * FROM t LIMIT :n", func(t *testing.T, s Statement) { + sel := s.AST.(*ast.SelectStmt) + bm := sel.Limit.(*ast.BindMarker) + if bm.Name != "n" { + t.Fatalf("expected name 'n', got %s", bm.Name) + } + }}, + {"PER PARTITION LIMIT ?", "SELECT * FROM t PER PARTITION LIMIT ?", func(t *testing.T, s Statement) { + sel := s.AST.(*ast.SelectStmt) + if _, ok := sel.PerPartitionLimit.(*ast.BindMarker); !ok { + t.Fatalf("expected BindMarker, got %T", sel.PerPartitionLimit) + } + }}, + {"CREATE CUSTOM INDEX WITH OPTIONS", "CREATE CUSTOM INDEX idx ON t (col) USING 'org.Custom' WITH OPTIONS = {'key': 'val'}", func(t *testing.T, s Statement) { + idx := s.AST.(*ast.CreateIndexStmt) + if !idx.IsCustom { + t.Fatal("expected IsCustom") + } + if idx.Options == nil || len(idx.Options.Items) != 1 { + t.Fatal("expected 1 option") + } + }}, + {"CREATE USER HASHED PASSWORD", "CREATE USER u WITH HASHED PASSWORD 'h' SUPERUSER", func(t *testing.T, s Statement) { + u := s.AST.(*ast.CreateUserStmt) + if !u.Hashed { + t.Fatal("expected Hashed=true") + } + if u.Superuser == nil || !*u.Superuser { + t.Fatal("expected Superuser=true") + } + }}, + {"CREATE USER no password", "CREATE USER u SUPERUSER", func(t *testing.T, s Statement) { + u := s.AST.(*ast.CreateUserStmt) + if u.Password != nil { + t.Fatal("expected no password") + } + }}, + {"ALTER USER HASHED PASSWORD", "ALTER USER u WITH HASHED PASSWORD 'h' NOSUPERUSER", func(t *testing.T, s Statement) { + u := s.AST.(*ast.AlterUserStmt) + if !u.Hashed { + t.Fatal("expected Hashed=true") + } + if u.Superuser == nil || *u.Superuser { + t.Fatal("expected Superuser=false") + } + }}, + } + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + stmts, err := Parse(tt.sql) + if err != nil { + t.Fatal(err) + } + if len(stmts) != 1 { + t.Fatalf("expected 1 stmt, got %d", len(stmts)) + } + tt.check(t, stmts[0]) + }) + } +} diff --git a/cassandra/parser/auth.go b/cassandra/parser/auth.go new file mode 100644 index 00000000..ef242cb8 --- /dev/null +++ b/cassandra/parser/auth.go @@ -0,0 +1,660 @@ +package parser + +import ( + "strings" + + "github.com/bytebase/omni/cassandra/ast" +) + +func (p *Parser) parseGrant() (ast.StmtNode, error) { + start := p.curLoc() + p.advance() // GRANT + + // Check for GRANT role_name TO grantee (role grant). + // If the current token is identifier-like and the next token is TO, + // and the current token is NOT a privilege keyword that would be followed by ON, + // then parse as role grant. + if isIdentLike(p.cur.Type) && p.peekNext().Type == tokTO { + roleName, err := p.parseIdentifier() + if err != nil { + return nil, err + } + if err := p.expectKeyword(tokTO); err != nil { + return nil, err + } + grantee, err := p.parseIdentifier() + if err != nil { + return nil, err + } + return &ast.GrantRoleStmt{ + RoleName: roleName, + Grantee: grantee, + Loc: p.makeLoc(start), + }, nil + } + + priv, err := p.parsePrivilege() + if err != nil { + return nil, err + } + + if err := p.expectKeyword(tokON); err != nil { + return nil, err + } + + resource, err := p.parseResource() + if err != nil { + return nil, err + } + + if err := p.expectKeyword(tokTO); err != nil { + return nil, err + } + + role, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + return &ast.GrantStmt{ + Privilege: priv, + Resource: resource, + Role: role, + Loc: p.makeLoc(start), + }, nil +} + +func (p *Parser) parseRevoke() (ast.StmtNode, error) { + start := p.curLoc() + p.advance() // REVOKE + + // Check for REVOKE role_name FROM revokee (role revoke). + // If the current token is identifier-like and the next token is FROM, + // then parse as role revoke. + if isIdentLike(p.cur.Type) && p.peekNext().Type == tokFROM { + roleName, err := p.parseIdentifier() + if err != nil { + return nil, err + } + if err := p.expectKeyword(tokFROM); err != nil { + return nil, err + } + revokee, err := p.parseIdentifier() + if err != nil { + return nil, err + } + return &ast.RevokeRoleStmt{ + RoleName: roleName, + Revokee: revokee, + Loc: p.makeLoc(start), + }, nil + } + + priv, err := p.parsePrivilege() + if err != nil { + return nil, err + } + + if err := p.expectKeyword(tokON); err != nil { + return nil, err + } + + resource, err := p.parseResource() + if err != nil { + return nil, err + } + + if err := p.expectKeyword(tokFROM); err != nil { + return nil, err + } + + role, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + return &ast.RevokeStmt{ + Privilege: priv, + Resource: resource, + Role: role, + Loc: p.makeLoc(start), + }, nil +} + +func (p *Parser) parseList() (ast.StmtNode, error) { + start := p.curLoc() + p.advance() // LIST + + if p.cur.Type == tokROLES { + return p.parseListRoles(start) + } + + // LIST permissions [ON resource] [OF role] + priv, err := p.parsePrivilege() + if err != nil { + return nil, err + } + + stmt := &ast.ListPermissionsStmt{Privilege: priv} + + if p.match(tokON) { + resource, err := p.parseResource() + if err != nil { + return nil, err + } + stmt.Resource = resource + } + + if p.match(tokOF) { + role, err := p.parseIdentifier() + if err != nil { + return nil, err + } + stmt.Role = role + } + + stmt.Loc = p.makeLoc(start) + return stmt, nil +} + +func (p *Parser) parseListRoles(start int) (*ast.ListRolesStmt, error) { + p.advance() // ROLES + + stmt := &ast.ListRolesStmt{} + + if p.match(tokOF) { + role, err := p.parseIdentifier() + if err != nil { + return nil, err + } + stmt.Of = role + } + + if p.cur.Type == tokNORECURSIVE { + stmt.NoRecursive = true + p.advance() + } + + stmt.Loc = p.makeLoc(start) + return stmt, nil +} + +func (p *Parser) parsePrivilege() (string, error) { + tok := p.cur + switch tok.Type { + case tokALL: + p.advance() + if p.cur.Type == tokPERMISSIONS { + p.advance() + return "ALL PERMISSIONS", nil + } + if p.cur.Type == tokPERMISSION { + p.advance() + return "ALL PERMISSIONS", nil + } + return "ALL", nil + case tokALTER: + p.advance() + p.consumeOptionalPermission() + return "ALTER", nil + case tokAUTHORIZE: + p.advance() + p.consumeOptionalPermission() + return "AUTHORIZE", nil + case tokDESCRIBE: + p.advance() + p.consumeOptionalPermission() + return "DESCRIBE", nil + case tokEXECUTE: + p.advance() + p.consumeOptionalPermission() + return "EXECUTE", nil + case tokCREATE: + p.advance() + p.consumeOptionalPermission() + return "CREATE", nil + case tokDROP: + p.advance() + p.consumeOptionalPermission() + return "DROP", nil + case tokMODIFY: + p.advance() + p.consumeOptionalPermission() + return "MODIFY", nil + case tokSELECT: + p.advance() + p.consumeOptionalPermission() + return "SELECT", nil + default: + return "", p.errorf("expected privilege keyword, got %s", p.tokenDesc()) + } +} + +func (p *Parser) consumeOptionalPermission() { + if p.cur.Type == tokPERMISSION { + p.advance() + } +} + +func (p *Parser) parseResource() (*ast.Resource, error) { + start := p.curLoc() + + switch p.cur.Type { + case tokALL: + p.advance() + switch p.cur.Type { + case tokFUNCTIONS: + p.advance() + if p.match(tokIN) { + if err := p.expectKeyword(tokKEYSPACE); err != nil { + return nil, err + } + ks, err := p.parseIdentifier() + if err != nil { + return nil, err + } + return &ast.Resource{ + Type: "ALL FUNCTIONS IN KEYSPACE", + Name: &ast.QualifiedName{Parts: []*ast.Identifier{ks}, Loc: ks.Loc}, + Loc: p.makeLoc(start), + }, nil + } + return &ast.Resource{Type: "ALL FUNCTIONS", Loc: p.makeLoc(start)}, nil + case tokKEYSPACES: + p.advance() + return &ast.Resource{Type: "ALL KEYSPACES", Loc: p.makeLoc(start)}, nil + case tokROLES: + p.advance() + return &ast.Resource{Type: "ALL ROLES", Loc: p.makeLoc(start)}, nil + case tokMBEANS: + p.advance() + return &ast.Resource{Type: "ALL MBEANS", Loc: p.makeLoc(start)}, nil + default: + return nil, p.errorf("expected FUNCTIONS, KEYSPACES, ROLES, or MBEANS after ALL, got %s", p.tokenDesc()) + } + + case tokFUNCTION: + p.advance() + name, err := p.parseQualifiedName() + if err != nil { + return nil, err + } + res := &ast.Resource{Type: "FUNCTION", Name: name, Loc: p.makeLoc(start)} + if p.cur.Type == tokLPAREN { + p.advance() + if p.cur.Type != tokRPAREN { + argTypes, err := p.parseTypeList() + if err != nil { + return nil, err + } + res.ArgTypes = argTypes + } + if _, err := p.expect(tokRPAREN); err != nil { + return nil, err + } + res.Loc = p.makeLoc(start) + } + return res, nil + + case tokMBEAN: + p.advance() + val, err := p.parseConstant() + if err != nil { + return nil, err + } + strLit, ok := val.(*ast.StringLit) + if !ok { + return nil, p.errorf("expected string literal for MBEAN name, got %T", val) + } + ident := &ast.Identifier{Name: strLit.Val, Loc: val.GetLoc()} + return &ast.Resource{ + Type: "MBEAN", + Name: &ast.QualifiedName{Parts: []*ast.Identifier{ident}, Loc: ident.Loc}, + Loc: p.makeLoc(start), + }, nil + + case tokMBEANS: + p.advance() + val, err := p.parseConstant() + if err != nil { + return nil, err + } + strLit, ok := val.(*ast.StringLit) + if !ok { + return nil, p.errorf("expected string literal for MBEANS pattern, got %T", val) + } + ident := &ast.Identifier{Name: strLit.Val, Loc: val.GetLoc()} + return &ast.Resource{ + Type: "MBEANS", + Name: &ast.QualifiedName{Parts: []*ast.Identifier{ident}, Loc: ident.Loc}, + Loc: p.makeLoc(start), + }, nil + + case tokKEYSPACE: + p.advance() + ks, err := p.parseIdentifier() + if err != nil { + return nil, err + } + return &ast.Resource{ + Type: "KEYSPACE", + Name: &ast.QualifiedName{Parts: []*ast.Identifier{ks}, Loc: ks.Loc}, + Loc: p.makeLoc(start), + }, nil + + case tokROLE: + p.advance() + role, err := p.parseIdentifier() + if err != nil { + return nil, err + } + return &ast.Resource{ + Type: "ROLE", + Name: &ast.QualifiedName{Parts: []*ast.Identifier{role}, Loc: role.Loc}, + Loc: p.makeLoc(start), + }, nil + + case tokTABLE: + p.advance() + name, err := p.parseQualifiedName() + if err != nil { + return nil, err + } + return &ast.Resource{Type: "TABLE", Name: name, Loc: p.makeLoc(start)}, nil + + default: + // Bare table reference: [keyspace.]table (TABLE keyword is optional) + name, err := p.parseQualifiedName() + if err != nil { + return nil, err + } + return &ast.Resource{Type: "TABLE", Name: name, Loc: p.makeLoc(start)}, nil + } +} + +func (p *Parser) parseCreateRole() (*ast.CreateRoleStmt, error) { + start := p.curLoc() + p.advance() // ROLE + + ifNotExists, err := p.parseIfNotExists() + if err != nil { + return nil, err + } + + name, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + stmt := &ast.CreateRoleStmt{ + IfNotExists: ifNotExists, + Name: name, + } + + if p.match(tokWITH) { + opts, err := p.parseRoleWithOptions() + if err != nil { + return nil, err + } + stmt.Options = opts + } + + stmt.Loc = p.makeLoc(start) + return stmt, nil +} + +func (p *Parser) parseAlterRole() (*ast.AlterRoleStmt, error) { + start := p.curLoc() + p.advance() // ROLE + + ifExists := p.parseIfExists() + + name, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + stmt := &ast.AlterRoleStmt{IfExists: ifExists, Name: name} + + if p.match(tokWITH) { + opts, err := p.parseRoleWithOptions() + if err != nil { + return nil, err + } + stmt.Options = opts + } + + stmt.Loc = p.makeLoc(start) + return stmt, nil +} + +func (p *Parser) parseDropRole() (*ast.DropRoleStmt, error) { + start := p.curLoc() + p.advance() // ROLE + + ifExists := p.parseIfExists() + + name, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + return &ast.DropRoleStmt{ + IfExists: ifExists, + Name: name, + Loc: p.makeLoc(start), + }, nil +} + +func (p *Parser) parseRoleWithOptions() ([]*ast.RoleOption, error) { + var opts []*ast.RoleOption + for { + opt, err := p.parseRoleOption() + if err != nil { + return nil, err + } + opts = append(opts, opt) + if !p.match(tokAND) { + break + } + } + return opts, nil +} + +func (p *Parser) parseRoleOption() (*ast.RoleOption, error) { + start := p.curLoc() + tok := p.cur + key := strings.ToUpper(tok.Str) + + switch tok.Type { + case tokHASHED: + p.advance() // HASHED + if err := p.expectKeyword(tokPASSWORD); err != nil { + return nil, err + } + if _, err := p.expect(tokEQ); err != nil { + return nil, err + } + val, err := p.parseConstant() + if err != nil { + return nil, err + } + return &ast.RoleOption{Key: "HASHED PASSWORD", Value: val, Loc: p.makeLoc(start)}, nil + + case tokPASSWORD: + p.advance() + if _, err := p.expect(tokEQ); err != nil { + return nil, err + } + val, err := p.parseConstant() + if err != nil { + return nil, err + } + return &ast.RoleOption{Key: key, Value: val, Loc: p.makeLoc(start)}, nil + + case tokACCESS: + p.advance() // ACCESS + if err := p.expectKeyword(tokTO); err != nil { + return nil, err + } + if p.cur.Type == tokALL { + p.advance() + if err := p.expectKeyword(tokDATACENTERS); err != nil { + return nil, err + } + return &ast.RoleOption{Key: "ACCESS TO ALL DATACENTERS", Loc: p.makeLoc(start)}, nil + } + if err := p.expectKeyword(tokDATACENTERS); err != nil { + return nil, err + } + // Parse set of datacenter names: { 'dc1', 'dc2' } + val, err := p.parseCollectionLiteral() + if err != nil { + return nil, err + } + return &ast.RoleOption{Key: "ACCESS TO DATACENTERS", Value: val, Loc: p.makeLoc(start)}, nil + + case tokLOGIN, tokSUPERUSER: + p.advance() + if _, err := p.expect(tokEQ); err != nil { + return nil, err + } + val, err := p.parseBoolLit() + if err != nil { + return nil, err + } + return &ast.RoleOption{Key: key, Value: val, Loc: p.makeLoc(start)}, nil + + case tokOPTIONS: + p.advance() + if _, err := p.expect(tokEQ); err != nil { + return nil, err + } + val, err := p.parseOptionHash() + if err != nil { + return nil, err + } + return &ast.RoleOption{Key: key, Value: val, Loc: p.makeLoc(start)}, nil + + default: + return nil, p.errorf("expected PASSWORD, LOGIN, SUPERUSER, or OPTIONS, got %s", p.tokenDesc()) + } +} + +func (p *Parser) parseCreateUser() (*ast.CreateUserStmt, error) { + start := p.curLoc() + p.advance() // USER + + ifNotExists, err := p.parseIfNotExists() + if err != nil { + return nil, err + } + + name, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + stmt := &ast.CreateUserStmt{ + IfNotExists: ifNotExists, + Name: name, + } + + if p.cur.Type == tokWITH { + p.advance() + if err := p.parseUserPasswordClause(stmt, nil); err != nil { + return nil, err + } + } + + if p.cur.Type == tokSUPERUSER { + v := true + stmt.Superuser = &v + p.advance() + } else if p.cur.Type == tokNOSUPERUSER { + v := false + stmt.Superuser = &v + p.advance() + } + + stmt.Loc = p.makeLoc(start) + return stmt, nil +} + +// parseUserPasswordClause parses [HASHED] PASSWORD string after WITH. +// Sets Password and Hashed on the appropriate stmt via the setter closures. +func (p *Parser) parseUserPasswordClause(create *ast.CreateUserStmt, alter *ast.AlterUserStmt) error { + hashed := false + if p.cur.Type == tokHASHED { + hashed = true + p.advance() + } + if err := p.expectKeyword(tokPASSWORD); err != nil { + return err + } + pwd, err := p.parseConstant() + if err != nil { + return err + } + if create != nil { + create.Password = pwd + create.Hashed = hashed + } + if alter != nil { + alter.Password = pwd + alter.Hashed = hashed + } + return nil +} + +func (p *Parser) parseAlterUser() (*ast.AlterUserStmt, error) { + start := p.curLoc() + p.advance() // USER + + ifExists := p.parseIfExists() + + name, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + stmt := &ast.AlterUserStmt{IfExists: ifExists, Name: name} + + if p.cur.Type == tokWITH { + p.advance() + if err := p.parseUserPasswordClause(nil, stmt); err != nil { + return nil, err + } + } + + if p.cur.Type == tokSUPERUSER { + v := true + stmt.Superuser = &v + p.advance() + } else if p.cur.Type == tokNOSUPERUSER { + v := false + stmt.Superuser = &v + p.advance() + } + + stmt.Loc = p.makeLoc(start) + return stmt, nil +} + +func (p *Parser) parseDropUser() (*ast.DropUserStmt, error) { + start := p.curLoc() + p.advance() // USER + + ifExists := p.parseIfExists() + + name, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + return &ast.DropUserStmt{ + IfExists: ifExists, + Name: name, + Loc: p.makeLoc(start), + }, nil +} diff --git a/cassandra/parser/batch.go b/cassandra/parser/batch.go new file mode 100644 index 00000000..a7a64c72 --- /dev/null +++ b/cassandra/parser/batch.go @@ -0,0 +1,105 @@ +package parser + +import ( + "github.com/bytebase/omni/cassandra/ast" +) + +// parseBatch parses a BATCH statement: +// +// BEGIN [UNLOGGED|LOGGED] BATCH [USING TIMESTAMP n] +// (insert | update | delete) ';' +// ... +// APPLY BATCH +func (p *Parser) parseBatch() (*ast.BatchStmt, error) { + start := p.curLoc() + if err := p.expectKeyword(tokBEGIN); err != nil { + return nil, err + } + + // Parse optional batch type: UNLOGGED, LOGGED, or COUNTER + batchType := ast.BatchDefault + switch p.cur.Type { + case tokUNLOGGED: + batchType = ast.BatchUnlogged + p.advance() + case tokLOGGED: + batchType = ast.BatchLogged + p.advance() + case tokCOUNTER: + batchType = ast.BatchCounter + p.advance() + } + + if err := p.expectKeyword(tokBATCH); err != nil { + return nil, err + } + + // Optional USING clause (typically just TIMESTAMP for batches) + using, err := p.parseUsingClause() + if err != nil { + return nil, err + } + + // Parse inner DML statements until APPLY BATCH + var stmts []ast.StmtNode + for { + // Check for APPLY BATCH + if p.cur.Type == tokAPPLY { + break + } + if p.cur.Type == tokEOF { + return nil, p.errorf("unexpected end of input, expected APPLY BATCH") + } + + // Skip any stray semicolons between statements + if p.cur.Type == tokSEMI { + p.advance() + continue + } + + // Parse an inner DML statement (INSERT, UPDATE, or DELETE) + var stmt ast.StmtNode + switch p.cur.Type { + case tokINSERT: + s, err := p.parseInsert() + if err != nil { + return nil, err + } + stmt = s + case tokUPDATE: + s, err := p.parseUpdate() + if err != nil { + return nil, err + } + stmt = s + case tokDELETE: + s, err := p.parseDelete() + if err != nil { + return nil, err + } + stmt = s + default: + return nil, p.errorf("expected INSERT, UPDATE, or DELETE inside BATCH, got %s", p.tokenDesc()) + } + + stmts = append(stmts, stmt) + + // Consume optional semicolon after each statement + p.match(tokSEMI) + } + + // Expect APPLY BATCH + if err := p.expectKeyword(tokAPPLY); err != nil { + return nil, err + } + if err := p.expectKeyword(tokBATCH); err != nil { + return nil, err + } + + return &ast.BatchStmt{ + Type: batchType, + Using: using, + Statements: stmts, + Loc: p.makeLoc(start), + }, nil +} diff --git a/cassandra/parser/ddl_index.go b/cassandra/parser/ddl_index.go new file mode 100644 index 00000000..e4212be0 --- /dev/null +++ b/cassandra/parser/ddl_index.go @@ -0,0 +1,131 @@ +package parser + +import ( + "github.com/bytebase/omni/cassandra/ast" +) + +func (p *Parser) parseCreateIndex() (*ast.CreateIndexStmt, error) { + start := p.curLoc() + isCustom := false + if p.cur.Type == tokCUSTOM { + isCustom = true + p.advance() + } + p.advance() // INDEX + + ifNotExists, err := p.parseIfNotExists() + if err != nil { + return nil, err + } + + stmt := &ast.CreateIndexStmt{ + IsCustom: isCustom, + IfNotExists: ifNotExists, + } + + // Optional index name before ON. + if p.cur.Type != tokON { + name, err := p.parseIdentifier() + if err != nil { + return nil, err + } + stmt.IndexName = name + } + + if err := p.expectKeyword(tokON); err != nil { + return nil, err + } + + table, err := p.parseQualifiedName() + if err != nil { + return nil, err + } + stmt.Table = table + + if _, err := p.expect(tokLPAREN); err != nil { + return nil, err + } + + // Index column: IDENT or FULL(IDENT), KEYS(IDENT), VALUES(IDENT), ENTRIES(IDENT) + switch p.cur.Type { + case tokFULL, tokKEYS, tokVALUES, tokENTRIES: + fStart := p.curLoc() + fname, err := p.parseIdentifier() + if err != nil { + return nil, err + } + fc, err := p.parseFunctionCallWithName(fname) + if err != nil { + return nil, err + } + fc.Loc.Start = fStart + stmt.Column = fc + default: + col, err := p.parseIdentifier() + if err != nil { + return nil, err + } + stmt.Column = col + } + + if _, err := p.expect(tokRPAREN); err != nil { + return nil, err + } + + // Optional USING class + if p.match(tokUSING) { + switch p.cur.Type { + case tokSTRING: + stmt.UsingClass = &ast.StringLit{Val: p.cur.Str, Loc: ast.Loc{Start: p.cur.Loc, End: p.cur.End}} + p.advance() + case tokSAI: + stmt.UsingClass = &ast.Identifier{Name: p.cur.Str, Loc: ast.Loc{Start: p.cur.Loc, End: p.cur.End}} + p.advance() + case tokSTORAGEATTACHEDINDEX: + stmt.UsingClass = &ast.Identifier{Name: p.cur.Str, Loc: ast.Loc{Start: p.cur.Loc, End: p.cur.End}} + p.advance() + default: + val, err := p.parseConstant() + if err != nil { + return nil, err + } + stmt.UsingClass = val + } + } + + // Optional WITH OPTIONS = { ... } or WITH { ... } + if p.match(tokWITH) { + if p.cur.Type == tokOPTIONS { + p.advance() + if _, err := p.expect(tokEQ); err != nil { + return nil, err + } + } + opts, err := p.parseOptionHash() + if err != nil { + return nil, err + } + stmt.Options = opts + } + + stmt.Loc = p.makeLoc(start) + return stmt, nil +} + +func (p *Parser) parseDropIndex() (*ast.DropIndexStmt, error) { + start := p.curLoc() + p.advance() // INDEX + + ifExists := p.parseIfExists() + + name, err := p.parseQualifiedName() + if err != nil { + return nil, err + } + + return &ast.DropIndexStmt{ + IfExists: ifExists, + Name: name, + Loc: p.makeLoc(start), + }, nil +} diff --git a/cassandra/parser/ddl_keyspace.go b/cassandra/parser/ddl_keyspace.go new file mode 100644 index 00000000..524e4105 --- /dev/null +++ b/cassandra/parser/ddl_keyspace.go @@ -0,0 +1,136 @@ +package parser + +import ( + "github.com/bytebase/omni/cassandra/ast" +) + +func (p *Parser) parseCreateKeyspace() (*ast.CreateKeyspaceStmt, error) { + start := p.curLoc() + p.advance() // KEYSPACE + + ifNotExists, err := p.parseIfNotExists() + if err != nil { + return nil, err + } + + name, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + if err := p.expectKeyword(tokWITH); err != nil { + return nil, err + } + + stmt := &ast.CreateKeyspaceStmt{ + IfNotExists: ifNotExists, + Name: name, + } + + if err := p.expectKeyword(tokREPLICATION); err != nil { + return nil, err + } + if _, err := p.expect(tokEQ); err != nil { + return nil, err + } + repl, err := p.parseOptionHash() + if err != nil { + return nil, err + } + stmt.Replication = repl + + if p.match(tokAND) { + if err := p.expectKeyword(tokDURABLE_WRITES); err != nil { + return nil, err + } + if _, err := p.expect(tokEQ); err != nil { + return nil, err + } + boolVal, err := p.parseBoolLit() + if err != nil { + return nil, err + } + stmt.DurableWrites = boolVal + } + + stmt.Loc = p.makeLoc(start) + return stmt, nil +} + +func (p *Parser) parseAlterKeyspace() (*ast.AlterKeyspaceStmt, error) { + start := p.curLoc() + p.advance() // KEYSPACE + + ifExists := p.parseIfExists() + + name, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + stmt := &ast.AlterKeyspaceStmt{IfExists: ifExists, Name: name} + + if err := p.expectKeyword(tokWITH); err != nil { + return nil, err + } + if err := p.expectKeyword(tokREPLICATION); err != nil { + return nil, err + } + if _, err := p.expect(tokEQ); err != nil { + return nil, err + } + repl, err := p.parseOptionHash() + if err != nil { + return nil, err + } + stmt.Replication = repl + + if p.match(tokAND) { + if err := p.expectKeyword(tokDURABLE_WRITES); err != nil { + return nil, err + } + if _, err := p.expect(tokEQ); err != nil { + return nil, err + } + boolVal, err := p.parseBoolLit() + if err != nil { + return nil, err + } + stmt.DurableWrites = boolVal + } + + stmt.Loc = p.makeLoc(start) + return stmt, nil +} + +func (p *Parser) parseDropKeyspace() (*ast.DropKeyspaceStmt, error) { + start := p.curLoc() + p.advance() // KEYSPACE + + ifExists := p.parseIfExists() + + name, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + return &ast.DropKeyspaceStmt{ + IfExists: ifExists, + Name: name, + Loc: p.makeLoc(start), + }, nil +} + +func (p *Parser) parseBoolLit() (*ast.BoolLit, error) { + tok := p.cur + switch tok.Type { + case tokTRUE: + p.advance() + return &ast.BoolLit{Val: true, Loc: ast.Loc{Start: tok.Loc, End: tok.End}}, nil + case tokFALSE: + p.advance() + return &ast.BoolLit{Val: false, Loc: ast.Loc{Start: tok.Loc, End: tok.End}}, nil + default: + return nil, p.errorf("expected TRUE or FALSE, got %s", p.tokenDesc()) + } +} diff --git a/cassandra/parser/ddl_misc.go b/cassandra/parser/ddl_misc.go new file mode 100644 index 00000000..acaa12ec --- /dev/null +++ b/cassandra/parser/ddl_misc.go @@ -0,0 +1,377 @@ +package parser + +import ( + "github.com/bytebase/omni/cassandra/ast" +) + +func (p *Parser) parseCreateFunction() (*ast.CreateFunctionStmt, error) { + start := p.curLoc() + p.advance() // FUNCTION + + ifNotExists, err := p.parseIfNotExists() + if err != nil { + return nil, err + } + + name, err := p.parseQualifiedName() + if err != nil { + return nil, err + } + + stmt := &ast.CreateFunctionStmt{ + IfNotExists: ifNotExists, + Name: name, + } + + // Parameter list + if _, err := p.expect(tokLPAREN); err != nil { + return nil, err + } + if p.cur.Type != tokRPAREN { + for { + pStart := p.curLoc() + pName, err := p.parseIdentifier() + if err != nil { + return nil, err + } + pType, err := p.parseDataType() + if err != nil { + return nil, err + } + stmt.Params = append(stmt.Params, &ast.FunctionParam{ + Name: pName, + Type: pType, + Loc: p.makeLoc(pStart), + }) + if !p.match(tokCOMMA) { + break + } + } + } + if _, err := p.expect(tokRPAREN); err != nil { + return nil, err + } + + // Return mode: CALLED ON NULL INPUT | RETURNS NULL ON NULL INPUT + if p.cur.Type == tokCALLED { + stmt.ReturnMode = ast.ReturnCalledOnNull + p.advance() // CALLED + if err := p.expectKeyword(tokON); err != nil { + return nil, err + } + if err := p.expectKeyword(tokNULL); err != nil { + return nil, err + } + if err := p.expectKeyword(tokINPUT); err != nil { + return nil, err + } + } else if p.cur.Type == tokRETURNS && p.peekNext().Type == tokNULL { + stmt.ReturnMode = ast.ReturnNullOnNull + p.advance() // RETURNS + p.advance() // NULL + if err := p.expectKeyword(tokON); err != nil { + return nil, err + } + if err := p.expectKeyword(tokNULL); err != nil { + return nil, err + } + if err := p.expectKeyword(tokINPUT); err != nil { + return nil, err + } + } + + // RETURNS dataType + if err := p.expectKeyword(tokRETURNS); err != nil { + return nil, err + } + + retType, err := p.parseDataType() + if err != nil { + return nil, err + } + stmt.ReturnType = retType + + if err := p.expectKeyword(tokLANGUAGE); err != nil { + return nil, err + } + lang, err := p.parseIdentifier() + if err != nil { + return nil, err + } + stmt.Language = lang + + if err := p.expectKeyword(tokAS); err != nil { + return nil, err + } + + // Body: code block or string literal + switch p.cur.Type { + case tokCODEBLOCK: + stmt.Body = &ast.CodeBlock{Val: p.cur.Str, Loc: ast.Loc{Start: p.cur.Loc, End: p.cur.End}} + p.advance() + case tokSTRING: + stmt.Body = &ast.StringLit{Val: p.cur.Str, Loc: ast.Loc{Start: p.cur.Loc, End: p.cur.End}} + p.advance() + default: + return nil, p.errorf("expected code block or string literal for function body, got %s", p.tokenDesc()) + } + + stmt.Loc = p.makeLoc(start) + return stmt, nil +} + +func (p *Parser) parseDropFunction() (*ast.DropFunctionStmt, error) { + start := p.curLoc() + p.advance() // FUNCTION + + ifExists := p.parseIfExists() + + name, err := p.parseQualifiedName() + if err != nil { + return nil, err + } + + var argTypes []*ast.DataType + if p.cur.Type == tokLPAREN { + p.advance() + if p.cur.Type != tokRPAREN { + argTypes, err = p.parseTypeList() + if err != nil { + return nil, err + } + } + if _, err := p.expect(tokRPAREN); err != nil { + return nil, err + } + } + + return &ast.DropFunctionStmt{ + IfExists: ifExists, + Name: name, + ArgTypes: argTypes, + Loc: p.makeLoc(start), + }, nil +} + +func (p *Parser) parseCreateAggregate() (*ast.CreateAggregateStmt, error) { + start := p.curLoc() + p.advance() // AGGREGATE + + ifNotExists, err := p.parseIfNotExists() + if err != nil { + return nil, err + } + + name, err := p.parseQualifiedName() + if err != nil { + return nil, err + } + + if _, err := p.expect(tokLPAREN); err != nil { + return nil, err + } + paramType, err := p.parseDataType() + if err != nil { + return nil, err + } + if _, err := p.expect(tokRPAREN); err != nil { + return nil, err + } + + if err := p.expectKeyword(tokSFUNC); err != nil { + return nil, err + } + sfunc, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + if err := p.expectKeyword(tokSTYPE); err != nil { + return nil, err + } + stype, err := p.parseDataType() + if err != nil { + return nil, err + } + + var finalfunc *ast.Identifier + if p.cur.Type == tokFINALFUNC { + p.advance() + ff, err := p.parseIdentifier() + if err != nil { + return nil, err + } + finalfunc = ff + } + + var initcond ast.ExprNode + if p.cur.Type == tokINITCOND { + p.advance() + ic, err := p.parseInitCondDefinition() + if err != nil { + return nil, err + } + initcond = ic + } + + return &ast.CreateAggregateStmt{ + IfNotExists: ifNotExists, + Name: name, + ParamType: paramType, + SFunc: sfunc, + SType: stype, + FinalFunc: finalfunc, + InitCond: initcond, + Loc: p.makeLoc(start), + }, nil +} + +func (p *Parser) parseInitCondDefinition() (ast.ExprNode, error) { + switch p.cur.Type { + case tokLBRACE: + return p.parseMapOrSetLiteral() + case tokLPAREN: + return p.parseInitCondListOrNested() + default: + return p.parseConstant() + } +} + +func (p *Parser) parseInitCondListOrNested() (ast.ExprNode, error) { + start := p.curLoc() + p.advance() // ( + var elems []ast.ExprNode + for p.cur.Type != tokRPAREN && p.cur.Type != tokEOF { + elem, err := p.parseInitCondDefinition() + if err != nil { + return nil, err + } + elems = append(elems, elem) + if !p.match(tokCOMMA) { + break + } + } + if _, err := p.expect(tokRPAREN); err != nil { + return nil, err + } + return &ast.TupleLit{Elements: elems, Loc: p.makeLoc(start)}, nil +} + +func (p *Parser) parseDropAggregate() (*ast.DropAggregateStmt, error) { + start := p.curLoc() + p.advance() // AGGREGATE + + ifExists := p.parseIfExists() + + name, err := p.parseQualifiedName() + if err != nil { + return nil, err + } + + var argTypes []*ast.DataType + if p.cur.Type == tokLPAREN { + p.advance() + if p.cur.Type != tokRPAREN { + var err error + argTypes, err = p.parseTypeList() + if err != nil { + return nil, err + } + } + if _, err := p.expect(tokRPAREN); err != nil { + return nil, err + } + } + + return &ast.DropAggregateStmt{ + IfExists: ifExists, + Name: name, + ArgTypes: argTypes, + Loc: p.makeLoc(start), + }, nil +} + +// parseTypeList parses a comma-separated list of data types. +func (p *Parser) parseTypeList() ([]*ast.DataType, error) { + var types []*ast.DataType + for { + dt, err := p.parseDataType() + if err != nil { + return nil, err + } + types = append(types, dt) + if !p.match(tokCOMMA) { + break + } + } + return types, nil +} + +func (p *Parser) parseCreateTrigger() (*ast.CreateTriggerStmt, error) { + start := p.curLoc() + p.advance() // TRIGGER + + ifNotExists, err := p.parseIfNotExists() + if err != nil { + return nil, err + } + + name, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + if err := p.expectKeyword(tokON); err != nil { + return nil, err + } + + table, err := p.parseQualifiedName() + if err != nil { + return nil, err + } + + if err := p.expectKeyword(tokUSING); err != nil { + return nil, err + } + + usingClass, err := p.parseConstant() + if err != nil { + return nil, err + } + + return &ast.CreateTriggerStmt{ + IfNotExists: ifNotExists, + Name: name, + Table: table, + UsingClass: usingClass, + Loc: p.makeLoc(start), + }, nil +} + +func (p *Parser) parseDropTrigger() (*ast.DropTriggerStmt, error) { + start := p.curLoc() + p.advance() // TRIGGER + + ifExists := p.parseIfExists() + + name, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + if err := p.expectKeyword(tokON); err != nil { + return nil, err + } + + table, err := p.parseQualifiedName() + if err != nil { + return nil, err + } + + return &ast.DropTriggerStmt{ + IfExists: ifExists, + Name: name, + Table: table, + Loc: p.makeLoc(start), + }, nil +} diff --git a/cassandra/parser/ddl_mv.go b/cassandra/parser/ddl_mv.go new file mode 100644 index 00000000..9ed03eaf --- /dev/null +++ b/cassandra/parser/ddl_mv.go @@ -0,0 +1,88 @@ +package parser + +import ( + "github.com/bytebase/omni/cassandra/ast" +) + +func (p *Parser) parseMVOptions(stmt *ast.CreateMVStmt) error { + for { + if p.cur.Type == tokCLUSTERING { + co, err := p.parseClusteringOrderClause() + if err != nil { + return err + } + stmt.ClusteringOrders = co + } else { + opt, err := p.parseTableOption() + if err != nil { + return err + } + stmt.Options = append(stmt.Options, opt) + } + if !p.match(tokAND) { + break + } + } + return nil +} + +func (p *Parser) isColumnNotNull() bool { + if !isIdentLike(p.cur.Type) { + return false + } + // We need to look ahead: ident IS NOT NULL + // This is tricky with limited lookahead. Use peekNext for the IS keyword. + next := p.peekNext() + return next.Type == tokIS +} + +func (p *Parser) isNextColumnNotNull(next Token) bool { + // After AND, check if the next token starts ident IS NOT NULL + return isIdentLike(next.Type) +} + +func (p *Parser) parseAlterMV() (*ast.AlterMVStmt, error) { + start := p.curLoc() + + ifExists := p.parseIfExists() + + name, err := p.parseQualifiedName() + if err != nil { + return nil, err + } + + stmt := &ast.AlterMVStmt{IfExists: ifExists, Name: name} + + if p.match(tokWITH) { + for { + opt, err := p.parseTableOption() + if err != nil { + return nil, err + } + stmt.Options = append(stmt.Options, opt) + if !p.match(tokAND) { + break + } + } + } + + stmt.Loc = p.makeLoc(start) + return stmt, nil +} + +func (p *Parser) parseDropMV() (*ast.DropMVStmt, error) { + start := p.curLoc() + + ifExists := p.parseIfExists() + + name, err := p.parseQualifiedName() + if err != nil { + return nil, err + } + + return &ast.DropMVStmt{ + IfExists: ifExists, + Name: name, + Loc: p.makeLoc(start), + }, nil +} diff --git a/cassandra/parser/ddl_table.go b/cassandra/parser/ddl_table.go new file mode 100644 index 00000000..d5c3d8cc --- /dev/null +++ b/cassandra/parser/ddl_table.go @@ -0,0 +1,356 @@ +package parser + +import ( + "github.com/bytebase/omni/cassandra/ast" +) + +func (p *Parser) parseCreateTable() (*ast.CreateTableStmt, error) { + start := p.curLoc() + p.advance() // TABLE + + ifNotExists, err := p.parseIfNotExists() + if err != nil { + return nil, err + } + + name, err := p.parseQualifiedName() + if err != nil { + return nil, err + } + + if _, err := p.expect(tokLPAREN); err != nil { + return nil, err + } + + stmt := &ast.CreateTableStmt{ + IfNotExists: ifNotExists, + Name: name, + } + + // Parse column definitions and optional trailing PRIMARY KEY. + for { + if p.cur.Type == tokPRIMARY { + pk, err := p.parsePrimaryKeyElement() + if err != nil { + return nil, err + } + stmt.PrimaryKey = pk + break + } + col, err := p.parseColumnDefinition() + if err != nil { + return nil, err + } + stmt.Columns = append(stmt.Columns, col) + if !p.match(tokCOMMA) { + break + } + } + + if _, err := p.expect(tokRPAREN); err != nil { + return nil, err + } + + // Parse optional WITH clause. + if p.match(tokWITH) { + if err := p.parseTableOptions(stmt); err != nil { + return nil, err + } + } + + stmt.Loc = p.makeLoc(start) + return stmt, nil +} + +func (p *Parser) parseColumnDefinition() (*ast.ColumnDef, error) { + start := p.curLoc() + name, err := p.parseIdentifier() + if err != nil { + return nil, err + } + dt, err := p.parseDataType() + if err != nil { + return nil, err + } + col := &ast.ColumnDef{Name: name, Type: dt} + + if p.cur.Type == tokSTATIC { + col.Static = true + p.advance() + } + if p.cur.Type == tokPRIMARY && p.peekNext().Type == tokKEY { + col.PrimaryKey = true + p.advance() // PRIMARY + p.advance() // KEY + } + + col.Loc = p.makeLoc(start) + return col, nil +} + +func (p *Parser) parsePrimaryKeyElement() (*ast.PrimaryKeyDef, error) { + start := p.curLoc() + p.advance() // PRIMARY + if err := p.expectKeyword(tokKEY); err != nil { + return nil, err + } + if _, err := p.expect(tokLPAREN); err != nil { + return nil, err + } + + pk := &ast.PrimaryKeyDef{} + + // Composite key: ((pk1, pk2), ck1, ck2) + if p.cur.Type == tokLPAREN { + p.advance() + for { + col, err := p.parseIdentifier() + if err != nil { + return nil, err + } + pk.PartitionKeys = append(pk.PartitionKeys, col) + if !p.match(tokCOMMA) { + break + } + } + if _, err := p.expect(tokRPAREN); err != nil { + return nil, err + } + } else { + // Single or compound key: pk, ck1, ck2 + col, err := p.parseIdentifier() + if err != nil { + return nil, err + } + pk.PartitionKeys = append(pk.PartitionKeys, col) + } + + // Clustering keys + for p.match(tokCOMMA) { + col, err := p.parseIdentifier() + if err != nil { + return nil, err + } + pk.ClusteringKeys = append(pk.ClusteringKeys, col) + } + + if _, err := p.expect(tokRPAREN); err != nil { + return nil, err + } + + pk.Loc = p.makeLoc(start) + return pk, nil +} + +func (p *Parser) parseTableOptions(stmt *ast.CreateTableStmt) error { + for { + // Check for COMPACT STORAGE + if p.cur.Type == tokCOMPACT { + p.advance() + if err := p.expectKeyword(tokSTORAGE); err != nil { + return err + } + stmt.CompactStorage = true + } else if p.cur.Type == tokCLUSTERING { + co, err := p.parseClusteringOrderClause() + if err != nil { + return err + } + stmt.ClusteringOrders = co + } else { + opt, err := p.parseTableOption() + if err != nil { + return err + } + stmt.Options = append(stmt.Options, opt) + } + if !p.match(tokAND) { + break + } + } + return nil +} + +func (p *Parser) parseTableOption() (*ast.TableOption, error) { + start := p.curLoc() + name, err := p.parseIdentifier() + if err != nil { + return nil, err + } + if _, err := p.expect(tokEQ); err != nil { + return nil, err + } + var value ast.ExprNode + if p.cur.Type == tokLBRACE { + hash, err := p.parseOptionHash() + if err != nil { + return nil, err + } + value = hash + } else { + val, err := p.parseConstant() + if err != nil { + return nil, err + } + value = val + } + return &ast.TableOption{Name: name, Value: value, Loc: p.makeLoc(start)}, nil +} + +func (p *Parser) parseClusteringOrderClause() ([]*ast.ClusteringOrder, error) { + p.advance() // CLUSTERING + if err := p.expectKeyword(tokORDER); err != nil { + return nil, err + } + if err := p.expectKeyword(tokBY); err != nil { + return nil, err + } + if _, err := p.expect(tokLPAREN); err != nil { + return nil, err + } + + var orders []*ast.ClusteringOrder + for { + start := p.curLoc() + col, err := p.parseIdentifier() + if err != nil { + return nil, err + } + dir := "" + if p.cur.Type == tokASC { + dir = "ASC" + p.advance() + } else if p.cur.Type == tokDESC { + dir = "DESC" + p.advance() + } + orders = append(orders, &ast.ClusteringOrder{Column: col, Direction: dir, Loc: p.makeLoc(start)}) + if !p.match(tokCOMMA) { + break + } + } + + if _, err := p.expect(tokRPAREN); err != nil { + return nil, err + } + return orders, nil +} + +func (p *Parser) parseAlterTable() (*ast.AlterTableStmt, error) { + start := p.curLoc() + p.advance() // TABLE + + ifExists := p.parseIfExists() + + name, err := p.parseQualifiedName() + if err != nil { + return nil, err + } + + stmt := &ast.AlterTableStmt{IfExists: ifExists, Name: name} + + switch p.cur.Type { + case tokADD: + stmt.Op = ast.AlterTableAdd + p.advance() + ine, err := p.parseIfNotExists() + if err != nil { + return nil, err + } + stmt.AddIfNotExists = ine + for { + col, err := p.parseColumnDefinition() + if err != nil { + return nil, err + } + stmt.AddColumns = append(stmt.AddColumns, col) + if !p.match(tokCOMMA) { + break + } + } + case tokDROP: + p.advance() + if p.cur.Type == tokCOMPACT { + stmt.Op = ast.AlterTableDropCompactStorage + p.advance() // COMPACT + if err := p.expectKeyword(tokSTORAGE); err != nil { + return nil, err + } + } else { + stmt.Op = ast.AlterTableDrop + stmt.DropIfExists = p.parseIfExists() + for { + col, err := p.parseIdentifier() + if err != nil { + return nil, err + } + stmt.DropColumns = append(stmt.DropColumns, col) + if !p.match(tokCOMMA) { + break + } + } + } + case tokRENAME: + stmt.Op = ast.AlterTableRename + p.advance() + stmt.RenameIfExists = p.parseIfExists() + for { + rStart := p.curLoc() + from, err := p.parseIdentifier() + if err != nil { + return nil, err + } + if err := p.expectKeyword(tokTO); err != nil { + return nil, err + } + to, err := p.parseIdentifier() + if err != nil { + return nil, err + } + stmt.RenameItems = append(stmt.RenameItems, &ast.AlterTableRenameItem{ + From: from, + To: to, + Loc: p.makeLoc(rStart), + }) + if !p.match(tokAND) { + break + } + } + case tokWITH: + stmt.Op = ast.AlterTableWith + p.advance() + for { + opt, err := p.parseTableOption() + if err != nil { + return nil, err + } + stmt.Options = append(stmt.Options, opt) + if !p.match(tokAND) { + break + } + } + default: + return nil, p.errorf("expected ADD, DROP, RENAME, or WITH after ALTER TABLE, got %s", p.tokenDesc()) + } + + stmt.Loc = p.makeLoc(start) + return stmt, nil +} + +func (p *Parser) parseDropTable() (*ast.DropTableStmt, error) { + start := p.curLoc() + p.advance() // TABLE + + ifExists := p.parseIfExists() + + name, err := p.parseQualifiedName() + if err != nil { + return nil, err + } + + return &ast.DropTableStmt{ + IfExists: ifExists, + Name: name, + Loc: p.makeLoc(start), + }, nil +} diff --git a/cassandra/parser/ddl_type.go b/cassandra/parser/ddl_type.go new file mode 100644 index 00000000..e589aa18 --- /dev/null +++ b/cassandra/parser/ddl_type.go @@ -0,0 +1,168 @@ +package parser + +import ( + "github.com/bytebase/omni/cassandra/ast" +) + +func (p *Parser) parseCreateType() (*ast.CreateTypeStmt, error) { + start := p.curLoc() + p.advance() // TYPE + + ifNotExists, err := p.parseIfNotExists() + if err != nil { + return nil, err + } + + name, err := p.parseQualifiedName() + if err != nil { + return nil, err + } + + if _, err := p.expect(tokLPAREN); err != nil { + return nil, err + } + + stmt := &ast.CreateTypeStmt{ + IfNotExists: ifNotExists, + Name: name, + } + + for { + fStart := p.curLoc() + col, err := p.parseIdentifier() + if err != nil { + return nil, err + } + dt, err := p.parseDataType() + if err != nil { + return nil, err + } + stmt.Fields = append(stmt.Fields, &ast.ColumnDef{ + Name: col, + Type: dt, + Loc: p.makeLoc(fStart), + }) + if !p.match(tokCOMMA) { + break + } + } + + if _, err := p.expect(tokRPAREN); err != nil { + return nil, err + } + + stmt.Loc = p.makeLoc(start) + return stmt, nil +} + +func (p *Parser) parseAlterType() (*ast.AlterTypeStmt, error) { + start := p.curLoc() + p.advance() // TYPE + + ifExists := p.parseIfExists() + + name, err := p.parseQualifiedName() + if err != nil { + return nil, err + } + + stmt := &ast.AlterTypeStmt{IfExists: ifExists, Name: name} + + switch p.cur.Type { + case tokALTER: + stmt.Op = ast.AlterTypeAlter + p.advance() + col, err := p.parseIdentifier() + if err != nil { + return nil, err + } + if err := p.expectKeyword(tokTYPE); err != nil { + return nil, err + } + dt, err := p.parseDataType() + if err != nil { + return nil, err + } + stmt.AlterColumn = col + stmt.AlterType = dt + + case tokADD: + stmt.Op = ast.AlterTypeAdd + p.advance() + ine, err := p.parseIfNotExists() + if err != nil { + return nil, err + } + stmt.AddIfNotExists = ine + for { + fStart := p.curLoc() + col, err := p.parseIdentifier() + if err != nil { + return nil, err + } + dt, err := p.parseDataType() + if err != nil { + return nil, err + } + stmt.AddFields = append(stmt.AddFields, &ast.ColumnDef{ + Name: col, + Type: dt, + Loc: p.makeLoc(fStart), + }) + if !p.match(tokCOMMA) { + break + } + } + + case tokRENAME: + stmt.Op = ast.AlterTypeRename + p.advance() + stmt.RenameIfExists = p.parseIfExists() + for { + rStart := p.curLoc() + from, err := p.parseIdentifier() + if err != nil { + return nil, err + } + if err := p.expectKeyword(tokTO); err != nil { + return nil, err + } + to, err := p.parseIdentifier() + if err != nil { + return nil, err + } + stmt.Renames = append(stmt.Renames, &ast.AlterTypeRenameItem{ + From: from, + To: to, + Loc: p.makeLoc(rStart), + }) + if !p.match(tokAND) { + break + } + } + + default: + return nil, p.errorf("expected ALTER, ADD, or RENAME after ALTER TYPE, got %s", p.tokenDesc()) + } + + stmt.Loc = p.makeLoc(start) + return stmt, nil +} + +func (p *Parser) parseDropType() (*ast.DropTypeStmt, error) { + start := p.curLoc() + p.advance() // TYPE + + ifExists := p.parseIfExists() + + name, err := p.parseQualifiedName() + if err != nil { + return nil, err + } + + return &ast.DropTypeStmt{ + IfExists: ifExists, + Name: name, + Loc: p.makeLoc(start), + }, nil +} diff --git a/cassandra/parser/delete.go b/cassandra/parser/delete.go new file mode 100644 index 00000000..ef4bd240 --- /dev/null +++ b/cassandra/parser/delete.go @@ -0,0 +1,131 @@ +package parser + +import ( + "github.com/bytebase/omni/cassandra/ast" +) + +// parseDelete parses a DELETE statement: +// +// DELETE [deleteColumnList] FROM [keyspace.]table [USING TIMESTAMP n] WHERE relationElements [IF EXISTS | IF ifConditionList] +// deleteColumnList: deleteColumnItem (',' deleteColumnItem)* +// deleteColumnItem: IDENT | IDENT '[' (string|decimal) ']' +func (p *Parser) parseDelete() (*ast.DeleteStmt, error) { + start := p.curLoc() + if err := p.expectKeyword(tokDELETE); err != nil { + return nil, err + } + + stmt := &ast.DeleteStmt{} + + // Parse optional delete column list (appears before FROM). + // If the next token is not FROM, we have a column list. + if p.cur.Type != tokFROM { + cols, err := p.parseDeleteColumnList() + if err != nil { + return nil, err + } + stmt.Columns = cols + } + + // FROM table + if err := p.expectKeyword(tokFROM); err != nil { + return nil, err + } + + table, err := p.parseQualifiedName() + if err != nil { + return nil, err + } + stmt.From = table + + // Optional USING TIMESTAMP clause + using, err := p.parseUsingClause() + if err != nil { + return nil, err + } + stmt.Using = using + + // WHERE clause + where, err := p.parseWhereClause() + if err != nil { + return nil, err + } + stmt.Where = where + + // Optional IF EXISTS or IF conditions + if p.cur.Type == tokIF { + if p.peekNext().Type == tokEXISTS { + p.advance() // IF + p.advance() // EXISTS + stmt.IfExists = true + } else { + conds, err := p.parseIfConditions() + if err != nil { + return nil, err + } + stmt.IfConditions = conds + } + } + + stmt.Loc = p.makeLoc(start) + return stmt, nil +} + +// parseDeleteColumnList parses: deleteColumnItem (',' deleteColumnItem)* +func (p *Parser) parseDeleteColumnList() ([]ast.ExprNode, error) { + var cols []ast.ExprNode + first, err := p.parseDeleteColumnItem() + if err != nil { + return nil, err + } + cols = append(cols, first) + for p.match(tokCOMMA) { + item, err := p.parseDeleteColumnItem() + if err != nil { + return nil, err + } + cols = append(cols, item) + } + return cols, nil +} + +// parseDeleteColumnItem parses: IDENT | IDENT '.' IDENT | IDENT '[' expression ']' +func (p *Parser) parseDeleteColumnItem() (ast.ExprNode, error) { + start := p.curLoc() + + ident, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + if p.cur.Type == tokDOT { + p.advance() + field, err := p.parseIdentifier() + if err != nil { + return nil, err + } + return &ast.DotAccess{ + Object: ident, + Field: field, + Loc: p.makeLoc(start), + }, nil + } + + if p.cur.Type == tokLBRACK { + p.advance() // [ + idx, err := p.parseExpression() + if err != nil { + return nil, err + } + if _, err := p.expect(tokRBRACK); err != nil { + return nil, err + } + return &ast.IndexAccess{ + Collection: ident, + Index: idx, + Loc: p.makeLoc(start), + }, nil + } + + return ident, nil +} diff --git a/cassandra/parser/dispatch.go b/cassandra/parser/dispatch.go new file mode 100644 index 00000000..634e6901 --- /dev/null +++ b/cassandra/parser/dispatch.go @@ -0,0 +1,435 @@ +package parser + +import ( + "github.com/bytebase/omni/cassandra/ast" +) + +// parseCreate dispatches CREATE statements. +func (p *Parser) parseCreate() (ast.StmtNode, error) { + start := p.curLoc() + p.advance() // CREATE + + // Handle OR REPLACE for FUNCTION/AGGREGATE. + orReplace := false + if p.cur.Type == tokOR { + p.advance() // OR + if err := p.expectKeyword(tokREPLACE); err != nil { + return nil, err + } + orReplace = true + } + + // Handle CUSTOM INDEX. + if p.cur.Type == tokCUSTOM { + stmt, err := p.parseCreateIndex() + if err != nil { + return nil, err + } + stmt.IsCustom = true + stmt.Loc.Start = start + return stmt, nil + } + + switch p.cur.Type { + case tokKEYSPACE: + stmt, err := p.parseCreateKeyspace() + if err != nil { + return nil, err + } + stmt.Loc.Start = start + return stmt, nil + case tokTABLE: + stmt, err := p.parseCreateTable() + if err != nil { + return nil, err + } + stmt.Loc.Start = start + return stmt, nil + case tokINDEX: + stmt, err := p.parseCreateIndex() + if err != nil { + return nil, err + } + stmt.Loc.Start = start + return stmt, nil + case tokTYPE: + stmt, err := p.parseCreateType() + if err != nil { + return nil, err + } + stmt.Loc.Start = start + return stmt, nil + case tokMATERIALIZED: + p.advance() // MATERIALIZED + if err := p.expectKeyword(tokVIEW); err != nil { + return nil, err + } + return p.parseCreateMVInline(start) + case tokFUNCTION: + stmt, err := p.parseCreateFunction() + if err != nil { + return nil, err + } + stmt.OrReplace = orReplace + stmt.Loc.Start = start + return stmt, nil + case tokAGGREGATE: + stmt, err := p.parseCreateAggregate() + if err != nil { + return nil, err + } + stmt.OrReplace = orReplace + stmt.Loc.Start = start + return stmt, nil + case tokTRIGGER: + stmt, err := p.parseCreateTrigger() + if err != nil { + return nil, err + } + stmt.Loc.Start = start + return stmt, nil + case tokROLE: + stmt, err := p.parseCreateRole() + if err != nil { + return nil, err + } + stmt.Loc.Start = start + return stmt, nil + case tokUSER: + stmt, err := p.parseCreateUser() + if err != nil { + return nil, err + } + stmt.Loc.Start = start + return stmt, nil + default: + return nil, p.errorf("expected object type after CREATE, got %s", p.tokenDesc()) + } +} + +// parseCreateMVInline handles CREATE MATERIALIZED VIEW when we've already consumed +// CREATE MATERIALIZED VIEW tokens. +func (p *Parser) parseCreateMVInline(start int) (*ast.CreateMVStmt, error) { + ifNotExists, err := p.parseIfNotExists() + if err != nil { + return nil, err + } + + name, err := p.parseQualifiedName() + if err != nil { + return nil, err + } + + if err := p.expectKeyword(tokAS); err != nil { + return nil, err + } + if err := p.expectKeyword(tokSELECT); err != nil { + return nil, err + } + + stmt := &ast.CreateMVStmt{ + IfNotExists: ifNotExists, + Name: name, + } + + // Select columns or * + if p.match(tokSTAR) { + stmt.SelectAll = true + } else { + for { + col, err := p.parseIdentifier() + if err != nil { + return nil, err + } + stmt.SelectColumns = append(stmt.SelectColumns, col) + if !p.match(tokCOMMA) { + break + } + } + } + + if err := p.expectKeyword(tokFROM); err != nil { + return nil, err + } + from, err := p.parseQualifiedName() + if err != nil { + return nil, err + } + stmt.FromTable = from + + // WHERE + if err := p.expectKeyword(tokWHERE); err != nil { + return nil, err + } + + for { + if !p.isColumnNotNull() { + break + } + col, err := p.parseIdentifier() + if err != nil { + return nil, err + } + if err := p.expectKeyword(tokIS); err != nil { + return nil, err + } + if err := p.expectKeyword(tokNOT); err != nil { + return nil, err + } + if err := p.expectKeyword(tokNULL); err != nil { + return nil, err + } + stmt.WhereNotNull = append(stmt.WhereNotNull, col) + if p.cur.Type != tokAND { + break + } + next := p.peekNext() + if !p.isNextColumnNotNull(next) { + break + } + p.advance() // AND + } + + if p.match(tokAND) { + rels, err := p.parseRelationElements() + if err != nil { + return nil, err + } + stmt.WhereRelations = rels + } + + // PRIMARY KEY + if err := p.expectKeyword(tokPRIMARY); err != nil { + return nil, err + } + if err := p.expectKeyword(tokKEY); err != nil { + return nil, err + } + if _, err := p.expect(tokLPAREN); err != nil { + return nil, err + } + + pk := &ast.PrimaryKeyDef{Loc: ast.Loc{Start: p.curLoc()}} + if p.cur.Type == tokLPAREN { + p.advance() + for { + col, err := p.parseIdentifier() + if err != nil { + return nil, err + } + pk.PartitionKeys = append(pk.PartitionKeys, col) + if !p.match(tokCOMMA) { + break + } + } + if _, err := p.expect(tokRPAREN); err != nil { + return nil, err + } + } else { + col, err := p.parseIdentifier() + if err != nil { + return nil, err + } + pk.PartitionKeys = append(pk.PartitionKeys, col) + } + + for p.match(tokCOMMA) { + col, err := p.parseIdentifier() + if err != nil { + return nil, err + } + pk.ClusteringKeys = append(pk.ClusteringKeys, col) + } + + if _, err := p.expect(tokRPAREN); err != nil { + return nil, err + } + pk.Loc = p.makeLoc(pk.Loc.Start) + stmt.PrimaryKey = pk + + if p.match(tokWITH) { + if err := p.parseMVOptions(stmt); err != nil { + return nil, err + } + } + + stmt.Loc = p.makeLoc(start) + return stmt, nil +} + +// parseAlter dispatches ALTER statements. +func (p *Parser) parseAlter() (ast.StmtNode, error) { + start := p.curLoc() + p.advance() // ALTER + + switch p.cur.Type { + case tokKEYSPACE: + stmt, err := p.parseAlterKeyspace() + if err != nil { + return nil, err + } + stmt.Loc.Start = start + return stmt, nil + case tokTABLE: + stmt, err := p.parseAlterTable() + if err != nil { + return nil, err + } + stmt.Loc.Start = start + return stmt, nil + case tokTYPE: + stmt, err := p.parseAlterType() + if err != nil { + return nil, err + } + stmt.Loc.Start = start + return stmt, nil + case tokMATERIALIZED: + p.advance() // MATERIALIZED + if err := p.expectKeyword(tokVIEW); err != nil { + return nil, err + } + stmt, err := p.parseAlterMV() + if err != nil { + return nil, err + } + stmt.Loc.Start = start + return stmt, nil + case tokROLE: + stmt, err := p.parseAlterRole() + if err != nil { + return nil, err + } + stmt.Loc.Start = start + return stmt, nil + case tokUSER: + stmt, err := p.parseAlterUser() + if err != nil { + return nil, err + } + stmt.Loc.Start = start + return stmt, nil + default: + return nil, p.errorf("expected object type after ALTER, got %s", p.tokenDesc()) + } +} + +// parseDrop dispatches DROP statements. +func (p *Parser) parseDrop() (ast.StmtNode, error) { + start := p.curLoc() + p.advance() // DROP + + switch p.cur.Type { + case tokKEYSPACE: + stmt, err := p.parseDropKeyspace() + if err != nil { + return nil, err + } + stmt.Loc.Start = start + return stmt, nil + case tokTABLE: + stmt, err := p.parseDropTable() + if err != nil { + return nil, err + } + stmt.Loc.Start = start + return stmt, nil + case tokINDEX: + stmt, err := p.parseDropIndex() + if err != nil { + return nil, err + } + stmt.Loc.Start = start + return stmt, nil + case tokTYPE: + stmt, err := p.parseDropType() + if err != nil { + return nil, err + } + stmt.Loc.Start = start + return stmt, nil + case tokMATERIALIZED: + p.advance() // MATERIALIZED + if err := p.expectKeyword(tokVIEW); err != nil { + return nil, err + } + stmt, err := p.parseDropMV() + if err != nil { + return nil, err + } + stmt.Loc.Start = start + return stmt, nil + case tokFUNCTION: + stmt, err := p.parseDropFunction() + if err != nil { + return nil, err + } + stmt.Loc.Start = start + return stmt, nil + case tokAGGREGATE: + stmt, err := p.parseDropAggregate() + if err != nil { + return nil, err + } + stmt.Loc.Start = start + return stmt, nil + case tokTRIGGER: + stmt, err := p.parseDropTrigger() + if err != nil { + return nil, err + } + stmt.Loc.Start = start + return stmt, nil + case tokROLE: + stmt, err := p.parseDropRole() + if err != nil { + return nil, err + } + stmt.Loc.Start = start + return stmt, nil + case tokUSER: + stmt, err := p.parseDropUser() + if err != nil { + return nil, err + } + stmt.Loc.Start = start + return stmt, nil + default: + return nil, p.errorf("expected object type after DROP, got %s", p.tokenDesc()) + } +} + +// parseTruncate parses TRUNCATE [TABLE] [keyspace.]table. +func (p *Parser) parseTruncate() (*ast.TruncateStmt, error) { + start := p.curLoc() + p.advance() // TRUNCATE + + // Optional TABLE keyword + p.match(tokTABLE) + + name, err := p.parseQualifiedName() + if err != nil { + return nil, err + } + + return &ast.TruncateStmt{ + Table: name, + Loc: p.makeLoc(start), + }, nil +} + +// parseUse parses USE keyspace. +func (p *Parser) parseUse() (*ast.UseStmt, error) { + start := p.curLoc() + p.advance() // USE + + ks, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + return &ast.UseStmt{ + Keyspace: ks, + Loc: p.makeLoc(start), + }, nil +} diff --git a/cassandra/parser/errors.go b/cassandra/parser/errors.go new file mode 100644 index 00000000..b7e4cdc7 --- /dev/null +++ b/cassandra/parser/errors.go @@ -0,0 +1,46 @@ +package parser + +import ( + "fmt" + "sort" + + "github.com/bytebase/omni/cassandra/ast" +) + +// ParseError represents a syntax error during CQL parsing. +type ParseError struct { + Message string + Loc ast.Loc + Line int + Column int + Near string +} + +func (e *ParseError) Error() string { + if e.Near != "" { + return fmt.Sprintf("line %d column %d: %s at or near %q", e.Line, e.Column, e.Message, e.Near) + } + return fmt.Sprintf("line %d column %d: %s", e.Line, e.Column, e.Message) +} + +type lineIndex []int + +func buildLineIndex(s string) lineIndex { + idx := lineIndex{0} + for i := 0; i < len(s); i++ { + if s[i] == '\n' { + idx = append(idx, i+1) + } + } + return idx +} + +func offsetToLineCol(idx lineIndex, offset int) (int, int) { + line := sort.SearchInts(idx, offset+1) + col := offset - idx[line-1] + 1 + return line, col +} + +func locFromOffsets(start, end int) ast.Loc { + return ast.Loc{Start: start, End: end} +} diff --git a/cassandra/parser/expr.go b/cassandra/parser/expr.go new file mode 100644 index 00000000..66b2383e --- /dev/null +++ b/cassandra/parser/expr.go @@ -0,0 +1,609 @@ +package parser + +import ( + "strings" + + "github.com/bytebase/omni/cassandra/ast" +) + +// --------------------------------------------------------------------------- +// Constants +// --------------------------------------------------------------------------- + +// parseConstant parses a CQL constant literal: string, integer, float, uuid, +// hex, boolean (true/false), null, or code block ($$...$$). +func (p *Parser) parseConstant() (ast.ExprNode, error) { + tok := p.cur + switch tok.Type { + case tokSTRING: + p.advance() + return &ast.StringLit{Val: tok.Str, Loc: ast.Loc{Start: tok.Loc, End: tok.End}}, nil + + case tokINTEGER: + p.advance() + return &ast.IntegerLit{Val: tok.Str, Loc: ast.Loc{Start: tok.Loc, End: tok.End}}, nil + + case tokFLOAT: + p.advance() + return &ast.FloatLit{Val: tok.Str, Loc: ast.Loc{Start: tok.Loc, End: tok.End}}, nil + + case tokUUID: + p.advance() + return &ast.UUIDLit{Val: tok.Str, Loc: ast.Loc{Start: tok.Loc, End: tok.End}}, nil + + case tokHEX: + p.advance() + return &ast.HexLit{Val: tok.Str, Loc: ast.Loc{Start: tok.Loc, End: tok.End}}, nil + + case tokTRUE: + p.advance() + return &ast.BoolLit{Val: true, Loc: ast.Loc{Start: tok.Loc, End: tok.End}}, nil + + case tokFALSE: + p.advance() + return &ast.BoolLit{Val: false, Loc: ast.Loc{Start: tok.Loc, End: tok.End}}, nil + + case tokNULL: + p.advance() + return &ast.NullLit{Loc: ast.Loc{Start: tok.Loc, End: tok.End}}, nil + + case tokNAN: + p.advance() + return &ast.FloatLit{Val: "NaN", Loc: ast.Loc{Start: tok.Loc, End: tok.End}}, nil + + case tokINFINITY: + p.advance() + return &ast.FloatLit{Val: "Infinity", Loc: ast.Loc{Start: tok.Loc, End: tok.End}}, nil + + case tokCODEBLOCK: + p.advance() + return &ast.CodeBlock{Val: tok.Str, Loc: ast.Loc{Start: tok.Loc, End: tok.End}}, nil + + case tokMINUS: + start := tok.Loc + p.advance() + next := p.cur + switch next.Type { + case tokINTEGER: + p.advance() + return &ast.IntegerLit{Val: "-" + next.Str, Loc: ast.Loc{Start: start, End: next.End}}, nil + case tokFLOAT: + p.advance() + return &ast.FloatLit{Val: "-" + next.Str, Loc: ast.Loc{Start: start, End: next.End}}, nil + case tokINFINITY: + p.advance() + return &ast.FloatLit{Val: "-Infinity", Loc: ast.Loc{Start: start, End: next.End}}, nil + case tokNAN: + p.advance() + return &ast.FloatLit{Val: "-NaN", Loc: ast.Loc{Start: start, End: next.End}}, nil + default: + return nil, p.errorf("expected number after '-', got %s", p.tokenDesc()) + } + + default: + return nil, p.errorf("expected constant, got %s", p.tokenDesc()) + } +} + +// --------------------------------------------------------------------------- +// Expressions +// --------------------------------------------------------------------------- + +// parseExpression parses a CQL expression: constant, function call, +// identifier, or collection literal (map, set, list, tuple). After parsing +// the primary expression it checks for index access (e.g. col[0]). +func (p *Parser) parseExpression() (ast.ExprNode, error) { + var expr ast.ExprNode + var err error + + switch { + // Collection literals. + case p.cur.Type == tokLBRACE || p.cur.Type == tokLBRACK: + expr, err = p.parseCollectionLiteral() + + // Tuple literal: (val, val, ...) + case p.cur.Type == tokLPAREN: + start := p.curLoc() + p.advance() // ( + elems, err2 := p.parseExpressionList() + if err2 != nil { + return nil, err2 + } + if _, err2 = p.expect(tokRPAREN); err2 != nil { + return nil, err2 + } + expr = &ast.TupleLit{Elements: elems, Loc: p.makeLoc(start)} + + // Literal keywords that would otherwise be swallowed by isIdentLike. + case p.cur.Type == tokNULL || p.cur.Type == tokTRUE || p.cur.Type == tokFALSE: + expr, err = p.parseConstant() + + // NaN and Infinity literals + case p.cur.Type == tokNAN || p.cur.Type == tokINFINITY: + expr, err = p.parseConstant() + + // CAST(expr AS type) + case p.cur.Type == tokCAST: + expr, err = p.parseCast() + + // Bind markers: ? or :name + case p.cur.Type == tokQMARK: + expr = &ast.BindMarker{Loc: ast.Loc{Start: p.cur.Loc, End: p.cur.End}} + p.advance() + + case p.cur.Type == tokCOLON: + start := p.curLoc() + p.advance() // : + name, err2 := p.parseIdentifier() + if err2 != nil { + return nil, err2 + } + expr = &ast.BindMarker{Name: name.Name, Loc: p.makeLoc(start)} + + // Identifier or function call. + case isIdentLike(p.cur.Type): + if p.peekNext().Type == tokLPAREN { + expr, err = p.parseFunctionCall() + } else { + expr, err = p.parseIdentifier() + } + + // Everything else: try as a constant. + default: + expr, err = p.parseConstant() + } + + if err != nil { + return nil, err + } + + // Post-fix: index access col[idx] + for p.cur.Type == tokLBRACK { + start := expr.GetLoc().Start + p.advance() // [ + idx, err := p.parseExpression() + if err != nil { + return nil, err + } + if _, err := p.expect(tokRBRACK); err != nil { + return nil, err + } + expr = &ast.IndexAccess{Collection: expr, Index: idx, Loc: p.makeLoc(start)} + } + + return expr, nil +} + +// parseCast parses CAST(expr AS type). +func (p *Parser) parseCast() (*ast.CastExpr, error) { + start := p.curLoc() + p.advance() // CAST + if _, err := p.expect(tokLPAREN); err != nil { + return nil, err + } + expr, err := p.parseExpression() + if err != nil { + return nil, err + } + if err := p.expectKeyword(tokAS); err != nil { + return nil, err + } + dt, err := p.parseDataType() + if err != nil { + return nil, err + } + if _, err := p.expect(tokRPAREN); err != nil { + return nil, err + } + return &ast.CastExpr{Expr: expr, Type: dt, Loc: p.makeLoc(start)}, nil +} + +// parseExpressionList parses a comma-separated list of expressions. +func (p *Parser) parseExpressionList() ([]ast.ExprNode, error) { + first, err := p.parseExpression() + if err != nil { + return nil, err + } + list := []ast.ExprNode{first} + for p.match(tokCOMMA) { + next, err := p.parseExpression() + if err != nil { + return nil, err + } + list = append(list, next) + } + return list, nil +} + +// --------------------------------------------------------------------------- +// Function calls +// --------------------------------------------------------------------------- + +// builtinFuncToken maps keyword token types that represent built-in CQL +// functions to the canonical (lowercase) function name used in the AST. +var builtinFuncToken = map[int]string{ + tokNOW: "now", + tokUUID_KW: "uuid", + tokFROMJSON: "fromjson", + tokTOJSON: "tojson", + tokMINTIMEUUID: "mintimeuuid", + tokMAXTIMEUUID: "maxtimeuuid", + tokDATETIMENOW: "datetimenow", + tokCURRENTTIMESTAMP: "currenttimestamp", + tokCURRENTDATE: "currentdate", + tokCURRENTTIME: "currenttime", + tokCURRENTTIMEUUID: "currenttimeuuid", +} + +// parseFunctionCall parses a function call. It handles both built-in +// function keywords (now(), uuid(), fromJson(), ...) and generic +// identifier-based calls (token(col), writetime(col), ...). +func (p *Parser) parseFunctionCall() (ast.ExprNode, error) { + tok := p.cur + + // Built-in function keywords. + if name, ok := builtinFuncToken[tok.Type]; ok { + start := tok.Loc + ident := &ast.Identifier{Name: name, Loc: ast.Loc{Start: tok.Loc, End: tok.End}} + p.advance() // consume keyword + + fc, err := p.parseFunctionCallWithName(ident) + if err != nil { + return nil, err + } + fc.Loc = p.makeLoc(start) + return fc, nil + } + + // Generic function call: name(args...) or ks.name(args...). + name, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + // Handle dotted function name: keyspace.func(...) + if p.cur.Type == tokDOT { + p.advance() + second, err := p.parseIdentifier() + if err != nil { + return nil, err + } + // Combine into "ks.func" identifier preserving full location. + combined := &ast.Identifier{ + Name: name.Name + "." + second.Name, + Loc: ast.Loc{Start: name.Loc.Start, End: second.Loc.End}, + } + return p.parseFunctionCallWithName(combined) + } + + return p.parseFunctionCallWithName(name) +} + +// parseFunctionCallWithName parses the (args...) or (*) part of a function +// call, given the already-consumed function name identifier. +func (p *Parser) parseFunctionCallWithName(name *ast.Identifier) (*ast.FunctionCall, error) { + start := name.Loc.Start + if _, err := p.expect(tokLPAREN); err != nil { + return nil, err + } + + // count(*) or similar star argument. + if p.cur.Type == tokSTAR { + p.advance() + if _, err := p.expect(tokRPAREN); err != nil { + return nil, err + } + return &ast.FunctionCall{ + Name: name, + Star: true, + Loc: p.makeLoc(start), + }, nil + } + + // Empty argument list: func() + if p.cur.Type == tokRPAREN { + p.advance() + return &ast.FunctionCall{ + Name: name, + Loc: p.makeLoc(start), + }, nil + } + + args, err := p.parseFunctionArgs() + if err != nil { + return nil, err + } + if _, err := p.expect(tokRPAREN); err != nil { + return nil, err + } + return &ast.FunctionCall{ + Name: name, + Args: args, + Loc: p.makeLoc(start), + }, nil +} + +// parseFunctionArgs parses comma-separated function arguments. Each argument +// can be a constant, an identifier, a collection literal, or a nested +// function call. +func (p *Parser) parseFunctionArgs() ([]ast.ExprNode, error) { + var args []ast.ExprNode + for { + arg, err := p.parseFunctionArg() + if err != nil { + return nil, err + } + args = append(args, arg) + if !p.match(tokCOMMA) { + break + } + } + return args, nil +} + +// parseFunctionArg parses a single function argument. +func (p *Parser) parseFunctionArg() (ast.ExprNode, error) { + // Collection literals. + if p.cur.Type == tokLBRACE || p.cur.Type == tokLBRACK { + return p.parseCollectionLiteral() + } + + // Tuple literal inside function args. + if p.cur.Type == tokLPAREN { + start := p.curLoc() + p.advance() // ( + elems, err := p.parseExpressionList() + if err != nil { + return nil, err + } + if _, err := p.expect(tokRPAREN); err != nil { + return nil, err + } + return &ast.TupleLit{Elements: elems, Loc: p.makeLoc(start)}, nil + } + + // Bind markers + if p.cur.Type == tokQMARK { + m := &ast.BindMarker{Loc: ast.Loc{Start: p.cur.Loc, End: p.cur.End}} + p.advance() + return m, nil + } + if p.cur.Type == tokCOLON { + start := p.curLoc() + p.advance() + name, err := p.parseIdentifier() + if err != nil { + return nil, err + } + return &ast.BindMarker{Name: name.Name, Loc: p.makeLoc(start)}, nil + } + + // CAST + if p.cur.Type == tokCAST { + return p.parseCast() + } + + // Identifier or nested function call. + if isIdentLike(p.cur.Type) { + if p.peekNext().Type == tokLPAREN { + return p.parseFunctionCall() + } + return p.parseIdentifier() + } + + // Constant. + return p.parseConstant() +} + +// --------------------------------------------------------------------------- +// Collection literals +// --------------------------------------------------------------------------- + +// parseCollectionLiteral dispatches to map/set literal for { ... } or list +// literal for [ ... ]. +func (p *Parser) parseCollectionLiteral() (ast.ExprNode, error) { + switch p.cur.Type { + case tokLBRACE: + return p.parseMapOrSetLiteral() + case tokLBRACK: + return p.parseListLiteral() + default: + return nil, p.errorf("expected '{' or '[' for collection literal, got %s", p.tokenDesc()) + } +} + +// parseMapOrSetLiteral parses a { ... } literal. After consuming the +// opening brace it looks ahead: if the first element is followed by a colon, +// it is a map literal {k:v, ...}; otherwise it is a set literal {v, ...}. +// An empty {} is treated as an empty map. +func (p *Parser) parseMapOrSetLiteral() (ast.ExprNode, error) { + start := p.curLoc() + p.advance() // { + + // Empty braces: empty map. + if p.cur.Type == tokRBRACE { + p.advance() + return &ast.MapLit{Loc: p.makeLoc(start)}, nil + } + + // Parse the first element to decide map vs set. + first, err := p.parseExpression() + if err != nil { + return nil, err + } + + if p.cur.Type == tokCOLON { + // Map literal. + return p.finishMapLiteral(start, first) + } + + // Set literal. + return p.finishSetLiteral(start, first) +} + +// finishMapLiteral continues parsing a map literal after the first key has +// been consumed and the colon has been seen (but not consumed). +func (p *Parser) finishMapLiteral(start int, firstKey ast.ExprNode) (ast.ExprNode, error) { + p.advance() // consume : + firstVal, err := p.parseExpression() + if err != nil { + return nil, err + } + + keys := []ast.ExprNode{firstKey} + values := []ast.ExprNode{firstVal} + + for p.match(tokCOMMA) { + k, err := p.parseExpression() + if err != nil { + return nil, err + } + if _, err := p.expect(tokCOLON); err != nil { + return nil, err + } + v, err := p.parseExpression() + if err != nil { + return nil, err + } + keys = append(keys, k) + values = append(values, v) + } + + if _, err := p.expect(tokRBRACE); err != nil { + return nil, err + } + return &ast.MapLit{Keys: keys, Values: values, Loc: p.makeLoc(start)}, nil +} + +// finishSetLiteral continues parsing a set literal after the first element +// has been consumed. +func (p *Parser) finishSetLiteral(start int, firstElem ast.ExprNode) (ast.ExprNode, error) { + elems := []ast.ExprNode{firstElem} + + for p.match(tokCOMMA) { + e, err := p.parseExpression() + if err != nil { + return nil, err + } + elems = append(elems, e) + } + + if _, err := p.expect(tokRBRACE); err != nil { + return nil, err + } + return &ast.SetLit{Elements: elems, Loc: p.makeLoc(start)}, nil +} + +// parseListLiteral parses a [ elem, elem, ... ] list literal. +func (p *Parser) parseListLiteral() (ast.ExprNode, error) { + start := p.curLoc() + p.advance() // [ + + // Empty list. + if p.cur.Type == tokRBRACK { + p.advance() + return &ast.ListLit{Loc: p.makeLoc(start)}, nil + } + + elems, err := p.parseExpressionList() + if err != nil { + return nil, err + } + + if _, err := p.expect(tokRBRACK); err != nil { + return nil, err + } + return &ast.ListLit{Elements: elems, Loc: p.makeLoc(start)}, nil +} + +// --------------------------------------------------------------------------- +// Option hash (used in table / index options) +// --------------------------------------------------------------------------- + +// parseOptionHash parses a { 'key' : 'value', ... } option hash used in +// WITH clauses for table and index options. +func (p *Parser) parseOptionHash() (*ast.OptionHash, error) { + start := p.curLoc() + if _, err := p.expect(tokLBRACE); err != nil { + return nil, err + } + + var items []*ast.OptionHashItem + + // Empty hash. + if p.cur.Type == tokRBRACE { + p.advance() + return &ast.OptionHash{Items: items, Loc: p.makeLoc(start)}, nil + } + + for { + item, err := p.parseOptionHashItem() + if err != nil { + return nil, err + } + items = append(items, item) + if !p.match(tokCOMMA) { + break + } + } + + if _, err := p.expect(tokRBRACE); err != nil { + return nil, err + } + return &ast.OptionHash{Items: items, Loc: p.makeLoc(start)}, nil +} + +// parseOptionHashItem parses a single key : value pair inside an option hash. +// Keys and values are typically string literals but may also be identifiers, +// integers, floats, or booleans. +func (p *Parser) parseOptionHashItem() (*ast.OptionHashItem, error) { + start := p.curLoc() + key, err := p.parseOptionHashValue() + if err != nil { + return nil, err + } + if _, err := p.expect(tokCOLON); err != nil { + return nil, err + } + val, err := p.parseOptionHashValue() + if err != nil { + return nil, err + } + return &ast.OptionHashItem{Key: key, Value: val, Loc: p.makeLoc(start)}, nil +} + +// parseOptionHashValue parses a value that can appear as a key or value in an +// option hash. This is more permissive than parseConstant because option +// hashes accept unquoted identifiers as values (e.g. class : 'SimpleStrategy'). +func (p *Parser) parseOptionHashValue() (ast.ExprNode, error) { + tok := p.cur + switch { + case tok.Type == tokSTRING: + p.advance() + return &ast.StringLit{Val: tok.Str, Loc: ast.Loc{Start: tok.Loc, End: tok.End}}, nil + case tok.Type == tokINTEGER: + p.advance() + return &ast.IntegerLit{Val: tok.Str, Loc: ast.Loc{Start: tok.Loc, End: tok.End}}, nil + case tok.Type == tokFLOAT: + p.advance() + return &ast.FloatLit{Val: tok.Str, Loc: ast.Loc{Start: tok.Loc, End: tok.End}}, nil + case tok.Type == tokTRUE: + p.advance() + return &ast.BoolLit{Val: true, Loc: ast.Loc{Start: tok.Loc, End: tok.End}}, nil + case tok.Type == tokFALSE: + p.advance() + return &ast.BoolLit{Val: false, Loc: ast.Loc{Start: tok.Loc, End: tok.End}}, nil + case tok.Type == tokNULL: + p.advance() + return &ast.NullLit{Loc: ast.Loc{Start: tok.Loc, End: tok.End}}, nil + case isIdentLike(tok.Type): + p.advance() + return &ast.Identifier{ + Name: strings.ToLower(tok.Str), + Quoted: tok.Type == tokQUOTED, + Loc: ast.Loc{Start: tok.Loc, End: tok.End}, + }, nil + default: + return nil, p.errorf("expected value in option hash, got %s", p.tokenDesc()) + } +} diff --git a/cassandra/parser/insert.go b/cassandra/parser/insert.go new file mode 100644 index 00000000..29c3c936 --- /dev/null +++ b/cassandra/parser/insert.go @@ -0,0 +1,100 @@ +package parser + +import ( + "github.com/bytebase/omni/cassandra/ast" +) + +// parseInsert parses an INSERT statement: +// +// INSERT INTO [keyspace.]table [(columnList)] insertValuesSpec [IF NOT EXISTS] [USING TTL n [AND TIMESTAMP m]] +// insertValuesSpec: VALUES '(' expressionList ')' | JSON constant [DEFAULT UNSET] +func (p *Parser) parseInsert() (*ast.InsertStmt, error) { + start := p.curLoc() + if err := p.expectKeyword(tokINSERT); err != nil { + return nil, err + } + if err := p.expectKeyword(tokINTO); err != nil { + return nil, err + } + + table, err := p.parseQualifiedName() + if err != nil { + return nil, err + } + + stmt := &ast.InsertStmt{Table: table} + + // Parse optional column list: '(' identifier (',' identifier)* ')' + if p.cur.Type == tokLPAREN { + p.advance() // ( + for { + col, err := p.parseIdentifier() + if err != nil { + return nil, err + } + stmt.Columns = append(stmt.Columns, col) + if !p.match(tokCOMMA) { + break + } + } + if _, err := p.expect(tokRPAREN); err != nil { + return nil, err + } + } + + // Parse insert values spec: VALUES (...) or JSON constant + switch p.cur.Type { + case tokVALUES: + p.advance() // VALUES + if _, err := p.expect(tokLPAREN); err != nil { + return nil, err + } + values, err := p.parseExpressionList() + if err != nil { + return nil, err + } + stmt.Values = values + if _, err := p.expect(tokRPAREN); err != nil { + return nil, err + } + case tokJSON: + p.advance() // JSON + stmt.IsJSON = true + jsonVal, err := p.parseConstant() + if err != nil { + return nil, err + } + stmt.JSONValue = jsonVal + // Optional DEFAULT UNSET | DEFAULT NULL + if p.cur.Type == tokDEFAULT { + p.advance() // DEFAULT + if p.cur.Type == tokNULL { + p.advance() + stmt.DefaultNull = true + } else if err := p.expectKeyword(tokUNSET); err != nil { + return nil, err + } else { + stmt.DefaultUnset = true + } + } + default: + return nil, p.errorf("expected VALUES or JSON, got %s", p.tokenDesc()) + } + + // Optional IF NOT EXISTS + ifne, err := p.parseIfNotExists() + if err != nil { + return nil, err + } + stmt.IfNotExists = ifne + + // Optional USING clause + using, err := p.parseUsingClause() + if err != nil { + return nil, err + } + stmt.Using = using + + stmt.Loc = p.makeLoc(start) + return stmt, nil +} diff --git a/cassandra/parser/keywords.go b/cassandra/parser/keywords.go new file mode 100644 index 00000000..c54c6c55 --- /dev/null +++ b/cassandra/parser/keywords.go @@ -0,0 +1,167 @@ +package parser + +// keywords maps lowercase keyword strings to their token types. +// CQL keywords are case-insensitive. +var keywords = map[string]int{ + "access": tokACCESS, + "add": tokADD, + "aggregate": tokAGGREGATE, + "all": tokALL, + "allow": tokALLOW, + "alter": tokALTER, + "and": tokAND, + "ann": tokANN, + "apply": tokAPPLY, + "as": tokAS, + "asc": tokASC, + "ascii": tokASCII, + "authorize": tokAUTHORIZE, + "batch": tokBATCH, + "begin": tokBEGIN, + "bigint": tokBIGINT, + "blob": tokBLOB, + "boolean": tokBOOLEAN, + "by": tokBY, + "called": tokCALLED, + "cast": tokCAST, + "clustering": tokCLUSTERING, + "compact": tokCOMPACT, + "contains": tokCONTAINS, + "counter": tokCOUNTER, + "create": tokCREATE, + "currentdate": tokCURRENTDATE, + "currenttime": tokCURRENTTIME, + "currenttimestamp": tokCURRENTTIMESTAMP, + "currenttimeuuid": tokCURRENTTIMEUUID, + "custom": tokCUSTOM, + "datacenters": tokDATACENTERS, + "date": tokDATE, + "datetimenow": tokDATETIMENOW, + "decimal": tokDECIMAL, + "default": tokDEFAULT, + "delete": tokDELETE, + "desc": tokDESC, + "describe": tokDESCRIBE, + "distinct": tokDISTINCT, + "double": tokDOUBLE, + "drop": tokDROP, + "durable_writes": tokDURABLE_WRITES, + "duration": tokDURATION, + "entries": tokENTRIES, + "execute": tokEXECUTE, + "exists": tokEXISTS, + "false": tokFALSE, + "filtering": tokFILTERING, + "finalfunc": tokFINALFUNC, + "float": tokFLOATKW, + "from": tokFROM, + "fromjson": tokFROMJSON, + "frozen": tokFROZEN, + "full": tokFULL, + "function": tokFUNCTION, + "functions": tokFUNCTIONS, + "grant": tokGRANT, + "group": tokGROUP, + "hashed": tokHASHED, + "if": tokIF, + "in": tokIN, + "index": tokINDEX, + "inet": tokINET, + "infinity": tokINFINITY, + "initcond": tokINITCOND, + "input": tokINPUT, + "insert": tokINSERT, + "int": tokINT, + "into": tokINTO, + "is": tokIS, + "json": tokJSON, + "key": tokKEY, + "keys": tokKEYS, + "keyspace": tokKEYSPACE, + "keyspaces": tokKEYSPACES, + "language": tokLANGUAGE, + "limit": tokLIMIT, + "list": tokLIST, + "logged": tokLOGGED, + "login": tokLOGIN, + "map": tokMAP, + "materialized": tokMATERIALIZED, + "mbean": tokMBEAN, + "mbeans": tokMBEANS, + "maxtimeuuid": tokMAXTIMEUUID, + "mintimeuuid": tokMINTIMEUUID, + "modify": tokMODIFY, + "nan": tokNAN, + "norecursive": tokNORECURSIVE, + "nosuperuser": tokNOSUPERUSER, + "not": tokNOT, + "now": tokNOW, + "null": tokNULL, + "of": tokOF, + "on": tokON, + "options": tokOPTIONS, + "or": tokOR, + "order": tokORDER, + "partition": tokPARTITION, + "password": tokPASSWORD, + "per": tokPER, + "permission": tokPERMISSION, + "permissions": tokPERMISSIONS, + "primary": tokPRIMARY, + "rename": tokRENAME, + "replace": tokREPLACE, + "replication": tokREPLICATION, + "returns": tokRETURNS, + "revoke": tokREVOKE, + "role": tokROLE, + "roles": tokROLES, + "sai": tokSAI, + "select": tokSELECT, + "set": tokSET, + "sfunc": tokSFUNC, + "smallint": tokSMALLINT, + "static": tokSTATIC, + "storage": tokSTORAGE, + "storageattachedindex": tokSTORAGEATTACHEDINDEX, + "stype": tokSTYPE, + "superuser": tokSUPERUSER, + "table": tokTABLE, + "text": tokTEXT, + "time": tokTIME, + "timestamp": tokTIMESTAMP, + "timeuuid": tokTIMEUUID, + "tinyint": tokTINYINT, + "to": tokTO, + "tojson": tokTOJSON, + "trigger": tokTRIGGER, + "true": tokTRUE, + "truncate": tokTRUNCATE, + "ttl": tokTTL, + "tuple": tokTUPLE, + "type": tokTYPE, + "unlogged": tokUNLOGGED, + "unset": tokUNSET, + "update": tokUPDATE, + "use": tokUSE, + "user": tokUSER, + "using": tokUSING, + "uuid": tokUUID_KW, + "values": tokVALUES, + "varchar": tokVARCHAR, + "varint": tokVARINT, + "vector": tokVECTOR, + "view": tokVIEW, + "where": tokWHERE, + "with": tokWITH, +} + +// isKeyword returns true if the token type is a keyword. +func isKeyword(tok int) bool { + return tok >= 3000 +} + +// isIdentLike returns true if the token is an identifier or a keyword that +// can be used as an identifier in certain contexts (like column/table names). +func isIdentLike(tok int) bool { + return tok == tokIDENT || tok == tokQUOTED || isKeyword(tok) +} diff --git a/cassandra/parser/lexer.go b/cassandra/parser/lexer.go new file mode 100644 index 00000000..e26b45c7 --- /dev/null +++ b/cassandra/parser/lexer.go @@ -0,0 +1,391 @@ +package parser + +import ( + "strings" + "unicode/utf8" +) + +// Lexer tokenizes CQL input. +type Lexer struct { + input string + lineIdx lineIndex + pos int + Err error +} + +// NewLexer creates a new Lexer for the given input. +func NewLexer(input string) *Lexer { + return &Lexer{input: input, lineIdx: buildLineIndex(input)} +} + +func (l *Lexer) makeError(msg string, start, end int) *ParseError { + line, col := offsetToLineCol(l.lineIdx, start) + nearEnd := start + 30 + if nearEnd > end { + nearEnd = end + } + return &ParseError{ + Message: msg, + Loc: locFromOffsets(start, end), + Line: line, + Column: col, + Near: l.input[start:nearEnd], + } +} + +// Next returns the next token from the input. +func (l *Lexer) Next() Token { + l.skipWhitespaceAndComments() + + if l.pos >= len(l.input) { + return Token{Type: tokEOF, Loc: l.pos, End: l.pos} + } + + start := l.pos + ch := l.input[l.pos] + + switch { + case ch == '\'': + return l.scanString(start) + case ch == '"': + return l.scanQuotedIdentifier(start) + case ch == '$' && l.pos+1 < len(l.input) && l.input[l.pos+1] == '$': + return l.scanCodeBlock(start) + case ch == '0' && l.pos+1 < len(l.input) && (l.input[l.pos+1] == 'x' || l.input[l.pos+1] == 'X'): + return l.scanHex(start) + case isDigit(ch): + return l.scanNumber(start) + case isIdentStart(ch): + return l.scanIdentOrKeyword(start) + default: + return l.scanOperator(start) + } +} + +func (l *Lexer) skipWhitespaceAndComments() { + for l.pos < len(l.input) { + ch := l.input[l.pos] + if ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' || ch == '\f' { + l.pos++ + continue + } + if ch == '-' && l.pos+1 < len(l.input) && l.input[l.pos+1] == '-' { + l.pos += 2 + for l.pos < len(l.input) && l.input[l.pos] != '\n' { + l.pos++ + } + continue + } + if ch == '/' && l.pos+1 < len(l.input) && l.input[l.pos+1] == '/' { + l.pos += 2 + for l.pos < len(l.input) && l.input[l.pos] != '\n' { + l.pos++ + } + continue + } + if ch == '/' && l.pos+1 < len(l.input) && l.input[l.pos+1] == '*' { + l.pos += 2 + for l.pos+1 < len(l.input) { + if l.input[l.pos] == '*' && l.input[l.pos+1] == '/' { + l.pos += 2 + break + } + l.pos++ + } + if l.pos >= len(l.input) { + // Unterminated block comment — just stop skipping + } + continue + } + break + } +} + +func (l *Lexer) scanString(start int) Token { + l.pos++ // skip opening ' + var b strings.Builder + for l.pos < len(l.input) { + ch := l.input[l.pos] + if ch == '\'' { + if l.pos+1 < len(l.input) && l.input[l.pos+1] == '\'' { + b.WriteByte('\'') + l.pos += 2 + continue + } + l.pos++ + return Token{Type: tokSTRING, Str: b.String(), Loc: start, End: l.pos} + } + b.WriteByte(ch) + l.pos++ + } + l.Err = l.makeError("unterminated string literal", start, l.pos) + return Token{Type: tokEOF, Loc: l.pos, End: l.pos} +} + +func (l *Lexer) scanQuotedIdentifier(start int) Token { + l.pos++ // skip opening " + var b strings.Builder + for l.pos < len(l.input) { + ch := l.input[l.pos] + if ch == '"' { + if l.pos+1 < len(l.input) && l.input[l.pos+1] == '"' { + b.WriteByte('"') + l.pos += 2 + continue + } + l.pos++ + return Token{Type: tokQUOTED, Str: b.String(), Loc: start, End: l.pos} + } + b.WriteByte(ch) + l.pos++ + } + l.Err = l.makeError("unterminated quoted identifier", start, l.pos) + return Token{Type: tokEOF, Loc: l.pos, End: l.pos} +} + +func (l *Lexer) scanCodeBlock(start int) Token { + l.pos += 2 // skip opening $$ + idx := strings.Index(l.input[l.pos:], "$$") + if idx < 0 { + l.Err = l.makeError("unterminated code block", start, len(l.input)) + l.pos = len(l.input) + return Token{Type: tokEOF, Loc: l.pos, End: l.pos} + } + val := l.input[l.pos : l.pos+idx] + l.pos += idx + 2 + return Token{Type: tokCODEBLOCK, Str: val, Loc: start, End: l.pos} +} + +func (l *Lexer) scanHex(start int) Token { + l.pos += 2 // skip 0x + for l.pos < len(l.input) && isHexDigit(l.input[l.pos]) { + l.pos++ + } + return Token{Type: tokHEX, Str: l.input[start:l.pos], Loc: start, End: l.pos} +} + +func (l *Lexer) scanNumber(start int) Token { + isFloat := false + for l.pos < len(l.input) && isDigit(l.input[l.pos]) { + l.pos++ + } + + // UUID check: digits followed by hex letters (a-f) may form the first 8-char group. + numLen := l.pos - start + if numLen < 8 && l.pos < len(l.input) && isHexLetter(l.input[l.pos]) { + savedPos := l.pos + for l.pos < len(l.input) && l.pos-start < 8 && isHexDigit(l.input[l.pos]) { + l.pos++ + } + if l.pos-start == 8 { + str := l.input[start:l.pos] + if isUUIDCandidate(str, l) { + return l.scanUUID(start, str) + } + } + l.pos = savedPos + } + + if l.pos < len(l.input) && l.input[l.pos] == '.' { + next := l.pos + 1 + if next < len(l.input) && isDigit(l.input[next]) { + isFloat = true + l.pos++ // skip . + for l.pos < len(l.input) && isDigit(l.input[l.pos]) { + l.pos++ + } + } + } + if l.pos < len(l.input) && (l.input[l.pos] == 'e' || l.input[l.pos] == 'E') { + ePos := l.pos + l.pos++ + if l.pos < len(l.input) && (l.input[l.pos] == '+' || l.input[l.pos] == '-') { + l.pos++ + } + if l.pos >= len(l.input) || !isDigit(l.input[l.pos]) { + l.pos = ePos + } else { + isFloat = true + for l.pos < len(l.input) && isDigit(l.input[l.pos]) { + l.pos++ + } + } + } + str := l.input[start:l.pos] + + // A number like 550e8400 may actually be the first group of a UUID. + if isUUIDCandidate(str, l) { + return l.scanUUID(start, str) + } + + tok := tokINTEGER + if isFloat { + tok = tokFLOAT + } + return Token{Type: tok, Str: str, Loc: start, End: l.pos} +} + +func (l *Lexer) scanIdentOrKeyword(start int) Token { + for l.pos < len(l.input) { + ch := l.input[l.pos] + if isIdentPart(ch) { + l.pos++ + } else { + break + } + } + str := l.input[start:l.pos] + lower := strings.ToLower(str) + + // Check if this looks like a UUID: 8-4-4-4-12 hex pattern. + if isUUIDCandidate(str, l) { + return l.scanUUID(start, str) + } + + if tok, ok := keywords[lower]; ok { + return Token{Type: tok, Str: str, Loc: start, End: l.pos} + } + + return Token{Type: tokIDENT, Str: str, Loc: start, End: l.pos} +} + +// isUUIDCandidate checks if the current identifier followed by upcoming chars +// forms a UUID pattern (xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx). +func isUUIDCandidate(str string, l *Lexer) bool { + if len(str) != 8 { + return false + } + for _, c := range str { + if !isHexDigitRune(c) { + return false + } + } + // Check if followed by - and more hex groups + remaining := l.input[l.pos:] + if len(remaining) < 28 { // -xxxx-xxxx-xxxx-xxxxxxxxxxxx + return false + } + // Pattern: -XXXX-XXXX-XXXX-XXXXXXXXXXXX + if remaining[0] != '-' { + return false + } + expected := []int{4, 4, 4, 12} + off := 1 + for i, groupLen := range expected { + for j := 0; j < groupLen; j++ { + if off >= len(remaining) || !isHexDigitByte(remaining[off]) { + return false + } + off++ + } + if i < len(expected)-1 { + if off >= len(remaining) || remaining[off] != '-' { + return false + } + off++ + } + } + // Make sure the UUID isn't followed by more identifier chars + if off < len(remaining) && isIdentPart(remaining[off]) { + return false + } + return true +} + +func (l *Lexer) scanUUID(start int, firstGroup string) Token { + // We already consumed the first 8-char group. Consume the rest. + // Pattern: -xxxx-xxxx-xxxx-xxxxxxxxxxxx + groups := []int{4, 4, 4, 12} + for _, groupLen := range groups { + l.pos++ // skip - + l.pos += groupLen + } + return Token{Type: tokUUID, Str: l.input[start:l.pos], Loc: start, End: l.pos} +} + +func (l *Lexer) scanOperator(start int) Token { + ch := l.input[l.pos] + l.pos++ + switch ch { + case '.': + return Token{Type: tokDOT, Str: ".", Loc: start, End: l.pos} + case ',': + return Token{Type: tokCOMMA, Str: ",", Loc: start, End: l.pos} + case ';': + return Token{Type: tokSEMI, Str: ";", Loc: start, End: l.pos} + case ':': + return Token{Type: tokCOLON, Str: ":", Loc: start, End: l.pos} + case '(': + return Token{Type: tokLPAREN, Str: "(", Loc: start, End: l.pos} + case ')': + return Token{Type: tokRPAREN, Str: ")", Loc: start, End: l.pos} + case '{': + return Token{Type: tokLBRACE, Str: "{", Loc: start, End: l.pos} + case '}': + return Token{Type: tokRBRACE, Str: "}", Loc: start, End: l.pos} + case '[': + return Token{Type: tokLBRACK, Str: "[", Loc: start, End: l.pos} + case ']': + return Token{Type: tokRBRACK, Str: "]", Loc: start, End: l.pos} + case '*': + return Token{Type: tokSTAR, Str: "*", Loc: start, End: l.pos} + case '+': + return Token{Type: tokPLUS, Str: "+", Loc: start, End: l.pos} + case '-': + return Token{Type: tokMINUS, Str: "-", Loc: start, End: l.pos} + case '!': + if l.pos < len(l.input) && l.input[l.pos] == '=' { + l.pos++ + return Token{Type: tokNE, Str: "!=", Loc: start, End: l.pos} + } + return Token{Type: tokILLEGAL, Str: "!", Loc: start, End: l.pos} + case '?': + return Token{Type: tokQMARK, Str: "?", Loc: start, End: l.pos} + case '=': + return Token{Type: tokEQ, Str: "=", Loc: start, End: l.pos} + case '<': + if l.pos < len(l.input) && l.input[l.pos] == '=' { + l.pos++ + return Token{Type: tokLTE, Str: "<=", Loc: start, End: l.pos} + } + return Token{Type: tokLT, Str: "<", Loc: start, End: l.pos} + case '>': + if l.pos < len(l.input) && l.input[l.pos] == '=' { + l.pos++ + return Token{Type: tokGTE, Str: ">=", Loc: start, End: l.pos} + } + return Token{Type: tokGT, Str: ">", Loc: start, End: l.pos} + default: + _, size := utf8.DecodeRuneInString(l.input[start:]) + l.pos = start + size + return Token{Type: tokILLEGAL, Str: l.input[start:l.pos], Loc: start, End: l.pos} + } +} + +func isDigit(ch byte) bool { + return ch >= '0' && ch <= '9' +} + +func isHexDigit(ch byte) bool { + return (ch >= '0' && ch <= '9') || (ch >= 'a' && ch <= 'f') || (ch >= 'A' && ch <= 'F') +} + +func isHexLetter(ch byte) bool { + return (ch >= 'a' && ch <= 'f') || (ch >= 'A' && ch <= 'F') +} + +func isHexDigitRune(ch rune) bool { + return (ch >= '0' && ch <= '9') || (ch >= 'a' && ch <= 'f') || (ch >= 'A' && ch <= 'F') +} + +func isHexDigitByte(ch byte) bool { + return isHexDigit(ch) +} + +func isIdentStart(ch byte) bool { + return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '_' +} + +func isIdentPart(ch byte) bool { + return isIdentStart(ch) || isDigit(ch) +} diff --git a/cassandra/parser/parser.go b/cassandra/parser/parser.go new file mode 100644 index 00000000..09275584 --- /dev/null +++ b/cassandra/parser/parser.go @@ -0,0 +1,576 @@ +package parser + +import ( + "fmt" + "strings" + + "github.com/bytebase/omni/cassandra/ast" +) + +// Parser is the recursive-descent parser for Cassandra CQL. +type Parser struct { + lexer *Lexer + source string + cur Token + prev Token + nextBuf Token + hasNext bool +} + +// Parse parses a CQL input containing one or more statements and returns a List of RawStmt. +func Parse(sql string) (*ast.List, error) { + p := &Parser{ + lexer: NewLexer(sql), + source: sql, + } + p.advance() + + var items []ast.Node + for p.cur.Type != tokEOF { + if p.cur.Type == tokSEMI { + p.advance() + continue + } + + stmtStart := p.cur.Loc + stmt, err := p.parseStatement() + if err != nil { + return nil, err + } + if stmt == nil { + break + } + + stmtEnd := p.prev.End + raw := &ast.RawStmt{ + Stmt: stmt, + StmtLocation: stmtStart, + StmtLen: stmtEnd - stmtStart, + } + items = append(items, raw) + + // Consume optional trailing semicolons between statements. + for p.cur.Type == tokSEMI { + p.advance() + } + } + + if p.lexer.Err != nil { + return nil, p.lexer.Err + } + + return &ast.List{Items: items}, nil +} + +// parseStatement dispatches to the appropriate statement parser. +func (p *Parser) parseStatement() (ast.StmtNode, error) { + switch p.cur.Type { + case tokSELECT: + return p.parseSelect() + case tokINSERT: + return p.parseInsert() + case tokUPDATE: + return p.parseUpdate() + case tokDELETE: + return p.parseDelete() + case tokBEGIN: + return p.parseBatch() + case tokTRUNCATE: + return p.parseTruncate() + case tokUSE: + return p.parseUse() + case tokCREATE: + return p.parseCreate() + case tokALTER: + return p.parseAlter() + case tokDROP: + return p.parseDrop() + case tokGRANT: + return p.parseGrant() + case tokREVOKE: + return p.parseRevoke() + case tokLIST: + return p.parseList() + case tokAPPLY: + // APPLY BATCH is handled within parseBatch; standalone APPLY is an error + return nil, p.errorf("unexpected APPLY without matching BEGIN BATCH") + default: + return nil, p.errorf("expected statement, got %s", p.tokenDesc()) + } +} + +// --------------------------------------------------------------------------- +// Token manipulation helpers +// --------------------------------------------------------------------------- + +func (p *Parser) advance() { + p.prev = p.cur + if p.hasNext { + p.cur = p.nextBuf + p.hasNext = false + } else { + p.cur = p.lexer.Next() + } +} + +func (p *Parser) peekNext() Token { + if !p.hasNext { + p.nextBuf = p.lexer.Next() + p.hasNext = true + } + return p.nextBuf +} + +func (p *Parser) match(types ...int) bool { + for _, t := range types { + if p.cur.Type == t { + p.advance() + return true + } + } + return false +} + +func (p *Parser) expect(typ int) (Token, error) { + if p.cur.Type == typ { + tok := p.cur + p.advance() + return tok, nil + } + return Token{}, p.errorf("expected %s, got %s", tokenName(typ), p.tokenDesc()) +} + +func (p *Parser) expectKeyword(typ int) error { + _, err := p.expect(typ) + return err +} + +func (p *Parser) curLoc() int { + return p.cur.Loc +} + +func (p *Parser) prevEnd() int { + return p.prev.End +} + +func (p *Parser) makeLoc(start int) ast.Loc { + return ast.Loc{Start: start, End: p.prevEnd()} +} + +func (p *Parser) errorf(format string, args ...any) *ParseError { + if pe, ok := p.lexer.Err.(*ParseError); ok { + return pe + } + line, col := offsetToLineCol(p.lexer.lineIdx, p.cur.Loc) + near := p.cur.Str + if near == "" && p.cur.Type != tokEOF { + near = p.extractNear(p.cur.Loc) + } + return &ParseError{ + Message: fmt.Sprintf(format, args...), + Loc: ast.Loc{Start: p.cur.Loc, End: p.cur.End}, + Line: line, + Column: col, + Near: near, + } +} + +func (p *Parser) extractNear(offset int) string { + if offset >= len(p.source) { + return "" + } + end := offset + for end < len(p.source) && end-offset < 30 && p.source[end] != ' ' && p.source[end] != '\n' && p.source[end] != '\t' { + end++ + } + return p.source[offset:end] +} + +func (p *Parser) tokenDesc() string { + switch p.cur.Type { + case tokEOF: + return "end of input" + case tokILLEGAL: + return fmt.Sprintf("illegal character %q", p.cur.Str) + default: + if p.cur.Str != "" { + return fmt.Sprintf("%q", p.cur.Str) + } + return tokenName(p.cur.Type) + } +} + +// --------------------------------------------------------------------------- +// Common parsing helpers +// --------------------------------------------------------------------------- + +// parseIdentifier parses an identifier (unquoted, quoted, or keyword-as-ident). +func (p *Parser) parseIdentifier() (*ast.Identifier, error) { + tok := p.cur + switch { + case tok.Type == tokIDENT: + p.advance() + return &ast.Identifier{Name: tok.Str, Loc: ast.Loc{Start: tok.Loc, End: tok.End}}, nil + case tok.Type == tokQUOTED: + p.advance() + return &ast.Identifier{Name: tok.Str, Quoted: true, Loc: ast.Loc{Start: tok.Loc, End: tok.End}}, nil + case isKeyword(tok.Type): + p.advance() + return &ast.Identifier{Name: tok.Str, Loc: ast.Loc{Start: tok.Loc, End: tok.End}}, nil + default: + return nil, p.errorf("expected identifier, got %s", p.tokenDesc()) + } +} + +// parseQualifiedName parses name or keyspace.name. +func (p *Parser) parseQualifiedName() (*ast.QualifiedName, error) { + start := p.curLoc() + first, err := p.parseIdentifier() + if err != nil { + return nil, err + } + parts := []*ast.Identifier{first} + if p.cur.Type == tokDOT { + p.advance() + second, err := p.parseIdentifier() + if err != nil { + return nil, err + } + parts = append(parts, second) + } + return &ast.QualifiedName{Parts: parts, Loc: p.makeLoc(start)}, nil +} + +// parseIfNotExists parses optional IF NOT EXISTS, returning (found, error). +func (p *Parser) parseIfNotExists() (bool, error) { + if p.cur.Type == tokIF { + next := p.peekNext() + if next.Type == tokNOT { + p.advance() // IF + p.advance() // NOT + if err := p.expectKeyword(tokEXISTS); err != nil { + return false, err + } + return true, nil + } + } + return false, nil +} + +// parseIfExists parses optional IF EXISTS. +func (p *Parser) parseIfExists() bool { + if p.cur.Type == tokIF && p.peekNext().Type == tokEXISTS { + p.advance() // IF + p.advance() // EXISTS + return true + } + return false +} + +// parseUsingClause parses optional USING TTL n [AND TIMESTAMP m] / USING TIMESTAMP m [AND TTL n]. +func (p *Parser) parseUsingClause() (*ast.UsingClause, error) { + if p.cur.Type != tokUSING { + return nil, nil + } + start := p.curLoc() + p.advance() // USING + + clause := &ast.UsingClause{} + for { + switch p.cur.Type { + case tokTTL: + p.advance() + val, err := p.parseConstant() + if err != nil { + return nil, err + } + clause.TTL = val + case tokTIMESTAMP: + p.advance() + val, err := p.parseConstant() + if err != nil { + return nil, err + } + clause.Timestamp = val + default: + return nil, p.errorf("expected TTL or TIMESTAMP after USING") + } + if !p.match(tokAND) { + break + } + } + clause.Loc = p.makeLoc(start) + return clause, nil +} + +// parseWhereClause parses WHERE relationElement (AND relationElement)*. +func (p *Parser) parseWhereClause() ([]ast.ExprNode, error) { + if err := p.expectKeyword(tokWHERE); err != nil { + return nil, err + } + return p.parseRelationElements() +} + +// parseRelationElements parses relation (AND relation)*. +func (p *Parser) parseRelationElements() ([]ast.ExprNode, error) { + var relations []ast.ExprNode + first, err := p.parseRelationElement() + if err != nil { + return nil, err + } + relations = append(relations, first) + for p.match(tokAND) { + rel, err := p.parseRelationElement() + if err != nil { + return nil, err + } + relations = append(relations, rel) + } + return relations, nil +} + +// parseRelationElement parses a single WHERE condition. +func (p *Parser) parseRelationElement() (ast.ExprNode, error) { + start := p.curLoc() + + // Handle tuple comparison: (col1, col2, ...) op/IN (...) + if p.cur.Type == tokLPAREN { + return p.parseTupleRelation(start) + } + + // Could be: identifier op constant, identifier IN (...), identifier CONTAINS [KEY] constant, + // or function op constant/function. + left, err := p.parseRelationLeft() + if err != nil { + return nil, err + } + + switch p.cur.Type { + case tokEQ, tokLT, tokGT, tokLTE, tokGTE, tokNE: + op := p.cur.Str + p.advance() + right, err := p.parseRelationRight() + if err != nil { + return nil, err + } + return &ast.BinaryExpr{Left: left, Op: op, Right: right, Loc: p.makeLoc(start)}, nil + + case tokIN: + p.advance() + if _, err := p.expect(tokLPAREN); err != nil { + return nil, err + } + var values []ast.ExprNode + if p.cur.Type != tokRPAREN { + values, err = p.parseExpressionList() + if err != nil { + return nil, err + } + } + if _, err := p.expect(tokRPAREN); err != nil { + return nil, err + } + return &ast.InExpr{Column: left, Values: values, Loc: p.makeLoc(start)}, nil + + case tokCONTAINS: + p.advance() + isKey := false + if p.cur.Type == tokKEY { + isKey = true + p.advance() + } + val, err := p.parseConstant() + if err != nil { + return nil, err + } + return &ast.ContainsExpr{Column: left, Value: val, IsKey: isKey, Loc: p.makeLoc(start)}, nil + + default: + return nil, p.errorf("expected operator after expression in WHERE clause, got %s", p.tokenDesc()) + } +} + +func (p *Parser) parseRelationLeft() (ast.ExprNode, error) { + if isIdentLike(p.cur.Type) { + name, err := p.parseIdentifier() + if err != nil { + return nil, err + } + // Check for function call: name(...) + if p.cur.Type == tokLPAREN { + return p.parseFunctionCallWithName(name) + } + // Check for dotted name: name.field + if p.cur.Type == tokDOT { + start := name.Loc.Start + p.advance() + field, err := p.parseIdentifier() + if err != nil { + return nil, err + } + return &ast.DotAccess{Object: name, Field: field, Loc: p.makeLoc(start)}, nil + } + return name, nil + } + return nil, p.errorf("expected identifier or function call in relation, got %s", p.tokenDesc()) +} + +func (p *Parser) parseRelationRight() (ast.ExprNode, error) { + // Bind markers + if p.cur.Type == tokQMARK { + m := &ast.BindMarker{Loc: ast.Loc{Start: p.cur.Loc, End: p.cur.End}} + p.advance() + return m, nil + } + if p.cur.Type == tokCOLON { + start := p.curLoc() + p.advance() + name, err := p.parseIdentifier() + if err != nil { + return nil, err + } + return &ast.BindMarker{Name: name.Name, Loc: p.makeLoc(start)}, nil + } + // Could be a constant or a function call. + if isIdentLike(p.cur.Type) && p.peekNext().Type == tokLPAREN { + name, err := p.parseIdentifier() + if err != nil { + return nil, err + } + return p.parseFunctionCallWithName(name) + } + return p.parseConstant() +} + +func (p *Parser) parseTupleRelation(start int) (ast.ExprNode, error) { + p.advance() // ( + var cols []ast.ExprNode + for { + col, err := p.parseIdentifier() + if err != nil { + return nil, err + } + cols = append(cols, col) + if !p.match(tokCOMMA) { + break + } + } + if _, err := p.expect(tokRPAREN); err != nil { + return nil, err + } + + if p.cur.Type == tokIN { + p.advance() + if _, err := p.expect(tokLPAREN); err != nil { + return nil, err + } + var tuples []*ast.TupleLit + for { + t, err := p.parseTupleLit() + if err != nil { + return nil, err + } + tuples = append(tuples, t) + if !p.match(tokCOMMA) { + break + } + } + if _, err := p.expect(tokRPAREN); err != nil { + return nil, err + } + return &ast.TupleInExpr{Columns: cols, Tuples: tuples, Loc: p.makeLoc(start)}, nil + } + + // Tuple comparison: (col1, col2) op (val1, val2) + op := p.cur.Str + if !p.match(tokEQ, tokLT, tokGT, tokLTE, tokGTE, tokNE) { + return nil, p.errorf("expected operator or IN after tuple columns") + } + var values []ast.ExprNode + // Could be a single tuple or multiple tuples + t, err := p.parseTupleLit() + if err != nil { + return nil, err + } + values = t.Elements + return &ast.TupleCompareExpr{Columns: cols, Op: op, Values: values, Loc: p.makeLoc(start)}, nil +} + +func (p *Parser) parseTupleLit() (*ast.TupleLit, error) { + start := p.curLoc() + if _, err := p.expect(tokLPAREN); err != nil { + return nil, err + } + elems, err := p.parseExpressionList() + if err != nil { + return nil, err + } + if _, err := p.expect(tokRPAREN); err != nil { + return nil, err + } + return &ast.TupleLit{Elements: elems, Loc: p.makeLoc(start)}, nil +} + +// tokenName returns a human-readable name for a token type. +func tokenName(typ int) string { + switch { + case typ == tokEOF: + return "EOF" + case typ == tokILLEGAL: + return "ILLEGAL" + case typ == tokIDENT: + return "identifier" + case typ == tokQUOTED: + return "quoted identifier" + case typ == tokSTRING: + return "string" + case typ == tokINTEGER: + return "integer" + case typ == tokFLOAT: + return "float" + case typ == tokUUID: + return "UUID" + case typ == tokHEX: + return "hex literal" + case typ == tokCODEBLOCK: + return "code block" + case typ == tokSEMI: + return "';'" + case typ == tokLPAREN: + return "'('" + case typ == tokRPAREN: + return "')'" + case typ == tokLBRACE: + return "'{'" + case typ == tokRBRACE: + return "'}'" + case typ == tokLBRACK: + return "'['" + case typ == tokRBRACK: + return "']'" + case typ == tokCOMMA: + return "','" + case typ == tokDOT: + return "'.'" + case typ == tokSTAR: + return "'*'" + case typ == tokEQ: + return "'='" + case typ == tokLT: + return "'<'" + case typ == tokGT: + return "'>'" + case typ == tokNE: + return "'!='" + case typ == tokQMARK: + return "'?'" + default: + // For keywords, reverse lookup + for name, t := range keywords { + if t == typ { + return strings.ToUpper(name) + } + } + return fmt.Sprintf("token(%d)", typ) + } +} diff --git a/cassandra/parser/select.go b/cassandra/parser/select.go new file mode 100644 index 00000000..7fa6d81e --- /dev/null +++ b/cassandra/parser/select.go @@ -0,0 +1,400 @@ +package parser + +import ( + "github.com/bytebase/omni/cassandra/ast" +) + +// parseSelect parses a SELECT statement. +// +// SELECT [DISTINCT] [JSON] selectElements +// FROM fromSpecElement +// [WHERE relationElements] +// [ORDER BY orderSpecElement (',' orderSpecElement)*] +// [LIMIT decimal] +// [ALLOW FILTERING] +func (p *Parser) parseSelect() (*ast.SelectStmt, error) { + start := p.curLoc() + if err := p.expectKeyword(tokSELECT); err != nil { + return nil, err + } + + stmt := &ast.SelectStmt{} + + // Optional DISTINCT. + if p.cur.Type == tokDISTINCT { + stmt.Distinct = true + p.advance() + } + + // Optional JSON. + if p.cur.Type == tokJSON { + stmt.JSON = true + p.advance() + } + + // Parse select elements. + elements, err := p.parseSelectElements() + if err != nil { + return nil, err + } + stmt.Elements = elements + + // FROM tableName. + if err := p.expectKeyword(tokFROM); err != nil { + return nil, err + } + from, err := p.parseQualifiedName() + if err != nil { + return nil, err + } + stmt.From = from + + // Optional WHERE clause. + if p.cur.Type == tokWHERE { + where, err := p.parseWhereClause() + if err != nil { + return nil, err + } + stmt.Where = where + } + + // Optional GROUP BY clause. + if p.cur.Type == tokGROUP { + groupBy, err := p.parseGroupByClause() + if err != nil { + return nil, err + } + stmt.GroupBy = groupBy + } + + // Optional ORDER BY clause. + if p.cur.Type == tokORDER { + orderBy, err := p.parseOrderByClause() + if err != nil { + return nil, err + } + stmt.OrderBy = orderBy + } + + // Optional PER PARTITION LIMIT clause. + if p.cur.Type == tokPER { + perPartLimit, err := p.parsePerPartitionLimitClause() + if err != nil { + return nil, err + } + stmt.PerPartitionLimit = perPartLimit + } + + // Optional LIMIT clause. + if p.cur.Type == tokLIMIT { + limit, err := p.parseLimitClause() + if err != nil { + return nil, err + } + stmt.Limit = limit + } + + // Optional ALLOW FILTERING. + if p.cur.Type == tokALLOW { + p.advance() + if err := p.expectKeyword(tokFILTERING); err != nil { + return nil, err + } + stmt.AllowFiltering = true + } + + stmt.Loc = p.makeLoc(start) + return stmt, nil +} + +// parseSelectElements parses the select clause elements: +// +// '*' | selectElement (',' selectElement)* +func (p *Parser) parseSelectElements() ([]*ast.SelectElement, error) { + // Handle bare '*'. + if p.cur.Type == tokSTAR { + elemStart := p.curLoc() + p.advance() + star := &ast.StarExpr{Loc: p.makeLoc(elemStart)} + elem := &ast.SelectElement{Expr: star, Loc: p.makeLoc(elemStart)} + return []*ast.SelectElement{elem}, nil + } + + var elements []*ast.SelectElement + first, err := p.parseSelectElement() + if err != nil { + return nil, err + } + elements = append(elements, first) + + for p.cur.Type == tokCOMMA { + p.advance() + el, err := p.parseSelectElement() + if err != nil { + return nil, err + } + elements = append(elements, el) + } + return elements, nil +} + +// parseSelectElement parses a single select element: +// +// IDENT '.' '*' -> DotAccess with StarExpr-like field +// IDENT [AS IDENT] -> Identifier with optional alias +// functionCall [AS IDENT] -> FunctionCall with optional alias +func (p *Parser) parseSelectElement() (*ast.SelectElement, error) { + elemStart := p.curLoc() + + // CAST(expr AS type) in SELECT element + if p.cur.Type == tokCAST { + castExpr, err := p.parseCast() + if err != nil { + return nil, err + } + alias, err := p.parseOptionalAlias() + if err != nil { + return nil, err + } + return &ast.SelectElement{Expr: castExpr, Alias: alias, Loc: p.makeLoc(elemStart)}, nil + } + + // We need an identifier-like token to start. + if !isIdentLike(p.cur.Type) { + return nil, p.errorf("expected column name or function call in SELECT, got %s", p.tokenDesc()) + } + + name, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + // Check for IDENT '.' '*' (qualified star, e.g. ks.*). + if p.cur.Type == tokDOT { + next := p.peekNext() + if next.Type == tokSTAR { + p.advance() // consume '.' + starStart := p.curLoc() + p.advance() // consume '*' + starField := &ast.Identifier{Name: "*", Loc: p.makeLoc(starStart)} + dotAccess := &ast.DotAccess{Object: name, Field: starField, Loc: p.makeLoc(elemStart)} + elem := &ast.SelectElement{Expr: dotAccess, Loc: p.makeLoc(elemStart)} + return elem, nil + } + } + + // Check for function call: IDENT '(' ... ')'. + if p.cur.Type == tokLPAREN { + fc, err := p.parseFunctionCallWithName(name) + if err != nil { + return nil, err + } + alias, err := p.parseOptionalAlias() + if err != nil { + return nil, err + } + elem := &ast.SelectElement{Expr: fc, Alias: alias, Loc: p.makeLoc(elemStart)} + return elem, nil + } + + // Plain identifier, optionally with alias. + alias, err := p.parseOptionalAlias() + if err != nil { + return nil, err + } + elem := &ast.SelectElement{Expr: name, Alias: alias, Loc: p.makeLoc(elemStart)} + return elem, nil +} + +// parseOptionalAlias parses an optional AS IDENT clause. +func (p *Parser) parseOptionalAlias() (*ast.Identifier, error) { + if p.cur.Type != tokAS { + return nil, nil + } + p.advance() // consume AS + return p.parseIdentifier() +} + +// parseOrderByClause parses ORDER BY orderSpecElement (',' orderSpecElement)*. +func (p *Parser) parseOrderByClause() ([]*ast.OrderByElement, error) { + if err := p.expectKeyword(tokORDER); err != nil { + return nil, err + } + if err := p.expectKeyword(tokBY); err != nil { + return nil, err + } + + var elements []*ast.OrderByElement + first, err := p.parseOrderByElement() + if err != nil { + return nil, err + } + elements = append(elements, first) + + for p.cur.Type == tokCOMMA { + p.advance() + el, err := p.parseOrderByElement() + if err != nil { + return nil, err + } + elements = append(elements, el) + } + return elements, nil +} + +// parseOrderByElement parses a single ORDER BY element: +// +// IDENT [ASC | DESC] +// IDENT ANN OF vectorLiteral [LIMIT DECIMAL] +func (p *Parser) parseOrderByElement() (*ast.OrderByElement, error) { + elemStart := p.curLoc() + + col, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + elem := &ast.OrderByElement{Column: col} + + // Check for ANN OF ... ordering. + if p.cur.Type == tokANN { + p.advance() // consume ANN + if err := p.expectKeyword(tokOF); err != nil { + return nil, err + } + elem.IsANN = true + + vec, err := p.parseVectorLiteral() + if err != nil { + return nil, err + } + elem.AnnVector = vec + + // Optional LIMIT within ANN ordering. + if p.cur.Type == tokLIMIT { + p.advance() + limit, err := p.parseConstant() + if err != nil { + return nil, err + } + elem.AnnLimit = limit + } + + elem.Loc = p.makeLoc(elemStart) + return elem, nil + } + + // Optional ASC / DESC. + switch p.cur.Type { + case tokASC: + elem.Direction = "ASC" + p.advance() + case tokDESC: + elem.Direction = "DESC" + p.advance() + } + + elem.Loc = p.makeLoc(elemStart) + return elem, nil +} + +// parseVectorLiteral parses a vector literal: '[' constant (',' constant)* ']'. +func (p *Parser) parseVectorLiteral() (*ast.VectorLit, error) { + start := p.curLoc() + if _, err := p.expect(tokLBRACK); err != nil { + return nil, err + } + + var elements []ast.ExprNode + first, err := p.parseConstant() + if err != nil { + return nil, err + } + elements = append(elements, first) + + for p.cur.Type == tokCOMMA { + p.advance() + val, err := p.parseConstant() + if err != nil { + return nil, err + } + elements = append(elements, val) + } + + if _, err := p.expect(tokRBRACK); err != nil { + return nil, err + } + + return &ast.VectorLit{Elements: elements, Loc: p.makeLoc(start)}, nil +} + +// parseGroupByClause parses GROUP BY column (',' column)*. +func (p *Parser) parseGroupByClause() ([]*ast.Identifier, error) { + if err := p.expectKeyword(tokGROUP); err != nil { + return nil, err + } + if err := p.expectKeyword(tokBY); err != nil { + return nil, err + } + + var cols []*ast.Identifier + first, err := p.parseIdentifier() + if err != nil { + return nil, err + } + cols = append(cols, first) + + for p.cur.Type == tokCOMMA { + p.advance() + col, err := p.parseIdentifier() + if err != nil { + return nil, err + } + cols = append(cols, col) + } + return cols, nil +} + +// parsePerPartitionLimitClause parses PER PARTITION LIMIT decimal. +func (p *Parser) parsePerPartitionLimitClause() (ast.ExprNode, error) { + if err := p.expectKeyword(tokPER); err != nil { + return nil, err + } + if err := p.expectKeyword(tokPARTITION); err != nil { + return nil, err + } + if err := p.expectKeyword(tokLIMIT); err != nil { + return nil, err + } + return p.parseLimitValue() +} + +// parseLimitClause parses LIMIT decimal. +func (p *Parser) parseLimitClause() (ast.ExprNode, error) { + p.advance() // consume LIMIT + return p.parseLimitValue() +} + +// parseLimitValue parses an integer literal, bind marker, or named bind marker. +func (p *Parser) parseLimitValue() (ast.ExprNode, error) { + if p.cur.Type == tokQMARK { + m := &ast.BindMarker{Loc: ast.Loc{Start: p.cur.Loc, End: p.cur.End}} + p.advance() + return m, nil + } + if p.cur.Type == tokCOLON { + start := p.curLoc() + p.advance() + name, err := p.parseIdentifier() + if err != nil { + return nil, err + } + return &ast.BindMarker{Name: name.Name, Loc: p.makeLoc(start)}, nil + } + if p.cur.Type == tokINTEGER { + lit := &ast.IntegerLit{Val: p.cur.Str, Loc: ast.Loc{Start: p.cur.Loc, End: p.cur.End}} + p.advance() + return lit, nil + } + return nil, p.errorf("expected integer literal or bind marker for LIMIT, got %s", p.tokenDesc()) +} diff --git a/cassandra/parser/split.go b/cassandra/parser/split.go new file mode 100644 index 00000000..4a2c4291 --- /dev/null +++ b/cassandra/parser/split.go @@ -0,0 +1,133 @@ +package parser + +// Split splits a CQL input into segments at top-level semicolons, +// correctly handling string literals, quoted identifiers, code blocks, +// and comments. +func Split(sql string) []Segment { + var segments []Segment + start := 0 + i := 0 + for i < len(sql) { + ch := sql[i] + switch { + case ch == '\'': + i = skipString(sql, i) + case ch == '"': + i = skipQuotedIdent(sql, i) + case ch == '$' && i+1 < len(sql) && sql[i+1] == '$': + i = skipCodeBlock(sql, i) + case ch == '-' && i+1 < len(sql) && sql[i+1] == '-': + i = skipLineComment(sql, i) + case ch == '/' && i+1 < len(sql) && sql[i+1] == '*': + i = skipBlockComment(sql, i) + case ch == ';': + seg := makeSegment(sql, start, i) + if !seg.Empty { + segments = append(segments, seg) + } + i++ + start = i + default: + i++ + } + } + // Trailing segment after last semicolon (or the whole input if no semicolons). + if start < len(sql) { + seg := makeSegment(sql, start, len(sql)) + if !seg.Empty { + segments = append(segments, seg) + } + } + return segments +} + +// Segment represents a single SQL segment from splitting. +type Segment struct { + Text string + ByteStart int + ByteEnd int + Empty bool +} + +func makeSegment(sql string, start, end int) Segment { + for start < end && isWhitespace(sql[start]) { + start++ + } + for end > start && isWhitespace(sql[end-1]) { + end-- + } + text := sql[start:end] + empty := len(text) == 0 + return Segment{ + Text: text, + ByteStart: start, + ByteEnd: end, + Empty: empty, + } +} + +func isWhitespace(c byte) bool { + return c == ' ' || c == '\t' || c == '\n' || c == '\r' || c == '\f' +} + +func skipString(sql string, i int) int { + i++ // skip opening ' + for i < len(sql) { + if sql[i] == '\'' { + i++ + if i < len(sql) && sql[i] == '\'' { + i++ // escaped quote '' + continue + } + return i + } + i++ + } + return i +} + +func skipQuotedIdent(sql string, i int) int { + i++ // skip opening " + for i < len(sql) { + if sql[i] == '"' { + i++ + if i < len(sql) && sql[i] == '"' { + i++ // escaped quote "" + continue + } + return i + } + i++ + } + return i +} + +func skipCodeBlock(sql string, i int) int { + i += 2 // skip opening $$ + for i+1 < len(sql) { + if sql[i] == '$' && sql[i+1] == '$' { + return i + 2 + } + i++ + } + return len(sql) +} + +func skipLineComment(sql string, i int) int { + i += 2 // skip -- + for i < len(sql) && sql[i] != '\n' { + i++ + } + return i +} + +func skipBlockComment(sql string, i int) int { + i += 2 // skip /* + for i+1 < len(sql) { + if sql[i] == '*' && sql[i+1] == '/' { + return i + 2 + } + i++ + } + return len(sql) +} diff --git a/cassandra/parser/tokens.go b/cassandra/parser/tokens.go new file mode 100644 index 00000000..5c1677f2 --- /dev/null +++ b/cassandra/parser/tokens.go @@ -0,0 +1,206 @@ +package parser + +// Token type constants. +const ( + tokEOF = 0 + tokILLEGAL = -1 +) + +// Literal token types. +const ( + tokIDENT = iota + 1000 // unquoted identifier + tokQUOTED // double-quoted identifier + tokSTRING // single-quoted string + tokINTEGER // integer constant + tokFLOAT // float constant + tokUUID // UUID literal + tokHEX // hex literal 0xABCD + tokCODEBLOCK // $$...$$ code block +) + +// Operator / punctuation token types. +const ( + tokDOT = iota + 2000 // . + tokCOMMA // , + tokSEMI // ; + tokCOLON // : + tokLPAREN // ( + tokRPAREN // ) + tokLBRACE // { + tokRBRACE // } + tokLBRACK // [ + tokRBRACK // ] + tokSTAR // * + tokPLUS // + + tokMINUS // - + tokEQ // = + tokLT // < + tokGT // > + tokLTE // <= + tokGTE // >= + tokNE // != + tokMINUSMINUS // -- + tokQMARK // ? +) + +// Keyword token types. +const ( + tokACCESS = iota + 3000 + tokADD + tokAGGREGATE + tokALL + tokALLOW + tokALTER + tokAND + tokANN + tokAPPLY + tokAS + tokASC + tokASCII + tokAUTHORIZE + tokBATCH + tokBEGIN + tokBIGINT + tokBLOB + tokBOOLEAN + tokBY + tokCALLED + tokCAST + tokCLUSTERING + tokCOMPACT + tokCONTAINS + tokCOUNTER + tokCREATE + tokCURRENTDATE + tokCURRENTTIME + tokCURRENTTIMESTAMP + tokCURRENTTIMEUUID + tokCUSTOM + tokDATACENTERS + tokDATE + tokDATETIMENOW + tokDECIMAL + tokDEFAULT + tokDELETE + tokDESC + tokDESCRIBE + tokDISTINCT + tokDOUBLE + tokDROP + tokDURABLE_WRITES + tokDURATION + tokENTRIES + tokEXECUTE + tokEXISTS + tokFALSE + tokFILTERING + tokFINALFUNC + tokFLOATKW + tokFROM + tokFROMJSON + tokFROZEN + tokFULL + tokFUNCTION + tokFUNCTIONS + tokGRANT + tokGROUP + tokHASHED + tokIF + tokIN + tokINDEX + tokINET + tokINFINITY + tokINITCOND + tokINPUT + tokINSERT + tokINT + tokINTO + tokIS + tokJSON + tokKEY + tokKEYS + tokKEYSPACE + tokKEYSPACES + tokLANGUAGE + tokLIMIT + tokLIST + tokLOGGED + tokLOGIN + tokMAP + tokMATERIALIZED + tokMBEAN + tokMBEANS + tokMAXTIMEUUID + tokMINTIMEUUID + tokMODIFY + tokNAN + tokNORECURSIVE + tokNOSUPERUSER + tokNOT + tokNOW + tokNULL + tokOF + tokON + tokOPTIONS + tokOR + tokORDER + tokPARTITION + tokPASSWORD + tokPER + tokPERMISSION + tokPERMISSIONS + tokPRIMARY + tokRENAME + tokREPLACE + tokREPLICATION + tokRETURNS + tokREVOKE + tokROLE + tokROLES + tokSAI + tokSELECT + tokSET + tokSFUNC + tokSMALLINT + tokSTATIC + tokSTORAGE + tokSTORAGEATTACHEDINDEX + tokSTYPE + tokSUPERUSER + tokTABLE + tokTEXT + tokTIME + tokTIMESTAMP + tokTIMEUUID + tokTINYINT + tokTO + tokTOJSON + tokTRIGGER + tokTRUE + tokTRUNCATE + tokTTL + tokTUPLE + tokTYPE + tokUNLOGGED + tokUNSET + tokUPDATE + tokUSE + tokUSER + tokUSING + tokUUID_KW + tokVALUES + tokVARCHAR + tokVARINT + tokVECTOR + tokVIEW + tokWHERE + tokWITH +) + +// Token represents a single lexical token. +type Token struct { + Type int + Str string + Loc int // byte offset in source + End int // exclusive end byte offset +} diff --git a/cassandra/parser/types.go b/cassandra/parser/types.go new file mode 100644 index 00000000..13fc5b8b --- /dev/null +++ b/cassandra/parser/types.go @@ -0,0 +1,70 @@ +package parser + +import ( + "strings" + + "github.com/bytebase/omni/cassandra/ast" +) + +// parseDataType parses a CQL data type: name or name. +func (p *Parser) parseDataType() (*ast.DataType, error) { + start := p.curLoc() + name, err := p.parseDataTypeName() + if err != nil { + return nil, err + } + dt := &ast.DataType{Name: name, Loc: p.makeLoc(start)} + + if p.cur.Type == tokLT { + p.advance() // < + if strings.EqualFold(name.Name, "vector") { + elemType, err := p.parseDataType() + if err != nil { + return nil, err + } + dt.TypeParams = append(dt.TypeParams, elemType) + if _, err := p.expect(tokCOMMA); err != nil { + return nil, err + } + if p.cur.Type != tokINTEGER { + return nil, p.errorf("expected integer dimension for VECTOR, got %s", p.tokenDesc()) + } + dt.Dimension = &ast.IntegerLit{Val: p.cur.Str, Loc: ast.Loc{Start: p.cur.Loc, End: p.cur.End}} + p.advance() + } else { + for { + param, err := p.parseDataType() + if err != nil { + return nil, err + } + dt.TypeParams = append(dt.TypeParams, param) + if !p.match(tokCOMMA) { + break + } + } + } + if _, err := p.expect(tokGT); err != nil { + return nil, err + } + dt.Loc = p.makeLoc(start) + } + return dt, nil +} + +// parseDataTypeName parses a type name which can be a keyword or identifier. +func (p *Parser) parseDataTypeName() (*ast.Identifier, error) { + tok := p.cur + switch tok.Type { + case tokIDENT, tokQUOTED: + p.advance() + return &ast.Identifier{Name: tok.Str, Quoted: tok.Type == tokQUOTED, Loc: ast.Loc{Start: tok.Loc, End: tok.End}}, nil + case tokASCII, tokBIGINT, tokBLOB, tokBOOLEAN, tokCOUNTER, tokDATE, tokDECIMAL, + tokDOUBLE, tokDURATION, tokFLOATKW, tokFROZEN, tokINET, tokINT, tokLIST, + tokMAP, tokSET, tokSMALLINT, tokTEXT, tokTIME, tokTIMESTAMP, tokTIMEUUID, + tokTINYINT, tokTUPLE, tokVARCHAR, tokVARINT, tokUUID_KW, tokVECTOR: + p.advance() + return &ast.Identifier{Name: tok.Str, Loc: ast.Loc{Start: tok.Loc, End: tok.End}}, nil + default: + return nil, p.errorf("expected data type name, got %s", p.tokenDesc()) + } +} diff --git a/cassandra/parser/update.go b/cassandra/parser/update.go new file mode 100644 index 00000000..92eb9c64 --- /dev/null +++ b/cassandra/parser/update.go @@ -0,0 +1,288 @@ +package parser + +import ( + "github.com/bytebase/omni/cassandra/ast" +) + +// parseUpdate parses an UPDATE statement: +// +// UPDATE [keyspace.]table [USING TTL n AND TIMESTAMP m] SET assignments WHERE relationElements [IF EXISTS | IF ifConditionList] +func (p *Parser) parseUpdate() (*ast.UpdateStmt, error) { + start := p.curLoc() + if err := p.expectKeyword(tokUPDATE); err != nil { + return nil, err + } + + table, err := p.parseQualifiedName() + if err != nil { + return nil, err + } + + stmt := &ast.UpdateStmt{Table: table} + + // Optional USING clause (before SET) + using, err := p.parseUsingClause() + if err != nil { + return nil, err + } + stmt.Using = using + + // SET assignments + if err := p.expectKeyword(tokSET); err != nil { + return nil, err + } + + assignments, err := p.parseAssignments() + if err != nil { + return nil, err + } + stmt.Assignments = assignments + + // WHERE clause + where, err := p.parseWhereClause() + if err != nil { + return nil, err + } + stmt.Where = where + + // Optional IF EXISTS or IF conditions + if p.cur.Type == tokIF { + if p.peekNext().Type == tokEXISTS { + p.advance() // IF + p.advance() // EXISTS + stmt.IfExists = true + } else { + conds, err := p.parseIfConditions() + if err != nil { + return nil, err + } + stmt.IfConditions = conds + } + } + + stmt.Loc = p.makeLoc(start) + return stmt, nil +} + +// parseAssignments parses: assignmentElement (',' assignmentElement)* +func (p *Parser) parseAssignments() ([]*ast.AssignmentElement, error) { + var assignments []*ast.AssignmentElement + first, err := p.parseAssignmentElement() + if err != nil { + return nil, err + } + assignments = append(assignments, first) + for p.match(tokCOMMA) { + elem, err := p.parseAssignmentElement() + if err != nil { + return nil, err + } + assignments = append(assignments, elem) + } + return assignments, nil +} + +// parseAssignmentElement parses a single assignment in the SET clause. +// +// Forms: +// +// IDENT '=' expression +// IDENT '=' IDENT ('+' | '-') expression +// IDENT '=' expression ('+' | '-') IDENT +// IDENT '[' expression ']' '=' expression +func (p *Parser) parseAssignmentElement() (*ast.AssignmentElement, error) { + start := p.curLoc() + + target, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + // Check for field access (col.field) or index access (col[idx]) + var assignTarget ast.ExprNode = target + if p.cur.Type == tokDOT { + p.advance() + field, err := p.parseIdentifier() + if err != nil { + return nil, err + } + assignTarget = &ast.DotAccess{ + Object: target, + Field: field, + Loc: p.makeLoc(start), + } + } else if p.cur.Type == tokLBRACK { + p.advance() // [ + idx, err := p.parseExpression() + if err != nil { + return nil, err + } + if _, err := p.expect(tokRBRACK); err != nil { + return nil, err + } + assignTarget = &ast.IndexAccess{ + Collection: target, + Index: idx, + Loc: p.makeLoc(start), + } + } + + if _, err := p.expect(tokEQ); err != nil { + return nil, err + } + + // Parse the right-hand side. We need to detect patterns like: + // IDENT (+|-) expression => counter/collection increment + // expression (+|-) IDENT => collection prepend + // plain expression => simple assignment + rhs, err := p.parseExpression() + if err != nil { + return nil, err + } + + // Check for arithmetic operator after the first RHS expression. + if p.cur.Type == tokPLUS || p.cur.Type == tokMINUS { + op := p.cur.Str + p.advance() + right, err := p.parseExpression() + if err != nil { + return nil, err + } + + // Determine if this is col = col + val or col = val + col. + // If lhs is an identifier matching the target, it's col = col + val (operator is += or -=). + // If rhs is an identifier matching the target, it's col = val + col (also += but collection prepend). + // In either case, we represent it using the Operator field and the non-target side as Value. + if ident, ok := rhs.(*ast.Identifier); ok && matchesTarget(ident, target) { + // col = col + val => Operator: "+" + return &ast.AssignmentElement{ + Target: assignTarget, + Value: right, + Operator: op, + Loc: p.makeLoc(start), + }, nil + } + if ident, ok := right.(*ast.Identifier); ok && matchesTarget(ident, target) { + // col = val + col => Operator: "+" + return &ast.AssignmentElement{ + Target: assignTarget, + Value: rhs, + Operator: op, + Loc: p.makeLoc(start), + }, nil + } + // General arithmetic expression; wrap as binary expression value with simple "=". + return &ast.AssignmentElement{ + Target: assignTarget, + Value: &ast.BinaryExpr{Left: rhs, Op: op, Right: right, Loc: p.makeLoc(rhs.GetLoc().Start)}, + Operator: "=", + Loc: p.makeLoc(start), + }, nil + } + + return &ast.AssignmentElement{ + Target: assignTarget, + Value: rhs, + Operator: "=", + Loc: p.makeLoc(start), + }, nil +} + +// matchesTarget checks if an identifier matches the assignment target name. +func matchesTarget(ident *ast.Identifier, target *ast.Identifier) bool { + return ident.Name == target.Name && ident.Quoted == target.Quoted +} + +// parseIfConditions parses IF col op val [AND col op val ...]. +func (p *Parser) parseIfConditions() ([]*ast.IfCondition, error) { + if err := p.expectKeyword(tokIF); err != nil { + return nil, err + } + + var conditions []*ast.IfCondition + for { + cond, err := p.parseIfCondition() + if err != nil { + return nil, err + } + conditions = append(conditions, cond) + if !p.match(tokAND) { + break + } + } + return conditions, nil +} + +// parseIfCondition parses a single LWT condition: +// +// col op value | col IN (values) | col CONTAINS [KEY] value +func (p *Parser) parseIfCondition() (*ast.IfCondition, error) { + start := p.curLoc() + + col, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + // IN condition + if p.cur.Type == tokIN { + p.advance() + if _, err := p.expect(tokLPAREN); err != nil { + return nil, err + } + var values []ast.ExprNode + if p.cur.Type != tokRPAREN { + values, err = p.parseExpressionList() + if err != nil { + return nil, err + } + } + if _, err := p.expect(tokRPAREN); err != nil { + return nil, err + } + return &ast.IfCondition{ + Column: col, + Op: "IN", + InValues: values, + Loc: p.makeLoc(start), + }, nil + } + + // CONTAINS [KEY] condition + if p.cur.Type == tokCONTAINS { + p.advance() + op := "CONTAINS" + if p.cur.Type == tokKEY { + op = "CONTAINS KEY" + p.advance() + } + val, err := p.parseExpression() + if err != nil { + return nil, err + } + return &ast.IfCondition{ + Column: col, + Op: op, + Value: val, + Loc: p.makeLoc(start), + }, nil + } + + // Comparison operator + op := p.cur.Str + if !p.match(tokEQ, tokLT, tokGT, tokLTE, tokGTE, tokNE) { + return nil, p.errorf("expected comparison operator, IN, or CONTAINS in IF condition, got %s", p.tokenDesc()) + } + + val, err := p.parseExpression() + if err != nil { + return nil, err + } + + return &ast.IfCondition{ + Column: col, + Op: op, + Value: val, + Loc: p.makeLoc(start), + }, nil +} diff --git a/cassandra/split.go b/cassandra/split.go new file mode 100644 index 00000000..9f095e36 --- /dev/null +++ b/cassandra/split.go @@ -0,0 +1,28 @@ +package cassandra + +import ( + "github.com/bytebase/omni/cassandra/parser" +) + +// Segment represents a single SQL segment from splitting at top-level semicolons. +type Segment struct { + Text string + ByteStart int + ByteEnd int + Empty bool +} + +// Split splits a CQL input into segments at top-level semicolons. +func Split(sql string) []Segment { + internal := parser.Split(sql) + result := make([]Segment, len(internal)) + for i, seg := range internal { + result[i] = Segment{ + Text: seg.Text, + ByteStart: seg.ByteStart, + ByteEnd: seg.ByteEnd, + Empty: seg.Empty, + } + } + return result +} diff --git a/cassandra/split_test.go b/cassandra/split_test.go new file mode 100644 index 00000000..9a1983ee --- /dev/null +++ b/cassandra/split_test.go @@ -0,0 +1,51 @@ +package cassandra + +import ( + "testing" +) + +func TestSplit(t *testing.T) { + tests := []struct { + input string + expected int + }{ + {"", 0}, + {"SELECT * FROM users", 1}, + {"SELECT * FROM users;", 1}, + {"SELECT * FROM users; INSERT INTO t (id) VALUES (1)", 2}, + {"SELECT * FROM users; ; INSERT INTO t (id) VALUES (1)", 2}, // empty segment filtered + {"SELECT * FROM users WHERE name = 'hello;world'", 1}, // semicolon inside string + {`SELECT * FROM users WHERE name = "test;col"`, 1}, // semicolon inside quoted ident + {"SELECT * FROM users; -- comment with ; inside\nINSERT INTO t (id) VALUES (1)", 2}, + {"SELECT * FROM users /* ; */ ; SELECT 1", 2}, + {"CREATE FUNCTION f() RETURNS NULL ON NULL INPUT RETURNS text LANGUAGE java AS $$return \";\";$$", 1}, // semicolon in code block + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + segs := Split(tt.input) + if len(segs) != tt.expected { + t.Errorf("Split(%q): got %d segments, want %d", tt.input, len(segs), tt.expected) + for i, s := range segs { + t.Logf(" segment %d: %q (empty=%v)", i, s.Text, s.Empty) + } + } + }) + } +} + +func TestSplitPositions(t *testing.T) { + input := "SELECT 1; SELECT 2" + segs := Split(input) + if len(segs) != 2 { + t.Fatalf("expected 2 segments, got %d", len(segs)) + } + if segs[0].ByteStart != 0 || segs[0].ByteEnd != 8 { + t.Errorf("seg 0: ByteStart=%d ByteEnd=%d, want 0..8", segs[0].ByteStart, segs[0].ByteEnd) + } + if segs[0].Text != "SELECT 1" { + t.Errorf("seg 0: Text=%q, want %q", segs[0].Text, "SELECT 1") + } + if segs[1].ByteStart != 10 || segs[1].ByteEnd != 18 { + t.Errorf("seg 1: ByteStart=%d ByteEnd=%d, want 10..18", segs[1].ByteStart, segs[1].ByteEnd) + } +} diff --git a/cassandra/testdata/cql/examples/alterKeyspace.cql b/cassandra/testdata/cql/examples/alterKeyspace.cql new file mode 100644 index 00000000..93f7780b --- /dev/null +++ b/cassandra/testdata/cql/examples/alterKeyspace.cql @@ -0,0 +1,6 @@ +ALTER KEYSPACE cycling +WITH REPLICATION = { + 'class' : 'NetworkTopologyStrategy', + 'datacenter1' : 3 } + AND DURABLE_WRITES = false ; + \ No newline at end of file diff --git a/cassandra/testdata/cql/examples/alterMaterializedView.cql b/cassandra/testdata/cql/examples/alterMaterializedView.cql new file mode 100644 index 00000000..539e7553 --- /dev/null +++ b/cassandra/testdata/cql/examples/alterMaterializedView.cql @@ -0,0 +1,3 @@ +ALTER MATERIALIZED VIEW cycling.cyclist_by_age +WITH comment = 'A most excellent and useful view' +AND bloom_filter_fp_chance = 0.02; diff --git a/cassandra/testdata/cql/examples/alterRole.cql b/cassandra/testdata/cql/examples/alterRole.cql new file mode 100644 index 00000000..b2018603 --- /dev/null +++ b/cassandra/testdata/cql/examples/alterRole.cql @@ -0,0 +1 @@ +ALTER ROLE coach WITH PASSWORD='bestTeam'; diff --git a/cassandra/testdata/cql/examples/alterTable.cql b/cassandra/testdata/cql/examples/alterTable.cql new file mode 100644 index 00000000..85bbe5fb --- /dev/null +++ b/cassandra/testdata/cql/examples/alterTable.cql @@ -0,0 +1,5 @@ +ALTER TABLE cycling_comments +WITH compression = { + 'sstable_compression' : 'DeflateCompressor', + 'chunk_length_kb' : 64 }; + diff --git a/cassandra/testdata/cql/examples/alterType.cql b/cassandra/testdata/cql/examples/alterType.cql new file mode 100644 index 00000000..ae99ca5e --- /dev/null +++ b/cassandra/testdata/cql/examples/alterType.cql @@ -0,0 +1,4 @@ +ALTER TYPE cycling.fullname +RENAME middlename TO middle +AND lastname to last +AND firstname to first; diff --git a/cassandra/testdata/cql/examples/alterUser.cql b/cassandra/testdata/cql/examples/alterUser.cql new file mode 100644 index 00000000..c2e8eaf0 --- /dev/null +++ b/cassandra/testdata/cql/examples/alterUser.cql @@ -0,0 +1 @@ +ALTER USER moss WITH PASSWORD 'bestReceiver'; \ No newline at end of file diff --git a/cassandra/testdata/cql/examples/applyBatch.cql b/cassandra/testdata/cql/examples/applyBatch.cql new file mode 100644 index 00000000..189b6014 --- /dev/null +++ b/cassandra/testdata/cql/examples/applyBatch.cql @@ -0,0 +1 @@ +APPLY BATCH; \ No newline at end of file diff --git a/cassandra/testdata/cql/examples/createAggregate.cql b/cassandra/testdata/cql/examples/createAggregate.cql new file mode 100644 index 00000000..31ee5d43 --- /dev/null +++ b/cassandra/testdata/cql/examples/createAggregate.cql @@ -0,0 +1,5 @@ +CREATE AGGREGATE cycling.average(int) +SFUNC avgState +STYPE tuple +FINALFUNC avgFinal +INITCOND (0,0); \ No newline at end of file diff --git a/cassandra/testdata/cql/examples/createFunction.cql b/cassandra/testdata/cql/examples/createFunction.cql new file mode 100644 index 00000000..7ace8ea9 --- /dev/null +++ b/cassandra/testdata/cql/examples/createFunction.cql @@ -0,0 +1,16 @@ +CREATE OR REPLACE FUNCTION cycling.avgFinal ( state tuple ) +CALLED ON NULL INPUT +RETURNS double +LANGUAGE java AS + $$ double r = 0; + if (state.getInt(0) == 0) return null; + r = state.getLong(1); + r/= state.getInt(0); + return Double.valueOf(r); $$ +; + +CREATE OR REPLACE FUNCTION setMin(input set) RETURNS NULL ON NULL INPUT RETURNS int LANGUAGE java AS ' + int min = Integer.MAX_VALUE; + for (Object i : input) { min = Math.min(min, (Integer) i); } + return min; + '; diff --git a/cassandra/testdata/cql/examples/createIndex.cql b/cassandra/testdata/cql/examples/createIndex.cql new file mode 100644 index 00000000..eaf745b5 --- /dev/null +++ b/cassandra/testdata/cql/examples/createIndex.cql @@ -0,0 +1,4 @@ + CREATE INDEX user_state + ON myschema.users (state); + +CREATE INDEX ON myschema.users (zip); \ No newline at end of file diff --git a/cassandra/testdata/cql/examples/createIndexSAI.cql b/cassandra/testdata/cql/examples/createIndexSAI.cql new file mode 100644 index 00000000..c35f61f3 --- /dev/null +++ b/cassandra/testdata/cql/examples/createIndexSAI.cql @@ -0,0 +1,15 @@ +-- Storage-Attached Index (SAI) examples + +-- Basic SAI index +CREATE INDEX ON products (name) USING 'sai'; + +-- Custom SAI index with explicit name +CREATE CUSTOM INDEX product_name_idx ON products (name) USING 'StorageAttachedIndex'; + +-- Vector index for similarity search +CREATE CUSTOM INDEX embedding_idx ON products (embedding) + USING 'StorageAttachedIndex' + WITH {'similarity_function': 'cosine'}; + +-- Collection indexes with SAI (entire collection) +CREATE INDEX ON mytable (mymap) USING 'sai'; \ No newline at end of file diff --git a/cassandra/testdata/cql/examples/createKeyspace.cql b/cassandra/testdata/cql/examples/createKeyspace.cql new file mode 100644 index 00000000..759cba23 --- /dev/null +++ b/cassandra/testdata/cql/examples/createKeyspace.cql @@ -0,0 +1,6 @@ +CREATE KEYSPACE cycling + WITH REPLICATION = { + 'class' : 'SimpleStrategy', + 'replication_factor' : 1 + }; + \ No newline at end of file diff --git a/cassandra/testdata/cql/examples/createMaterializedView.cql b/cassandra/testdata/cql/examples/createMaterializedView.cql new file mode 100644 index 00000000..b4d16200 --- /dev/null +++ b/cassandra/testdata/cql/examples/createMaterializedView.cql @@ -0,0 +1,7 @@ +CREATE MATERIALIZED VIEW cycling.cyclist_by_age +AS SELECT age, name, country +FROM cycling.cyclist_mv +WHERE age IS NOT NULL AND cid IS NOT NULL +PRIMARY KEY (age, cid) +WITH caching = { 'keys' : 'ALL', 'rows_per_partition' : '100' } + AND comment = 'Based on table cyclist' ; \ No newline at end of file diff --git a/cassandra/testdata/cql/examples/createRole.cql b/cassandra/testdata/cql/examples/createRole.cql new file mode 100644 index 00000000..590783c1 --- /dev/null +++ b/cassandra/testdata/cql/examples/createRole.cql @@ -0,0 +1,3 @@ +CREATE ROLE coach +WITH PASSWORD = 'All4One2day!' +AND LOGIN = true; \ No newline at end of file diff --git a/cassandra/testdata/cql/examples/createTable.cql b/cassandra/testdata/cql/examples/createTable.cql new file mode 100644 index 00000000..ab98b7cf --- /dev/null +++ b/cassandra/testdata/cql/examples/createTable.cql @@ -0,0 +1,5 @@ +CREATE TABLE cycling.race_winners ( + race_name text, + race_position int, + cyclist_name FROZEN, + PRIMARY KEY (race_name, race_position)); diff --git a/cassandra/testdata/cql/examples/createTableVector.cql b/cassandra/testdata/cql/examples/createTableVector.cql new file mode 100644 index 00000000..1ab31dcb --- /dev/null +++ b/cassandra/testdata/cql/examples/createTableVector.cql @@ -0,0 +1,16 @@ +-- Create table with VECTOR data type (Cassandra 5.0+) +CREATE TABLE products ( + id uuid PRIMARY KEY, + name text, + description text, + embedding VECTOR, + price decimal +); + +-- Create table with DURATION type +CREATE TABLE events ( + id uuid PRIMARY KEY, + name text, + event_duration DURATION, + created_at timestamp +); \ No newline at end of file diff --git a/cassandra/testdata/cql/examples/createTrigger.cql b/cassandra/testdata/cql/examples/createTrigger.cql new file mode 100644 index 00000000..af59c5f2 --- /dev/null +++ b/cassandra/testdata/cql/examples/createTrigger.cql @@ -0,0 +1 @@ +DROP TRIGGER trigger_name ON table_name; diff --git a/cassandra/testdata/cql/examples/createType.cql b/cassandra/testdata/cql/examples/createType.cql new file mode 100644 index 00000000..01379259 --- /dev/null +++ b/cassandra/testdata/cql/examples/createType.cql @@ -0,0 +1,6 @@ +CREATE TYPE cycling.basic_info ( + birthday timestamp, + nationality text, + weight text, + height text +); \ No newline at end of file diff --git a/cassandra/testdata/cql/examples/createUser.cql b/cassandra/testdata/cql/examples/createUser.cql new file mode 100644 index 00000000..b6c6af3c --- /dev/null +++ b/cassandra/testdata/cql/examples/createUser.cql @@ -0,0 +1 @@ +CREATE USER newuser WITH PASSWORD 'password'; \ No newline at end of file diff --git a/cassandra/testdata/cql/examples/delete.cql b/cassandra/testdata/cql/examples/delete.cql new file mode 100644 index 00000000..8345ac58 --- /dev/null +++ b/cassandra/testdata/cql/examples/delete.cql @@ -0,0 +1,8 @@ +DELETE firstname, lastname FROM cycling.cyclist_name +WHERE id = e7ae5cf3-d358-4d99-b900-85902fda9bb0; +DELETE FROM cycling.cyclist_name +WHERE id =e7ae5cf3-d358-4d99-b900-85902fda9bb0 +if firstname='Alex' and lastname='Smith'; +DELETE id FROM cyclist_id +WHERE lastname = 'WELTEN' and firstname = 'Bram' +IF EXISTS; diff --git a/cassandra/testdata/cql/examples/dropAggregate.cql b/cassandra/testdata/cql/examples/dropAggregate.cql new file mode 100644 index 00000000..0392a206 --- /dev/null +++ b/cassandra/testdata/cql/examples/dropAggregate.cql @@ -0,0 +1 @@ +DROP AGGREGATE IF EXISTS cycling.avgState; \ No newline at end of file diff --git a/cassandra/testdata/cql/examples/dropFunction.cql b/cassandra/testdata/cql/examples/dropFunction.cql new file mode 100644 index 00000000..93605fc1 --- /dev/null +++ b/cassandra/testdata/cql/examples/dropFunction.cql @@ -0,0 +1 @@ +DROP FUNCTION IF EXISTS cycling.fLog; \ No newline at end of file diff --git a/cassandra/testdata/cql/examples/dropIndex.cql b/cassandra/testdata/cql/examples/dropIndex.cql new file mode 100644 index 00000000..7a44cada --- /dev/null +++ b/cassandra/testdata/cql/examples/dropIndex.cql @@ -0,0 +1 @@ +DROP INDEX cycling.ryear; \ No newline at end of file diff --git a/cassandra/testdata/cql/examples/dropKeyspace.cql b/cassandra/testdata/cql/examples/dropKeyspace.cql new file mode 100644 index 00000000..e86c66c5 --- /dev/null +++ b/cassandra/testdata/cql/examples/dropKeyspace.cql @@ -0,0 +1 @@ +DROP KEYSPACE cycling; diff --git a/cassandra/testdata/cql/examples/dropMaterializedView.cql b/cassandra/testdata/cql/examples/dropMaterializedView.cql new file mode 100644 index 00000000..9ad58b08 --- /dev/null +++ b/cassandra/testdata/cql/examples/dropMaterializedView.cql @@ -0,0 +1 @@ +DROP MATERIALIZED VIEW cycling.cyclist_by_age; \ No newline at end of file diff --git a/cassandra/testdata/cql/examples/dropRole.cql b/cassandra/testdata/cql/examples/dropRole.cql new file mode 100644 index 00000000..9c48b1d0 --- /dev/null +++ b/cassandra/testdata/cql/examples/dropRole.cql @@ -0,0 +1 @@ +DROP ROLE IF EXISTS team_manager; diff --git a/cassandra/testdata/cql/examples/dropTable.cql b/cassandra/testdata/cql/examples/dropTable.cql new file mode 100644 index 00000000..49059e64 --- /dev/null +++ b/cassandra/testdata/cql/examples/dropTable.cql @@ -0,0 +1 @@ +DROP TABLE cycling.cyclist_name; \ No newline at end of file diff --git a/cassandra/testdata/cql/examples/dropTrigger.cql b/cassandra/testdata/cql/examples/dropTrigger.cql new file mode 100644 index 00000000..df059573 --- /dev/null +++ b/cassandra/testdata/cql/examples/dropTrigger.cql @@ -0,0 +1 @@ +DROP TRIGGER trigger_name ON ks.table_name; \ No newline at end of file diff --git a/cassandra/testdata/cql/examples/dropType.cql b/cassandra/testdata/cql/examples/dropType.cql new file mode 100644 index 00000000..931272c8 --- /dev/null +++ b/cassandra/testdata/cql/examples/dropType.cql @@ -0,0 +1 @@ +DROP TYPE cycling.basic_info ; diff --git a/cassandra/testdata/cql/examples/dropUser.cql b/cassandra/testdata/cql/examples/dropUser.cql new file mode 100644 index 00000000..26a82e69 --- /dev/null +++ b/cassandra/testdata/cql/examples/dropUser.cql @@ -0,0 +1 @@ +DROP USER IF EXISTS boone; \ No newline at end of file diff --git a/cassandra/testdata/cql/examples/grant.cql b/cassandra/testdata/cql/examples/grant.cql new file mode 100644 index 00000000..9f1f0154 --- /dev/null +++ b/cassandra/testdata/cql/examples/grant.cql @@ -0,0 +1,10 @@ +GRANT SELECT ON ALL KEYSPACES TO coach; +GRANT MODIFY ON KEYSPACE field TO manager; +GRANT ALTER ON KEYSPACE cycling TO coach; +GRANT ALL PERMISSIONS ON cycling.name TO coach; +GRANT ALL ON KEYSPACE cycling TO cycling_admin; + + + + + diff --git a/cassandra/testdata/cql/examples/insert.cql b/cassandra/testdata/cql/examples/insert.cql new file mode 100644 index 00000000..5bb20a9f --- /dev/null +++ b/cassandra/testdata/cql/examples/insert.cql @@ -0,0 +1,25 @@ +INSERT INTO cycling.cyclist_name (id, lastname, firstname) + VALUES (6ab09bec-e68e-48d9-a5f8-97e6fb4c9b47, 'KRUIKSWIJK','Steven') + USING TTL 86400 AND TIMESTAMP 123456789; + + + INSERT INTO cycling.cyclist_categories (id,lastname,categories) + VALUES( + '6ab09bec-e68e-48d9-a5f8-97e6fb4c9b47', + 'KRUIJSWIJK', + {'GC', 'Time-trial', 'Sprint'}); + +INSERT INTO cycling.cyclist_categories JSON + '{"category": "", "points":780, "id": "6ab09bec-e68e-48d9-a5f8-97e6fb4c9b47"}'; + +INSERT INTO cycling.cyclist_teams (id,lastname,teams) + VALUES(5b6962dd-3f90-4c93-8f61-eabfa4a803e2,'VOS',$$Women's Tour of New Zealand$$); + +INSERT INTO cycling.route(race_id,race_name,lat_long) VALUES (500, 'Name', ('Champagne', (46.833,6.65))); + +INSERT INTO "students"("id", "address", "name", "[age]", "age", "colu'mn1", "colu68mn1", "height") values + (740, 'hongkong','alice',null,32,'','',172); + +INSERT INTO test(a,b,c,d) values(1,['listtext1','listtext2'],{'settext1','settext2'},{'mapkey1':'mapvale2','mapkey2':'mapvalue2'}); + +INSERT INTO PERSON (id, name) VALUES (uuid(), 'fullname'); diff --git a/cassandra/testdata/cql/examples/insertFunctions.cql b/cassandra/testdata/cql/examples/insertFunctions.cql new file mode 100644 index 00000000..6629edb9 --- /dev/null +++ b/cassandra/testdata/cql/examples/insertFunctions.cql @@ -0,0 +1,13 @@ +-- INSERT statements with CQL functions + +-- UUID generation function +INSERT INTO events (id, created_at) VALUES (uuid(), now()); + +-- Current time UUID and timestamp functions +INSERT INTO events (id, created_at) VALUES (currentTimeuuid(), currentTimestamp()); + +-- Date/time component functions +INSERT INTO logs (id, log_date, log_time) VALUES (uuid(), currentDate(), currentTime()); + +-- fromJson function for complex data +INSERT INTO users (id, profile) VALUES (uuid(), fromJson('{"location": "NYC", "interests": ["music", "sports"]}')); \ No newline at end of file diff --git a/cassandra/testdata/cql/examples/insertJson.cql b/cassandra/testdata/cql/examples/insertJson.cql new file mode 100644 index 00000000..d5a64c1e --- /dev/null +++ b/cassandra/testdata/cql/examples/insertJson.cql @@ -0,0 +1,10 @@ +-- INSERT statements with JSON support + +-- Basic JSON insert +INSERT INTO users JSON '{"id": "123e4567-e89b-12d3-a456-426614174000", "name": "John Doe", "age": 30}'; + +-- JSON insert with DEFAULT UNSET clause +INSERT INTO users JSON '{"id": "123e4567-e89b-12d3-a456-426614174000", "name": "Jane Doe"}' DEFAULT UNSET; + +-- fromJson function in VALUES clause +INSERT INTO users (id, profile) VALUES (uuid(), fromJson('{"location": "NYC", "interests": ["music", "sports"]}')); \ No newline at end of file diff --git a/cassandra/testdata/cql/examples/listPermissions.cql b/cassandra/testdata/cql/examples/listPermissions.cql new file mode 100644 index 00000000..b0269b35 --- /dev/null +++ b/cassandra/testdata/cql/examples/listPermissions.cql @@ -0,0 +1,5 @@ +LIST ALL +OF coach; +LIST ALL; +LIST ALL +ON cyclist.name; \ No newline at end of file diff --git a/cassandra/testdata/cql/examples/listRoles.cql b/cassandra/testdata/cql/examples/listRoles.cql new file mode 100644 index 00000000..0e0f7129 --- /dev/null +++ b/cassandra/testdata/cql/examples/listRoles.cql @@ -0,0 +1,3 @@ +LIST ROLES; +LIST ROLES +OF manager; diff --git a/cassandra/testdata/cql/examples/revoke.cql b/cassandra/testdata/cql/examples/revoke.cql new file mode 100644 index 00000000..fdd3b4b8 --- /dev/null +++ b/cassandra/testdata/cql/examples/revoke.cql @@ -0,0 +1,6 @@ +REVOKE SELECT +ON cycling.name +FROM manager; +REVOKE ALTER +ON ALL ROLES +FROM coach; \ No newline at end of file diff --git a/cassandra/testdata/cql/examples/select.cql b/cassandra/testdata/cql/examples/select.cql new file mode 100644 index 00000000..11d24922 --- /dev/null +++ b/cassandra/testdata/cql/examples/select.cql @@ -0,0 +1,41 @@ +SELECT event_id, + dateOf(created_at) AS creation_date, + blobAsText(content) AS content + FROM timeline; + + SELECT COUNT(*) +FROM system.IndexInfo; + +SELECT lastname +FROM cycling.cyclist_name +LIMIT 50000; + +SELECT id, lastname, teams +FROM cycling.cyclist_career_teams +WHERE id=5b6962dd-3f90-4c93-8f61-eabfa4a803e2; + + +SELECT * FROM cycling.cyclist_category; + +SELECT * FROM cycling.cyclist_category WHERE category = 'SPRINT'; + +SELECT category, points, lastname FROM cycling.cyclist_category; + +SELECT * From cycling.cyclist_name LIMIT 3; + +SELECT * FROM cycling.cyclist_cat_pts WHERE category = 'GC' ORDER BY points ASC; + +SELECT race_name, point_id, lat_long AS CITY_LATITUDE_LONGITUDE FROM cycling.route; + +SELECT * FROM cycling.upcoming_calendar WHERE year = 2015 AND month = 06; + +select json name, checkin_id, time_stamp from checkin; + +select name, checkin_id, toJson(time_stamp) from checkin; + +SELECT * FROM cycling.calendar WHERE race_id IN (100, 101, 102) AND (race_start_date, race_end_date) IN (('2015-01-01','2015-02-02'), ('2016-01-01','2016-02-02')); + +SELECT * FROM cycling.calendar WHERE race_id IN (100, 101, 102) AND (race_start, race_end) >= ('2015-01-01', '2015-02-02'); + +SELECT * FROM cycling.race_times WHERE race_name = '17th Santos Tour Down Under' and race_time >= '19:15:19' AND race_time <= '19:15:39'; + diff --git a/cassandra/testdata/cql/examples/selectFunctions.cql b/cassandra/testdata/cql/examples/selectFunctions.cql new file mode 100644 index 00000000..0d452332 --- /dev/null +++ b/cassandra/testdata/cql/examples/selectFunctions.cql @@ -0,0 +1,16 @@ +-- SELECT statements with CQL functions + +-- toJson function for JSON output +SELECT id, toJson(profile) FROM users WHERE age > 25; + +-- JSON format selection (all columns) +SELECT JSON * FROM users WHERE id = 123e4567-e89b-12d3-a456-426614174000; + +-- JSON format selection (specific columns) +SELECT JSON id, name, age FROM users LIMIT 10; + +-- dateOf function to extract timestamp from timeuuid +SELECT id, dateOf(created_id) FROM events WHERE id = 550e8400-e29b-41d4-a716-446655440000; + +-- unixTimestampOf function to get UNIX timestamp +SELECT id, unixTimestampOf(created_id) FROM events LIMIT 5; \ No newline at end of file diff --git a/cassandra/testdata/cql/examples/selectJson.cql b/cassandra/testdata/cql/examples/selectJson.cql new file mode 100644 index 00000000..9dd97894 --- /dev/null +++ b/cassandra/testdata/cql/examples/selectJson.cql @@ -0,0 +1,10 @@ +-- SELECT statements with JSON output + +-- Select all columns as JSON +SELECT JSON * FROM users WHERE id = 123e4567-e89b-12d3-a456-426614174000; + +-- Select specific columns as JSON +SELECT JSON id, name, age FROM users LIMIT 10; + +-- toJson function for specific column +SELECT id, toJson(profile) FROM users WHERE age > 25; \ No newline at end of file diff --git a/cassandra/testdata/cql/examples/selectVectorANN.cql b/cassandra/testdata/cql/examples/selectVectorANN.cql new file mode 100644 index 00000000..3c94e826 --- /dev/null +++ b/cassandra/testdata/cql/examples/selectVectorANN.cql @@ -0,0 +1,17 @@ +-- Vector similarity search using ANN (Approximate Nearest Neighbor) + +-- Basic vector search with ORDER BY ANN +SELECT * FROM products +WHERE category = 'electronics' +ORDER BY embedding ANN OF [0.1, 0.2, 0.3, 0.4, 0.5] +LIMIT 10; + +-- Vector search without WHERE clause +SELECT id, name, price FROM products +ORDER BY embedding ANN OF [0.25, 0.5, 0.75] +LIMIT 5; + +-- Combined with other ORDER BY (if supported) +SELECT * FROM products +WHERE price < 1000 +ORDER BY embedding ANN OF [0.1, 0.2, 0.3] LIMIT 20; \ No newline at end of file diff --git a/cassandra/testdata/cql/examples/tableOptions.cql b/cassandra/testdata/cql/examples/tableOptions.cql new file mode 100644 index 00000000..6c4024ec --- /dev/null +++ b/cassandra/testdata/cql/examples/tableOptions.cql @@ -0,0 +1,9 @@ +CREATE TABLE alphabet ( + a int, + b int, + c int, + PRIMARY KEY ((a), b, c) +) +WITH COMPACT STORAGE +AND CLUSTERING ORDER BY (b ASC, c DESC) +AND default_time_to_live = 300; diff --git a/cassandra/testdata/cql/examples/truncate.cql b/cassandra/testdata/cql/examples/truncate.cql new file mode 100644 index 00000000..5cd74667 --- /dev/null +++ b/cassandra/testdata/cql/examples/truncate.cql @@ -0,0 +1,2 @@ +TRUNCATE cycling.user_activity; +TRUNCATE TABLE cycling.user_activity; \ No newline at end of file diff --git a/cassandra/testdata/cql/examples/update.cql b/cassandra/testdata/cql/examples/update.cql new file mode 100644 index 00000000..ff06f88f --- /dev/null +++ b/cassandra/testdata/cql/examples/update.cql @@ -0,0 +1,29 @@ + UPDATE cycling.cyclist_name + SET comments ='Rides hard, gets along with others, a real winner' + WHERE id = fb372533-eb95-4bb4-8685-6ef61e994caa IF EXISTS; + + UPDATE cycling.cyclists + SET firstname = 'Marianne', + lastname = 'VOS' + WHERE id = 88b8fd18-b1ed-4e96-bf79-4280797cba80; + UPDATE cycling.cyclists + SET firstname = 'Anna', lastname = 'VAN DER BREGGEN' WHERE id = e7cd5752-bc0d-4157-a80f-7523add8dbcd; + + + UPDATE cycling.upcoming_calendar + SET events = ['Tour de France'] + events WHERE year=2015 AND month=06; + + + + UPDATE users + SET state = 'TX' + WHERE user_uuid + IN (88b8fd18-b1ed-4e96-bf79-4280797cba80, + 06a8913c-c0d6-477c-937d-6c1b69a95d43, + bc108776-7cb5-477f-917d-869c12dfffa8); + + UPDATE cyclist.cyclist_career_teams SET teams = teams + {'Team DSB - Ballast Nedam'} WHERE id = 88b8fd18-b1ed-4e96-bf79-4280797cba80; + + UPDATE cyclist.cyclist_career_teams SET teams = teams - {'WOMBATS'} WHERE id = 88b8fd18-b1ed-4e96-bf79-4280797cba80; + + UPDATE cyclist.cyclist_career_teams SET teams = {} WHERE id = 88b8fd18-b1ed-4e96-bf79-4280797cba80; \ No newline at end of file diff --git a/cassandra/testdata/cql/examples/use.cql b/cassandra/testdata/cql/examples/use.cql new file mode 100644 index 00000000..8c066c96 --- /dev/null +++ b/cassandra/testdata/cql/examples/use.cql @@ -0,0 +1,3 @@ +USE key_name; +USE PortfolioDemo; +USE "Excalibur"; diff --git a/cassandra/walk_coverage_test.go b/cassandra/walk_coverage_test.go new file mode 100644 index 00000000..10d7d3b8 --- /dev/null +++ b/cassandra/walk_coverage_test.go @@ -0,0 +1,331 @@ +package cassandra + +import ( + "reflect" + "testing" + + "github.com/bytebase/omni/cassandra/ast" + "github.com/bytebase/omni/cassandra/parser" +) + +func TestWalkCoversAllChildren(t *testing.T) { + tests := []struct { + name string + sql string + }{ + {"CREATE TABLE complex", "CREATE TABLE t (id int, name text, age int, PRIMARY KEY ((id, name), age)) WITH CLUSTERING ORDER BY (age DESC) AND comment = 'test'"}, + {"CREATE MV", "CREATE MATERIALIZED VIEW mv AS SELECT col1, col2 FROM t WHERE col1 IS NOT NULL AND col2 IS NOT NULL PRIMARY KEY (col1, col2)"}, + {"CREATE ROLE WITH", "CREATE ROLE myrole WITH PASSWORD = 'secret' AND LOGIN = true AND SUPERUSER = false"}, + {"SELECT complex", "SELECT DISTINCT name AS n FROM ks.users WHERE id = 1 ORDER BY name ASC LIMIT 10 ALLOW FILTERING"}, + {"UPDATE with IF", "UPDATE users USING TTL 3600 SET name = 'Bob' WHERE id = 2 IF name = 'old'"}, + {"INSERT JSON", "INSERT INTO users JSON '{\"id\": 1}' DEFAULT UNSET IF NOT EXISTS USING TTL 86400"}, + {"DELETE complex", "DELETE name FROM ks.users WHERE id = 2 IF EXISTS"}, + {"BATCH", "BEGIN UNLOGGED BATCH USING TIMESTAMP 12345 INSERT INTO t (id) VALUES (1); DELETE FROM t WHERE id = 2; APPLY BATCH"}, + {"GRANT", "GRANT SELECT ON TABLE users TO reader"}, + {"CREATE KEYSPACE", "CREATE KEYSPACE ks WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': '1'} AND DURABLE_WRITES = true"}, + {"CREATE INDEX", "CREATE INDEX idx ON users (name)"}, + {"CREATE TYPE", "CREATE TYPE address (street text, city text, zip int)"}, + {"ALTER TYPE RENAME", "ALTER TYPE address RENAME street TO road AND city TO town"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + list, err := parser.Parse(tt.sql) + if err != nil { + t.Fatal(err) + } + + reflectTypes := make(map[string]bool) + for _, item := range list.Items { + raw := item.(*ast.RawStmt) + collectNodeTypes(reflect.ValueOf(raw.Stmt), reflectTypes) + } + + walkTypes := make(map[string]bool) + for _, item := range list.Items { + raw := item.(*ast.RawStmt) + ast.Inspect(raw.Stmt, func(n ast.Node) bool { + if n == nil { + return false + } + walkTypes[reflect.TypeOf(n).Elem().Name()] = true + return true + }) + } + + for typeName := range reflectTypes { + if !walkTypes[typeName] { + t.Errorf("reflection found Node type %s but ast.Walk did not visit it", typeName) + } + } + }) + } +} + +var nodeIface = reflect.TypeOf((*ast.Node)(nil)).Elem() + +func collectNodeTypes(v reflect.Value, types map[string]bool) { + switch v.Kind() { + case reflect.Ptr: + if v.IsNil() { + return + } + if v.Type().Implements(nodeIface) { + types[v.Type().Elem().Name()] = true + } + collectNodeTypes(v.Elem(), types) + case reflect.Interface: + if v.IsNil() { + return + } + elem := v.Elem() + if elem.Type().Implements(nodeIface) { + typeName := elem.Type().Elem().Name() + types[typeName] = true + } + collectNodeTypes(elem, types) + case reflect.Struct: + t := v.Type() + for i := range t.NumField() { + f := t.Field(i) + if !f.IsExported() || f.Name == "Loc" { + continue + } + collectNodeTypes(v.Field(i), types) + } + case reflect.Slice: + for i := range v.Len() { + collectNodeTypes(v.Index(i), types) + } + } +} + +func TestWalkChildrenCoverage(t *testing.T) { + nodeTypes := collectAllNodeStructTypes() + + walkCases := make(map[string]bool) + for _, sql := range []string{ + "SELECT * FROM users", + "SELECT name AS n FROM ks.users WHERE id = 1 ORDER BY name ASC LIMIT 10 ALLOW FILTERING", + "SELECT count(*) FROM users WHERE id IN (1, 2, 3)", + "SELECT * FROM users WHERE tags CONTAINS 'admin'", + "SELECT * FROM users WHERE tags CONTAINS KEY 'role'", + "SELECT * FROM users WHERE (a, b) = (1, 2)", + "SELECT * FROM users WHERE (a, b) IN ((1, 2), (3, 4))", + "INSERT INTO t (id) VALUES (1)", + "INSERT INTO t (id, name) VALUES (1, 'Alice') USING TTL 86400 AND TIMESTAMP 12345", + "INSERT INTO t (id, f, b, n) VALUES (1, 3.14, true, null)", + "INSERT INTO t (id, u) VALUES (1, 550e8400-e29b-41d4-a716-446655440000)", + "INSERT INTO t (id, h) VALUES (1, 0xDEADBEEF)", + "INSERT INTO t (id, m) VALUES (1, {'key': 'val'})", + "INSERT INTO t (id, s) VALUES (1, {'a', 'b'})", + "INSERT INTO t (id, l) VALUES (1, ['x', 'y'])", + "INSERT INTO t (id, t2) VALUES (1, (1, 'a', true))", + "UPDATE t SET x = 1 WHERE id = 1", + "UPDATE t SET x = 1 WHERE id = 1 IF x = 0", + "UPDATE t SET m['key'] = 'val' WHERE id = 1", + "DELETE FROM t WHERE id = 1", + "DELETE FROM t WHERE id = 1 IF EXISTS", + "BEGIN BATCH INSERT INTO t (id) VALUES (1); APPLY BATCH", + "TRUNCATE t", + "USE ks", + "SELECT token(id) FROM t WHERE token(id) > token(1)", + "CREATE KEYSPACE ks WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': '1'} AND DURABLE_WRITES = true", + "ALTER KEYSPACE ks WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': '1'}", + "DROP KEYSPACE ks", + "CREATE TABLE t (id int, name text, age int, PRIMARY KEY ((id, name), age)) WITH CLUSTERING ORDER BY (age DESC) AND comment = 'test'", + "ALTER TABLE t ADD col text", + "DROP TABLE t", + "CREATE INDEX ON t (col)", + "DROP INDEX idx", + "CREATE TYPE mytype (f1 text)", + "ALTER TYPE mytype ADD f2 int", + "ALTER TYPE mytype RENAME f1 TO field1 AND f2 TO field2", + "DROP TYPE mytype", + "CREATE MATERIALIZED VIEW mv AS SELECT * FROM t WHERE id IS NOT NULL PRIMARY KEY (id)", + "ALTER MATERIALIZED VIEW mv WITH comment = 'test'", + "DROP MATERIALIZED VIEW mv", + "CREATE FUNCTION ks.f(input text) CALLED ON NULL INPUT RETURNS text LANGUAGE java AS $$return input;$$", + "DROP FUNCTION IF EXISTS ks.f", + "CREATE AGGREGATE ks.agg(int) SFUNC plus STYPE int FINALFUNC fin INITCOND 0", + "DROP AGGREGATE IF EXISTS ks.agg", + "CREATE TRIGGER tr ON t USING 'org.example.Trigger'", + "DROP TRIGGER tr ON t", + "GRANT SELECT ON TABLE t TO r", + "REVOKE ALL ON ALL KEYSPACES FROM r", + "LIST ALL PERMISSIONS", + "LIST ROLES", + "CREATE ROLE r", + "ALTER ROLE r WITH PASSWORD = 'x'", + "DROP ROLE r", + "CREATE USER u WITH PASSWORD 'x'", + "ALTER USER u WITH PASSWORD 'y'", + "DROP USER u", + "SELECT ks.* FROM ks.t", + "SELECT * FROM t ORDER BY v ANN OF [1.0, 2.0, 3.0] LIMIT 5", + "SELECT * FROM t GROUP BY a, b", + "SELECT * FROM t PER PARTITION LIMIT 5 LIMIT 100", + "GRANT myrole TO user1", + "REVOKE myrole FROM user1", + "BEGIN COUNTER BATCH UPDATE t SET c = c + 1 WHERE id = 1; APPLY BATCH", + "CREATE INDEX ON t (KEYS(m))", + "CREATE INDEX ON t (VALUES(m))", + "CREATE INDEX ON t (ENTRIES(m))", + "CREATE AGGREGATE ks.agg2(int) SFUNC plus STYPE int", + "ALTER KEYSPACE IF EXISTS ks WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': '1'}", + "ALTER TABLE IF EXISTS t ADD IF NOT EXISTS col text", + "ALTER TABLE t DROP IF EXISTS col", + "ALTER TABLE t RENAME IF EXISTS a TO b AND c TO d", + "ALTER TYPE IF EXISTS mytype ADD IF NOT EXISTS f2 int", + "ALTER TYPE IF EXISTS mytype RENAME IF EXISTS f1 TO field1", + "ALTER MATERIALIZED VIEW IF EXISTS mv WITH comment = 'test'", + "ALTER ROLE IF EXISTS r WITH PASSWORD = 'x'", + "ALTER USER IF EXISTS u WITH PASSWORD 'y'", + "SELECT CAST(col AS int) FROM t", + "INSERT INTO t (id) VALUES (?)", + "SELECT * FROM t WHERE id = :myid", + "INSERT INTO t (id) VALUES (NaN)", + "INSERT INTO t (id) VALUES (Infinity)", + "INSERT INTO t (id) VALUES (-Infinity)", + "DROP FUNCTION IF EXISTS ks.f(int, text)", + "DROP AGGREGATE IF EXISTS ks.agg(int)", + "UPDATE t SET x = 1 WHERE id = 1 IF x IN (1, 2, 3)", + "UPDATE t SET x = 1 WHERE id = 1 IF tags CONTAINS 'admin'", + "UPDATE t SET x = 1 WHERE id = 1 IF tags CONTAINS KEY 'role'", + "INSERT INTO t JSON '{\"id\": 1}' DEFAULT NULL", + "CREATE ROLE r WITH HASHED PASSWORD = 'hash'", + "CREATE ROLE r WITH ACCESS TO DATACENTERS {'dc1', 'dc2'}", + "CREATE ROLE r WITH ACCESS TO ALL DATACENTERS", + "GRANT SELECT ON MBEAN 'org.example:type=Foo' TO r", + "GRANT SELECT ON ALL MBEANS TO r", + "GRANT EXECUTE ON FUNCTION ks.f(int, text) TO r", + "UPDATE t SET addr.street = 'Main St' WHERE id = 1", + "DELETE addr.street FROM t WHERE id = 1", + } { + list, err := parser.Parse(sql) + if err != nil { + continue + } + for _, item := range list.Items { + raw := item.(*ast.RawStmt) + ast.Inspect(raw.Stmt, func(n ast.Node) bool { + if n == nil { + return false + } + walkCases[reflect.TypeOf(n).Elem().Name()] = true + return true + }) + } + } + + var missing []string + for _, nt := range nodeTypes { + if nt == "List" || nt == "RawStmt" { + continue + } + if !walkCases[nt] { + missing = append(missing, nt) + } + } + if len(missing) > 0 { + t.Errorf("AST node types not reached by Walk in any test SQL: %v", missing) + t.Log("Add test SQL that exercises these node types, or add cases to walkChildren") + } +} + +func collectAllNodeStructTypes() []string { + nodeType := reflect.TypeOf((*ast.Node)(nil)).Elem() + candidates := []reflect.Type{ + reflect.TypeOf(ast.Identifier{}), + reflect.TypeOf(ast.QualifiedName{}), + reflect.TypeOf(ast.StringLit{}), + reflect.TypeOf(ast.IntegerLit{}), + reflect.TypeOf(ast.FloatLit{}), + reflect.TypeOf(ast.BoolLit{}), + reflect.TypeOf(ast.NullLit{}), + reflect.TypeOf(ast.UUIDLit{}), + reflect.TypeOf(ast.HexLit{}), + reflect.TypeOf(ast.CodeBlock{}), + reflect.TypeOf(ast.StarExpr{}), + reflect.TypeOf(ast.CastExpr{}), + reflect.TypeOf(ast.BindMarker{}), + reflect.TypeOf(ast.MapLit{}), + reflect.TypeOf(ast.SetLit{}), + reflect.TypeOf(ast.ListLit{}), + reflect.TypeOf(ast.TupleLit{}), + reflect.TypeOf(ast.VectorLit{}), + reflect.TypeOf(ast.FunctionCall{}), + reflect.TypeOf(ast.BinaryExpr{}), + reflect.TypeOf(ast.InExpr{}), + reflect.TypeOf(ast.ContainsExpr{}), + reflect.TypeOf(ast.TupleCompareExpr{}), + reflect.TypeOf(ast.TupleInExpr{}), + reflect.TypeOf(ast.IndexAccess{}), + reflect.TypeOf(ast.DotAccess{}), + reflect.TypeOf(ast.DataType{}), + reflect.TypeOf(ast.ColumnDef{}), + reflect.TypeOf(ast.PrimaryKeyDef{}), + reflect.TypeOf(ast.ClusteringOrder{}), + reflect.TypeOf(ast.TableOption{}), + reflect.TypeOf(ast.OptionHash{}), + reflect.TypeOf(ast.OptionHashItem{}), + reflect.TypeOf(ast.SelectElement{}), + reflect.TypeOf(ast.AssignmentElement{}), + reflect.TypeOf(ast.IfCondition{}), + reflect.TypeOf(ast.UsingClause{}), + reflect.TypeOf(ast.OrderByElement{}), + reflect.TypeOf(ast.SelectStmt{}), + reflect.TypeOf(ast.InsertStmt{}), + reflect.TypeOf(ast.UpdateStmt{}), + reflect.TypeOf(ast.DeleteStmt{}), + reflect.TypeOf(ast.BatchStmt{}), + reflect.TypeOf(ast.TruncateStmt{}), + reflect.TypeOf(ast.UseStmt{}), + reflect.TypeOf(ast.CreateKeyspaceStmt{}), + reflect.TypeOf(ast.AlterKeyspaceStmt{}), + reflect.TypeOf(ast.DropKeyspaceStmt{}), + reflect.TypeOf(ast.CreateTableStmt{}), + reflect.TypeOf(ast.AlterTableStmt{}), + reflect.TypeOf(ast.DropTableStmt{}), + reflect.TypeOf(ast.CreateIndexStmt{}), + reflect.TypeOf(ast.DropIndexStmt{}), + reflect.TypeOf(ast.CreateTypeStmt{}), + reflect.TypeOf(ast.AlterTypeStmt{}), + reflect.TypeOf(ast.AlterTableRenameItem{}), + reflect.TypeOf(ast.AlterTypeRenameItem{}), + reflect.TypeOf(ast.DropTypeStmt{}), + reflect.TypeOf(ast.CreateMVStmt{}), + reflect.TypeOf(ast.AlterMVStmt{}), + reflect.TypeOf(ast.DropMVStmt{}), + reflect.TypeOf(ast.CreateFunctionStmt{}), + reflect.TypeOf(ast.FunctionParam{}), + reflect.TypeOf(ast.DropFunctionStmt{}), + reflect.TypeOf(ast.CreateAggregateStmt{}), + reflect.TypeOf(ast.DropAggregateStmt{}), + reflect.TypeOf(ast.CreateTriggerStmt{}), + reflect.TypeOf(ast.DropTriggerStmt{}), + reflect.TypeOf(ast.CreateRoleStmt{}), + reflect.TypeOf(ast.RoleOption{}), + reflect.TypeOf(ast.AlterRoleStmt{}), + reflect.TypeOf(ast.DropRoleStmt{}), + reflect.TypeOf(ast.CreateUserStmt{}), + reflect.TypeOf(ast.AlterUserStmt{}), + reflect.TypeOf(ast.DropUserStmt{}), + reflect.TypeOf(ast.GrantStmt{}), + reflect.TypeOf(ast.GrantRoleStmt{}), + reflect.TypeOf(ast.RevokeStmt{}), + reflect.TypeOf(ast.RevokeRoleStmt{}), + reflect.TypeOf(ast.Resource{}), + reflect.TypeOf(ast.ListPermissionsStmt{}), + reflect.TypeOf(ast.ListRolesStmt{}), + reflect.TypeOf(ast.List{}), + reflect.TypeOf(ast.RawStmt{}), + } + var result []string + for _, c := range candidates { + ptr := reflect.PointerTo(c) + if ptr.Implements(nodeType) { + result = append(result, c.Name()) + } + } + return result +}