diff --git a/mindsdb_sql_parser/__about__.py b/mindsdb_sql_parser/__about__.py index 58de06c..dd10ca9 100644 --- a/mindsdb_sql_parser/__about__.py +++ b/mindsdb_sql_parser/__about__.py @@ -1,6 +1,6 @@ __title__ = 'mindsdb_sql_parser' __package_name__ = 'mindsdb_sql_parser' -__version__ = '0.13.7' +__version__ = '0.13.8' __description__ = "Mindsdb SQL parser" __email__ = "jorge@mindsdb.com" __author__ = 'MindsDB Inc' diff --git a/mindsdb_sql_parser/__init__.py b/mindsdb_sql_parser/__init__.py index 03ecb86..5d60f92 100644 --- a/mindsdb_sql_parser/__init__.py +++ b/mindsdb_sql_parser/__init__.py @@ -25,6 +25,11 @@ def process(self) -> str: # show error location msgs = self.error_location() + if self.bad_token is not None and self.bad_token.value == ';': + # unexpected semicolon in the middle of the query, it might be delimiter of statements + msgs.append('Only a single sql statement is expected. Got multiple instead') + return '\n'.join(msgs) + # suggestion suggestions = self.make_suggestion() @@ -171,11 +176,25 @@ def parse_sql(sql, dialect=None): from mindsdb_sql_parser.parser import MindsDBParser lexer, parser = MindsDBLexer(), MindsDBParser() - # remove ending semicolon and spaces - sql = re.sub(r'[\s;]+$', '', sql) + def semicolon_checker(generator): + """ + Repeat the same elements from generator except trailing SEMICOLON tokens. + They are kept in buffer till any other token appear + """ + + buffer = [] + for token in generator: + if token.type == 'SEMICOLON': + buffer.append(token) + continue + elif len(buffer) > 0: + for buf_token in buffer: + yield buf_token + buffer = [] + yield token tokens = lexer.tokenize(sql) - ast = parser.parse(tokens) + ast = parser.parse(semicolon_checker(tokens)) if ast is None: diff --git a/mindsdb_sql_parser/parser.py b/mindsdb_sql_parser/parser.py index f503f42..f8414ab 100644 --- a/mindsdb_sql_parser/parser.py +++ b/mindsdb_sql_parser/parser.py @@ -43,6 +43,7 @@ class MindsDBParser(Parser): log = ParserLogger() tokens = MindsDBLexer.tokens + start = "query" precedence = ( ('left', OR), diff --git a/tests/test_base_sql/test_base_sql.py b/tests/test_base_sql/test_base_sql.py index 7a3816c..3241a6a 100644 --- a/tests/test_base_sql/test_base_sql.py +++ b/tests/test_base_sql/test_base_sql.py @@ -1,5 +1,9 @@ from textwrap import dedent + +import pytest + from mindsdb_sql_parser import parse_sql +from mindsdb_sql_parser.exceptions import ParsingException from mindsdb_sql_parser.ast import * @@ -86,3 +90,45 @@ def test_quotes_identifier(self): assert str(ast).lower() == str(expected_ast).lower() assert ast.to_tree() == expected_ast.to_tree() + + def test_multy_statement(self): + sql = """ + select 1; + select 2 + """ + + with pytest.raises(ParsingException) as excinfo: + parse_sql(sql) + + assert "Only a single sql statement is expected" in str(excinfo.value) + + def test_trailing_semicolon(self): + query = parse_sql("select 1;") + assert query == Select(targets=[Constant(1)]) + + def test_comment_after_semicolon(self): + sql = """ + select 1; -- my query + """ + + query = parse_sql(sql) + assert query == Select(targets=[Constant(1)]) + + def test_comment_symbols_in_string(self): + expected_query = Select(targets=[Constant('--x')]) + + query = parse_sql("select '--x'") + assert query == expected_query + + query = parse_sql('select "--x"') + assert query == expected_query + + # multiline + expected_query = Select(targets=[Constant('/* x */')]) + + query = parse_sql("select '/* x */'") + assert query == expected_query + + query = parse_sql('select "/* x */"') + assert query == expected_query +