From 27ddb972c3306f9d338b690009492ce92038e9fb Mon Sep 17 00:00:00 2001 From: Balram Choudhary Date: Mon, 15 Dec 2025 14:09:42 +0530 Subject: [PATCH] Make Db2 dialect statement-cache compatible Signed-off-by: Balram Choudhary --- ibm_db_sa/base.py | 128 +++++++++++++++++++++++++++++----------- ibm_db_sa/ibm_db.py | 2 +- ibm_db_sa/pyodbc.py | 2 +- ibm_db_sa/reflection.py | 1 + pyproject.toml | 2 +- setup.py | 4 +- 6 files changed, 98 insertions(+), 41 deletions(-) diff --git a/ibm_db_sa/base.py b/ibm_db_sa/base.py index 53f5879..53c9bd7 100644 --- a/ibm_db_sa/base.py +++ b/ibm_db_sa/base.py @@ -24,6 +24,8 @@ from sqlalchemy import types as sa_types from sqlalchemy import schema as sa_schema from sqlalchemy import util +from sqlalchemy import exc +from sqlalchemy.sql.elements import BindParameter from sqlalchemy.sql import compiler from sqlalchemy.sql import operators from sqlalchemy.engine import default @@ -396,42 +398,87 @@ def visit_mod_binary(self, binary, operator, **kw): return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right)) - def limit_clause(self, select, **kwargs): - limit = select._limit - offset = select._offset or 0 - - if limit is not None: - if offset > 0: - return f" LIMIT {limit} OFFSET {offset}" - else: - return f" LIMIT {limit}" - return "" + def literalBindsFlagFrom_kw(self, kw=None): + """Return True if literal_binds is requested in compile kwargs.""" + if not kw or not isinstance(kw, dict): + return False + if kw.get("literal_binds"): + return True + ck = kw.get("compile_kwargs") + if isinstance(ck, dict) and ck.get("literal_binds"): + return True + return False + + def limit_clause(self, select, **kw): + text = "" + limit_clause = select._limit_clause + offset_clause = select._offset_clause + literal_binds = self.literalBindsFlagFrom_kw(kw) - def visit_select(self, select, **kwargs): - limit, offset = select._limit, select._offset - sql_ori = compiler.SQLCompiler.visit_select(self, select, **kwargs) + def _render_clause(clause): + if clause is None: + return None + if select._simple_int_clause(clause): + return self.process(clause.render_literal_execute(), **kw) + if literal_binds: + if hasattr(clause, "render_literal_execute"): + try: + return self.process(clause.render_literal_execute(), **kw) + except Exception: + pass + try: + return self.process(clause, literal_binds=True, **kw) + except Exception: + pass + try: + if isinstance(clause, BindParameter): + val = getattr(clause, "value", None) + if val is not None: + if isinstance(val, str): + return f"'{val}'" + return str(val) + except Exception: + pass + try: + return self.process(clause, **kw) + except Exception as e: + raise exc.CompileError( + "dialect 'ibm_db_sa' cannot render LIMIT/OFFSET for this clause; " + "ensure the clause is a simple integer or is processable by the compiler." + ) from e + + limit_text = _render_clause(limit_clause) + if limit_text is not None: + text += " LIMIT %s" % limit_text + offset_text = _render_clause(offset_clause) + if offset_text is not None: + text += " OFFSET %s" % offset_text + return text - if ('LIMIT' in sql_ori.upper()) or ('FETCH FIRST' in sql_ori.upper()): + def visit_select(self, select, **kw): + sql_ori = compiler.SQLCompiler.visit_select(self, select, **kw) + if ("LIMIT" in sql_ori.upper()) or ("FETCH FIRST" in sql_ori.upper()): return sql_ori - - if limit is not None: - sql = re.sub(r'FETCH FIRST \d+ ROWS ONLY', '', sql_ori, flags=re.IGNORECASE).strip() - limit_offset_clause = self.limit_clause(select, **kwargs) - sql += limit_offset_clause - return sql - - if offset is not None: + limit_clause_obj = select._limit_clause + offset_clause_obj = select._offset_clause + if limit_clause_obj is not None: + limit_offset_clause = self.limit_clause(select, **kw) + if limit_offset_clause: + return sql_ori + limit_offset_clause + if offset_clause_obj is not None: __rownum = 'Z.__ROWNUM' - sql_split = re.split(r"[\s+]FROM ", sql_ori, 1) + sql_work = re.sub(r'FETCH FIRST \d+ ROWS ONLY', '', sql_ori, flags=re.IGNORECASE).strip() + sql_work = re.sub(r'\s+OFFSET\s+(?:\d+|__\[POSTCOMPILE_[^\]]+\]|:[A-Za-z0-9_]+|\?)\s*$', '', sql_work, + flags=re.IGNORECASE) + sql_split = re.split(r"[\s+]FROM ", sql_work, 1) + if len(sql_split) < 2: + return sql_ori sql_sec = " \nFROM %s " % (sql_split[1]) - dummyVal = "Z.__db2_" sql_pri = "" - sql_sel = "SELECT " if select._distinct: sql_sel = "SELECT DISTINCT " - sql_select_token = sql_split[0].split(",") i = 0 while i < len(sql_select_token): @@ -440,32 +487,41 @@ def visit_select(self, select, **kwargs): sql_pri = f'{sql_pri} {sql_select_token[i]},{sql_select_token[i + 1]},{sql_select_token[i + 2]},{sql_select_token[i + 3]} AS "{dummyVal}{i + 1}",' i += 4 continue - if sql_select_token[i].count(" AS ") == 1: temp_col_alias = sql_select_token[i].split(" AS ") sql_pri = f'{sql_pri} {sql_select_token[i]},' sql_sel = f'{sql_sel} {temp_col_alias[1]},' i += 1 continue - sql_pri = f'{sql_pri} {sql_select_token[i]} AS "{dummyVal}{i + 1}",' sql_sel = f'{sql_sel} "{dummyVal}{i + 1}",' i += 1 - sql_pri = sql_pri.rstrip(",") sql_pri = f"{sql_pri}{sql_sec}" sql_sel = sql_sel.rstrip(",") sql = f'{sql_sel}, ( ROW_NUMBER() OVER() ) AS "{__rownum}" FROM ( {sql_pri} ) AS M' sql = f'{sql_sel} FROM ( {sql} ) Z WHERE' - if offset != 0: - sql = f'{sql} "{__rownum}" > {offset}' - if offset != 0 and limit is not None: + def _process_clause_text(clause): + if clause is None: + return None + if select._simple_int_clause(clause): + return self.process(clause.render_literal_execute(), **kw) + else: + return self.process(clause, **kw) + + offset_text = _process_clause_text(offset_clause_obj) + limit_text = _process_clause_text(limit_clause_obj) + if offset_text is not None: + sql = f'{sql} "{__rownum}" > {offset_text}' + if offset_text is not None and limit_text is not None: sql = f'{sql} AND ' - if limit is not None: - sql = f'{sql} "{__rownum}" <= {offset + limit}' + if limit_text is not None: + if offset_text is not None: + sql = f'{sql} "{__rownum}" <= ({offset_text} + {limit_text})' + else: + sql = f'{sql} "{__rownum}" <= {limit_text}' return f"( {sql} )" - return sql_ori def visit_sequence(self, sequence, **kw): @@ -753,7 +809,7 @@ class DB2Dialect(default.DefaultDialect): supports_sane_multi_rowcount = True supports_native_decimal = False supports_native_boolean = False - supports_statement_cache = False + supports_statement_cache = True preexecute_sequences = False supports_alter = True supports_sequences = True diff --git a/ibm_db_sa/ibm_db.py b/ibm_db_sa/ibm_db.py index d2c4adf..7352cc2 100644 --- a/ibm_db_sa/ibm_db.py +++ b/ibm_db_sa/ibm_db.py @@ -99,7 +99,7 @@ def get_result_proxy(self): class DB2Dialect_ibm_db(DB2Dialect): driver = 'ibm_db_sa' supports_unicode_statements = True - supports_statement_cache = False + supports_statement_cache = True supports_sane_rowcount = True supports_sane_multi_rowcount = False supports_native_decimal = False diff --git a/ibm_db_sa/pyodbc.py b/ibm_db_sa/pyodbc.py index a77f5b4..4988dfd 100644 --- a/ibm_db_sa/pyodbc.py +++ b/ibm_db_sa/pyodbc.py @@ -30,7 +30,7 @@ class DB2Dialect_pyodbc(PyODBCConnector, DB2Dialect): supports_unicode_statements = True supports_char_length = True supports_native_decimal = False - supports_statement_cache = False + supports_statement_cache = True execution_ctx_cls = DB2ExecutionContext_pyodbc diff --git a/ibm_db_sa/reflection.py b/ibm_db_sa/reflection.py index 059e5de..396ab56 100644 --- a/ibm_db_sa/reflection.py +++ b/ibm_db_sa/reflection.py @@ -29,6 +29,7 @@ class CoerceUnicode(sa_types.TypeDecorator): impl = sa_types.Unicode + cache_ok = True def process_bind_param(self, value, dialect): if isinstance(value, str): diff --git a/pyproject.toml b/pyproject.toml index 1b68d94..20de75d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,3 @@ [build-system] -requires = ["setuptools>=42", "wheel"] +requires = ["setuptools>=42", "wheel", "packaging>=20.0"] build-backend = "setuptools.build_meta" \ No newline at end of file diff --git a/setup.py b/setup.py index d44a02c..b954e07 100644 --- a/setup.py +++ b/setup.py @@ -11,9 +11,9 @@ readme = os.path.join(os.path.dirname(__file__), 'README.md') if 'USE_PYODBC' in os.environ and os.environ['USE_PYODBC'] == '1': - require = ['sqlalchemy>=0.7.3'] + require = ['sqlalchemy>=0.7.3', 'packaging>=20.0'] else: - require = ['sqlalchemy>=0.7.3','ibm_db>=2.0.0'] + require = ['sqlalchemy>=0.7.3','ibm_db>=2.0.0', 'packaging>=20.0'] setup(