Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 92 additions & 36 deletions ibm_db_sa/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from sqlalchemy import types as sa_types
from sqlalchemy import schema as sa_schema
from sqlalchemy import util
from sqlalchemy import exc
from sqlalchemy.sql.elements import BindParameter
from sqlalchemy.sql import compiler
from sqlalchemy.sql import operators
from sqlalchemy.engine import default
Expand Down Expand Up @@ -396,42 +398,87 @@ def visit_mod_binary(self, binary, operator, **kw):
return "mod(%s, %s)" % (self.process(binary.left),
self.process(binary.right))

def limit_clause(self, select, **kwargs):
limit = select._limit
offset = select._offset or 0

if limit is not None:
if offset > 0:
return f" LIMIT {limit} OFFSET {offset}"
else:
return f" LIMIT {limit}"
return ""
def literalBindsFlagFrom_kw(self, kw=None):
"""Return True if literal_binds is requested in compile kwargs."""
if not kw or not isinstance(kw, dict):
return False
if kw.get("literal_binds"):
return True
ck = kw.get("compile_kwargs")
if isinstance(ck, dict) and ck.get("literal_binds"):
return True
return False

def limit_clause(self, select, **kw):
text = ""
limit_clause = select._limit_clause
offset_clause = select._offset_clause
literal_binds = self.literalBindsFlagFrom_kw(kw)

def visit_select(self, select, **kwargs):
limit, offset = select._limit, select._offset
sql_ori = compiler.SQLCompiler.visit_select(self, select, **kwargs)
def _render_clause(clause):
if clause is None:
return None
if select._simple_int_clause(clause):
return self.process(clause.render_literal_execute(), **kw)
if literal_binds:
if hasattr(clause, "render_literal_execute"):
try:
return self.process(clause.render_literal_execute(), **kw)
except Exception:
pass
try:
return self.process(clause, literal_binds=True, **kw)
except Exception:
pass
try:
if isinstance(clause, BindParameter):
val = getattr(clause, "value", None)
if val is not None:
if isinstance(val, str):
return f"'{val}'"
return str(val)
except Exception:
pass
try:
return self.process(clause, **kw)
except Exception as e:
raise exc.CompileError(
"dialect 'ibm_db_sa' cannot render LIMIT/OFFSET for this clause; "
"ensure the clause is a simple integer or is processable by the compiler."
) from e

limit_text = _render_clause(limit_clause)
if limit_text is not None:
text += " LIMIT %s" % limit_text
offset_text = _render_clause(offset_clause)
if offset_text is not None:
text += " OFFSET %s" % offset_text
return text

if ('LIMIT' in sql_ori.upper()) or ('FETCH FIRST' in sql_ori.upper()):
def visit_select(self, select, **kw):
sql_ori = compiler.SQLCompiler.visit_select(self, select, **kw)
if ("LIMIT" in sql_ori.upper()) or ("FETCH FIRST" in sql_ori.upper()):
return sql_ori

if limit is not None:
sql = re.sub(r'FETCH FIRST \d+ ROWS ONLY', '', sql_ori, flags=re.IGNORECASE).strip()
limit_offset_clause = self.limit_clause(select, **kwargs)
sql += limit_offset_clause
return sql

if offset is not None:
limit_clause_obj = select._limit_clause
offset_clause_obj = select._offset_clause
if limit_clause_obj is not None:
limit_offset_clause = self.limit_clause(select, **kw)
if limit_offset_clause:
return sql_ori + limit_offset_clause
if offset_clause_obj is not None:
__rownum = 'Z.__ROWNUM'
sql_split = re.split(r"[\s+]FROM ", sql_ori, 1)
sql_work = re.sub(r'FETCH FIRST \d+ ROWS ONLY', '', sql_ori, flags=re.IGNORECASE).strip()
sql_work = re.sub(r'\s+OFFSET\s+(?:\d+|__\[POSTCOMPILE_[^\]]+\]|:[A-Za-z0-9_]+|\?)\s*$', '', sql_work,
flags=re.IGNORECASE)
sql_split = re.split(r"[\s+]FROM ", sql_work, 1)
if len(sql_split) < 2:
return sql_ori
sql_sec = " \nFROM %s " % (sql_split[1])

dummyVal = "Z.__db2_"
sql_pri = ""

sql_sel = "SELECT "
if select._distinct:
sql_sel = "SELECT DISTINCT "

