From a190a53968eadc45b06777d81ca519927e58c0d7 Mon Sep 17 00:00:00 2001 From: simeonreusch Date: Fri, 16 May 2025 14:30:47 +0200 Subject: [PATCH 1/7] start implementing custom user defined functions --- src/queryparser/common/common.py | 520 +++++++++++------- src/queryparser/postgresql/PostgreSQLLexer.g4 | 10 + .../postgresql/PostgreSQLParser.g4 | 2 +- 3 files changed, 345 insertions(+), 187 deletions(-) diff --git a/src/queryparser/common/common.py b/src/queryparser/common/common.py index 0564b61..0aed211 100644 --- a/src/queryparser/common/common.py +++ b/src/queryparser/common/common.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # All listeners that are with minor modifications shared between PostgreSQL # and MySQL. -from __future__ import (absolute_import, print_function) +from __future__ import absolute_import, print_function import logging import re @@ -31,7 +31,7 @@ def parse_alias(alias, quote_char): def process_column_name(column_name_listener, walker, ctx, quote_char): - ''' + """ A helper function that strips the quote characters from the column names. The returned list includes: @@ -44,7 +44,7 @@ def process_column_name(column_name_listener, walker, ctx, quote_char): column_name_listener object :param walker: - antlr walker object + antlr walker object :param ctx: antlr context to walk through @@ -52,7 +52,7 @@ def process_column_name(column_name_listener, walker, ctx, quote_char): :param quote_char: which quote character are we expecting? - ''' + """ cn = [] column_name_listener.column_name = [] walker.walk(column_name_listener, ctx) @@ -60,33 +60,33 @@ def process_column_name(column_name_listener, walker, ctx, quote_char): for i in column_name_listener.column_name: cni = [None, None, None, i] if i.schema_name(): - cni[0] = i.schema_name().getText().replace(quote_char, '') + cni[0] = i.schema_name().getText().replace(quote_char, "") if i.table_name(): - cni[1] = i.table_name().getText().replace(quote_char, '') + cni[1] = i.table_name().getText().replace(quote_char, "") if i.column_name(): - cni[2] = i.column_name().getText().replace(quote_char, '') + cni[2] = i.column_name().getText().replace(quote_char, "") cn.append(cni) else: try: ctx.ASTERISK() ts = ctx.table_spec() - cn = [[None, None, '*', None]] + cn = [[None, None, "*", None]] if ts.schema_name(): - cn[0][0] = ts.schema_name().getText().replace(quote_char, '') + cn[0][0] = ts.schema_name().getText().replace(quote_char, "") if ts.table_name(): - cn[0][1] = ts.table_name().getText().replace(quote_char, '') + cn[0][1] = ts.table_name().getText().replace(quote_char, "") except AttributeError: cn = [[None, None, None, None]] return cn def get_column_name_listener(base): - class ColumnNameListener(base): """ Get all column names. """ + def __init__(self): self.column_name = [] self.column_as_array = [] @@ -105,12 +105,12 @@ def enterColumn_spec(self, ctx): def get_table_name_listener(base, quote_char): - class TableNameListener(base): """ Get table names. """ + def __init__(self): self.table_names = [] self.table_aliases = [] @@ -126,9 +126,7 @@ def enterAlias(self, ctx): def get_schema_name_listener(base, quote_char): - class SchemaNameListener(base): - def __init__(self, replace_schema_name): self.replace_schema_name = replace_schema_name @@ -136,14 +134,18 @@ def enterSchema_name(self, ctx): ttype = ctx.start.type sn = ctx.getTokens(ttype)[0].getSymbol().text try: - nsn = self.replace_schema_name[sn.replace(quote_char, '')] + nsn = self.replace_schema_name[sn.replace(quote_char, "")] try: - nsn = unicode(nsn, 'utf-8') + nsn = unicode(nsn, "utf-8") except NameError: pass - nsn = re.sub(r'(|{})(?!{})[\S]*[^{}](|{})'.format( - quote_char, quote_char, quote_char, quote_char), - r'\1{}\2'.format(nsn), sn) + nsn = re.sub( + r"(|{})(?!{})[\S]*[^{}](|{})".format( + quote_char, quote_char, quote_char, quote_char + ), + r"\1{}\2".format(nsn), + sn, + ) ctx.getTokens(ttype)[0].getSymbol().text = nsn except KeyError: pass @@ -152,42 +154,46 @@ def enterSchema_name(self, ctx): def get_remove_subqueries_listener(base, base_parser): - class RemoveSubqueriesListener(base): """ Remove nested select_expressions. """ + def __init__(self, depth): self.depth = depth def enterSelect_expression(self, ctx): parent = ctx.parentCtx.parentCtx - if isinstance(parent, base_parser.SubqueryContext) and \ - ctx.depth() > self.depth: + if ( + isinstance(parent, base_parser.SubqueryContext) + and ctx.depth() > self.depth + ): # we need to remove all Select_expression instances, not # just the last one so we loop over until we get all of them # out - seinstances = [isinstance(i, - base_parser.Select_expressionContext) - for i in ctx.parentCtx.children] + seinstances = [ + isinstance(i, base_parser.Select_expressionContext) + for i in ctx.parentCtx.children + ] while True in seinstances: ctx.parentCtx.removeLastChild() - seinstances = [isinstance(i, - base_parser.Select_expressionContext) - for i in ctx.parentCtx.children] + seinstances = [ + isinstance(i, base_parser.Select_expressionContext) + for i in ctx.parentCtx.children + ] return RemoveSubqueriesListener def get_query_listener(base, base_parser, quote_char): - class QueryListener(base): """ Extract all select_expressions. """ + def __init__(self): self.select_expressions = [] self.select_list = None @@ -196,7 +202,7 @@ def __init__(self): def enterSelect_statement(self, ctx): if ctx.UNION_SYM(): - self.keywords.append('union') + self.keywords.append("union") def enterSelect_expression(self, ctx): # we need to keep track of unions as they act as subqueries @@ -219,12 +225,12 @@ def enterSelect_list(self, ctx): def get_column_keyword_function_listener(base, quote_char): - class ColumnKeywordFunctionListener(base): """ Extract columns, keywords and functions. """ + def __init__(self): self.tables = [] self.columns = [] @@ -232,8 +238,7 @@ def __init__(self): self.keywords = [] self.functions = [] self.column_name_listener = get_column_name_listener(base)() - self.table_name_listener = get_table_name_listener( - base, quote_char)() + self.table_name_listener = get_table_name_listener(base, quote_char)() self.walker = antlr4.ParseTreeWalker() self.data = [] @@ -247,8 +252,9 @@ def _process_alias(self, ctx): return alias def _extract_column(self, ctx, append=True, join_columns=False): - cn = process_column_name(self.column_name_listener, self.walker, - ctx, quote_char) + cn = process_column_name( + self.column_name_listener, self.walker, ctx, quote_char + ) alias = self._process_alias(ctx) if len(cn) > 1: @@ -288,23 +294,28 @@ def enterTable_atom(self, ctx): if ts: tn = [None, None] if ts.schema_name(): - tn[0] = ts.schema_name().getText().replace(quote_char, '') + tn[0] = ts.schema_name().getText().replace(quote_char, "") if ts.table_name(): - tn[1] = ts.table_name().getText().replace(quote_char, '') + tn[1] = ts.table_name().getText().replace(quote_char, "") self.tables.append((alias, tn, ctx.depth())) - logging.info((ctx.depth(), ctx.__class__.__name__, - [tn, alias])) + logging.info((ctx.depth(), ctx.__class__.__name__, [tn, alias])) self.data.append([ctx.depth(), ctx, [tn, alias]]) def enterDisplayed_column(self, ctx): - logging.info((ctx.depth(), ctx.__class__.__name__, - self._extract_column(ctx, append=False)[1])) - self.data.append([ctx.depth(), ctx, - self._extract_column(ctx, append=False)[1]]) + logging.info( + ( + ctx.depth(), + ctx.__class__.__name__, + self._extract_column(ctx, append=False)[1], + ) + ) + self.data.append( + [ctx.depth(), ctx, self._extract_column(ctx, append=False)[1]] + ) self._extract_column(ctx) if ctx.ASTERISK(): - self.keywords.append('*') + self.keywords.append("*") def enterSelect_expression(self, ctx): logging.info((ctx.depth(), ctx.__class__.__name__)) @@ -312,12 +323,12 @@ def enterSelect_expression(self, ctx): def enterSelect_list(self, ctx): if ctx.ASTERISK(): - logging.info((ctx.depth(), ctx.__class__.__name__, - [[None, None, '*'], None])) - self.data.append([ctx.depth(), ctx, [[[None, None, '*'], - None]]]) - self.columns.append(('*', None)) - self.keywords.append('*') + logging.info( + (ctx.depth(), ctx.__class__.__name__, [[None, None, "*"], None]) + ) + self.data.append([ctx.depth(), ctx, [[[None, None, "*"], None]]]) + self.columns.append(("*", None)) + self.keywords.append("*") def enterFunctionList(self, ctx): self.functions.append(ctx.getText()) @@ -326,75 +337,105 @@ def enterGroup_functions(self, ctx): self.functions.append(ctx.getText()) def enterGroupby_clause(self, ctx): - self.keywords.append('group by') + self.keywords.append("group by") col = self._extract_column(ctx, append=False) if col[1][0][0][2] not in self.column_aliases: self._extract_column(ctx) - logging.info((ctx.depth(), ctx.__class__.__name__, - self._extract_column(ctx, append=False)[1])) - self.data.append([ctx.depth(), ctx, - self._extract_column(ctx, append=False)[1]]) + logging.info( + ( + ctx.depth(), + ctx.__class__.__name__, + self._extract_column(ctx, append=False)[1], + ) + ) + self.data.append( + [ctx.depth(), ctx, self._extract_column(ctx, append=False)[1]] + ) def enterWhere_clause(self, ctx): - self.keywords.append('where') + self.keywords.append("where") self._extract_column(ctx) - logging.info((ctx.depth(), ctx.__class__.__name__, - self._extract_column(ctx, append=False)[1])) - self.data.append([ctx.depth(), ctx, - self._extract_column(ctx, append=False)[1]]) + logging.info( + ( + ctx.depth(), + ctx.__class__.__name__, + self._extract_column(ctx, append=False)[1], + ) + ) + self.data.append( + [ctx.depth(), ctx, self._extract_column(ctx, append=False)[1]] + ) def enterHaving_clause(self, ctx): - self.keywords.append('having') + self.keywords.append("having") self._extract_column(ctx) - logging.info((ctx.depth(), ctx.__class__.__name__, - self._extract_column(ctx, append=False)[1])) - self.data.append([ctx.depth(), ctx, - self._extract_column(ctx, append=False)[1]]) + logging.info( + ( + ctx.depth(), + ctx.__class__.__name__, + self._extract_column(ctx, append=False)[1], + ) + ) + self.data.append( + [ctx.depth(), ctx, self._extract_column(ctx, append=False)[1]] + ) def enterOrderby_clause(self, ctx): - self.keywords.append('order by') + self.keywords.append("order by") col = self._extract_column(ctx, append=False) if col[1][0][0][2] not in self.column_aliases: self._extract_column(ctx) - logging.info((ctx.depth(), ctx.__class__.__name__, - self._extract_column(ctx, append=False)[1])) - self.data.append([ctx.depth(), ctx, - self._extract_column(ctx, append=False)[1]]) + logging.info( + ( + ctx.depth(), + ctx.__class__.__name__, + self._extract_column(ctx, append=False)[1], + ) + ) + self.data.append( + [ctx.depth(), ctx, self._extract_column(ctx, append=False)[1]] + ) def enterLimit_clause(self, ctx): - self.keywords.append('limit') + self.keywords.append("limit") def enterJoin_condition(self, ctx): - self.keywords.append('join') + self.keywords.append("join") self._extract_column(ctx, join_columns=ctx) - logging.info((ctx.depth(), ctx.__class__.__name__, - self._extract_column(ctx, append=False)[1])) - self.data.append([ctx.depth(), ctx, - self._extract_column(ctx, append=False)[1]]) + logging.info( + ( + ctx.depth(), + ctx.__class__.__name__, + self._extract_column(ctx, append=False)[1], + ) + ) + self.data.append( + [ctx.depth(), ctx, self._extract_column(ctx, append=False)[1]] + ) def enterSpoint(self, ctx): - self.functions.append('spoint') + self.functions.append("spoint") def enterScircle(self, ctx): - self.functions.append('scircle') + self.functions.append("scircle") def enterSline(self, ctx): - self.functions.append('sline') + self.functions.append("sline") def enterSellipse(self, ctx): - self.functions.append('sellipse') + self.functions.append("sellipse") def enterSbox(self, ctx): - self.functions.append('sbox') + self.functions.append("sbox") def enterSpoly(self, ctx): - self.functions.append('spoly') + self.functions.append("spoly") def enterSpath(self, ctx): - self.functions.append('spath') + self.functions.append("spath") def enterStrans(self, ctx): - self.functions.append('strans') + self.functions.append("strans") return ColumnKeywordFunctionListener @@ -437,8 +478,16 @@ class SQLQueryProcessor(object): other types of listeners can be added. """ - def __init__(self, base_lexer, base_parser, base_parser_listener, - quote_char, query=None, base_sphere_listener=None): + + def __init__( + self, + base_lexer, + base_parser, + base_parser_listener, + quote_char, + query=None, + base_sphere_listener=None, + ): self.lexer = base_lexer self.parser = base_parser self.parser_listener = base_parser_listener @@ -495,12 +544,12 @@ def _extract_instances(self, column_keyword_function_listener): if isinstance(i[1], self.parser.Select_listContext): if len(i) == 3: - select_list_columns.append([[i[2][0][0] + [i[1]], - i[2][0][1]]]) + select_list_columns.append([[i[2][0][0] + [i[1]], i[2][0][1]]]) ctx_stack.append(i) - if isinstance(i[1], self.parser.Where_clauseContext) or\ - isinstance(i[1], self.parser.Having_clauseContext): + if isinstance(i[1], self.parser.Where_clauseContext) or isinstance( + i[1], self.parser.Having_clauseContext + ): if len(i[2]) > 1: for j in i[2]: other_columns.append([j]) @@ -514,15 +563,23 @@ def _extract_instances(self, column_keyword_function_listener): if i[1].USING_SYM(): for ctx in ctx_stack[::-1]: - if not isinstance(ctx[1], - self.parser.Table_atomContext): + if not isinstance(ctx[1], self.parser.Table_atomContext): break for ju in join_using: if ju[0][1] is None: - other_columns.append([[[ctx[2][0][0], - ctx[2][0][1], - ju[0][2], - ctx[1]], None]]) + other_columns.append( + [ + [ + [ + ctx[2][0][0], + ctx[2][0][1], + ju[0][2], + ctx[1], + ], + None, + ] + ] + ) elif i[1].ON(): if len(i[2]) > 1: for j in i[2]: @@ -546,9 +603,16 @@ def _extract_instances(self, column_keyword_function_listener): go_columns.append(i[2]) ctx_stack.append(i) - return select_list_columns, select_list_tables,\ - select_list_table_references, other_columns, go_columns, join,\ - join_using, column_aliases + return ( + select_list_columns, + select_list_tables, + select_list_table_references, + other_columns, + go_columns, + join, + join_using, + column_aliases, + ) def _get_budget_column(self, c, tab, ref): cname = c[0][2] @@ -559,28 +623,35 @@ def _get_budget_column(self, c, tab, ref): column_found = False for bc in ref: - if bc[0][2] == '*': - t = [[bc[0][0], bc[0][1]], 'None'] + if bc[0][2] == "*": + t = [[bc[0][0], bc[0][1]], "None"] column_found = True break elif bc[1] and c[0][2] == bc[1]: - t = [[bc[0][0], bc[0][1]], 'None'] + t = [[bc[0][0], bc[0][1]], "None"] cname = bc[0][2] if c[1] is None: calias = c[0][2] column_found = True break elif c[0][2] == bc[0][2] and bc[1] is None: - t = [[bc[0][0], bc[0][1]], 'None'] + t = [[bc[0][0], bc[0][1]], "None"] column_found = True break return cname, cctx, calias, column_found, t - def _extract_columns(self, columns, select_list_tables, ref_dict, join, - budget, column_aliases, touched_columns=None, - subquery_contents=None): - + def _extract_columns( + self, + columns, + select_list_tables, + ref_dict, + join, + budget, + column_aliases, + touched_columns=None, + subquery_contents=None, + ): # Here we store all columns that might have references somewhere # higher up in the tree structure. We'll revisit them later. missing_columns = [] @@ -595,11 +666,11 @@ def _extract_columns(self, columns, select_list_tables, ref_dict, join, calias = c[1] # if * is selected we don't care too much - if c[0][0] is None and c[0][1] is None and c[0][2] == '*'\ - and not join: + if c[0][0] is None and c[0][1] is None and c[0][2] == "*" and not join: for slt in select_list_tables: - extra_columns.append([[slt[0][0][0], slt[0][0][1], cname, - c[0][3]], calias]) + extra_columns.append( + [[slt[0][0][0], slt[0][0][1], cname, c[0][3]], calias] + ) remove_column_idxs.append(i) continue @@ -612,19 +683,21 @@ def _extract_columns(self, columns, select_list_tables, ref_dict, join, try: tab = select_list_tables[0][0] if tab[0][0] is None: - raise QueryError('Missing schema specification.') + raise QueryError("Missing schema specification.") # We have to check if we also have a join on the same level # and we are actually touching a column from the joined table - if join and c[0][2] != '*' and\ - (tab[1] != c[0][1] or - (tab[1] is None and c[0][1] is None)): - cname, cctx, calias, column_found, tab =\ - self._get_budget_column(c, tab, budget[-1][2]) + if ( + join + and c[0][2] != "*" + and (tab[1] != c[0][1] or (tab[1] is None and c[0][1] is None)) + ): + cname, cctx, calias, column_found, tab = self._get_budget_column( + c, tab, budget[-1][2] + ) # raise an ambiguous column if column_found and c[0][1] is None: - raise QueryError("Column '%s' is possibly ambiguous." - % c[0][2]) + raise QueryError("Column '%s' is possibly ambiguous." % c[0][2]) except IndexError: pass @@ -636,14 +709,18 @@ def _extract_columns(self, columns, select_list_tables, ref_dict, join, if isinstance(ref[0], int): # ref is a budget column - cname, cctx, calias, column_found, tab =\ - self._get_budget_column(c, tab, ref[2]) + cname, cctx, calias, column_found, tab = self._get_budget_column( + c, tab, ref[2] + ) ref_cols = [j[0][2] for j in ref[2]] - if not column_found and c[0][1] is not None\ - and c[0][1] != tab[0][1] and '*' not in ref_cols: - raise QueryError("Unknown column '%s.%s'." % (c[0][1], - c[0][2])) + if ( + not column_found + and c[0][1] is not None + and c[0][1] != tab[0][1] + and "*" not in ref_cols + ): + raise QueryError("Unknown column '%s.%s'." % (c[0][1], c[0][2])) else: # ref is a table @@ -662,12 +739,12 @@ def _extract_columns(self, columns, select_list_tables, ref_dict, join, if subquery_contents is not None: try: contents = subquery_contents[c[0][1]] - cname, cctx, calias, column_found, tab =\ + cname, cctx, calias, column_found, tab = ( self._get_budget_column(c, tab, contents) + ) except KeyError: - tabs = [j[0][0][:2] for j in - subquery_contents.values()] + tabs = [j[0][0][:2] for j in subquery_contents.values()] tabs += [j[0][0] for j in select_list_tables] column_found = False for t in tabs: @@ -686,18 +763,24 @@ def _extract_columns(self, columns, select_list_tables, ref_dict, join, continue else: if tab[0][1] == c[0][1]: - columns[i] = [[tab[0][0], tab[0][1], - c[0][2], c[0][3]], c[1]] + columns[i] = [ + [tab[0][0], tab[0][1], c[0][2], c[0][3]], + c[1], + ] else: - missing_columns.append(c) columns[i] = c if touched_columns is not None: touched_columns.append(c) continue - elif c[0][2] is not None and c[0][2] != '*' and c[0][1] is \ - None and len(ref_dict.keys()) > 1 and not join: + elif ( + c[0][2] is not None + and c[0][2] != "*" + and c[0][1] is None + and len(ref_dict.keys()) > 1 + and not join + ): raise QueryError("Column '%s' is ambiguous." % c[0][2]) elif len(budget) and tab[0][0] is None and tab[0][1] is None: @@ -705,18 +788,21 @@ def _extract_columns(self, columns, select_list_tables, ref_dict, join, column_found = False if isinstance(ref[0], int): - cname, cctx, calias, column_found, tab =\ - self._get_budget_column(c, tab, ref[2]) + cname, cctx, calias, column_found, tab = ( + self._get_budget_column(c, tab, ref[2]) + ) # We allow None.None columns because they are produced # by count(*) - if not column_found and c[0][2] is not None\ - and c[0][2] not in column_aliases: + if ( + not column_found + and c[0][2] is not None + and c[0][2] not in column_aliases + ): raise QueryError("Unknown column '%s'." % c[0][2]) if touched_columns is not None: - touched_columns.append([[tab[0][0], tab[0][1], cname, cctx], - calias]) + touched_columns.append([[tab[0][0], tab[0][1], cname, cctx], calias]) else: columns[i] = [[tab[0][0], tab[0][1], cname, c[0][3]], calias] @@ -726,7 +812,12 @@ def _extract_columns(self, columns, select_list_tables, ref_dict, join, columns.extend(extra_columns) return missing_columns - def process_query(self, replace_schema_name=None, indexed_objects=None): + def process_query( + self, + replace_schema_name=None, + replace_function_names=None, + indexed_objects=None, + ): """ Parses and processes the query. After a successful run it fills up columns, keywords, functions and syntax_errors lists. @@ -737,6 +828,14 @@ def process_query(self, replace_schema_name=None, indexed_objects=None): :param indexed_objects: Deprecated """ + self.replaced_functions = [] + print(f"replace_function_names: {replace_function_names}") + if replace_function_names: + for i, function_name in enumerate(replace_function_names): + if function_name in self.query: + self.replaced_functions.append(function_name) + self.set_query(self.query.replace(function_name, f"UDF_{i}")) + # Antlr objects inpt = antlr4.InputStream(self.query) lexer = self.lexer(inpt) @@ -752,12 +851,14 @@ def process_query(self, replace_schema_name=None, indexed_objects=None): if replace_schema_name is not None: schema_name_listener = get_schema_name_listener( - self.parser_listener, self.quote_char)(replace_schema_name) + self.parser_listener, self.quote_char + )(replace_schema_name) self.walker.walk(schema_name_listener, tree) self._query = stream.getText() - query_listener = get_query_listener(self.parser_listener, - self.parser, self.quote_char)() + query_listener = get_query_listener( + self.parser_listener, self.parser, self.quote_char + )() subquery_aliases = [None] keywords = [] functions = [] @@ -784,10 +885,11 @@ def process_query(self, replace_schema_name=None, indexed_objects=None): # Iterate through subqueries starting with the lowest level for ccc, ctx in enumerate(query_listener.select_expressions[::-1]): remove_subquieries_listener = get_remove_subqueries_listener( - self.parser_listener, self.parser)(ctx.depth()) - column_keyword_function_listener = \ - get_column_keyword_function_listener( - self.parser_listener, self.quote_char)() + self.parser_listener, self.parser + )(ctx.depth()) + column_keyword_function_listener = get_column_keyword_function_listener( + self.parser_listener, self.quote_char + )() # Remove nested subqueries from select_expressions self.walker.walk(remove_subquieries_listener, ctx) @@ -809,10 +911,16 @@ def process_query(self, replace_schema_name=None, indexed_objects=None): # We get the columns from the select list along with all # other touched columns and any possible join conditions column_aliases_from_previous = [i for i in column_aliases] - select_list_columns, select_list_tables,\ - select_list_table_references, other_columns, go_columns, join,\ - join_using, column_aliases =\ - self._extract_instances(column_keyword_function_listener) + ( + select_list_columns, + select_list_tables, + select_list_table_references, + other_columns, + go_columns, + join, + join_using, + column_aliases, + ) = self._extract_instances(column_keyword_function_listener) tables.extend([i[0] for i in select_list_tables]) @@ -837,9 +945,14 @@ def process_query(self, replace_schema_name=None, indexed_objects=None): for table in select_list_tables: ref_dict[table[0][0][1]] = table - mc = self._extract_columns(select_list_columns, select_list_tables, - ref_dict, join, budget, - column_aliases_from_previous) + mc = self._extract_columns( + select_list_columns, + select_list_tables, + ref_dict, + join, + budget, + column_aliases_from_previous, + ) missing_columns.extend([[i] for i in mc]) touched_columns.extend(select_list_columns) @@ -851,10 +964,15 @@ def process_query(self, replace_schema_name=None, indexed_objects=None): if col[0][0][2] not in aliases: other_columns.append(col) - mc = self._extract_columns(other_columns, select_list_tables, - ref_dict, join, budget, - column_aliases_from_previous, - touched_columns) + mc = self._extract_columns( + other_columns, + select_list_tables, + ref_dict, + join, + budget, + column_aliases_from_previous, + touched_columns, + ) missing_columns.extend([[i] for i in mc]) @@ -863,8 +981,9 @@ def process_query(self, replace_schema_name=None, indexed_objects=None): join_columns.append(budget.pop(-1)) if len(join_using) == 1: for tab in select_list_tables: - touched_columns.append([[tab[0][0][0], tab[0][0][1], - join_using[0][0][2]], None]) + touched_columns.append( + [[tab[0][0][0], tab[0][0][1], join_using[0][0][2]], None] + ) bp = [] for b in budget[::-1]: if b[0] > current_depth: @@ -876,26 +995,38 @@ def process_query(self, replace_schema_name=None, indexed_objects=None): subquery_contents[subquery_alias] = current_columns if len(missing_columns): - mc = self._extract_columns(missing_columns, select_list_tables, - ref_dict, join, budget, - column_aliases_from_previous, - touched_columns, subquery_contents) + mc = self._extract_columns( + missing_columns, + select_list_tables, + ref_dict, + join, + budget, + column_aliases_from_previous, + touched_columns, + subquery_contents, + ) if len(mc): - unref_cols = "', '".join(['.'.join([j for j in i[0][:3] if j]) - for i in mc]) + unref_cols = "', '".join( + [".".join([j for j in i[0][:3] if j]) for i in mc] + ) raise QueryError("Unreferenced column(s): '%s'." % unref_cols) touched_columns = set([tuple(i[0]) for i in touched_columns]) # extract display_columns display_columns = [] - mc = self._extract_columns([[i] for i in budget[-1][2]], - select_list_tables, ref_dict, join, budget, - column_aliases_from_previous, - display_columns, subquery_contents) - - display_columns = [[i[1] if i[1] else i[0][2], i[0]] - for i in display_columns] + mc = self._extract_columns( + [[i] for i in budget[-1][2]], + select_list_tables, + ref_dict, + join, + budget, + column_aliases_from_previous, + display_columns, + subquery_contents, + ) + + display_columns = [[i[1] if i[1] else i[0][2], i[0]] for i in display_columns] # Let's get rid of all columns that are already covered by # db.tab.*. Figure out a better way to do it and replace the code @@ -903,28 +1034,45 @@ def process_query(self, replace_schema_name=None, indexed_objects=None): asterisk_columns = [] del_columns = [] for col in touched_columns: - if col[2] == '*': + if col[2] == "*": asterisk_columns.append(col) for acol in asterisk_columns: for col in touched_columns: - if acol[0] == col[0] and acol[1] == col[1] and \ - acol[2] != col[2]: + if acol[0] == col[0] and acol[1] == col[1] and acol[2] != col[2]: del_columns.append(col) columns = list(set(touched_columns).difference(del_columns)) self.columns = list(set([self._strip_column(i) for i in columns])) self.keywords = list(set(keywords)) self.functions = list(set(functions)) - self.display_columns = [(i[0].lstrip('"').rstrip('"'), - list(self._strip_column(i[1]))) - for i in display_columns] - - self.tables = list(set([tuple([i[0][0].lstrip('"').rstrip('"') - if i[0][0] is not None else i[0][0], - i[0][1].lstrip('"').rstrip('"') - if i[0][1] is not None else i[0][1]]) - for i in tables])) + self.display_columns = [ + (i[0].lstrip('"').rstrip('"'), list(self._strip_column(i[1]))) + for i in display_columns + ] + + self.tables = list( + set( + [ + tuple( + [ + i[0][0].lstrip('"').rstrip('"') + if i[0][0] is not None + else i[0][0], + i[0][1].lstrip('"').rstrip('"') + if i[0][1] is not None + else i[0][1], + ] + ) + for i in tables + ] + ) + ) + + print(f"replaced_functions: {self.replaced_functions}") + if self.replaced_functions: + for i, function_name in enumerate(self.replaced_functions): + self.set_query(self.query.replace(f"UDF_{i}", function_name)) @property def query(self): @@ -935,7 +1083,7 @@ def query(self): return self._query def _strip_query(self, query): - return query.lstrip('\n').rstrip().rstrip(';') + ';' + return query.lstrip("\n").rstrip().rstrip(";") + ";" def _strip_column(self, col): scol = [None, None, None] diff --git a/src/queryparser/postgresql/PostgreSQLLexer.g4 b/src/queryparser/postgresql/PostgreSQLLexer.g4 index 2532ca5..904c543 100644 --- a/src/queryparser/postgresql/PostgreSQLLexer.g4 +++ b/src/queryparser/postgresql/PostgreSQLLexer.g4 @@ -170,6 +170,16 @@ TIME_SYM : T_ I_ M_ E_ ; TIMESTAMP : T_ I_ M_ E_ S_ T_ A_ M_ P_ ; TRUE_SYM : T_ R_ U_ E_ ; TRUNCATE : T_ R_ U_ N_ C_ A_ T_ E_ ; +UDF_0 : U_ D_ F_ '_' '0' ; +UDF_1 : U_ D_ F_ '_' '1' ; +UDF_2 : U_ D_ F_ '_' '2' ; +UDF_3 : U_ D_ F_ '_' '3' ; +UDF_4 : U_ D_ F_ '_' '4' ; +UDF_5 : U_ D_ F_ '_' '5' ; +UDF_6 : U_ D_ F_ '_' '6' ; +UDF_7 : U_ D_ F_ '_' '7' ; +UDF_8 : U_ D_ F_ '_' '8' ; +UDF_9 : U_ D_ F_ '_' '9' ; UNION_SYM : U_ N_ I_ O_ N_ ; UNSIGNED_SYM : U_ N_ S_ I_ G_ N_ E_ D_ ; UPDATE : U_ P_ D_ A_ T_ E_ ; diff --git a/src/queryparser/postgresql/PostgreSQLParser.g4 b/src/queryparser/postgresql/PostgreSQLParser.g4 index 209248c..b3aadc5 100644 --- a/src/queryparser/postgresql/PostgreSQLParser.g4 +++ b/src/queryparser/postgresql/PostgreSQLParser.g4 @@ -63,7 +63,7 @@ array_functions: ARRAY_LENGTH ; custom_functions: - GAIA_HEALPIX_INDEX | PDIST ; + GAIA_HEALPIX_INDEX | PDIST | UDF_0 | UDF_1 | UDF_2 | UDF_3 | UDF_4 | UDF_5 | UDF_6 | UDF_7 | UDF_8 | UDF_9 ; pg_sphere_functions: AREA ; From bfb67cd14c312918ffc7c18c358784e1eb741730 Mon Sep 17 00:00:00 2001 From: simeonreusch Date: Fri, 16 May 2025 15:23:32 +0200 Subject: [PATCH 2/7] add a slightly smarter way to check if a udf function name occurs in the query --- src/queryparser/common/common.py | 37 ++++++++++++++----- .../postgresql/postgresqlprocessor.py | 10 ++--- 2 files changed, 33 insertions(+), 14 deletions(-) diff --git a/src/queryparser/common/common.py b/src/queryparser/common/common.py index 0aed211..680e44e 100644 --- a/src/queryparser/common/common.py +++ b/src/queryparser/common/common.py @@ -812,6 +812,21 @@ def _extract_columns( columns.extend(extra_columns) return missing_columns + @staticmethod + def _match_and_replace_function_name(query, function_name, i): + """ + This very roughly checks if the function name is present in the query. + We check for a space, the function name, and an opening parenthesis. + """ + pattern = r"\s" + re.escape(function_name) + r"\(" + match = re.search(pattern, query) + if match: + start, end = match.span() + # Replace the matched function name with UDF_{i} + query = query[: start + 1] + f"UDF_{i}" + query[end - 1 :] + + return match, query + def process_query( self, replace_schema_name=None, @@ -828,13 +843,16 @@ def process_query( :param indexed_objects: Deprecated """ - self.replaced_functions = [] - print(f"replace_function_names: {replace_function_names}") + self.replaced_functions = {} + if replace_function_names: for i, function_name in enumerate(replace_function_names): - if function_name in self.query: - self.replaced_functions.append(function_name) - self.set_query(self.query.replace(function_name, f"UDF_{i}")) + match, query = self._match_and_replace_function_name( + self.query, function_name, i + ) + if match: + self.replaced_functions[i] = function_name + self.set_query(query) # Antlr objects inpt = antlr4.InputStream(self.query) @@ -1069,10 +1087,11 @@ def process_query( ) ) - print(f"replaced_functions: {self.replaced_functions}") - if self.replaced_functions: - for i, function_name in enumerate(self.replaced_functions): - self.set_query(self.query.replace(f"UDF_{i}", function_name)) + if len(self.replaced_functions) > 0: + for i, function_name in self.replaced_functions.items(): + self._query = self.query.replace(f"UDF_{i}", function_name) + self.functions.remove(f"UDF_{i}") + self.functions.append(function_name) @property def query(self): diff --git a/src/queryparser/postgresql/postgresqlprocessor.py b/src/queryparser/postgresql/postgresqlprocessor.py index a15b7f6..e400745 100644 --- a/src/queryparser/postgresql/postgresqlprocessor.py +++ b/src/queryparser/postgresql/postgresqlprocessor.py @@ -6,18 +6,18 @@ """ -from __future__ import (absolute_import, print_function) +from __future__ import absolute_import, print_function __all__ = ["PostgreSQLQueryProcessor"] +from ..common import SQLQueryProcessor from .PostgreSQLLexer import PostgreSQLLexer from .PostgreSQLParser import PostgreSQLParser from .PostgreSQLParserListener import PostgreSQLParserListener -from ..common import SQLQueryProcessor - class PostgreSQLQueryProcessor(SQLQueryProcessor): def __init__(self, query=None): - super().__init__(PostgreSQLLexer, PostgreSQLParser, - PostgreSQLParserListener, '"', query) + super().__init__( + PostgreSQLLexer, PostgreSQLParser, PostgreSQLParserListener, '"', query + ) From ec3c7a8ca9553a3f3869e3b71da809e19a85e5bf Mon Sep 17 00:00:00 2001 From: simeonreusch Date: Fri, 16 May 2025 16:27:08 +0200 Subject: [PATCH 3/7] add test for udf --- src/queryparser/testing/tests.yaml | 73 +++++++++++++++++++++++++++--- src/queryparser/testing/utils.py | 48 ++++++++++++++------ 2 files changed, 102 insertions(+), 19 deletions(-) diff --git a/src/queryparser/testing/tests.yaml b/src/queryparser/testing/tests.yaml index f33812e..fb093f6 100644 --- a/src/queryparser/testing/tests.yaml +++ b/src/queryparser/testing/tests.yaml @@ -16,6 +16,7 @@ common_tests: - - ['col1: db.tab.a'] - ['db.tab'] + - - - SELECT t.a FROM db.tab1 as t, db.tab2; @@ -24,6 +25,7 @@ common_tests: - - ['a: db.tab1.a'] - ['db.tab1', 'db.tab2'] + - - - SELECT COUNT(*), a*2, b, 100 FROM db.tab; @@ -32,6 +34,7 @@ common_tests: - ['COUNT'] - ['a: db.tab.a', 'b: db.tab.b'] - ['db.tab'] + - - - SELECT (((((((1+2)*3)/4)^5)%6)&7)>>8) FROM db.tab; @@ -40,6 +43,7 @@ common_tests: - - - ['db.tab'] + - - - SELECT ABS(a),AVG(b) FROM db.tab; @@ -48,6 +52,7 @@ common_tests: - ['AVG', 'ABS'] - ['a: db.tab.a', 'b: db.tab.b'] - ['db.tab'] + - - - SELECT AVG(((((b & a) << 1) + 1) / a) ^ 4.5) FROM db.tab; @@ -56,6 +61,7 @@ common_tests: - ['AVG'] - - ['db.tab'] + - - - SELECT A.a,B.* FROM db.tab1 A,db.tab2 AS B LIMIT 10; @@ -64,6 +70,7 @@ common_tests: - - ['a: db.tab1.a', '*: db.tab2.*'] - ['db.tab1', 'db.tab2'] + - - - SELECT fofid, x, y, z, vx, vy, vz @@ -76,6 +83,7 @@ common_tests: - - ['fofid: MDR1.FOF.fofid', 'x: MDR1.FOF.x', 'y: MDR1.FOF.y', 'z: MDR1.FOF.z', 'vx: MDR1.FOF.vx', 'vy: MDR1.FOF.vy', 'vz: MDR1.FOF.vz'] - ['MDR1.FOF'] + - - - SELECT article, dealer, price @@ -86,6 +94,7 @@ common_tests: - ['MAX'] - ['article: world.shop.article', 'dealer: world.shop.dealer', 'price: world.shop.price'] - ['world.shop', 'universe.shop'] + - - - SELECT dealer, price @@ -99,6 +108,7 @@ common_tests: - ['MAX'] - ['price: db.shop.price', 'dealer: db.shop.dealer'] - ['db.shop', 'db.warehouse'] + - - - SELECT A.*, B.* @@ -110,6 +120,7 @@ common_tests: - - ['*: db1.table1.*', '*: db2.table1.*'] - ['db1.table1', 'db2.table1'] + - - - SELECT * FROM mmm.products @@ -120,6 +131,7 @@ common_tests: - - ['*: mmm.products.*'] - ['mmm.products'] + - - - SELECT t.table_name AS tname, t.description AS tdesc, @@ -153,6 +165,7 @@ common_tests: 'jcol: tap_schema.cols.column_name', 'kcol: tap_schema.cols.column_name'] - ['tap_schema.tabs', 'tap_schema.cols'] + - - - SELECT t1.a FROM d.tab t1 @@ -162,6 +175,7 @@ common_tests: - ['a: foo.tab.a'] - ['foo.tab'] - 'd': 'foo' + - - - SELECT DISTINCT t.table_name @@ -175,6 +189,7 @@ common_tests: - - ['table_name: tap_schema.tabs.table_name'] - ['tap_schema.tabs', 'tap_schema.cols'] + - - - SELECT s.* FROM db.person p INNER JOIN db.shirt s @@ -186,6 +201,7 @@ common_tests: - - ['*: db.shirt.*'] - ['db.shirt', 'db.person'] + - - - SELECT x, y, z, mass @@ -200,6 +216,7 @@ common_tests: - ['x: MDR1.FOF.x', 'y: MDR1.FOF.y', 'z: MDR1.FOF.z', 'mass: MDR1.FOF.mass'] - ['MDR1.FOF'] + - - - SELECT h.Mvir, h.spin, g.diskMassStellar, @@ -219,6 +236,7 @@ common_tests: 'diskMassStellar: MDPL2.Galacticus.diskMassStellar', 'spin: MDPL2.Rockstar.spin'] - ['MDPL2.Rockstar', 'MDPL2.Galacticus'] + - - - SELECT bdmId, Rbin, mass, dens @@ -247,6 +265,7 @@ common_tests: - ['bdmId: Bolshoi.BDMVProf.bdmId', 'Rbin: Bolshoi.BDMVProf.Rbin', 'mass: Bolshoi.BDMVProf.mass', 'dens: Bolshoi.BDMVProf.dens'] - ['Bolshoi.BDMVProf', 'Bolshoi.BDMV'] + - - - SELECT t.RAVE_OBS_ID AS c1, t.HEALPix AS c2, @@ -274,6 +293,7 @@ common_tests: 'c4: RAVEPUB_DR5.RAVE_ON.TEFF'] - ['RAVEPUB_DR5.RAVE_DR5', 'RAVEPUB_DR5.RAVE_Gravity_SC', 'RAVEPUB_DR5.RAVE_ON'] + - - - SELECT db.tab.a FROM db.tab; @@ -282,6 +302,7 @@ common_tests: - - ['a: db.tab.a'] - ['db.tab'] + - - - SELECT COUNT(*) AS n, id, mra, mlem AS qqq, blem @@ -316,6 +337,7 @@ common_tests: - ['n: None.None.None', 'id: db.bar.id', 'mra: db.tab.ra', 'qqq: db.bar.mlem', 'blem: None.None.blem'] - ['db.tab', 'db.bar', 'db.gaia'] + - - - SELECT @@ -361,6 +383,7 @@ common_tests: 'n: None.None.None'] - ['gaiadr1.tgas_source', 'gaiadr1.tmass_best_neighbour', 'gaiadr1.tmass_original_valid'] + - - - SELECT ra, sub.qqq, t1.bar @@ -379,6 +402,7 @@ common_tests: - - - ['db.tab', 'db.blem'] + - - - SELECT t1.a, t2.b, t3.c, t4.z @@ -392,6 +416,7 @@ common_tests: 'd': 'foo' 'db2': 'bar' 'foo': 'bas' + - - - SELECT *, AVG(par) as apar FROM db.tab; @@ -400,6 +425,7 @@ common_tests: - ['AVG'] - ['*: db.tab.*', 'apar: db.tab.par'] - ['db.tab'] + - - - SELECT q.ra, q.de, tab2.par @@ -414,6 +440,7 @@ common_tests: - ['MAX'] - ['ra: db.tab.ra', 'de: db.tab.de', 'par: db.tab2.par'] - ['db.tab', 'db.tab2', 'db.undef'] + - - - SELECT a, b @@ -427,6 +454,7 @@ common_tests: - - ['a: db.tab1.a', 'b: db.tab1.b'] - ['db.tab1', 'db.tab2'] + - - - SELECT a FROM db.tab HAVING b > 0 @@ -435,6 +463,7 @@ common_tests: - - ['a: db.tab.a'] - ['db.tab'] + - - - SELECT a FROM db.tab WHERE EXISTS ( @@ -445,7 +474,8 @@ common_tests: - - ['a: db.tab.a'] - ['db.tab', 'db.foo'] - + - + - - SELECT * FROM ( @@ -458,6 +488,7 @@ common_tests: - - - ['db.a', 'db.b', 'db.c', 'db.d', 'db.x', 'db.y'] + - - - SELECT * @@ -471,6 +502,7 @@ common_tests: - - - ['db.a', 'db.b', 'db.c', 'db.d', 'db.x', 'db.y'] + - - - SELECT A.*, B.* @@ -482,6 +514,7 @@ common_tests: - - ['*: db1.table1.*', '*: db2.table1.*'] - ['db1.table1', 'db2.table1'] + - common_translation_tests: @@ -493,6 +526,7 @@ common_translation_tests: - - - + - mysql_tests: @@ -508,6 +542,7 @@ mysql_tests: - - ['fi@1: db.test_table.fi@1', 'fi2: db.test_table.fi2'] - ['db.test_table', 'bd.test_table'] + - - - SELECT `fi@1`, fi2 @@ -521,6 +556,7 @@ mysql_tests: - - ['fi@1: db.test_table.fi@1', 'fi2: db.test_table.fi2'] - ['db.test_table', 'bd.test_table'] + - - - SELECT log10(mass)/sqrt(x) AS logM @@ -530,6 +566,7 @@ mysql_tests: - ['log10', 'sqrt'] - - ['MDR1.FOF'] + - - - SELECT log10(ABS(x)) AS log_x @@ -539,6 +576,7 @@ mysql_tests: - ['log10', 'ABS'] - ['log_x: MDR1.FOF.x'] - ['MDR1.FOF'] + - - - SELECT DEGREES(sdist(spoint(RADIANS(0.0), RADIANS(0.0)), @@ -550,6 +588,7 @@ mysql_tests: - ['DEGREES', 'RADIANS', 'sdist', 'spoint'] - - ['db.VII/233/xsc'] + - - - SELECT Data FROM db.Users @@ -559,6 +598,7 @@ mysql_tests: - - ['Data: db.Users.Data'] - ['db.Users'] + - - - SELECT CONVERT(ra, DECIMAL(12,9)) as ra2, ra as ra1 @@ -570,6 +610,7 @@ mysql_tests: - - ['ra1: GDR1.gaia_source.ra', 'ra2: GDR1.gaia_source.ra'] - ['GDR1.gaia_source'] + - - - SELECT DEGREES(sdist(spoint(RADIANS(ra), RADIANS(dec)), @@ -587,6 +628,7 @@ mysql_tests: 'DEGREES'] - - ['GDR1.gaia_source'] + - - - SELECT x, y, z, mass @@ -598,6 +640,7 @@ mysql_tests: - ['x: MDR1.FOF.x', 'y: MDR1.FOF.y', 'z: MDR1.FOF.z', 'mass: MDR1.FOF.mass'] - ['MDR1.FOF'] + - - - SELECT bdmId, Rbin, mass, dens FROM Bolshoi.BDMVProf @@ -618,6 +661,7 @@ mysql_tests: - ['bdmId: Bolshoi.BDMVProf.bdmId', 'Rbin: Bolshoi.BDMVProf.Rbin', 'mass: Bolshoi.BDMVProf.mass', 'dens: Bolshoi.BDMVProf.dens'] - ['Bolshoi.BDMVProf', 'Bolshoi.BDMV'] + - postgresql_tests: @@ -628,6 +672,7 @@ postgresql_tests: - ['pdist'] - - + - - - SELECT DISTINCT ON ("source"."tycho2_id") "tycho2_id", "source"."tycho2_dist" @@ -637,6 +682,7 @@ postgresql_tests: - - - + - - - SELECT ra, dec FROM gdr1.gaia_source @@ -647,6 +693,7 @@ postgresql_tests: - ['scircle', 'spoint'] - ['ra: gdr1.gaia_source.ra', 'dec: gdr1.gaia_source.dec'] - ['gdr1.gaia_source'] + - - - SELECT ra, dec FROM gdr1.gaia_source @@ -657,6 +704,7 @@ postgresql_tests: - ['scircle', 'spoint'] - ['ra: gdr1.gaia_source.ra', 'dec: gdr1.gaia_source.dec'] - ['gdr1.gaia_source'] + - - - SELECT * FROM gdr2.vari_cepheid AS v @@ -668,6 +716,7 @@ postgresql_tests: - ['scircle', 'spoint'] - ['*: gdr2.vari_cepheid.*'] - ['gdr2.gaia_source', 'gdr2.vari_cepheid'] + - - - SELECT curves.observation_time, @@ -696,6 +745,7 @@ postgresql_tests: 'observation_time: gdr1.phot_variable_time_series_gfov.observation_time', 'phase: gdr1.rrlyrae.p1'] - ['gdr1.phot_variable_time_series_gfov', 'gdr1.rrlyrae'] + - - - SELECT a @@ -706,6 +756,7 @@ postgresql_tests: - - ['a: db.tab.a'] - ['db.tab'] + - - - SELECT arr[1:3] FROM db.phot; @@ -714,6 +765,7 @@ postgresql_tests: - - ['arr: db.phot.arr'] - ['db.phot'] + - - - SELECT arr[1:3][1][2][3][4] FROM db.phot; @@ -722,6 +774,7 @@ postgresql_tests: - - ['arr: db.phot.arr'] - ['db.phot'] + - - - SELECT ra, dec FROM gdr1.gaia_source @@ -732,6 +785,7 @@ postgresql_tests: - ['scircle', 'spoint'] - ['ra: gdr1.gaia_source.ra', 'dec: gdr1.gaia_source.dec'] - ['gdr1.gaia_source'] + - - - SELECT q2.c / q1.c FROM ( @@ -749,6 +803,7 @@ postgresql_tests: - ['COUNT'] - - ['gdr1.tgas_source'] + - - - SELECT * FROM gdr2.vari_cepheid AS v @@ -760,6 +815,7 @@ postgresql_tests: - ['scircle', 'spoint'] - ['*: gdr2.vari_cepheid.*'] - ['gdr2.gaia_source', 'gdr2.vari_cepheid'] + - - - SELECT ra FROM gdr2.gaia_source AS gaia @@ -770,14 +826,19 @@ postgresql_tests: - ['RADIANS', 'spoint', 'scircle'] - ['ra: gdr2.gaia_source.ra'] - ['gdr2.gaia_source'] + - + - + - "SELECT specuid, ra, dec FROM dr1.spectrum WHERE QMOST_SPEC_IS_IN_SURVEY(specuid, '04');" + - ['dr1.spectrum.specuid', 'dr1.spectrum.ra', 'dr1.spectrum.dec'] + - ["where"] + - ["QMOST_SPEC_IS_IN_SURVEY"] + - ["specuid: dr1.spectrum.specuid", "ra: dr1.spectrum.ra", "dec: dr1.spectrum.dec"] + - ["dr1.spectrum"] + - + - ["QMOST_SPEC_IS_IN_SURVEY"] -# Each test below consists of: -# -# - ADQL query string -# - translated query string - adql_mysql_tests: - - SELECT POINT('icrs', 10, 10) AS "p" FROM "db".tab diff --git a/src/queryparser/testing/utils.py b/src/queryparser/testing/utils.py index c89ad7f..28daeee 100644 --- a/src/queryparser/testing/utils.py +++ b/src/queryparser/testing/utils.py @@ -6,12 +6,28 @@ def _test_parsing(query_processor, test, translate=False): - if len(test) == 6: - query, columns, keywords, functions, display_columns, tables = test + if len(test) == 7: + ( + query, + columns, + keywords, + functions, + display_columns, + tables, + replace_function_names, + ) = test replace_schema_name = None - elif len(test) == 7: - query, columns, keywords, functions, display_columns, tables,\ - replace_schema_name = test + elif len(test) == 8: + ( + query, + columns, + keywords, + functions, + display_columns, + tables, + replace_schema_name, + replace_function_names, + ) = test if translate: adt = ADQLQueryTranslator() @@ -28,13 +44,20 @@ def _test_parsing(query_processor, test, translate=False): qp.set_query(query) qp.process_query(replace_schema_name=replace_schema_name) - qp_columns = ['.'.join([str(j) for j in i[:3]]) for i in qp.columns - if i[0] is not None and i[1] is not None] - qp_display_columns = ['%s: %s' % (str(i[0]), - '.'.join([str(j) for j in i[1]])) - for i in qp.display_columns] - qp_tables = ['.'.join([str(j) for j in i]) for i in qp.tables - if i[0] is not None and i[1] is not None] + qp_columns = [ + ".".join([str(j) for j in i[:3]]) + for i in qp.columns + if i[0] is not None and i[1] is not None + ] + qp_display_columns = [ + "%s: %s" % (str(i[0]), ".".join([str(j) for j in i[1]])) + for i in qp.display_columns + ] + qp_tables = [ + ".".join([str(j) for j in i]) + for i in qp.tables + if i[0] is not None and i[1] is not None + ] if columns is not None: assert set(columns) == set(qp_columns) @@ -50,4 +73,3 @@ def _test_parsing(query_processor, test, translate=False): if tables is not None: assert set(tables) == set(qp_tables) - From cb5364b9d4ab496344db04c0e5d517b9ab2b2f18 Mon Sep 17 00:00:00 2001 From: simeonreusch Date: Fri, 16 May 2025 16:41:09 +0200 Subject: [PATCH 4/7] add ruff config, fix test --- .ruff.toml | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 .ruff.toml diff --git a/.ruff.toml b/.ruff.toml new file mode 100644 index 0000000..489184b --- /dev/null +++ b/.ruff.toml @@ -0,0 +1,3 @@ +[format] +# Like Black, use double quotes for strings. +quote-style = "single" \ No newline at end of file From 08d2437f178382f9f5a81085e029e2b9fb092532 Mon Sep 17 00:00:00 2001 From: simeonreusch Date: Fri, 16 May 2025 16:41:25 +0200 Subject: [PATCH 5/7] fix test --- src/queryparser/testing/utils.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/queryparser/testing/utils.py b/src/queryparser/testing/utils.py index 28daeee..012e40a 100644 --- a/src/queryparser/testing/utils.py +++ b/src/queryparser/testing/utils.py @@ -37,24 +37,27 @@ def _test_parsing(query_processor, test, translate=False): elif query_processor == PostgreSQLQueryProcessor: query = adt.to_postgresql() - if replace_schema_name is None: - qp = query_processor(query) - else: - qp = query_processor() - qp.set_query(query) - qp.process_query(replace_schema_name=replace_schema_name) + if replace_function_names is None: + replace_function_names = [] + + qp = query_processor() + qp.set_query(query) + qp.process_query( + replace_schema_name=replace_schema_name, + replace_function_names=replace_function_names, + ) qp_columns = [ - ".".join([str(j) for j in i[:3]]) + '.'.join([str(j) for j in i[:3]]) for i in qp.columns if i[0] is not None and i[1] is not None ] qp_display_columns = [ - "%s: %s" % (str(i[0]), ".".join([str(j) for j in i[1]])) + '%s: %s' % (str(i[0]), '.'.join([str(j) for j in i[1]])) for i in qp.display_columns ] qp_tables = [ - ".".join([str(j) for j in i]) + '.'.join([str(j) for j in i]) for i in qp.tables if i[0] is not None and i[1] is not None ] From 84665c0314947bb4655b088a171b53c990c01177 Mon Sep 17 00:00:00 2001 From: simeonreusch Date: Mon, 19 May 2025 10:31:13 +0200 Subject: [PATCH 6/7] fix quotest --- src/queryparser/common/common.py | 96 ++++++++++++++++---------------- 1 file changed, 48 insertions(+), 48 deletions(-) diff --git a/src/queryparser/common/common.py b/src/queryparser/common/common.py index 680e44e..3c5476c 100644 --- a/src/queryparser/common/common.py +++ b/src/queryparser/common/common.py @@ -60,21 +60,21 @@ def process_column_name(column_name_listener, walker, ctx, quote_char): for i in column_name_listener.column_name: cni = [None, None, None, i] if i.schema_name(): - cni[0] = i.schema_name().getText().replace(quote_char, "") + cni[0] = i.schema_name().getText().replace(quote_char, '') if i.table_name(): - cni[1] = i.table_name().getText().replace(quote_char, "") + cni[1] = i.table_name().getText().replace(quote_char, '') if i.column_name(): - cni[2] = i.column_name().getText().replace(quote_char, "") + cni[2] = i.column_name().getText().replace(quote_char, '') cn.append(cni) else: try: ctx.ASTERISK() ts = ctx.table_spec() - cn = [[None, None, "*", None]] + cn = [[None, None, '*', None]] if ts.schema_name(): - cn[0][0] = ts.schema_name().getText().replace(quote_char, "") + cn[0][0] = ts.schema_name().getText().replace(quote_char, '') if ts.table_name(): - cn[0][1] = ts.table_name().getText().replace(quote_char, "") + cn[0][1] = ts.table_name().getText().replace(quote_char, '') except AttributeError: cn = [[None, None, None, None]] return cn @@ -134,16 +134,16 @@ def enterSchema_name(self, ctx): ttype = ctx.start.type sn = ctx.getTokens(ttype)[0].getSymbol().text try: - nsn = self.replace_schema_name[sn.replace(quote_char, "")] + nsn = self.replace_schema_name[sn.replace(quote_char, '')] try: - nsn = unicode(nsn, "utf-8") + nsn = unicode(nsn, 'utf-8') except NameError: pass nsn = re.sub( - r"(|{})(?!{})[\S]*[^{}](|{})".format( + r'(|{})(?!{})[\S]*[^{}](|{})'.format( quote_char, quote_char, quote_char, quote_char ), - r"\1{}\2".format(nsn), + r'\1{}\2'.format(nsn), sn, ) ctx.getTokens(ttype)[0].getSymbol().text = nsn @@ -202,7 +202,7 @@ def __init__(self): def enterSelect_statement(self, ctx): if ctx.UNION_SYM(): - self.keywords.append("union") + self.keywords.append('union') def enterSelect_expression(self, ctx): # we need to keep track of unions as they act as subqueries @@ -294,9 +294,9 @@ def enterTable_atom(self, ctx): if ts: tn = [None, None] if ts.schema_name(): - tn[0] = ts.schema_name().getText().replace(quote_char, "") + tn[0] = ts.schema_name().getText().replace(quote_char, '') if ts.table_name(): - tn[1] = ts.table_name().getText().replace(quote_char, "") + tn[1] = ts.table_name().getText().replace(quote_char, '') self.tables.append((alias, tn, ctx.depth())) logging.info((ctx.depth(), ctx.__class__.__name__, [tn, alias])) @@ -315,7 +315,7 @@ def enterDisplayed_column(self, ctx): ) self._extract_column(ctx) if ctx.ASTERISK(): - self.keywords.append("*") + self.keywords.append('*') def enterSelect_expression(self, ctx): logging.info((ctx.depth(), ctx.__class__.__name__)) @@ -324,11 +324,11 @@ def enterSelect_expression(self, ctx): def enterSelect_list(self, ctx): if ctx.ASTERISK(): logging.info( - (ctx.depth(), ctx.__class__.__name__, [[None, None, "*"], None]) + (ctx.depth(), ctx.__class__.__name__, [[None, None, '*'], None]) ) - self.data.append([ctx.depth(), ctx, [[[None, None, "*"], None]]]) - self.columns.append(("*", None)) - self.keywords.append("*") + self.data.append([ctx.depth(), ctx, [[[None, None, '*'], None]]]) + self.columns.append(('*', None)) + self.keywords.append('*') def enterFunctionList(self, ctx): self.functions.append(ctx.getText()) @@ -337,7 +337,7 @@ def enterGroup_functions(self, ctx): self.functions.append(ctx.getText()) def enterGroupby_clause(self, ctx): - self.keywords.append("group by") + self.keywords.append('group by') col = self._extract_column(ctx, append=False) if col[1][0][0][2] not in self.column_aliases: self._extract_column(ctx) @@ -353,7 +353,7 @@ def enterGroupby_clause(self, ctx): ) def enterWhere_clause(self, ctx): - self.keywords.append("where") + self.keywords.append('where') self._extract_column(ctx) logging.info( ( @@ -367,7 +367,7 @@ def enterWhere_clause(self, ctx): ) def enterHaving_clause(self, ctx): - self.keywords.append("having") + self.keywords.append('having') self._extract_column(ctx) logging.info( ( @@ -381,7 +381,7 @@ def enterHaving_clause(self, ctx): ) def enterOrderby_clause(self, ctx): - self.keywords.append("order by") + self.keywords.append('order by') col = self._extract_column(ctx, append=False) if col[1][0][0][2] not in self.column_aliases: self._extract_column(ctx) @@ -397,10 +397,10 @@ def enterOrderby_clause(self, ctx): ) def enterLimit_clause(self, ctx): - self.keywords.append("limit") + self.keywords.append('limit') def enterJoin_condition(self, ctx): - self.keywords.append("join") + self.keywords.append('join') self._extract_column(ctx, join_columns=ctx) logging.info( ( @@ -414,28 +414,28 @@ def enterJoin_condition(self, ctx): ) def enterSpoint(self, ctx): - self.functions.append("spoint") + self.functions.append('spoint') def enterScircle(self, ctx): - self.functions.append("scircle") + self.functions.append('scircle') def enterSline(self, ctx): - self.functions.append("sline") + self.functions.append('sline') def enterSellipse(self, ctx): - self.functions.append("sellipse") + self.functions.append('sellipse') def enterSbox(self, ctx): - self.functions.append("sbox") + self.functions.append('sbox') def enterSpoly(self, ctx): - self.functions.append("spoly") + self.functions.append('spoly') def enterSpath(self, ctx): - self.functions.append("spath") + self.functions.append('spath') def enterStrans(self, ctx): - self.functions.append("strans") + self.functions.append('strans') return ColumnKeywordFunctionListener @@ -623,19 +623,19 @@ def _get_budget_column(self, c, tab, ref): column_found = False for bc in ref: - if bc[0][2] == "*": - t = [[bc[0][0], bc[0][1]], "None"] + if bc[0][2] == '*': + t = [[bc[0][0], bc[0][1]], 'None'] column_found = True break elif bc[1] and c[0][2] == bc[1]: - t = [[bc[0][0], bc[0][1]], "None"] + t = [[bc[0][0], bc[0][1]], 'None'] cname = bc[0][2] if c[1] is None: calias = c[0][2] column_found = True break elif c[0][2] == bc[0][2] and bc[1] is None: - t = [[bc[0][0], bc[0][1]], "None"] + t = [[bc[0][0], bc[0][1]], 'None'] column_found = True break @@ -666,7 +666,7 @@ def _extract_columns( calias = c[1] # if * is selected we don't care too much - if c[0][0] is None and c[0][1] is None and c[0][2] == "*" and not join: + if c[0][0] is None and c[0][1] is None and c[0][2] == '*' and not join: for slt in select_list_tables: extra_columns.append( [[slt[0][0][0], slt[0][0][1], cname, c[0][3]], calias] @@ -683,13 +683,13 @@ def _extract_columns( try: tab = select_list_tables[0][0] if tab[0][0] is None: - raise QueryError("Missing schema specification.") + raise QueryError('Missing schema specification.') # We have to check if we also have a join on the same level # and we are actually touching a column from the joined table if ( join - and c[0][2] != "*" + and c[0][2] != '*' and (tab[1] != c[0][1] or (tab[1] is None and c[0][1] is None)) ): cname, cctx, calias, column_found, tab = self._get_budget_column( @@ -718,7 +718,7 @@ def _extract_columns( not column_found and c[0][1] is not None and c[0][1] != tab[0][1] - and "*" not in ref_cols + and '*' not in ref_cols ): raise QueryError("Unknown column '%s.%s'." % (c[0][1], c[0][2])) @@ -776,7 +776,7 @@ def _extract_columns( elif ( c[0][2] is not None - and c[0][2] != "*" + and c[0][2] != '*' and c[0][1] is None and len(ref_dict.keys()) > 1 and not join @@ -818,12 +818,12 @@ def _match_and_replace_function_name(query, function_name, i): This very roughly checks if the function name is present in the query. We check for a space, the function name, and an opening parenthesis. """ - pattern = r"\s" + re.escape(function_name) + r"\(" + pattern = r'\s' + re.escape(function_name) + r'\(' match = re.search(pattern, query) if match: start, end = match.span() # Replace the matched function name with UDF_{i} - query = query[: start + 1] + f"UDF_{i}" + query[end - 1 :] + query = query[: start + 1] + f'UDF_{i}' + query[end - 1 :] return match, query @@ -1025,7 +1025,7 @@ def process_query( ) if len(mc): unref_cols = "', '".join( - [".".join([j for j in i[0][:3] if j]) for i in mc] + ['.'.join([j for j in i[0][:3] if j]) for i in mc] ) raise QueryError("Unreferenced column(s): '%s'." % unref_cols) @@ -1052,7 +1052,7 @@ def process_query( asterisk_columns = [] del_columns = [] for col in touched_columns: - if col[2] == "*": + if col[2] == '*': asterisk_columns.append(col) for acol in asterisk_columns: @@ -1089,8 +1089,8 @@ def process_query( if len(self.replaced_functions) > 0: for i, function_name in self.replaced_functions.items(): - self._query = self.query.replace(f"UDF_{i}", function_name) - self.functions.remove(f"UDF_{i}") + self._query = self.query.replace(f'UDF_{i}', function_name) + self.functions.remove(f'UDF_{i}') self.functions.append(function_name) @property @@ -1102,7 +1102,7 @@ def query(self): return self._query def _strip_query(self, query): - return query.lstrip("\n").rstrip().rstrip(";") + ";" + return query.lstrip('\n').rstrip().rstrip(';') + ';' def _strip_column(self, col): scol = [None, None, None] From 4e0b83275feeacb11af6ccdb6a18eff2e6ddc8f8 Mon Sep 17 00:00:00 2001 From: simeonreusch Date: Mon, 19 May 2025 10:40:41 +0200 Subject: [PATCH 7/7] raise ValueError if too many udfs are passed --- src/queryparser/common/common.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/queryparser/common/common.py b/src/queryparser/common/common.py index 3c5476c..c2343b8 100644 --- a/src/queryparser/common/common.py +++ b/src/queryparser/common/common.py @@ -846,6 +846,10 @@ def process_query( self.replaced_functions = {} if replace_function_names: + if (n := len(replace_function_names)) > 10: + raise ValueError( + f'Too many function names to replace (you passed {n}). Maximum: 10' + ) for i, function_name in enumerate(replace_function_names): match, query = self._match_and_replace_function_name( self.query, function_name, i