sql_select_token = sql_split[0].split(",")
i = 0
while i < len(sql_select_token):
Expand All @@ -440,32 +487,41 @@ def visit_select(self, select, **kwargs):
sql_pri = f'{sql_pri} {sql_select_token[i]},{sql_select_token[i + 1]},{sql_select_token[i + 2]},{sql_select_token[i + 3]} AS "{dummyVal}{i + 1}",'
i += 4
continue

if sql_select_token[i].count(" AS ") == 1:
temp_col_alias = sql_select_token[i].split(" AS ")
sql_pri = f'{sql_pri} {sql_select_token[i]},'
sql_sel = f'{sql_sel} {temp_col_alias[1]},'
i += 1
continue

sql_pri = f'{sql_pri} {sql_select_token[i]} AS "{dummyVal}{i + 1}",'
sql_sel = f'{sql_sel} "{dummyVal}{i + 1}",'
i += 1

sql_pri = sql_pri.rstrip(",")
sql_pri = f"{sql_pri}{sql_sec}"
sql_sel = sql_sel.rstrip(",")
sql = f'{sql_sel}, ( ROW_NUMBER() OVER() ) AS "{__rownum}" FROM ( {sql_pri} ) AS M'
sql = f'{sql_sel} FROM ( {sql} ) Z WHERE'

if offset != 0:
sql = f'{sql} "{__rownum}" > {offset}'
if offset != 0 and limit is not None:
def _process_clause_text(clause):
if clause is None:
return None
if select._simple_int_clause(clause):
return self.process(clause.render_literal_execute(), **kw)
else:
return self.process(clause, **kw)

offset_text = _process_clause_text(offset_clause_obj)
limit_text = _process_clause_text(limit_clause_obj)
if offset_text is not None:
sql = f'{sql} "{__rownum}" > {offset_text}'
if offset_text is not None and limit_text is not None:
sql = f'{sql} AND '
if limit is not None:
sql = f'{sql} "{__rownum}" <= {offset + limit}'
if limit_text is not None:
if offset_text is not None:
sql = f'{sql} "{__rownum}" <= ({offset_text} + {limit_text})'
else:
sql = f'{sql} "{__rownum}" <= {limit_text}'
return f"( {sql} )"

return sql_ori

def visit_sequence(self, sequence, **kw):
Expand Down Expand Up @@ -753,7 +809,7 @@ class DB2Dialect(default.DefaultDialect):
supports_sane_multi_rowcount = True
supports_native_decimal = False
supports_native_boolean = False
supports_statement_cache = False
supports_statement_cache = True
preexecute_sequences = False
supports_alter = True
supports_sequences = True
Expand Down
2 changes: 1 addition & 1 deletion ibm_db_sa/ibm_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def get_result_proxy(self):
class DB2Dialect_ibm_db(DB2Dialect):
driver = 'ibm_db_sa'
supports_unicode_statements = True
supports_statement_cache = False
supports_statement_cache = True
supports_sane_rowcount = True
supports_sane_multi_rowcount = False
supports_native_decimal = False
Expand Down
2 changes: 1 addition & 1 deletion ibm_db_sa/pyodbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class DB2Dialect_pyodbc(PyODBCConnector, DB2Dialect):
supports_unicode_statements = True
supports_char_length = True
supports_native_decimal = False
supports_statement_cache = False
supports_statement_cache = True

execution_ctx_cls = DB2ExecutionContext_pyodbc

Expand Down
1 change: 1 addition & 0 deletions ibm_db_sa/reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

class CoerceUnicode(sa_types.TypeDecorator):
impl = sa_types.Unicode
cache_ok = True

def process_bind_param(self, value, dialect):
if isinstance(value, str):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[build-system]
requires = ["setuptools>=42", "wheel"]
requires = ["setuptools>=42", "wheel", "packaging>=20.0"]
build-backend = "setuptools.build_meta"
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

readme = os.path.join(os.path.dirname(__file__), 'README.md')
if 'USE_PYODBC' in os.environ and os.environ['USE_PYODBC'] == '1':
require = ['sqlalchemy>=0.7.3']
require = ['sqlalchemy>=0.7.3', 'packaging>=20.0']
else:
require = ['sqlalchemy>=0.7.3','ibm_db>=2.0.0']
require = ['sqlalchemy>=0.7.3','ibm_db>=2.0.0', 'packaging>=20.0']


setup(
Expand Down