diff --git a/pyproject.toml b/pyproject.toml index d1f590c3..30c372e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta" [project] name = "dataops-testgen" -version = "4.16.3" +version = "4.20.4" description = "DataKitchen's Data Quality DataOps TestGen" authors = [ { "name" = "DataKitchen, Inc.", "email" = "info@datakitchen.io" }, diff --git a/testgen/__main__.py b/testgen/__main__.py index 74541d76..6e0d8a9c 100644 --- a/testgen/__main__.py +++ b/testgen/__main__.py @@ -41,8 +41,10 @@ get_tg_schema, version_service, ) +from testgen.common.models import with_database_session +from testgen.common.models.profiling_run import ProfilingRun +from testgen.common.models.test_run import TestRun from testgen.scheduler import register_scheduler_job, run_scheduler -from testgen.ui.queries import profiling_run_queries, test_run_queries from testgen.utils import plugins LOG = logging.getLogger("testgen") @@ -72,9 +74,9 @@ def invoke(self, ctx: Context): cls=CliGroup, help=f""" {VERSION_DATA.edition} {VERSION_DATA.current or ""} - + {f"New version available! {VERSION_DATA.latest}" if VERSION_DATA.latest != VERSION_DATA.current else ""} - + Schema revision: {get_schema_revision()} """ ) @@ -625,11 +627,16 @@ def run_ui(): use_ssl = os.path.isfile(settings.SSL_CERT_FILE) and os.path.isfile(settings.SSL_KEY_FILE) patch_streamlit.patch(force=True) - try: - profiling_run_queries.cancel_all_running() - test_run_queries.cancel_all_running() - except Exception: - LOG.warning("Failed to cancel 'Running' profiling/test runs") + + @with_database_session + def cancel_all_running(): + try: + ProfilingRun.cancel_all_running() + TestRun.cancel_all_running() + except Exception: + LOG.warning("Failed to cancel 'Running' profiling/test runs") + + cancel_all_running() try: app_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ui/app.py") diff --git a/testgen/commands/queries/execute_cat_tests_query.py b/testgen/commands/queries/execute_cat_tests_query.py index f3c94ff8..0a22ee8c 100644 --- a/testgen/commands/queries/execute_cat_tests_query.py +++ b/testgen/commands/queries/execute_cat_tests_query.py @@ -1,11 +1,19 @@ -import typing +from typing import ClassVar, TypedDict from testgen.commands.queries.rollup_scores_query import CRollupScoresSQL from testgen.common import date_service, read_template_sql_file -from testgen.common.database import database_service +from testgen.common.database.database_service import get_flavor_service, replace_params from testgen.common.read_file import replace_templated_functions +class CATTestParams(TypedDict): + schema_name: str + table_name: str + cat_sequence: int + test_measures: str + test_conditions: str + + class CCATExecutionSQL: project_code = "" flavor = "" @@ -16,11 +24,9 @@ class CCATExecutionSQL: table_groups_id = "" max_query_chars = "" exception_message = "" - - # Test Set Parameters target_schema = "" target_table = "" - dctTestParms: typing.ClassVar = {} + cat_test_params: ClassVar[CATTestParams] = {} _rollup_scores_sql: CRollupScoresSQL = None @@ -29,7 +35,7 @@ def __init__(self, strProjectCode, strTestSuiteId, strTestSuite, strSQLFlavor, m self.test_suite_id = strTestSuiteId self.test_suite = strTestSuite self.project_code = strProjectCode - flavor_service = database_service.get_flavor_service(strSQLFlavor) + flavor_service = get_flavor_service(strSQLFlavor) self.concat_operator = flavor_service.get_concat_operator() self.flavor = strSQLFlavor self.max_query_chars = max_query_chars @@ -41,83 +47,78 @@ def _get_rollup_scores_sql(self) -> CRollupScoresSQL: self._rollup_scores_sql = CRollupScoresSQL(self.test_run_id, self.table_groups_id) return self._rollup_scores_sql - - def _ReplaceParms(self, strInputString): - strInputString = strInputString.replace("{MAX_QUERY_CHARS}", str(self.max_query_chars)) - strInputString = strInputString.replace("{TEST_RUN_ID}", self.test_run_id) - strInputString = strInputString.replace("{PROJECT_CODE}", self.project_code) - strInputString = strInputString.replace("{TEST_SUITE}", self.test_suite) - strInputString = strInputString.replace("{TEST_SUITE_ID}", self.test_suite_id) - strInputString = strInputString.replace("{TABLE_GROUPS_ID}", self.table_groups_id) - - strInputString = strInputString.replace("{SQL_FLAVOR}", self.flavor) - strInputString = strInputString.replace("{ID_SEPARATOR}", "`" if self.flavor == "databricks" else '"') - strInputString = strInputString.replace("{CONCAT_OPERATOR}", self.concat_operator) - - strInputString = strInputString.replace("{SCHEMA_NAME}", self.target_schema) - strInputString = strInputString.replace("{TABLE_NAME}", self.target_table) - - strInputString = strInputString.replace("{RUN_DATE}", self.run_date) - strInputString = strInputString.replace("{NOW_DATE}", "GETDATE()") - strInputString = strInputString.replace("{START_TIME}", self.today) - strInputString = strInputString.replace( - "{NOW}", date_service.get_now_as_string_with_offset(self.minutes_offset) - ) - strInputString = strInputString.replace("{EXCEPTION_MESSAGE}", self.exception_message.strip()) - - for parm, value in self.dctTestParms.items(): - strInputString = strInputString.replace("{" + parm.upper() + "}", str(value)) - - strInputString = strInputString.replace("{RUN_DATE}", self.run_date) - - strInputString = replace_templated_functions(strInputString, self.flavor) - - if self.flavor != "databricks": + + def _get_query(self, template_file_name: str, sub_directory: str | None = "exec_cat_tests", no_bind: bool = False) -> tuple[str, dict | None]: + query = read_template_sql_file(template_file_name, sub_directory) + params = { + "MAX_QUERY_CHARS": self.max_query_chars, + "TEST_RUN_ID": self.test_run_id, + "PROJECT_CODE": self.project_code, + "TEST_SUITE": self.test_suite, + "TEST_SUITE_ID": self.test_suite_id, + "TABLE_GROUPS_ID": self.table_groups_id, + "SQL_FLAVOR": self.flavor, + "ID_SEPARATOR": "`" if self.flavor == "databricks" else '"', + "CONCAT_OPERATOR": self.concat_operator, + "SCHEMA_NAME": self.target_schema, + "TABLE_NAME": self.target_table, + "NOW_DATE": "GETDATE()", + "START_TIME": self.today, + "NOW_TIMESTAMP": date_service.get_now_as_string_with_offset(self.minutes_offset), + "EXCEPTION_MESSAGE": self.exception_message.strip(), + **{key.upper(): value for key, value in self.cat_test_params.items()}, + # This has to be replaced at the end + "RUN_DATE": self.run_date, + } + query = replace_params(query, params) + query = replace_templated_functions(query, self.flavor) + + if no_bind and self.flavor != "databricks": # Adding escape character where ':' is referenced - strInputString = strInputString.replace(":", "\\:") + query = query.replace(":", "\\:") - return strInputString + return query, None if no_bind else params - def GetDistinctTablesSQL(self): - # Runs on DK DB - strQ = self._ReplaceParms(read_template_sql_file("ex_cat_get_distinct_tables.sql", "exec_cat_tests")) - return strQ + def GetDistinctTablesSQL(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("ex_cat_get_distinct_tables.sql") - def GetAggregateTableTestSQL(self): - # Runs on DK DB - strQ = self._ReplaceParms(read_template_sql_file("ex_cat_build_agg_table_tests.sql", "exec_cat_tests")) - return strQ + def GetAggregateTableTestSQL(self) -> tuple[str, None]: + # Runs on App database + return self._get_query("ex_cat_build_agg_table_tests.sql", no_bind=True) - def GetAggregateTestParmsSQL(self): - # Runs on DK DB - strQ = self._ReplaceParms(read_template_sql_file("ex_cat_retrieve_agg_test_parms.sql", "exec_cat_tests")) - return strQ + def GetAggregateTestParmsSQL(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("ex_cat_retrieve_agg_test_parms.sql") - def PrepCATQuerySQL(self): - strQ = self._ReplaceParms(read_template_sql_file("ex_cat_test_query.sql", "exec_cat_tests")) - return strQ + def PrepCATQuerySQL(self) -> tuple[str, None]: + # Runs on Target database + return self._get_query("ex_cat_test_query.sql", no_bind=True) - def GetCATResultsParseSQL(self): - strQ = self._ReplaceParms(read_template_sql_file("ex_cat_results_parse.sql", "exec_cat_tests")) - return strQ + def GetCATResultsParseSQL(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("ex_cat_results_parse.sql") - def FinalizeTestResultsSQL(self): - strQ = self._ReplaceParms(read_template_sql_file("ex_finalize_test_run_results.sql", "execution")) - return strQ + def FinalizeTestResultsSQL(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("ex_finalize_test_run_results.sql", "execution") - def PushTestRunStatusUpdateSQL(self): - strQ = self._ReplaceParms(read_template_sql_file("ex_update_test_record_in_testrun_table.sql", "execution")) - return strQ + def PushTestRunStatusUpdateSQL(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("ex_update_test_record_in_testrun_table.sql", "execution") - def FinalizeTestSuiteUpdateSQL(self): - strQ = self._ReplaceParms(read_template_sql_file("ex_update_test_suite.sql", "execution")) - return strQ + def FinalizeTestSuiteUpdateSQL(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("ex_update_test_suite.sql", "execution") - def CalcPrevalenceTestResultsSQL(self): - return self._ReplaceParms(read_template_sql_file("ex_calc_prevalence_test_results.sql", "execution")) + def CalcPrevalenceTestResultsSQL(self) -> tuple[str, None]: + # Runs on App database + return self._get_query("ex_calc_prevalence_test_results.sql", "execution", no_bind=True) - def TestScoringRollupRunSQL(self): + def TestScoringRollupRunSQL(self) -> tuple[str, dict]: + # Runs on App database return self._get_rollup_scores_sql().GetRollupScoresTestRunQuery() - def TestScoringRollupTableGroupSQL(self): + def TestScoringRollupTableGroupSQL(self) -> tuple[str, dict]: + # Runs on App database return self._get_rollup_scores_sql().GetRollupScoresTestTableGroupQuery() diff --git a/testgen/commands/queries/execute_tests_query.py b/testgen/commands/queries/execute_tests_query.py index 00b9d4d4..23a3e492 100644 --- a/testgen/commands/queries/execute_tests_query.py +++ b/testgen/commands/queries/execute_tests_query.py @@ -1,6 +1,41 @@ -import typing +from typing import ClassVar, TypedDict from testgen.common import AddQuotesToIdentifierCSV, CleanSQL, ConcatColumnList, date_service, read_template_sql_file +from testgen.common.database.database_service import replace_params + + +class TestParams(TypedDict): + test_type: str + test_definition_id: str + test_description: str + test_action: str + schema_name: str + table_name: str + column_name: str + skip_errors: str + baseline_ct: str + baseline_unique_ct: str + baseline_value: str + baseline_value_ct: str + threshold_value: str + baseline_sum: str + baseline_avg: str + baseline_sd: str + lower_tolerance: str + upper_tolerance: str + subset_condition: str + groupby_names: str + having_condition: str + window_date_column: str + window_days: str + match_schema_name: str + match_table_name: str + match_column_names: str + match_subset_condition: str + match_groupby_names: str + match_having_condition: str + custom_query: str + template_name: str class CTestExecutionSQL: @@ -12,12 +47,9 @@ class CTestExecutionSQL: test_run_id = "" exception_message = "" process_id = "" + test_params: ClassVar[TestParams] = {} - # Test Group Parameters - dctTestParms: typing.ClassVar = {} - sum_columns = "" - match_sum_columns = "" - multi_column_error_condition = "" + _use_clean = False def __init__(self, strProjectCode, strFlavor, strTestSuiteId, strTestSuite, minutes_offset=0): self.project_code = strProjectCode @@ -27,9 +59,8 @@ def __init__(self, strProjectCode, strFlavor, strTestSuiteId, strTestSuite, minu self.today = date_service.get_now_as_string_with_offset(minutes_offset) self.minutes_offset = minutes_offset - def _AssembleDisplayParameters(self): - - lst_parms = [ + def _get_input_parameters(self): + param_keys = [ "column_name", "skip_errors", "baseline_ct", @@ -53,125 +84,83 @@ def _AssembleDisplayParameters(self): "match_groupby_names", "match_having_condition", ] - str_parms = "; ".join(f"{key}={self.dctTestParms[key]}" - for key in lst_parms - if key.lower() in self.dctTestParms and self.dctTestParms[key] not in [None, ""]) - str_parms = str_parms.replace("'", "`") - return str_parms - - def _ReplaceParms(self, strInputString: str): - strInputString = strInputString.replace("{PROJECT_CODE}", self.project_code) - strInputString = strInputString.replace("{TEST_SUITE_ID}", self.test_suite_id) - strInputString = strInputString.replace("{TEST_SUITE}", self.test_suite) - strInputString = strInputString.replace("{SQL_FLAVOR}", self.flavor) - strInputString = strInputString.replace("{TEST_RUN_ID}", self.test_run_id) - strInputString = strInputString.replace("{INPUT_PARAMETERS}", self._AssembleDisplayParameters()) - - strInputString = strInputString.replace("{RUN_DATE}", self.run_date) - strInputString = strInputString.replace("{EXCEPTION_MESSAGE}", self.exception_message) - strInputString = strInputString.replace("{START_TIME}", self.today) - strInputString = strInputString.replace("{PROCESS_ID}", str(self.process_id)) - strInputString = strInputString.replace("{VARCHAR_TYPE}", "STRING" if self.flavor == "databricks" else "VARCHAR") - strInputString = strInputString.replace( - "{NOW}", date_service.get_now_as_string_with_offset(self.minutes_offset) + input_parameters = "; ".join( + f"{key}={self.test_params[key]}" + for key in param_keys + if key.lower() in self.test_params and self.test_params[key] not in [None, ""] ) - - column_designators = [ - "COLUMN_NAME", - # "COLUMN_NAMES", - # "COL_NAME", - # "COL_NAMES", - # "MATCH_COLUMN_NAMES", - # "MATCH_GROUPBY_NAMES", - # "MATCH_SUM_COLUMNS", - ] - - for parm, value in self.dctTestParms.items(): - if value: - if parm.upper() in column_designators: - strInputString = strInputString.replace("{" + parm.upper() + "}", AddQuotesToIdentifierCSV(value)) - else: - strInputString = strInputString.replace("{" + parm.upper() + "}", value) - else: - strInputString = strInputString.replace("{" + parm.upper() + "}", "") - if parm == "column_name": - # Shows contents without double-quotes for display and aggregate expressions - strInputString = strInputString.replace("{COLUMN_NAME_NO_QUOTES}", value if value else "") - # Concatenates column list into single expression for relative entropy - str_value = ConcatColumnList(value, "") - strInputString = strInputString.replace("{CONCAT_COLUMNS}", str_value if str_value else "") - if parm == "match_groupby_names": - # Concatenates column list into single expression for relative entropy - str_value = ConcatColumnList(value, "") - strInputString = strInputString.replace("{CONCAT_MATCH_GROUPBY}", str_value if str_value else "") - if parm == "subset_condition": - strInputString = strInputString.replace("{SUBSET_DISPLAY}", value.replace("'", "''") if value else "") - - if self.flavor != "databricks": + return input_parameters.replace("'", "`") + + def _get_query( + self, template_file_name: str, sub_directory: str | None = "execution", no_bind: bool = False + ) -> tuple[str, dict | None]: + query = read_template_sql_file(template_file_name, sub_directory) + params = { + "PROJECT_CODE": self.project_code, + "TEST_SUITE_ID": self.test_suite_id, + "TEST_SUITE": self.test_suite, + "SQL_FLAVOR": self.flavor, + "TEST_RUN_ID": self.test_run_id, + "INPUT_PARAMETERS": self._get_input_parameters(), + "RUN_DATE": self.run_date, + "EXCEPTION_MESSAGE": self.exception_message, + "START_TIME": self.today, + "PROCESS_ID": self.process_id, + "VARCHAR_TYPE": "STRING" if self.flavor == "databricks" else "VARCHAR", + "NOW_TIMESTAMP": date_service.get_now_as_string_with_offset(self.minutes_offset), + **{key.upper(): value or "" for key, value in self.test_params.items()}, + } + + if self.test_params: + column_name = self.test_params["column_name"] + params["COLUMN_NAME"] = AddQuotesToIdentifierCSV(column_name) if column_name else "" + # Shows contents without double-quotes for display and aggregate expressions + params["COLUMN_NAME_NO_QUOTES"] = column_name or "" + # Concatenates column list into single expression for relative entropy + params["CONCAT_COLUMNS"] = ConcatColumnList(column_name, "") if column_name else "" + + match_groupby_names = self.test_params["match_groupby_names"] + # Concatenates column list into single expression for relative entropy + params["CONCAT_MATCH_GROUPBY"] = ( + ConcatColumnList(match_groupby_names, "") if match_groupby_names else "" + ) + + subset_condition = self.test_params["subset_condition"] + params["SUBSET_DISPLAY"] = subset_condition.replace("'", "''") if subset_condition else "" + + query = replace_params(query, params) + + if no_bind and self.flavor != "databricks": # Adding escape character where ':' is referenced - strInputString = strInputString.replace(":", "\\:") - - return strInputString - - def ClearTestParms(self): - # Test Set Parameters - pass - - def GetTestsNonCAT(self, booClean): - # Runs on DK DB - strQ = self._ReplaceParms(read_template_sql_file("ex_get_tests_non_cat.sql", "execution")) - if booClean: - strQ = CleanSQL(strQ) - - return strQ - - def PushTestRunStatusUpdateSQL(self): - # Runs on DK DB - strQ = self._ReplaceParms(read_template_sql_file("ex_update_test_record_in_testrun_table.sql", "execution")) - - return strQ - - def _GetTestQueryFromTemplate(self, strTemplateFile: str): - # Runs on Project DB - if strTemplateFile.endswith("_generic.sql"): - template_flavor = "generic" + query = query.replace(":", "\\:") + + return query, None if no_bind else params + + def GetTestsNonCAT(self) -> tuple[str, dict]: + # Runs on App database + query, params = self._get_query("ex_get_tests_non_cat.sql") + if self._use_clean: + query = CleanSQL(query) + return query, params + + def AddTestRecordtoTestRunTable(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("ex_write_test_record_to_testrun_table.sql") + + def PushTestRunStatusUpdateSQL(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("ex_update_test_record_in_testrun_table.sql") + + def GetTestQuery(self) -> tuple[str, None]: + # Runs on Target database + if template_name := self.test_params["template_name"]: + template_flavor = "generic" if template_name.endswith("_generic.sql") else self.flavor + query, params = self._get_query(template_name, f"flavors/{template_flavor}/exec_query_tests", no_bind=True) + # Final replace to cover parm within CUSTOM_QUERY parm + query = replace_params(query, {"DATA_SCHEMA": self.test_params["schema_name"]}) + + if self._use_clean: + query = CleanSQL(query) + return query, params else: - template_flavor = self.flavor - strQ = self._ReplaceParms( - read_template_sql_file(strTemplateFile, f"flavors/{template_flavor}/exec_query_tests") - ) - return strQ - - def _ConstructAggregateMatchParms(self): - # Prepares column list for SQL to compare sums of each column - - # Split each comma separated column name into individual list items - cols = [s.strip() for s in self.dctTestParms["column_name"].split(",")] - _ = [s.strip() for s in self.dctTestParms["match_column_names"].split(",")] - - # Surround all column names with SUM() to generate proper SQL syntax - self.list_sum_columns = ["SUM(" + i + ") as " + i for i in cols] - self.sum_columns = ", ".join(self.list_sum_columns) - - self.list_match_sum_columns = ["SUM(" + i + ") as " + i for i in cols] - self.match_sum_columns = ", ".join(self.list_match_sum_columns) - - # Suffix all column names with '< 0' to generate proper SQL WHERE/HAVING clause syntax - self.list_multi_column_error_condition = [i + " < 0" for i in cols] - self.multi_column_error_condition = " or ".join(self.list_multi_column_error_condition) - - - def GetTestQuery(self, booClean: bool): - strTestType = self.dctTestParms["test_type"] - strTemplate = self.dctTestParms["template_name"] - - if strTemplate == "": - raise ValueError(f"No query template assigned to test_type {strTestType}") - - strQ = self._GetTestQueryFromTemplate(strTemplate) - # Final replace to cover parm within CUSTOM_QUERY parm - strQ = strQ.replace("{DATA_SCHEMA}", self.dctTestParms["schema_name"]) - - if booClean: - strQ = CleanSQL(strQ) - return strQ + raise ValueError(f"No query template assigned to test_type {self.test_params["test_type"]}") diff --git a/testgen/commands/queries/generate_tests_query.py b/testgen/commands/queries/generate_tests_query.py index 460ff73c..5f0b1ce2 100644 --- a/testgen/commands/queries/generate_tests_query.py +++ b/testgen/commands/queries/generate_tests_query.py @@ -1,10 +1,17 @@ import logging -import typing +from typing import ClassVar, TypedDict -from testgen.common import CleanSQL, date_service, get_template_files, read_template_sql_file +from testgen.common import CleanSQL, date_service, read_template_sql_file +from testgen.common.database.database_service import get_queries_for_command, replace_params LOG = logging.getLogger("testgen") +class GenTestParams(TypedDict): + test_type: str + selection_criteria: str + default_parm_columns: str + default_parm_values: str + class CDeriveTestsSQL: run_date = "" @@ -17,79 +24,59 @@ class CDeriveTestsSQL: generation_set = "" as_of_date = "" sql_flavor = "" - dctTestParms: typing.ClassVar = {} + gen_test_params: ClassVar[GenTestParams] = {} + + _use_clean = False def __init__(self): today = date_service.get_now_as_string() self.run_date = today self.as_of_date = today - self.dctTestParms = {} - - def ClearTestParms(self): - # Test Set Parameters - self.dctTestParms = {} - - def ReplaceParms(self, strInputString): - for parm, value in self.dctTestParms.items(): - strInputString = strInputString.replace("{" + parm.upper() + "}", value) - - strInputString = strInputString.replace("{PROJECT_CODE}", self.project_code) - strInputString = strInputString.replace("{SQL_FLAVOR}", self.sql_flavor) - strInputString = strInputString.replace("{CONNECTION_ID}", self.connection_id) - strInputString = strInputString.replace("{TABLE_GROUPS_ID}", self.table_groups_id) - strInputString = strInputString.replace("{RUN_DATE}", self.run_date) - strInputString = strInputString.replace("{TEST_SUITE}", self.test_suite) - strInputString = strInputString.replace("{TEST_SUITE_ID}", self.test_suite_id) - strInputString = strInputString.replace("{GENERATION_SET}", self.generation_set) - strInputString = strInputString.replace("{AS_OF_DATE}", self.as_of_date) - strInputString = strInputString.replace("{DATA_SCHEMA}", self.data_schema) - strInputString = strInputString.replace("{ID_SEPARATOR}", "`" if self.sql_flavor == "databricks" else '"') - - return strInputString - - def GetInsertTestSuiteSQL(self, booClean): - strQuery = self.ReplaceParms(read_template_sql_file("gen_insert_test_suite.sql", "generation")) - if booClean: - strQuery = CleanSQL(strQuery) - - return strQuery - - def GetTestTypesSQL(self, booClean): - strQuery = self.ReplaceParms(read_template_sql_file("gen_standard_test_type_list.sql", "generation")) - if booClean: - strQuery = CleanSQL(strQuery) - - return strQuery - - def GetTestDerivationQueriesAsList(self, template_directory, booClean): - # This assumes the queries run in no particular order, - # and will order them alphabetically by file name - lstQueries = sorted( - get_template_files(mask=r"^.*sql$", sub_directory=template_directory), key=lambda key: str(key) - ) - lstTemplate = [] - - for script in lstQueries: - query = script.read_text("utf-8") - template = self.ReplaceParms(query) - lstTemplate.append(template) - - if booClean: - lstTemplate = [CleanSQL(q) for q in lstTemplate] - - if len(lstQueries) == 0: - LOG.warning("No funny CAT test generation templates were found") - - return lstTemplate - - def GetTestQueriesFromGenericFile(self, booClean: bool): - strQuery = self.ReplaceParms(read_template_sql_file("gen_standard_tests.sql", "generation")) - if booClean: - strQuery = CleanSQL(strQuery) - return strQuery - def GetDeleteOldTestsQuery(self, booClean: bool): - strQuery = self.ReplaceParms(read_template_sql_file("gen_delete_old_tests.sql", "generation")) - if booClean: - strQuery = CleanSQL(strQuery) - return strQuery + def _get_params(self) -> dict: + return { + **{key.upper(): value for key, value in self.gen_test_params.items()}, + "PROJECT_CODE": self.project_code, + "SQL_FLAVOR": self.sql_flavor, + "CONNECTION_ID": self.connection_id, + "TABLE_GROUPS_ID": self.table_groups_id, + "RUN_DATE": self.run_date, + "TEST_SUITE": self.test_suite, + "TEST_SUITE_ID": self.test_suite_id, + "GENERATION_SET": self.generation_set, + "AS_OF_DATE": self.as_of_date, + "DATA_SCHEMA": self.data_schema, + "ID_SEPARATOR": "`" if self.sql_flavor == "databricks" else '"', + } + + def _get_query(self, template_file_name: str, sub_directory: str | None = "generation") -> tuple[str, dict]: + query = read_template_sql_file(template_file_name, sub_directory) + params = self._get_params() + query = replace_params(query, params) + if self._use_clean: + query = CleanSQL(query) + return query, params + + def GetInsertTestSuiteSQL(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("gen_insert_test_suite.sql") + + def GetTestTypesSQL(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("gen_standard_test_type_list.sql") + + def GetTestDerivationQueriesAsList(self, template_directory: str) -> list[tuple[str, dict]]: + # Runs on App database + params = self._get_params() + queries = get_queries_for_command(template_directory, params) + if self._use_clean: + queries = [ CleanSQL(query) for query in queries ] + return [ (query, params) for query in queries ] + + def GetTestQueriesFromGenericFile(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("gen_standard_tests.sql") + + def GetDeleteOldTestsQuery(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("gen_delete_old_tests.sql") diff --git a/testgen/commands/queries/profiling_query.py b/testgen/commands/queries/profiling_query.py index 155cc98a..d71334dc 100644 --- a/testgen/commands/queries/profiling_query.py +++ b/testgen/commands/queries/profiling_query.py @@ -3,12 +3,11 @@ from testgen.commands.queries.refresh_data_chars_query import CRefreshDataCharsSQL from testgen.commands.queries.rollup_scores_query import CRollupScoresSQL from testgen.common import date_service, read_template_sql_file, read_template_yaml_file +from testgen.common.database.database_service import replace_params from testgen.common.read_file import replace_templated_functions class CProfilingSQL: - template_path = "" - dctTemplates: typing.ClassVar = {} dctSnippetTemplate: typing.ClassVar = {} project_code = "" @@ -23,8 +22,6 @@ class CProfilingSQL: col_gen_type = "" col_type = "" col_ordinal_position = "0" - - col_max_char_length = 0 col_is_decimal = "" col_top_freq_update = "" @@ -32,12 +29,10 @@ class CProfilingSQL: parm_table_include_mask = None parm_table_exclude_mask = None parm_do_patterns = "Y" - parm_max_pattern_length = 30 + parm_max_pattern_length = 25 parm_do_freqs = "Y" - parm_max_freq_length = 30 - parm_vldb_flag = "N" parm_do_sample = "N" - parm_sample_size = "" + parm_sample_size = 0 profile_run_id = "" profile_id_column_mask = "" profile_sk_column_mask = "" @@ -48,6 +43,7 @@ class CProfilingSQL: sampling_table = "" sample_ratio = "" + sample_percent_calc = "" process_id = None @@ -64,14 +60,6 @@ def __init__(self, strProjectCode, flavor): self.project_code = strProjectCode # Defaults self.run_date = date_service.get_now_as_string() - self.col_ordinal_position = "0" - self.col_max_char_length = 0 - self.parm_do_patterns = "Y" - self.parm_max_pattern_length = 25 - self.parm_do_freqs = "Y" - self.parm_max_freq_length = 25 - self.parm_vldb_flag = "N" - self.parm_do_sample = "N" self.today = date_service.get_now_as_string() def _get_data_chars_sql(self) -> CRefreshDataCharsSQL: @@ -96,163 +84,165 @@ def _get_rollup_scores_sql(self) -> CRollupScoresSQL: return self._rollup_scores_sql - def ReplaceParms(self, strInputString): - strInputString = strInputString.replace("{PROJECT_CODE}", self.project_code) - strInputString = strInputString.replace("{CONNECTION_ID}", self.connection_id) - strInputString = strInputString.replace("{TABLE_GROUPS_ID}", self.table_groups_id) - strInputString = strInputString.replace("{RUN_DATE}", self.run_date) - strInputString = strInputString.replace("{DATA_SCHEMA}", self.data_schema) - strInputString = strInputString.replace("{DATA_TABLE}", self.data_table) - strInputString = strInputString.replace("{COL_NAME}", self.col_name) - strInputString = strInputString.replace("{COL_NAME_SANITIZED}", self.col_name.replace("'", "''")) - strInputString = strInputString.replace("{COL_GEN_TYPE}", self.col_gen_type) - strInputString = strInputString.replace("{COL_TYPE}", self.col_type or "") - strInputString = strInputString.replace("{COL_POS}", str(self.col_ordinal_position)) - strInputString = strInputString.replace("{TOP_FREQ}", self.col_top_freq_update) - strInputString = strInputString.replace("{PROFILE_RUN_ID}", self.profile_run_id) - strInputString = strInputString.replace("{PROFILE_ID_COLUMN_MASK}", self.profile_id_column_mask) - strInputString = strInputString.replace("{PROFILE_SK_COLUMN_MASK}", self.profile_sk_column_mask) - strInputString = strInputString.replace("{START_TIME}", self.today) - strInputString = strInputString.replace("{NOW}", date_service.get_now_as_string()) - strInputString = strInputString.replace("{EXCEPTION_MESSAGE}", self.exception_message) - strInputString = strInputString.replace("{SAMPLING_TABLE}", self.sampling_table) - strInputString = strInputString.replace("{SAMPLE_SIZE}", str(self.parm_sample_size)) - strInputString = strInputString.replace("{PROFILE_USE_SAMPLING}", self.profile_use_sampling) - strInputString = strInputString.replace("{PROFILE_SAMPLE_PERCENT}", self.profile_sample_percent) - strInputString = strInputString.replace("{PROFILE_SAMPLE_MIN_COUNT}", str(self.profile_sample_min_count)) - strInputString = strInputString.replace("{PROFILE_SAMPLE_RATIO}", str(self.sample_ratio)) - strInputString = strInputString.replace("{PARM_MAX_PATTERN_LENGTH}", str(self.parm_max_pattern_length)) - strInputString = strInputString.replace("{CONTINGENCY_COLUMNS}", self.contingency_columns) - strInputString = strInputString.replace("{CONTINGENCY_MAX_VALUES}", self.contingency_max_values) - strInputString = strInputString.replace("{PROCESS_ID}", str(self.process_id)) - strInputString = strInputString.replace("{SQL_FLAVOR}", self.flavor) - strInputString = replace_templated_functions(strInputString, self.flavor) - - return strInputString - - def GetSecondProfilingColumnsQuery(self): - # Runs on DK Postgres Server - strQ = self.ReplaceParms(read_template_sql_file("secondary_profiling_columns.sql", sub_directory="profiling")) - return strQ - - def GetSecondProfilingUpdateQuery(self): - # Runs on DK Postgres Server - strQ = self.ReplaceParms(read_template_sql_file("secondary_profiling_update.sql", sub_directory="profiling")) - return strQ - - def GetSecondProfilingStageDeleteQuery(self): - # Runs on DK Postgres Server - strQ = self.ReplaceParms(read_template_sql_file("secondary_profiling_delete.sql", sub_directory="profiling")) - return strQ - - def GetDataTypeSuggestionUpdateQuery(self): - # Runs on DK Postgres Server - strQ = self.ReplaceParms(read_template_sql_file("datatype_suggestions.sql", sub_directory="profiling")) - return strQ - - def GetFunctionalDataTypeUpdateQuery(self): - # Runs on DK Postgres Server - strQ = self.ReplaceParms(read_template_sql_file("functional_datatype.sql", sub_directory="profiling")) - return strQ - - def GetFunctionalTableTypeStageQuery(self): - # Runs on DK Postgres Server - strQ = self.ReplaceParms(read_template_sql_file("functional_tabletype_stage.sql", sub_directory="profiling")) - return strQ - - def GetFunctionalTableTypeUpdateQuery(self): - # Runs on DK Postgres Server - strQ = self.ReplaceParms(read_template_sql_file("functional_tabletype_update.sql", sub_directory="profiling")) - return strQ - - def GetPIIFlagUpdateQuery(self): - # Runs on DK Postgres Server - strQ = self.ReplaceParms(read_template_sql_file("pii_flag.sql", sub_directory="profiling")) - return strQ - - def GetAnomalyStatsRefreshQuery(self): - # Runs on DK Postgres Server - strQ = self.ReplaceParms(read_template_sql_file("refresh_anomalies.sql", sub_directory="profiling")) - return strQ - - def GetAnomalyScoringRollupRunQuery(self): - # Runs on DK Postgres Server + def _get_params(self) -> dict: + return { + "PROJECT_CODE": self.project_code, + "CONNECTION_ID": self.connection_id, + "TABLE_GROUPS_ID": self.table_groups_id, + "RUN_DATE": self.run_date, + "DATA_SCHEMA": self.data_schema, + "DATA_TABLE": self.data_table, + "COL_NAME": self.col_name, + "COL_NAME_SANITIZED": self.col_name.replace("'", "''"), + "COL_GEN_TYPE": self.col_gen_type, + "COL_TYPE": self.col_type or "", + "COL_POS": self.col_ordinal_position, + "TOP_FREQ": self.col_top_freq_update, + "PROFILE_RUN_ID": self.profile_run_id, + "PROFILE_ID_COLUMN_MASK": self.profile_id_column_mask, + "PROFILE_SK_COLUMN_MASK": self.profile_sk_column_mask, + "START_TIME": self.today, + "NOW_TIMESTAMP": date_service.get_now_as_string(), + "EXCEPTION_MESSAGE": self.exception_message, + "SAMPLING_TABLE": self.sampling_table, + "SAMPLE_SIZE": int(self.parm_sample_size), + "PROFILE_USE_SAMPLING": self.profile_use_sampling, + "PROFILE_SAMPLE_PERCENT": self.profile_sample_percent, + "PROFILE_SAMPLE_MIN_COUNT": self.profile_sample_min_count, + "PROFILE_SAMPLE_RATIO": self.sample_ratio, + "SAMPLE_PERCENT_CALC": self.sample_percent_calc, + "PARM_MAX_PATTERN_LENGTH": self.parm_max_pattern_length, + "CONTINGENCY_COLUMNS": self.contingency_columns, + "CONTINGENCY_MAX_VALUES": self.contingency_max_values, + "PROCESS_ID": self.process_id, + "SQL_FLAVOR": self.flavor, + } + + def _get_query( + self, + template_file_name: str, + sub_directory: str | None = "profiling", + extra_params: dict | None = None, + ) -> tuple[str | None, dict]: + query = read_template_sql_file(template_file_name, sub_directory) + params = {} + + if query: + if extra_params: + params.update(extra_params) + params.update(self._get_params()) + + query = replace_params(query, params) + query = replace_templated_functions(query, self.flavor) + + return query, params + + def GetSecondProfilingColumnsQuery(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("secondary_profiling_columns.sql") + + def GetSecondProfilingUpdateQuery(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("secondary_profiling_update.sql") + + def GetSecondProfilingStageDeleteQuery(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("secondary_profiling_delete.sql") + + def GetDataTypeSuggestionUpdateQuery(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("datatype_suggestions.sql") + + def GetFunctionalDataTypeUpdateQuery(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("functional_datatype.sql") + + def GetFunctionalTableTypeStageQuery(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("functional_tabletype_stage.sql") + + def GetFunctionalTableTypeUpdateQuery(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("functional_tabletype_update.sql") + + def GetPIIFlagUpdateQuery(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("pii_flag.sql") + + def GetAnomalyStatsRefreshQuery(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("refresh_anomalies.sql") + + def GetAnomalyScoringRollupRunQuery(self) -> tuple[str, dict]: + # Runs on App database return self._get_rollup_scores_sql().GetRollupScoresProfileRunQuery() - def GetAnomalyScoringRollupTableGroupQuery(self): - # Runs on DK Postgres Server + def GetAnomalyScoringRollupTableGroupQuery(self) -> tuple[str, dict]: + # Runs on App database return self._get_rollup_scores_sql().GetRollupScoresProfileTableGroupQuery() - def GetAnomalyTestTypesQuery(self): - # Runs on DK Postgres Server - strQ = self.ReplaceParms(read_template_sql_file("profile_anomaly_types_get.sql", sub_directory="profiling")) - return strQ + def GetAnomalyTestTypesQuery(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("profile_anomaly_types_get.sql") - def GetAnomalyTestQuery(self, dct_test_type): - # Runs on DK Postgres Server - strQ = None + def GetAnomalyTestQuery(self, test_type: dict) -> tuple[str, dict] | None: + # Runs on App database + extra_params = { + "ANOMALY_ID": test_type["id"], + "DETAIL_EXPRESSION": test_type["detail_expression"], + "ANOMALY_CRITERIA": test_type["anomaly_criteria"], + } - match dct_test_type["data_object"]: + match test_type["data_object"]: case "Column": - strQ = read_template_sql_file("profile_anomalies_screen_column.sql", sub_directory="profiling") + query, params = self._get_query("profile_anomalies_screen_column.sql", extra_params=extra_params) case "Multi-Col": - strQ = read_template_sql_file("profile_anomalies_screen_multi_column.sql", sub_directory="profiling") + query, params = self._get_query("profile_anomalies_screen_multi_column.sql", extra_params=extra_params) case "Dates": - strQ = read_template_sql_file("profile_anomalies_screen_table_dates.sql", sub_directory="profiling") + query, params = self._get_query("profile_anomalies_screen_table_dates.sql", extra_params=extra_params) case "Table": - strQ = read_template_sql_file("profile_anomalies_screen_table.sql", sub_directory="profiling") + query, params = self._get_query("profile_anomalies_screen_table.sql", extra_params=extra_params) case "Variant": - strQ = read_template_sql_file("profile_anomalies_screen_variants.sql", sub_directory="profiling") - - if strQ: - strQ = strQ.replace("{ANOMALY_ID}", dct_test_type["id"]) - strQ = strQ.replace("{DETAIL_EXPRESSION}", dct_test_type["detail_expression"]) - strQ = strQ.replace("{ANOMALY_CRITERIA}", dct_test_type["anomaly_criteria"]) - strQ = self.ReplaceParms(strQ) - - return strQ - - def GetAnomalyScoringQuery(self, dct_test_type): - # Runs on DK Postgres Server - strQ = read_template_sql_file("profile_anomaly_scoring.sql", sub_directory="profiling") - if strQ: - strQ = strQ.replace("{PROFILE_RUN_ID}", self.profile_run_id) - strQ = strQ.replace("{ANOMALY_ID}", dct_test_type["id"]) - strQ = strQ.replace("{PREV_FORMULA}", dct_test_type["dq_score_prevalence_formula"]) - strQ = strQ.replace("{RISK}", dct_test_type["dq_score_risk_factor"]) - return strQ - - def GetDataCharsRefreshQuery(self): - # Runs on DK Postgres Server + query, params = self._get_query("profile_anomalies_screen_variants.sql", extra_params=extra_params) + case _: + return None + + return query, params + + def GetAnomalyScoringQuery(self, test_type: dict) -> tuple[str, dict]: + # Runs on App database + query = read_template_sql_file("profile_anomaly_scoring.sql", sub_directory="profiling") + params = { + "PROFILE_RUN_ID": self.profile_run_id, + "ANOMALY_ID": test_type["id"], + "PREV_FORMULA": test_type["dq_score_prevalence_formula"], + "RISK": test_type["dq_score_risk_factor"], + } + query = replace_params(query, params) + return query, params + + def GetDataCharsRefreshQuery(self) -> tuple[str, dict]: + # Runs on App database return self._get_data_chars_sql().GetDataCharsUpdateQuery() - def GetCDEFlaggerQuery(self): - # Runs on DK Postgres Server - strQ = self.ReplaceParms( - read_template_sql_file("cde_flagger_query.sql", sub_directory="profiling") - ) - return strQ - - def GetProfileRunInfoRecordsQuery(self): - # Runs on DK Postgres Server - strQ = self.ReplaceParms( - read_template_sql_file("project_profile_run_record_insert.sql", sub_directory="profiling") - ) - return strQ - - def GetProfileRunInfoRecordUpdateQuery(self): - # Runs on DK Postgres Server - strQ = self.ReplaceParms( - read_template_sql_file("project_profile_run_record_update.sql", sub_directory="profiling") - ) - return strQ - - def GetDDFQuery(self): - # Runs on Project DB + def GetCDEFlaggerQuery(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("cde_flagger_query.sql") + + def GetProfileRunInfoRecordsQuery(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("project_profile_run_record_insert.sql") + + def GetProfileRunInfoRecordUpdateQuery(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("project_profile_run_record_update.sql") + + def GetDDFQuery(self) -> tuple[str, dict]: + # Runs on Target database return self._get_data_chars_sql().GetDDFQuery() - def GetProfilingQuery(self): - # Runs on Project DB + def GetProfilingQuery(self) -> tuple[str, dict]: + # Runs on Target database if not self.dctSnippetTemplate: self.dctSnippetTemplate = read_template_yaml_file( f"project_profiling_query_{self.flavor}.yaml", sub_directory=f"flavors/{self.flavor}/profiling" @@ -333,48 +323,38 @@ def GetProfilingQuery(self): strQ += dctSnippetTemplate["strTemplate98_else"] if self.col_gen_type == "N": - strQ += dctSnippetTemplate["strTemplate99_N"] + if self.parm_do_sample == "Y": + strQ += dctSnippetTemplate["strTemplate99_N_sampling"] + else: + strQ += dctSnippetTemplate["strTemplate99_N"] else: strQ += dctSnippetTemplate["strTemplate99_else"] if self.parm_do_sample == "Y": strQ += dctSnippetTemplate["strTemplate100_sampling"] - strQ = self.ReplaceParms(strQ) + params = self._get_params() + query = replace_params(strQ, params) + query = replace_templated_functions(query, self.flavor) - return strQ + return query, params - def GetSecondProfilingQuery(self): - # Runs on Project DB - strQ = self.ReplaceParms( - read_template_sql_file( - f"project_secondary_profiling_query_{self.flavor}.sql", sub_directory=f"flavors/{self.flavor}/profiling" - ) - ) - return strQ - - def GetTableSampleCount(self): - # Runs on Project DB - strQ = self.ReplaceParms( - read_template_sql_file("project_get_table_sample_count.sql", sub_directory="profiling") - ) - return strQ - - def GetContingencyColumns(self): - # Runs on Project DB - strQ = self.ReplaceParms(read_template_sql_file("contingency_columns.sql", sub_directory="profiling")) - return strQ - - def GetContingencyCounts(self): - # Runs on Project DB - strQ = self.ReplaceParms( - read_template_sql_file("contingency_counts.sql", sub_directory="flavors/generic/profiling") - ) - return strQ - - def UpdateProfileResultsToEst(self): - # Runs on Project DB - strQ = self.ReplaceParms( - read_template_sql_file("project_update_profile_results_to_estimates.sql", sub_directory="profiling") - ) - return strQ + def GetSecondProfilingQuery(self) -> tuple[str, dict]: + # Runs on Target database + return self._get_query(f"project_secondary_profiling_query_{self.flavor}.sql", f"flavors/{self.flavor}/profiling") + + def GetTableSampleCount(self) -> tuple[str, dict]: + # Runs on Target database + return self._get_query(f"project_get_table_sample_count_{self.flavor}.sql", f"flavors/{self.flavor}/profiling") + + def GetContingencyColumns(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("contingency_columns.sql") + + def GetContingencyCounts(self) -> tuple[str, dict]: + # Runs on Target database + return self._get_query("contingency_counts.sql", "flavors/generic/profiling") + + def UpdateProfileResultsToEst(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("project_update_profile_results_to_estimates.sql") diff --git a/testgen/commands/queries/refresh_data_chars_query.py b/testgen/commands/queries/refresh_data_chars_query.py index 694eeefb..414616f0 100644 --- a/testgen/commands/queries/refresh_data_chars_query.py +++ b/testgen/commands/queries/refresh_data_chars_query.py @@ -1,4 +1,6 @@ from testgen.common import read_template_sql_file +from testgen.common.database.database_service import replace_params +from testgen.common.database.flavor.flavor_service import SQLFlavor from testgen.utils import chunk_queries @@ -7,7 +9,7 @@ class CRefreshDataCharsSQL: source_table: str project_code: str - sql_flavor: str + sql_flavor: SQLFlavor table_group_schema: str table_group_id: str @@ -30,66 +32,70 @@ def __init__(self, params: dict, run_date: str, source_table: str): self.profiling_include_mask = params["profiling_include_mask"] self.profiling_exclude_mask = params["profiling_exclude_mask"] - def _replace_params(self, sql_query: str) -> str: - sql_query = sql_query.replace("{PROJECT_CODE}", self.project_code) - sql_query = sql_query.replace("{DATA_SCHEMA}", self.table_group_schema) - sql_query = sql_query.replace("{TABLE_GROUPS_ID}", self.table_group_id) - sql_query = sql_query.replace("{RUN_DATE}", self.run_date) - sql_query = sql_query.replace("{SOURCE_TABLE}", self.source_table) - return sql_query + def _get_query(self, template_file_name: str, sub_directory: str | None = "data_chars") -> tuple[str, dict]: + query = read_template_sql_file(template_file_name, sub_directory) + params = { + "PROJECT_CODE": self.project_code, + "DATA_SCHEMA": self.table_group_schema, + "TABLE_GROUPS_ID": self.table_group_id, + "RUN_DATE": self.run_date, + "SOURCE_TABLE": self.source_table, + } + query = replace_params(query, params) + return query, params def _get_mask_query(self, mask: str, is_include: bool) -> str: - sub_query = "" - if mask: - sub_query += " AND (" if is_include else " AND NOT (" - is_first = True - escape = "" - if self.sql_flavor.startswith("mssql"): - escaped_underscore = "[_]" - elif self.sql_flavor == "snowflake": - escaped_underscore = "\\\\_" - escape = "ESCAPE '\\\\'" - elif self.sql_flavor == "redshift": - escaped_underscore = "\\\\_" - else: - escaped_underscore = "\\_" - for item in mask.split(","): - if not is_first: - sub_query += " OR " - item = item.strip().replace("_", escaped_underscore) - sub_query += f"(c.table_name LIKE '{item}' {escape})" - is_first = False - sub_query += ")" - return sub_query + escape = "" + if self.sql_flavor.startswith("mssql"): + escaped_underscore = "[_]" + elif self.sql_flavor == "snowflake": + escaped_underscore = "\\\\_" + escape = "ESCAPE '\\\\'" + elif self.sql_flavor == "redshift": + escaped_underscore = "\\\\_" + else: + escaped_underscore = "\\_" - def GetDDFQuery(self) -> str: - # Runs on Project DB - sql_query = self._replace_params( - read_template_sql_file( - f"schema_ddf_query_{self.sql_flavor}.sql", sub_directory=f"flavors/{self.sql_flavor}/data_chars" + table_names = [ item.strip().replace("_", escaped_underscore) for item in mask.split(",") ] + sub_query = f""" + AND {"NOT" if not is_include else ""} ( + {" OR ".join([ f"(c.table_name LIKE '{item}' {escape})" for item in table_names ])} ) - ) + """ + + return sub_query + + def GetDDFQuery(self) -> tuple[str, dict]: + # Runs on Target database + query, params = self._get_query(f"schema_ddf_query_{self.sql_flavor}.sql", f"flavors/{self.sql_flavor}/data_chars") table_criteria = "" if self.profiling_table_set: table_criteria += f" AND c.table_name IN ({self.profiling_table_set})" - table_criteria += self._get_mask_query(self.profiling_include_mask, is_include=True) - table_criteria += self._get_mask_query(self.profiling_exclude_mask, is_include=False) - sql_query = sql_query.replace("{TABLE_CRITERIA}", table_criteria) - return sql_query + if self.profiling_include_mask: + table_criteria += self._get_mask_query(self.profiling_include_mask, is_include=True) + + if self.profiling_exclude_mask: + table_criteria += self._get_mask_query(self.profiling_exclude_mask, is_include=False) - def GetRecordCountQueries(self, schema_tables: list[str]) -> list[str]: + query = query.replace("{TABLE_CRITERIA}", table_criteria) + + return query, params + + def GetRecordCountQueries(self, schema_tables: list[str]) -> list[tuple[str, None]]: + # Runs on Target database count_queries = [ f"SELECT '{item}', COUNT(*) FROM {item}" for item in schema_tables ] - return chunk_queries(count_queries, " UNION ALL ", self.max_query_chars) - - def GetDataCharsUpdateQuery(self) -> str: - # Runs on DK Postgres Server - return self._replace_params(read_template_sql_file("data_chars_update.sql", sub_directory="data_chars")) - - def GetStagingDeleteQuery(self) -> str: - # Runs on DK Postgres Server - return self._replace_params(read_template_sql_file("data_chars_staging_delete.sql", sub_directory="data_chars")) + chunked_queries = chunk_queries(count_queries, " UNION ALL ", self.max_query_chars) + return [ (query, None) for query in chunked_queries ] + + def GetDataCharsUpdateQuery(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("data_chars_update.sql") + + def GetStagingDeleteQuery(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("data_chars_staging_delete.sql") diff --git a/testgen/commands/queries/rollup_scores_query.py b/testgen/commands/queries/rollup_scores_query.py index b23a945f..dde0d556 100644 --- a/testgen/commands/queries/rollup_scores_query.py +++ b/testgen/commands/queries/rollup_scores_query.py @@ -1,32 +1,38 @@ +from uuid import UUID + from testgen.common import read_template_sql_file +from testgen.common.database.database_service import replace_params class CRollupScoresSQL: run_id: str table_group_id: str - def __init__(self, run_id: str, table_group_id: str | None = None): + def __init__(self, run_id: str, table_group_id: str | UUID | None = None): self.run_id = run_id - self.table_group_id = table_group_id + self.table_group_id = str(table_group_id) - def _replace_params(self, sql_query: str) -> str: - sql_query = sql_query.replace("{RUN_ID}", self.run_id) - if self.table_group_id: - sql_query = sql_query.replace("{TABLE_GROUPS_ID}", self.table_group_id) - return sql_query + def _get_query(self, template_file_name: str, sub_directory: str | None = "rollup_scores") -> tuple[str, dict]: + query = read_template_sql_file(template_file_name, sub_directory) + params = { + "RUN_ID": self.run_id, + "TABLE_GROUPS_ID": self.table_group_id or "" + } + query = replace_params(query, params) + return query, params - def GetRollupScoresProfileRunQuery(self): - # Runs on DK Postgres Server - return self._replace_params(read_template_sql_file("rollup_scores_profile_run.sql", sub_directory="rollup_scores")) + def GetRollupScoresProfileRunQuery(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("rollup_scores_profile_run.sql") - def GetRollupScoresProfileTableGroupQuery(self): - # Runs on DK Postgres Server - return self._replace_params(read_template_sql_file("rollup_scores_profile_table_group.sql", sub_directory="rollup_scores")) + def GetRollupScoresProfileTableGroupQuery(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("rollup_scores_profile_table_group.sql") - def GetRollupScoresTestRunQuery(self): - # Runs on DK Postgres Server - return self._replace_params(read_template_sql_file("rollup_scores_test_run.sql", sub_directory="rollup_scores")) + def GetRollupScoresTestRunQuery(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("rollup_scores_test_run.sql") - def GetRollupScoresTestTableGroupQuery(self): - # Runs on DK Postgres Server - return self._replace_params(read_template_sql_file("rollup_scores_test_table_group.sql", sub_directory="rollup_scores")) + def GetRollupScoresTestTableGroupQuery(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("rollup_scores_test_table_group.sql") diff --git a/testgen/commands/queries/test_parameter_validation_query.py b/testgen/commands/queries/test_parameter_validation_query.py index d34e2102..ec8cf408 100644 --- a/testgen/commands/queries/test_parameter_validation_query.py +++ b/testgen/commands/queries/test_parameter_validation_query.py @@ -1,90 +1,70 @@ import typing from testgen.common import CleanSQL, date_service, read_template_sql_file +from testgen.common.database.database_service import replace_params class CTestParamValidationSQL: flavor = "" run_date = "" test_run_id = "" - project_code = "" - test_suite = "" - test_schemas = "" + test_schemas: str = "" message = "" - test_ids = [] # noqa + test_ids: typing.ClassVar = [] exception_message = "" flag_val = "" - # Test Set Parameters - dctTestParms: typing.ClassVar = {} + _use_clean = False def __init__(self, strFlavor, strTestSuiteId): self.flavor = strFlavor self.test_suite_id = strTestSuiteId self.today = date_service.get_now_as_string() - def _ReplaceParms(self, strInputString): - strInputString = strInputString.replace("{TEST_SUITE_ID}", self.test_suite_id) - strInputString = strInputString.replace("{RUN_DATE}", self.run_date) - strInputString = strInputString.replace("{TEST_RUN_ID}", self.test_run_id) - strInputString = strInputString.replace("{FLAG}", self.flag_val) - strInputString = strInputString.replace("{TEST_SCHEMAS}", self.test_schemas) - strInputString = strInputString.replace("{EXCEPTION_MESSAGE}", self.exception_message) - strInputString = strInputString.replace("{MESSAGE}", self.message) - strInputString = strInputString.replace("{CAT_TEST_IDS}", ", ".join(map(str, self.test_ids))) - strInputString = strInputString.replace("{START_TIME}", self.today) - strInputString = strInputString.replace("{NOW}", date_service.get_now_as_string()) - - for parm, value in self.dctTestParms.items(): - strInputString = strInputString.replace("{" + parm.upper() + "}", value) - - return strInputString - - def ClearTestParms(self): - # Test Set Parameters - pass - - def GetTestValidationColumns(self, booClean): - # Runs on DK DB - strQ = self._ReplaceParms(read_template_sql_file("ex_get_test_column_list_tg.sql", "validate_tests")) - if booClean: - strQ = CleanSQL(strQ) - - return strQ - - def GetProjectTestValidationColumns(self): - # Runs on Project DB - strQ = self._ReplaceParms( - read_template_sql_file("ex_get_project_column_list_generic.sql", "flavors/generic/validate_tests") - ) - - return strQ - - def PrepFlagTestsWithFailedValidation(self): - # Runs on Project DB - strQ = self._ReplaceParms(read_template_sql_file("ex_prep_flag_tests_test_definitions.sql", "validate_tests")) - - return strQ - - def FlagTestsWithFailedValidation(self): - # Runs on Project DB - strQ = self._ReplaceParms(read_template_sql_file("ex_flag_tests_test_definitions.sql", "validate_tests")) - - return strQ - - def DisableTestsWithFailedValidation(self): - # Runs on Project DB - strQ = self._ReplaceParms(read_template_sql_file("ex_disable_tests_test_definitions.sql", "validate_tests")) - - return strQ - - def ReportTestValidationErrors(self): - # Runs on Project DB - strQ = self._ReplaceParms(read_template_sql_file("ex_write_test_val_errors.sql", "validate_tests")) - - return strQ - - def PushTestRunStatusUpdateSQL(self): - strQ = self._ReplaceParms(read_template_sql_file("ex_update_test_record_in_testrun_table.sql", "execution")) - - return strQ + def _get_query(self, template_file_name: str, sub_directory: str | None = "validate_tests") -> tuple[str, dict]: + query = read_template_sql_file(template_file_name, sub_directory) + params = { + "TEST_SUITE_ID": self.test_suite_id, + "RUN_DATE": self.run_date, + "TEST_RUN_ID": self.test_run_id, + "FLAG": self.flag_val, + "TEST_SCHEMAS": self.test_schemas, + "EXCEPTION_MESSAGE": self.exception_message, + "MESSAGE": self.message, + "CAT_TEST_IDS": tuple(self.test_ids or []), + "START_TIME": self.today, + "NOW_TIMESTAMP": date_service.get_now_as_string(), + } + query = replace_params(query, params) + return query, params + + def GetTestValidationColumns(self) -> tuple[str, dict]: + # Runs on App database + query, params = self._get_query("ex_get_test_column_list_tg.sql") + if self._use_clean: + query = CleanSQL(query) + return query, params + + def GetProjectTestValidationColumns(self) -> tuple[str, dict]: + # Runs on Target database + return self._get_query("ex_get_project_column_list_generic.sql", "flavors/generic/validate_tests") + + def PrepFlagTestsWithFailedValidation(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("ex_prep_flag_tests_test_definitions.sql") + + def FlagTestsWithFailedValidation(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("ex_flag_tests_test_definitions.sql") + + def DisableTestsWithFailedValidation(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("ex_disable_tests_test_definitions.sql") + + def ReportTestValidationErrors(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("ex_write_test_val_errors.sql") + + def PushTestRunStatusUpdateSQL(self) -> tuple[str, dict]: + # Runs on App database + return self._get_query("ex_update_test_record_in_testrun_table.sql", "execution") diff --git a/testgen/commands/run_execute_cat_tests.py b/testgen/commands/run_execute_cat_tests.py index 090adb60..15d30a14 100644 --- a/testgen/commands/run_execute_cat_tests.py +++ b/testgen/commands/run_execute_cat_tests.py @@ -1,70 +1,26 @@ import logging from datetime import UTC, datetime +from progress.spinner import Spinner + from testgen import settings from testgen.commands.queries.execute_cat_tests_query import CCATExecutionSQL from testgen.commands.run_refresh_score_cards_results import run_refresh_score_cards_results from testgen.common import ( - RetrieveDBResultsToDictList, - RunActionQueryList, - RunThreadedRetrievalQueryList, - WriteListToDB, date_service, + execute_db_queries, + fetch_dict_from_db, + fetch_from_db_threaded, + write_to_app_db, ) +from testgen.common.get_pipeline_parms import TestExecutionParams from testgen.common.mixpanel_service import MixpanelService LOG = logging.getLogger("testgen") -def RetrieveTargetTables(clsCATExecute): - # Gets distinct list of tables to be tested, to aggregate tests by table, from dk db - strQuery = clsCATExecute.GetDistinctTablesSQL() - lstTables = RetrieveDBResultsToDictList("DKTG", strQuery) - - if len(lstTables) == 0: - LOG.info("0 tables in the list for CAT test execution.") - - return lstTables - - -def AggregateTableTests(clsCATExecute): - # Writes records of aggregated tests per table and sequence number - # (to prevent table queries from getting too large) to dk db. - strQuery = clsCATExecute.GetAggregateTableTestSQL() - lstQueries = [strQuery] - RunActionQueryList("DKTG", lstQueries) - - -def RetrieveTestParms(clsCATExecute): - # Retrieves records of aggregated tests to run as queries from dk db - strQuery = clsCATExecute.GetAggregateTestParmsSQL() - lstResults = RetrieveDBResultsToDictList("DKTG", strQuery) - - return lstResults - - -def PrepCATQueries(clsCATExecute, lstCATParms): - # Prepares CAT Queries and populates query list - LOG.info("CurrentStep: Preparing CAT Queries") - lstQueries = [] - for dctCATQuery in lstCATParms: - clsCATExecute.target_schema = dctCATQuery["schema_name"] - clsCATExecute.target_table = dctCATQuery["table_name"] - clsCATExecute.dctTestParms = dctCATQuery - strQuery = clsCATExecute.PrepCATQuerySQL() - lstQueries.append(strQuery) - - return lstQueries - - -def ParseCATResults(clsCATExecute): - # Parses aggregate results to individual test_result records at dk db - strQuery = clsCATExecute.GetCATResultsParseSQL() - RunActionQueryList("DKTG", [strQuery]) - - def FinalizeTestRun(clsCATExecute: CCATExecutionSQL, username: str | None = None): - _, row_counts = RunActionQueryList(("DKTG"), [ + _, row_counts = execute_db_queries([ clsCATExecute.FinalizeTestResultsSQL(), clsCATExecute.PushTestRunStatusUpdateSQL(), clsCATExecute.FinalizeTestSuiteUpdateSQL(), @@ -72,7 +28,7 @@ def FinalizeTestRun(clsCATExecute: CCATExecutionSQL, username: str | None = None end_time = datetime.now(UTC) try: - RunActionQueryList(("DKTG"), [ + execute_db_queries([ clsCATExecute.CalcPrevalenceTestResultsSQL(), clsCATExecute.TestScoringRollupRunSQL(), clsCATExecute.TestScoringRollupTableGroupSQL(), @@ -98,17 +54,25 @@ def FinalizeTestRun(clsCATExecute: CCATExecutionSQL, username: str | None = None def run_cat_test_queries( - dctParms, strTestRunID, strTestTime, strProjectCode, strTestSuite, error_msg, username=None, minutes_offset=0, spinner=None + params: TestExecutionParams, + test_run_id: str, + test_time: str, + project_code: str, + test_suite: str, + error_msg: str, + username: str | None = None, + minutes_offset: int = 0, + spinner: Spinner | None = None ): - booErrors = False + has_errors = False LOG.info("CurrentStep: Initializing CAT Query Generator") clsCATExecute = CCATExecutionSQL( - strProjectCode, dctParms["test_suite_id"], strTestSuite, dctParms["sql_flavor"], dctParms["max_query_chars"], minutes_offset + project_code, params["test_suite_id"], test_suite, params["sql_flavor"], params["max_query_chars"], minutes_offset ) - clsCATExecute.test_run_id = strTestRunID - clsCATExecute.run_date = strTestTime - clsCATExecute.table_groups_id = dctParms["table_groups_id"] + clsCATExecute.test_run_id = test_run_id + clsCATExecute.run_date = test_time + clsCATExecute.table_groups_id = params["table_groups_id"] clsCATExecute.exception_message += error_msg # START TEST EXECUTION @@ -121,7 +85,8 @@ def run_cat_test_queries( try: # Retrieve distinct target tables from metadata LOG.info("CurrentStep: Retrieving Target Tables") - lstTables = RetrieveTargetTables(clsCATExecute) + # Gets distinct list of tables to be tested, to aggregate tests by table, from dk db + lstTables = fetch_dict_from_db(*clsCATExecute.GetDistinctTablesSQL()) LOG.info("Test Tables Identified: %s", len(lstTables)) if lstTables: @@ -129,27 +94,39 @@ def run_cat_test_queries( for dctTable in lstTables: clsCATExecute.target_schema = dctTable["schema_name"] clsCATExecute.target_table = dctTable["table_name"] - AggregateTableTests(clsCATExecute) + # Writes records of aggregated tests per table and sequence number + # (to prevent table queries from getting too large) to dk db. + execute_db_queries([clsCATExecute.GetAggregateTableTestSQL()]) LOG.info("CurrentStep: Retrieving CAT Tests to Run") - lstCATParms = RetrieveTestParms(clsCATExecute) + # Retrieves records of aggregated tests to run as queries from dk db + lstCATParms = fetch_dict_from_db(*clsCATExecute.GetAggregateTestParmsSQL()) + + lstCATQueries = [] + # Prepares CAT Queries and populates query list + LOG.info("CurrentStep: Preparing CAT Queries") + for dctCATQuery in lstCATParms: + clsCATExecute.target_schema = dctCATQuery["schema_name"] + clsCATExecute.target_table = dctCATQuery["table_name"] + clsCATExecute.cat_test_params = dctCATQuery + lstCATQueries.append(clsCATExecute.PrepCATQuerySQL()) - lstCATQueries = PrepCATQueries(clsCATExecute, lstCATParms) if lstCATQueries: LOG.info("CurrentStep: Performing CAT Tests") - lstAllResults, lstResultColumnNames, intErrors = RunThreadedRetrievalQueryList( - "PROJECT", lstCATQueries, dctParms["max_threads"], spinner + lstAllResults, lstResultColumnNames, intErrors = fetch_from_db_threaded( + lstCATQueries, use_target_db=True, max_threads=params["max_threads"], spinner=spinner ) if lstAllResults: LOG.info("CurrentStep: Saving CAT Results") # Write aggregate result records to aggregate result table at dk db - WriteListToDB("DKTG", lstAllResults, lstResultColumnNames, "working_agg_cat_results") + write_to_app_db(lstAllResults, lstResultColumnNames, "working_agg_cat_results") LOG.info("CurrentStep: Parsing CAT Results") - ParseCATResults(clsCATExecute) + # Parses aggregate results to individual test_result records at dk db + execute_db_queries([clsCATExecute.GetCATResultsParseSQL()]) LOG.info("Test results successfully parsed.") if intErrors > 0: - booErrors = True + has_errors = True cat_error_msg = f"Errors were encountered executing aggregate tests. ({intErrors} errors occurred.) Please check log." LOG.warning(cat_error_msg) clsCATExecute.exception_message += cat_error_msg @@ -157,14 +134,14 @@ def run_cat_test_queries( LOG.info("No valid tests were available to perform") except Exception as e: - booErrors = True + has_errors = True sqlsplit = e.args[0].split("[SQL", 1) errorline = sqlsplit[0].replace("'", "''") if len(sqlsplit) > 0 else "unknown error" clsCATExecute.exception_message += f"{type(e).__name__}: {errorline}" raise else: - return booErrors + return has_errors finally: LOG.info("Finalizing test run") diff --git a/testgen/commands/run_execute_tests.py b/testgen/commands/run_execute_tests.py index 511ec1fd..e5ff2beb 100644 --- a/testgen/commands/run_execute_tests.py +++ b/testgen/commands/run_execute_tests.py @@ -9,15 +9,18 @@ from testgen import settings from testgen.commands.queries.execute_tests_query import CTestExecutionSQL from testgen.common import ( - AssignConnectParms, - RetrieveDBResultsToDictList, - RetrieveTestExecParms, - RunActionQueryList, - RunThreadedRetrievalQueryList, - WriteListToDB, date_service, + execute_db_queries, + fetch_dict_from_db, + fetch_from_db_threaded, + get_test_execution_params, + set_target_db_params, + write_to_app_db, ) -from testgen.common.database.database_service import ExecuteDBQuery, empty_cache +from testgen.common.database.database_service import empty_cache +from testgen.common.get_pipeline_parms import TestExecutionParams +from testgen.common.models import with_database_session +from testgen.common.models.connection import Connection from testgen.ui.session import session from .run_execute_cat_tests import run_cat_test_queries @@ -27,34 +30,47 @@ LOG = logging.getLogger("testgen") -def add_test_run_record(test_run_id, test_suite_id, test_time, process_id): - query = f""" +def add_test_run_record(test_run_id: str, test_suite_id: str, test_time: str, process_id: int): + execute_db_queries([( + """ INSERT INTO test_runs(id, test_suite_id, test_starttime, process_id) - (SELECT '{test_run_id}':: UUID as id, - '{test_suite_id}' as test_suite_id, - '{test_time}' as test_starttime, - '{process_id}' as process_id); - """ - ExecuteDBQuery("DKTG", query) - - -def run_test_queries(dctParms, strTestRunID, strTestTime, strProjectCode, strTestSuite, minutes_offset=0, spinner=None): - booErrors = False + (SELECT :test_run_id as id, + :test_suite_id as test_suite_id, + :test_time as test_starttime, + :process_id as process_id); + """, + { + "test_run_id": test_run_id, + "test_suite_id": test_suite_id, + "test_time": test_time, + "process_id": process_id, + } + )]) + + +def run_test_queries( + params: TestExecutionParams, + test_run_id: str, + test_time: str, + project_code: str, + test_suite: str, + minutes_offset: int = 0, + spinner: Spinner | None = None, +): + has_errors = False error_msg = "" LOG.info("CurrentStep: Initializing Query Generator") - clsExecute = CTestExecutionSQL(strProjectCode, dctParms["sql_flavor"], dctParms["test_suite_id"], strTestSuite, minutes_offset) - clsExecute.run_date = strTestTime - clsExecute.test_run_id = strTestRunID + clsExecute = CTestExecutionSQL(project_code, params["sql_flavor"], params["test_suite_id"], test_suite, minutes_offset) + clsExecute.run_date = test_time + clsExecute.test_run_id = test_run_id clsExecute.process_id = process_service.get_current_process_id() - booClean = False try: # Retrieve non-CAT Queries LOG.info("CurrentStep: Retrieve Non-CAT Queries") - strQuery = clsExecute.GetTestsNonCAT(booClean) - lstTestSet = RetrieveDBResultsToDictList("DKTG", strQuery) + lstTestSet = fetch_dict_from_db(*clsExecute.GetTestsNonCAT()) if len(lstTestSet) == 0: LOG.debug("0 non-CAT Queries retrieved.") @@ -63,25 +79,23 @@ def run_test_queries(dctParms, strTestRunID, strTestTime, strProjectCode, strTes LOG.info("CurrentStep: Preparing Non-CAT Tests") lstTestQueries = [] for dctTest in lstTestSet: - # Set Test Parms - clsExecute.ClearTestParms() - clsExecute.dctTestParms = dctTest - lstTestQueries.append(clsExecute.GetTestQuery(booClean)) + clsExecute.test_params = dctTest + lstTestQueries.append(clsExecute.GetTestQuery()) if spinner: spinner.next() # Execute list, returning test results LOG.info("CurrentStep: Executing Non-CAT Test Queries") - lstTestResults, colResultNames, intErrors = RunThreadedRetrievalQueryList( - "PROJECT", lstTestQueries, dctParms["max_threads"], spinner + lstTestResults, colResultNames, intErrors = fetch_from_db_threaded( + lstTestQueries, use_target_db=True, max_threads=params["max_threads"], spinner=spinner ) # Copy test results to DK DB LOG.info("CurrentStep: Saving Non-CAT Test Results") if lstTestResults: - WriteListToDB("DKTG", lstTestResults, colResultNames, "test_results") + write_to_app_db(lstTestResults, colResultNames, "test_results") if intErrors > 0: - booErrors = True + has_errors = True error_msg = ( f"Errors were encountered executing Referential Tests. ({intErrors} errors occurred.) " "Please check log. " @@ -95,12 +109,11 @@ def run_test_queries(dctParms, strTestRunID, strTestTime, strProjectCode, strTes errorline = sqlsplit[0].replace("'", "''") if len(sqlsplit) > 0 else "unknown error" clsExecute.exception_message = f"{type(e).__name__}: {errorline}" LOG.info("Updating the test run record with exception message") - lstTestRunQuery = [clsExecute.PushTestRunStatusUpdateSQL()] - RunActionQueryList("DKTG", lstTestRunQuery) + execute_db_queries([clsExecute.PushTestRunStatusUpdateSQL()]) raise else: - return booErrors, error_msg + return has_errors, error_msg def run_execution_steps_in_background(project_code, test_suite): @@ -119,12 +132,13 @@ def run_execution_steps_in_background(project_code, test_suite): subprocess.Popen(script) # NOQA S603 +@with_database_session def run_execution_steps( project_code: str, test_suite: str, username: str | None = None, - minutes_offset: int=0, - spinner: Spinner=None, + minutes_offset: int = 0, + spinner: Spinner | None = None, ) -> str: # Initialize required parms for all steps has_errors = False @@ -137,31 +151,19 @@ def run_execution_steps( spinner.next() LOG.info("CurrentStep: Retrieving TestExec Parameters") - test_exec_params = RetrieveTestExecParms(project_code, test_suite) + test_exec_params = get_test_execution_params(project_code, test_suite) # Add a record in Test Run table for the new Test Run add_test_run_record( test_run_id, test_exec_params["test_suite_id"], test_time, process_service.get_current_process_id() ) - LOG.info("CurrentStep: Assigning Connection Parms") - AssignConnectParms( - test_exec_params["project_code"], - test_exec_params["connection_id"], - test_exec_params["project_host"], - test_exec_params["project_port"], - test_exec_params["project_db"], - test_exec_params["table_group_schema"], - test_exec_params["project_user"], - test_exec_params["sql_flavor"], - test_exec_params["url"], - test_exec_params["connect_by_url"], - test_exec_params["connect_by_key"], - test_exec_params["private_key"], - test_exec_params["private_key_passphrase"], - test_exec_params["http_path"], - "PROJECT", - ) + LOG.info("CurrentStep: Assigning Connection Parameters") + connection = Connection.get_by_table_group(test_exec_params["table_groups_id"]) + set_target_db_params(connection.__dict__) + test_exec_params["sql_flavor"] = connection.sql_flavor + test_exec_params["max_query_chars"] = connection.max_query_chars + test_exec_params["max_threads"] = connection.max_threads try: LOG.info("CurrentStep: Execute Step - Data Characteristics Refresh") diff --git a/testgen/commands/run_generate_tests.py b/testgen/commands/run_generate_tests.py index 01fd8fe0..c163fbab 100644 --- a/testgen/commands/run_generate_tests.py +++ b/testgen/commands/run_generate_tests.py @@ -2,112 +2,84 @@ from testgen import settings from testgen.commands.queries.generate_tests_query import CDeriveTestsSQL -from testgen.common import AssignConnectParms, RetrieveDBResultsToDictList, RetrieveTestGenParms, RunActionQueryList +from testgen.common import execute_db_queries, fetch_dict_from_db, get_test_generation_params, set_target_db_params from testgen.common.mixpanel_service import MixpanelService +from testgen.common.models import with_database_session +from testgen.common.models.connection import Connection LOG = logging.getLogger("testgen") -def run_test_gen_queries(strTableGroupsID, strTestSuite, strGenerationSet=None): - if strTableGroupsID is None: +@with_database_session +def run_test_gen_queries(table_group_id: str, test_suite: str, generation_set: str | None = None): + if table_group_id is None: raise ValueError("Table Group ID was not specified") - clsTests = CDeriveTestsSQL() + LOG.info("CurrentStep: Assigning Connection Parameters") + connection = Connection.get_by_table_group(table_group_id) + set_target_db_params(connection.__dict__) - # Set General Parms - booClean = False + clsTests = CDeriveTestsSQL() - LOG.info("CurrentStep: Retrieving General Parameters for Test Suite %s", strTestSuite) - dctParms = RetrieveTestGenParms(strTableGroupsID, strTestSuite) + LOG.info(f"CurrentStep: Retrieving General Parameters for Test Suite {test_suite}") + params = get_test_generation_params(table_group_id, test_suite) - # Set Project Connection Parms from retrieved parms - LOG.info("CurrentStep: Assigning Connection Parameters") - AssignConnectParms( - dctParms["project_code"], - dctParms["connection_id"], - dctParms["project_host"], - dctParms["project_port"], - dctParms["project_db"], - dctParms["table_group_schema"], - dctParms["project_user"], - dctParms["sql_flavor"], - dctParms["url"], - dctParms["connect_by_url"], - dctParms["connect_by_key"], - dctParms["private_key"], - dctParms["private_key_passphrase"], - dctParms["http_path"], - "PROJECT", - ) # Set static parms - clsTests.project_code = dctParms["project_code"] - clsTests.test_suite = strTestSuite - clsTests.generation_set = strGenerationSet if strGenerationSet is not None else "" - clsTests.test_suite_id = dctParms["test_suite_id"] if dctParms["test_suite_id"] else "" - clsTests.connection_id = str(dctParms["connection_id"]) - clsTests.table_groups_id = strTableGroupsID - clsTests.sql_flavor = dctParms["sql_flavor"] - clsTests.data_schema = dctParms["table_group_schema"] - if dctParms["profiling_as_of_date"] is not None: - clsTests.as_of_date = dctParms["profiling_as_of_date"].strftime("%Y-%m-%d %H:%M:%S") - - if dctParms["test_suite_id"]: - clsTests.test_suite_id = dctParms["test_suite_id"] + clsTests.project_code = params["project_code"] + clsTests.test_suite = test_suite + clsTests.generation_set = generation_set if generation_set is not None else "" + clsTests.test_suite_id = params["test_suite_id"] if params["test_suite_id"] else "" + clsTests.connection_id = str(connection.connection_id) + clsTests.table_groups_id = table_group_id + clsTests.sql_flavor = connection.sql_flavor + clsTests.data_schema = params["table_group_schema"] + if params["profiling_as_of_date"] is not None: + clsTests.as_of_date = params["profiling_as_of_date"].strftime("%Y-%m-%d %H:%M:%S") + + if params["test_suite_id"]: + clsTests.test_suite_id = params["test_suite_id"] else: LOG.info("CurrentStep: Creating new Test Suite") - strQuery = clsTests.GetInsertTestSuiteSQL(booClean) - if strQuery: - insert_ids, _ = RunActionQueryList("DKTG", [strQuery]) - clsTests.test_suite_id = insert_ids[0] - else: - raise ValueError("Test Suite not found and could not be created") + insert_ids, _ = execute_db_queries([clsTests.GetInsertTestSuiteSQL()]) + clsTests.test_suite_id = insert_ids[0] LOG.info("CurrentStep: Compiling Test Gen Queries") - lstFunnyTemplateQueries = clsTests.GetTestDerivationQueriesAsList("gen_funny_cat_tests", booClean) - lstQueryTemplateQueries = clsTests.GetTestDerivationQueriesAsList("gen_query_tests", booClean) + lstFunnyTemplateQueries = clsTests.GetTestDerivationQueriesAsList("gen_funny_cat_tests") + lstQueryTemplateQueries = clsTests.GetTestDerivationQueriesAsList("gen_query_tests") lstGenericTemplateQueries = [] # Delete old Tests - strDeleteQuery = clsTests.GetDeleteOldTestsQuery(booClean) + deleteQuery = clsTests.GetDeleteOldTestsQuery() # Retrieve test_types as parms from list of dictionaries: test_type, selection_criteria, default_parm_columns, # default_parm_values - strQuery = clsTests.GetTestTypesSQL(booClean) - - # Execute Query - if strQuery: - lstTestTypes = RetrieveDBResultsToDictList("DKTG", strQuery) - - if lstTestTypes is None: - raise ValueError("Test Type Parameters not found") - elif ( - lstTestTypes[0]["test_type"] == "" - or lstTestTypes[0]["selection_criteria"] == "" - or lstTestTypes[0]["default_parm_columns"] == "" - or lstTestTypes[0]["default_parm_values"] == "" - ): - raise ValueError("Test Type parameters not correctly set") - else: - raise ValueError("Test Type Queries were not generated") + lstTestTypes = fetch_dict_from_db(*clsTests.GetTestTypesSQL()) + + if lstTestTypes is None: + raise ValueError("Test Type Parameters not found") + elif ( + lstTestTypes[0]["test_type"] == "" + or lstTestTypes[0]["selection_criteria"] == "" + or lstTestTypes[0]["default_parm_columns"] == "" + or lstTestTypes[0]["default_parm_values"] == "" + ): + raise ValueError("Test Type parameters not correctly set") + lstGenericTemplateQueries = [] for dctTestParms in lstTestTypes: - clsTests.ClearTestParms() - clsTests.dctTestParms = dctTestParms - strQuery = clsTests.GetTestQueriesFromGenericFile(booClean) - - if strQuery: - lstGenericTemplateQueries.append(strQuery) + clsTests.gen_test_params = dctTestParms + lstGenericTemplateQueries.append(clsTests.GetTestQueriesFromGenericFile()) LOG.info("TestGen CAT Queries were compiled") # Make sure delete, then generic templates run before the funny templates - lstQueries = [strDeleteQuery, *lstGenericTemplateQueries, *lstFunnyTemplateQueries, *lstQueryTemplateQueries] + lstQueries = [deleteQuery, *lstGenericTemplateQueries, *lstFunnyTemplateQueries, *lstQueryTemplateQueries] if lstQueries: LOG.info("Running Test Generation Template Queries") - RunActionQueryList("DKTG", lstQueries) + execute_db_queries(lstQueries) message = "Test generation completed successfully." else: message = "No TestGen Queries were compiled." diff --git a/testgen/commands/run_get_entities.py b/testgen/commands/run_get_entities.py index efda24e4..aa0ad04d 100644 --- a/testgen/commands/run_get_entities.py +++ b/testgen/commands/run_get_entities.py @@ -1,110 +1,83 @@ -import logging +from testgen.common import fetch_list_from_db, read_template_sql_file -from testgen.common import RetrieveDBResultsToList, read_template_sql_file -LOG = logging.getLogger("testgen") - - -def run_list_profiles(table_groups_id): - sql_template = read_template_sql_file("get_profile_list.sql", "get_entities") - - sql_template = sql_template.replace("{TABLE_GROUPS_ID}", table_groups_id) - - return RetrieveDBResultsToList("DKTG", sql_template) +def run_list_profiles(table_group_id: str): + return fetch_list_from_db( + read_template_sql_file("get_profile_list.sql", "get_entities"), + {"TABLE_GROUP_ID": table_group_id}, + ) def run_list_test_types(): sql_template = read_template_sql_file("list_test_types.sql", "get_entities") - return RetrieveDBResultsToList("DKTG", sql_template) + return fetch_list_from_db(sql_template) def run_list_projects(): sql_template = read_template_sql_file("get_project_list.sql", "get_entities") - return RetrieveDBResultsToList("DKTG", sql_template) + return fetch_list_from_db(sql_template) def run_list_connections(): sql_template = read_template_sql_file("get_connections_list.sql", "get_entities") - return RetrieveDBResultsToList("DKTG", sql_template) + return fetch_list_from_db(sql_template) def run_table_group_list(project_code): sql_template = read_template_sql_file("get_table_group_list.sql", "get_entities") - sql_template = sql_template.replace("{PROJECT_CODE}", project_code) - return RetrieveDBResultsToList("DKTG", sql_template) + return fetch_list_from_db(sql_template, {"PROJECT_CODE": project_code}) def run_list_test_suites(project_code): sql_template = read_template_sql_file("get_test_suite_list.sql", "get_entities") - sql_template = sql_template.replace("{PROJECT_CODE}", project_code) - return RetrieveDBResultsToList("DKTG", sql_template) - - -def run_get_test_suite(project_code, test_suite): - sql_template = read_template_sql_file("get_test_suite.sql", "get_entities") - - sql_template = sql_template.replace("{PROJECT_CODE}", project_code) - sql_template = sql_template.replace("{TEST_SUITE}", test_suite) - - return RetrieveDBResultsToList("DKTG", sql_template) - + return fetch_list_from_db(sql_template, {"PROJECT_CODE": project_code}) -def run_profile_info(profile_run, table_name=None): - if not table_name: - table_name = "%" # if no table_name, we select all the tables - sql_template = read_template_sql_file("get_profile_info.sql", "get_entities") - sql_template = sql_template.replace("{PROFILE_RUN}", str(profile_run)) - sql_template = sql_template.replace("{TABLE_NAME}", table_name) +def run_get_test_suite(project_code: str, test_suite: str): + return fetch_list_from_db( + read_template_sql_file("get_test_suite.sql", "get_entities"), + {"PROJECT_CODE": project_code, "TEST_SUITE": test_suite}, + ) - return RetrieveDBResultsToList("DKTG", sql_template) +def run_profile_info(profiling_run_id: str, table_name: str | None = None): + return fetch_list_from_db( + read_template_sql_file("get_profile_info.sql", "get_entities"), + # if no table_name, we select all the tables + {"PROFILING_RUN_ID": profiling_run_id, "TABLE_NAME": table_name or "%"}, + ) -def run_profile_screen(profile_run, table_name=None): - if not table_name: - table_name = "%" - sql_template = read_template_sql_file("get_profile_screen.sql", "get_entities") - sql_template = sql_template.replace("{PROFILE_RUN}", profile_run) - sql_template = sql_template.replace("{TABLE_NAME}", table_name) +def run_profile_screen(profiling_run_id: str, table_name: str | None = None): + return fetch_list_from_db( + read_template_sql_file("get_profile_screen.sql", "get_entities"), + # if no table_name, we select all the tables + {"PROFILING_RUN_ID": profiling_run_id, "TABLE_NAME": table_name or "%"}, + ) - return RetrieveDBResultsToList("DKTG", sql_template) +def run_list_test_generation(project_code: str, test_suite: str): + return fetch_list_from_db( + read_template_sql_file("get_test_generation_list.sql", "get_entities"), + {"PROJECT_CODE": project_code, "TEST_SUITE": test_suite}, + ) -def run_list_test_generation(project_code, test_suite): - sql_template = read_template_sql_file("get_test_generation_list.sql", "get_entities") - sql_template = sql_template.replace("{PROJECT_CODE}", project_code) - sql_template = sql_template.replace("{TEST_SUITE}", test_suite) +def run_test_info(project_code: str, test_suite: str): + return fetch_list_from_db( + read_template_sql_file("get_test_info.sql", "get_entities"), + {"PROJECT_CODE": project_code, "TEST_SUITE": test_suite}, + ) - return RetrieveDBResultsToList("DKTG", sql_template) +def run_list_test_runs(project_code: str, test_suite: str): + return fetch_list_from_db( + read_template_sql_file("get_test_run_list.sql", "get_entities"), + {"PROJECT_CODE": project_code, "TEST_SUITE": test_suite}, + ) -def run_test_info(project_code, test_suite): - sql_template = read_template_sql_file("get_test_info.sql", "get_entities") - sql_template = sql_template.replace("{PROJECT_CODE}", project_code) - sql_template = sql_template.replace("{TEST_SUITE}", test_suite) - - return RetrieveDBResultsToList("DKTG", sql_template) - - -def run_list_test_runs(project_code, test_suite): - sql_template = read_template_sql_file("get_test_run_list.sql", "get_entities") - - sql_template = sql_template.replace("{PROJECT_CODE}", project_code) - sql_template = sql_template.replace("{TEST_SUITE}", test_suite) - - return RetrieveDBResultsToList("DKTG", sql_template) - - -def run_get_results(test_run_id, booErrorsOnly): +def run_get_results(test_run_id: str, errors_only: bool): sql_template = read_template_sql_file("get_test_results_for_run_cli.sql", "get_entities") - - sql_template = sql_template.replace("{TEST_RUN_ID}", test_run_id) - if booErrorsOnly: - sql_template = sql_template.replace("{ERRORS_ONLY}", "AND result_code = 0") - else: - sql_template = sql_template.replace("{ERRORS_ONLY}", "") - - return RetrieveDBResultsToList("DKTG", sql_template) + sql_template = sql_template.replace("{ERRORS_ONLY}", "AND result_code = 0" if errors_only else "") + return fetch_list_from_db(sql_template, {"TEST_RUN_ID": test_run_id}) diff --git a/testgen/commands/run_launch_db_config.py b/testgen/commands/run_launch_db_config.py index bdfa6ab1..01daa060 100644 --- a/testgen/commands/run_launch_db_config.py +++ b/testgen/commands/run_launch_db_config.py @@ -2,12 +2,13 @@ import os from testgen import settings -from testgen.common import CreateDatabaseIfNotExists, RunActionQueryList, date_service +from testgen.common import create_database, date_service, execute_db_queries from testgen.common.credentials import get_tg_db, get_tg_schema from testgen.common.database.database_service import get_queries_for_command from testgen.common.encrypt import EncryptText, encrypt_ui_password from testgen.common.models import with_database_session from testgen.common.models.scores import ScoreDefinition +from testgen.common.models.table_group import TableGroup from testgen.common.read_file import get_template_files LOG = logging.getLogger("testgen") @@ -74,19 +75,20 @@ def _get_params_mapping() -> dict: def run_launch_db_config(delete_db: bool) -> None: params_mapping = _get_params_mapping() - CreateDatabaseIfNotExists(get_tg_db(), params_mapping, delete_db) + create_database(get_tg_db(), params_mapping, drop_existing=delete_db, drop_users_and_roles=True) queries = get_queries_for_command("dbsetup", params_mapping) - RunActionQueryList( - "DKTG", - queries, - "S", + execute_db_queries( + [(query, None) for query in queries], user_override=params_mapping["TESTGEN_ADMIN_USER"], - pwd_override=params_mapping["TESTGEN_ADMIN_PASSWORD"], + password_override=params_mapping["TESTGEN_ADMIN_PASSWORD"], + user_type="schema_admin", ) - ScoreDefinition.from_table_group({ - "project_code": settings.PROJECT_KEY, - "table_groups_name": settings.DEFAULT_TABLE_GROUPS_NAME, - }).save() + ScoreDefinition.from_table_group( + TableGroup( + project_code=settings.PROJECT_KEY, + table_groups_name=settings.DEFAULT_TABLE_GROUPS_NAME, + ) + ).save() \ No newline at end of file diff --git a/testgen/commands/run_observability_exporter.py b/testgen/commands/run_observability_exporter.py index 1efcc1e8..b8f966b9 100644 --- a/testgen/commands/run_observability_exporter.py +++ b/testgen/commands/run_observability_exporter.py @@ -12,10 +12,10 @@ from testgen import settings from testgen.common import date_service, read_template_sql_file from testgen.common.database.database_service import ( - ExecuteDBQuery, - RetrieveDBResultsToDictList, - RetrieveDBResultsToList, + execute_db_queries, + fetch_dict_from_db, ) +from testgen.common.models.test_suite import TestSuite LOG = logging.getLogger("testgen") @@ -77,21 +77,17 @@ def _get_api_endpoint(api_url: str | None, event_type: str) -> str: def collect_event_data(test_suite_id): try: - event_data_query = ( - read_template_sql_file("get_event_data.sql", "observability") - .replace("{TEST_SUITE_ID}", test_suite_id) + event_data_query_result = fetch_dict_from_db( + read_template_sql_file("get_event_data.sql", "observability"), + {"TEST_SUITE_ID": test_suite_id}, ) - - event_data_query_result = RetrieveDBResultsToDictList("DKTG", event_data_query) if not event_data_query_result: LOG.error( f"Could not get event data for exporting to Observability. Test suite '{test_suite_id}'. EXITING!" ) sys.exit(1) if len(event_data_query_result) == 0: - LOG.error( - f"Event data query is empty. Test suite '{test_suite_id}'. Exiting export to Observability!" - ) + LOG.error(f"Event data query is empty. Test suite '{test_suite_id}'. Exiting export to Observability!") sys.exit(1) event = event_data_query_result[0] @@ -99,9 +95,7 @@ def collect_event_data(test_suite_id): api_key = event.observability_api_key api_url = event.observability_api_url except Exception: - LOG.exception( - f"Error collecting event data for exporting to Observability. Test suite '{test_suite_id}'" - ) + LOG.exception(f"Error collecting event data for exporting to Observability. Test suite '{test_suite_id}'") sys.exit(2) else: return event_data, api_url, api_key @@ -208,12 +202,10 @@ def _get_processed_profiling_table_set(profiling_table_set): def collect_test_results(test_suite_id, max_qty_events): try: - query = ( - read_template_sql_file("get_test_results.sql", "observability") - .replace("{TEST_SUITE_ID}", test_suite_id) - .replace("{MAX_QTY_EVENTS}", str(max_qty_events)) + query_results = fetch_dict_from_db( + read_template_sql_file("get_test_results.sql", "observability"), + {"TEST_SUITE_ID": test_suite_id, "MAX_QTY_EVENTS": max_qty_events}, ) - query_results = RetrieveDBResultsToDictList("DKTG", query) collected = [] updated_ids = [] except Exception: @@ -289,14 +281,12 @@ def mark_exported_results(test_suite_id, ids): if len(ids) == 0: return - result_ids = ", ".join(ids) query = ( - read_template_sql_file("update_test_results_exported_to_observability.sql", "observability") - .replace("{TEST_SUITE_ID}", test_suite_id) - .replace("{RESULT_IDS}", result_ids) + read_template_sql_file("update_test_results_exported_to_observability.sql", "observability"), + {"TEST_SUITE_ID": test_suite_id, "TEST_RESULT_IDS": ids}, ) try: - ExecuteDBQuery("DKTG", query) + execute_db_queries([query]) except Exception: LOG.exception("Error marking exported results.") LOG.error( # noqa: TRY400 @@ -321,11 +311,11 @@ def export_test_results(test_suite_id): def run_observability_exporter(project_code, test_suite): LOG.info("CurrentStep: Observability Export - Test Results") - result = RetrieveDBResultsToList( - "DKTG", - f"SELECT id::VARCHAR FROM test_suites WHERE test_suite = '{test_suite}' AND project_code = '{project_code}'" + test_suites = TestSuite.select_minimal_where( + TestSuite.project_code == project_code, + TestSuite.test_suite == test_suite, ) - qty_of_exported_events = export_test_results(result[0][0][0]) + qty_of_exported_events = export_test_results(test_suites[0].id) click.echo(f"{qty_of_exported_events} events have been exported.") diff --git a/testgen/commands/run_profiling_bridge.py b/testgen/commands/run_profiling_bridge.py index 3782dbff..cdffbfaa 100644 --- a/testgen/commands/run_profiling_bridge.py +++ b/testgen/commands/run_profiling_bridge.py @@ -5,56 +5,31 @@ from datetime import UTC, datetime import pandas as pd +from progress.spinner import Spinner import testgen.common.process_service as process_service from testgen import settings from testgen.commands.queries.profiling_query import CProfilingSQL from testgen.commands.run_refresh_score_cards_results import run_refresh_score_cards_results from testgen.common import ( - AssignConnectParms, - QuoteCSVItems, - RetrieveDBResultsToDictList, - RetrieveProfilingParms, - RunActionQueryList, - RunThreadedRetrievalQueryList, - WriteListToDB, date_service, + execute_db_queries, + fetch_dict_from_db, + fetch_from_db_threaded, + get_profiling_params, + quote_csv_items, + set_target_db_params, + write_to_app_db, ) from testgen.common.database.database_service import empty_cache from testgen.common.mixpanel_service import MixpanelService +from testgen.common.models import with_database_session +from testgen.common.models.connection import Connection from testgen.ui.session import session -booClean = True LOG = logging.getLogger("testgen") -def InitializeProfilingSQL(strProject, strSQLFlavor): - return CProfilingSQL(strProject, strSQLFlavor) - - -def CompileAnomalyTestQueries(clsProfiling, lst_tests): - # Get queries for each test - lst_queries = [] - for dct_test_type in lst_tests: - str_query = clsProfiling.GetAnomalyTestQuery(dct_test_type) - if str_query: - lst_queries.append(str_query) - - return lst_queries - - -def CompileAnomalyScoringQueries(clsProfiling, lst_tests): - # Get queries for each test - lst_queries = [] - for dct_test_type in lst_tests: - if dct_test_type["dq_score_prevalence_formula"]: - str_query = clsProfiling.GetAnomalyScoringQuery(dct_test_type) - if str_query: - lst_queries.append(str_query) - - return lst_queries - - def save_contingency_rules(df_merged, threshold_ratio): # Prep rows to save lst_rules = [] @@ -117,8 +92,7 @@ def save_contingency_rules(df_merged, threshold_ratio): ] ) - WriteListToDB( - "DKTG", + write_to_app_db( lst_rules, [ "profile_run_id", @@ -137,7 +111,7 @@ def save_contingency_rules(df_merged, threshold_ratio): ) -def RunPairwiseContingencyCheck(clsProfiling, threshold_ratio): +def RunPairwiseContingencyCheck(clsProfiling: CProfilingSQL, threshold_ratio: float): # Goal: identify pairs of values that represent IF X=A THEN Y=B rules # Define the threshold percent -- should be high @@ -149,8 +123,7 @@ def RunPairwiseContingencyCheck(clsProfiling, threshold_ratio): # Retrieve columns to include in list from profiing results clsProfiling.contingency_max_values = str_max_values - str_query = clsProfiling.GetContingencyColumns() - lst_tables = RetrieveDBResultsToDictList("DKTG", str_query) + lst_tables = fetch_dict_from_db(*clsProfiling.GetContingencyColumns()) # Retrieve record counts per column combination df_merged = None @@ -159,9 +132,8 @@ def RunPairwiseContingencyCheck(clsProfiling, threshold_ratio): df_merged = None clsProfiling.data_schema = dct_table["schema_name"] clsProfiling.data_table = dct_table["table_name"] - clsProfiling.contingency_columns = QuoteCSVItems(dct_table["contingency_columns"]) - str_query = clsProfiling.GetContingencyCounts() - lst_counts = RetrieveDBResultsToDictList("PROJECT", str_query) + clsProfiling.contingency_columns = quote_csv_items(dct_table["contingency_columns"]) + lst_counts = fetch_dict_from_db(*clsProfiling.GetContingencyCounts(), use_target_db=True) if lst_counts: df = pd.DataFrame(lst_counts) # Get list of columns @@ -244,73 +216,50 @@ def run_profiling_in_background(table_group_id): background_thread.start() else: LOG.info(msg) - script = ["testgen", "run-profile", "-tg", table_group_id] + script = ["testgen", "run-profile", "-tg", str(table_group_id)] subprocess.Popen(script) # NOQA S603 -def run_profiling_queries(strTableGroupsID, username=None, spinner=None): - if strTableGroupsID is None: +@with_database_session +def run_profiling_queries(table_group_id: str, username: str | None = None, spinner: Spinner | None = None): + if table_group_id is None: raise ValueError("Table Group ID was not specified") has_errors = False + # Set Project Connection Parms in common.db_bridgers from retrieved parms + LOG.info("CurrentStep: Assigning Connection Parameters") + connection = Connection.get_by_table_group(table_group_id) + set_target_db_params(connection.__dict__) + LOG.info("CurrentStep: Retrieving Parameters") # Generate UUID for Profile Run ID profiling_run_id = str(uuid.uuid4()) - dctParms = RetrieveProfilingParms(strTableGroupsID) + params = get_profiling_params(table_group_id) LOG.info("CurrentStep: Initializing Query Generator") - clsProfiling = InitializeProfilingSQL(dctParms["project_code"], dctParms["sql_flavor"]) - - # Set Project Connection Parms in common.db_bridgers from retrieved parms - LOG.info("CurrentStep: Assigning Connection Parms") - AssignConnectParms( - dctParms["project_code"], - dctParms["connection_id"], - dctParms["project_host"], - dctParms["project_port"], - dctParms["project_db"], - dctParms["table_group_schema"], - dctParms["project_user"], - dctParms["sql_flavor"], - dctParms["url"], - dctParms["connect_by_url"], - dctParms["connect_by_key"], - dctParms["private_key"], - dctParms["private_key_passphrase"], - dctParms["http_path"], - "PROJECT", - ) + clsProfiling = CProfilingSQL(params["project_code"], connection.sql_flavor) # Set General Parms - clsProfiling.table_groups_id = strTableGroupsID - clsProfiling.connection_id = dctParms["connection_id"] - clsProfiling.parm_do_sample = "N" - clsProfiling.parm_sample_size = 0 - clsProfiling.parm_vldb_flag = "N" - clsProfiling.parm_do_freqs = "Y" - clsProfiling.parm_max_freq_length = 25 - clsProfiling.parm_do_patterns = "Y" - clsProfiling.parm_max_pattern_length = 25 + clsProfiling.table_groups_id = table_group_id + clsProfiling.connection_id = connection.connection_id clsProfiling.profile_run_id = profiling_run_id - clsProfiling.data_schema = dctParms["table_group_schema"] - clsProfiling.parm_table_set = dctParms["profiling_table_set"] - clsProfiling.parm_table_include_mask = dctParms["profiling_include_mask"] - clsProfiling.parm_table_exclude_mask = dctParms["profiling_exclude_mask"] - clsProfiling.profile_id_column_mask = dctParms["profile_id_column_mask"] - clsProfiling.profile_sk_column_mask = dctParms["profile_sk_column_mask"] - clsProfiling.profile_use_sampling = dctParms["profile_use_sampling"] - clsProfiling.profile_flag_cdes = dctParms["profile_flag_cdes"] - clsProfiling.profile_sample_percent = dctParms["profile_sample_percent"] - clsProfiling.profile_sample_min_count = dctParms["profile_sample_min_count"] + clsProfiling.data_schema = params["table_group_schema"] + clsProfiling.parm_table_set = params["profiling_table_set"] + clsProfiling.parm_table_include_mask = params["profiling_include_mask"] + clsProfiling.parm_table_exclude_mask = params["profiling_exclude_mask"] + clsProfiling.profile_id_column_mask = params["profile_id_column_mask"] + clsProfiling.profile_sk_column_mask = params["profile_sk_column_mask"] + clsProfiling.profile_use_sampling = params["profile_use_sampling"] + clsProfiling.profile_flag_cdes = params["profile_flag_cdes"] + clsProfiling.profile_sample_percent = params["profile_sample_percent"] + clsProfiling.profile_sample_min_count = params["profile_sample_min_count"] clsProfiling.process_id = process_service.get_current_process_id() # Add a record in profiling_runs table for the new profile - strProfileRunQuery = clsProfiling.GetProfileRunInfoRecordsQuery() - lstProfileRunQuery = [strProfileRunQuery] - RunActionQueryList("DKTG", lstProfileRunQuery) + execute_db_queries([clsProfiling.GetProfileRunInfoRecordsQuery()]) if spinner: spinner.next() @@ -320,8 +269,7 @@ def run_profiling_queries(strTableGroupsID, username=None, spinner=None): # Retrieve Column Metadata LOG.info("CurrentStep: Getting DDF from project") - strQuery = clsProfiling.GetDDFQuery() - lstResult = RetrieveDBResultsToDictList("PROJECT", strQuery) + lstResult = fetch_dict_from_db(*clsProfiling.GetDDFQuery(), use_target_db=True) column_count = len(lstResult) if lstResult: @@ -341,13 +289,12 @@ def run_profiling_queries(strTableGroupsID, username=None, spinner=None): lstQueries = [] for parm_sampling_table in distinct_tables_list: clsProfiling.sampling_table = parm_sampling_table - strQuery = clsProfiling.GetTableSampleCount() - lstQueries.append(strQuery) + lstQueries.append(clsProfiling.GetTableSampleCount()) - lstSampleTables, _, intErrors = RunThreadedRetrievalQueryList( - "PROJECT", lstQueries, dctParms["max_threads"], spinner + lstSampleTables, _, intErrors = fetch_from_db_threaded( + lstQueries, use_target_db=True, max_threads=connection.max_threads, spinner=spinner ) - dctSampleTables = {x[0]: [x[1], x[2]] for x in lstSampleTables} + dctSampleTables = {x[0]: [x[1], x[2], x[3]] for x in lstSampleTables} if intErrors > 0: has_errors = True LOG.warning( @@ -368,7 +315,6 @@ def run_profiling_queries(strTableGroupsID, username=None, spinner=None): clsProfiling.profile_run_id = profiling_run_id clsProfiling.col_is_decimal = dctColumnRecord["is_decimal"] clsProfiling.col_ordinal_position = dctColumnRecord["ordinal_position"] - clsProfiling.col_max_char_length = dctColumnRecord["character_maximum_length"] clsProfiling.col_gen_type = dctColumnRecord["general_type"] clsProfiling.parm_do_sample = "N" @@ -380,20 +326,23 @@ def run_profiling_queries(strTableGroupsID, username=None, spinner=None): clsProfiling.sample_ratio = dctSampleTables[ clsProfiling.data_schema + "." + clsProfiling.data_table ][1] + clsProfiling.sample_percent_calc = dctSampleTables[ + clsProfiling.data_schema + "." + clsProfiling.data_table + ][2] clsProfiling.parm_do_sample = clsProfiling.profile_use_sampling else: clsProfiling.parm_sample_size = 0 clsProfiling.sample_ratio = "" + clsProfiling.sample_percent_calc = "" - strQuery = clsProfiling.GetProfilingQuery() - lstQueries.append(strQuery) + lstQueries.append(clsProfiling.GetProfilingQuery()) # Run Profiling Queries and save results LOG.info("CurrentStep: Profiling Round 1") LOG.debug("Running %s profiling queries", len(lstQueries)) - lstProfiles, colProfileNames, intErrors = RunThreadedRetrievalQueryList( - "PROJECT", lstQueries, dctParms["max_threads"], spinner + lstProfiles, colProfileNames, intErrors = fetch_from_db_threaded( + lstQueries, use_target_db=True, max_threads=connection.max_threads, spinner=spinner ) if intErrors > 0: has_errors = True @@ -401,7 +350,7 @@ def run_profiling_queries(strTableGroupsID, username=None, spinner=None): f"Errors were encountered executing profiling queries. ({intErrors} errors occurred.) Please check log." ) LOG.info("CurrentStep: Saving Round 1 profiling results to Metadata") - WriteListToDB("DKTG", lstProfiles, colProfileNames, "profile_results") + write_to_app_db(lstProfiles, colProfileNames, "profile_results") if clsProfiling.profile_use_sampling == "Y": lstQueries = [] @@ -409,17 +358,15 @@ def run_profiling_queries(strTableGroupsID, username=None, spinner=None): if value[0] > -1: clsProfiling.sampling_table = table_name clsProfiling.sample_ratio = value[1] - strQuery = clsProfiling.UpdateProfileResultsToEst() - lstQueries.append(strQuery) + lstQueries.append(clsProfiling.UpdateProfileResultsToEst()) - RunActionQueryList("DKTG", lstQueries) + execute_db_queries(lstQueries) if clsProfiling.parm_do_freqs == "Y": lstUpdates = [] # Get secondary profiling columns LOG.info("CurrentStep: Selecting columns for frequency analysis") - strQuery = clsProfiling.GetSecondProfilingColumnsQuery() - lstResult = RetrieveDBResultsToDictList("DKTG", strQuery) + lstResult = fetch_dict_from_db(*clsProfiling.GetSecondProfilingColumnsQuery()) if lstResult: # Assemble secondary profiling queries @@ -431,12 +378,11 @@ def run_profiling_queries(strTableGroupsID, username=None, spinner=None): clsProfiling.data_table = dctColumnRecord["table_name"] clsProfiling.col_name = dctColumnRecord["column_name"] - strQuery = clsProfiling.GetSecondProfilingQuery() - lstQueries.append(strQuery) + lstQueries.append(clsProfiling.GetSecondProfilingQuery()) # Run secondary profiling queries LOG.info("CurrentStep: Retrieving %s frequency results from project", len(lstQueries)) - lstUpdates, colProfileNames, intErrors = RunThreadedRetrievalQueryList( - "PROJECT", lstQueries, dctParms["max_threads"], spinner + lstUpdates, colProfileNames, intErrors = fetch_from_db_threaded( + lstQueries, use_target_db=True, max_threads=connection.max_threads, spinner=spinner ) if intErrors > 0: has_errors = True @@ -447,7 +393,7 @@ def run_profiling_queries(strTableGroupsID, username=None, spinner=None): if lstUpdates: # Copy secondary results to DQ staging LOG.info("CurrentStep: Writing frequency results to Staging") - WriteListToDB("DKTG", lstUpdates, colProfileNames, "stg_secondary_profile_updates") + write_to_app_db(lstUpdates, colProfileNames, "stg_secondary_profile_updates") LOG.info("CurrentStep: Generating profiling update queries") @@ -456,41 +402,40 @@ def run_profiling_queries(strTableGroupsID, username=None, spinner=None): if lstUpdates: # Run single update query, then delete from staging - strQuery = clsProfiling.GetSecondProfilingUpdateQuery() - lstQueries.append(strQuery) - strQuery = clsProfiling.GetSecondProfilingStageDeleteQuery() - lstQueries.append(strQuery) - strQuery = clsProfiling.GetDataTypeSuggestionUpdateQuery() - lstQueries.append(strQuery) - strQuery = clsProfiling.GetFunctionalDataTypeUpdateQuery() - lstQueries.append(strQuery) - strQuery = clsProfiling.GetFunctionalTableTypeStageQuery() - lstQueries.append(strQuery) - strQuery = clsProfiling.GetFunctionalTableTypeUpdateQuery() - lstQueries.append(strQuery) - strQuery = clsProfiling.GetPIIFlagUpdateQuery() - lstQueries.append(strQuery) - - strQuery = clsProfiling.GetAnomalyTestTypesQuery() - lstAnomalyTypes = RetrieveDBResultsToDictList("DKTG", strQuery) - lstQueries.extend(CompileAnomalyTestQueries(clsProfiling, lstAnomalyTypes)) - lstQueries.extend(CompileAnomalyScoringQueries(clsProfiling, lstAnomalyTypes)) - strQuery = clsProfiling.GetAnomalyStatsRefreshQuery() - lstQueries.append(strQuery) + lstQueries.extend([ + clsProfiling.GetSecondProfilingUpdateQuery(), + clsProfiling.GetSecondProfilingStageDeleteQuery(), + ]) + lstQueries.extend([ + clsProfiling.GetDataTypeSuggestionUpdateQuery(), + clsProfiling.GetFunctionalDataTypeUpdateQuery(), + clsProfiling.GetFunctionalTableTypeStageQuery(), + clsProfiling.GetFunctionalTableTypeUpdateQuery(), + clsProfiling.GetPIIFlagUpdateQuery(), + ]) + + lstAnomalyTypes = fetch_dict_from_db(*clsProfiling.GetAnomalyTestTypesQuery()) + lstQueries.extend([ + query for test_type in lstAnomalyTypes if (query := clsProfiling.GetAnomalyTestQuery(test_type)) + ]) + lstQueries.extend([ + clsProfiling.GetAnomalyScoringQuery(test_type) + for test_type in lstAnomalyTypes + if test_type["dq_score_prevalence_formula"] + ]) + lstQueries.append(clsProfiling.GetAnomalyStatsRefreshQuery()) # Always runs last - strQuery = clsProfiling.GetDataCharsRefreshQuery() - lstQueries.append(strQuery) + lstQueries.append(clsProfiling.GetDataCharsRefreshQuery()) if clsProfiling.profile_flag_cdes: - strQuery = clsProfiling.GetCDEFlaggerQuery() - lstQueries.append(strQuery) + lstQueries.append(clsProfiling.GetCDEFlaggerQuery()) LOG.info("CurrentStep: Running profiling update queries") - RunActionQueryList("DKTG", lstQueries) + execute_db_queries(lstQueries) - if dctParms["profile_do_pair_rules"] == "Y": + if params["profile_do_pair_rules"] == "Y": LOG.info("CurrentStep: Compiling pairwise contingency rules") - RunPairwiseContingencyCheck(clsProfiling, dctParms["profile_pair_rule_pct"]) + RunPairwiseContingencyCheck(clsProfiling, params["profile_pair_rule_pct"]) else: LOG.info("No columns were selected to profile.") except Exception as e: @@ -501,17 +446,15 @@ def run_profiling_queries(strTableGroupsID, username=None, spinner=None): raise finally: LOG.info("Updating the profiling run record") - RunActionQueryList("DKTG", [ - clsProfiling.GetProfileRunInfoRecordUpdateQuery(), - ]) + execute_db_queries([clsProfiling.GetProfileRunInfoRecordUpdateQuery()]) end_time = datetime.now(UTC) - RunActionQueryList("DKTG", [ + execute_db_queries([ clsProfiling.GetAnomalyScoringRollupRunQuery(), clsProfiling.GetAnomalyScoringRollupTableGroupQuery(), ]) run_refresh_score_cards_results( - project_code=dctParms["project_code"], + project_code=params["project_code"], add_history_entry=True, refresh_date=date_service.parse_now(clsProfiling.run_date), ) diff --git a/testgen/commands/run_quick_start.py b/testgen/commands/run_quick_start.py index 1ab68f40..659109a0 100644 --- a/testgen/commands/run_quick_start.py +++ b/testgen/commands/run_quick_start.py @@ -3,14 +3,15 @@ import click from testgen import settings -from testgen.commands.run_get_entities import run_table_group_list from testgen.commands.run_launch_db_config import run_launch_db_config +from testgen.common.credentials import get_tg_schema from testgen.common.database.database_service import ( - AssignConnectParms, - CreateDatabaseIfNotExists, - RunActionQueryList, + create_database, + execute_db_queries, replace_params, + set_target_db_params, ) +from testgen.common.database.flavor.flavor_service import ConnectionParams from testgen.common.read_file import read_template_sql_file LOG = logging.getLogger("testgen") @@ -69,29 +70,23 @@ def _get_max_productid_seq(iteration: int): def _prepare_connection_to_target_database(params_mapping): - AssignConnectParms( - params_mapping["PROJECT_KEY"], - None, - params_mapping["PROJECT_DB_HOST"], - params_mapping["PROJECT_DB_PORT"], - params_mapping["PROJECT_DB"], - params_mapping["PROJECT_SCHEMA"], - params_mapping["TESTGEN_ADMIN_USER"], - params_mapping["SQL_FLAVOR"], - None, - None, - False, - None, - None, - None, - "PROJECT", - ) + connection_params: ConnectionParams = { + "sql_flavor": params_mapping["SQL_FLAVOR"], + "project_host": params_mapping["PROJECT_DB_HOST"], + "project_port": params_mapping["PROJECT_DB_PORT"], + "project_db": params_mapping["PROJECT_DB"], + "project_user": params_mapping["TESTGEN_ADMIN_USER"], + "table_group_schema": params_mapping["PROJECT_SCHEMA"], + "project_pw_encrypted": params_mapping["TESTGEN_ADMIN_PASSWORD"], + } + set_target_db_params(connection_params) def _get_params_mapping(iteration: int = 0) -> dict: return { "TESTGEN_ADMIN_USER": settings.DATABASE_ADMIN_USER, "TESTGEN_ADMIN_PASSWORD": settings.DATABASE_ADMIN_PASSWORD, + "SCHEMA_NAME": get_tg_schema(), "PROJECT_DB": settings.PROJECT_DATABASE_NAME, "PROJECT_SCHEMA": settings.PROJECT_DATABASE_SCHEMA, "PROJECT_KEY": settings.PROJECT_KEY, @@ -114,7 +109,7 @@ def run_quick_start(delete_target_db: bool) -> None: # Create DB target_db_name = params_mapping["PROJECT_DB"] click.echo(f"Creating target db : {target_db_name}") - CreateDatabaseIfNotExists(target_db_name, params_mapping, delete_target_db, drop_users_and_roles=False) + create_database(target_db_name, params_mapping, drop_existing=delete_target_db) # run setup command = "testgen setup-system-db --delete-db --yes" @@ -124,22 +119,14 @@ def run_quick_start(delete_target_db: bool) -> None: # Schema and Populate target db click.echo(f"Populating target db : {target_db_name}") - queries = [ - replace_params(read_template_sql_file("recreate_target_data_schema.sql", "quick_start"), params_mapping), - replace_params(read_template_sql_file("populate_target_data.sql", "quick_start"), params_mapping), - ] - RunActionQueryList( - "PROJECT", - queries, - user_override=params_mapping["TESTGEN_ADMIN_USER"], - pwd_override=params_mapping["TESTGEN_ADMIN_PASSWORD"], + execute_db_queries( + [ + (replace_params(read_template_sql_file("recreate_target_data_schema.sql", "quick_start"), params_mapping), params_mapping), + (replace_params(read_template_sql_file("populate_target_data.sql", "quick_start"), params_mapping), params_mapping), + ], + use_target_db=True, ) - # Get table group id - project_key = params_mapping["PROJECT_KEY"] - rows, _ = run_table_group_list(project_key) - connection_id = str(rows[0][2]) - def run_quick_start_increment(iteration): params_mapping = _get_params_mapping(iteration) @@ -148,12 +135,30 @@ def run_quick_start_increment(iteration): target_db_name = params_mapping["PROJECT_DB"] LOG.info(f"Incremental population of target db : {target_db_name}") - queries = [ - replace_params(read_template_sql_file("update_target_data.sql", "quick_start"), params_mapping), - ] - RunActionQueryList( - "PROJECT", - queries, - user_override=params_mapping["TESTGEN_ADMIN_USER"], - pwd_override=params_mapping["TESTGEN_ADMIN_PASSWORD"], + execute_db_queries( + [ + (replace_params(read_template_sql_file("update_target_data.sql", "quick_start"), params_mapping), params_mapping), + (replace_params(read_template_sql_file(f"update_target_data_iter{iteration}.sql", "quick_start"), params_mapping), params_mapping), + ], + use_target_db=True, + ) + setup_cat_tests(iteration) + + +def setup_cat_tests(iteration): + if iteration == 0: + return + elif iteration == 1: + sql_file = "add_cat_tests.sql" + elif iteration >=1: + sql_file = "update_cat_tests.sql" + + params_mapping = _get_params_mapping(iteration) + query = replace_params(read_template_sql_file(sql_file, "quick_start"), params_mapping) + + execute_db_queries( + [ + (query, params_mapping), + ], + use_target_db=False, ) diff --git a/testgen/commands/run_refresh_data_chars.py b/testgen/commands/run_refresh_data_chars.py index 5bf2a230..78489445 100644 --- a/testgen/commands/run_refresh_data_chars.py +++ b/testgen/commands/run_refresh_data_chars.py @@ -4,23 +4,23 @@ from testgen.commands.queries.refresh_data_chars_query import CRefreshDataCharsSQL from testgen.common.database.database_service import ( - RetrieveDBResultsToDictList, - RunActionQueryList, - RunThreadedRetrievalQueryList, - WriteListToDB, + execute_db_queries, + fetch_dict_from_db, + fetch_from_db_threaded, + write_to_app_db, ) +from testgen.common.get_pipeline_parms import TestExecutionParams LOG = logging.getLogger("testgen") STAGING_TABLE = "stg_data_chars_updates" -def run_refresh_data_chars_queries(params: dict, run_date: str, spinner: Spinner=None): +def run_refresh_data_chars_queries(params: TestExecutionParams, run_date: str, spinner: Spinner=None): LOG.info("CurrentStep: Initializing Data Characteristics Refresh") sql_generator = CRefreshDataCharsSQL(params, run_date, STAGING_TABLE) LOG.info("CurrentStep: Getting DDF for table group") - ddf_query = sql_generator.GetDDFQuery() - ddf_results = RetrieveDBResultsToDictList("PROJECT", ddf_query) + ddf_results = fetch_dict_from_db(*sql_generator.GetDDFQuery(), use_target_db=True) distinct_tables = { f"{item['table_schema']}.{item['table_name']}" @@ -30,8 +30,8 @@ def run_refresh_data_chars_queries(params: dict, run_date: str, spinner: Spinner count_queries = sql_generator.GetRecordCountQueries(distinct_tables) LOG.info("CurrentStep: Getting record counts for table group") - count_results, _, error_count = RunThreadedRetrievalQueryList( - "PROJECT", count_queries, params["max_threads"], spinner + count_results, _, error_count = fetch_from_db_threaded( + count_queries, use_target_db=True, max_threads=params["max_threads"], spinner=spinner ) if error_count: LOG.warning(f"{error_count} errors were encountered while retrieving record counts.") @@ -69,10 +69,10 @@ def run_refresh_data_chars_queries(params: dict, run_date: str, spinner: Spinner ] LOG.info("CurrentStep: Writing data characteristics to staging") - WriteListToDB("DKTG", staging_records, staging_columns, STAGING_TABLE) + write_to_app_db(staging_records, staging_columns, STAGING_TABLE) LOG.info("CurrentStep: Refreshing data characteristics and deleting staging") - RunActionQueryList("DKTG", [ + execute_db_queries([ sql_generator.GetDataCharsUpdateQuery(), sql_generator.GetStagingDeleteQuery(), ]) diff --git a/testgen/commands/run_rollup_scores.py b/testgen/commands/run_rollup_scores.py index c3ee9646..e835571e 100644 --- a/testgen/commands/run_rollup_scores.py +++ b/testgen/commands/run_rollup_scores.py @@ -2,7 +2,7 @@ from testgen.commands.queries.rollup_scores_query import CRollupScoresSQL from testgen.commands.run_refresh_score_cards_results import run_refresh_score_cards_results -from testgen.common.database.database_service import RunActionQueryList +from testgen.common.database.database_service import execute_db_queries LOG = logging.getLogger("testgen") @@ -16,7 +16,7 @@ def run_profile_rollup_scoring_queries(project_code: str, run_id: str, table_gro queries.append(sql_generator.GetRollupScoresProfileTableGroupQuery()) LOG.info("CurrentStep: Rolling up profiling scores") - RunActionQueryList("DKTG", queries) + execute_db_queries(queries) run_refresh_score_cards_results(project_code=project_code) @@ -29,5 +29,5 @@ def run_test_rollup_scoring_queries(project_code: str, run_id: str, table_group_ queries.append(sql_generator.GetRollupScoresTestTableGroupQuery()) LOG.info("CurrentStep: Rolling up testing scores") - RunActionQueryList("DKTG", queries) + execute_db_queries(queries) run_refresh_score_cards_results(project_code=project_code) diff --git a/testgen/commands/run_test_parameter_validation.py b/testgen/commands/run_test_parameter_validation.py index 71668bcd..db6ba728 100644 --- a/testgen/commands/run_test_parameter_validation.py +++ b/testgen/commands/run_test_parameter_validation.py @@ -4,37 +4,37 @@ from testgen.commands.queries.test_parameter_validation_query import CTestParamValidationSQL from testgen.common import ( - RetrieveDBResultsToDictList, - RetrieveDBResultsToList, - RunActionQueryList, + execute_db_queries, + fetch_dict_from_db, + fetch_list_from_db, ) +from testgen.common.get_pipeline_parms import TestExecutionParams LOG = logging.getLogger("testgen") def run_parameter_validation_queries( - dctParms, test_run_id="", test_time="", strTestSuite="" + params: TestExecutionParams, + test_run_id: str = "", + test_time: str = "", + test_suite: str = "", ): - LOG.info("CurrentStep: Initializing Test Parameter Validation") - clsExecute = CTestParamValidationSQL(dctParms["sql_flavor"], dctParms["test_suite_id"]) + clsExecute = CTestParamValidationSQL(params["sql_flavor"], params["test_suite_id"]) clsExecute.run_date = test_time clsExecute.test_run_id = test_run_id LOG.info("CurrentStep: Validation Class successfully initialized") - booClean = False - # Retrieve Test Column list LOG.info("CurrentStep: Retrieve Test Columns for Validation") - strColumnList = clsExecute.GetTestValidationColumns(booClean) - test_columns, _ = RetrieveDBResultsToList("DKTG", strColumnList) + test_columns, _ = fetch_list_from_db(*clsExecute.GetTestValidationColumns()) invalid_tests = [ test_ids for col, test_ids in test_columns if not col ] invalid_tests = { item for sublist in invalid_tests for item in sublist } test_columns = [ item for item in test_columns if item[0] ] if not test_columns: - LOG.warning(f"No test columns are present to validate in Test Suite {strTestSuite}") + LOG.warning(f"No test columns are present to validate in Test Suite {test_suite}") missing_columns = [] missing_tables = set() else: @@ -46,10 +46,7 @@ def run_parameter_validation_queries( # Retrieve Current Project Column list LOG.info("CurrentStep: Retrieve Current Columns for Validation") clsExecute.test_schemas = strSchemas - strProjectColumnList = clsExecute.GetProjectTestValidationColumns() - if "where table_schema in ()" in strProjectColumnList: - raise ValueError("No schema specified in Validation Columns check") - lstProjectTestColumns = RetrieveDBResultsToDictList("PROJECT", strProjectColumnList) + lstProjectTestColumns = fetch_dict_from_db(*clsExecute.GetProjectTestValidationColumns(), use_target_db=True) if len(lstProjectTestColumns) == 0: LOG.info("Current Test Column list is empty") @@ -91,36 +88,30 @@ def run_parameter_validation_queries( clsExecute.flag_val = "D" clsExecute.test_ids = list(set(chain(*tests_missing_tables.values(), *tests_missing_columns.values(), invalid_tests))) - strPrepFlagTests = clsExecute.PrepFlagTestsWithFailedValidation() - RunActionQueryList("DKTG", [strPrepFlagTests]) + execute_db_queries([clsExecute.PrepFlagTestsWithFailedValidation()]) for column_name, test_ids in tests_missing_columns.items(): clsExecute.message = f"Missing column: {column_name}" clsExecute.test_ids = test_ids - strFlagTests = clsExecute.FlagTestsWithFailedValidation() - RunActionQueryList("DKTG", [strFlagTests]) + execute_db_queries([clsExecute.FlagTestsWithFailedValidation()]) for table_name, test_ids in tests_missing_tables.items(): clsExecute.message = f"Missing table: {table_name}" clsExecute.test_ids = test_ids - strFlagTests = clsExecute.FlagTestsWithFailedValidation() - RunActionQueryList("DKTG", [strFlagTests]) + execute_db_queries([clsExecute.FlagTestsWithFailedValidation()]) if invalid_tests: clsExecute.message = "Invalid test: schema, table, or column not defined" clsExecute.test_ids = invalid_tests - strFlagTests = clsExecute.FlagTestsWithFailedValidation() - RunActionQueryList("DKTG", [strFlagTests]) + execute_db_queries([clsExecute.FlagTestsWithFailedValidation()]) # Copy test results to DK DB, using temporary flagged D value to identify LOG.info("CurrentStep: Saving error results for invalid tests") - strReportValErrors = clsExecute.ReportTestValidationErrors() - RunActionQueryList("DKTG", [strReportValErrors]) + execute_db_queries([clsExecute.ReportTestValidationErrors()]) # Set to Inactive those test_definitions tests that are flagged D: set to N LOG.info("CurrentStep: Disabling Tests That Failed Validation") - strDisableTests = clsExecute.DisableTestsWithFailedValidation() - RunActionQueryList("DKTG", [strDisableTests]) + execute_db_queries([clsExecute.DisableTestsWithFailedValidation()]) LOG.info("Validation Complete: Tests referencing missing tables or columns have been deactivated.") else: diff --git a/testgen/commands/run_upgrade_db_config.py b/testgen/commands/run_upgrade_db_config.py index be82d7df..5d532120 100644 --- a/testgen/commands/run_upgrade_db_config.py +++ b/testgen/commands/run_upgrade_db_config.py @@ -1,7 +1,7 @@ import logging from testgen import settings -from testgen.common import RetrieveSingleResultValue, RunActionQueryList, read_template_sql_file +from testgen.common import execute_db_queries, fetch_dict_from_db, read_template_sql_file from testgen.common.credentials import get_tg_schema from testgen.common.database.database_service import replace_params from testgen.common.read_file import get_template_files @@ -22,9 +22,8 @@ def _get_revision_prefix(params_mapping): strQuery = read_template_sql_file("get_tg_revision.sql", "dbupgrade_helpers") strQuery = replace_params(strQuery, params_mapping) - intNextRevision = RetrieveSingleResultValue("DKTG", strQuery) - - return intNextRevision + result = fetch_dict_from_db(strQuery) + return result[0]["revision"] def _get_next_revision_prefix(params_mapping): @@ -51,7 +50,7 @@ def _get_upgrade_template_directory(): return "dbupgrade" -def _get_upgrade_scripts(sub_directory: str, params_mapping: dict, mask: str = r"^.*sql$", min_val: str = ""): +def _get_upgrade_scripts(sub_directory: str, params_mapping: dict, mask: str = r"^.*sql$", min_val: str = "") -> tuple[list[tuple[str, dict]], str]: files = sorted(get_template_files(mask=mask, sub_directory=sub_directory), key=lambda key: str(key)) max_prefix = "" @@ -60,7 +59,7 @@ def _get_upgrade_scripts(sub_directory: str, params_mapping: dict, mask: str = r if file.name > min_val: template = file.read_text("utf-8") query = replace_params(template, params_mapping) - queries.append(query) + queries.append((query, None)) max_prefix = file.name[0:4] if len(queries) == 0: @@ -69,14 +68,13 @@ def _get_upgrade_scripts(sub_directory: str, params_mapping: dict, mask: str = r return queries, max_prefix -def _execute_upgrade_scripts(params_mapping, lstScripts): +def _execute_upgrade_scripts(params_mapping: dict, lstScripts: list[tuple[str, dict]]): # Run scripts using admin credentials - RunActionQueryList( - "DKTG", + execute_db_queries( lstScripts, - "S", user_override=params_mapping["TESTGEN_ADMIN_USER"], - pwd_override=params_mapping["TESTGEN_ADMIN_PASSWORD"], + password_override=params_mapping["TESTGEN_ADMIN_PASSWORD"], + user_type="schema_admin", ) return True @@ -92,12 +90,11 @@ def _refresh_static_metadata(params_mapping): strQueryRights = read_template_sql_file("075_grant_role_rights.sql", "dbsetup") strQueryRights = replace_params(strQueryRights, params_mapping) - RunActionQueryList( - "DKTG", - [strQueryMetadata, strQueryViews, strQueryRights], - "S", + execute_db_queries( + [(strQueryMetadata, None), (strQueryViews, None), (strQueryRights, None)], user_override=params_mapping["TESTGEN_ADMIN_USER"], - pwd_override=params_mapping["TESTGEN_ADMIN_PASSWORD"], + password_override=params_mapping["TESTGEN_ADMIN_PASSWORD"], + user_type="schema_admin", ) @@ -107,38 +104,34 @@ def _update_revision_number(params_mapping, latest_prefix_applied): strQuery = strQuery.replace("{DB_REVISION}", str(int(latest_prefix_applied))) strQuery = replace_params(strQuery, params_mapping) - RunActionQueryList( - "DKTG", - [strQuery], - "S", + execute_db_queries( + [(strQuery, None)], user_override=params_mapping["TESTGEN_ADMIN_USER"], - pwd_override=params_mapping["TESTGEN_ADMIN_PASSWORD"], + password_override=params_mapping["TESTGEN_ADMIN_PASSWORD"], + user_type="schema_admin", ) def run_upgrade_db_config() -> bool: - LOG.info("Running run_upgrade_db_config") - + LOG.info("Upgrading system version") params_mapping = _get_params_mapping() + current_revision = _get_revision_prefix(params_mapping) - # Look for prefix one higher than last revision extant in db - strNextPrefix = _format_revision_prefix(_get_next_revision_prefix(params_mapping)) - # Retrieve template upgrade directory name + next_revision = _format_revision_prefix(_get_next_revision_prefix(params_mapping)) upgrade_dir = _get_upgrade_template_directory() - # Retrieve and execute upgrade scripts, if any - lstQueries, max_prefix = _get_upgrade_scripts(upgrade_dir, params_mapping, min_val=strNextPrefix) - LOG.info(f"Updating db config qty of queries: {len(lstQueries)}. New prefix: {max_prefix}. Queries: {lstQueries}") - if len(lstQueries) > 0: - has_been_upgraded = _execute_upgrade_scripts(params_mapping, lstQueries) + queries, max_revision = _get_upgrade_scripts(upgrade_dir, params_mapping, min_val=next_revision) + LOG.info(f"Current revision: {current_revision}. Latest revision: {max_revision or current_revision}. Upgrade scripts: {len(queries)}") + if len(queries) > 0: + has_been_upgraded = _execute_upgrade_scripts(params_mapping, queries) else: has_been_upgraded = False + LOG.info("Refreshing static metadata") _refresh_static_metadata(params_mapping) if has_been_upgraded: - # Update revision number to max prefix found in update scripts - _update_revision_number(params_mapping, max_prefix) + _update_revision_number(params_mapping, max_revision) LOG.info("Application data was successfully upgraded, and static metadata was refreshed.") else: LOG.info("Database upgrade was not required. Static metadata was refreshed.") diff --git a/testgen/common/database/database_service.py b/testgen/common/database/database_service.py index 643217ec..75016501 100644 --- a/testgen/common/database/database_service.py +++ b/testgen/common/database/database_service.py @@ -5,10 +5,17 @@ import queue as qu import threading from contextlib import suppress +from dataclasses import dataclass, field +from typing import Any, Literal, TypedDict from urllib.parse import quote_plus +import psycopg2.sql +from progress.spinner import Spinner from sqlalchemy import create_engine, text +from sqlalchemy.engine import LegacyRow, RowMapping +from sqlalchemy.engine.base import Connection, Engine from sqlalchemy.exc import ProgrammingError, SQLAlchemyError +from sqlalchemy.pool.base import _ConnectionFairy from testgen import settings from testgen.common.credentials import ( @@ -20,623 +27,390 @@ get_tg_username, ) from testgen.common.database import FilteredStringIO -from testgen.common.encrypt import DecryptText +from testgen.common.database.flavor.flavor_service import ConnectionParams, FlavorService, SQLFlavor from testgen.common.read_file import get_template_files LOG = logging.getLogger("testgen") +# "normal": Log into database/schema for normal stuff +# "database_admin": Log into postgres/public to create database via override user/password +# "schema_admin": Log into database/public to create schema and run scripts via override user/password +UserType = Literal["normal", "database_admin", "schema_admin"] -class CConnectParms: - connectname = "" - projectcode = "" - connectid = "" - hostname = "" - port = "" - dbname = "" - schemaname = "" - username = "" - sql_flavor = "" - url = "" - connect_by_url = "" - connect_by_key = "" - private_key = "" - private_key_passphrase = "" - password = None - http_path = "" - - def __init__(self, connectname): - self.connectname = connectname - +@dataclass +class EngineCache: + app_db: Engine | None = field(default=None) + target_db: Engine | None = field(default=None) # Initialize variables global to this script -clsConnectParms = CConnectParms("NONE") -dctDBEngines = {} +target_db_params: ConnectionParams | None = None +engine_cache = EngineCache() -def QuoteCSVItems(str_csv_row, char_quote='"'): - if str_csv_row: - lst_values = str_csv_row.split(",") +def quote_csv_items(csv_row: str, quote_character: str = '"') -> str: + if csv_row: + values = csv_row.split(",") # Process each value individually, quoting it if not already quoted - str_quoted_values = ",".join( + quoted_values = ",".join( [ ( - f"{char_quote}{value}{char_quote}" - if not (value.startswith(char_quote) and value.endswith(char_quote)) + f"{quote_character}{value}{quote_character}" + if not (value.startswith(quote_character) and value.endswith(quote_character)) else value ) - for value in lst_values + for value in values ] ) - return str_quoted_values - return str_csv_row - - -def empty_cache(): - global dctDBEngines - dctDBEngines = {} - - -def AssignConnectParms( - projectcode, - connectid, - host, - port, - dbname, - schema, - user, - flavor, - url, - connect_by_url, - connect_by_key, - private_key, - private_key_passphrase, - http_path, - connectname="PROJECT", - password=None, -): - global clsConnectParms - - clsConnectParms.connectname = connectname - clsConnectParms.projectcode = projectcode - clsConnectParms.connectid = connectid - clsConnectParms.hostname = host - clsConnectParms.port = port - clsConnectParms.dbname = dbname - clsConnectParms.schemaname = schema - clsConnectParms.username = user - clsConnectParms.sql_flavor = flavor - clsConnectParms.password = password - clsConnectParms.url = url - clsConnectParms.connect_by_url = connect_by_url - clsConnectParms.connect_by_key = connect_by_key - clsConnectParms.private_key = private_key - clsConnectParms.private_key_passphrase = private_key_passphrase - clsConnectParms.http_path = http_path - - -def _RetrieveProjectPW(strProjectCode, strConnID): - strSQL = """ SELECT project_pw_encrypted - FROM connections cc - WHERE cc.project_code = '{PROJECT_CODE}' AND cc.connection_id = {CONNECTION_ID}; """ - - # Replace Parameters - strSQL = strSQL.replace("{PROJECT_CODE}", strProjectCode) - strSQL = strSQL.replace("{CONNECTION_ID}", str(strConnID)) - # Execute Query - strPW = RetrieveSingleResultValue("DKTG", strSQL) - # Convert Postgres bytea to Python byte array - strPW = bytes(strPW) if strPW else None - - # Perform Decryption - strPW = DecryptText(strPW) - return strPW - - -def _GetDBPassword(strCredentialSet): - global clsConnectParms - - if strCredentialSet == "PROJECT": - if not clsConnectParms.password: - strPW = _RetrieveProjectPW(clsConnectParms.projectcode, clsConnectParms.connectid) - else: - strPW = clsConnectParms.password - elif strCredentialSet == "DKTG": - strPW = get_tg_password() - else: - raise ValueError('Credential Set "' + strCredentialSet + '" is unknown.') - - if strPW == "": - raise ValueError('Password for Credential Set "' + strCredentialSet + '" is unknown.') - else: - return strPW - - -def get_db_type(sql_flavor): - # This is for connection purposes. sqlalchemy 1.4.46 uses postgresql to connect to redshift database - if sql_flavor == "redshift": - return "postgresql" - else: - return sql_flavor - - -def _GetDBCredentials(strCredentialSet): - global clsConnectParms - - if strCredentialSet == "PROJECT": - # Check for unassigned parms - if clsConnectParms.connectname == "NONE": - raise ValueError("Project Connection Parameters were not set.") - - strConnectflavor = get_db_type(clsConnectParms.sql_flavor) - - # Get project credentials from clsConnectParms - dctCredentials = { - "name": strCredentialSet, - "host": clsConnectParms.hostname, - "port": clsConnectParms.port, - "dbname": clsConnectParms.dbname, - "dbschema": clsConnectParms.schemaname, - "user": clsConnectParms.username, - "flavor": strConnectflavor, - "dbtype": clsConnectParms.sql_flavor, - "url": clsConnectParms.url, - "connect_by_url": clsConnectParms.connect_by_url, - "connect_by_key": clsConnectParms.connect_by_key, - "private_key": clsConnectParms.private_key, - "private_key_passphrase": clsConnectParms.private_key_passphrase, - "http_path": clsConnectParms.http_path, - } - elif strCredentialSet == "DKTG": - # Get credentials from functions in my_dk_credentials.py - dctCredentials = { - "name": strCredentialSet, - "host": get_tg_host(), - "port": get_tg_port(), - "dbname": get_tg_db(), - "dbschema": get_tg_schema(), - "user": get_tg_username(), - "flavor": "postgresql", - "dbtype": "postgresql", - } - else: - raise ValueError("Credentials for " + strCredentialSet + " are not defined.") - - return dctCredentials - - -def get_flavor_service(flavor): - module_path = f"testgen.common.database.flavor.{flavor}_flavor_service" - class_name = f"{flavor.capitalize()}FlavorService" - module = importlib.import_module(module_path) - flavor_class = getattr(module, class_name) - return flavor_class() - - -def _InitDBConnection(strCredentialSet, strRaw="N", strAdmin="N", user_override=None, pwd_override=None): - # Get DB Credentials - dctCredentials = _GetDBCredentials(strCredentialSet) - - if strCredentialSet == "DKTG": - con = _InitDBConnection_appdb(dctCredentials, strCredentialSet, strRaw, strAdmin, user_override, pwd_override) - else: - flavor_service = get_flavor_service(dctCredentials["dbtype"]) - flavor_service.init(dctCredentials) - con = _InitDBConnection_target_db(flavor_service, strCredentialSet, strRaw, user_override, pwd_override) - return con - - -def _InitDBConnection_appdb( - dctCredentials, strCredentialSet, strRaw="N", strAdmin="N", user_override=None, pwd_override=None -): - # Get DB Credentials - dctCredentials = _GetDBCredentials(strCredentialSet) - - # Set DB Credential Overrides for Admin connections - # strAdmin = "N": Log into DB/schema for normal stuff - # strAdmin = "D": Log into postgres/public to create DB via override user/password - # strAdmin = "S": Log into DB/public to create schema and run scripts via override user/password - if strAdmin in {"D", "S"}: - dctCredentials["user"] = user_override - dctCredentials["dbschema"] = "public" - if strAdmin == "D": - dctCredentials["dbname"] = "postgres" - - # Get DBEngine using credentials - if strCredentialSet in dctDBEngines and strAdmin == "N": - # Retrieve existing engine from store - dbEngine = dctDBEngines[strCredentialSet] - else: - # Handle Admin overrides or circumstantial password override - if strAdmin in {"D", "S"} or pwd_override is not None: - strPW = pwd_override - else: - strPW = _GetDBPassword(strCredentialSet) - - # Open a new engine with appropriate connection parms - # STANDARD FORMAT: strConnect = 'flavor://username:password@host:port/database' - strConnect = "{}://{}:{}@{}:{}/{}".format( - dctCredentials["flavor"], - dctCredentials["user"], - quote_plus(strPW), - dctCredentials["host"], - dctCredentials["port"], - dctCredentials["dbname"], - ) - try: - # Timeout in seconds: 1 hour = 60 * 60 second = 3600 - dbEngine = create_engine(strConnect, connect_args={"connect_timeout": 3600}) - dctDBEngines[strCredentialSet] = dbEngine - - except SQLAlchemyError as e: - raise ValueError( - f"Failed to create engine (Admin={strAdmin}) \ - for database {dctCredentials['dbname']}" - ) from e + return quoted_values + return csv_row - # Second, create a connection from our engine - try: - if strRaw == "N": - con = dbEngine.connect() - if strAdmin == "N": - strSchemaSQL = f"SET SEARCH_PATH = {dctCredentials['dbschema']};" - con.execute(text(strSchemaSQL)) - else: - con = dbEngine.raw_connection() - strSchemaSQL = "SET SEARCH_PATH = " + dctCredentials["dbschema"] - with con.cursor() as cur: - cur.execute(strSchemaSQL) - con.commit() - except SQLAlchemyError as e: - raise ValueError("Failed to connect to database " + dctCredentials["dbname"]) from e - - return con - - -def _InitDBConnection_target_db(flavor_service, strCredentialSet, strRaw="N", user_override=None, pwd_override=None): - # Get DBEngine using credentials - if strCredentialSet in dctDBEngines: - # Retrieve existing engine from store - dbEngine = dctDBEngines[strCredentialSet] - else: - # Handle user override - if user_override is not None: - flavor_service.override_user(user_override) - # Handle password override - if pwd_override is not None: - strPW = pwd_override - elif not flavor_service.is_connect_by_key(): - strPW = _GetDBPassword(strCredentialSet) - else: - strPW = None - # Open a new engine with appropriate connection parms - is_password_overwritten = pwd_override is not None - strConnect = flavor_service.get_connection_string(strPW, is_password_overwritten) +def empty_cache() -> None: + engine_cache.app_db = None + engine_cache.target_db = None - connect_args = flavor_service.get_connect_args(is_password_overwritten) - try: - # Timeout in seconds: 1 hour = 60 * 60 second = 3600 - dbEngine = create_engine(strConnect, connect_args=connect_args) - dctDBEngines[strCredentialSet] = dbEngine - - except SQLAlchemyError as e: - raise ValueError(f"Failed to create engine for database {flavor_service.get_db_name}") from e - - # Second, create a connection from our engine - queries = flavor_service.get_pre_connection_queries() - if strRaw == "N": - connection = dbEngine.connect() - for query in queries: - try: - connection.execute(text(query)) - except Exception: - LOG.warning( - f"failed executing pre connection query: `{query}`", - exc_info=settings.IS_DEBUG, - stack_info=settings.IS_DEBUG, - ) - else: - connection = dbEngine.raw_connection() - with connection.cursor() as cur: - for query in queries: - try: - cur.execute(query) - except Exception: - LOG.warning( - f"failed executing pre connection query: `{query}`", - exc_info=settings.IS_DEBUG, - stack_info=settings.IS_DEBUG, - ) - connection.commit() +def set_target_db_params(connection_params: ConnectionParams) -> None: + global target_db_params + target_db_params = dict(connection_params) - return connection +def get_flavor_service(flavor: SQLFlavor) -> FlavorService: + module_path = f"testgen.common.database.flavor.{flavor}_flavor_service" + class_name = f"{flavor.capitalize()}FlavorService" + module = importlib.import_module(module_path) + flavor_class = getattr(module, class_name) + return flavor_class() -def CreateDatabaseIfNotExists(strDBName: str, params_mapping: dict, delete_db: bool, drop_users_and_roles: bool = True): - LOG.info("CurrentDB Operation: CreateDatabase. Creds: DKTG Admin") - con = _InitDBConnection( - "DKTG", - strAdmin="D", - user_override=params_mapping["TESTGEN_ADMIN_USER"], - pwd_override=params_mapping["TESTGEN_ADMIN_PASSWORD"], +class CreateDatabaseParams(TypedDict): + TESTGEN_ADMIN_USER: str + TESTGEN_ADMIN_PASSWORD: str + TESTGEN_USER: str | None + TESTGEN_REPORT_USER: str | None + +def create_database( + database_name: str, + params: CreateDatabaseParams, + drop_existing: bool = False, + drop_users_and_roles: bool = False, +) -> None: + LOG.info("DB operation: create_database on App database (User type = database_admin)") + + connection = _init_db_connection( + user_override=params["TESTGEN_ADMIN_USER"], + password_override=params["TESTGEN_ADMIN_PASSWORD"], + user_type="database_admin", ) - con.execute("commit") - - # Catch and ignore error if database already exists - with con: - if delete_db: - con.execute( - f"SELECT pg_terminate_backend(pg_stat_activity.pid) FROM pg_stat_activity WHERE pg_stat_activity.datname = '{strDBName}'" + connection.execute("commit") + + with connection: + if drop_existing: + connection.execute( + text( + "SELECT pg_terminate_backend(pg_stat_activity.pid) FROM pg_stat_activity WHERE pg_stat_activity.datname = :database_name" + ), + {"database_name": database_name}, ) - con.execute("commit") - con.execute(f"DROP DATABASE IF EXISTS {strDBName}") - con.execute("commit") + connection.execute("commit") + connection.execute(f"DROP DATABASE IF EXISTS {database_name}") + connection.execute("commit") if drop_users_and_roles: - con.execute(replace_params("DROP USER IF EXISTS {TESTGEN_USER}", params_mapping)) - con.execute(replace_params("DROP USER IF EXISTS {TESTGEN_REPORT_USER}", params_mapping)) - con.execute("DROP ROLE IF EXISTS testgen_execute_role") - con.execute("DROP ROLE IF EXISTS testgen_report_role") - con.execute("commit") + if user := params.get("TESTGEN_USER"): + connection.execute(f"DROP USER IF EXISTS {user}") + if report_user := params.get("TESTGEN_REPORT_USER"): + connection.execute(f"DROP USER IF EXISTS {report_user}") + connection.execute("DROP ROLE IF EXISTS testgen_execute_role") + connection.execute("DROP ROLE IF EXISTS testgen_report_role") + connection.execute("commit") with suppress(ProgrammingError): - con.execute("create database " + strDBName) - con.close() - - -def RunActionQueryList(strCredentialSet, lstQueries, strAdminNDS="N", user_override=None, pwd_override=None): - LOG.info("CurrentDB Operation: RunActionQueryList. Creds: %s", strCredentialSet) - - with _InitDBConnection( - strCredentialSet, strAdmin=strAdminNDS, user_override=user_override, pwd_override=pwd_override - ) as con: - i = 0 - n = len(lstQueries) - insert_ids = [] - row_counts = [] - if n == 0: + connection.execute(f"CREATE DATABASE {database_name}") + connection.close() + + +def execute_db_queries( + queries: list[tuple[str, dict | None]], + use_target_db: bool = False, + user_override: str | None = None, + password_override: str | None = None, + user_type: UserType = "normal", +) -> tuple[list[Any], list[int]]: + LOG.info(f"DB operation: execute_db_queries on {'Target' if use_target_db else 'App'} database (User type = {user_type})") + + with _init_db_connection(use_target_db, user_override, password_override, user_type) as connection: + return_values: list[Any] = [] + row_counts: list[int] = [] + if not queries: LOG.info("No queries to process") - for q in lstQueries: - i += 1 - LOG.debug(f"LastQuery = {q}") - LOG.info(f"(Processing {i} of {n})") - tx = con.begin() - exQ = con.execute(text(q)) - row_counts.append(exQ.rowcount) - if exQ.rowcount == -1: - strMsg = "Action query processed no records." + for index, (query, params) in enumerate(queries): + LOG.debug(f"Query: {query}") + LOG.info(f"Processing {index + 1} of {len(queries)} queries") + transaction = connection.begin() + result = connection.execute(text(query), params) + row_counts.append(result.rowcount) + if result.rowcount == -1: + message = "No records processed" else: - strMsg = str(exQ.rowcount) + " records processed." + message = f"{result.rowcount} records processed" try: - insert_ids.append(exQ.fetchone()[0]) + return_values.append(result.fetchone()[0]) except Exception: - insert_ids.append(None) + return_values.append(None) - tx.commit() - LOG.info(strMsg) + transaction.commit() + LOG.debug(message) - return insert_ids, row_counts + return return_values, row_counts +def fetch_from_db_threaded( + queries: list[tuple[str, dict | None]], + use_target_db: bool = False, + max_threads: int | None = None, + spinner: Spinner | None = None, +) -> tuple[list[LegacyRow], list[str], int]: + LOG.info(f"DB operation: fetch_from_db_threaded on {'Target' if use_target_db else 'App'} database (User type = normal)") -def RunRetrievalQueryList(strCredentialSet, lstQueries): - LOG.info("CurrentDB Operation: RunRetrievalQueryList. Creds: %s", strCredentialSet) + result_data = [] + result_columns: list[str] = [] + error_count = 0 - with _InitDBConnection(strCredentialSet) as con: - colNames = None - lstResults = [] - i = 0 - n = len(lstQueries) - if n == 0: - LOG.info("No queries to process") - for q in lstQueries: - i += 1 - LOG.debug("LastQuery = %s", q) - LOG.info("(Processing %s of %s)", i, n) + if not max_threads or max_threads < 1 or max_threads > 10: + max_threads = 4 - exQ = con.execute(text(q)) - lstOneResult = exQ.fetchall() - if not colNames: - colNames = exQ.keys() - strRows = str(exQ.rowcount) - lstResults.extend(lstOneResult) + queue = qu.Queue() + for item in queries: + queue.put(item) - LOG.info("%s records retrieved.", strRows) + threaded_fetch = _ThreadedFetch(use_target_db, threading.Lock()) - return lstResults, colNames + with concurrent.futures.ThreadPoolExecutor(max_workers=max_threads) as executor: + try: + futures = [] + while not queue.empty(): + query, params = queue.get() + futures.append(executor.submit(threaded_fetch, query, params)) + for future in futures: + row_data, column_names, has_errors = future.result() + if spinner: + spinner.next() + error_count += 1 if has_errors else 0 + if row_data: + result_data.append(row_data) + result_columns = column_names -class _CThreadedFetch: - def __init__(self, strCredentialSet, count_lock): - self.strCredentialSet = strCredentialSet - self.count_lock = count_lock - self.count = 0 + except Exception: + LOG.exception("Failed to execute threaded queries") - def __call__(self, strQuery): - colNames = None - lstResult = None - booError = False + # Flatten nested lists + result_data = [element for sublist in result_data for element in sublist] + return result_data, result_columns, error_count - with self.count_lock: - self.count += 1 - i = self.count - try: - with _InitDBConnection(self.strCredentialSet) as con: - try: - exQ = con.execute(text(strQuery)) - lstResult = exQ.fetchall() - if not colNames: - colNames = exQ.keys() - LOG.info("(Processed Threaded Query %s on thread %s)", i, threading.current_thread().name) - except Exception: - LOG.exception(f"Failed Query. LastQuery: {strQuery}") - booError = True - except Exception as e: - LOG.info("LastQuery: %s", strQuery) - raise ValueError(f"Failed to execute threaded query: {e}") from e - else: - return lstResult, colNames, booError +def fetch_list_from_db( + query: str, params: dict | None = None, use_target_db: bool = False +) -> tuple[list[LegacyRow], list[str]]: + LOG.info(f"DB operation: fetch_list_from_db on {'Target' if use_target_db else 'App'} database (User type = normal)") + with _init_db_connection(use_target_db) as connection: + LOG.debug(f"Query: {query}") + result = connection.execute(text(query), params) + row_data = result.fetchall() + column_names = result.keys() + LOG.debug(f"{result.rowcount} records retrieved") -def RunThreadedRetrievalQueryList(strCredentialSet, lstQueries, intMaxThreads, spinner): - LOG.info("CurrentDB Operation: RunThreadedRetrievalQueryList. Creds: %s", strCredentialSet) + return row_data, column_names - lstResults = [] - colNames = [] - intErrors = 0 - if intMaxThreads is None: - intMaxThreads = 4 - elif intMaxThreads < 1 or intMaxThreads > 10: - intMaxThreads = 4 +def fetch_dict_from_db( + query: str, params: dict | None = None, use_target_db: bool = False +) -> list[RowMapping]: + LOG.info(f"DB operation: fetch_dict_from_db on {'Target' if use_target_db else 'App'} database (User type = normal)") - qq = qu.Queue() + with _init_db_connection(use_target_db) as connection: + LOG.debug(f"Query: {query}") + result = connection.execute(text(query), params) + LOG.debug(f"{result.rowcount} records retrieved") + # Creates list of dictionaries so records are addressible by column name + return [row._mapping for row in result] - for query in lstQueries: - qq.put(query) - # Initialize count and lock - count_lock = threading.Lock() +def write_to_app_db(data: list[LegacyRow], column_names: list[str], table_name: str) -> None: + LOG.info("DB operation: write_to_app_db on App database (User type = normal)") - clsThreadedFetch = _CThreadedFetch(strCredentialSet, count_lock) + # use_raw is required to make use of the copy_expert method for fast batch ingestion + connection = _init_db_connection(use_raw=True) + cursor = connection.cursor() - with concurrent.futures.ThreadPoolExecutor(max_workers=intMaxThreads) as executor: - try: - futures = [] - while not qq.empty(): - query = qq.get() - futures.append(executor.submit(clsThreadedFetch, query)) + # Write List to CSV in memory + buffer = FilteredStringIO(["\x00"]) + writer = csv.writer(buffer, quoting=csv.QUOTE_MINIMAL) + writer.writerows(data) + buffer.seek(0) - for future in futures: - lstOneResult, colName, booError = future.result() - if spinner: - spinner.next() - intErrors += 1 if booError else 0 - if lstOneResult: - lstResults.append(lstOneResult) - colNames = colName - - except Exception: - LOG.exception("Failed to execute threaded queries") + # List should have same column names as destination table, though not all columns in table are required + query = psycopg2.sql.SQL("COPY {table_name} ({column_names}) FROM STDIN WITH (FORMAT CSV)").format( + table_name=psycopg2.sql.Identifier(table_name), + column_names=psycopg2.sql.SQL(", ").join([psycopg2.sql.Identifier(column) for column in column_names]), + ) + LOG.debug(f"Query: {query}") + cursor.copy_expert(query, buffer) + connection.commit() + connection.close() - lstResults = [element for sublist in lstResults for element in sublist] - return lstResults, colNames, intErrors +def replace_params(query: str, params: dict[str, Any]) -> str: + for key, value in params.items(): + query = query.replace(f"{{{key}}}", "" if value is None else str(value)) + return query -def RetrieveDBResultsToList(strCredentialSet, strRunSQL): - LOG.info("CurrentDB Operation: RetrieveDBResultsToList. Creds: %s", strCredentialSet) +def get_queries_for_command( + sub_directory: str, params: dict[str, Any], mask: str = r"^.*sql$", path: str | None = None +) -> list[str]: + files = sorted(get_template_files(mask=mask, sub_directory=sub_directory, path=path), key=lambda key: str(key)) - with _InitDBConnection(strCredentialSet) as con: - exQ = con.execute(text(strRunSQL)) - lstResults = exQ.fetchall() - colNames = exQ.keys() + queries = [] + for file in files: + query = file.read_text("utf-8") + template = replace_params(query, params) - LOG.debug("Last Query='%s'", strRunSQL) - LOG.debug("%s records retrieved.", exQ.rowcount) + queries.append(template) - return lstResults, colNames + if len(queries) == 0: + LOG.warning(f"No sql files were found for the mask {mask} in subdirectory {sub_directory}") + return queries -def RetrieveDBResultsToDictList(strCredentialSet, strRunSQL): - LOG.info("CurrentDB Operation: RetrieveDBResultsToDictList. Creds: %s", strCredentialSet) - LOG.info("(Processing Query)") - with _InitDBConnection(strCredentialSet) as con: - LOG.debug("Last Query='%s'", strRunSQL) - exQ = con.execute(text(strRunSQL)) +def _init_db_connection( + use_target_db: bool = False, + user_override: str | None = None, + password_override: str | None = None, + user_type: UserType = "normal", + use_raw: bool = False, +) -> Connection: + if use_target_db: + return _init_target_db_connection() + return _init_app_db_connection(user_override, password_override, user_type, use_raw) + + +def _init_app_db_connection( + user_override: str | None = None, + password_override: str | None = None, + user_type: UserType = "normal", + use_raw: bool = False, +) -> Connection | _ConnectionFairy: + database_name = "postgres" if user_type == "database_admin" else get_tg_db() + is_admin = user_type == "database_admin" or user_type == "schema_admin" + + engine = None + if user_type == "normal": + engine = engine_cache.app_db + + if not engine: + user = user_override if is_admin else get_tg_username() + password = password_override if (is_admin or password_override is not None) else get_tg_password() + + # STANDARD FORMAT: flavor://username:password@host:port/database + connection_string = ( + f"postgresql://{user}:{quote_plus(password)}@{get_tg_host()}:{get_tg_port()}/{database_name}" + ) + try: + engine: Engine = create_engine(connection_string, connect_args={"connect_timeout": 3600}) + engine_cache.app_db = engine - # Creates list of dictionaries so records are addressible by column name - lstResults = [row._mapping for row in exQ] - LOG.debug("%s records retrieved.", exQ.rowcount) + except SQLAlchemyError as e: + raise ValueError(f"Failed to create engine for App database '{database_name}' (User type = {user_type})") from e - return lstResults + try: + schema_name = "public" if is_admin else get_tg_schema() + if use_raw: + connection: _ConnectionFairy = engine.raw_connection() + with connection.cursor() as cursor: + cursor.execute( + "SET SEARCH_PATH = %(schema_name)s", + {"schema_name": schema_name}, + ) + connection.commit() + else: + connection: Connection = engine.connect() + if user_type == "normal": + connection.execute( + text("SET SEARCH_PATH = :schema_name;"), + {"schema_name": schema_name}, + ) + except SQLAlchemyError as e: + raise ValueError(f"Failed to connect to App database '{database_name}'") from e + return connection -def ExecuteDBQuery(strCredentialSet, strRunSQL): - LOG.info("CurrentDB Operation: ExecuteDBQuery. Creds: %s", strCredentialSet) - LOG.info("(Processing Query)") - with _InitDBConnection(strCredentialSet) as con: - LOG.debug("Last Query='%s'", strRunSQL) - con.execute(text(strRunSQL)) - con.execute("commit") - LOG.debug("Query ran.") +def _init_target_db_connection() -> Connection: + if not target_db_params: + raise ValueError("Target database connection parameters were not set") + flavor_service = get_flavor_service(target_db_params["sql_flavor"]) + flavor_service.init(target_db_params) -def RetrieveSingleResultValue(strCredentialSet, strRunSQL): - LOG.debug("CurrentDB Operation: RetrieveSingleResultValue. Creds: %s", strCredentialSet) + engine = engine_cache.target_db + if not engine: + connection_string = flavor_service.get_connection_string() + connect_args = flavor_service.get_connect_args() - with _InitDBConnection(strCredentialSet) as con: - LOG.debug("Last Query='%s'", strRunSQL) - lstResult = con.execute(text(strRunSQL)).fetchone() - if lstResult: - LOG.debug("Single result retrieved.") - valReturn = lstResult[0] - return valReturn - else: - LOG.debug("Single result NOT retrieved.") + try: + engine: Engine = create_engine(connection_string, connect_args=connect_args) + engine_cache.target_db = engine + except SQLAlchemyError as e: + raise ValueError(f"Failed to create engine for Target database '{flavor_service.dbname}' (User type = normal)") from e -def WriteListToDB(strCredentialSet, lstData, lstColumns, strDBTable): - LOG.info("CurrentDB Operation: WriteListToDB. Creds: %s", strCredentialSet) - LOG.debug("(Processing ingestion query: %s records)", lstData) + connection: Connection = engine.connect() - # List should have same column names as destination table, though not all columns in table are required - # Use COPY for DKTG database, otherwise executemany() - con = _InitDBConnection(strCredentialSet, "Y") - cur = con.cursor() - if strCredentialSet == "DKTG": - # Write List to CSV in memory - sio = FilteredStringIO(["\x00"]) - writer = csv.writer(sio, quoting=csv.QUOTE_MINIMAL) - writer.writerows(lstData) - sio.seek(0) - - # Get list of column names for COPY statement - strColumnNames = ", ".join(lstColumns) - strCopySQL = f"COPY {strDBTable} ({strColumnNames}) FROM STDIN WITH (FORMAT CSV)" - LOG.debug("Last Query='%s'", strCopySQL) - - cur.copy_expert(strCopySQL, sio) - con.commit() - else: - # Get list of column names and column names formatted as parms - strColumnNames = ", ".join(lstColumns) - lstColumnParms = [":" + column_name for column_name in lstColumns] - strColumnParms = ", ".join(lstColumnParms) - - # Prep data as list of dictionaries - lstRowDicts = [dict(row) for row in lstData] - - strInsertSQL = "INSERT INTO " + strDBTable + "(" + strColumnNames + ")" + " VALUES (" + strColumnParms + ")" - LOG.debug("Last Query='%s'", strInsertSQL) - - exQ = con.execute(text(strInsertSQL), lstRowDicts) - con.commit() - LOG.debug("%s records saved", exQ.rowcount) - con.close() - - -def replace_params(query: str, params_mapping: dict) -> str: - for key, value in params_mapping.items(): - query = query.replace(f"{{{key}}}", str(value)) - return query + for query, params in flavor_service.get_pre_connection_queries(): + try: + connection.execute(text(query), params) + except Exception: + LOG.warning( + f"Failed to execute preconnection query on Target database: {query}", + exc_info=settings.IS_DEBUG, + stack_info=settings.IS_DEBUG, + ) + return connection -def get_queries_for_command(sub_directory: str, params_mapping: dict, mask: str = r"^.*sql$", path: str | None = None) -> list[str]: - files = sorted(get_template_files(mask=mask, sub_directory=sub_directory, path=path), key=lambda key: str(key)) - queries = [] - for file in files: - query = file.read_text("utf-8") - template = replace_params(query, params_mapping) +class _ThreadedFetch: + def __init__(self, use_target_db: bool, count_lock: threading.Lock): + self.use_target_db = use_target_db + self.count_lock = count_lock + self.count = 0 - queries.append(template) + def __call__(self, query: str, params: dict | None = None) -> tuple[list[LegacyRow], list[str], bool]: + LOG.debug(f"Query: {query}") + column_names: list[str] = [] + row_data: list = None + has_errors = False - if len(queries) == 0: - LOG.warning(f"No sql files were found for the mask {mask} in subdirectory {sub_directory}") + with self.count_lock: + self.count += 1 + i = self.count - return queries + try: + with _init_db_connection(self.use_target_db) as connection: + try: + result = connection.execute(text(query), params) + LOG.debug(f"{result.rowcount} records retrieved") + row_data = result.fetchall() + if not column_names: + column_names = result.keys() + LOG.info(f"Processed threaded query {i} on thread {threading.current_thread().name}") + except Exception: + LOG.exception(f"Failed to execute threaded query: {query}") + has_errors = True + except Exception as e: + raise ValueError(f"Failed to execute threaded query: {e}") from e + else: + return row_data, list(column_names), has_errors diff --git a/testgen/common/database/flavor/databricks_flavor_service.py b/testgen/common/database/flavor/databricks_flavor_service.py index a31367f5..9ef750a9 100644 --- a/testgen/common/database/flavor/databricks_flavor_service.py +++ b/testgen/common/database/flavor/databricks_flavor_service.py @@ -4,13 +4,12 @@ class DatabricksFlavorService(FlavorService): - def get_connection_string_head(self, strPW): - strConnect = f"{self.flavor}://{self.username}:{quote_plus(strPW)}@" - return strConnect - def get_connection_string_from_fields(self, strPW, is_password_overwritten: bool = False): # NOQA ARG002 - strConnect = ( - f"{self.flavor}://{self.username}:{quote_plus(strPW)}@{self.host}:{self.port}/{self.dbname}" + def get_connection_string_head(self): + return f"{self.flavor}://{self.username}:{quote_plus(self.password)}@" + + def get_connection_string_from_fields(self): + return ( + f"{self.flavor}://{self.username}:{quote_plus(self.password)}@{self.host}:{self.port}/{self.dbname}" f"?http_path={self.http_path}" ) - return strConnect diff --git a/testgen/common/database/flavor/flavor_service.py b/testgen/common/database/flavor/flavor_service.py index 7b7f7246..986cb64f 100644 --- a/testgen/common/database/flavor/flavor_service.py +++ b/testgen/common/database/flavor/flavor_service.py @@ -1,13 +1,32 @@ from abc import abstractmethod +from typing import Literal, TypedDict from testgen.common.encrypt import DecryptText +SQLFlavor = Literal["redshift", "snowflake", "mssql", "postgresql", "databricks"] + + +class ConnectionParams(TypedDict): + sql_flavor: SQLFlavor + project_host: str + project_port: str + project_user: str + project_db: str + table_group_schema: str + project_pw_encrypted: bytes + url: str + connect_by_url: bool + connect_by_key: bool + private_key: bytes + private_key_passphrase: bytes + http_path: str class FlavorService: url = None connect_by_url = None username = None + password = None host = None port = None dbname = None @@ -19,59 +38,55 @@ class FlavorService: http_path = None catalog = None - def init(self, connection_params: dict): + def init(self, connection_params: ConnectionParams): self.url = connection_params.get("url", None) self.connect_by_url = connection_params.get("connect_by_url", False) - self.username = connection_params.get("user") - self.host = connection_params.get("host") - self.port = connection_params.get("port") - self.dbname = connection_params.get("dbname") - self.flavor = connection_params.get("flavor") - self.dbschema = connection_params.get("dbschema", None) + self.username = connection_params.get("project_user") + self.host = connection_params.get("project_host") + self.port = connection_params.get("project_port") + self.dbname = connection_params.get("project_db") + self.flavor = connection_params.get("sql_flavor") + self.dbschema = connection_params.get("table_group_schema", None) self.connect_by_key = connection_params.get("connect_by_key", False) self.http_path = connection_params.get("http_path", None) self.catalog = connection_params.get("catalog", None) + password = connection_params.get("project_pw_encrypted", None) + if isinstance(password, memoryview) or isinstance(password, bytes): + password = DecryptText(password) + self.password = password + private_key = connection_params.get("private_key", None) - if isinstance(private_key, memoryview): + if isinstance(private_key, memoryview) or isinstance(private_key, bytes): private_key = DecryptText(private_key) self.private_key = private_key private_key_passphrase = connection_params.get("private_key_passphrase", None) - if isinstance(private_key_passphrase, memoryview): + if isinstance(private_key_passphrase, memoryview) or isinstance(private_key_passphrase, bytes): private_key_passphrase = DecryptText(private_key_passphrase) self.private_key_passphrase = private_key_passphrase - def override_user(self, user_override: str): - self.username = user_override - - def get_db_name(self) -> str: - return self.dbname - - def is_connect_by_key(self) -> str: - return self.connect_by_key - - def get_pre_connection_queries(self) -> list[str]: + def get_pre_connection_queries(self) -> list[tuple[str, dict | None]]: return [] - - def get_connect_args(self, _is_password_overwritten: bool = False) -> dict: + + def get_connect_args(self) -> dict: return {"connect_timeout": 3600} def get_concat_operator(self) -> str: return "||" - def get_connection_string(self, strPW, is_password_overwritten: bool = False): + def get_connection_string(self) -> str: if self.connect_by_url: - header = self.get_connection_string_head(strPW) + header = self.get_connection_string_head() url = header + self.url return url else: - return self.get_connection_string_from_fields(strPW, is_password_overwritten) + return self.get_connection_string_from_fields() @abstractmethod - def get_connection_string_from_fields(self, strPW, is_password_overwritten: bool = False): + def get_connection_string_from_fields(self) -> str: raise NotImplementedError("Subclasses must implement this method") @abstractmethod - def get_connection_string_head(self, strPW): + def get_connection_string_head(self) -> str: raise NotImplementedError("Subclasses must implement this method") diff --git a/testgen/common/database/flavor/mssql_flavor_service.py b/testgen/common/database/flavor/mssql_flavor_service.py index d472f3cd..7cdc23fe 100644 --- a/testgen/common/database/flavor/mssql_flavor_service.py +++ b/testgen/common/database/flavor/mssql_flavor_service.py @@ -4,20 +4,13 @@ from testgen.common.database.flavor.flavor_service import FlavorService -class MssqlFlavorService(FlavorService): - def get_connection_string_head(self, strPW): - username = self.username - password = quote_plus(strPW) - - strConnect = f"mssql+pyodbc://{username}:{password}@" - - return strConnect - - def get_connection_string_from_fields(self, strPW, is_password_overwritten: bool = False): # NOQA ARG002 - password = quote_plus(strPW) +class MssqlFlavorService(FlavorService): + def get_connection_string_head(self): + return f"mssql+pyodbc://{self.username}:{quote_plus(self.password)}@" + def get_connection_string_from_fields(self): strConnect = ( - f"mssql+pyodbc://{self.username}:{password}@{self.host}:{self.port}/{self.dbname}?driver=ODBC+Driver+18+for+SQL+Server" + f"mssql+pyodbc://{self.username}:{quote_plus(self.password)}@{self.host}:{self.port}/{self.dbname}?driver=ODBC+Driver+18+for+SQL+Server" ) if "synapse" in self.host: @@ -25,14 +18,14 @@ def get_connection_string_from_fields(self, strPW, is_password_overwritten: bool return strConnect - def get_pre_connection_queries(self): # ARG002 + def get_pre_connection_queries(self): return [ - "SET ANSI_DEFAULTS ON;", - "SET TRANSACTION ISOLATION LEVEL READ UNCOMMITTED;", + ("SET ANSI_DEFAULTS ON;", None), + ("SET TRANSACTION ISOLATION LEVEL READ UNCOMMITTED;", None), ] - def get_connect_args(self, is_password_overwritten: bool = False): - connect_args = super().get_connect_args(is_password_overwritten) + def get_connect_args(self): + connect_args = super().get_connect_args() if settings.SKIP_DATABASE_CERTIFICATE_VERIFICATION: connect_args["TrustServerCertificate"] = "yes" return connect_args diff --git a/testgen/common/database/flavor/redshift_flavor_service.py b/testgen/common/database/flavor/redshift_flavor_service.py index 1d29e3f2..ba17105e 100644 --- a/testgen/common/database/flavor/redshift_flavor_service.py +++ b/testgen/common/database/flavor/redshift_flavor_service.py @@ -4,16 +4,14 @@ class RedshiftFlavorService(FlavorService): - def get_connection_string_head(self, strPW): - strConnect = f"{self.flavor}://{self.username}:{quote_plus(strPW)}@" - return strConnect + def init(self, connection_params: dict): + super().init(connection_params) + # This is for connection purposes. sqlalchemy 1.4.46 uses postgresql to connect to redshift database + self.flavor = "postgresql" - def get_connection_string_from_fields(self, strPW, is_password_overwritten: bool = False): # NOQA ARG002 - # STANDARD FORMAT: strConnect = 'flavor://username:password@host:port/database' - strConnect = f"{self.flavor}://{self.username}:{quote_plus(strPW)}@{self.host}:{self.port}/{self.dbname}" - return strConnect + def get_connection_string_head(self): + return f"{self.flavor}://{self.username}:{quote_plus(self.password)}@" - def get_pre_connection_queries(self): - return [ - "SET SEARCH_PATH = '" + self.dbschema + "'", - ] + def get_connection_string_from_fields(self): + # STANDARD FORMAT: strConnect = 'flavor://username:password@host:port/database' + return f"{self.flavor}://{self.username}:{quote_plus(self.password)}@{self.host}:{self.port}/{self.dbname}" diff --git a/testgen/common/database/flavor/snowflake_flavor_service.py b/testgen/common/database/flavor/snowflake_flavor_service.py index c6d213b1..c1636d7f 100644 --- a/testgen/common/database/flavor/snowflake_flavor_service.py +++ b/testgen/common/database/flavor/snowflake_flavor_service.py @@ -8,10 +8,8 @@ class SnowflakeFlavorService(FlavorService): - def get_connect_args(self, is_password_overwritten: bool = False): - connect_args = super().get_connect_args(is_password_overwritten) - - if self.connect_by_key and not is_password_overwritten: + def get_connect_args(self): + if self.connect_by_key: # https://docs.snowflake.com/en/developer-guide/python-connector/sqlalchemy#key-pair-authentication-support private_key_passphrase = self.private_key_passphrase.encode() if self.private_key_passphrase else None private_key = serialization.load_pem_private_key( @@ -26,17 +24,16 @@ def get_connect_args(self, is_password_overwritten: bool = False): encryption_algorithm=serialization.NoEncryption(), ) - connect_args.update({"private_key": private_key_bytes}) - return connect_args + return {"private_key": private_key_bytes} + return {} - def get_connection_string_head(self, strPW): - if self.connect_by_key and not strPW: - strConnect = f"snowflake://{self.username}@" + def get_connection_string_head(self): + if self.connect_by_key: + return f"snowflake://{self.username}@" else: - strConnect = f"snowflake://{self.username}:{quote_plus(strPW)}@" - return strConnect + return f"snowflake://{self.username}:{quote_plus(self.password)}@" - def get_connection_string_from_fields(self, strPW, is_password_overwritten: bool = False): + def get_connection_string_from_fields(self): # SNOWFLAKE FORMAT: strConnect = 'flavor://username:password@host/database' # optionally + '/[schema]' + '?warehouse=xxx' # NOTE: Snowflake host should NOT include ".snowflakecomputing.com" @@ -56,14 +53,13 @@ def get_raw_host_name(host): if self.port != "443": host += ":" + self.port - if self.connect_by_key and not is_password_overwritten: - strConnect = f"snowflake://{self.username}@{host}/{self.dbname}/{self.dbschema}" + if self.connect_by_key: + return f"snowflake://{self.username}@{host}/{self.dbname}/{self.dbschema}" else: - strConnect = f"snowflake://{self.username}:{quote_plus(strPW)}@{host}/{self.dbname}/{self.dbschema}" - return strConnect + return f"snowflake://{self.username}:{quote_plus(self.password)}@{host}/{self.dbname}/{self.dbschema}" - def get_pre_connection_queries(self): # ARG002 + def get_pre_connection_queries(self): return [ - "ALTER SESSION SET MULTI_STATEMENT_COUNT = 0;", - "ALTER SESSION SET WEEK_START = 7;", + ("ALTER SESSION SET MULTI_STATEMENT_COUNT = 0;", None), + ("ALTER SESSION SET WEEK_START = 7;", None), ] diff --git a/testgen/common/database/flavor/trino_flavor_service.py b/testgen/common/database/flavor/trino_flavor_service.py index 12db762b..ce1133cc 100644 --- a/testgen/common/database/flavor/trino_flavor_service.py +++ b/testgen/common/database/flavor/trino_flavor_service.py @@ -4,15 +4,14 @@ class TrinoFlavorService(FlavorService): - def get_connection_string_head(self, strPW): - strConnect = f"{self.flavor}://{self.username}:{quote_plus(strPW)}@" - return strConnect + def get_connection_string_head(self): + return f"{self.flavor}://{self.username}:{quote_plus(self.password)}@" - def get_connection_string_from_fields(self, strPW, is_password_overwritten: bool = False): # NOQA ARG002 + def get_connection_string_from_fields(self): # STANDARD FORMAT: strConnect = 'flavor://username:password@host:port/catalog' - return f"{self.flavor}://{self.username}:{quote_plus(strPW)}@{self.host}:{self.port}/{self.catalog}" + return f"{self.flavor}://{self.username}:{quote_plus(self.password)}@{self.host}:{self.port}/{self.catalog}" def get_pre_connection_queries(self): return [ - "USE " + self.catalog + "." + self.dbschema, + (f"USE {self.catalog}.{self.dbschema}", None), ] diff --git a/testgen/common/get_pipeline_parms.py b/testgen/common/get_pipeline_parms.py index d8d2e213..b651ee1d 100644 --- a/testgen/common/get_pipeline_parms.py +++ b/testgen/common/get_pipeline_parms.py @@ -1,45 +1,71 @@ -from testgen.common.database.database_service import RetrieveDBResultsToDictList +from typing import TypedDict + +from testgen.common.database.database_service import fetch_dict_from_db from testgen.common.read_file import read_template_sql_file -def RetrieveProfilingParms(strTableGroupsID): - strSQL = read_template_sql_file("parms_profiling.sql", "parms") - # Replace Parameters - strSQL = strSQL.replace("{TABLE_GROUPS_ID}", strTableGroupsID) +class BaseParams(TypedDict): + project_code: str + connection_id: str + +class ProfilingParams(BaseParams): + table_groups_id: str + profiling_table_set: str + profiling_include_mask: str + profiling_exclude_mask: str + profile_id_column_mask: str + profile_sk_column_mask: str + profile_use_sampling: str + profile_flag_cdes: bool + profile_sample_percent: str + profile_sample_min_count: int + profile_do_pair_rules: str + profile_pair_rule_pct: int + - # Execute Query - lstParms = RetrieveDBResultsToDictList("DKTG", strSQL) +class TestGenerationParams(BaseParams): + export_to_observability: str + test_suite_id: str + profiling_as_of_date: str - if lstParms is None: - raise ValueError("Project Connection Parameters not found") - return lstParms[0] +class TestExecutionParams(BaseParams): + test_suite_id: str + table_groups_id: str + profiling_table_set: str + profiling_include_mask: str + profiling_exclude_mask: str + sql_flavor: str + max_threads: int + max_query_chars: int -def RetrieveTestGenParms(strTableGroupsID, strTestSuite): - strSQL = read_template_sql_file("parms_test_gen.sql", "parms") - # Replace Parameters - strSQL = strSQL.replace("{TABLE_GROUPS_ID}", strTableGroupsID) - strSQL = strSQL.replace("{TEST_SUITE}", strTestSuite) - # Execute Query - lstParms = RetrieveDBResultsToDictList("DKTG", strSQL) - if len(lstParms) == 0: - raise ValueError("SQL retrieved 0 records") - return lstParms[0] +def get_profiling_params(table_group_id: str) -> ProfilingParams: + results = fetch_dict_from_db( + read_template_sql_file("parms_profiling.sql", "parms"), + {"TABLE_GROUP_ID": table_group_id}, + ) + if not results: + raise ValueError("Connection parameters not found for profiling.") + return ProfilingParams(results[0]) -def RetrieveTestExecParms(strProjectCode, strTestSuite): - strSQL = read_template_sql_file("parms_test_execution.sql", "parms") - # Replace Parameters - strSQL = strSQL.replace("{PROJECT_CODE}", strProjectCode) - strSQL = strSQL.replace("{TEST_SUITE}", strTestSuite) +def get_test_generation_params(table_group_id: str, test_suite: str) -> TestGenerationParams: + results = fetch_dict_from_db( + read_template_sql_file("parms_test_gen.sql", "parms"), + {"TABLE_GROUP_ID": table_group_id, "TEST_SUITE": test_suite}, + ) + if not results: + raise ValueError("Connection parameters not found for test generation.") + return TestGenerationParams(results[0]) - # Execute Query - lstParms = RetrieveDBResultsToDictList("DKTG", strSQL) - if len(lstParms) == 0: - raise ValueError("Test Execution parameters could not be retrieved") - elif len(lstParms) > 1: - raise ValueError("Test Execution parameters returned too many records") - return lstParms[0] +def get_test_execution_params(project_code: str, test_suite: str) -> TestExecutionParams: + results = fetch_dict_from_db( + read_template_sql_file("parms_test_execution.sql", "parms"), + {"PROJECT_CODE": project_code, "TEST_SUITE": test_suite} + ) + if not results: + raise ValueError("Connection parameters not found for test execution.") + return TestExecutionParams(results[0]) diff --git a/testgen/common/mixpanel_service.py b/testgen/common/mixpanel_service.py index fd3908f8..b534cf69 100644 --- a/testgen/common/mixpanel_service.py +++ b/testgen/common/mixpanel_service.py @@ -9,12 +9,10 @@ from urllib.parse import urlencode from urllib.request import Request, urlopen -import streamlit as st - -import testgen.ui.services.database_service as db from testgen import settings from testgen.common.models import with_database_session from testgen.common.models.settings import PersistedSetting, SettingNotFound +from testgen.ui.services.database_service import fetch_one_from_db from testgen.ui.session import session from testgen.utils.singleton import Singleton @@ -92,14 +90,14 @@ def send_mp_request(self, endpoint, payload): except Exception: LOG.exception("Failed to send analytics data") + @with_database_session def get_usage(self): - schema: str = st.session_state["dbschema"] - query = f""" + query = """ SELECT - (SELECT COUNT(*) FROM {schema}.auth_users) AS user_count, - (SELECT COUNT(*) FROM {schema}.projects) AS project_count, - (SELECT COUNT(*) FROM {schema}.connections) AS connection_count, - (SELECT COUNT(*) FROM {schema}.table_groups) AS table_group_count, - (SELECT COUNT(*) FROM {schema}.test_suites) AS test_suite_count; + (SELECT COUNT(*) FROM auth_users) AS user_count, + (SELECT COUNT(*) FROM projects) AS project_count, + (SELECT COUNT(*) FROM connections) AS connection_count, + (SELECT COUNT(*) FROM table_groups) AS table_group_count, + (SELECT COUNT(*) FROM test_suites) AS test_suite_count; """ - return db.retrieve_data(query).iloc[0].to_dict() + return fetch_one_from_db(query) diff --git a/testgen/common/models/connection.py b/testgen/common/models/connection.py new file mode 100644 index 00000000..5f8cd65e --- /dev/null +++ b/testgen/common/models/connection.py @@ -0,0 +1,126 @@ +from collections.abc import Iterable +from dataclasses import dataclass +from typing import Literal, Self +from uuid import UUID, uuid4 + +import streamlit as st +from sqlalchemy import ( + BigInteger, + Boolean, + Column, + ForeignKey, + Identity, + Integer, + String, + asc, + select, +) +from sqlalchemy.dialects import postgresql +from sqlalchemy.orm import InstrumentedAttribute + +from testgen.common.database.flavor.flavor_service import SQLFlavor +from testgen.common.models import get_current_session +from testgen.common.models.custom_types import EncryptedBytea +from testgen.common.models.entity import ENTITY_HASH_FUNCS, Entity, EntityMinimal +from testgen.common.models.table_group import TableGroup +from testgen.utils import is_uuid4 + +SQLFlavorCode = Literal["redshift", "snowflake", "mssql", "azure_mssql", "synapse_mssql", "postgresql", "databricks"] + + +@dataclass +class ConnectionMinimal(EntityMinimal): + project_code: str + connection_id: int + sql_flavor_code: SQLFlavorCode + connection_name: str + + +class Connection(Entity): + __tablename__ = "connections" + + id: UUID = Column(postgresql.UUID(as_uuid=True), default=uuid4) + project_code: str = Column(String, ForeignKey("projects.project_code")) + connection_id: int = Column(BigInteger, Identity(always=True), primary_key=True) + sql_flavor: SQLFlavor = Column(String) + sql_flavor_code: SQLFlavorCode = Column(String) + project_host: str = Column(String) + project_port: str = Column(String) + project_user: str = Column(String) + project_db: str = Column(String) + connection_name: str = Column(String) + project_pw_encrypted: str = Column(EncryptedBytea) + max_threads: int = Column(Integer, default=4) + max_query_chars: int = Column(Integer) + url: str = Column(String, default="") + connect_by_url: bool = Column(Boolean, default=False) + connect_by_key: bool = Column(Boolean, default=False) + private_key: str = Column(EncryptedBytea) + private_key_passphrase: str = Column(EncryptedBytea) + http_path: str = Column(String) + + _get_by = "connection_id" + _default_order_by = (asc(connection_name),) + _minimal_columns = ConnectionMinimal.__annotations__.keys() + + @classmethod + @st.cache_data(show_spinner=False) + def get_minimal(cls, identifier: int) -> ConnectionMinimal | None: + result = cls._get_columns(identifier, cls._minimal_columns) + return ConnectionMinimal(**result) if result else None + + @classmethod + @st.cache_data(show_spinner=False) + def get_by_table_group(cls, table_group_id: str | UUID) -> Self | None: + if not is_uuid4(table_group_id): + return None + + query = select(cls).join(TableGroup).where(TableGroup.id == table_group_id) + return get_current_session().scalars(query).first() + + @classmethod + @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) + def select_minimal_where( + cls, *clauses, order_by: tuple[str | InstrumentedAttribute] = _default_order_by + ) -> Iterable[ConnectionMinimal]: + results = cls._select_columns_where(cls._minimal_columns, *clauses, order_by=order_by) + return [ConnectionMinimal(**row) for row in results] + + @classmethod + def has_running_process(cls, ids: list[str]) -> bool: + table_groups = TableGroup.select_minimal_where(TableGroup.connection_id.in_(ids)) + if table_groups: + return TableGroup.has_running_process([item.id for item in table_groups]) + return False + + @classmethod + def is_in_use(cls, ids: list[str]) -> bool: + table_groups = TableGroup.select_minimal_where(TableGroup.connection_id.in_(ids)) + return len(table_groups) > 0 + + @classmethod + def cascade_delete(cls, ids: list[str]) -> bool: + table_groups = TableGroup.select_minimal_where(TableGroup.connection_id.in_(ids)) + if table_groups: + TableGroup.cascade_delete([item.id for item in table_groups]) + cls.delete_where(cls.connection_id.in_(ids)) + + @classmethod + def clear_cache(cls) -> bool: + super().clear_cache() + cls.get_minimal.clear() + cls.get_by_table_group.clear() + cls.select_minimal_where.clear() + + def save(self) -> None: + if self.connect_by_url and self.url: + url_sections = self.url.split("/") + if url_sections: + host_port = url_sections[0] + host_port_sections = host_port.split(":") + self.project_host = host_port_sections[0] if host_port_sections else host_port + self.project_port = "".join(host_port_sections[1:]) if host_port_sections else "" + if len(url_sections) > 1: + self.project_db = url_sections[1] + + super().save() diff --git a/testgen/common/models/custom_types.py b/testgen/common/models/custom_types.py new file mode 100644 index 00000000..b4a34276 --- /dev/null +++ b/testgen/common/models/custom_types.py @@ -0,0 +1,56 @@ +from datetime import UTC, datetime + +from sqlalchemy import Integer, String, TypeDecorator +from sqlalchemy.dialects import postgresql + +from testgen.common.encrypt import DecryptText, EncryptText + + +class NullIfEmptyString(TypeDecorator): + impl = String + cache_ok = True + + def process_bind_param(self, value: str, _dialect) -> str | None: + return None if value == "" else value + + +class YNString(TypeDecorator): + impl = String + cache_ok = True + + def process_bind_param(self, value: bool | str | None, _dialect) -> str | None: + if isinstance(value, bool): + return "Y" if value else "N" + return value + + def process_result_value(self, value: str | None, _dialect) -> bool | None: + if isinstance(value, str): + return value == "Y" + return value + + +class ZeroIfEmptyInteger(TypeDecorator): + impl = Integer + cache_ok = True + + def process_bind_param(self, value: str | int, _dialect) -> int: + return value or 0 + + +class UpdateTimestamp(TypeDecorator): + impl = postgresql.TIMESTAMP + cache_ok = True + + def process_bind_param(self, _value, _dialect) -> datetime: + return datetime.now(UTC) + + +class EncryptedBytea(TypeDecorator): + impl = postgresql.BYTEA + cache_ok = True + + def process_bind_param(self, value: str, _dialect) -> bytes: + return EncryptText(value).encode("UTF-8") if value is not None else value + + def process_result_value(self, value: bytes, _dialect) -> str: + return DecryptText(value) if value is not None else value diff --git a/testgen/common/models/entity.py b/testgen/common/models/entity.py new file mode 100644 index 00000000..a6175606 --- /dev/null +++ b/testgen/common/models/entity.py @@ -0,0 +1,165 @@ +from collections.abc import Iterable +from dataclasses import asdict, dataclass +from typing import Any, Self +from uuid import UUID + +import streamlit as st +from sqlalchemy import delete, select +from sqlalchemy.dialects import postgresql +from sqlalchemy.orm import InstrumentedAttribute +from sqlalchemy.sql.elements import BinaryExpression + +from testgen.common.models import Base, get_current_session +from testgen.utils import is_uuid4, make_json_safe + +ENTITY_HASH_FUNCS = { + BinaryExpression: lambda x: str(x.compile(compile_kwargs={"literal_binds": True})), + tuple: lambda x: [str(y) for y in x], +} + + +@dataclass +class EntityMinimal: + @classmethod + def columns(cls) -> list[str]: + return list(cls.__annotations__.keys()) + + def to_dict(self, json_safe: bool = False) -> dict[str, Any]: + result = asdict(self) + if json_safe: + return {key: make_json_safe(value) for key, value in result.items()} + return result + + +class Entity(Base): + __abstract__ = True + + _get_by: str = "id" + _default_order_by: tuple[str | InstrumentedAttribute] = ("id",) + + @classmethod + @st.cache_data(show_spinner=False) + def get(cls, identifier: str | int | UUID) -> Self | None: + get_by_column = getattr(cls, cls._get_by) + if isinstance(get_by_column.property.columns[0].type, postgresql.UUID) and not is_uuid4(identifier): + return None + + query = select(cls).where(get_by_column == identifier) + return get_current_session().scalars(query).first() + + @classmethod + def get_minimal(cls, identifier: str | int | UUID) -> Any: + raise NotImplementedError + + @classmethod + def _get_columns( + cls, + identifier: str | int | UUID, + columns: list[str | InstrumentedAttribute], + join_target: Self | None = None, + join_clause: BinaryExpression | None = None, + ) -> Self | None: + get_by_column = getattr(cls, cls._get_by) + if isinstance(get_by_column.property.columns[0].type, postgresql.UUID) and not is_uuid4(identifier): + return None + + if join_target: + select_columns = [ + getattr(cls, col, None) or getattr(join_target, col) if isinstance(col, str) else col for col in columns + ] + query = select(select_columns).join(join_target, join_clause) + else: + select_columns = [getattr(cls, col) if isinstance(col, str) else col for col in columns] + query = select(select_columns) + + query = query.where(get_by_column == identifier) + return get_current_session().execute(query).first() + + @classmethod + @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) + def select_where(cls, *clauses, order_by: tuple[str | InstrumentedAttribute] | None = None) -> Iterable[Self]: + order_by = order_by or cls._default_order_by + query = select(cls).where(*clauses).order_by(*order_by) + return get_current_session().scalars(query).all() + + @classmethod + def select_minimal_where(cls, *clauses, order_by: tuple[str | InstrumentedAttribute]) -> Iterable[Any]: + raise NotImplementedError + + @classmethod + def _select_columns_where( + cls, + columns: list[str | InstrumentedAttribute], + *clauses, + join_target: Self | None = None, + join_clause: BinaryExpression | None = None, + order_by: tuple[str | InstrumentedAttribute] | None = None, + ) -> Self | None: + if join_target: + select_columns = [ + getattr(cls, col, None) or getattr(join_target, col) if isinstance(col, str) else col for col in columns + ] + query = select(select_columns).join(join_target, join_clause) + else: + select_columns = [getattr(cls, col) if isinstance(col, str) else col for col in columns] + query = select(select_columns) + + order_by = order_by or cls._default_order_by + query = query.where(*clauses).order_by(*order_by) + return get_current_session().execute(query).all() + + @classmethod + def has_running_process(cls, ids: list[str]) -> bool: + raise NotImplementedError + + @classmethod + def delete_where(cls, *clauses) -> None: + query = delete(cls).where(*clauses) + db_session = get_current_session() + db_session.execute(query) + db_session.commit() + # We clear all because cached data like Project.select_summary will be affected + st.cache_data.clear() + + @classmethod + def is_in_use(cls, ids: list[str]) -> bool: + raise NotImplementedError + + @classmethod + def cascade_delete(cls, ids: list[str]) -> None: + raise NotImplementedError + + @classmethod + def clear_cache(cls) -> None: + cls.get.clear() + cls.select_where.clear() + + @classmethod + def columns(cls) -> list[str]: + return list(cls.__annotations__.keys()) + + def save(self) -> None: + is_new = self.id is None + db_session = get_current_session() + db_session.add(self) + db_session.flush([self]) + db_session.commit() + db_session.refresh(self, ["id"]) + if is_new: + # We clear all because cached data like Project.select_summary will be affected + st.cache_data.clear() + else: + self.__class__.clear_cache() + + def delete(self) -> None: + db_session = get_current_session() + db_session.add(self) + db_session.delete(self) + db_session.commit() + self.__class__.clear_cache() + + def to_dict(self, json_safe: bool = False): + result = {col.name: getattr(self, col.name) for col in self.__table__.columns} + if json_safe: + return {key: make_json_safe(value) for key, value in result.items()} + return result diff --git a/testgen/common/models/profiling_run.py b/testgen/common/models/profiling_run.py new file mode 100644 index 00000000..f7c68bb0 --- /dev/null +++ b/testgen/common/models/profiling_run.py @@ -0,0 +1,259 @@ +from collections.abc import Iterable +from dataclasses import dataclass +from datetime import UTC, datetime +from typing import Literal, NamedTuple +from uuid import UUID + +import streamlit as st +from sqlalchemy import BigInteger, Column, Float, Integer, String, desc, func, select, text, update +from sqlalchemy.dialects import postgresql +from sqlalchemy.orm import InstrumentedAttribute +from sqlalchemy.sql.expression import case + +from testgen.common.models import get_current_session +from testgen.common.models.entity import ENTITY_HASH_FUNCS, Entity, EntityMinimal +from testgen.common.models.table_group import TableGroup +from testgen.utils import is_uuid4 + +ProfilingRunStatus = Literal["Running", "Complete", "Error", "Cancelled"] + + +@dataclass +class ProfilingRunMinimal(EntityMinimal): + id: UUID + project_code: str + table_groups_id: UUID + table_groups_name: str + profiling_starttime: datetime + dq_score_profiling: float + is_latest_run: bool + + +@dataclass +class ProfilingRunSummary(EntityMinimal): + profiling_run_id: UUID + start_time: datetime + table_groups_name: str + status: ProfilingRunStatus + process_id: int + duration: str + log_message: str + schema_name: str + table_ct: int + column_ct: int + anomaly_ct: int + anomalies_definite_ct: int + anomalies_likely_ct: int + anomalies_possible_ct: int + anomalies_dismissed_ct: int + dq_score_profiling: float + + +class LatestProfilingRun(NamedTuple): + id: str + run_time: datetime + + +class ProfilingRun(Entity): + __tablename__ = "profiling_runs" + + id: UUID = Column(postgresql.UUID(as_uuid=True), primary_key=True) + project_code: str = Column(String, nullable=False) + connection_id: str = Column(BigInteger, nullable=False) + table_groups_id: UUID = Column(postgresql.UUID(as_uuid=True), nullable=False) + profiling_starttime: datetime = Column(postgresql.TIMESTAMP) + profiling_endtime: datetime = Column(postgresql.TIMESTAMP) + status: ProfilingRunStatus = Column(String, default="Running") + log_message: str = Column(String) + table_ct: int = Column(BigInteger) + column_ct: int = Column(BigInteger) + anomaly_ct: int = Column(BigInteger) + anomaly_table_ct: int = Column(BigInteger) + anomaly_column_ct: int = Column(BigInteger) + dq_affected_data_points: int = Column(BigInteger) + dq_total_data_points: int = Column(BigInteger) + dq_score_profiling: float = Column(Float) + process_id: int = Column(Integer) + + _default_order_by = (desc(profiling_starttime),) + _minimal_columns = ( + id, + project_code, + table_groups_id, + TableGroup.table_groups_name, + profiling_starttime, + dq_score_profiling, + case( + (id == TableGroup.last_complete_profile_run_id, True), + else_=False, + ).label("is_latest_run"), + ) + + @classmethod + @st.cache_data(show_spinner=False) + def get_minimal(cls, run_id: str | UUID) -> ProfilingRunMinimal | None: + if not is_uuid4(run_id): + return None + + query = ( + select(cls._minimal_columns).join(TableGroup, cls.table_groups_id == TableGroup.id).where(cls.id == run_id) + ) + result = get_current_session().execute(query).first() + return ProfilingRunMinimal(**result) if result else None + + @classmethod + def get_latest_run(cls, project_code: str) -> LatestProfilingRun | None: + query = ( + select(ProfilingRun.id, ProfilingRun.profiling_starttime) + .where(ProfilingRun.project_code == project_code, ProfilingRun.status == "Complete") + .order_by(desc(ProfilingRun.profiling_starttime)) + .limit(1) + ) + result = get_current_session().execute(query).first() + if result: + return LatestProfilingRun(str(result["id"]), result["profiling_starttime"]) + return None + + @classmethod + @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) + def select_minimal_where( + cls, *clauses, order_by: tuple[str | InstrumentedAttribute] = _default_order_by + ) -> Iterable[ProfilingRunMinimal]: + query = ( + select(cls._minimal_columns) + .join(TableGroup, cls.table_groups_id == TableGroup.id) + .where(*clauses) + .order_by(*order_by) + ) + results = get_current_session().execute(query).all() + return [ProfilingRunMinimal(**row) for row in results] + + @classmethod + @st.cache_data(show_spinner=False) + def select_summary( + cls, project_code: str, table_group_id: str | UUID | None = None, profiling_run_ids: list[str] | None = None + ) -> Iterable[ProfilingRunSummary]: + if (table_group_id and not is_uuid4(table_group_id)) or ( + profiling_run_ids and not all(is_uuid4(run_id) for run_id in profiling_run_ids) + ): + return [] + + query = f""" + WITH profile_anomalies AS ( + SELECT profile_anomaly_results.profile_run_id, + SUM( + CASE + WHEN COALESCE(profile_anomaly_results.disposition, 'Confirmed') = 'Confirmed' + AND profile_anomaly_types.issue_likelihood = 'Definite' THEN 1 + ELSE 0 + END + ) AS definite_ct, + SUM( + CASE + WHEN COALESCE(profile_anomaly_results.disposition, 'Confirmed') = 'Confirmed' + AND profile_anomaly_types.issue_likelihood = 'Likely' THEN 1 + ELSE 0 + END + ) AS likely_ct, + SUM( + CASE + WHEN COALESCE(profile_anomaly_results.disposition, 'Confirmed') = 'Confirmed' + AND profile_anomaly_types.issue_likelihood = 'Possible' THEN 1 + ELSE 0 + END + ) AS possible_ct, + SUM( + CASE + WHEN COALESCE(profile_anomaly_results.disposition, 'Confirmed') IN ('Dismissed', 'Inactive') + AND profile_anomaly_types.issue_likelihood <> 'Potential PII' THEN 1 + ELSE 0 + END + ) AS dismissed_ct + FROM profile_anomaly_results + LEFT JOIN profile_anomaly_types ON ( + profile_anomaly_types.id = profile_anomaly_results.anomaly_id + ) + GROUP BY profile_anomaly_results.profile_run_id + ) + SELECT v_profiling_runs.profiling_run_id, + v_profiling_runs.start_time, + v_profiling_runs.table_groups_name, + v_profiling_runs.status, + v_profiling_runs.process_id, + v_profiling_runs.duration, + v_profiling_runs.log_message, + v_profiling_runs.schema_name, + v_profiling_runs.table_ct, + v_profiling_runs.column_ct, + v_profiling_runs.anomaly_ct, + profile_anomalies.definite_ct AS anomalies_definite_ct, + profile_anomalies.likely_ct AS anomalies_likely_ct, + profile_anomalies.possible_ct AS anomalies_possible_ct, + profile_anomalies.dismissed_ct AS anomalies_dismissed_ct, + v_profiling_runs.dq_score_profiling + FROM v_profiling_runs + LEFT JOIN profile_anomalies ON (v_profiling_runs.profiling_run_id = profile_anomalies.profile_run_id) + WHERE project_code = :project_code + {"AND v_profiling_runs.table_groups_id = :table_group_id" if table_group_id else ""} + {"AND v_profiling_runs.profiling_run_id IN :profiling_run_ids" if profiling_run_ids else ""} + ORDER BY start_time DESC; + """ + params = { + "project_code": project_code, + "table_group_id": table_group_id, + "profiling_run_ids": tuple(profiling_run_ids or []), + } + db_session = get_current_session() + results = db_session.execute(text(query), params).mappings().all() + return [ProfilingRunSummary(**row) for row in results] + + @classmethod + def has_running_process(cls, ids: list[str]) -> bool: + query = select(func.count(cls.id)).where(cls.id.in_(ids), cls.status == "Running") + process_count = get_current_session().execute(query).scalar() + return process_count > 0 + + @classmethod + def cancel_all_running(cls) -> None: + query = ( + update(cls).where(cls.status == "Running").values(status="Cancelled", profiling_endtime=datetime.now(UTC)) + ) + db_session = get_current_session() + db_session.execute(query) + db_session.commit() + cls.clear_cache() + + @classmethod + def update_status(cls, run_id: str | UUID, status: ProfilingRunStatus) -> None: + query = update(cls).where(cls.id == run_id).values(status=status) + db_session = get_current_session() + db_session.execute(query) + db_session.commit() + cls.clear_cache() + + @classmethod + def cascade_delete(cls, ids: list[str]) -> None: + query = """ + DELETE FROM profile_pair_rules + WHERE profile_run_id IN :profiling_run_ids; + + DELETE FROM profile_anomaly_results + WHERE profile_run_id IN :profiling_run_ids; + + DELETE FROM profile_results + WHERE profile_run_id IN :profiling_run_ids; + """ + db_session = get_current_session() + db_session.execute(text(query), {"profiling_run_ids": tuple(ids)}) + db_session.commit() + cls.delete_where(cls.id.in_(ids)) + + @classmethod + def clear_cache(cls) -> bool: + super().clear_cache() + cls.get_minimal.clear() + cls.select_minimal_where.clear() + cls.select_summary.clear() + + def save(self) -> None: + raise NotImplementedError diff --git a/testgen/common/models/project.py b/testgen/common/models/project.py new file mode 100644 index 00000000..e1e32a7b --- /dev/null +++ b/testgen/common/models/project.py @@ -0,0 +1,101 @@ +from dataclasses import dataclass +from uuid import UUID, uuid4 + +import streamlit as st +from sqlalchemy import Column, String, asc, text +from sqlalchemy.dialects import postgresql + +from testgen.common.models import get_current_session +from testgen.common.models.connection import Connection +from testgen.common.models.custom_types import NullIfEmptyString +from testgen.common.models.entity import Entity, EntityMinimal + + +@dataclass +class ProjectSummary(EntityMinimal): + project_code: str + connection_count: int + default_connection_id: int + table_group_count: int + profiling_run_count: int + test_suite_count: int + test_definition_count: int + test_run_count: int + can_export_to_observability: bool + + +class Project(Entity): + __tablename__ = "projects" + + id: UUID = Column(postgresql.UUID(as_uuid=True), default=uuid4) + project_code: str = Column(String, primary_key=True, nullable=False) + project_name: str = Column(String) + observability_api_url: str = Column(NullIfEmptyString) + observability_api_key: str = Column(NullIfEmptyString) + + _get_by = "project_code" + _default_order_by = (asc(project_name),) + + @classmethod + @st.cache_data(show_spinner=False) + def get_summary(cls, project_code: str) -> ProjectSummary | None: + query = """ + SELECT + ( + SELECT COUNT(*) AS count FROM connections WHERE connections.project_code = :project_code + ) AS connection_count, + ( + SELECT connection_id FROM connections WHERE connections.project_code = :project_code LIMIT 1 + ) AS default_connection_id, + ( + SELECT COUNT(*) FROM table_groups WHERE table_groups.project_code = :project_code + ) AS table_group_count, + ( + SELECT COUNT(*) + FROM profiling_runs + LEFT JOIN table_groups ON profiling_runs.table_groups_id = table_groups.id + WHERE table_groups.project_code = :project_code + ) AS profiling_run_count, + ( + SELECT COUNT(*) FROM test_suites WHERE test_suites.project_code = :project_code + ) AS test_suite_count, + ( + SELECT COUNT(*) + FROM test_definitions + LEFT JOIN test_suites ON test_definitions.test_suite_id = test_suites.id + WHERE test_suites.project_code = :project_code + ) AS test_definition_count, + ( + SELECT COUNT(*) + FROM test_runs + LEFT JOIN test_suites ON test_runs.test_suite_id = test_suites.id + WHERE test_suites.project_code = :project_code + ) AS test_run_count, + ( + SELECT COALESCE(observability_api_key, '') <> '' + AND COALESCE(observability_api_url, '') <> '' + FROM projects + WHERE project_code = :project_code + ) AS can_export_to_observability; + """ + + db_session = get_current_session() + result = db_session.execute(text(query), {"project_code": project_code}).first() + return ProjectSummary(**result, project_code=project_code) if result else None + + @classmethod + def is_in_use(cls, ids: list[str]) -> bool: + connections = Connection.select_minimal_where(Connection.project_code.in_(ids)) + return len(connections) > 0 + + @classmethod + def cascade_delete(cls, ids: list[str]) -> bool: + connections = Connection.select_minimal_where(Connection.project_code.in_(ids)) + if connections: + Connection.cascade_delete([item.connection_id for item in connections]) + cls.delete_where(cls.project_code.in_(ids)) + + @classmethod + def clear_cache(cls) -> bool: + super().clear_cache() + cls.get_summary.clear() diff --git a/testgen/common/models/scheduler.py b/testgen/common/models/scheduler.py index 99b7ed33..12a55ffc 100644 --- a/testgen/common/models/scheduler.py +++ b/testgen/common/models/scheduler.py @@ -1,11 +1,11 @@ -import uuid from collections.abc import Iterable from datetime import datetime from typing import Any, Self +from uuid import UUID, uuid4 from cron_converter import Cron from sqlalchemy import Column, String, select -from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.dialects import postgresql from sqlalchemy.orm import InstrumentedAttribute from testgen.common.models import Base, get_current_session @@ -14,12 +14,12 @@ class JobSchedule(Base): __tablename__ = "job_schedules" - id: UUID = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + id: UUID = Column(postgresql.UUID(as_uuid=True), primary_key=True, default=uuid4) project_code: str = Column(String) key: str = Column(String, nullable=False) - args: list[Any] = Column(JSONB, nullable=False, default=[]) - kwargs: dict[str, Any] = Column(JSONB, nullable=False, default={}) + args: list[Any] = Column(postgresql.JSONB, nullable=False, default=[]) + kwargs: dict[str, Any] = Column(postgresql.JSONB, nullable=False, default={}) cron_expr: str = Column(String, nullable=False) cron_tz: str = Column(String, nullable=False) diff --git a/testgen/common/models/scores.py b/testgen/common/models/scores.py index 91dcb144..c6db830b 100644 --- a/testgen/common/models/scores.py +++ b/testgen/common/models/scores.py @@ -1,17 +1,23 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from testgen.common.models.table_group import TableGroup + import enum -import uuid from collections.abc import Iterable from datetime import UTC, datetime from itertools import groupby from typing import Literal, Self, TypedDict +from uuid import UUID, uuid4 -import pandas as pd from sqlalchemy import Boolean, Column, DateTime, Enum, Float, ForeignKey, Integer, String, select, text -from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.dialects import postgresql from sqlalchemy.orm import relationship from testgen.common import read_template_sql_file -from testgen.common.models import Base, engine, get_current_session +from testgen.common.models import Base, get_current_session from testgen.utils import is_uuid4 SCORE_CATEGORIES = [ @@ -44,6 +50,7 @@ "transform_level", "data_product", ] +ScoreTypes = Literal["score", "cde_score"] class ScoreCategory(enum.Enum): @@ -62,33 +69,33 @@ class ScoreCategory(enum.Enum): class ScoreDefinition(Base): __tablename__ = "score_definitions" - id: str = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + id: UUID = Column(postgresql.UUID(as_uuid=True), primary_key=True, default=uuid4) project_code: str = Column(String) name: str = Column(String, nullable=False) total_score: bool = Column(Boolean, default=True, nullable=False) cde_score: bool = Column(Boolean, default=False, nullable=False) category: ScoreCategory | None = Column(Enum(ScoreCategory), nullable=True) - criteria: "ScoreDefinitionCriteria" = relationship( + criteria: ScoreDefinitionCriteria = relationship( "ScoreDefinitionCriteria", cascade="all, delete-orphan", lazy="joined", uselist=False, single_parent=True, ) - results: Iterable["ScoreDefinitionResult"] = relationship( + results: Iterable[ScoreDefinitionResult] = relationship( "ScoreDefinitionResult", cascade="all, delete-orphan", order_by="ScoreDefinitionResult.category", lazy="joined", ) - breakdown: Iterable["ScoreDefinitionBreakdownItem"] = relationship( + breakdown: Iterable[ScoreDefinitionBreakdownItem] = relationship( "ScoreDefinitionBreakdownItem", cascade="all, delete-orphan", order_by="ScoreDefinitionBreakdownItem.impact.desc()", lazy="joined", ) - history: Iterable["ScoreDefinitionResultHistoryEntry"] = relationship( + history: Iterable[ScoreDefinitionResultHistoryEntry] = relationship( "ScoreDefinitionResultHistoryEntry", order_by="ScoreDefinitionResultHistoryEntry.last_run_time.asc()", cascade="all, delete-orphan", @@ -97,23 +104,23 @@ class ScoreDefinition(Base): ) @classmethod - def from_table_group(cls, table_group: dict) -> Self: + def from_table_group(cls, table_group: TableGroup) -> Self: definition = cls() - definition.project_code = table_group["project_code"] - definition.name = table_group["table_groups_name"] + definition.project_code = table_group.project_code + definition.name = table_group.table_groups_name definition.total_score = True definition.cde_score = True definition.category = ScoreCategory.dq_dimension definition.criteria = ScoreDefinitionCriteria( operand="AND", filters=[ - ScoreDefinitionFilter(field="table_groups_name", value=table_group["table_groups_name"]), + ScoreDefinitionFilter(field="table_groups_name", value=table_group.table_groups_name), ], ) return definition @classmethod - def get(cls, id_: str) -> "Self | None": + def get(cls, id_: str) -> Self | None: if not is_uuid4(id_): return None @@ -129,7 +136,7 @@ def all( project_code: str | None = None, name_filter: str | None = None, sorted_by: str | None = "name", - ) -> "Iterable[Self]": + ) -> Iterable[Self]: definitions = [] db_session = get_current_session() query = select(ScoreDefinition) @@ -154,7 +161,7 @@ def delete(self) -> None: db_session.delete(self) db_session.commit() - def as_score_card(self) -> "ScoreCard": + def as_score_card(self) -> ScoreCard: """ Executes and combines two raw queries to build a fresh score card from this definition. @@ -216,7 +223,7 @@ def as_score_card(self) -> "ScoreCard": "definition": self, } - def as_cached_score_card(self) -> "ScoreCard": + def as_cached_score_card(self) -> ScoreCard: """Reads the cached values to build a scorecard""" root_keys: list[str] = ["score", "profiling_score", "testing_score", "cde_score"] score_card: ScoreCard = { @@ -303,9 +310,9 @@ def get_score_card_breakdown( .replace("{records_count_filters}", records_count_filters) .replace("{non_null_columns}", ", ".join(non_null_columns)) ) - results = pd.read_sql_query(query, engine) + results = get_current_session().execute(query).mappings().all() - return [row.to_dict() for _, row in results.iterrows()] + return [dict(row) for row in results] def get_score_card_issues( self, @@ -338,17 +345,17 @@ def get_score_card_issues( dq_dimension_filter = "" if group_by == "dq_dimension": - dq_dimension_filter = f" AND dq_dimension = '{value_}'" + dq_dimension_filter = " AND dq_dimension = :value" query = ( read_template_sql_file(query_template_file, sub_directory="score_cards") .replace("{filters}", filters) .replace("{group_by}", group_by) - .replace("{value}", value_) .replace("{dq_dimension_filter}", dq_dimension_filter) ) - results = pd.read_sql_query(query, engine) - return [row.to_dict() for _, row in results.iterrows()] + params = {"value": value_} + results = get_current_session().execute(text(query), params).mappings().all() + return [dict(row) for row in results] def recalculate_scores_history(self) -> None: """ @@ -362,9 +369,9 @@ def recalculate_scores_history(self) -> None: query = ( read_template_sql_file(template, sub_directory="score_cards") .replace("{filters}", " AND ".join(self._get_raw_query_filters())) - .replace("{definition_id}", str(self.id)) ) - overall_scores = get_current_session().execute(query).mappings().all() + params = {"definition_id": self.id} + overall_scores = get_current_session().execute(text(query), params).mappings().all() current_history: dict[tuple[datetime, str, str], ScoreDefinitionResultHistoryEntry] = {} renewed_history: dict[tuple[datetime, str, str], float] = {} @@ -439,11 +446,11 @@ class ScoreDefinitionCriteria(Base): __tablename__ = "score_definition_criteria" - id: str = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - definition_id: str = Column(UUID(as_uuid=True), ForeignKey("score_definitions.id", ondelete="CASCADE")) + id: UUID = Column(postgresql.UUID(as_uuid=True), primary_key=True, default=uuid4) + definition_id: str = Column(postgresql.UUID(as_uuid=True), ForeignKey("score_definitions.id", ondelete="CASCADE")) operand: Literal["AND", "OR"] = Column(String, nullable=False, default="AND") group_by_field: bool = Column(Boolean, nullable=False, default=True) - filters: list["ScoreDefinitionFilter"] = relationship( + filters: list[ScoreDefinitionFilter] = relationship( "ScoreDefinitionFilter", cascade="all, delete-orphan", lazy="joined", @@ -485,7 +492,7 @@ def has_filters(self) -> bool: return len(self.filters) > 0 @classmethod - def from_filters(cls, filters: list[dict], group_by_field: bool = True) -> "ScoreDefinitionCriteria": + def from_filters(cls, filters: list[dict], group_by_field: bool = True) -> ScoreDefinitionCriteria: chained_filters: list[ScoreDefinitionFilter] = [] for filter_ in filters: root_filter = current_filter = ScoreDefinitionFilter( @@ -507,22 +514,22 @@ def from_filters(cls, filters: list[dict], group_by_field: bool = True) -> "Scor class ScoreDefinitionFilter(Base): __tablename__ = "score_definition_filters" - id: str = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - criteria_id = Column( - UUID(as_uuid=True), + id: UUID = Column(postgresql.UUID(as_uuid=True), primary_key=True, default=uuid4) + criteria_id: UUID = Column( + postgresql.UUID(as_uuid=True), ForeignKey("score_definition_criteria.id", ondelete="CASCADE"), nullable=True, default=None, ) field: str = Column(String, nullable=False) value: str = Column(String, nullable=False) - next_filter_id = Column( - UUID(as_uuid=True), + next_filter_id: UUID = Column( + postgresql.UUID(as_uuid=True), ForeignKey("score_definition_filters.id", ondelete="CASCADE"), nullable=True, default=None, ) - next_filter: "ScoreDefinitionFilter" = relationship( + next_filter: ScoreDefinitionFilter = relationship( "ScoreDefinitionFilter", cascade="all, delete-orphan", lazy="joined", @@ -545,8 +552,8 @@ def get_as_sql(self, prefix: str | None = None, operand: Literal["AND", "OR"] = class ScoreDefinitionResult(Base): __tablename__ = "score_definition_results" - definition_id: str = Column( - UUID(as_uuid=True), + definition_id: UUID = Column( + postgresql.UUID(as_uuid=True), ForeignKey("score_definitions.id", ondelete="CASCADE"), primary_key=True, ) @@ -557,9 +564,9 @@ class ScoreDefinitionResult(Base): class ScoreDefinitionBreakdownItem(Base): __tablename__ = "score_definition_results_breakdown" - id: str = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - definition_id: str = Column( - UUID(as_uuid=True), + id: UUID = Column(postgresql.UUID(as_uuid=True), primary_key=True, default=uuid4) + definition_id: UUID = Column( + postgresql.UUID(as_uuid=True), ForeignKey("score_definitions.id", ondelete="CASCADE"), ) category: str = Column(String, nullable=False) @@ -588,8 +595,8 @@ def filter( *, definition_id: str, category: Categories, - score_type: Literal["score", "cde_score"], - ) -> "Iterable[Self]": + score_type: ScoreTypes, + ) -> Iterable[Self]: items = [] db_session = get_current_session() query = select(ScoreDefinitionBreakdownItem).where( @@ -616,8 +623,8 @@ def to_dict(self) -> dict: class ScoreDefinitionResultHistoryEntry(Base): __tablename__ = "score_definition_results_history" - definition_id: str = Column( - UUID(as_uuid=True), + definition_id: UUID = Column( + postgresql.UUID(as_uuid=True), ForeignKey("score_definitions.id", ondelete="CASCADE"), primary_key=True, ) @@ -637,14 +644,14 @@ def add_as_cutoff(self): add_latest_runs.sql """ # ruff: noqa: RUF027 - query = ( - read_template_sql_file("add_latest_runs.sql", sub_directory="score_cards") - .replace("{project_code}", self.definition.project_code) - .replace("{definition_id}", str(self.definition_id)) - .replace("{score_history_cutoff_time}", self.last_run_time.isoformat()) - ) + query = read_template_sql_file("add_latest_runs.sql", sub_directory="score_cards") + params = { + "project_code": self.definition.project_code, + "definition_id": self.definition_id, + "score_history_cutoff_time": self.last_run_time.isoformat(), + } session = get_current_session() - session.execute(query) + session.execute(text(query), params) class ScoreCard(TypedDict): @@ -655,8 +662,8 @@ class ScoreCard(TypedDict): cde_score: float profiling_score: float testing_score: float - categories: list["CategoryScore"] - history: list["HistoryEntry"] + categories: list[CategoryScore] + history: list[HistoryEntry] definition: ScoreDefinition | None diff --git a/testgen/common/models/table_group.py b/testgen/common/models/table_group.py new file mode 100644 index 00000000..1ed09223 --- /dev/null +++ b/testgen/common/models/table_group.py @@ -0,0 +1,279 @@ +from collections.abc import Iterable +from dataclasses import dataclass +from datetime import datetime +from uuid import UUID, uuid4 + +import streamlit as st +from sqlalchemy import BigInteger, Boolean, Column, Float, ForeignKey, Integer, String, asc, text, update +from sqlalchemy.dialects import postgresql +from sqlalchemy.orm import InstrumentedAttribute + +from testgen.common.models import get_current_session +from testgen.common.models.custom_types import NullIfEmptyString, YNString +from testgen.common.models.entity import ENTITY_HASH_FUNCS, Entity, EntityMinimal +from testgen.common.models.scores import ScoreDefinition +from testgen.common.models.test_suite import TestSuite + + +@dataclass +class TableGroupMinimal(EntityMinimal): + id: UUID + project_code: str + connection_id: int + table_groups_name: str + table_group_schema: str + profiling_table_set: str + profiling_include_mask: str + profiling_exclude_mask: str + profile_use_sampling: bool + profiling_delay_days: str + + +@dataclass +class TableGroupSummary(EntityMinimal): + id: UUID + table_groups_name: str + dq_score_profiling: float + dq_score_testing: float + latest_profile_id: UUID + latest_profile_start: datetime + latest_profile_table_ct: int + latest_profile_column_ct: int + latest_anomalies_ct: int + latest_anomalies_definite_ct: int + latest_anomalies_likely_ct: int + latest_anomalies_possible_ct: int + latest_anomalies_dismissed_ct: int + + +class TableGroup(Entity): + __tablename__ = "table_groups" + + id: UUID = Column(postgresql.UUID(as_uuid=True), primary_key=True, default=uuid4) + project_code: str = Column(String, ForeignKey("projects.project_code")) + connection_id: int = Column(BigInteger, ForeignKey("connections.connection_id")) + table_groups_name: str = Column(String) + table_group_schema: str = Column(String) + profiling_table_set: str = Column(NullIfEmptyString) + profiling_include_mask: str = Column(NullIfEmptyString) + profiling_exclude_mask: str = Column(NullIfEmptyString) + profile_id_column_mask: str = Column(String, default="%id") + profile_sk_column_mask: str = Column(String, default="%_sk") + profile_use_sampling: bool = Column(YNString, default="N") + profile_sample_percent: str = Column(String, default="30") + profile_sample_min_count: int = Column(BigInteger, default=100000) + profiling_delay_days: str = Column(String, default="0") + profile_flag_cdes: bool = Column(Boolean, default=True) + profile_do_pair_rules: bool = Column(YNString, default="N") + profile_pair_rule_pct: int = Column(Integer, default=95) + include_in_dashboard: bool = Column(Boolean, default=True) + description: str = Column(NullIfEmptyString) + data_source: str = Column(NullIfEmptyString) + source_system: str = Column(NullIfEmptyString) + source_process: str = Column(NullIfEmptyString) + data_location: str = Column(NullIfEmptyString) + business_domain: str = Column(NullIfEmptyString) + stakeholder_group: str = Column(NullIfEmptyString) + transform_level: str = Column(NullIfEmptyString) + data_product: str = Column(NullIfEmptyString) + last_complete_profile_run_id: UUID = Column(postgresql.UUID(as_uuid=True)) + dq_score_profiling: float = Column(Float) + dq_score_testing: float = Column(Float) + + _default_order_by = (asc(table_groups_name),) + _minimal_columns = TableGroupMinimal.__annotations__.keys() + _update_exclude_columns = ( + id, + project_code, + connection_id, + profile_do_pair_rules, + profile_pair_rule_pct, + last_complete_profile_run_id, + dq_score_profiling, + dq_score_testing, + ) + + @classmethod + @st.cache_data(show_spinner=False) + def get_minimal(cls, id_: str | UUID) -> TableGroupMinimal | None: + result = cls._get_columns(id_, cls._minimal_columns) + return TableGroupMinimal(**result) if result else None + + @classmethod + @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) + def select_minimal_where( + cls, *clauses, order_by: tuple[str | InstrumentedAttribute] = _default_order_by + ) -> Iterable[TableGroupMinimal]: + results = cls._select_columns_where(cls._minimal_columns, *clauses, order_by=order_by) + return [TableGroupMinimal(**row) for row in results] + + @classmethod + @st.cache_data(show_spinner=False) + def select_summary(cls, project_code: str, for_dashboard: bool = False) -> Iterable[TableGroupSummary]: + query = f""" + WITH latest_profile AS ( + SELECT latest_run.table_groups_id, + latest_run.id, + latest_run.profiling_starttime, + latest_run.table_ct, + latest_run.column_ct, + latest_run.anomaly_ct, + SUM( + CASE + WHEN COALESCE(latest_anomalies.disposition, 'Confirmed') = 'Confirmed' + AND anomaly_types.issue_likelihood = 'Definite' THEN 1 + ELSE 0 + END + ) AS definite_ct, + SUM( + CASE + WHEN COALESCE(latest_anomalies.disposition, 'Confirmed') = 'Confirmed' + AND anomaly_types.issue_likelihood = 'Likely' THEN 1 + ELSE 0 + END + ) AS likely_ct, + SUM( + CASE + WHEN COALESCE(latest_anomalies.disposition, 'Confirmed') = 'Confirmed' + AND anomaly_types.issue_likelihood = 'Possible' THEN 1 + ELSE 0 + END + ) AS possible_ct, + SUM( + CASE + WHEN COALESCE(latest_anomalies.disposition, 'Confirmed') IN ('Dismissed', 'Inactive') + AND anomaly_types.issue_likelihood <> 'Potential PII' THEN 1 + ELSE 0 + END + ) AS dismissed_ct + FROM table_groups groups + LEFT JOIN profiling_runs latest_run ON ( + groups.last_complete_profile_run_id = latest_run.id + ) + LEFT JOIN profile_anomaly_results latest_anomalies ON ( + latest_run.id = latest_anomalies.profile_run_id + ) + LEFT JOIN profile_anomaly_types anomaly_types ON ( + anomaly_types.id = latest_anomalies.anomaly_id + ) + GROUP BY latest_run.id + ) + SELECT groups.id, + groups.table_groups_name, + groups.dq_score_profiling, + groups.dq_score_testing, + latest_profile.id AS latest_profile_id, + latest_profile.profiling_starttime AS latest_profile_start, + latest_profile.table_ct AS latest_profile_table_ct, + latest_profile.column_ct AS latest_profile_column_ct, + latest_profile.anomaly_ct AS latest_anomalies_ct, + latest_profile.definite_ct AS latest_anomalies_definite_ct, + latest_profile.likely_ct AS latest_anomalies_likely_ct, + latest_profile.possible_ct AS latest_anomalies_possible_ct, + latest_profile.dismissed_ct AS latest_anomalies_dismissed_ct + FROM table_groups AS groups + LEFT JOIN latest_profile ON (groups.id = latest_profile.table_groups_id) + WHERE groups.project_code = :project_code + {"AND groups.include_in_dashboard IS TRUE" if for_dashboard else ""}; + """ + params = {"project_code": project_code} + db_session = get_current_session() + results = db_session.execute(text(query), params).mappings().all() + return [TableGroupSummary(**row) for row in results] + + @classmethod + def has_running_process(cls, ids: list[str]) -> bool | None: + query = """ + SELECT DISTINCT profiling_runs.id + FROM profiling_runs + INNER JOIN table_groups + ON table_groups.id = profiling_runs.table_groups_id + WHERE table_groups.id IN :table_group_ids + AND profiling_runs.status = 'Running'; + """ + params = {"table_group_ids": tuple(ids)} + process_count = get_current_session().execute(text(query), params).rowcount + if process_count: + return True + + test_suites = TestSuite.select_minimal_where(TestSuite.table_groups_id.in_(ids)) + if test_suites: + return TestSuite.has_running_process([item.id for item in test_suites]) + + return False + + @classmethod + def is_in_use(cls, ids: list[str]) -> bool: + test_suites = TestSuite.select_minimal_where(TestSuite.table_groups_id.in_(ids)) + if test_suites: + return True + + query = "SELECT id FROM profiling_runs WHERE table_groups_id IN :table_group_ids;" + params = {"table_group_ids": tuple(ids)} + dependency_count = get_current_session().execute(text(query), params).rowcount + return dependency_count > 0 + + @classmethod + def cascade_delete(cls, ids: list[str]) -> None: + test_suites = TestSuite.select_minimal_where(TestSuite.table_groups_id.in_(ids)) + if test_suites: + TestSuite.cascade_delete([item.id for item in test_suites]) + + query = """ + DELETE FROM profile_pair_rules ppr + USING profiling_runs pr, table_groups tg + WHERE pr.id = ppr.profile_run_id AND tg.id = pr.table_groups_id AND tg.id IN :table_group_ids; + + DELETE FROM profile_anomaly_results par + USING table_groups tg + WHERE tg.id = par.table_groups_id AND tg.id IN :table_group_ids; + + DELETE FROM profile_results pr + USING table_groups tg + WHERE tg.id = pr.table_groups_id AND tg.id IN :table_group_ids; + + DELETE FROM profiling_runs pr + USING table_groups tg + WHERE tg.id = pr.table_groups_id AND tg.id IN :table_group_ids; + + DELETE FROM data_table_chars dtc + USING table_groups tg + WHERE tg.id = dtc.table_groups_id AND tg.id IN :table_group_ids; + + DELETE FROM data_column_chars dcs + USING table_groups tg + WHERE tg.id = dcs.table_groups_id AND tg.id IN :table_group_ids; + + DELETE FROM job_schedules + WHERE (kwargs->>'table_group_id')::UUID IN :table_group_ids; + """ + params = {"table_group_ids": tuple(ids)} + db_session = get_current_session() + db_session.execute(text(query), params) + db_session.commit() + cls.delete_where(cls.id.in_(ids)) + + @classmethod + def clear_cache(cls) -> bool: + super().clear_cache() + cls.get_minimal.clear() + cls.select_minimal_where.clear() + cls.select_summary.clear() + + def save(self, add_scorecard_definition: bool = False) -> None: + if self.id: + values = { + column.key: getattr(self, column.key, None) + for column in self.__table__.columns + if column not in self._update_exclude_columns + } + query = update(TableGroup).where(TableGroup.id == self.id).values(**values) + db_session = get_current_session() + db_session.execute(query) + db_session.commit() + else: + super().save() + if add_scorecard_definition: + ScoreDefinition.from_table_group(self).save() + + TableGroup.clear_cache() diff --git a/testgen/common/models/test_definition.py b/testgen/common/models/test_definition.py new file mode 100644 index 00000000..936e6b65 --- /dev/null +++ b/testgen/common/models/test_definition.py @@ -0,0 +1,381 @@ +from collections.abc import Iterable +from dataclasses import dataclass +from datetime import datetime +from typing import Literal +from uuid import UUID + +import streamlit as st +from sqlalchemy import ( + BigInteger, + Column, + ForeignKey, + Identity, + String, + Text, + TypeDecorator, + asc, + insert, + select, + text, + update, +) +from sqlalchemy.dialects import postgresql +from sqlalchemy.orm import InstrumentedAttribute +from sqlalchemy.sql.expression import case, literal + +from testgen.common.models import get_current_session +from testgen.common.models.custom_types import NullIfEmptyString, UpdateTimestamp, YNString, ZeroIfEmptyInteger +from testgen.common.models.entity import ENTITY_HASH_FUNCS, Entity, EntityMinimal +from testgen.utils import is_uuid4 + +TestRunStatus = Literal["Running", "Complete", "Error", "Cancelled"] + + +@dataclass +class TestDefinitionSummary(EntityMinimal): + id: UUID + table_groups_id: UUID + profile_run_id: UUID + test_type: str + test_suite_id: UUID + test_description: str + schema_name: str + table_name: str + column_name: str + skip_errors: int + baseline_ct: str + baseline_unique_ct: str + baseline_value: str + baseline_value_ct: str + threshold_value: str + baseline_sum: str + baseline_avg: str + baseline_sd: str + lower_tolerance: str + upper_tolerance: str + subset_condition: str + groupby_names: str + having_condition: str + window_date_column: str + window_days: int + match_schema_name: str + match_table_name: str + match_column_names: str + match_subset_condition: str + match_groupby_names: str + match_having_condition: str + custom_query: str + test_active: str + test_definition_status: str + severity: str + lock_refresh: str + last_auto_gen_date: datetime + profiling_as_of_date: datetime + last_manual_update: datetime + export_to_observability: str + test_name_short: str + default_test_description: str + measure_uom: str + measure_uom_description: str + default_parm_columns: str + default_parm_prompts: str + default_parm_help: str + default_severity: str + test_scope: str + usage_notes: str + + +@dataclass +class TestDefinitionMinimal(EntityMinimal): + id: UUID + table_groups_id: UUID + test_type: str + test_suite_id: UUID + schema_name: str + table_name: str + column_name: str + test_active: bool + lock_refresh: bool + test_name_short: str + + +class QueryString(TypeDecorator): + impl = String + cache_ok = True + + def process_bind_param(self, value: str | None, _dialect) -> str | None: + if value and isinstance(value, str): + value = value.strip() + if value.endswith(";"): + value = value[:-1] + return value or None + + +class TestType(Entity): + __tablename__ = "test_types" + + id: str = Column(String) + test_type: str = Column(String, primary_key=True, nullable=False) + test_name_short: str = Column(String) + test_name_long: str = Column(String) + test_description: str = Column(String) + except_message: str = Column(String) + measure_uom: str = Column(String) + measure_uom_description: str = Column(String) + selection_criteria: str = Column(Text) + dq_score_prevalence_formula: str = Column(Text) + dq_score_risk_factor: str = Column(Text) + column_name_prompt: str = Column(Text) + column_name_help: str = Column(Text) + default_parm_columns: str = Column(Text) + default_parm_values: str = Column(Text) + default_parm_prompts: str = Column(Text) + default_parm_help: str = Column(Text) + default_severity: str = Column(String) + run_type: str = Column(String) + test_scope: str = Column(String) + dq_dimension: str = Column(String) + health_dimension: str = Column(String) + threshold_description: str = Column(String) + usage_notes: str = Column(String) + active: str = Column(String) + + +class TestDefinition(Entity): + __tablename__ = "test_definitions" + + id: UUID = Column(postgresql.UUID(as_uuid=True)) + cat_test_id: int = Column(BigInteger, Identity(), primary_key=True) + table_groups_id: UUID = Column(postgresql.UUID(as_uuid=True)) + profile_run_id: UUID = Column(postgresql.UUID(as_uuid=True)) + test_type: str = Column(String) + test_suite_id: UUID = Column(postgresql.UUID(as_uuid=True), ForeignKey("test_suites.id"), nullable=False) + test_description: str = Column(NullIfEmptyString) + test_action: str = Column(String) + schema_name: str = Column(String) + table_name: str = Column(NullIfEmptyString) + column_name: str = Column(NullIfEmptyString) + skip_errors: int = Column(ZeroIfEmptyInteger) + baseline_ct: str = Column(NullIfEmptyString) + baseline_unique_ct: str = Column(NullIfEmptyString) + baseline_value: str = Column(NullIfEmptyString) + baseline_value_ct: str = Column(NullIfEmptyString) + threshold_value: str = Column(NullIfEmptyString) + baseline_sum: str = Column(NullIfEmptyString) + baseline_avg: str = Column(NullIfEmptyString) + baseline_sd: str = Column(NullIfEmptyString) + lower_tolerance: str = Column(NullIfEmptyString) + upper_tolerance: str = Column(NullIfEmptyString) + subset_condition: str = Column(NullIfEmptyString) + groupby_names: str = Column(NullIfEmptyString) + having_condition: str = Column(NullIfEmptyString) + window_date_column: str = Column(NullIfEmptyString) + window_days: int = Column(ZeroIfEmptyInteger) + match_schema_name: str = Column(NullIfEmptyString) + match_table_name: str = Column(NullIfEmptyString) + match_column_names: str = Column(NullIfEmptyString) + match_subset_condition: str = Column(NullIfEmptyString) + match_groupby_names: str = Column(NullIfEmptyString) + match_having_condition: str = Column(NullIfEmptyString) + test_mode: str = Column(String) + custom_query: str = Column(QueryString) + test_active: bool = Column(YNString, default="Y") + test_definition_status: str = Column(NullIfEmptyString) + severity: str = Column(NullIfEmptyString) + watch_level: str = Column(String, default="WARN") + check_result: str = Column(String) + lock_refresh: bool = Column(YNString, default="N", nullable=False) + last_auto_gen_date: datetime = Column(postgresql.TIMESTAMP) + profiling_as_of_date: datetime = Column(postgresql.TIMESTAMP) + last_manual_update: datetime = Column(UpdateTimestamp, nullable=False) + export_to_observability: bool = Column(YNString) + + _default_order_by = (asc(schema_name), asc(table_name), asc(column_name), asc(test_type)) + _summary_columns = ( + *[key for key in TestDefinitionSummary.__annotations__.keys() if key != "default_test_description"], + TestType.test_description.label("default_test_description"), + ) + _minimal_columns = TestDefinitionMinimal.__annotations__.keys() + _update_exclude_columns = ( + id, + cat_test_id, + table_groups_id, + profile_run_id, + test_type, + test_suite_id, + test_action, + schema_name, + test_mode, + watch_level, + check_result, + last_auto_gen_date, + profiling_as_of_date, + ) + + @classmethod + @st.cache_data(show_spinner=False) + def get(cls, identifier: str | UUID) -> TestDefinitionSummary | None: + if not is_uuid4(identifier): + return None + + result = cls._get_columns( + identifier, + cls._summary_columns, + join_target=TestType, + join_clause=cls.test_type == TestType.test_type, + ) + return TestDefinitionSummary(**result) if result else None + + @classmethod + @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) + def select_where( + cls, *clauses, order_by: tuple[str | InstrumentedAttribute] = _default_order_by + ) -> Iterable[TestDefinitionSummary]: + results = cls._select_columns_where( + cls._summary_columns, + *clauses, + join_target=TestType, + join_clause=cls.test_type == TestType.test_type, + order_by=order_by, + ) + return [TestDefinitionSummary(**row) for row in results] + + @classmethod + @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) + def select_minimal_where( + cls, *clauses, order_by: tuple[str | InstrumentedAttribute] = _default_order_by + ) -> Iterable[TestDefinitionMinimal]: + results = cls._select_columns_where( + cls._minimal_columns, + *clauses, + join_target=TestType, + join_clause=cls.test_type == TestType.test_type, + order_by=order_by, + ) + return [TestDefinitionMinimal(**row) for row in results] + + @classmethod + def set_status_attribute( + cls, + status_type: Literal["test_active", "lock_refresh"], + test_definition_ids: list[str | UUID], + value: bool, + ) -> None: + query = f""" + WITH selected AS ( + SELECT UNNEST(ARRAY [:test_definition_ids]) AS id + ) + UPDATE test_definitions + SET {status_type} = :value + FROM test_definitions td + INNER JOIN selected ON (td.id = selected.id::UUID) + WHERE td.id = test_definitions.id; + """ + params = { + "test_definition_ids": test_definition_ids, + "value": YNString().process_bind_param(value, None), + } + + db_session = get_current_session() + db_session.execute(text(query), params) + db_session.commit() + cls.clear_cache() + + @classmethod + def move( + cls, + test_definition_ids: list[str | UUID], + target_table_group_id: str | UUID, + target_test_suite_id: str | UUID, + target_table_name: str | None = None, + target_column_name: str | None = None, + ) -> None: + query = f""" + WITH selected AS ( + SELECT UNNEST(ARRAY [:test_definition_ids]) AS id + ) + UPDATE test_definitions + SET + {"table_name = :target_table_name," if target_table_name else ""} + {"column_name = :target_column_name," if target_column_name else ""} + table_groups_id = :target_table_group, + test_suite_id = :target_test_suite + FROM test_definitions td + INNER JOIN selected ON (td.id = selected.id::UUID) + WHERE td.id = test_definitions.id; + """ + params = { + "test_definition_ids": test_definition_ids, + "target_table_group": target_table_group_id, + "target_test_suite": target_test_suite_id, + "target_table_name": target_table_name, + "target_column_name": target_column_name, + } + + db_session = get_current_session() + db_session.execute(text(query), params) + db_session.commit() + cls.clear_cache() + + @classmethod + def copy( + cls, + test_definition_ids: list[str | UUID], + target_table_group_id: str | UUID, + target_test_suite_id: str | UUID, + target_table_name: str | None = None, + target_column_name: str | None = None, + ) -> None: + id_columns = (cls.id, cls.cat_test_id) + modified_columns = [cls.table_groups_id, cls.profile_run_id, cls.test_suite_id] + + select_columns = [ + literal(target_table_group_id).label("table_groups_id"), + case( + (cls.table_groups_id == target_table_group_id, cls.profile_run_id), + else_=None, + ).label("profile_run_id"), + literal(target_test_suite_id).label("test_suite_id"), + ] + + if target_table_name: + modified_columns.append(cls.table_name) + select_columns.append(literal(target_table_name).label("table_name")) + + if target_column_name: + modified_columns.append(cls.column_name) + select_columns.append(literal(target_column_name).label("column_name")) + + other_columns = [ + column for column in cls.__table__.columns if column not in modified_columns and column not in id_columns + ] + select_columns.extend(other_columns) + + query = insert(cls).from_select( + [*modified_columns, *other_columns], select(select_columns).where(cls.id.in_(test_definition_ids)) + ) + db_session = get_current_session() + db_session.execute(query) + db_session.commit() + cls.clear_cache() + + @classmethod + def clear_cache(cls) -> bool: + super().clear_cache() + cls.select_minimal_where.clear() + + def save(self) -> None: + if self.id: + values = { + column.key: getattr(self, column.key, None) + for column in self.__table__.columns + if column not in self._update_exclude_columns + } + query = update(TestDefinition).where(TestDefinition.id == self.id).values(**values) + db_session = get_current_session() + db_session.execute(query) + db_session.commit() + else: + super().save() + + TestDefinition.clear_cache() diff --git a/testgen/common/models/test_run.py b/testgen/common/models/test_run.py new file mode 100644 index 00000000..47aa1584 --- /dev/null +++ b/testgen/common/models/test_run.py @@ -0,0 +1,257 @@ +from collections.abc import Iterable +from dataclasses import dataclass +from datetime import UTC, datetime +from typing import Literal, NamedTuple +from uuid import UUID + +import streamlit as st +from sqlalchemy import BigInteger, Column, Float, ForeignKey, Integer, String, Text, desc, func, select, text, update +from sqlalchemy.dialects import postgresql +from sqlalchemy.sql.expression import case + +from testgen.common.models import get_current_session +from testgen.common.models.entity import Entity, EntityMinimal +from testgen.common.models.test_suite import TestSuite +from testgen.utils import is_uuid4 + +TestRunStatus = Literal["Running", "Complete", "Error", "Cancelled"] + + +@dataclass +class TestRunMinimal(EntityMinimal): + id: UUID + project_code: str + table_groups_id: UUID + test_suite_id: UUID + test_suite: str + test_starttime: datetime + dq_score_test_run: float + is_latest_run: bool + + +@dataclass +class TestRunSummary(EntityMinimal): + test_run_id: UUID + test_starttime: datetime + table_groups_name: str + test_suite: str + status: TestRunStatus + duration: str + process_id: int + log_message: str + test_ct: int + passed_ct: int + warning_ct: int + failed_ct: int + error_ct: int + dismissed_ct: int + dq_score_testing: float + + +class LatestTestRun(NamedTuple): + id: str + run_time: datetime + + +class TestRun(Entity): + __tablename__ = "test_runs" + + id: UUID = Column(postgresql.UUID(as_uuid=True), primary_key=True, nullable=False) + test_suite_id: UUID = Column(postgresql.UUID(as_uuid=True), ForeignKey("test_suites.id"), nullable=False) + test_starttime: datetime = Column(postgresql.TIMESTAMP) + test_endtime: datetime = Column(postgresql.TIMESTAMP) + status: TestRunStatus = Column(String, default="Running") + log_message: str = Column(Text) + duration: str = Column(String) + test_ct: int = Column(Integer) + passed_ct: int = Column(Integer) + failed_ct: int = Column(Integer) + warning_ct: int = Column(Integer) + error_ct: int = Column(Integer) + table_ct: int = Column(Integer) + column_ct: int = Column(Integer) + column_failed_ct: int = Column(Integer) + column_warning_ct: int = Column(Integer) + dq_affected_data_points: int = Column(BigInteger) + dq_total_data_points: int = Column(BigInteger) + dq_score_test_run: float = Column(Float) + process_id: int = Column(Integer) + + _default_order_by = (desc(test_starttime),) + _minimal_columns = ( + id, + TestSuite.project_code, + TestSuite.table_groups_id, + TestSuite.id.label("test_suite_id"), + TestSuite.test_suite, + test_starttime, + dq_score_test_run, + case( + (id == TestSuite.last_complete_test_run_id, True), + else_=False, + ).label("is_latest_run"), + ) + + @classmethod + @st.cache_data(show_spinner=False) + def get_minimal(cls, run_id: str | UUID) -> TestRunMinimal | None: + if not is_uuid4(run_id): + return None + + query = select(cls._minimal_columns).join(TestSuite).where(cls.id == run_id) + result = get_current_session().execute(query).first() + return TestRunMinimal(**result) if result else None + + @classmethod + def get_latest_run(cls, project_code: str) -> LatestTestRun | None: + query = ( + select(TestRun.id, TestRun.test_starttime) + .join(TestSuite) + .where(TestSuite.project_code == project_code, TestRun.status == "Complete") + .order_by(desc(TestRun.test_starttime)) + .limit(1) + ) + result = get_current_session().execute(query).first() + if result: + return LatestTestRun(str(result["id"]), result["test_starttime"]) + return None + + @classmethod + @st.cache_data(show_spinner=False) + def select_summary( + cls, + project_code: str, + table_group_id: str | None = None, + test_suite_id: str | None = None, + test_run_ids: list[str] | None = None, + ) -> Iterable[TestRunSummary]: + if ( + (table_group_id and not is_uuid4(table_group_id)) + or (test_suite_id and not is_uuid4(test_suite_id)) + or (test_run_ids and not all(is_uuid4(run_id) for run_id in test_run_ids)) + ): + return [] + + query = f""" + WITH run_results AS ( + SELECT test_run_id, + SUM( + CASE + WHEN COALESCE(disposition, 'Confirmed') = 'Confirmed' + AND result_status = 'Passed' THEN 1 + ELSE 0 + END + ) AS passed_ct, + SUM( + CASE + WHEN COALESCE(disposition, 'Confirmed') = 'Confirmed' + AND result_status = 'Warning' THEN 1 + ELSE 0 + END + ) AS warning_ct, + SUM( + CASE + WHEN COALESCE(disposition, 'Confirmed') = 'Confirmed' + AND result_status = 'Failed' THEN 1 + ELSE 0 + END + ) AS failed_ct, + SUM( + CASE + WHEN COALESCE(disposition, 'Confirmed') = 'Confirmed' + AND result_status = 'Error' THEN 1 + ELSE 0 + END + ) AS error_ct, + SUM( + CASE + WHEN COALESCE(disposition, 'Confirmed') IN ('Dismissed', 'Inactive') THEN 1 + ELSE 0 + END + ) AS dismissed_ct + FROM test_results + GROUP BY test_run_id + ) + SELECT test_runs.id AS test_run_id, + test_runs.test_starttime, + table_groups.table_groups_name, + test_suites.test_suite, + test_runs.status, + test_runs.duration, + test_runs.process_id, + test_runs.log_message, + test_runs.test_ct, + run_results.passed_ct, + run_results.warning_ct, + run_results.failed_ct, + run_results.error_ct, + run_results.dismissed_ct, + test_runs.dq_score_test_run AS dq_score_testing + FROM test_runs + LEFT JOIN run_results ON (test_runs.id = run_results.test_run_id) + INNER JOIN test_suites ON (test_runs.test_suite_id = test_suites.id) + INNER JOIN table_groups ON (test_suites.table_groups_id = table_groups.id) + INNER JOIN projects ON (test_suites.project_code = projects.project_code) + WHERE test_suites.project_code = :project_code + {"AND test_suites.table_groups_id = :table_group_id" if table_group_id else ""} + {" AND test_suites.id = :test_suite_id" if test_suite_id else ""} + {" AND test_runs.id IN :test_run_ids" if test_run_ids else ""} + ORDER BY test_runs.test_starttime DESC; + """ + params = { + "project_code": project_code, + "table_group_id": table_group_id, + "test_suite_id": test_suite_id, + "test_run_ids": tuple(test_run_ids or []), + } + db_session = get_current_session() + results = db_session.execute(text(query), params).mappings().all() + return [TestRunSummary(**row) for row in results] + + @classmethod + def has_running_process(cls, ids: list[str]) -> bool: + query = select(func.count(cls.id)).where(cls.id.in_(ids), cls.status == "Running") + process_count = get_current_session().execute(query).scalar() + return process_count > 0 + + @classmethod + def cancel_all_running(cls) -> None: + query = update(cls).where(cls.status == "Running").values(status="Cancelled", test_endtime=datetime.now(UTC)) + db_session = get_current_session() + db_session.execute(query) + db_session.commit() + cls.clear_cache() + + @classmethod + def update_status(cls, run_id: str | UUID, status: TestRunStatus) -> None: + query = update(cls).where(cls.id == run_id).values(status=status) + db_session = get_current_session() + db_session.execute(query) + db_session.commit() + cls.clear_cache() + + @classmethod + def cascade_delete(cls, ids: list[str]) -> None: + query = """ + DELETE FROM working_agg_cat_results + WHERE test_run_id IN :test_run_ids; + + DELETE FROM working_agg_cat_tests + WHERE test_run_id IN :test_run_ids; + + DELETE FROM test_results + WHERE test_run_id IN :test_run_ids; + """ + db_session = get_current_session() + db_session.execute(text(query), {"test_run_ids": tuple(ids)}) + db_session.commit() + cls.delete_where(cls.id.in_(ids)) + + @classmethod + def clear_cache(cls) -> bool: + super().clear_cache() + cls.get_minimal.clear() + cls.select_summary.clear() + + def save(self) -> None: + raise NotImplementedError diff --git a/testgen/common/models/test_suite.py b/testgen/common/models/test_suite.py new file mode 100644 index 00000000..02eccdec --- /dev/null +++ b/testgen/common/models/test_suite.py @@ -0,0 +1,246 @@ +from collections.abc import Iterable +from dataclasses import dataclass +from datetime import datetime +from uuid import UUID, uuid4 + +import streamlit as st +from sqlalchemy import BigInteger, Boolean, Column, ForeignKey, String, asc, text +from sqlalchemy.dialects import postgresql +from sqlalchemy.orm import InstrumentedAttribute + +from testgen.common.models import get_current_session +from testgen.common.models.custom_types import NullIfEmptyString, YNString +from testgen.common.models.entity import ENTITY_HASH_FUNCS, Entity, EntityMinimal +from testgen.utils import is_uuid4 + + +@dataclass +class TestSuiteMinimal(EntityMinimal): + id: UUID + project_code: str + test_suite: str + connection_id: int + table_groups_id: UUID + export_to_observability: str + + +@dataclass +class TestSuiteSummary(EntityMinimal): + id: UUID + project_code: str + test_suite: str + connection_name: str + table_groups_id: UUID + table_groups_name: str + test_suite_description: str + export_to_observability: bool + test_ct: int + last_complete_profile_run_id: UUID + latest_run_id: UUID + latest_run_start: datetime + last_run_test_ct: int + last_run_passed_ct: int + last_run_warning_ct: int + last_run_failed_ct: int + last_run_error_ct: int + last_run_dismissed_ct: int + + +class TestSuite(Entity): + __tablename__ = "test_suites" + + id: UUID = Column(postgresql.UUID(as_uuid=True), primary_key=True, default=uuid4) + project_code: str = Column(String) + test_suite: str = Column(String) + connection_id: int = Column(BigInteger, ForeignKey("connections.connection_id")) + table_groups_id: UUID = Column(postgresql.UUID(as_uuid=True)) + test_suite_description: str = Column(NullIfEmptyString) + test_action: str = Column(String) + severity: str = Column(NullIfEmptyString) + export_to_observability: bool = Column(YNString, default="Y") + test_suite_schema: str = Column(NullIfEmptyString) + component_key: str = Column(NullIfEmptyString) + component_type: str = Column(NullIfEmptyString) + component_name: str = Column(NullIfEmptyString) + last_complete_test_run_id: UUID = Column(postgresql.UUID(as_uuid=True)) + dq_score_exclude: bool = Column(Boolean, default=False) + + _default_order_by = (asc(test_suite),) + _minimal_columns = TestSuiteMinimal.__annotations__.keys() + + @classmethod + @st.cache_data(show_spinner=False) + def get_minimal(cls, identifier: int) -> TestSuiteMinimal | None: + result = cls._get_columns(identifier, cls._minimal_columns) + return TestSuiteMinimal(**result) if result else None + + @classmethod + @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) + def select_minimal_where( + cls, *clauses, order_by: tuple[str | InstrumentedAttribute] = _default_order_by + ) -> Iterable[TestSuiteMinimal]: + results = cls._select_columns_where(cls._minimal_columns, *clauses, order_by=order_by) + return [TestSuiteMinimal(**row) for row in results] + + @classmethod + @st.cache_data(show_spinner=False) + def select_summary(cls, project_code: str, table_group_id: str | UUID | None = None) -> Iterable[TestSuiteSummary]: + if table_group_id and not is_uuid4(table_group_id): + return [] + + query = f""" + WITH last_run AS ( + SELECT test_runs.test_suite_id, + test_runs.id, + test_runs.test_starttime, + test_runs.test_ct, + SUM( + CASE + WHEN COALESCE(test_results.disposition, 'Confirmed') = 'Confirmed' + AND test_results.result_status = 'Passed' THEN 1 + ELSE 0 + END + ) AS passed_ct, + SUM( + CASE + WHEN COALESCE(test_results.disposition, 'Confirmed') = 'Confirmed' + AND test_results.result_status = 'Warning' THEN 1 + ELSE 0 + END + ) AS warning_ct, + SUM( + CASE + WHEN COALESCE(test_results.disposition, 'Confirmed') = 'Confirmed' + AND test_results.result_status = 'Failed' THEN 1 + ELSE 0 + END + ) AS failed_ct, + SUM( + CASE + WHEN COALESCE(test_results.disposition, 'Confirmed') = 'Confirmed' + AND test_results.result_status = 'Error' THEN 1 + ELSE 0 + END + ) AS error_ct, + SUM( + CASE + WHEN COALESCE(test_results.disposition, 'Confirmed') IN ('Dismissed', 'Inactive') THEN 1 + ELSE 0 + END + ) AS dismissed_ct + FROM test_suites + LEFT JOIN test_runs ON ( + test_suites.last_complete_test_run_id = test_runs.id + ) + LEFT JOIN test_results ON ( + test_runs.id = test_results.test_run_id + ) + GROUP BY test_runs.id + ), + test_defs AS ( + SELECT test_suite_id, + COUNT(*) AS count + FROM test_definitions + GROUP BY test_suite_id + ) + SELECT + suites.id, + suites.project_code, + suites.test_suite, + connections.connection_name, + suites.table_groups_id, + groups.table_groups_name, + suites.test_suite_description, + CASE WHEN suites.export_to_observability = 'Y' THEN TRUE ELSE FALSE END AS export_to_observability, + test_defs.count AS test_ct, + last_complete_profile_run_id, + last_run.id AS latest_run_id, + last_run.test_starttime AS latest_run_start, + last_run.test_ct AS last_run_test_ct, + last_run.passed_ct AS last_run_passed_ct, + last_run.warning_ct AS last_run_warning_ct, + last_run.failed_ct AS last_run_failed_ct, + last_run.error_ct AS last_run_error_ct, + last_run.dismissed_ct AS last_run_dismissed_ct + FROM test_suites AS suites + LEFT JOIN last_run + ON (suites.id = last_run.test_suite_id) + LEFT JOIN test_defs + ON (suites.id = test_defs.test_suite_id) + LEFT JOIN connections AS connections + ON (connections.connection_id = suites.connection_id) + LEFT JOIN table_groups AS groups + ON (groups.id = suites.table_groups_id) + WHERE suites.project_code = :project_code + {"AND suites.table_groups_id = :table_group_id" if table_group_id else ""} + ORDER BY suites.test_suite; + """ + params = {"project_code": project_code, "table_group_id": table_group_id} + db_session = get_current_session() + results = db_session.execute(text(query), params).mappings().all() + return [TestSuiteSummary(**row) for row in results] + + @classmethod + def has_running_process(cls, ids: list[str]) -> bool: + query = """ + SELECT DISTINCT test_suite_id + FROM test_runs + WHERE test_suite_id IN :test_suite_ids + AND status = 'Running'; + """ + params = {"test_suite_ids": tuple(ids)} + process_count = get_current_session().execute(text(query), params).rowcount + return process_count > 0 + + @classmethod + def is_in_use(cls, ids: list[str]) -> bool: + query = """ + SELECT DISTINCT test_suite_id FROM test_definitions WHERE test_suite_id IN :test_suite_ids + UNION + SELECT DISTINCT test_suite_id FROM test_results WHERE test_suite_id IN :test_suite_ids; + """ + params = {"test_suite_ids": tuple(ids)} + dependency_count = get_current_session().execute(text(query), params).rowcount + return dependency_count > 0 + + @classmethod + def cascade_delete(cls, ids: list[str]) -> None: + query = """ + DELETE FROM working_agg_cat_results + WHERE test_run_id IN ( + SELECT id FROM test_runs + WHERE test_suite_id IN :test_suite_ids + ); + + DELETE FROM working_agg_cat_tests + WHERE test_run_id IN ( + SELECT id FROM test_runs + WHERE test_suite_id IN :test_suite_ids + ); + + DELETE FROM test_runs + WHERE test_suite_id IN :test_suite_ids; + + DELETE FROM test_results + WHERE test_suite_id IN :test_suite_ids; + + DELETE FROM test_definitions + WHERE test_suite_id IN :test_suite_ids; + + DELETE FROM job_schedules js + USING test_suites ts + WHERE js.kwargs->>'project_key' = ts.project_code + AND js.kwargs->>'test_suite_key' = ts.test_suite + AND ts.id IN :test_suite_ids; + """ + db_session = get_current_session() + db_session.execute(text(query), {"test_suite_ids": tuple(ids)}) + db_session.commit() + cls.delete_where(cls.id.in_(ids)) + + @classmethod + def clear_cache(cls) -> bool: + super().clear_cache() + cls.get_minimal.clear() + cls.select_minimal_where.clear() + cls.select_summary.clear() diff --git a/testgen/common/models/user.py b/testgen/common/models/user.py new file mode 100644 index 00000000..bcba4599 --- /dev/null +++ b/testgen/common/models/user.py @@ -0,0 +1,23 @@ +from typing import Literal +from uuid import UUID, uuid4 + +from sqlalchemy import Column, String, asc +from sqlalchemy.dialects import postgresql + +from testgen.common.models.custom_types import NullIfEmptyString +from testgen.common.models.entity import Entity + +RoleType = Literal["admin", "data_quality", "analyst", "business", "catalog"] + + +class User(Entity): + __tablename__ = "auth_users" + + id: UUID = Column(postgresql.UUID(as_uuid=True), primary_key=True, default=uuid4) + username: str = Column(String) + email: str = Column(NullIfEmptyString) + name: str = Column(NullIfEmptyString) + password: str = Column(String) + role: RoleType = Column(String) + + _default_order_by = (asc(username),) diff --git a/testgen/template/data_chars/data_chars_staging_delete.sql b/testgen/template/data_chars/data_chars_staging_delete.sql index d81bcf14..292d722c 100644 --- a/testgen/template/data_chars/data_chars_staging_delete.sql +++ b/testgen/template/data_chars/data_chars_staging_delete.sql @@ -1,4 +1,4 @@ DELETE FROM stg_data_chars_updates -WHERE project_code = '{PROJECT_CODE}' - AND table_groups_id = '{TABLE_GROUPS_ID}' - AND run_date = '{RUN_DATE}'; +WHERE project_code = :PROJECT_CODE + AND table_groups_id = :TABLE_GROUPS_ID + AND run_date = :RUN_DATE; diff --git a/testgen/template/data_chars/data_chars_update.sql b/testgen/template/data_chars/data_chars_update.sql index c35dc933..dcad1454 100644 --- a/testgen/template/data_chars/data_chars_update.sql +++ b/testgen/template/data_chars/data_chars_update.sql @@ -12,7 +12,7 @@ WITH new_chars AS ( MAX(record_ct) AS record_ct, COUNT(*) AS column_ct FROM {SOURCE_TABLE} - WHERE table_groups_id = '{TABLE_GROUPS_ID}' + WHERE table_groups_id = :TABLE_GROUPS_ID GROUP BY table_groups_id, schema_name, table_name, @@ -43,7 +43,7 @@ WITH new_chars AS ( MAX(record_ct) AS record_ct, COUNT(*) AS column_ct FROM {SOURCE_TABLE} - WHERE table_groups_id = '{TABLE_GROUPS_ID}' + WHERE table_groups_id = :TABLE_GROUPS_ID GROUP BY table_groups_id, schema_name, table_name, @@ -82,7 +82,7 @@ WITH new_chars AS ( schema_name, table_name FROM {SOURCE_TABLE} - WHERE table_groups_id = '{TABLE_GROUPS_ID}' + WHERE table_groups_id = :TABLE_GROUPS_ID GROUP BY table_groups_id, schema_name, table_name @@ -91,7 +91,7 @@ last_run AS ( SELECT table_groups_id, MAX(run_date) as last_run_date FROM {SOURCE_TABLE} - WHERE table_groups_id = '{TABLE_GROUPS_ID}' + WHERE table_groups_id = :TABLE_GROUPS_ID GROUP BY table_groups_id ) UPDATE data_table_chars @@ -123,7 +123,7 @@ WITH new_chars AS ( functional_data_type, run_date FROM {SOURCE_TABLE} - WHERE table_groups_id = '{TABLE_GROUPS_ID}' + WHERE table_groups_id = :TABLE_GROUPS_ID ) UPDATE data_column_chars SET ordinal_position = n.position, @@ -154,7 +154,7 @@ WITH new_chars AS ( functional_data_type, run_date FROM {SOURCE_TABLE} - WHERE table_groups_id = '{TABLE_GROUPS_ID}' + WHERE table_groups_id = :TABLE_GROUPS_ID ) INSERT INTO data_column_chars ( table_groups_id, @@ -201,13 +201,13 @@ WITH new_chars AS ( table_name, column_name FROM {SOURCE_TABLE} - WHERE table_groups_id = '{TABLE_GROUPS_ID}' + WHERE table_groups_id = :TABLE_GROUPS_ID ), last_run AS ( SELECT table_groups_id, MAX(run_date) as last_run_date FROM {SOURCE_TABLE} - WHERE table_groups_id = '{TABLE_GROUPS_ID}' + WHERE table_groups_id = :TABLE_GROUPS_ID GROUP BY table_groups_id ) UPDATE data_column_chars diff --git a/testgen/template/dbsetup/030_initialize_new_schema_structure.sql b/testgen/template/dbsetup/030_initialize_new_schema_structure.sql index 96549aaa..d8079a07 100644 --- a/testgen/template/dbsetup/030_initialize_new_schema_structure.sql +++ b/testgen/template/dbsetup/030_initialize_new_schema_structure.sql @@ -50,8 +50,6 @@ CREATE TABLE projects ( CONSTRAINT projects_project_code_pk PRIMARY KEY, project_name VARCHAR(50), - effective_from_date DATE, - effective_thru_date DATE, observability_api_key TEXT, observability_api_url TEXT DEFAULT '' ); @@ -84,7 +82,9 @@ CREATE TABLE connections ( CREATE TABLE table_groups ( - id UUID DEFAULT gen_random_uuid(), + id UUID DEFAULT gen_random_uuid() + CONSTRAINT pk_tg_id + PRIMARY KEY, project_code VARCHAR(30) CONSTRAINT table_groups_projects_project_code_fk REFERENCES projects, @@ -622,12 +622,14 @@ CREATE TABLE functional_test_results ); CREATE TABLE auth_users ( - id UUID DEFAULT gen_random_uuid(), - username VARCHAR(20), - email VARCHAR(120), + id UUID DEFAULT gen_random_uuid() + CONSTRAINT pk_au_id + PRIMARY KEY, + username VARCHAR(20), + email VARCHAR(120), name VARCHAR(120), - password VARCHAR(120), - role VARCHAR(20) + password VARCHAR(120), + role VARCHAR(20) ); ALTER TABLE auth_users diff --git a/testgen/template/dbsetup/040_populate_new_schema_project.sql b/testgen/template/dbsetup/040_populate_new_schema_project.sql index 3e8adfdf..ae4eddf8 100644 --- a/testgen/template/dbsetup/040_populate_new_schema_project.sql +++ b/testgen/template/dbsetup/040_populate_new_schema_project.sql @@ -1,10 +1,9 @@ SET SEARCH_PATH TO {SCHEMA_NAME}; INSERT INTO projects - (project_code, project_name, effective_from_date, observability_api_key, observability_api_url) + (project_code, project_name, observability_api_key, observability_api_url) SELECT '{PROJECT_CODE}' as project_code, '{PROJECT_NAME}' as project_name, - (CURRENT_TIMESTAMP AT TIME ZONE 'UTC')::DATE as effective_from_date, '{OBSERVABILITY_API_KEY}' as observability_api_key, '{OBSERVABILITY_API_URL}' as observability_api_url; diff --git a/testgen/template/dbsetup/050_populate_new_schema_metadata.sql b/testgen/template/dbsetup/050_populate_new_schema_metadata.sql index dc94bf04..d12b0212 100644 --- a/testgen/template/dbsetup/050_populate_new_schema_metadata.sql +++ b/testgen/template/dbsetup/050_populate_new_schema_metadata.sql @@ -317,7 +317,7 @@ VALUES ('1001', 'Alpha_Trunc', 'redshift', 'MAX(LENGTH({COLUMN_NAME}))', '<', ' ('3022', 'Recency', 'mssql', 'DATEDIFF(day, MAX({COLUMN_NAME}), CAST(''{RUN_DATE}''AS DATE))', '>', '{THRESHOLD_VALUE}'), ('3023', 'Required', 'mssql', 'COUNT(*) - COUNT( {COLUMN_NAME} )', '>', '{THRESHOLD_VALUE}'), ('3024', 'Row_Ct', 'mssql', 'COUNT(*)', '<', '{THRESHOLD_VALUE}'), - ('3025', 'Row_Ct_Pct', 'mssql', 'ABS(ROUND(100.0 * CAST((COUNT(*) - {BASELINE_CT} ) AS FLOAT)/ CAST({BASELINE_CT} AS FLOAT, 2)))', '>', '{THRESHOLD_VALUE}'), + ('3025', 'Row_Ct_Pct', 'mssql', 'ABS(ROUND(100.0 * CAST((COUNT(*) - {BASELINE_CT} ) AS FLOAT)/ CAST({BASELINE_CT} AS FLOAT), 2))', '>', '{THRESHOLD_VALUE}'), ('3026', 'Street_Addr_Pattern', 'mssql', 'CAST(100.0*SUM(CASE WHEN UPPER({COLUMN_NAME}) LIKE ''[1-9]% [A-Z]% %'' AND CHARINDEX('' '', {COLUMN_NAME}) BETWEEN 2 AND 6 THEN 1 ELSE 0 END) as FLOAT) /CAST(COUNT({COLUMN_NAME}) AS FLOAT)', '<', '{THRESHOLD_VALUE}'), ('3027', 'US_State', 'mssql', 'SUM(CASE WHEN NULLIF({COLUMN_NAME}, '''') NOT IN (''AL'',''AK'',''AS'',''AZ'',''AR'',''CA'',''CO'',''CT'',''DE'',''DC'',''FM'',''FL'',''GA'',''GU'',''HI'',''ID'',''IL'',''IN'',''IA'',''KS'',''KY'',''LA'',''ME'',''MH'',''MD'',''MA'',''MI'',''MN'',''MS'',''MO'',''MT'',''NE'',''NV'',''NH'',''NJ'',''NM'',''NY'',''NC'',''ND'',''MP'',''OH'',''OK'',''OR'',''PW'',''PA'',''PR'',''RI'',''SC'',''SD'',''TN'',''TX'',''UT'',''VT'',''VI'',''VA'',''WA'',''WV'',''WI'',''WY'',''AE'',''AP'',''AA'') THEN 1 ELSE 0 END)', '>', '{THRESHOLD_VALUE}'), ('3028', 'Unique', 'mssql', 'COUNT(*) - COUNT(DISTINCT {COLUMN_NAME})', '>', '{THRESHOLD_VALUE}'), diff --git a/testgen/template/dbupgrade/0145_incremental_upgrade.sql b/testgen/template/dbupgrade/0145_incremental_upgrade.sql new file mode 100644 index 00000000..20934bf3 --- /dev/null +++ b/testgen/template/dbupgrade/0145_incremental_upgrade.sql @@ -0,0 +1,17 @@ +SET SEARCH_PATH TO {SCHEMA_NAME}; + +ALTER TABLE auth_users + ADD CONSTRAINT pk_au_id + PRIMARY KEY (id); + +ALTER TABLE projects + DROP COLUMN effective_from_date, + DROP COLUMN effective_thru_date; + +ALTER TABLE table_groups + ADD CONSTRAINT pk_tg_id + PRIMARY KEY (id); + +UPDATE connections + SET sql_flavor_code = sql_flavor + WHERE sql_flavor_code IS NULL; diff --git a/testgen/template/exec_cat_tests/ex_cat_get_distinct_tables.sql b/testgen/template/exec_cat_tests/ex_cat_get_distinct_tables.sql index e8e85d0b..b5c79617 100644 --- a/testgen/template/exec_cat_tests/ex_cat_get_distinct_tables.sql +++ b/testgen/template/exec_cat_tests/ex_cat_get_distinct_tables.sql @@ -7,5 +7,5 @@ SELECT DISTINCT schema_name, ON (td.table_groups_id = tg.id) INNER JOIN connections c ON (tg.connection_id = c.connection_id) - WHERE td.test_suite_id = '{TEST_SUITE_ID}' + WHERE td.test_suite_id = :TEST_SUITE_ID AND tt.run_type = 'CAT'; diff --git a/testgen/template/exec_cat_tests/ex_cat_results_parse.sql b/testgen/template/exec_cat_tests/ex_cat_results_parse.sql index f499fff4..74f5dce5 100644 --- a/testgen/template/exec_cat_tests/ex_cat_results_parse.sql +++ b/testgen/template/exec_cat_tests/ex_cat_results_parse.sql @@ -22,7 +22,7 @@ WITH seq_digit AS ( AND t.schema_name = r.schema_name AND t.table_name = r.table_name AND t.cat_sequence = r.cat_sequence) - WHERE t.test_run_id = '{TEST_RUN_ID}' + WHERE t.test_run_id = :TEST_RUN_ID AND t.column_names > '' ), parsed_results AS ( @@ -51,15 +51,16 @@ INSERT INTO test_results test_time, starttime, endtime, schema_name, table_name, column_names, skip_errors, input_parameters, result_code, result_measure, test_action, subset_condition, result_query, test_description) -SELECT '{TEST_RUN_ID}' as test_run_id, - r.test_type, r.test_definition_id::UUID, '{TEST_SUITE_ID}'::UUID, r.test_time, r.start_time, r.end_time, +SELECT :TEST_RUN_ID as test_run_id, + r.test_type, r.test_definition_id::UUID, :TEST_SUITE_ID, r.test_time, r.start_time, r.end_time, r.schema_name, r.table_name, r.column_name, 0 as skip_errors, r.test_parms as input_parameters, r.test_result::INT as result_code, r.measure_result as result_measure, r.test_action, NULL as subset_condition, - 'SELECT ' || LEFT(REPLACE(r.condition, '{RUN_' || 'DATE}', '{RUN_DATE}'), LENGTH(REPLACE(r.condition, '{RUN_' || 'DATE}', '{RUN_DATE}')) - LENGTH(' THEN ''0,'' ELSE ''1,'' END')) || ' THEN 0 ELSE 1 END' + 'SELECT ' || LEFT(REPLACE(r.condition, '{RUN_' || 'DATE}', :RUN_DATE), LENGTH(REPLACE(r.condition, '{RUN_' || 'DATE}', :RUN_DATE + )) - LENGTH(' THEN ''0,'' ELSE ''1,'' END')) || ' THEN 0 ELSE 1 END' || ' FROM ' || r.schema_name || '.' || r.table_name as result_query, COALESCE(r.test_description, c.test_description) as test_description FROM parsed_results r diff --git a/testgen/template/exec_cat_tests/ex_cat_retrieve_agg_test_parms.sql b/testgen/template/exec_cat_tests/ex_cat_retrieve_agg_test_parms.sql index 615f5c95..7632fdb5 100644 --- a/testgen/template/exec_cat_tests/ex_cat_retrieve_agg_test_parms.sql +++ b/testgen/template/exec_cat_tests/ex_cat_retrieve_agg_test_parms.sql @@ -2,7 +2,7 @@ SELECT schema_name, table_name, cat_sequence, -- Replace list delimiters with concat operator - REPLACE(test_measures, '++', '{CONCAT_OPERATOR}') as test_measures, - REPLACE(test_conditions, '++', '{CONCAT_OPERATOR}') as test_conditions + REPLACE(test_measures, '++', :CONCAT_OPERATOR) as test_measures, + REPLACE(test_conditions, '++', :CONCAT_OPERATOR) as test_conditions FROM working_agg_cat_tests - WHERE test_run_id = '{TEST_RUN_ID}'; + WHERE test_run_id = :TEST_RUN_ID; diff --git a/testgen/template/execution/ex_finalize_test_run_results.sql b/testgen/template/execution/ex_finalize_test_run_results.sql index e05147a2..b2070b94 100644 --- a/testgen/template/execution/ex_finalize_test_run_results.sql +++ b/testgen/template/execution/ex_finalize_test_run_results.sql @@ -30,5 +30,5 @@ UPDATE test_results INNER JOIN test_suites s ON r.test_suite_id = s.id INNER JOIN test_definitions d ON r.test_definition_id = d.id INNER JOIN test_types tt ON r.test_type = tt.test_type -WHERE r.test_run_id = '{TEST_RUN_ID}' +WHERE r.test_run_id = :TEST_RUN_ID AND test_results.id = r.id; diff --git a/testgen/template/execution/ex_get_tests_non_cat.sql b/testgen/template/execution/ex_get_tests_non_cat.sql index 536fd509..69672e1b 100644 --- a/testgen/template/execution/ex_get_tests_non_cat.sql +++ b/testgen/template/execution/ex_get_tests_non_cat.sql @@ -45,7 +45,7 @@ FROM test_definitions td ON (td.test_type = tt.test_type) LEFT JOIN test_templates tm ON (td.test_type = tm.test_type - AND '{SQL_FLAVOR}' = tm.sql_flavor) -WHERE td.test_suite_id = '{TEST_SUITE_ID}' + AND :SQL_FLAVOR = tm.sql_flavor) +WHERE td.test_suite_id = :TEST_SUITE_ID AND tt.run_type = 'QUERY' AND td.test_active = 'Y'; diff --git a/testgen/template/execution/ex_update_test_record_in_testrun_table.sql b/testgen/template/execution/ex_update_test_record_in_testrun_table.sql index a3eeba78..43ef1146 100644 --- a/testgen/template/execution/ex_update_test_record_in_testrun_table.sql +++ b/testgen/template/execution/ex_update_test_record_in_testrun_table.sql @@ -8,13 +8,13 @@ WITH stats FROM test_runs r INNER JOIN test_results tr ON r.id = tr.test_run_id - WHERE r.id = '{TEST_RUN_ID}'::UUID + WHERE r.id = :TEST_RUN_ID GROUP BY r.id ) UPDATE test_runs - SET status = CASE WHEN length('{EXCEPTION_MESSAGE}') = 0 then 'Complete' else 'Error' end, - test_endtime = '{NOW}', - log_message = '{EXCEPTION_MESSAGE}', - duration = TO_CHAR('{NOW}' - r.test_starttime, 'HH24:MI:SS'), + SET status = CASE WHEN length(:EXCEPTION_MESSAGE) = 0 then 'Complete' else 'Error' end, + test_endtime = :NOW_TIMESTAMP, + log_message = :EXCEPTION_MESSAGE, + duration = TO_CHAR(:NOW_TIMESTAMP - r.test_starttime, 'HH24:MI:SS'), test_ct = s.test_ct, passed_ct = s.passed_ct, failed_ct = s.failed_ct, @@ -23,5 +23,5 @@ UPDATE test_runs FROM test_runs r LEFT JOIN stats s ON r.id = s.test_run_id -WHERE r.id = '{TEST_RUN_ID}'::UUID +WHERE r.id = :TEST_RUN_ID AND r.id = test_runs.id; diff --git a/testgen/template/execution/ex_update_test_suite.sql b/testgen/template/execution/ex_update_test_suite.sql index 68283f14..72505590 100644 --- a/testgen/template/execution/ex_update_test_suite.sql +++ b/testgen/template/execution/ex_update_test_suite.sql @@ -1,7 +1,7 @@ WITH last_run AS (SELECT test_suite_id, MAX(test_starttime) as max_starttime FROM test_runs - WHERE test_suite_id = '{TEST_SUITE_ID}' + WHERE test_suite_id = :TEST_SUITE_ID AND status = 'Complete' GROUP BY test_suite_id) UPDATE test_suites diff --git a/testgen/template/flavors/databricks/profiling/project_get_table_sample_count_databricks.sql b/testgen/template/flavors/databricks/profiling/project_get_table_sample_count_databricks.sql new file mode 100644 index 00000000..9a62c3d6 --- /dev/null +++ b/testgen/template/flavors/databricks/profiling/project_get_table_sample_count_databricks.sql @@ -0,0 +1,23 @@ +WITH stats + AS (SELECT COUNT(*)::FLOAT as record_ct, + ROUND(CAST({PROFILE_SAMPLE_PERCENT} as FLOAT) * CAST(COUNT(*) as FLOAT) / 100.0) as calc_sample_ct, + CAST({PROFILE_SAMPLE_MIN_COUNT} as FLOAT) as min_sample_ct, + CAST(999000 as FLOAT) as max_sample_ct + FROM {SAMPLING_TABLE} ) +SELECT '{SAMPLING_TABLE}' as schema_table, + CASE WHEN record_ct <= min_sample_ct THEN -1 + WHEN calc_sample_ct > max_sample_ct THEN max_sample_ct + WHEN calc_sample_ct > min_sample_ct THEN calc_sample_ct + ELSE {PROFILE_SAMPLE_MIN_COUNT} + END as sample_count, + CASE WHEN record_ct <= min_sample_ct THEN 1 + WHEN calc_sample_ct > max_sample_ct THEN record_ct / max_sample_ct + WHEN calc_sample_ct > min_sample_ct THEN record_ct / calc_sample_ct + ELSE record_ct / min_sample_ct + END as sample_ratio, + ROUND(CASE WHEN record_ct <= min_sample_ct THEN 100 + WHEN calc_sample_ct > max_sample_ct THEN 100.0 * max_sample_ct / record_ct + WHEN calc_sample_ct > min_sample_ct THEN 100.0 * calc_sample_ct / record_ct + ELSE 100.0 * min_sample_ct / record_ct + END, 4) as sample_percent_calc + FROM stats; diff --git a/testgen/template/flavors/databricks/profiling/project_profiling_query_databricks.yaml b/testgen/template/flavors/databricks/profiling/project_profiling_query_databricks.yaml index d42c6947..d7612bd3 100644 --- a/testgen/template/flavors/databricks/profiling/project_profiling_query_databricks.yaml +++ b/testgen/template/flavors/databricks/profiling/project_profiling_query_databricks.yaml @@ -262,7 +262,7 @@ strTemplate15_ALL: NULL as functional_data_type, strTemplate16_ALL: " '{PROFILE_RUN_ID}' as profile_run_id" -strTemplate98_sampling: ' FROM {DATA_SCHEMA}.{DATA_TABLE} LIMIT {SAMPLE_SIZE}' +strTemplate98_sampling: ' FROM {DATA_SCHEMA}.{DATA_TABLE} TABLESAMPLE ({SAMPLE_PERCENT_CALC} PERCENT)' strTemplate98_else: ' FROM {DATA_SCHEMA}.{DATA_TABLE}' @@ -273,6 +273,13 @@ strTemplate99_N: | PERCENTILE_CONT(0.75) WITHIN GROUP (ORDER BY `{COL_NAME}`) OVER () AS pct_75 FROM {DATA_SCHEMA}.{DATA_TABLE} LIMIT 1) pctile +strTemplate99_N_sampling: | + , (SELECT + PERCENTILE_CONT(0.25) WITHIN GROUP (ORDER BY `{COL_NAME}`) OVER () AS pct_25, + PERCENTILE_CONT(0.50) WITHIN GROUP (ORDER BY `{COL_NAME}`) OVER () AS pct_50, + PERCENTILE_CONT(0.75) WITHIN GROUP (ORDER BY `{COL_NAME}`) OVER () AS pct_75 + FROM {DATA_SCHEMA}.{DATA_TABLE} TABLESAMPLE ({SAMPLE_PERCENT_CALC} PERCENT) LIMIT 1 ) pctile + strTemplate99_else: ' ' -strTemplate100_sampling: ' ORDER BY RAND()' +strTemplate100_sampling: ' ' diff --git a/testgen/template/flavors/mssql/profiling/project_get_table_sample_count_mssql.sql b/testgen/template/flavors/mssql/profiling/project_get_table_sample_count_mssql.sql new file mode 100644 index 00000000..b7ccafaf --- /dev/null +++ b/testgen/template/flavors/mssql/profiling/project_get_table_sample_count_mssql.sql @@ -0,0 +1,23 @@ +WITH stats + AS (SELECT CAST(COUNT(*) as FLOAT) as record_ct, + ROUND(CAST({PROFILE_SAMPLE_PERCENT} as FLOAT) * CAST(COUNT(*) as FLOAT) / 100.0, 0) as calc_sample_ct, + CAST({PROFILE_SAMPLE_MIN_COUNT} as FLOAT) as min_sample_ct, + CAST(999000 as FLOAT) as max_sample_ct + FROM {SAMPLING_TABLE} ) +SELECT '{SAMPLING_TABLE}' as schema_table, + CASE WHEN record_ct <= min_sample_ct THEN -1 + WHEN calc_sample_ct > max_sample_ct THEN max_sample_ct + WHEN calc_sample_ct > min_sample_ct THEN calc_sample_ct + ELSE {PROFILE_SAMPLE_MIN_COUNT} + END as sample_count, + CASE WHEN record_ct <= min_sample_ct THEN 1 + WHEN calc_sample_ct > max_sample_ct THEN record_ct / max_sample_ct + WHEN calc_sample_ct > min_sample_ct THEN record_ct / calc_sample_ct + ELSE record_ct / min_sample_ct + END as sample_ratio, + ROUND(CASE WHEN record_ct <= min_sample_ct THEN 100 + WHEN calc_sample_ct > max_sample_ct THEN 100.0 * max_sample_ct / record_ct + WHEN calc_sample_ct > min_sample_ct THEN 100.0 * calc_sample_ct / record_ct + ELSE 100.0 * min_sample_ct / record_ct + END, 4) as sample_percent_calc + FROM stats; diff --git a/testgen/template/flavors/mssql/profiling/project_profiling_query_mssql.yaml b/testgen/template/flavors/mssql/profiling/project_profiling_query_mssql.yaml index 40a7568e..adf38026 100644 --- a/testgen/template/flavors/mssql/profiling/project_profiling_query_mssql.yaml +++ b/testgen/template/flavors/mssql/profiling/project_profiling_query_mssql.yaml @@ -1,5 +1,5 @@ --- -strTemplate01_sampling: "SELECT TOP {SAMPLE_SIZE} " +strTemplate01_sampling: "SELECT " strTemplate01_else: "SELECT " strTemplate02_all: | {CONNECTION_ID} as connection_id, @@ -259,7 +259,7 @@ strTemplate15_ALL: NULL as functional_data_type, strTemplate16_ALL: " '{PROFILE_RUN_ID}' as profile_run_id" -strTemplate98_sampling: ' FROM {DATA_SCHEMA}.{DATA_TABLE} WITH (NOLOCK)' +strTemplate98_sampling: ' FROM {DATA_SCHEMA}.{DATA_TABLE} TABLESAMPLE ({SAMPLE_PERCENT_CALC} PERCENT) WITH (NOLOCK)' strTemplate98_else: ' FROM {DATA_SCHEMA}.{DATA_TABLE} WITH (NOLOCK)' @@ -270,6 +270,13 @@ strTemplate99_N: | PERCENTILE_CONT(0.75) WITHIN GROUP (ORDER BY "{COL_NAME}") OVER () AS pct_75 FROM {DATA_SCHEMA}.{DATA_TABLE} WITH (NOLOCK)) pctile +strTemplate99_N_sampling: | + , (SELECT TOP 1 + PERCENTILE_CONT(0.25) WITHIN GROUP (ORDER BY "{COL_NAME}") OVER () AS pct_25, + PERCENTILE_CONT(0.50) WITHIN GROUP (ORDER BY "{COL_NAME}") OVER () AS pct_50, + PERCENTILE_CONT(0.75) WITHIN GROUP (ORDER BY "{COL_NAME}") OVER () AS pct_75 + FROM {DATA_SCHEMA}.{DATA_TABLE} TABLESAMPLE ({SAMPLE_PERCENT_CALC} PERCENT) WITH (NOLOCK)) pctile + strTemplate99_else: ' ' -strTemplate100_sampling: ' ORDER BY RAND()' +strTemplate100_sampling: ' ' diff --git a/testgen/template/flavors/postgresql/profiling/project_get_table_sample_count_postgresql.sql b/testgen/template/flavors/postgresql/profiling/project_get_table_sample_count_postgresql.sql new file mode 100644 index 00000000..6939bae9 --- /dev/null +++ b/testgen/template/flavors/postgresql/profiling/project_get_table_sample_count_postgresql.sql @@ -0,0 +1,23 @@ +WITH stats + AS (SELECT COUNT(*)::FLOAT as record_ct, + ROUND(CAST({PROFILE_SAMPLE_PERCENT} as FLOAT) * CAST(COUNT(*) as FLOAT) / 100.0) as calc_sample_ct, + CAST({PROFILE_SAMPLE_MIN_COUNT} as FLOAT) as min_sample_ct, + CAST(999000 as FLOAT) as max_sample_ct + FROM {SAMPLING_TABLE} ) +SELECT '{SAMPLING_TABLE}' as schema_table, + CASE WHEN record_ct <= min_sample_ct THEN -1 + WHEN calc_sample_ct > max_sample_ct THEN max_sample_ct + WHEN calc_sample_ct > min_sample_ct THEN calc_sample_ct + ELSE {PROFILE_SAMPLE_MIN_COUNT} + END as sample_count, + CASE WHEN record_ct <= min_sample_ct THEN 1 + WHEN calc_sample_ct > max_sample_ct THEN record_ct / max_sample_ct + WHEN calc_sample_ct > min_sample_ct THEN record_ct / calc_sample_ct + ELSE record_ct / min_sample_ct + END as sample_ratio, + ROUND(CASE WHEN record_ct <= min_sample_ct THEN 100 + WHEN calc_sample_ct > max_sample_ct THEN 100.0 * max_sample_ct / record_ct + WHEN calc_sample_ct > min_sample_ct THEN 100.0 * calc_sample_ct / record_ct + ELSE 100.0 * min_sample_ct / record_ct + END::NUMERIC, 4) as sample_percent_calc + FROM stats; diff --git a/testgen/template/flavors/postgresql/profiling/project_profiling_query_postgresql.yaml b/testgen/template/flavors/postgresql/profiling/project_profiling_query_postgresql.yaml index a12c2d50..7fe97764 100644 --- a/testgen/template/flavors/postgresql/profiling/project_profiling_query_postgresql.yaml +++ b/testgen/template/flavors/postgresql/profiling/project_profiling_query_postgresql.yaml @@ -237,7 +237,7 @@ strTemplate15_ALL: NULL as functional_data_type, strTemplate16_ALL: " '{PROFILE_RUN_ID}' as profile_run_id" -strTemplate98_sampling: ' FROM {DATA_SCHEMA}.{DATA_TABLE} ' +strTemplate98_sampling: ' FROM {DATA_SCHEMA}.{DATA_TABLE} TABLESAMPLE BERNOULLI ({SAMPLE_PERCENT_CALC}) REPEATABLE (64)' strTemplate98_else: ' FROM {DATA_SCHEMA}.{DATA_TABLE} ' @@ -248,6 +248,13 @@ strTemplate99_N: | PERCENTILE_CONT(0.75) WITHIN GROUP (ORDER BY "{COL_NAME}") AS pct_75 FROM {DATA_SCHEMA}.{DATA_TABLE} LIMIT 1) pctile +strTemplate99_N_sampling: | + , (SELECT + PERCENTILE_CONT(0.25) WITHIN GROUP (ORDER BY "{COL_NAME}") AS pct_25, + PERCENTILE_CONT(0.50) WITHIN GROUP (ORDER BY "{COL_NAME}") AS pct_50, + PERCENTILE_CONT(0.75) WITHIN GROUP (ORDER BY "{COL_NAME}") AS pct_75 + FROM {DATA_SCHEMA}.{DATA_TABLE} TABLESAMPLE BERNOULLI ({SAMPLE_PERCENT_CALC}) REPEATABLE (64) LIMIT 1) pctile + strTemplate99_else: ' ' -strTemplate100_sampling: 'WHERE RAND() <= 1.0 / {PROFILE_SAMPLE_RATIO}' +strTemplate100_sampling: ' ' diff --git a/testgen/template/flavors/redshift/profiling/project_get_table_sample_count_redshift.sql b/testgen/template/flavors/redshift/profiling/project_get_table_sample_count_redshift.sql new file mode 100644 index 00000000..9a62c3d6 --- /dev/null +++ b/testgen/template/flavors/redshift/profiling/project_get_table_sample_count_redshift.sql @@ -0,0 +1,23 @@ +WITH stats + AS (SELECT COUNT(*)::FLOAT as record_ct, + ROUND(CAST({PROFILE_SAMPLE_PERCENT} as FLOAT) * CAST(COUNT(*) as FLOAT) / 100.0) as calc_sample_ct, + CAST({PROFILE_SAMPLE_MIN_COUNT} as FLOAT) as min_sample_ct, + CAST(999000 as FLOAT) as max_sample_ct + FROM {SAMPLING_TABLE} ) +SELECT '{SAMPLING_TABLE}' as schema_table, + CASE WHEN record_ct <= min_sample_ct THEN -1 + WHEN calc_sample_ct > max_sample_ct THEN max_sample_ct + WHEN calc_sample_ct > min_sample_ct THEN calc_sample_ct + ELSE {PROFILE_SAMPLE_MIN_COUNT} + END as sample_count, + CASE WHEN record_ct <= min_sample_ct THEN 1 + WHEN calc_sample_ct > max_sample_ct THEN record_ct / max_sample_ct + WHEN calc_sample_ct > min_sample_ct THEN record_ct / calc_sample_ct + ELSE record_ct / min_sample_ct + END as sample_ratio, + ROUND(CASE WHEN record_ct <= min_sample_ct THEN 100 + WHEN calc_sample_ct > max_sample_ct THEN 100.0 * max_sample_ct / record_ct + WHEN calc_sample_ct > min_sample_ct THEN 100.0 * calc_sample_ct / record_ct + ELSE 100.0 * min_sample_ct / record_ct + END, 4) as sample_percent_calc + FROM stats; diff --git a/testgen/template/flavors/redshift/profiling/project_profiling_query_redshift.yaml b/testgen/template/flavors/redshift/profiling/project_profiling_query_redshift.yaml index b64d4d54..0edbe8a7 100644 --- a/testgen/template/flavors/redshift/profiling/project_profiling_query_redshift.yaml +++ b/testgen/template/flavors/redshift/profiling/project_profiling_query_redshift.yaml @@ -201,6 +201,13 @@ strTemplate99_N: | PERCENTILE_CONT(0.75) WITHIN GROUP (ORDER BY "{COL_NAME}") OVER () AS pct_75 FROM {DATA_SCHEMA}.{DATA_TABLE} LIMIT 1) pctile +strTemplate99_N_sampling: | + , (SELECT + PERCENTILE_CONT(0.25) WITHIN GROUP (ORDER BY "{COL_NAME}") OVER () AS pct_25, + PERCENTILE_CONT(0.50) WITHIN GROUP (ORDER BY "{COL_NAME}") OVER () AS pct_50, + PERCENTILE_CONT(0.75) WITHIN GROUP (ORDER BY "{COL_NAME}") OVER () AS pct_75 + FROM {DATA_SCHEMA}.{DATA_TABLE} LIMIT 1) pctile + strTemplate99_else: ' ' strTemplate100_sampling: 'WHERE RAND() <= 1.0 / {PROFILE_SAMPLE_RATIO}' diff --git a/testgen/template/flavors/snowflake/profiling/project_get_table_sample_count_snowflake.sql b/testgen/template/flavors/snowflake/profiling/project_get_table_sample_count_snowflake.sql new file mode 100644 index 00000000..9a62c3d6 --- /dev/null +++ b/testgen/template/flavors/snowflake/profiling/project_get_table_sample_count_snowflake.sql @@ -0,0 +1,23 @@ +WITH stats + AS (SELECT COUNT(*)::FLOAT as record_ct, + ROUND(CAST({PROFILE_SAMPLE_PERCENT} as FLOAT) * CAST(COUNT(*) as FLOAT) / 100.0) as calc_sample_ct, + CAST({PROFILE_SAMPLE_MIN_COUNT} as FLOAT) as min_sample_ct, + CAST(999000 as FLOAT) as max_sample_ct + FROM {SAMPLING_TABLE} ) +SELECT '{SAMPLING_TABLE}' as schema_table, + CASE WHEN record_ct <= min_sample_ct THEN -1 + WHEN calc_sample_ct > max_sample_ct THEN max_sample_ct + WHEN calc_sample_ct > min_sample_ct THEN calc_sample_ct + ELSE {PROFILE_SAMPLE_MIN_COUNT} + END as sample_count, + CASE WHEN record_ct <= min_sample_ct THEN 1 + WHEN calc_sample_ct > max_sample_ct THEN record_ct / max_sample_ct + WHEN calc_sample_ct > min_sample_ct THEN record_ct / calc_sample_ct + ELSE record_ct / min_sample_ct + END as sample_ratio, + ROUND(CASE WHEN record_ct <= min_sample_ct THEN 100 + WHEN calc_sample_ct > max_sample_ct THEN 100.0 * max_sample_ct / record_ct + WHEN calc_sample_ct > min_sample_ct THEN 100.0 * calc_sample_ct / record_ct + ELSE 100.0 * min_sample_ct / record_ct + END, 4) as sample_percent_calc + FROM stats; diff --git a/testgen/template/flavors/snowflake/profiling/project_profiling_query_snowflake.yaml b/testgen/template/flavors/snowflake/profiling/project_profiling_query_snowflake.yaml index 1caffc0a..ce9f3066 100644 --- a/testgen/template/flavors/snowflake/profiling/project_profiling_query_snowflake.yaml +++ b/testgen/template/flavors/snowflake/profiling/project_profiling_query_snowflake.yaml @@ -206,6 +206,15 @@ strTemplate99_N: | PERCENTILE_CONT(0.50) WITHIN GROUP (ORDER BY "{COL_NAME}") OVER () AS pct_50, PERCENTILE_CONT(0.75) WITHIN GROUP (ORDER BY "{COL_NAME}") OVER () AS pct_75 FROM {DATA_SCHEMA}.{DATA_TABLE} LIMIT 1) pctile + +strTemplate99_N_sampling: | + , + (SELECT + PERCENTILE_CONT(0.25) WITHIN GROUP (ORDER BY "{COL_NAME}") OVER () AS pct_25, + PERCENTILE_CONT(0.50) WITHIN GROUP (ORDER BY "{COL_NAME}") OVER () AS pct_50, + PERCENTILE_CONT(0.75) WITHIN GROUP (ORDER BY "{COL_NAME}") OVER () AS pct_75 + FROM {DATA_SCHEMA}.{DATA_TABLE} SAMPLE ({SAMPLE_SIZE} rows) LIMIT 1 ) pctile + strTemplate99_else: ; strTemplate100_sampling: ' ' diff --git a/testgen/template/flavors/trino/profiling/project_get_table_sample_count_trino.sql b/testgen/template/flavors/trino/profiling/project_get_table_sample_count_trino.sql new file mode 100644 index 00000000..23f5a4bf --- /dev/null +++ b/testgen/template/flavors/trino/profiling/project_get_table_sample_count_trino.sql @@ -0,0 +1,23 @@ +WITH stats + AS (SELECT COUNT(*)::REAL as record_ct, + ROUND(CAST({PROFILE_SAMPLE_PERCENT} as REAL) * CAST(COUNT(*) as REAL) / 100.0) as calc_sample_ct, + CAST({PROFILE_SAMPLE_MIN_COUNT} as REAL) as min_sample_ct, + CAST(999000 as REAL) as max_sample_ct + FROM {SAMPLING_TABLE} ) +SELECT '{SAMPLING_TABLE}' as schema_table, + CASE WHEN record_ct <= min_sample_ct THEN -1 + WHEN calc_sample_ct > max_sample_ct THEN max_sample_ct + WHEN calc_sample_ct > min_sample_ct THEN calc_sample_ct + ELSE {PROFILE_SAMPLE_MIN_COUNT} + END as sample_count, + CASE WHEN record_ct <= min_sample_ct THEN 1 + WHEN calc_sample_ct > max_sample_ct THEN record_ct / max_sample_ct + WHEN calc_sample_ct > min_sample_ct THEN record_ct / calc_sample_ct + ELSE record_ct / min_sample_ct + END as sample_ratio, + ROUND(CASE WHEN record_ct <= min_sample_ct THEN 100 + WHEN calc_sample_ct > max_sample_ct THEN 100.0 * max_sample_ct / record_ct + WHEN calc_sample_ct > min_sample_ct THEN 100.0 * calc_sample_ct / record_ct + ELSE 100.0 * min_sample_ct / record_ct + END, 4) as sample_percent_calc + FROM stats; diff --git a/testgen/template/flavors/trino/profiling/project_profiling_query_trino.yaml b/testgen/template/flavors/trino/profiling/project_profiling_query_trino.yaml index c1355afc..284605b4 100644 --- a/testgen/template/flavors/trino/profiling/project_profiling_query_trino.yaml +++ b/testgen/template/flavors/trino/profiling/project_profiling_query_trino.yaml @@ -233,7 +233,7 @@ strTemplate15_ALL: NULL as functional_data_type, strTemplate16_ALL: " '{PROFILE_RUN_ID}' as profile_run_id" -strTemplate98_sampling: ' FROM {DATA_SCHEMA}.{DATA_TABLE} ' +strTemplate98_sampling: ' FROM {DATA_SCHEMA}.{DATA_TABLE} TABLESAMPLE SYSTEM ({SAMPLE_PERCENT_CALC})' strTemplate98_else: ' FROM {DATA_SCHEMA}.{DATA_TABLE}' @@ -244,6 +244,13 @@ strTemplate99_N: | APPROX_PERCENTILE("{COL_NAME}", 0.75) AS pct_75 FROM {DATA_SCHEMA}.{DATA_TABLE} LIMIT 1) pctile +strTemplate99_N_sampling: | + , (SELECT + APPROX_PERCENTILE("{COL_NAME}", 0.25) AS pct_25, + APPROX_PERCENTILE("{COL_NAME}", 0.50) AS pct_50, + APPROX_PERCENTILE("{COL_NAME}", 0.75) AS pct_75 + FROM {DATA_SCHEMA}.{DATA_TABLE} TABLESAMPLE SYSTEM ({SAMPLE_PERCENT_CALC}) ) pctile + strTemplate99_else: ' ' -strTemplate100_sampling: 'WHERE RAND() <= 1.0 / {PROFILE_SAMPLE_RATIO}' +strTemplate100_sampling: ' ' diff --git a/testgen/template/gen_funny_cat_tests/gen_test_constant.sql b/testgen/template/gen_funny_cat_tests/gen_test_constant.sql index 4270d713..98181fb6 100644 --- a/testgen/template/gen_funny_cat_tests/gen_test_constant.sql +++ b/testgen/template/gen_funny_cat_tests/gen_test_constant.sql @@ -12,10 +12,10 @@ WITH last_run AS (SELECT r.table_groups_id, MAX(run_date) AS last_run_date INNER JOIN test_suites ts ON p.project_code = ts.project_code AND p.connection_id = ts.connection_id - WHERE p.project_code = '{PROJECT_CODE}' - AND r.table_groups_id = '{TABLE_GROUPS_ID}'::UUID - AND ts.id = '{TEST_SUITE_ID}' - AND p.run_date::DATE <= '{AS_OF_DATE}' + WHERE p.project_code = :PROJECT_CODE + AND r.table_groups_id = :TABLE_GROUPS_ID + AND ts.id = :TEST_SUITE_ID + AND p.run_date::DATE <= :AS_OF_DATE GROUP BY r.table_groups_id), curprof AS (SELECT p.* FROM last_run lr @@ -24,8 +24,8 @@ WITH last_run AS (SELECT r.table_groups_id, MAX(run_date) AS last_run_date AND lr.last_run_date = p.run_date) ), locked AS (SELECT schema_name, table_name, column_name, test_type FROM test_definitions - WHERE table_groups_id = '{TABLE_GROUPS_ID}'::UUID - AND test_suite_id = '{TEST_SUITE_ID}' + WHERE table_groups_id = :TABLE_GROUPS_ID + AND test_suite_id = :TEST_SUITE_ID AND lock_refresh = 'Y'), all_runs AS ( SELECT DISTINCT p.table_groups_id, p.schema_name, p.run_date, DENSE_RANK() OVER (PARTITION BY p.table_groups_id ORDER BY p.run_date DESC) as run_rank @@ -33,9 +33,9 @@ WITH last_run AS (SELECT r.table_groups_id, MAX(run_date) AS last_run_date INNER JOIN test_suites ts ON p.connection_id = ts.connection_id AND p.project_code = ts.project_code - WHERE p.table_groups_id = '{TABLE_GROUPS_ID}'::UUID - AND ts.id = '{TEST_SUITE_ID}' - AND p.run_date::DATE <= '{AS_OF_DATE}'), + WHERE p.table_groups_id = :TABLE_GROUPS_ID + AND ts.id = :TEST_SUITE_ID + AND p.run_date::DATE <= :AS_OF_DATE), recent_runs AS (SELECT table_groups_id, schema_name, run_date, run_rank FROM all_runs WHERE run_rank <= 5), @@ -73,12 +73,12 @@ WITH last_run AS (SELECT r.table_groups_id, MAX(run_date) AS last_run_date -- Only constant if more than one profiling result AND COUNT(*) > 1), newtests AS ( SELECT 'Constant'::VARCHAR AS test_type, - '{TEST_SUITE_ID}'::UUID AS test_suite_id, + :TEST_SUITE_ID ::UUID AS test_suite_id, c.profile_run_id, c.schema_name, c.table_name, c.column_name, c.run_date AS last_run_date, - case when general_type='A' then fn_quote_literal_escape(min_text, '{SQL_FLAVOR}')::VARCHAR - when general_type='D' then fn_quote_literal_escape(min_date :: VARCHAR, '{SQL_FLAVOR}')::VARCHAR + case when general_type='A' then fn_quote_literal_escape(min_text, :SQL_FLAVOR)::VARCHAR + when general_type='D' then fn_quote_literal_escape(min_date :: VARCHAR, :SQL_FLAVOR)::VARCHAR when general_type='N' then min_value::VARCHAR when general_type='B' and boolean_true_ct = 0 then 'FALSE'::VARCHAR when general_type='B' and boolean_true_ct > 0 then 'TRUE'::VARCHAR @@ -90,14 +90,14 @@ newtests AS ( SELECT 'Constant'::VARCHAR AS test_type, AND c.column_name = r.column_name) LEFT JOIN generation_sets s ON ('Constant' = s.test_type - AND '{GENERATION_SET}' = s.generation_set) + AND :GENERATION_SET = s.generation_set) WHERE (s.generation_set IS NOT NULL - OR '{GENERATION_SET}' = '') ) -SELECT '{TABLE_GROUPS_ID}'::UUID as table_groups_id, n.profile_run_id, + OR :GENERATION_SET = '') ) +SELECT :TABLE_GROUPS_ID as table_groups_id, n.profile_run_id, n.test_type, n.test_suite_id, n.schema_name, n.table_name, n.column_name, - 0 as skip_errors, '{RUN_DATE}'::TIMESTAMP as auto_gen_date, + 0 as skip_errors, :RUN_DATE ::TIMESTAMP as auto_gen_date, 'Y' as test_active, COALESCE(baseline_value, '') as baseline_value, - '0' as threshold_value, '{AS_OF_DATE}'::TIMESTAMP + '0' as threshold_value, :AS_OF_DATE ::TIMESTAMP FROM newtests n LEFT JOIN locked l ON (n.schema_name = l.schema_name diff --git a/testgen/template/gen_funny_cat_tests/gen_test_distinct_value_ct.sql b/testgen/template/gen_funny_cat_tests/gen_test_distinct_value_ct.sql index ab939339..350e1048 100644 --- a/testgen/template/gen_funny_cat_tests/gen_test_distinct_value_ct.sql +++ b/testgen/template/gen_funny_cat_tests/gen_test_distinct_value_ct.sql @@ -11,10 +11,10 @@ WITH last_run AS (SELECT r.table_groups_id, MAX(run_date) AS last_run_date INNER JOIN test_suites ts ON p.project_code = ts.project_code AND p.connection_id = ts.connection_id - WHERE p.project_code = '{PROJECT_CODE}' - AND r.table_groups_id = '{TABLE_GROUPS_ID}'::UUID - AND ts.id = '{TEST_SUITE_ID}' - AND p.run_date::DATE <= '{AS_OF_DATE}' + WHERE p.project_code = :PROJECT_CODE + AND r.table_groups_id = :TABLE_GROUPS_ID + AND ts.id = :TEST_SUITE_ID + AND p.run_date::DATE <= :AS_OF_DATE GROUP BY r.table_groups_id), curprof AS (SELECT p.* FROM last_run lr @@ -23,8 +23,8 @@ WITH last_run AS (SELECT r.table_groups_id, MAX(run_date) AS last_run_date AND lr.last_run_date = p.run_date) ), locked AS (SELECT schema_name, table_name, column_name, test_type FROM test_definitions - WHERE table_groups_id = '{TABLE_GROUPS_ID}'::UUID - AND test_suite_id = '{TEST_SUITE_ID}' + WHERE table_groups_id = :TABLE_GROUPS_ID + AND test_suite_id = :TEST_SUITE_ID AND lock_refresh = 'Y'), all_runs AS ( SELECT DISTINCT p.table_groups_id, p.schema_name, p.run_date, DENSE_RANK() OVER (PARTITION BY p.table_groups_id ORDER BY p.run_date DESC) as run_rank @@ -32,9 +32,9 @@ WITH last_run AS (SELECT r.table_groups_id, MAX(run_date) AS last_run_date INNER JOIN test_suites ts ON p.connection_id = ts.connection_id AND p.project_code = ts.project_code - WHERE p.table_groups_id = '{TABLE_GROUPS_ID}'::UUID - AND ts.id = '{TEST_SUITE_ID}' - AND p.run_date::DATE <= '{AS_OF_DATE}'), + WHERE p.table_groups_id = :TABLE_GROUPS_ID + AND ts.id = :TEST_SUITE_ID + AND p.run_date::DATE <= :AS_OF_DATE), recent_runs AS (SELECT table_groups_id, schema_name, run_date, run_rank FROM all_runs WHERE run_rank <= 5), @@ -69,7 +69,7 @@ WITH last_run AS (SELECT r.table_groups_id, MAX(run_date) AS last_run_date -- include cases with only single profiling result -- can't yet assume constant OR COUNT(*) = 1)), newtests AS ( SELECT 'Distinct_Value_Ct'::VARCHAR AS test_type, - '{TEST_SUITE_ID}'::UUID AS test_suite_id, + :TEST_SUITE_ID ::UUID AS test_suite_id, c.table_groups_id, c.profile_run_id, c.schema_name, c.table_name, c.column_name, c.run_date AS last_run_date, @@ -81,15 +81,15 @@ newtests AS ( SELECT 'Distinct_Value_Ct'::VARCHAR AS test_type, AND c.column_name = r.column_name) LEFT JOIN generation_sets s ON ('Distinct_Value_Ct' = s.test_type - AND '{GENERATION_SET}' = s.generation_set) + AND :GENERATION_SET = s.generation_set) WHERE (s.generation_set IS NOT NULL - OR '{GENERATION_SET}' = '') ) + OR :GENERATION_SET = '') ) SELECT n.table_groups_id, n.profile_run_id, n.test_type, n.test_suite_id, n.schema_name, n.table_name, n.column_name, 0 as skip_errors, - '{RUN_DATE}'::TIMESTAMP as last_auto_gen_date, 'Y' as test_active, + :RUN_DATE ::TIMESTAMP as last_auto_gen_date, 'Y' as test_active, distinct_value_ct as baseline_value_ct, distinct_value_ct as threshold_value, - '{AS_OF_DATE}'::TIMESTAMP as profiling_as_of_date + :AS_OF_DATE ::TIMESTAMP as profiling_as_of_date FROM newtests n LEFT JOIN locked l ON (n.schema_name = l.schema_name diff --git a/testgen/template/gen_funny_cat_tests/gen_test_row_ct.sql b/testgen/template/gen_funny_cat_tests/gen_test_row_ct.sql index 55b626e0..c1e4578f 100644 --- a/testgen/template/gen_funny_cat_tests/gen_test_row_ct.sql +++ b/testgen/template/gen_funny_cat_tests/gen_test_row_ct.sql @@ -10,10 +10,10 @@ WITH last_run AS (SELECT r.table_groups_id, MAX(run_date) AS last_run_date INNER JOIN test_suites ts ON p.project_code = ts.project_code AND p.connection_id = ts.connection_id - WHERE p.project_code = '{PROJECT_CODE}' - AND r.table_groups_id = '{TABLE_GROUPS_ID}'::UUID - AND ts.id = '{TEST_SUITE_ID}' - AND p.run_date::DATE <= '{AS_OF_DATE}' + WHERE p.project_code = :PROJECT_CODE + AND r.table_groups_id = :TABLE_GROUPS_ID + AND ts.id = :TEST_SUITE_ID + AND p.run_date::DATE <= :AS_OF_DATE GROUP BY r.table_groups_id), curprof AS (SELECT p.* FROM last_run lr @@ -22,32 +22,32 @@ WITH last_run AS (SELECT r.table_groups_id, MAX(run_date) AS last_run_date AND lr.last_run_date = p.run_date) ), locked AS (SELECT schema_name, table_name, column_name, test_type FROM test_definitions - WHERE table_groups_id = '{TABLE_GROUPS_ID}'::UUID - AND test_suite_id = '{TEST_SUITE_ID}' + WHERE table_groups_id = :TABLE_GROUPS_ID + AND test_suite_id = :TEST_SUITE_ID AND lock_refresh = 'Y'), newtests AS (SELECT table_groups_id, profile_run_id, 'Row_Ct' AS test_type, - '{TEST_SUITE_ID}'::UUID AS test_suite_id, + :TEST_SUITE_ID ::UUID AS test_suite_id, schema_name, table_name, MAX(record_ct) as record_ct FROM curprof c LEFT JOIN generation_sets s ON ('Row_Ct' = s.test_type - AND '{GENERATION_SET}' = s.generation_set) - WHERE schema_name = '{DATA_SCHEMA}' + AND :GENERATION_SET = s.generation_set) + WHERE schema_name = :DATA_SCHEMA AND functional_table_type LIKE '%cumulative%' AND (s.generation_set IS NOT NULL - OR '{GENERATION_SET}' = '') + OR :GENERATION_SET = '') GROUP BY project_code, table_groups_id, profile_run_id, test_type, test_suite_id, schema_name, table_name ) SELECT n.table_groups_id, n.profile_run_id, n.test_type, n.test_suite_id, n.schema_name, n.table_name, 0 as skip_errors, record_ct AS threshold_value, - '{RUN_DATE}'::TIMESTAMP as last_auto_gen_date, + :RUN_DATE ::TIMESTAMP as last_auto_gen_date, 'Y' as test_active, record_ct as baseline_ct, - '{AS_OF_DATE}'::TIMESTAMP as profiling_as_of_date + :AS_OF_DATE ::TIMESTAMP as profiling_as_of_date FROM newtests n LEFT JOIN locked l ON (n.schema_name = l.schema_name diff --git a/testgen/template/gen_funny_cat_tests/gen_test_row_ct_pct.sql b/testgen/template/gen_funny_cat_tests/gen_test_row_ct_pct.sql index d68a4321..a338a2e2 100644 --- a/testgen/template/gen_funny_cat_tests/gen_test_row_ct_pct.sql +++ b/testgen/template/gen_funny_cat_tests/gen_test_row_ct_pct.sql @@ -10,10 +10,10 @@ WITH last_run AS (SELECT r.table_groups_id, MAX(run_date) AS last_run_date INNER JOIN test_suites ts ON p.project_code = ts.project_code AND p.connection_id = ts.connection_id - WHERE p.project_code = '{PROJECT_CODE}' - AND r.table_groups_id = '{TABLE_GROUPS_ID}'::UUID - AND ts.id = '{TEST_SUITE_ID}' - AND p.run_date::DATE <= '{AS_OF_DATE}' + WHERE p.project_code = :PROJECT_CODE + AND r.table_groups_id = :TABLE_GROUPS_ID + AND ts.id = :TEST_SUITE_ID + AND p.run_date::DATE <= :AS_OF_DATE GROUP BY r.table_groups_id), curprof AS (SELECT p.* FROM last_run lr @@ -22,33 +22,33 @@ WITH last_run AS (SELECT r.table_groups_id, MAX(run_date) AS last_run_date AND lr.last_run_date = p.run_date) ), locked AS (SELECT schema_name, table_name, column_name, test_type FROM test_definitions - WHERE table_groups_id = '{TABLE_GROUPS_ID}'::UUID - AND test_suite_id = '{TEST_SUITE_ID}' + WHERE table_groups_id = :TABLE_GROUPS_ID + AND test_suite_id = :TEST_SUITE_ID AND lock_refresh = 'Y'), newtests AS ( SELECT table_groups_id, profile_run_id, 'Row_Ct_Pct' AS test_type, - '{TEST_SUITE_ID}'::UUID AS test_suite_id, + :TEST_SUITE_ID ::UUID AS test_suite_id, schema_name, table_name, MAX(record_ct) as record_ct FROM curprof LEFT JOIN generation_sets s ON ('Row_Ct_Pct' = s.test_type - AND '{GENERATION_SET}' = s.generation_set) - WHERE schema_name = '{DATA_SCHEMA}' + AND :GENERATION_SET = s.generation_set) + WHERE schema_name = :DATA_SCHEMA AND functional_table_type NOT ILIKE '%cumulative%' AND (s.generation_set IS NOT NULL - OR '{GENERATION_SET}' = '') + OR :GENERATION_SET = '') GROUP BY project_code, table_groups_id, profile_run_id, test_type, test_suite_id, schema_name, table_name HAVING MAX(record_ct) >= 500) SELECT n.table_groups_id, n.profile_run_id, n.test_type, n.test_suite_id, n.schema_name, n.table_name, 0 as skip_errors, - '{RUN_DATE}'::TIMESTAMP as last_auto_gen_date, - '{AS_OF_DATE}'::TIMESTAMP as profiling_as_of_date, + :RUN_DATE ::TIMESTAMP as last_auto_gen_date, + :AS_OF_DATE ::TIMESTAMP as profiling_as_of_date, 'Y' as test_active, record_ct as baseline_ct, 0.5 AS threshold_value FROM newtests n diff --git a/testgen/template/generation/gen_delete_old_tests.sql b/testgen/template/generation/gen_delete_old_tests.sql index 94463045..0aeeec7d 100644 --- a/testgen/template/generation/gen_delete_old_tests.sql +++ b/testgen/template/generation/gen_delete_old_tests.sql @@ -1,5 +1,5 @@ DELETE FROM test_definitions - WHERE table_groups_id = '{TABLE_GROUPS_ID}'::UUID - AND test_suite_id = '{TEST_SUITE_ID}' + WHERE table_groups_id = :TABLE_GROUPS_ID + AND test_suite_id = :TEST_SUITE_ID AND last_auto_gen_date IS NOT NULL AND COALESCE(lock_refresh, 'N') <> 'Y'; diff --git a/testgen/template/generation/gen_insert_test_suite.sql b/testgen/template/generation/gen_insert_test_suite.sql index d78becb6..c070f65b 100644 --- a/testgen/template/generation/gen_insert_test_suite.sql +++ b/testgen/template/generation/gen_insert_test_suite.sql @@ -1,6 +1,6 @@ INSERT INTO test_suites (project_code, test_suite, connection_id, table_groups_id, test_suite_description, component_type, component_key) -VALUES ('{PROJECT_CODE}', '{TEST_SUITE}', {CONNECTION_ID}, '{TABLE_GROUPS_ID}', '{TEST_SUITE} Test Suite', - 'dataset', '{TEST_SUITE}') +VALUES (:PROJECT_CODE, :TEST_SUITE, :CONNECTION_ID, :TABLE_GROUPS_ID, :TEST_SUITE || ' Test Suite', + 'dataset', :TEST_SUITE) RETURNING id::VARCHAR; diff --git a/testgen/template/generation/gen_standard_test_type_list.sql b/testgen/template/generation/gen_standard_test_type_list.sql index 11b7d17f..9f041c9f 100644 --- a/testgen/template/generation/gen_standard_test_type_list.sql +++ b/testgen/template/generation/gen_standard_test_type_list.sql @@ -5,9 +5,9 @@ SELECT t.test_type, FROM test_types t LEFT JOIN generation_sets s ON (t.test_type = s.test_type - AND '{GENERATION_SET}' = s.generation_set) + AND :GENERATION_SET = s.generation_set) WHERE t.active = 'Y' AND t.selection_criteria <> 'TEMPLATE' -- Also excludes NULL AND (s.generation_set IS NOT NULL - OR '{GENERATION_SET}' = '') + OR :GENERATION_SET = '') ORDER BY test_type; diff --git a/testgen/template/generation/gen_standard_tests.sql b/testgen/template/generation/gen_standard_tests.sql index c8b9a61b..2053ba54 100644 --- a/testgen/template/generation/gen_standard_tests.sql +++ b/testgen/template/generation/gen_standard_tests.sql @@ -10,10 +10,10 @@ WITH last_run AS (SELECT r.table_groups_id, MAX(run_date) AS last_run_date INNER JOIN test_suites ts ON p.project_code = ts.project_code AND p.connection_id = ts.connection_id - WHERE p.project_code = '{PROJECT_CODE}' - AND r.table_groups_id = '{TABLE_GROUPS_ID}'::UUID - AND ts.id = '{TEST_SUITE_ID}' - AND p.run_date::DATE <= '{AS_OF_DATE}' + WHERE p.project_code = :PROJECT_CODE + AND r.table_groups_id = :TABLE_GROUPS_ID + AND ts.id = :TEST_SUITE_ID + AND p.run_date::DATE <= :AS_OF_DATE GROUP BY r.table_groups_id), curprof AS (SELECT p.*, datediff('MM', p.min_date, p.max_date) as min_max_months, datediff('week', '1800-01-05'::DATE, p.max_date) - datediff('week', '1800-01-05'::DATE, p.min_date) as min_max_weeks FROM last_run lr @@ -22,21 +22,21 @@ WITH last_run AS (SELECT r.table_groups_id, MAX(run_date) AS last_run_date AND lr.last_run_date = p.run_date) ), locked AS (SELECT schema_name, table_name, column_name FROM test_definitions - WHERE table_groups_id = '{TABLE_GROUPS_ID}'::UUID - AND test_suite_id = '{TEST_SUITE_ID}' - AND test_type = '{TEST_TYPE}' + WHERE table_groups_id = :TABLE_GROUPS_ID + AND test_suite_id = :TEST_SUITE_ID + AND test_type = :TEST_TYPE AND lock_refresh = 'Y'), newtests AS (SELECT * FROM curprof - WHERE schema_name = '{DATA_SCHEMA}' + WHERE schema_name = :DATA_SCHEMA AND {SELECTION_CRITERIA} ) -SELECT '{TABLE_GROUPS_ID}'::UUID as table_groups_id, +SELECT :TABLE_GROUPS_ID as table_groups_id, n.profile_run_id, - '{TEST_TYPE}' AS test_type, - '{TEST_SUITE_ID}' AS test_suite_id, + :TEST_TYPE AS test_type, + :TEST_SUITE_ID AS test_suite_id, n.schema_name, n.table_name, n.column_name, - 0 as skip_errors, 'Y' as test_active, '{RUN_DATE}'::TIMESTAMP as last_auto_gen_date, - '{AS_OF_DATE}'::TIMESTAMP as profiling_as_of_date, + 0 as skip_errors, 'Y' as test_active, :RUN_DATE ::TIMESTAMP as last_auto_gen_date, + :AS_OF_DATE ::TIMESTAMP as profiling_as_of_date, {DEFAULT_PARM_VALUES} FROM newtests n LEFT JOIN locked l diff --git a/testgen/template/get_entities/get_profile_info.sql b/testgen/template/get_entities/get_profile_info.sql index e1d2148c..6079044a 100644 --- a/testgen/template/get_entities/get_profile_info.sql +++ b/testgen/template/get_entities/get_profile_info.sql @@ -12,6 +12,6 @@ SELECT profile_run_id, column_type, functional_data_type FROM profile_results - WHERE table_name ILIKE '{TABLE_NAME}' - AND profile_run_id = '{PROFILE_RUN}'::UUID + WHERE table_name ILIKE :TABLE_NAME + AND profile_run_id = :PROFILING_RUN_ID ORDER BY table_name, position; \ No newline at end of file diff --git a/testgen/template/get_entities/get_profile_list.sql b/testgen/template/get_entities/get_profile_list.sql index 92c93484..088bc745 100644 --- a/testgen/template/get_entities/get_profile_list.sql +++ b/testgen/template/get_entities/get_profile_list.sql @@ -11,7 +11,7 @@ SELECT p.id as profile_run_id, FROM profiling_runs p INNER JOIN profile_results r ON (p.id = r.profile_run_id) - WHERE p.table_groups_id = '{TABLE_GROUPS_ID}'::UUID + WHERE p.table_groups_id = :TABLE_GROUP_ID GROUP BY p.id, p.project_code, p.connection_id, schema_name, p.table_groups_id, profiling_starttime, status ORDER BY profiling_starttime DESC; diff --git a/testgen/template/get_entities/get_profile_screen.sql b/testgen/template/get_entities/get_profile_screen.sql index cdd42a09..914bd637 100644 --- a/testgen/template/get_entities/get_profile_screen.sql +++ b/testgen/template/get_entities/get_profile_screen.sql @@ -1,8 +1,8 @@ WITH profiling as ( SELECT * FROM profile_results - WHERE profile_run_id = '{PROFILE_RUN}'::UUID - AND table_name ILIKE '{TABLE_NAME}' ), + WHERE profile_run_id = :PROFILING_RUN_ID + AND table_name ILIKE :TABLE_NAME ), profile_date as (SELECT MAX(run_date) as run_date FROM profiling), mults AS ( SELECT p.project_code, diff --git a/testgen/template/get_entities/get_project_list.sql b/testgen/template/get_entities/get_project_list.sql index 8c3f153f..e86784da 100644 --- a/testgen/template/get_entities/get_project_list.sql +++ b/testgen/template/get_entities/get_project_list.sql @@ -3,4 +3,4 @@ SELECT id, project_name, observability_api_key FROM projects -ORDER BY effective_from_date desc; +ORDER BY project_name; diff --git a/testgen/template/get_entities/get_table_group_list.sql b/testgen/template/get_entities/get_table_group_list.sql index bcfae6e3..ea57e0a3 100644 --- a/testgen/template/get_entities/get_table_group_list.sql +++ b/testgen/template/get_entities/get_table_group_list.sql @@ -7,4 +7,4 @@ SELECT profiling_include_mask as include_mask, profiling_exclude_mask as exclude_mask FROM table_groups -where project_code = '{PROJECT_CODE}'; +where project_code = :PROJECT_CODE; diff --git a/testgen/template/get_entities/get_test_generation_list.sql b/testgen/template/get_entities/get_test_generation_list.sql index 2efa5587..95600b7e 100644 --- a/testgen/template/get_entities/get_test_generation_list.sql +++ b/testgen/template/get_entities/get_test_generation_list.sql @@ -12,8 +12,8 @@ Optional: n/a*/ COUNT(*) as tests FROM test_definitions td JOIN test_suites ts ON td.test_suite_id = ts.id - WHERE ts.project_code = '{PROJECT_CODE}' - AND ts.test_suite = '{TEST_SUITE}' + WHERE ts.project_code = :PROJECT_CODE + AND ts.test_suite = :TEST_SUITE AND td.last_auto_gen_date IS NOT NULL GROUP BY ts.id, td.last_auto_gen_date, td.profiling_as_of_date, td.lock_refresh ORDER BY td.last_auto_gen_date desc; diff --git a/testgen/template/get_entities/get_test_info.sql b/testgen/template/get_entities/get_test_info.sql index feb0cfb8..b941cc23 100644 --- a/testgen/template/get_entities/get_test_info.sql +++ b/testgen/template/get_entities/get_test_info.sql @@ -38,8 +38,8 @@ Optional: last_auto_run_date (==test-gen-run-id==), schema-name, table-name, col FROM test_definitions td INNER JOIN test_types tt ON td.test_type = tt.test_type INNER JOIN test_suites ts ON td.test_suite_id = ts.id - WHERE ts.project_code = '{PROJECT_CODE}' - AND ts.test_suite = '{TEST_SUITE}' + WHERE ts.project_code = :PROJECT_CODE + AND ts.test_suite = :TEST_SUITE ORDER BY td.schema_name, td.table_name, td.column_name, diff --git a/testgen/template/get_entities/get_test_results_for_run_cli.sql b/testgen/template/get_entities/get_test_results_for_run_cli.sql index 240cb014..dd96a337 100644 --- a/testgen/template/get_entities/get_test_results_for_run_cli.sql +++ b/testgen/template/get_entities/get_test_results_for_run_cli.sql @@ -13,6 +13,6 @@ SELECT ts.test_suite as test_suite_key, FROM test_results r INNER JOIN test_types tt ON r.test_type = tt.test_type INNER JOIN test_suites ts ON r.test_suite_id = ts.id - WHERE test_run_id = '{TEST_RUN_ID}'::UUID + WHERE test_run_id = :TEST_RUN_ID {ERRORS_ONLY} ORDER BY r.schema_name, r.table_name, r.column_names, r.test_type; diff --git a/testgen/template/get_entities/get_test_run_list.sql b/testgen/template/get_entities/get_test_run_list.sql index bc25ccc1..14079499 100644 --- a/testgen/template/get_entities/get_test_run_list.sql +++ b/testgen/template/get_entities/get_test_run_list.sql @@ -15,8 +15,8 @@ Optional: table-name, column-name, from-date, thru-date*/ FROM test_runs tr INNER JOIN test_results r ON tr.id = r.test_run_id INNER JOIN test_suites ts ON tr.test_suite_id = ts.id - WHERE ts.project_code = '{PROJECT_CODE}' - AND ts.test_suite = '{TEST_SUITE}' + WHERE ts.project_code = :PROJECT_CODE + AND ts.test_suite = :TEST_SUITE GROUP BY tr.id, ts.project_code, ts.test_suite, diff --git a/testgen/template/get_entities/get_test_suite.sql b/testgen/template/get_entities/get_test_suite.sql index 97241154..b602768d 100644 --- a/testgen/template/get_entities/get_test_suite.sql +++ b/testgen/template/get_entities/get_test_suite.sql @@ -9,5 +9,5 @@ SELECT component_key, component_type FROM test_suites -WHERE project_code = '{PROJECT_CODE}' -AND test_suite = '{TEST_SUITE}'; +WHERE project_code = :PROJECT_CODE +AND test_suite = :TEST_SUITE; diff --git a/testgen/template/get_entities/get_test_suite_list.sql b/testgen/template/get_entities/get_test_suite_list.sql index 294eb654..4ba63e1f 100644 --- a/testgen/template/get_entities/get_test_suite_list.sql +++ b/testgen/template/get_entities/get_test_suite_list.sql @@ -7,5 +7,5 @@ FROM test_suites ts LEFT JOIN test_runs tr ON tr.test_suite_id = ts.id - WHERE ts.project_code = '{PROJECT_CODE}' + WHERE ts.project_code = :PROJECT_CODE ORDER BY ts.test_suite; diff --git a/testgen/template/observability/get_event_data.sql b/testgen/template/observability/get_event_data.sql index d3f531e9..704fdc2f 100644 --- a/testgen/template/observability/get_event_data.sql +++ b/testgen/template/observability/get_event_data.sql @@ -19,4 +19,4 @@ from test_suites ts join connections c on c.connection_id = ts.connection_id join projects pr on pr.project_code = ts.project_code join table_groups tg on tg.id = ts.table_groups_id -where ts.id = '{TEST_SUITE_ID}' +where ts.id = :TEST_SUITE_ID; diff --git a/testgen/template/observability/get_test_results.sql b/testgen/template/observability/get_test_results.sql index afc29ec9..85ab567a 100644 --- a/testgen/template/observability/get_test_results.sql +++ b/testgen/template/observability/get_test_results.sql @@ -35,6 +35,6 @@ SELECT measure_uom, measure_uom_description FROM v_queued_observability_results -where test_suite_id = '{TEST_SUITE_ID}' +where test_suite_id = :TEST_SUITE_ID order by start_time asc -limit {MAX_QTY_EVENTS} +limit :MAX_QTY_EVENTS; diff --git a/testgen/template/observability/update_test_results_exported_to_observability.sql b/testgen/template/observability/update_test_results_exported_to_observability.sql index 2bb1e15f..b1b5fd0b 100644 --- a/testgen/template/observability/update_test_results_exported_to_observability.sql +++ b/testgen/template/observability/update_test_results_exported_to_observability.sql @@ -2,10 +2,10 @@ Output: updates exported results */ with selects - as ( SELECT UNNEST(ARRAY[{RESULT_IDS}]) AS selected_id ) + as ( SELECT UNNEST(ARRAY[:TEST_RESULT_IDS]) AS selected_id ) update test_results set observability_status = 'Sent' from test_results r - INNER JOIN selects s ON (r.result_id = s.selected_id) + INNER JOIN selects s ON (r.result_id = s.selected_id::BIGINT) where r.id = test_results.id and r.observability_status = 'Queued' - and r.test_suite_id = '{TEST_SUITE_ID}' + and r.test_suite_id = :TEST_SUITE_ID; diff --git a/testgen/template/parms/parms_profiling.sql b/testgen/template/parms/parms_profiling.sql index fb786ebc..5eec7a4c 100644 --- a/testgen/template/parms/parms_profiling.sql +++ b/testgen/template/parms/parms_profiling.sql @@ -1,16 +1,4 @@ -SELECT cc.project_code, - cc.connection_id::VARCHAR(50) as connection_id, - cc.sql_flavor, - cc.url, - cc.connect_by_url, - cc.connect_by_key, - cc.private_key, - cc.private_key_passphrase, - cc.project_host, - cc.project_port, - cc.project_user, - cc.project_db, - cc.http_path, +SELECT tg.project_code, tg.id::VARCHAR(50) as table_groups_id, tg.table_group_schema, CASE @@ -26,10 +14,6 @@ SELECT cc.project_code, tg.profile_sample_percent, tg.profile_sample_min_count, tg.profile_do_pair_rules, - tg.profile_pair_rule_pct, - cc.max_threads + tg.profile_pair_rule_pct FROM table_groups tg - INNER JOIN connections cc - on cc.project_code = tg.project_code - and cc.connection_id = tg.connection_id - WHERE tg.id = '{TABLE_GROUPS_ID}'::UUID; + WHERE tg.id = :TABLE_GROUP_ID; diff --git a/testgen/template/parms/parms_test_execution.sql b/testgen/template/parms/parms_test_execution.sql index 85fe7fe0..f81b0c2f 100644 --- a/testgen/template/parms/parms_test_execution.sql +++ b/testgen/template/parms/parms_test_execution.sql @@ -1,5 +1,4 @@ SELECT ts.project_code, - ts.connection_id::VARCHAR, ts.id::VARCHAR as test_suite_id, ts.table_groups_id::VARCHAR, tg.table_group_schema, @@ -8,22 +7,8 @@ SELECT ts.project_code, ELSE fn_format_csv_quotes(tg.profiling_table_set) END as profiling_table_set, tg.profiling_include_mask, - tg.profiling_exclude_mask, - cc.sql_flavor, - cc.project_host, - cc.project_port, - cc.project_user, - cc.project_db, - cc.connect_by_key, - cc.private_key, - cc.private_key_passphrase, - cc.max_threads, - cc.max_query_chars, - cc.url, - cc.connect_by_url, - cc.http_path + tg.profiling_exclude_mask FROM test_suites ts - JOIN connections cc ON (ts.connection_id = cc.connection_id) JOIN table_groups tg ON (ts.table_groups_id = tg.id) - WHERE ts.project_code = '{PROJECT_CODE}' - AND ts.test_suite = '{TEST_SUITE}'; + WHERE ts.project_code = :PROJECT_CODE + AND ts.test_suite = :TEST_SUITE; diff --git a/testgen/template/parms/parms_test_gen.sql b/testgen/template/parms/parms_test_gen.sql index b9730edd..ebc717d1 100644 --- a/testgen/template/parms/parms_test_gen.sql +++ b/testgen/template/parms/parms_test_gen.sql @@ -1,22 +1,9 @@ SELECT tg.project_code, - tg.connection_id, - cc.sql_flavor, - cc.project_host, - cc.project_port, - cc.project_user, - cc.connect_by_key, - cc.private_key, - cc.private_key_passphrase, - cc.project_db, tg.table_group_schema, ts.export_to_observability, ts.id::VARCHAR as test_suite_id, - cc.url, - cc.connect_by_url, - cc.http_path, CURRENT_TIMESTAMP AT TIME ZONE 'UTC' - CAST(tg.profiling_delay_days AS integer) * INTERVAL '1 day' as profiling_as_of_date FROM table_groups tg -INNER JOIN connections cc ON tg.connection_id = cc.connection_id - LEFT JOIN test_suites ts ON tg.connection_id = ts.connection_id AND ts.test_suite = '{TEST_SUITE}' - WHERE tg.id = '{TABLE_GROUPS_ID}'; + LEFT JOIN test_suites ts ON tg.connection_id = ts.connection_id AND ts.test_suite = :TEST_SUITE + WHERE tg.id = :TABLE_GROUP_ID; diff --git a/testgen/template/profiling/cde_flagger_query.sql b/testgen/template/profiling/cde_flagger_query.sql index d1feb638..a23e69ec 100644 --- a/testgen/template/profiling/cde_flagger_query.sql +++ b/testgen/template/profiling/cde_flagger_query.sql @@ -1,6 +1,6 @@ UPDATE data_column_chars SET critical_data_element = FALSE - WHERE table_groups_id = '{TABLE_GROUPS_ID}'; + WHERE table_groups_id = :TABLE_GROUPS_ID; WITH cde_selects AS ( SELECT table_groups_id, table_name, column_name @@ -9,7 +9,7 @@ WITH cde_selects -- ROUND(100.0 * (value_ct - COALESCE(zero_length_ct, 0.0) - COALESCE(filled_value_ct, 0.0))::DEC(15, 3) / -- NULLIF(record_ct::DEC(15, 3), 0), 0) AS pct_records_populated FROM profile_results p - WHERE p.profile_run_id = '{PROFILE_RUN_ID}' + WHERE p.profile_run_id = :PROFILE_RUN_ID AND ROUND(100.0 * (value_ct - COALESCE(zero_length_ct, 0.0) - COALESCE(filled_value_ct, 0.0))::DEC(15, 3) / NULLIF(record_ct::DEC(15, 3), 0), 0) > 75 AND ((p.functional_table_type ILIKE '%Entity' diff --git a/testgen/template/profiling/contingency_columns.sql b/testgen/template/profiling/contingency_columns.sql index 045f967c..799e34fa 100644 --- a/testgen/template/profiling/contingency_columns.sql +++ b/testgen/template/profiling/contingency_columns.sql @@ -1,7 +1,7 @@ -- All codes / categories with few distinct values SELECT schema_name, table_name, STRING_AGG(column_name, ',' ORDER BY column_name) as contingency_columns FROM profile_results p - WHERE profile_run_id = '{PROFILE_RUN_ID}'::UUID + WHERE profile_run_id = :PROFILE_RUN_ID AND functional_data_type IN ('Code', 'Category') - AND distinct_value_ct BETWEEN 2 AND {CONTINGENCY_MAX_VALUES} + AND distinct_value_ct BETWEEN 2 AND :CONTINGENCY_MAX_VALUES GROUP BY schema_name, table_name; diff --git a/testgen/template/profiling/datatype_suggestions.sql b/testgen/template/profiling/datatype_suggestions.sql index d0af2a48..6eff3b95 100644 --- a/testgen/template/profiling/datatype_suggestions.sql +++ b/testgen/template/profiling/datatype_suggestions.sql @@ -144,8 +144,8 @@ FROM ( -- pull out declared size if present, else NULL CAST(substring(column_type FROM '\((\d+)\)') AS int) AS current_size FROM profile_results - WHERE project_code = '{PROJECT_CODE}' - AND schema_name = '{DATA_SCHEMA}' - AND run_date = '{RUN_DATE}' + WHERE project_code = :PROJECT_CODE + AND schema_name = :DATA_SCHEMA + AND run_date = :RUN_DATE ) AS base WHERE pr.id = base.id; diff --git a/testgen/template/profiling/functional_datatype.sql b/testgen/template/profiling/functional_datatype.sql index da853219..af610dbe 100644 --- a/testgen/template/profiling/functional_datatype.sql +++ b/testgen/template/profiling/functional_datatype.sql @@ -2,7 +2,7 @@ UPDATE profile_results SET functional_data_type = NULL, functional_table_type = NULL -WHERE profile_run_id = '{PROFILE_RUN_ID}'; +WHERE profile_run_id = :PROFILE_RUN_ID; -- 1. Assign CONSTANT and TBD - this is the first step of elimination @@ -18,7 +18,7 @@ SET functional_data_type = THEN 'TBD (Not enough data)' ELSE functional_data_type END -WHERE profile_run_id = '{PROFILE_RUN_ID}'; +WHERE profile_run_id = :PROFILE_RUN_ID; UPDATE profile_results @@ -28,21 +28,21 @@ SET functional_data_type = -- this tells us how much actual values we have filled in; threshold -> if there is only 1 value and it's 75% of the records -> then it's a constant THEN 'Constant' ELSE functional_data_type END -WHERE profile_run_id = '{PROFILE_RUN_ID}' +WHERE profile_run_id = :PROFILE_RUN_ID AND functional_data_type IS NULL; -- 1A. Assign ID's based on masks UPDATE profile_results SET functional_data_type = 'ID-SK' -WHERE profile_run_id = '{PROFILE_RUN_ID}' +WHERE profile_run_id = :PROFILE_RUN_ID AND functional_data_type IS NULL - AND column_name ILIKE '{PROFILE_SK_COLUMN_MASK}'; + AND column_name ILIKE :PROFILE_SK_COLUMN_MASK; UPDATE profile_results SET functional_data_type = 'ID' -WHERE profile_run_id = '{PROFILE_RUN_ID}' +WHERE profile_run_id = :PROFILE_RUN_ID AND functional_data_type IS NULL - AND column_name ILIKE '{PROFILE_ID_COLUMN_MASK}'; + AND column_name ILIKE :PROFILE_ID_COLUMN_MASK; -- 2. Assign DATE /* @@ -107,14 +107,14 @@ SET functional_data_type = THEN 'DateTime Stamp' ELSE functional_data_type END -WHERE profile_run_id = '{PROFILE_RUN_ID}' +WHERE profile_run_id = :PROFILE_RUN_ID AND functional_data_type IS NULL AND (general_type = 'D' OR (value_ct = date_ct + zero_length_ct AND value_ct > 0)); -- Character Date UPDATE profile_results SET functional_data_type = 'Date Stamp' -WHERE profile_run_id = '{PROFILE_RUN_ID}' +WHERE profile_run_id = :PROFILE_RUN_ID AND functional_data_type IS NULL AND distinct_pattern_ct = 1 AND min_text >= '1900' AND max_text <= '2200' @@ -123,7 +123,7 @@ WHERE profile_run_id = '{PROFILE_RUN_ID}' -- Character Timestamp UPDATE profile_results SET functional_data_type = 'DateTime Stamp' -WHERE profile_run_id = '{PROFILE_RUN_ID}' +WHERE profile_run_id = :PROFILE_RUN_ID AND functional_data_type IS NULL AND distinct_pattern_ct = 1 AND (TRIM(SPLIT_PART(top_patterns, '|', 2)) = 'NNNN-NN-NN NN:NN:NN' @@ -132,7 +132,7 @@ WHERE profile_run_id = '{PROFILE_RUN_ID}' -- Process Timestamp UPDATE profile_results SET functional_data_type = 'Process ' || functional_data_type -WHERE profile_run_id = '{PROFILE_RUN_ID}' +WHERE profile_run_id = :PROFILE_RUN_ID AND general_type IN ('A', 'D') AND ( column_name ~ '^(last_?|system_?|)(add|create|insert|inrt|update|updt|mod|modif|modf|del|delete|refresh)(.{0,3}d?_?(time|tm|date|day|dt|stamp|timestamp|datestamp))$' OR column_name ~ '^(last_?|)(change|chg|update|updt|mod|modify|modf|modified|refresh|refreshed)$' ); @@ -141,7 +141,7 @@ WHERE profile_run_id = '{PROFILE_RUN_ID}' -- Assign PERIODS: Period Year, Period Qtr, Period Month, Period Week, Period DOW UPDATE profile_results SET functional_data_type = 'Period Year' -WHERE profile_run_id = '{PROFILE_RUN_ID}' +WHERE profile_run_id = :PROFILE_RUN_ID AND functional_data_type IS NULL AND (column_name ILIKE '%year%' OR column_name ILIKE '%yr%') AND ( (min_value >= 1900 @@ -156,7 +156,7 @@ WHERE profile_run_id = '{PROFILE_RUN_ID}' UPDATE profile_results SET functional_data_type = 'Period Quarter' -WHERE profile_run_id = '{PROFILE_RUN_ID}' +WHERE profile_run_id = :PROFILE_RUN_ID AND functional_data_type IS NULL AND (column_name ILIKE '%qtr%' or column_name ILIKE '%quarter%') AND ( (min_value = 1 AND max_value = 4 @@ -169,7 +169,7 @@ WHERE profile_run_id = '{PROFILE_RUN_ID}' UPDATE profile_results SET functional_data_type = 'Period Year-Mon' -WHERE profile_run_id = '{PROFILE_RUN_ID}' +WHERE profile_run_id = :PROFILE_RUN_ID AND functional_data_type IS NULL AND column_name ILIKE '%mo%' AND min_text >= '1900' AND max_text <= '2200' @@ -182,7 +182,7 @@ WHERE profile_run_id = '{PROFILE_RUN_ID}' UPDATE profile_results SET functional_data_type = 'Period Month' -WHERE profile_run_id = '{PROFILE_RUN_ID}' +WHERE profile_run_id = :PROFILE_RUN_ID AND functional_data_type IS NULL AND column_name ILIKE '%mo%' AND ( @@ -194,7 +194,7 @@ WHERE profile_run_id = '{PROFILE_RUN_ID}' UPDATE profile_results SET functional_data_type = 'Period Mon-NN' -WHERE profile_run_id = '{PROFILE_RUN_ID}' +WHERE profile_run_id = :PROFILE_RUN_ID AND functional_data_type IS NULL AND min_text ~ '(?i)^(jan|feb|mar|apr|may|jun|jul|aug|sep|oct|nov|dec)[\s-]?\d{1,2}$' AND max_text ~ '(?i)^(jan|feb|mar|apr|may|jun|jul|aug|sep|oct|nov|dec)[\s-]?\d{1,2}$' @@ -203,7 +203,7 @@ WHERE profile_run_id = '{PROFILE_RUN_ID}' UPDATE profile_results SET functional_data_type = 'Period Week' -WHERE profile_run_id = '{PROFILE_RUN_ID}' +WHERE profile_run_id = :PROFILE_RUN_ID AND functional_data_type IS NULL AND ( column_name ILIKE '%wk%' OR column_name ILIKE '%week%' ) AND distinct_value_ct BETWEEN 10 AND 53 @@ -212,7 +212,7 @@ WHERE profile_run_id = '{PROFILE_RUN_ID}' UPDATE profile_results SET functional_data_type = 'Period DOW' -WHERE profile_run_id = '{PROFILE_RUN_ID}' +WHERE profile_run_id = :PROFILE_RUN_ID AND functional_data_type IS NULL AND ( column_name ILIKE '%day%' OR column_name ILIKE '%dow%') AND distinct_value_ct = 7 @@ -225,7 +225,7 @@ WHERE profile_run_id = '{PROFILE_RUN_ID}' UPDATE profile_results SET functional_data_type = 'Period Month' -WHERE profile_run_id = '{PROFILE_RUN_ID}' +WHERE profile_run_id = :PROFILE_RUN_ID AND (min_date = DATE_TRUNC('month', min_date)::DATE AND max_date = DATE_TRUNC('month', max_date)::DATE OR @@ -237,7 +237,7 @@ WHERE profile_run_id = '{PROFILE_RUN_ID}' UPDATE profile_results SET functional_data_type = 'Period Week' -WHERE profile_run_id = '{PROFILE_RUN_ID}' +WHERE profile_run_id = :PROFILE_RUN_ID AND ( EXTRACT(DOW FROM min_date) IN (0, 1, 5, 6) AND EXTRACT(DOW FROM min_date) = EXTRACT(DOW FROM max_date) ) @@ -254,7 +254,7 @@ SET functional_data_type = AND NOT functional_data_type ILIKE 'Period%' THEN 'Period' ELSE functional_data_type END -WHERE profile_run_id = '{PROFILE_RUN_ID}' +WHERE profile_run_id = :PROFILE_RUN_ID AND LOWER(column_name) IN ('period', 'month', 'week'); -- 3. Assign ADDRESS RELATED FIELDS, PHONE AND EMAIL @@ -286,7 +286,7 @@ SET functional_data_type = THEN 'State' ELSE functional_data_type END -WHERE profile_run_id = '{PROFILE_RUN_ID}' +WHERE profile_run_id = :PROFILE_RUN_ID AND functional_data_type IS NULL; -- Update City based on position of State and Zip @@ -303,7 +303,7 @@ INNER JOIN profile_results s AND c.table_name = s.table_name AND c.position + 1 = s.position AND 'State' = s.functional_data_type) - WHERE c.profile_run_id = '{PROFILE_RUN_ID}' + WHERE c.profile_run_id = :PROFILE_RUN_ID AND LOWER(c.column_name) SIMILAR TO '%c(|i)ty%' AND c.functional_data_type NOT IN ('State', 'Zip') AND profile_results.id = c.id; @@ -311,7 +311,7 @@ INNER JOIN profile_results s -- Assign Name UPDATE profile_results SET functional_data_type = 'Person Full Name' -WHERE profile_run_id = '{PROFILE_RUN_ID}' +WHERE profile_run_id = :PROFILE_RUN_ID AND functional_data_type IS NULL AND avg_length <= 20 AND avg_embedded_spaces BETWEEN 0.9 AND 2.0 @@ -321,7 +321,7 @@ WHERE profile_run_id = '{PROFILE_RUN_ID}' -- Assign First Name UPDATE profile_results SET functional_data_type = 'Person Given Name' -WHERE profile_run_id = '{PROFILE_RUN_ID}' +WHERE profile_run_id = :PROFILE_RUN_ID AND avg_length <= 8 AND avg_embedded_spaces < 0.2 AND (LOWER(column_name) SIMILAR TO '%f(|i)rst(_| |)n(|a)m%%' @@ -331,7 +331,7 @@ WHERE profile_run_id = '{PROFILE_RUN_ID}' -- Assign Last Name UPDATE profile_results SET functional_data_type = 'Person Last Name' -WHERE profile_run_id = '{PROFILE_RUN_ID}' +WHERE profile_run_id = :PROFILE_RUN_ID AND avg_length BETWEEN 5 and 8 AND avg_embedded_spaces < 0.2 AND (LOWER(column_name) SIMILAR TO '%l(|a)st(_| |)n(|a)m%' @@ -340,7 +340,7 @@ WHERE profile_run_id = '{PROFILE_RUN_ID}' UPDATE profile_results SET functional_data_type = 'Entity Name' -WHERE profile_run_id = '{PROFILE_RUN_ID}' +WHERE profile_run_id = :PROFILE_RUN_ID AND functional_data_type IS NULL AND general_type = 'A' AND column_name ~ '(acct|account|affiliation|branch|business|co|comp|company|corp|corporate|cust|customer|distributor|employer|entity|firm|franchise|hco|org|organization|site|supplier|vendor|hospital|practice|clinic)(_| |)(name|nm)$'; @@ -348,13 +348,13 @@ WHERE profile_run_id = '{PROFILE_RUN_ID}' -- Process User: tracks data process UPDATE profile_results SET functional_data_type = 'Process User' -WHERE profile_run_id = '{PROFILE_RUN_ID}' +WHERE profile_run_id = :PROFILE_RUN_ID AND column_name ~ '^(last_?|)(create|update|modif|delete|refresh)(.*?(by|id|name|nm|user|usr))$'; -- System User: SW system UPDATE profile_results SET functional_data_type = 'System User' -WHERE profile_run_id = '{PROFILE_RUN_ID}' +WHERE profile_run_id = :PROFILE_RUN_ID AND column_name ~ '(user|usr)_?(name|nm)?$'; @@ -383,7 +383,7 @@ SET functional_data_type = THEN 'Boolean' ELSE functional_data_type END -WHERE profile_run_id = '{PROFILE_RUN_ID}' +WHERE profile_run_id = :PROFILE_RUN_ID AND functional_data_type IS NULL; @@ -434,7 +434,7 @@ SET functional_data_type = END ELSE functional_data_type END -WHERE profile_run_id = '{PROFILE_RUN_ID}' +WHERE profile_run_id = :PROFILE_RUN_ID AND functional_data_type IS NULL AND general_type='A' AND LOWER(datatype_suggestion) SIMILAR TO '(%varchar%)'; @@ -455,7 +455,7 @@ SET functional_data_type = AND fn_charcount(top_patterns, 'A') > 0 THEN 'Flag' END -WHERE profile_run_id = '{PROFILE_RUN_ID}' +WHERE profile_run_id = :PROFILE_RUN_ID AND functional_data_type IS NULL; @@ -502,7 +502,7 @@ SET functional_data_type = ELSE 'UNKNOWN' END -WHERE profile_run_id = '{PROFILE_RUN_ID}' +WHERE profile_run_id = :PROFILE_RUN_ID AND functional_data_type IS NULL; -- Assign City @@ -514,7 +514,7 @@ UPDATE profile_results ON p.profile_run_id = pn.profile_run_id AND p.table_name = pn.table_name AND p.position = pn.position - 1 - WHERE p.profile_run_id = '{PROFILE_RUN_ID}' + WHERE p.profile_run_id = :PROFILE_RUN_ID AND p.includes_digit_ct::FLOAT/NULLIF(p.value_ct,0)::FLOAT < 0.05 AND p.numeric_ct::FLOAT/NULLIF(p.value_ct,0)::FLOAT < 0.05 AND p.date_ct::FLOAT/NULLIF(p.value_ct,0)::FLOAT < 0.05 @@ -534,21 +534,21 @@ SET functional_data_type = CASE AND ROUND(100.0 * distinct_value_ct::FLOAT/NULLIF(value_ct, 0)) < 75 THEN 'ID-Group' ELSE functional_data_type END - WHERE profile_run_id = '{PROFILE_RUN_ID}' + WHERE profile_run_id = :PROFILE_RUN_ID AND functional_data_type = 'ID'; -- Assign 'ID-Unique' functional data type to the columns that are identity columns UPDATE profile_results SET functional_data_type = 'ID-Unique' -WHERE profile_run_id = '{PROFILE_RUN_ID}' +WHERE profile_run_id = :PROFILE_RUN_ID AND functional_data_type IN ('ID', 'ID-Secondary') AND record_ct = distinct_value_ct AND record_ct > 50; UPDATE profile_results SET functional_data_type = 'ID-Unique-SK' -WHERE profile_run_id = '{PROFILE_RUN_ID}' +WHERE profile_run_id = :PROFILE_RUN_ID AND functional_data_type = 'ID-SK' AND record_ct = distinct_value_ct AND record_ct > 50; @@ -560,8 +560,8 @@ SET functional_data_type = 'ID-FK' FROM (Select table_groups_id, table_name, column_name from profile_results where functional_data_type IN ('ID-Unique', 'ID-Unique-SK') - and profile_run_id = '{PROFILE_RUN_ID}') ui -WHERE profile_results.profile_run_id = '{PROFILE_RUN_ID}' + and profile_run_id = :PROFILE_RUN_ID) ui +WHERE profile_results.profile_run_id = :PROFILE_RUN_ID and profile_results.column_name = ui.column_name and profile_results.table_groups_id = ui.table_groups_id and profile_results.table_name <> ui.table_name @@ -571,7 +571,7 @@ WHERE profile_results.profile_run_id = '{PROFILE_RUN_ID}' UPDATE profile_results SET functional_data_type = 'Measurement Pct' -WHERE profile_run_id = '{PROFILE_RUN_ID}' +WHERE profile_run_id = :PROFILE_RUN_ID AND functional_data_type IN ('Measurement', 'Measurement Discrete', 'UNKNOWN') AND general_type = 'N' AND min_value >= -200 @@ -580,7 +580,7 @@ WHERE profile_run_id = '{PROFILE_RUN_ID}' UPDATE profile_results SET functional_data_type = 'Measurement Pct' -WHERE profile_run_id = '{PROFILE_RUN_ID}' +WHERE profile_run_id = :PROFILE_RUN_ID AND functional_data_type = 'Code' AND distinct_pattern_ct between 1 and 3 AND value_ct = includes_digit_ct diff --git a/testgen/template/profiling/functional_tabletype_stage.sql b/testgen/template/profiling/functional_tabletype_stage.sql index c843a3da..a19576aa 100644 --- a/testgen/template/profiling/functional_tabletype_stage.sql +++ b/testgen/template/profiling/functional_tabletype_stage.sql @@ -9,13 +9,13 @@ WITH tablesrank AS FROM profile_results p INNER JOIN (SELECT DISTINCT schema_name, table_name FROM profile_results - WHERE project_code = '{PROJECT_CODE}' - AND schema_name = '{DATA_SCHEMA}' - AND run_date = '{RUN_DATE}') pt + WHERE project_code = :PROJECT_CODE + AND schema_name = :DATA_SCHEMA + AND run_date = :RUN_DATE) pt ON (p.schema_name = pt.schema_name AND p.table_name = pt.table_name) - WHERE p.project_code = '{PROJECT_CODE}' - AND p.schema_name = '{DATA_SCHEMA}' + WHERE p.project_code = :PROJECT_CODE + AND p.schema_name = :DATA_SCHEMA ORDER BY p.schema_name, p.table_name, p.run_date DESC), tablescount AS (SELECT * @@ -59,6 +59,6 @@ WITH tablesrank AS ORDER BY project_code, schema_name, table_name) INSERT INTO stg_functional_table_updates (project_code, schema_name, run_date, table_name, table_period, table_type) -SELECT project_code, schema_name, '{RUN_DATE}' as run_date, +SELECT project_code, schema_name, :RUN_DATE as run_date, table_name, table_period, table_type FROM tablestat; diff --git a/testgen/template/profiling/functional_tabletype_update.sql b/testgen/template/profiling/functional_tabletype_update.sql index a81cbdcb..3ae8c595 100644 --- a/testgen/template/profiling/functional_tabletype_update.sql +++ b/testgen/template/profiling/functional_tabletype_update.sql @@ -5,4 +5,4 @@ WHERE s.project_code = profile_results.project_code AND s.schema_name = profile_results.schema_name AND s.table_name = profile_results.table_name AND s.run_date = profile_results.run_date - AND s.run_date = '{RUN_DATE}'; + AND s.run_date = :RUN_DATE; diff --git a/testgen/template/profiling/pii_flag.sql b/testgen/template/profiling/pii_flag.sql index 587d187e..c12f2911 100644 --- a/testgen/template/profiling/pii_flag.sql +++ b/testgen/template/profiling/pii_flag.sql @@ -65,7 +65,7 @@ WITH screen END AS pii_flag FROM profile_results p - WHERE profile_run_id = '{PROFILE_RUN_ID}' + WHERE profile_run_id = :PROFILE_RUN_ID AND general_type = 'A' ) UPDATE profile_results SET pii_flag = screen.pii_flag @@ -77,7 +77,7 @@ UPDATE profile_results WITH table_pii_counts AS ( SELECT table_name, COUNT(pii_flag) AS pii_ct FROM profile_results - WHERE profile_run_id = '{PROFILE_RUN_ID}' + WHERE profile_run_id = :PROFILE_RUN_ID GROUP BY table_name ), screen AS ( SELECT id AS profile_results_id, @@ -122,7 +122,7 @@ UPDATE profile_results FROM profile_results p INNER JOIN table_pii_counts t ON (p.table_name = t.table_name) - WHERE p.profile_run_id = '{PROFILE_RUN_ID}' + WHERE p.profile_run_id = :PROFILE_RUN_ID AND p.general_type = 'A' AND p.pii_flag IS NULL AND t.pii_ct > 1 ) diff --git a/testgen/template/profiling/profile_anomalies_screen_column.sql b/testgen/template/profiling/profile_anomalies_screen_column.sql index cb9c4c1e..e7c2b5dc 100644 --- a/testgen/template/profiling/profile_anomalies_screen_column.sql +++ b/testgen/template/profiling/profile_anomalies_screen_column.sql @@ -4,7 +4,7 @@ INSERT INTO profile_anomaly_results SELECT p.project_code, p.table_groups_id, p.profile_run_id, - '{ANOMALY_ID}' as anomaly_id, + :ANOMALY_ID as anomaly_id, p.schema_name, p.table_name, p.column_name, @@ -16,7 +16,7 @@ LEFT JOIN v_inactive_anomalies i AND p.schema_name = i.schema_name AND p.table_name = i.table_name AND p.column_name = i.column_name - AND '{ANOMALY_ID}' = i.anomaly_id) - WHERE p.profile_run_id = '{PROFILE_RUN_ID}'::UUID + AND :ANOMALY_ID = i.anomaly_id) + WHERE p.profile_run_id = :PROFILE_RUN_ID AND i.anomaly_id IS NULL AND ({ANOMALY_CRITERIA}); diff --git a/testgen/template/profiling/profile_anomalies_screen_multi_column.sql b/testgen/template/profiling/profile_anomalies_screen_multi_column.sql index 6451eafd..af315502 100644 --- a/testgen/template/profiling/profile_anomalies_screen_multi_column.sql +++ b/testgen/template/profiling/profile_anomalies_screen_multi_column.sql @@ -13,7 +13,7 @@ WITH mults AS ( SELECT p.project_code, STRING_AGG(table_name, ', ' order by table_name) as table_list, MAX(RIGHT(REPEAT('0', 20) || SPLIT_PART(p.top_patterns, '|', 1), 20) || '|' || SPLIT_PART(p.top_patterns, '|', 2) )as very_top_pattern FROM profile_results p - WHERE p.profile_run_id = '{PROFILE_RUN_ID}'::UUID + WHERE p.profile_run_id = :PROFILE_RUN_ID GROUP BY p.project_code, p.table_groups_id, schema_name, p.column_name HAVING COUNT(*) > 1 ), subset AS @@ -21,7 +21,7 @@ WITH mults AS ( SELECT p.project_code, SELECT p.project_code, p.table_groups_id, p.profile_run_id, - '{ANOMALY_ID}' as anomaly_id, + :ANOMALY_ID as anomaly_id, p.schema_name, p.table_name, p.column_name, @@ -41,8 +41,8 @@ WITH mults AS ( SELECT p.project_code, AND p.schema_name = i.schema_name AND p.table_name = i.table_name AND p.column_name = i.column_name - AND '{ANOMALY_ID}' = i.anomaly_id) - WHERE p.profile_run_id = '{PROFILE_RUN_ID}'::UUID + AND :ANOMALY_ID = i.anomaly_id) + WHERE p.profile_run_id = :PROFILE_RUN_ID AND i.anomaly_id IS NULL AND ({ANOMALY_CRITERIA}) ) diff --git a/testgen/template/profiling/profile_anomalies_screen_table.sql b/testgen/template/profiling/profile_anomalies_screen_table.sql index 4877369b..646d2a00 100644 --- a/testgen/template/profiling/profile_anomalies_screen_table.sql +++ b/testgen/template/profiling/profile_anomalies_screen_table.sql @@ -4,7 +4,7 @@ INSERT INTO profile_anomaly_results SELECT p.project_code, p.table_groups_id, p.profile_run_id, - '{ANOMALY_ID}' as anomaly_id, + :ANOMALY_ID as anomaly_id, p.schema_name, p.table_name, '(Table)' as column_name, @@ -15,8 +15,8 @@ LEFT JOIN v_inactive_anomalies i ON (p.table_groups_id = i.table_groups_id AND p.schema_name = i.schema_name AND p.table_name = i.table_name - AND '{ANOMALY_ID}' = i.anomaly_id) - WHERE p.profile_run_id = '{PROFILE_RUN_ID}'::UUID + AND :ANOMALY_ID = i.anomaly_id) + WHERE p.profile_run_id = :PROFILE_RUN_ID GROUP BY p.project_code, p.table_groups_id, p.profile_run_id, p.schema_name, p.table_name HAVING {ANOMALY_CRITERIA}; diff --git a/testgen/template/profiling/profile_anomalies_screen_table_dates.sql b/testgen/template/profiling/profile_anomalies_screen_table_dates.sql index 581f2a3d..f4ba10f6 100644 --- a/testgen/template/profiling/profile_anomalies_screen_table_dates.sql +++ b/testgen/template/profiling/profile_anomalies_screen_table_dates.sql @@ -4,7 +4,7 @@ INSERT INTO profile_anomaly_results SELECT p.project_code, p.table_groups_id, p.profile_run_id, - '{ANOMALY_ID}' as anomaly_id, + :ANOMALY_ID as anomaly_id, p.schema_name, p.table_name, CASE @@ -21,8 +21,8 @@ LEFT JOIN v_inactive_anomalies i ON (p.table_groups_id = i.table_groups_id AND p.schema_name = i.schema_name AND p.table_name = i.table_name - AND '{ANOMALY_ID}' = i.anomaly_id) - WHERE p.profile_run_id = '{PROFILE_RUN_ID}'::UUID + AND :ANOMALY_ID = i.anomaly_id) + WHERE p.profile_run_id = :PROFILE_RUN_ID AND i.anomaly_id IS NULL AND p.general_type = 'D' GROUP BY p.project_code, p.table_groups_id, p.profile_run_id, diff --git a/testgen/template/profiling/profile_anomalies_screen_variants.sql b/testgen/template/profiling/profile_anomalies_screen_variants.sql index 266e73ee..f3e603e0 100644 --- a/testgen/template/profiling/profile_anomalies_screen_variants.sql +++ b/testgen/template/profiling/profile_anomalies_screen_variants.sql @@ -20,8 +20,8 @@ WITH all_matches AND p.schema_name = i.schema_name AND p.table_name = i.table_name AND p.column_name = i.column_name - AND '{ANOMALY_ID}' = i.anomaly_id) - WHERE p.profile_run_id = '{PROFILE_RUN_ID}'::UUID + AND :ANOMALY_ID = i.anomaly_id) + WHERE p.profile_run_id = :PROFILE_RUN_ID AND ({ANOMALY_CRITERIA}) AND p.top_freq_values > '' AND i.anomaly_id IS NULL @@ -34,7 +34,7 @@ WITH all_matches p.column_name, p.column_type ) SELECT project_code, table_groups_id, profile_run_id, - '{ANOMALY_ID}' AS anomaly_id, + :ANOMALY_ID AS anomaly_id, schema_name, table_name, column_name, column_type, {DETAIL_EXPRESSION} AS detail FROM all_matches; diff --git a/testgen/template/profiling/profile_anomaly_scoring.sql b/testgen/template/profiling/profile_anomaly_scoring.sql index 9511c125..e4c34b44 100644 --- a/testgen/template/profiling/profile_anomaly_scoring.sql +++ b/testgen/template/profiling/profile_anomaly_scoring.sql @@ -1,10 +1,10 @@ UPDATE profile_anomaly_results r - SET dq_prevalence = ({PREV_FORMULA}) * {RISK} + SET dq_prevalence = ({PREV_FORMULA}) * :RISK FROM profile_anomaly_results r2 INNER JOIN profile_results p ON (r2.profile_run_id = p.profile_run_id AND r2.table_name = p.table_name AND r2.column_name = p.column_name) - WHERE r.profile_run_id = '{PROFILE_RUN_ID}'::UUID - AND r2.anomaly_id = '{ANOMALY_ID}' + WHERE r.profile_run_id = :PROFILE_RUN_ID + AND r2.anomaly_id = :ANOMALY_ID AND r.id = r2.id; \ No newline at end of file diff --git a/testgen/template/profiling/project_get_table_sample_count.sql b/testgen/template/profiling/project_get_table_sample_count.sql deleted file mode 100644 index d80a22f9..00000000 --- a/testgen/template/profiling/project_get_table_sample_count.sql +++ /dev/null @@ -1,22 +0,0 @@ -SELECT '{SAMPLING_TABLE}' as schema_table, - CASE - WHEN count(*) <= {PROFILE_SAMPLE_MIN_COUNT} - THEN -1 - ELSE - CASE - WHEN ROUND(CAST({PROFILE_SAMPLE_PERCENT} as FLOAT) * CAST(COUNT(*) as FLOAT) / 100.0, 0) > {PROFILE_SAMPLE_MIN_COUNT} - THEN LEAST(999000, ROUND(CAST({PROFILE_SAMPLE_PERCENT} as FLOAT) * CAST(COUNT(*) as FLOAT) / 100.0, 0)) - ELSE {PROFILE_SAMPLE_MIN_COUNT} - END - END as sample_count, - CASE - WHEN count(*) <= {PROFILE_SAMPLE_MIN_COUNT} - THEN 1 - ELSE (CAST(COUNT(*) as FLOAT) - / CASE - WHEN ROUND(CAST({PROFILE_SAMPLE_PERCENT} as FLOAT) * CAST(COUNT(*) as FLOAT) / 100.0, 0) > {PROFILE_SAMPLE_MIN_COUNT} - THEN LEAST(999000, ROUND(CAST({PROFILE_SAMPLE_PERCENT} as FLOAT) * CAST(COUNT(*) as FLOAT) / 100.0, 0)) - ELSE {PROFILE_SAMPLE_MIN_COUNT} - END ) - END as sample_ratio -from {SAMPLING_TABLE}; diff --git a/testgen/template/profiling/project_profile_run_record_insert.sql b/testgen/template/profiling/project_profile_run_record_insert.sql index 04b902f3..e1c379fc 100644 --- a/testgen/template/profiling/project_profile_run_record_insert.sql +++ b/testgen/template/profiling/project_profile_run_record_insert.sql @@ -1,8 +1,8 @@ INSERT INTO profiling_runs (id, project_code, connection_id, table_groups_id, profiling_starttime, process_id) -(SELECT '{PROFILE_RUN_ID}' :: UUID as id, - '{PROJECT_CODE}' as project_code, - {CONNECTION_ID} as connection_id, - '{TABLE_GROUPS_ID}' :: UUID as table_groups_id, - '{RUN_DATE}' as profiling_starttime, - '{PROCESS_ID}' as process_id +(SELECT :PROFILE_RUN_ID as id, + :PROJECT_CODE as project_code, + :CONNECTION_ID as connection_id, + :TABLE_GROUPS_ID as table_groups_id, + :RUN_DATE as profiling_starttime, + :PROCESS_ID as process_id ); diff --git a/testgen/template/profiling/project_profile_run_record_update.sql b/testgen/template/profiling/project_profile_run_record_update.sql index fc64e011..e6c7b0de 100644 --- a/testgen/template/profiling/project_profile_run_record_update.sql +++ b/testgen/template/profiling/project_profile_run_record_update.sql @@ -1,5 +1,5 @@ UPDATE profiling_runs -SET status = CASE WHEN length('{EXCEPTION_MESSAGE}') = 0 then 'Complete' else 'Error' end, - profiling_endtime = '{NOW}', - log_message = '{EXCEPTION_MESSAGE}' -where id = '{PROFILE_RUN_ID}' :: UUID; +SET status = CASE WHEN length(:EXCEPTION_MESSAGE) = 0 then 'Complete' else 'Error' end, + profiling_endtime = :NOW_TIMESTAMP, + log_message = :EXCEPTION_MESSAGE +where id = :PROFILE_RUN_ID; diff --git a/testgen/template/profiling/project_update_profile_results_to_estimates.sql b/testgen/template/profiling/project_update_profile_results_to_estimates.sql index 640829cf..278302d0 100644 --- a/testgen/template/profiling/project_update_profile_results_to_estimates.sql +++ b/testgen/template/profiling/project_update_profile_results_to_estimates.sql @@ -4,28 +4,28 @@ -- in a random sample. update profile_results -set sample_ratio = {PROFILE_SAMPLE_RATIO}, - record_ct = ROUND(record_ct * {PROFILE_SAMPLE_RATIO}, 0), - value_ct = ROUND(value_ct * {PROFILE_SAMPLE_RATIO}, 0), - -- distinct_value_ct = ROUND(record_ct * {PROFILE_SAMPLE_RATIO} *(distinct_value_ct::numeric/record_ct::numeric), 0), - null_value_ct = ROUND(null_value_ct * {PROFILE_SAMPLE_RATIO}, 0), - zero_value_ct = ROUND(zero_value_ct * {PROFILE_SAMPLE_RATIO}, 0), - lead_space_ct = ROUND(lead_space_ct * {PROFILE_SAMPLE_RATIO}, 0), - embedded_space_ct = ROUND(embedded_space_ct * {PROFILE_SAMPLE_RATIO}, 0), - includes_digit_ct = ROUND(includes_digit_ct * {PROFILE_SAMPLE_RATIO}, 0), - filled_value_ct = ROUND(filled_value_ct * {PROFILE_SAMPLE_RATIO}, 0), - numeric_ct = ROUND(numeric_ct * {PROFILE_SAMPLE_RATIO}, 0), - date_ct = ROUND(date_ct * {PROFILE_SAMPLE_RATIO}, 0), - before_1yr_date_ct = ROUND(before_1yr_date_ct * {PROFILE_SAMPLE_RATIO}, 0), - before_5yr_date_ct = ROUND(before_5yr_date_ct * {PROFILE_SAMPLE_RATIO}, 0), - before_20yr_date_ct = ROUND(before_20yr_date_ct * {PROFILE_SAMPLE_RATIO}, 0), - within_1yr_date_ct = ROUND(within_1yr_date_ct * {PROFILE_SAMPLE_RATIO}, 0), - within_1mo_date_ct = ROUND(within_1mo_date_ct * {PROFILE_SAMPLE_RATIO}, 0), - future_date_ct = ROUND(future_date_ct * {PROFILE_SAMPLE_RATIO}, 0), - boolean_true_ct = ROUND(boolean_true_ct * {PROFILE_SAMPLE_RATIO}, 0) -where profile_run_id = '{PROFILE_RUN_ID}' -and schema_name = split_part('{SAMPLING_TABLE}', '.', 1) -and table_name = split_part('{SAMPLING_TABLE}', '.', 2) +set sample_ratio = :PROFILE_SAMPLE_RATIO, + record_ct = ROUND(record_ct * :PROFILE_SAMPLE_RATIO, 0), + value_ct = ROUND(value_ct * :PROFILE_SAMPLE_RATIO, 0), + -- distinct_value_ct = ROUND(record_ct * :PROFILE_SAMPLE_RATIO *(distinct_value_ct::numeric/record_ct::numeric), 0), + null_value_ct = ROUND(null_value_ct * :PROFILE_SAMPLE_RATIO, 0), + zero_value_ct = ROUND(zero_value_ct * :PROFILE_SAMPLE_RATIO, 0), + lead_space_ct = ROUND(lead_space_ct * :PROFILE_SAMPLE_RATIO, 0), + embedded_space_ct = ROUND(embedded_space_ct * :PROFILE_SAMPLE_RATIO, 0), + includes_digit_ct = ROUND(includes_digit_ct * :PROFILE_SAMPLE_RATIO, 0), + filled_value_ct = ROUND(filled_value_ct * :PROFILE_SAMPLE_RATIO, 0), + numeric_ct = ROUND(numeric_ct * :PROFILE_SAMPLE_RATIO, 0), + date_ct = ROUND(date_ct * :PROFILE_SAMPLE_RATIO, 0), + before_1yr_date_ct = ROUND(before_1yr_date_ct * :PROFILE_SAMPLE_RATIO, 0), + before_5yr_date_ct = ROUND(before_5yr_date_ct * :PROFILE_SAMPLE_RATIO, 0), + before_20yr_date_ct = ROUND(before_20yr_date_ct * :PROFILE_SAMPLE_RATIO, 0), + within_1yr_date_ct = ROUND(within_1yr_date_ct * :PROFILE_SAMPLE_RATIO, 0), + within_1mo_date_ct = ROUND(within_1mo_date_ct * :PROFILE_SAMPLE_RATIO, 0), + future_date_ct = ROUND(future_date_ct * :PROFILE_SAMPLE_RATIO, 0), + boolean_true_ct = ROUND(boolean_true_ct * :PROFILE_SAMPLE_RATIO, 0) +where profile_run_id = :PROFILE_RUN_ID +and schema_name = split_part(:SAMPLING_TABLE, '.', 1) +and table_name = split_part(:SAMPLING_TABLE, '.', 2) and sample_ratio IS NULL; diff --git a/testgen/template/profiling/refresh_anomalies.sql b/testgen/template/profiling/refresh_anomalies.sql index 3aca49db..9159fbf5 100644 --- a/testgen/template/profiling/refresh_anomalies.sql +++ b/testgen/template/profiling/refresh_anomalies.sql @@ -6,7 +6,7 @@ WITH anomalies COUNT(DISTINCT schema_name || '.' || table_name) as anomaly_table_ct, COUNT(DISTINCT schema_name || '.' || table_name || '.' || column_name) as anomaly_column_ct FROM profile_anomaly_results - WHERE profile_run_id = '{PROFILE_RUN_ID}'::UUID + WHERE profile_run_id = :PROFILE_RUN_ID GROUP BY profile_run_id ), profiles AS ( SELECT r.id as profile_run_id, @@ -15,7 +15,7 @@ profiles FROM profiling_runs r INNER JOIN profile_results p ON r.id = p.profile_run_id - WHERE r.id = '{PROFILE_RUN_ID}'::UUID + WHERE r.id = :PROFILE_RUN_ID GROUP BY r.id ), stats AS ( SELECT p.profile_run_id, table_ct, column_ct, diff --git a/testgen/template/profiling/secondary_profiling_columns.sql b/testgen/template/profiling/secondary_profiling_columns.sql index 2c56b92c..fb1b8cc4 100644 --- a/testgen/template/profiling/secondary_profiling_columns.sql +++ b/testgen/template/profiling/secondary_profiling_columns.sql @@ -4,7 +4,7 @@ SELECT schema_name, table_name, column_name FROM profile_results p - WHERE p.profile_run_id = '{PROFILE_RUN_ID}' + WHERE p.profile_run_id = :PROFILE_RUN_ID AND p.top_freq_values IS NULL AND p.general_type = 'A' AND p.distinct_value_ct BETWEEN 2 and 70 diff --git a/testgen/template/profiling/secondary_profiling_delete.sql b/testgen/template/profiling/secondary_profiling_delete.sql index 13b6519f..5f7f3253 100644 --- a/testgen/template/profiling/secondary_profiling_delete.sql +++ b/testgen/template/profiling/secondary_profiling_delete.sql @@ -1,4 +1,4 @@ DELETE FROM stg_secondary_profile_updates s - WHERE s.project_code = '{PROJECT_CODE}' - AND s.schema_name = '{DATA_SCHEMA}' - AND s.run_date = '{RUN_DATE}'; + WHERE s.project_code = :PROJECT_CODE + AND s.schema_name = :DATA_SCHEMA + AND s.run_date = :RUN_DATE; diff --git a/testgen/template/profiling/secondary_profiling_update.sql b/testgen/template/profiling/secondary_profiling_update.sql index 342a301a..b16309b9 100644 --- a/testgen/template/profiling/secondary_profiling_update.sql +++ b/testgen/template/profiling/secondary_profiling_update.sql @@ -13,6 +13,6 @@ WHERE p.project_code = profile_results.project_code AND p.run_date = profile_results.run_date AND p.table_name = profile_results.table_name AND p.column_name = profile_results.column_name - AND p.project_code = '{PROJECT_CODE}' - AND p.schema_name = '{DATA_SCHEMA}' - AND p.run_date = '{RUN_DATE}'; + AND p.project_code = :PROJECT_CODE + AND p.schema_name = :DATA_SCHEMA + AND p.run_date = :RUN_DATE; diff --git a/testgen/template/quick_start/add_cat_tests.sql b/testgen/template/quick_start/add_cat_tests.sql new file mode 100644 index 00000000..7dc931cf --- /dev/null +++ b/testgen/template/quick_start/add_cat_tests.sql @@ -0,0 +1,12 @@ +SET SEARCH_PATH TO {SCHEMA_NAME}; + +INSERT INTO test_definitions + (table_groups_id, last_manual_update, schema_name, match_schema_name, test_type, + test_suite_id, table_name, column_name, skip_errors, threshold_value, subset_condition, + groupby_names, having_condition, match_table_name, match_column_names, match_subset_condition, + match_groupby_names, match_having_condition, test_active, severity, watch_level, lock_refresh) +VALUES ('0ea85e17-acbe-47fe-8394-9970725ad37d', '2024-06-07 02:45:27.102847', :PROJECT_SCHEMA, :PROJECT_SCHEMA, + 'Aggregate_Balance', (SELECT id FROM test_suites WHERE test_suite = 'default-suite-1'), + 'f_ebike_sales', 'SUM(total_amount)', 0, '0', 'sale_date <= (DATE_TRUNC(''month'', CURRENT_DATE) - (interval ''3 month'' - interval ''{ITERATION_NUMBER} month'') - interval ''1 day'')', + 'product_id, sale_date_year, sale_date_month', null, 'tmp_f_ebike_sales_last_month', 'SUM(total_amount)', null, 'product_id, sale_date_year, sale_date_month', + null, 'Y', null, 'WARN', 'N'); diff --git a/testgen/template/quick_start/recreate_target_data_schema.sql b/testgen/template/quick_start/recreate_target_data_schema.sql index 1599f865..bdb9ecac 100644 --- a/testgen/template/quick_start/recreate_target_data_schema.sql +++ b/testgen/template/quick_start/recreate_target_data_schema.sql @@ -153,6 +153,8 @@ DROP TABLE IF EXISTS f_ebike_sales CASCADE; CREATE TABLE f_ebike_sales ( sale_id INTEGER, sale_date DATE, + sale_date_year INTEGER GENERATED ALWAYS AS (EXTRACT(YEAR FROM sale_date)) STORED, + sale_date_month INTEGER GENERATED ALWAYS AS (EXTRACT(MONTH FROM sale_date)) STORED, customer_id INT, supplier_id INT, product_id INT, diff --git a/testgen/template/quick_start/update_cat_tests.sql b/testgen/template/quick_start/update_cat_tests.sql new file mode 100644 index 00000000..29c16856 --- /dev/null +++ b/testgen/template/quick_start/update_cat_tests.sql @@ -0,0 +1,8 @@ +SET SEARCH_PATH TO {SCHEMA_NAME}; + +UPDATE test_definitions +SET subset_condition = CASE WHEN {ITERATION_NUMBER} <> 3 THEN 'sale_date <= (DATE_TRUNC(''month'', CURRENT_DATE) - (interval ''3 month'' - interval ''{ITERATION_NUMBER} month'') - interval ''1 day'')' + ELSE 'sale_date <= (DATE_TRUNC(''month'', CURRENT_DATE) - interval ''1 day'')' + END +WHERE test_type='Aggregate_Balance' AND table_name='f_ebike_sales' AND column_name='SUM(total_amount)' +AND test_suite_id = (SELECT id FROM test_suites WHERE test_suite = 'default-suite-1'); diff --git a/testgen/template/quick_start/update_target_data.sql b/testgen/template/quick_start/update_target_data.sql index 76ff6f2f..1fac85f3 100644 --- a/testgen/template/quick_start/update_target_data.sql +++ b/testgen/template/quick_start/update_target_data.sql @@ -16,8 +16,12 @@ SELECT t.sale_id, FROM tmp_ebike_sales t LEFT JOIN f_ebike_sales fes ON t.sale_id = fes.sale_id WHERE fes.sale_id IS NULL - AND t.sale_date <= '{MAX_DATE}'; + AND t.sale_date <= :MAX_DATE; +DROP TABLE IF EXISTS tmp_f_ebike_sales_last_month; + +CREATE TABLE tmp_f_ebike_sales_last_month AS +SELECT * from f_ebike_sales f WHERE f.sale_date <= (DATE_TRUNC('month', CURRENT_DATE) - (interval '3 month' - interval '{ITERATION_NUMBER} month') - interval '1 day'); TRUNCATE TABLE d_ebike_customers; @@ -48,10 +52,10 @@ SELECT t.customer_id, t.credit_card, t.last_contact - (55 - (SELECT (CURRENT_DATE - MAX(last_contact)) FROM tmp_d_ebike_customers)) FROM tmp_d_ebike_customers t -WHERE t.customer_id <= '{MAX_CUSTOMER_ID_SEQ}'; +WHERE t.customer_id <= :MAX_CUSTOMER_ID_SEQ; UPDATE d_ebike_customers - SET last_contact = CASE WHEN '{ITERATION_NUMBER}' = 1 AND + SET last_contact = CASE WHEN :ITERATION_NUMBER = 1 AND current_date - last_contact <= 60 THEN last_contact - (62 - (current_date - last_contact)) ELSE last_contact END; @@ -75,7 +79,7 @@ SELECT t.product_id, t.price, t.max_discount FROM tmp_d_ebike_products t -WHERE product_id <= '{MAX_PRODUCT_ID_SEQ}'; +WHERE product_id <= :MAX_PRODUCT_ID_SEQ; TRUNCATE TABLE d_ebike_suppliers; @@ -97,4 +101,4 @@ SELECT t.supplier_id, t.key_supplier, t.supply_reliability FROM tmp_d_ebike_suppliers t -WHERE t.supplier_id <= '{MAX_SUPPLIER_ID_SEQ}'; +WHERE t.supplier_id <= :MAX_SUPPLIER_ID_SEQ; diff --git a/testgen/template/quick_start/update_target_data_iter0.sql b/testgen/template/quick_start/update_target_data_iter0.sql new file mode 100644 index 00000000..100a5ffe --- /dev/null +++ b/testgen/template/quick_start/update_target_data_iter0.sql @@ -0,0 +1 @@ +SET SEARCH_PATH TO {PROJECT_SCHEMA}; diff --git a/testgen/template/quick_start/update_target_data_iter1.sql b/testgen/template/quick_start/update_target_data_iter1.sql new file mode 100644 index 00000000..100a5ffe --- /dev/null +++ b/testgen/template/quick_start/update_target_data_iter1.sql @@ -0,0 +1 @@ +SET SEARCH_PATH TO {PROJECT_SCHEMA}; diff --git a/testgen/template/quick_start/update_target_data_iter2.sql b/testgen/template/quick_start/update_target_data_iter2.sql new file mode 100644 index 00000000..100a5ffe --- /dev/null +++ b/testgen/template/quick_start/update_target_data_iter2.sql @@ -0,0 +1 @@ +SET SEARCH_PATH TO {PROJECT_SCHEMA}; diff --git a/testgen/template/quick_start/update_target_data_iter3.sql b/testgen/template/quick_start/update_target_data_iter3.sql new file mode 100644 index 00000000..d1dece90 --- /dev/null +++ b/testgen/template/quick_start/update_target_data_iter3.sql @@ -0,0 +1,7 @@ +SET SEARCH_PATH TO {PROJECT_SCHEMA}; + +UPDATE f_ebike_sales +SET total_amount = (sale_price + 100) * quantity_sold, + adjusted_total_amount = (sale_price + 100) * quantity_sold - discount_amount, + sale_price = sale_price + 100 +WHERE product_id = 30027; diff --git a/testgen/template/rollup_scores/rollup_scores_profile_run.sql b/testgen/template/rollup_scores/rollup_scores_profile_run.sql index 572cbc1e..bc7ae926 100644 --- a/testgen/template/rollup_scores/rollup_scores_profile_run.sql +++ b/testgen/template/rollup_scores/rollup_scores_profile_run.sql @@ -3,7 +3,7 @@ UPDATE profiling_runs SET dq_affected_data_points = 0, dq_total_data_points = 0, dq_score_profiling = 1 - WHERE id = '{RUN_ID}'; + WHERE id = :RUN_ID; -- Roll up scoring to profiling run WITH score_detail @@ -17,7 +17,7 @@ WITH score_detail ON (pr.profile_run_id = p.profile_run_id AND pr.column_name = p.column_name AND pr.table_name = p.table_name) - WHERE pr.profile_run_id = '{RUN_ID}' + WHERE pr.profile_run_id = :RUN_ID AND COALESCE(p.disposition, 'Confirmed') = 'Confirmed' GROUP BY 1, 2, 3 ), score_calc diff --git a/testgen/template/rollup_scores/rollup_scores_profile_table_group.sql b/testgen/template/rollup_scores/rollup_scores_profile_table_group.sql index 54460109..4290e384 100644 --- a/testgen/template/rollup_scores/rollup_scores_profile_table_group.sql +++ b/testgen/template/rollup_scores/rollup_scores_profile_table_group.sql @@ -12,7 +12,7 @@ score_calc INNER JOIN last_profile_date lp ON (run.table_groups_id = lp.table_groups_id AND run.profiling_starttime = lp.last_profile_run_date) - WHERE run.table_groups_id = '{TABLE_GROUPS_ID}' ) + WHERE run.table_groups_id = :TABLE_GROUPS_ID ) UPDATE table_groups SET dq_score_profiling = (1.0 - s.sum_affected_data_points::FLOAT / NULLIF(s.sum_data_points::FLOAT, 0)), last_complete_profile_run_id = s.profile_run_id @@ -26,7 +26,7 @@ UPDATE data_column_chars last_complete_profile_run_id = tg.last_complete_profile_run_id FROM table_groups tg WHERE data_column_chars.table_groups_id = tg.id - AND data_column_chars.table_groups_id = '{TABLE_GROUPS_ID}'; + AND data_column_chars.table_groups_id = :TABLE_GROUPS_ID; -- Roll up latest scores to data_column_chars WITH score_detail @@ -47,7 +47,7 @@ WITH score_detail ON (pr.profile_run_id = p.profile_run_id AND pr.column_name = p.column_name AND pr.table_name = p.table_name) - WHERE tg.id = '{TABLE_GROUPS_ID}' + WHERE tg.id = :TABLE_GROUPS_ID AND COALESCE(p.disposition, 'Confirmed') = 'Confirmed' GROUP BY dcc.column_id ) UPDATE data_column_chars @@ -63,7 +63,7 @@ UPDATE data_table_chars last_complete_profile_run_id = tg.last_complete_profile_run_id FROM table_groups tg WHERE data_table_chars.table_groups_id = tg.id - AND data_table_chars.table_groups_id = '{TABLE_GROUPS_ID}'; + AND data_table_chars.table_groups_id = :TABLE_GROUPS_ID; -- Roll up latest scores to data_table_chars WITH score_detail @@ -83,7 +83,7 @@ WITH score_detail ON (pr.profile_run_id = p.profile_run_id AND pr.column_name = p.column_name AND pr.table_name = p.table_name) - WHERE tg.id = '{TABLE_GROUPS_ID}' + WHERE tg.id = :TABLE_GROUPS_ID AND COALESCE(p.disposition, 'Confirmed') = 'Confirmed' GROUP BY dcc.column_id, dcc.table_id ), score_calc diff --git a/testgen/template/rollup_scores/rollup_scores_test_run.sql b/testgen/template/rollup_scores/rollup_scores_test_run.sql index 4693f745..a16e860c 100644 --- a/testgen/template/rollup_scores/rollup_scores_test_run.sql +++ b/testgen/template/rollup_scores/rollup_scores_test_run.sql @@ -3,7 +3,7 @@ UPDATE test_runs SET dq_affected_data_points = 0, dq_total_data_points = 0, dq_score_test_run = 1 - WHERE id = '{RUN_ID}'; + WHERE id = :RUN_ID; -- Roll up scoring to test run WITH score_detail @@ -11,7 +11,7 @@ WITH score_detail MAX(r.dq_record_ct) as row_ct, (1.0 - SUM_LN(COALESCE(r.dq_prevalence, 0.0))) * MAX(r.dq_record_ct) as affected_data_points FROM test_results r - WHERE r.test_run_id = '{RUN_ID}' + WHERE r.test_run_id = :RUN_ID AND COALESCE(r.disposition, 'Confirmed') = 'Confirmed' GROUP BY r.test_run_id, r.table_name, r.column_names ), score_calc diff --git a/testgen/template/rollup_scores/rollup_scores_test_table_group.sql b/testgen/template/rollup_scores/rollup_scores_test_table_group.sql index 45fbd50b..7aebeadd 100644 --- a/testgen/template/rollup_scores/rollup_scores_test_table_group.sql +++ b/testgen/template/rollup_scores/rollup_scores_test_table_group.sql @@ -14,7 +14,7 @@ score_calc INNER JOIN last_test_date lp ON (run.test_suite_id = lp.test_suite_id AND run.test_starttime = lp.last_test_run_date) - WHERE ts.table_groups_id = '{TABLE_GROUPS_ID}' + WHERE ts.table_groups_id = :TABLE_GROUPS_ID AND ts.dq_score_exclude = FALSE GROUP BY ts.table_groups_id) UPDATE table_groups @@ -26,7 +26,7 @@ UPDATE table_groups UPDATE data_column_chars SET valid_test_issue_ct = 0, dq_score_testing = 1 - WHERE table_groups_id = '{TABLE_GROUPS_ID}'; + WHERE table_groups_id = :TABLE_GROUPS_ID; -- Roll up latest scores to data_column_chars -- excludes multi-column tests WITH score_calc @@ -44,7 +44,7 @@ WITH score_calc ON (dcc.table_groups_id = ts.table_groups_id AND dcc.table_name = r.table_name AND dcc.column_name = r.column_names) - WHERE dcc.table_groups_id = '{TABLE_GROUPS_ID}' + WHERE dcc.table_groups_id = :TABLE_GROUPS_ID AND COALESCE(ts.dq_score_exclude, FALSE) = FALSE AND COALESCE(r.disposition, 'Confirmed') = 'Confirmed' GROUP BY dcc.column_id ) @@ -57,7 +57,7 @@ UPDATE data_column_chars -- Reset scoring in data_table_chars UPDATE data_table_chars SET dq_score_testing = 1 - WHERE table_groups_id = '{TABLE_GROUPS_ID}'; + WHERE table_groups_id = :TABLE_GROUPS_ID; -- Roll up latest scores to data_table_chars -- includes multi-column tests WITH score_detail @@ -73,7 +73,7 @@ WITH score_detail AND r.test_run_id = ts.last_complete_test_run_id)) ON (dtc.table_groups_id = ts.table_groups_id AND dtc.table_name = r.table_name) - WHERE dtc.table_groups_id = '{TABLE_GROUPS_ID}' + WHERE dtc.table_groups_id = :TABLE_GROUPS_ID AND COALESCE(ts.dq_score_exclude, FALSE) = FALSE AND COALESCE(r.disposition, 'Confirmed') = 'Confirmed' GROUP BY dtc.table_id, r.column_names), diff --git a/testgen/template/score_cards/add_latest_runs.sql b/testgen/template/score_cards/add_latest_runs.sql index f06abad0..bf332d8c 100644 --- a/testgen/template/score_cards/add_latest_runs.sql +++ b/testgen/template/score_cards/add_latest_runs.sql @@ -3,12 +3,12 @@ WITH ranked_profiling AS (SELECT project_code, table_groups_id, id as profiling_run_id, ROW_NUMBER() OVER (PARTITION BY table_groups_id ORDER BY profiling_starttime DESC) as rank FROM profiling_runs r - WHERE project_code = '{project_code}' - AND profiling_starttime <= '{score_history_cutoff_time}' + WHERE project_code = :project_code + AND profiling_starttime <= :score_history_cutoff_time AND r.status = 'Complete') INSERT INTO score_history_latest_runs (definition_id, score_history_cutoff_time, table_groups_id, last_profiling_run_id) -SELECT '{definition_id}' as definition_id, '{score_history_cutoff_time}' as score_history_cutoff_time, table_groups_id, profiling_run_id +SELECT :definition_id as definition_id, :score_history_cutoff_time as score_history_cutoff_time, table_groups_id, profiling_run_id FROM ranked_profiling WHERE rank = 1; @@ -20,11 +20,11 @@ WITH ranked_test_runs FROM test_runs r INNER JOIN test_suites s ON (r.test_suite_id = s.id) - WHERE s.project_code = '{project_code}' - AND r.test_starttime <= '{score_history_cutoff_time}' + WHERE s.project_code = :project_code + AND r.test_starttime <= :score_history_cutoff_time AND r.status = 'Complete') INSERT INTO score_history_latest_runs (definition_id, score_history_cutoff_time, test_suite_id, last_test_run_id) -SELECT '{definition_id}' as definition_id, '{score_history_cutoff_time}' as score_history_cutoff_time, test_suite_id, test_run_id +SELECT :definition_id as definition_id, :score_history_cutoff_time as score_history_cutoff_time, test_suite_id, test_run_id FROM ranked_test_runs WHERE rank = 1; diff --git a/testgen/template/score_cards/get_historical_overall_scores_by_column.sql b/testgen/template/score_cards/get_historical_overall_scores_by_column.sql index 0c9b4596..06485458 100644 --- a/testgen/template/score_cards/get_historical_overall_scores_by_column.sql +++ b/testgen/template/score_cards/get_historical_overall_scores_by_column.sql @@ -19,7 +19,7 @@ FROM ( AND history.last_run_time = v_dq_profile_scoring_history_by_column.score_history_cutoff_time ) WHERE {filters} - AND history.definition_id = '{definition_id}' + AND history.definition_id = :definition_id GROUP BY project_code, history.definition_id, history.last_run_time @@ -39,7 +39,7 @@ FULL OUTER JOIN ( AND history.last_run_time = v_dq_test_scoring_history_by_column.score_history_cutoff_time ) WHERE {filters} - AND history.definition_id = '{definition_id}' + AND history.definition_id = :definition_id GROUP BY project_code, history.definition_id, history.last_run_time diff --git a/testgen/template/score_cards/get_score_card_issues_by_column.sql b/testgen/template/score_cards/get_score_card_issues_by_column.sql index 2844f4be..804caac9 100644 --- a/testgen/template/score_cards/get_score_card_issues_by_column.sql +++ b/testgen/template/score_cards/get_score_card_issues_by_column.sql @@ -4,7 +4,7 @@ WITH score_profiling_runs AS ( table_name, column_name FROM v_dq_profile_scoring_latest_by_column - WHERE {filters} AND {group_by} = '{value}' + WHERE {filters} AND {group_by} = :value ), anomalies AS ( SELECT results.id::VARCHAR AS id, @@ -17,7 +17,7 @@ anomalies AS ( EXTRACT( EPOCH FROM runs.profiling_starttime - ) * 1000 AS time, + )::INT AS time, '' AS name, runs.id::text AS run_id, 'hygiene' AS issue_type @@ -37,7 +37,7 @@ score_test_runs AS ( column_name FROM v_dq_test_scoring_latest_by_column WHERE {filters} - AND {group_by} = '{value}' + AND {group_by} = :value ), tests AS ( SELECT test_results.id::VARCHAR AS id, @@ -50,7 +50,7 @@ tests AS ( EXTRACT( EPOCH FROM test_time - ) * 1000 AS time, + )::INT AS time, test_suites.test_suite AS name, test_results.test_run_id::text AS run_id, 'test' AS issue_type diff --git a/testgen/template/score_cards/get_score_card_issues_by_dimension.sql b/testgen/template/score_cards/get_score_card_issues_by_dimension.sql index b27243e4..43ea8d24 100644 --- a/testgen/template/score_cards/get_score_card_issues_by_dimension.sql +++ b/testgen/template/score_cards/get_score_card_issues_by_dimension.sql @@ -4,7 +4,7 @@ WITH score_profiling_runs AS ( table_name, column_name FROM v_dq_profile_scoring_latest_by_dimension - WHERE {filters} AND {group_by} = '{value}' + WHERE {filters} AND {group_by} = :value ), anomalies AS ( SELECT results.id::VARCHAR AS id, @@ -17,7 +17,7 @@ anomalies AS ( EXTRACT( EPOCH FROM runs.profiling_starttime - ) * 1000 AS time, + )::INT AS time, '' AS name, runs.id::text AS run_id, 'hygiene' AS issue_type @@ -38,7 +38,7 @@ score_test_runs AS ( column_name FROM v_dq_test_scoring_latest_by_dimension WHERE {filters} - AND {group_by} = '{value}' + AND {group_by} = :value ), tests AS ( SELECT test_results.id::VARCHAR AS id, @@ -51,7 +51,7 @@ tests AS ( EXTRACT( EPOCH FROM test_time - ) * 1000 AS time, + )::INT AS time, test_suites.test_suite AS name, test_results.test_run_id::text AS run_id, 'test' AS issue_type diff --git a/testgen/template/validate_tests/ex_disable_tests_test_definitions.sql b/testgen/template/validate_tests/ex_disable_tests_test_definitions.sql index 40793fa6..67478434 100644 --- a/testgen/template/validate_tests/ex_disable_tests_test_definitions.sql +++ b/testgen/template/validate_tests/ex_disable_tests_test_definitions.sql @@ -1,4 +1,4 @@ UPDATE test_definitions SET test_active = 'N' - WHERE test_suite_id = '{TEST_SUITE_ID}' + WHERE test_suite_id = :TEST_SUITE_ID AND test_active = 'D'; diff --git a/testgen/template/validate_tests/ex_flag_tests_test_definitions.sql b/testgen/template/validate_tests/ex_flag_tests_test_definitions.sql index e9ebc1fb..5d0b5a58 100644 --- a/testgen/template/validate_tests/ex_flag_tests_test_definitions.sql +++ b/testgen/template/validate_tests/ex_flag_tests_test_definitions.sql @@ -2,6 +2,6 @@ Mark Test inactive for Missing columns/tables with update status */ UPDATE test_definitions -SET test_active = '{FLAG}', - test_definition_status = LEFT('Inactivated {RUN_DATE}: ' || CONCAT_WS('; ', substring(test_definition_status from 34), '{MESSAGE}'), 200) -WHERE cat_test_id IN ({CAT_TEST_IDS}); +SET test_active = :FLAG, + test_definition_status = LEFT('Inactivated ' || :RUN_DATE || ': ' || CONCAT_WS('; ', substring(test_definition_status from 34), :MESSAGE), 200) +WHERE cat_test_id IN :CAT_TEST_IDS; diff --git a/testgen/template/validate_tests/ex_get_test_column_list_tg.sql b/testgen/template/validate_tests/ex_get_test_column_list_tg.sql index 6beec89c..1498ae87 100644 --- a/testgen/template/validate_tests/ex_get_test_column_list_tg.sql +++ b/testgen/template/validate_tests/ex_get_test_column_list_tg.sql @@ -9,7 +9,7 @@ FROM test_definitions d INNER JOIN test_types t ON d.test_type = t.test_type - WHERE test_suite_id = '{TEST_SUITE_ID}' + WHERE test_suite_id = :TEST_SUITE_ID AND COALESCE(test_active, 'Y') = 'Y' AND t.test_scope = 'column' UNION @@ -21,7 +21,7 @@ FROM test_definitions d INNER JOIN test_types t ON d.test_type = t.test_type - WHERE test_suite_id = '{TEST_SUITE_ID}' + WHERE test_suite_id = :TEST_SUITE_ID AND COALESCE(test_active, 'Y') = 'Y' AND t.test_scope = 'referential' AND t.test_type NOT LIKE 'Aggregate_%' @@ -34,7 +34,7 @@ FROM test_definitions d INNER JOIN test_types t ON d.test_type = t.test_type - WHERE test_suite_id = '{TEST_SUITE_ID}' + WHERE test_suite_id = :TEST_SUITE_ID AND COALESCE(test_active, 'Y') = 'Y' AND t.test_scope IN ('column', 'referential') UNION @@ -46,7 +46,7 @@ FROM test_definitions d INNER JOIN test_types t ON d.test_type = t.test_type - WHERE test_suite_id = '{TEST_SUITE_ID}' + WHERE test_suite_id = :TEST_SUITE_ID AND COALESCE(test_active, 'Y') = 'Y' AND t.test_scope = 'referential' UNION @@ -58,7 +58,7 @@ FROM test_definitions d INNER JOIN test_types t ON d.test_type = t.test_type - WHERE test_suite_id = '{TEST_SUITE_ID}' + WHERE test_suite_id = :TEST_SUITE_ID AND COALESCE(test_active, 'Y') = 'Y' AND t.test_scope = 'referential' AND t.test_type NOT LIKE 'Aggregate_%' @@ -71,7 +71,7 @@ FROM test_definitions d INNER JOIN test_types t ON d.test_type = t.test_type - WHERE test_suite_id = '{TEST_SUITE_ID}' + WHERE test_suite_id = :TEST_SUITE_ID AND COALESCE(test_active, 'Y') = 'Y' AND t.test_scope = 'referential' UNION @@ -82,7 +82,7 @@ FROM test_definitions d INNER JOIN test_types t ON d.test_type = t.test_type - WHERE test_suite_id = '{TEST_SUITE_ID}' + WHERE test_suite_id = :TEST_SUITE_ID AND COALESCE(test_active, 'Y') = 'Y' AND t.test_scope = 'table' ) cols GROUP BY columns; diff --git a/testgen/template/validate_tests/ex_prep_flag_tests_test_definitions.sql b/testgen/template/validate_tests/ex_prep_flag_tests_test_definitions.sql index d5eb6a27..d436a3ca 100644 --- a/testgen/template/validate_tests/ex_prep_flag_tests_test_definitions.sql +++ b/testgen/template/validate_tests/ex_prep_flag_tests_test_definitions.sql @@ -3,4 +3,4 @@ Clean the test definition status before it's set with missing tables / columns i */ UPDATE test_definitions SET test_definition_status = NULL -WHERE cat_test_id IN ({CAT_TEST_IDS}); +WHERE cat_test_id IN :CAT_TEST_IDS; diff --git a/testgen/template/validate_tests/ex_write_test_val_errors.sql b/testgen/template/validate_tests/ex_write_test_val_errors.sql index 639cc3ef..318d76bf 100644 --- a/testgen/template/validate_tests/ex_write_test_val_errors.sql +++ b/testgen/template/validate_tests/ex_write_test_val_errors.sql @@ -12,14 +12,14 @@ INSERT INTO test_results result_status, result_message, result_measure ) - SELECT '{TEST_SUITE_ID}'::UUID, + SELECT :TEST_SUITE_ID, test_type, id, schema_name, table_name, column_name, - '{RUN_DATE}' as test_time, - '{TEST_RUN_ID}' as test_run_id, + :RUN_DATE as test_time, + :TEST_RUN_ID as test_run_id, NULL as input_parameters, NULL as result_code, 'Error' as result_status, @@ -27,4 +27,4 @@ INSERT INTO test_results NULL as result_measure FROM test_definitions WHERE test_active = 'D' - AND test_suite_id = '{TEST_SUITE_ID}'; + AND test_suite_id = :TEST_SUITE_ID; diff --git a/testgen/ui/app.py b/testgen/ui/app.py index de4f8d0e..cc2e6173 100644 --- a/testgen/ui/app.py +++ b/testgen/ui/app.py @@ -6,11 +6,11 @@ from testgen.common import version_service from testgen.common.docker_service import check_basic_configuration from testgen.common.models import with_database_session +from testgen.common.models.project import Project from testgen.ui import bootstrap from testgen.ui.assets import get_asset_path from testgen.ui.components import widgets as testgen -from testgen.ui.services import database_service as db -from testgen.ui.services import javascript_service, project_service, user_session_service +from testgen.ui.services import javascript_service, user_session_service from testgen.ui.session import session @@ -37,7 +37,6 @@ def render(log_level: int = logging.INFO): set_locale() - session.dbschema = db.get_schema() session.sidebar_project = ( session.page_args_pending_router and session.page_args_pending_router.get("project_code") ) or st.query_params.get("project_code", session.sidebar_project) @@ -50,7 +49,7 @@ def render(log_level: int = logging.INFO): if session.authentication_status and not session.logging_in: with st.sidebar: testgen.sidebar( - projects=project_service.get_projects(), + projects=Project.select_where(), current_project=session.sidebar_project, menu=application.menu, current_page=session.current_page, diff --git a/testgen/ui/components/frontend/js/components/connection_form.js b/testgen/ui/components/frontend/js/components/connection_form.js index 19a33c88..2c700b32 100644 --- a/testgen/ui/components/frontend/js/components/connection_form.js +++ b/testgen/ui/components/frontend/js/components/connection_form.js @@ -26,7 +26,7 @@ * @property {string} project_port * @property {string} project_db * @property {string} project_user - * @property {string} password + * @property {string} project_pw_encrypted * @property {boolean} connect_by_url * @property {string?} url * @property {boolean} connect_by_key @@ -109,7 +109,7 @@ const ConnectionForm = (props, saveButton) => { project_port: connection?.project_port ?? defaultPort ?? '', project_db: connection?.project_db ?? '', project_user: connection?.project_user ?? '', - password: isEditMode ? '' : (connection?.password ?? ''), + project_pw_encrypted: isEditMode ? '' : (connection?.project_pw_encrypted ?? ''), connect_by_url: connection?.connect_by_url ?? false, connect_by_key: connection?.connect_by_key ?? false, private_key: isEditMode ? '' : (connection?.private_key ?? ''), @@ -346,7 +346,7 @@ const RedshiftForm = ( const connectionPort = van.state(connection.rawVal.project_port || defaultPorts[flavor.flavor]); const connectionDatabase = van.state(connection.rawVal.project_db ?? ''); const connectionUsername = van.state(connection.rawVal.project_user ?? ''); - const connectionPassword = van.state(connection.rawVal?.password ?? ''); + const connectionPassword = van.state(connection.rawVal?.project_pw_encrypted ?? ''); const [prefixPart, sufixPart] = originalURLTemplate.split('@'); const connectionStringPrefix = van.state(`${prefixPart}@`); @@ -376,7 +376,7 @@ const RedshiftForm = ( project_port: connectionPort.val, project_db: connectionDatabase.val, project_user: connectionUsername.val, - password: connectionPassword.val, + project_pw_encrypted: connectionPassword.val, connect_by_url: connectByUrl.val, url: connectByUrl.val ? connectionStringSuffix.val : connectionStringSuffix.rawVal, connect_by_key: false, @@ -486,7 +486,7 @@ const RedshiftForm = ( value: connectionPassword, height: 38, type: 'password', - placeholder: (originalConnection?.connection_id && originalConnection?.password) ? secretsPlaceholder : '', + placeholder: (originalConnection?.connection_id && originalConnection?.project_pw_encrypted) ? secretsPlaceholder : '', onChange: (value, state) => { connectionPassword.val = value; validityPerField['password'] = state.valid; @@ -528,7 +528,7 @@ const DatabricksForm = ( const connectionHttpPath = van.state(connection.rawVal?.http_path ?? ''); const connectionDatabase = van.state(connection.rawVal?.project_db ?? ''); const connectionUsername = van.state(connection.rawVal?.project_user ?? ''); - const connectionPassword = van.state(connection.rawVal?.password ?? ''); + const connectionPassword = van.state(connection.rawVal?.project_pw_encrypted ?? ''); const [prefixPart, sufixPart] = originalURLTemplate.split('@'); const connectionStringPrefix = van.state(`${prefixPart}@`); @@ -558,7 +558,7 @@ const DatabricksForm = ( project_port: connectionPort.val, project_db: connectionDatabase.val, project_user: connectionUsername.val, - password: connectionPassword.val, + project_pw_encrypted: connectionPassword.val, http_path: connectionHttpPath.val, connect_by_url: connectByUrl.val, url: connectByUrl.val ? connectionStringSuffix.val : connectionStringSuffix.rawVal, @@ -683,7 +683,7 @@ const DatabricksForm = ( value: connectionPassword, height: 38, type: 'password', - placeholder: (originalConnection?.connection_id && originalConnection?.password) ? secretsPlaceholder : '', + placeholder: (originalConnection?.connection_id && originalConnection?.project_pw_encrypted) ? secretsPlaceholder : '', onChange: (value, state) => { connectionPassword.val = value; validityPerField['password'] = state.valid; @@ -720,7 +720,7 @@ const SnowflakeForm = ( const connectionPort = van.state(connection.rawVal.project_port || defaultPorts[flavor.flavor]); const connectionDatabase = van.state(connection.rawVal.project_db ?? ''); const connectionUsername = van.state(connection.rawVal.project_user ?? ''); - const connectionPassword = van.state(connection.rawVal?.password ?? ''); + const connectionPassword = van.state(connection.rawVal?.project_pw_encrypted ?? ''); const connectionPrivateKey = van.state(connection.rawVal?.private_key ?? ''); const connectionPrivateKeyPassphrase = van.state( clearPrivateKeyPhrase.rawVal @@ -756,7 +756,7 @@ const SnowflakeForm = ( project_port: connectionPort.val, project_db: connectionDatabase.val, project_user: connectionUsername.val, - password: connectionPassword.val, + project_pw_encrypted: connectionPassword.val, connect_by_url: connectByUrl.val, url: connectByUrl.val ? connectionStringSuffix.val : connectionStringSuffix.rawVal, connect_by_key: connectByKey.val, @@ -953,7 +953,7 @@ const SnowflakeForm = ( value: connectionPassword, height: 38, type: 'password', - placeholder: (originalConnection?.connection_id && originalConnection?.password) ? secretsPlaceholder : '', + placeholder: (originalConnection?.connection_id && originalConnection?.project_pw_encrypted) ? secretsPlaceholder : '', onChange: (value, state) => { connectionPassword.val = value; validityPerField['password'] = state.valid; @@ -968,7 +968,7 @@ const SnowflakeForm = ( function formatURL(url, host, port, database, httpPath) { return url.replace('', host) .replace('', port) - .replace('', database) + .replace('', database) .replace('', httpPath); } diff --git a/testgen/ui/components/frontend/js/components/input.js b/testgen/ui/components/frontend/js/components/input.js index b5f19b5d..9c0b569f 100644 --- a/testgen/ui/components/frontend/js/components/input.js +++ b/testgen/ui/components/frontend/js/components/input.js @@ -211,6 +211,7 @@ stylesheet.replace(` border: unset; padding: 4px 8px; border-radius: 8px; + outline: none; } .tg-input--field > input::placeholder { @@ -218,9 +219,8 @@ stylesheet.replace(` color: var(--disabled-text-color); } -.tg-input--field:focus, -.tg-input--field:focus-visible { - outline: none; +.tg-input--field:has(input:focus), +.tg-input--field:has(input:focus-visible) { border-color: var(--primary-color); } diff --git a/testgen/ui/components/frontend/js/components/score_issues.js b/testgen/ui/components/frontend/js/components/score_issues.js index db62cbea..d773a1bb 100644 --- a/testgen/ui/components/frontend/js/components/score_issues.js +++ b/testgen/ui/components/frontend/js/components/score_issues.js @@ -32,6 +32,15 @@ import { colorMap, formatTimestamp } from '../display_utils.js'; const { div, i, span } = van.tags; const PAGE_SIZE = 100; const SCROLL_CONTAINER = window.top.document.querySelector('.stMain'); +const statusColors = { + 'Potential PII': colorMap.grey, + Likely: colorMap.orange, + Possible: colorMap.yellow, + Definite: colorMap.red, + Warning: colorMap.yellow, + Failed: colorMap.red, + Passed: colorMap.green, +}; const IssuesTable = ( /** @type Issue[] */ issues, @@ -117,23 +126,24 @@ const IssuesTable = ( ), ), () => Toolbar(filters, issues, category), - div( - { class: 'table-header issues-columns flex-row' }, - Checkbox({ - checked: () => selectedIssues.val.length === PAGE_SIZE, - indeterminate: () => !!selectedIssues.val.length, - onChange: (checked) => { - if (checked) { - selectedIssues.val = displayedIssues.val.map(({ id, issue_type }) => ({ id, issue_type })); - } else { - selectedIssues.val = []; - } - }, - }), - span({ class: category === 'column_name' ? null : 'ml-6' }), - columns.map(c => span({ style: `flex: ${c === 'detail' ? '1 1' : '0 0'} ${ISSUES_COLUMNS_SIZES[c]};` }, ISSUES_COLUMN_LABEL[c])) - ), - () => div( + () => displayedIssues.val.length + ? div( + div( + { class: 'table-header issues-columns flex-row' }, + Checkbox({ + checked: () => selectedIssues.val.length === PAGE_SIZE, + indeterminate: () => !!selectedIssues.val.length, + onChange: (checked) => { + if (checked) { + selectedIssues.val = displayedIssues.val.map(({ id, issue_type }) => ({ id, issue_type })); + } else { + selectedIssues.val = []; + } + }, + }), + span({ class: category === 'column_name' ? null : 'ml-6' }), + columns.map(c => span({ style: `flex: ${c === 'detail' ? '1 1' : '0 0'} ${ISSUES_COLUMNS_SIZES[c]};` }, ISSUES_COLUMN_LABEL[c])) + ), displayedIssues.val.map((row) => div( { class: 'table-row flex-row issues-row' }, Checkbox({ @@ -151,18 +161,22 @@ const IssuesTable = ( : ColumnProfilingButton(row.column, row.table, row.table_group_id), columns.map((columnName) => TableCell(row, columnName)), )), + () => Paginator({ + pageIndex, + count: filteredIssues.val.length, + pageSize: PAGE_SIZE, + onChange: (newIndex) => { + if (newIndex !== pageIndex.val) { + pageIndex.val = newIndex; + SCROLL_CONTAINER.scrollTop = 0; + } + }, + }), + ) + : div( + { class: 'mt-7 mb-6 text-secondary', style: 'text-align: center;' }, + 'No issues found matching filters', ), - () => Paginator({ - pageIndex, - count: filteredIssues.val.length, - pageSize: PAGE_SIZE, - onChange: (newIndex) => { - if (newIndex !== pageIndex.val) { - pageIndex.val = newIndex; - SCROLL_CONTAINER.scrollTop = 0; - } - }, - }), ); }; @@ -203,7 +217,10 @@ const Toolbar = ( .sort() .map(value => ({ label: value, value })), status: [ 'Definite', 'Failed', 'Likely', 'Possible', 'Warning', 'Potential PII' ] - .map(value => ({ label: value, value })), + .map(value => ({ + label: div({ class: 'flex-row fx-gap-2' }, dot({}, statusColors[value]), span(value)), + value, + })), }; const displayedFilters = [ 'type', 'status' ]; @@ -275,16 +292,6 @@ const IssueCell = (value, row) => { }; const StatusCell = (value, row) => { - const statusColors = { - 'Potential PII': colorMap.grey, - Likely: colorMap.orange, - Possible: colorMap.yellow, - Definite: colorMap.red, - Warning: colorMap.yellow, - Failed: colorMap.red, - Passed: colorMap.green, - }; - return div( { class: 'flex-row fx-align-flex-center', style: `flex: 0 0 ${ISSUES_COLUMNS_SIZES.status}` }, dot({ class: 'mr-2' }, statusColors[value]), diff --git a/testgen/ui/components/frontend/js/components/sorting_selector.js b/testgen/ui/components/frontend/js/components/sorting_selector.js index 824e118a..0833d3dd 100644 --- a/testgen/ui/components/frontend/js/components/sorting_selector.js +++ b/testgen/ui/components/frontend/js/components/sorting_selector.js @@ -37,13 +37,6 @@ const SortingSelector = (/** @type {Properties} */ props) => { {} ); - const selectedDiv = div( - { - class: 'tg-sort-selector--column-list', - style: `flex-grow: 1`, - }, - ); - const directionIcons = { ASC: `arrow_upward`, DESC: `arrow_downward`, @@ -54,6 +47,7 @@ const SortingSelector = (/** @type {Properties} */ props) => { const directionIcon = van.derive(() => directionIcons[state.val.direction]); return button( { + class: 'flex-row', onclick: () => { state.val = { ...state.val, direction: state.val.direction === "DESC" ? "ASC" : "DESC" }; }, @@ -63,12 +57,25 @@ const SortingSelector = (/** @type {Properties} */ props) => { directionIcon, ), span(columnLabel[colId]), + i( + { + class: `material-symbols-rounded clickable dismiss-button`, + style: `margin-left: auto;`, + onclick: (event) => { + event?.preventDefault(); + event?.stopPropagation(); + + componentState[colId].val = { direction: defaultDirection, order: null }; + }, + }, + 'close', + ), ) } const selectColumn = (colId, direction) => { - componentState[colId].val = { direction: direction, order: selectedDiv.childElementCount } - van.add(selectedDiv, activeColumnItem(colId)); + const activeColumnsCount = Object.values(componentState).filter((columnState) => columnState.val.order != null).length; + componentState[colId].val = { direction: direction, order: activeColumnsCount }; } prevComponentState.forEach(([colId, direction]) => selectColumn(colId, direction)); @@ -79,7 +86,6 @@ const SortingSelector = (/** @type {Properties} */ props) => { componentState[colId].val = { direction: defaultDirection, order: null } ) ); - selectedDiv.innerHTML = ``; } const externalComponentState = () => Object.entries(componentState).filter( @@ -112,13 +118,6 @@ const SortingSelector = (/** @type {Properties} */ props) => { ) } - const optionsDiv = div( - { - class: 'tg-sort-selector--column-list', - }, - columns.map(([colLabel, colId]) => van.derive(() => columnItem(colId))), - ) - const resetDisabled = () => Object.entries(componentState).filter( ([colId, colState]) => colState.val.order != null ).length === 0; @@ -133,12 +132,26 @@ const SortingSelector = (/** @type {Properties} */ props) => { }, span("Selected columns") ), - selectedDiv, + () => div( + { + class: 'tg-sort-selector--column-list', + style: `flex-grow: 1`, + }, + Object.entries(componentState) + .filter(([, colState]) => colState.val.order != null) + .sort(([, colState]) => colState.val.order) + .map(([colId,]) => activeColumnItem(colId)) + ), div( { class: `tg-sort-selector--header` }, span("Available columns") ), - optionsDiv, + div( + { + class: 'tg-sort-selector--column-list', + }, + columns.map(([colLabel, colId]) => van.derive(() => columnItem(colId))), + ), div( { class: `tg-sort-selector--footer` }, button( @@ -228,6 +241,13 @@ stylesheet.replace(` color: var(--disabled-text-color) !important; } +.dismiss-button { + margin-left: auto; + color: var(--disabled-text-color); +} +.dismiss-button:hover { + color: var(--button-text-color); +} @media (prefers-color-scheme: dark) { .tg-sort-selector--column-list button:hover { diff --git a/testgen/ui/components/frontend/js/components/table_group_test.js b/testgen/ui/components/frontend/js/components/table_group_test.js index c5dfcaf6..bb226a45 100644 --- a/testgen/ui/components/frontend/js/components/table_group_test.js +++ b/testgen/ui/components/frontend/js/components/table_group_test.js @@ -2,14 +2,20 @@ * @typedef TableGroupPreview * @type {object} * @property {string} schema - * @property {string[]?} tables + * @property {Record?} tables * @property {number?} column_count * @property {boolean?} success * @property {string?} message + * + * @typedef ComponentOptions + * @type {object} + * @property {(() => void)?} onVerifyAcess */ import van from '../van.min.js'; -import { getValue } from '../utils.js'; +import { emitEvent, getValue } from '../utils.js'; import { Alert } from '../components/alert.js'; +import { Icon } from '../components/icon.js'; +import { Button } from '../components/button.js'; const { div, span, strong } = van.tags; @@ -17,9 +23,10 @@ const { div, span, strong } = van.tags; * * @param {string} schema * @param {TableGroupPreview?} preview + * @param {ComponentOptions} options * @returns {HTMLElement} */ -const TableGroupTest = (schema, preview) => { +const TableGroupTest = (schema, preview, options) => { return div( { class: 'flex-column fx-gap-2' }, div( @@ -34,7 +41,7 @@ const TableGroupTest = (schema, preview) => { div( { class: 'flex-row fx-gap-1' }, strong({}, 'Table Count:'), - () => span({}, getValue(preview)?.tables?.length ?? '--'), + () => span({}, Object.keys(getValue(preview)?.tables ?? {})?.length ?? '--'), ), div( { class: 'flex-row fx-gap-1' }, @@ -42,6 +49,18 @@ const TableGroupTest = (schema, preview) => { () => span({}, getValue(preview)?.column_count ?? '--'), ), ), + options.onVerifyAcess + ? div( + { class: 'flex-row' }, + span({ class: 'fx-flex' }), + Button({ + label: 'Verify Access', + width: 'fit-content', + type: 'stroked', + onclick: options.onVerifyAcess, + }), + ) + : '', ), () => { const tableGroupPreview = getValue(preview); @@ -51,23 +70,50 @@ const TableGroupTest = (schema, preview) => { return ''; } + const tables = tableGroupPreview?.tables ?? {}; + const hasTables = Object.keys(tables).length > 0; + const verifiedAccess = Object.values(tables).some(v => v != null); + const tableAccessWarning = Object.values(tables).some(v => v != null && v === false) + ? tableGroupPreview.message + : ''; + return div( - { class: 'table hoverable p-3' }, - div( - { class: 'table-header' }, - span('Tables'), - ), + {class: 'flex-column fx-gap-2'}, div( - { class: 'flex-column', style: 'max-height: 200px; overflow-y: auto;' }, - tableGroupPreview?.tables?.length - ? tableGroupPreview.tables.map((table) => - div({ class: 'table-row' }, table), - ) - : div( - { class: 'flex-row fx-justify-center', style: 'height: 50px; font-size: 16px;'}, - tableGroupPreview.message ?? 'No tables found.' - ), + { class: 'table hoverable p-3' }, + div( + { class: 'table-header flex-row fx-justify-space-between' }, + span('Tables'), + verifiedAccess + ? span({class: 'flex-row fx-justify-center', style: 'width: 100px;'}, 'Has access?') + : '', + ), + div( + { class: 'flex-column', style: 'max-height: 200px; overflow-y: auto;' }, + hasTables + ? Object.entries(tables).map(([tableName, hasAccess]) => + div( + { class: 'table-row flex-row fx-justify-space-between' }, + span(tableName), + hasAccess != null + ? span( + {class: 'flex-row fx-justify-center', style: 'width: 100px;'}, + hasAccess + ? Icon({classes: 'text-green', size: 20}, 'check_circle') + : Icon({classes: 'text-error', size: 20}, 'dangerous'), + ) + : '', + ), + ) + : div( + { class: 'flex-row fx-justify-center', style: 'height: 50px; font-size: 16px;'}, + tableGroupPreview.message ?? 'No tables found.' + ), + ), ), + tableAccessWarning ? + Alert({type: 'warn', closeable: true, icon: 'warning'}, span(tableAccessWarning)) + : '', ); }, ); diff --git a/testgen/ui/components/frontend/js/data_profiling/column_profiling_history.js b/testgen/ui/components/frontend/js/data_profiling/column_profiling_history.js index 06d3f426..c2ca57cc 100644 --- a/testgen/ui/components/frontend/js/data_profiling/column_profiling_history.js +++ b/testgen/ui/components/frontend/js/data_profiling/column_profiling_history.js @@ -16,6 +16,7 @@ import { Streamlit } from '../streamlit.js'; import { emitEvent, getValue, loadStylesheet } from '../utils.js'; import { formatTimestamp } from '../display_utils.js'; import { ColumnDistributionCard } from './column_distribution.js'; +import { Card } from '../components/card.js'; const { div, span } = van.tags; @@ -24,24 +25,34 @@ const ColumnProfilingHistory = (/** @type Properties */ props) => { Streamlit.setFrameHeight(600); window.testgen.isPage = true; + const selectedRunId = van.state(null); + return div( { class: 'column-history flex-row fx-align-stretch' }, () => div( { class: 'column-history--list' }, getValue(props.profiling_runs).map(({ run_id, run_date }, index) => div( { - class: () => `column-history--item clickable ${getValue(props.selected_item).profile_run_id === run_id ? 'selected' : ''}`, - onclick: () => emitEvent('RunSelected', { payload: run_id }), + class: () => `column-history--item clickable ${selectedRunId.val === run_id ? 'selected' : ''}`, + onclick: () => { + selectedRunId.val = run_id; + emitEvent('RunSelected', { payload: run_id }); + }, }, div(formatTimestamp(run_date)), index === 0 ? span({ class: 'text-caption' }, 'Latest run') : null, )), ), span({class: 'column-history--divider'}), - () => div( - { class: 'column-history--details' }, - ColumnDistributionCard({}, getValue(props.selected_item)), - ), + () => getValue(props.selected_item) + ? div( + { class: 'column-history--details' }, + ColumnDistributionCard({}, getValue(props.selected_item)), + ) + : Card({ + class: 'column-history--empty', + content: 'No data available for column in selected profiling run.', + }), ); } @@ -52,7 +63,7 @@ stylesheet.replace(` } .column-history--list { - flex: 150px 1 1; + flex: 250px 0 1; } .column-history--item { @@ -73,6 +84,7 @@ stylesheet.replace(` .column-history--details { overflow: auto; + flex: auto; } .column-history--divider { @@ -80,6 +92,14 @@ stylesheet.replace(` background-color: var(--grey); margin: 0 10px; } + +.column-history--empty { + flex-grow: 1; + display: flex; + flex-flow: row; + justify-content: center; + align-items: center; +} `); export { ColumnProfilingHistory }; diff --git a/testgen/ui/components/frontend/js/data_profiling/table_create_script.js b/testgen/ui/components/frontend/js/data_profiling/table_create_script.js index 910d38f1..e0e9261f 100644 --- a/testgen/ui/components/frontend/js/data_profiling/table_create_script.js +++ b/testgen/ui/components/frontend/js/data_profiling/table_create_script.js @@ -20,6 +20,9 @@ const TableCreateScriptCard = (/** @type Properties */ _props, /** @type Table * label: 'View Script', icon: 'sdk', width: 'auto', + disabled: !item.column_ct, + tooltip: item.column_ct ? null : 'No columns detected in table', + tooltipPosition: 'right', onclick: () => emitEvent('CreateScriptClicked', { payload: item }), }), ), diff --git a/testgen/ui/components/frontend/js/display_utils.js b/testgen/ui/components/frontend/js/display_utils.js index f0315368..e4dcc612 100644 --- a/testgen/ui/components/frontend/js/display_utils.js +++ b/testgen/ui/components/frontend/js/display_utils.js @@ -3,7 +3,7 @@ function formatTimestamp( /** @type boolean */ show_year, ) { if (timestamp) { - const date = new Date(timestamp); + const date = new Date(typeof timestamp === 'number' ? timestamp * 1000 : timestamp); if (!isNaN(date)) { const months = [ 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec' ]; const hours = date.getHours(); @@ -30,11 +30,13 @@ function formatDuration(/** @type string */ duration) { return formatted.trim() || '< 1s'; } -function formatNumber(/** @type number | string */ number, /** @type number */ precision = 3) { +function formatNumber(/** @type number | string */ number, /** @type number */ decimals = 3) { if (!['number', 'string'].includes(typeof number) || isNaN(number)) { return '--'; } - return parseFloat(Number(number).toPrecision(precision)).toLocaleString(); + // toFixed - rounds to specified number of decimal places + // toLocaleString - adds commas as necessary + return parseFloat(Number(number).toFixed(decimals)).toLocaleString(); } function capitalize(/** @type string */ text) { diff --git a/testgen/ui/components/frontend/js/pages/data_catalog.js b/testgen/ui/components/frontend/js/pages/data_catalog.js index 246d7a53..2587a4ef 100644 --- a/testgen/ui/components/frontend/js/pages/data_catalog.js +++ b/testgen/ui/components/frontend/js/pages/data_catalog.js @@ -1,13 +1,7 @@ /** * @import { Column, Table } from '../data_profiling/data_profiling_utils.js'; * @import { TreeNode, SelectedNode } from '../components/tree.js'; - * - * @typedef ProjectSummary - * @type {object} - * @property {string} project_code - * @property {number} connections_ct - * @property {number} table_groups_ct - * @property {string} default_connection_id + * @import { ProjectSummary } from '../types.js'; * * @typedef ColumnPath * @type {object} @@ -182,7 +176,7 @@ const DataCatalog = (/** @type Properties */ props) => { const userCanNavigate = getValue(props.permissions)?.can_navigate ?? false; const projectSummary = getValue(props.project_summary); - return projectSummary.table_groups_ct > 0 + return projectSummary.table_group_count > 0 ? div( { class: 'flex-column tg-dh' }, div( @@ -216,10 +210,10 @@ const DataCatalog = (/** @type Properties */ props) => { multiSelectToggle: userCanEdit, multiSelectToggleLabel: 'Edit multiple', onMultiSelect: (/** @type string[] | null */ selected) => multiSelectedItems.val = selected, - isNodeHidden: (/** @type TreeNode */ node, /** string */ search) => - !node.label.toLowerCase().includes(search.toLowerCase()) - || (!!node.children && !searchOptions.tableName.val) - || (!node.children && !searchOptions.columnName.val) + isNodeHidden: (/** @type TreeNode */ node, /** string */ search) => search + && (!node.label.toLowerCase().includes(search.toLowerCase()) + || (!!node.children && !searchOptions.tableName.val) + || (!node.children && !searchOptions.columnName.val)) || ![ node.criticalDataElement, false ].includes(filters.criticalDataElement.val) || TAG_KEYS.some(key => ![ node[key], null ].includes(filters[key].val)), onApplySearchOptions: () => { @@ -228,7 +222,7 @@ const DataCatalog = (/** @type Properties */ props) => { // Otherwise, nothing will be matched and the user might not realize why if (!searchOptions.tableName.val && !searchOptions.columnName.val) { searchOptions.tableName.val = true; - searchOptions.columnName.val = true + searchOptions.columnName.val = true; } }, hasActiveFilters: () => filters.criticalDataElement.val || TAG_KEYS.some(key => !!filters[key].val), @@ -728,7 +722,7 @@ const ConditionalEmptyState = ( onclick: () => emitEvent('RunProfilingClicked', {}), }), } - if (projectSummary.connections_ct <= 0) { + if (projectSummary.connection_count <= 0) { args = { label: 'Your project is empty', message: EMPTY_STATE_MESSAGE.connection, @@ -739,7 +733,7 @@ const ConditionalEmptyState = ( disabled: !userCanNavigate, }, }; - } else if (projectSummary.table_groups_ct <= 0) { + } else if (projectSummary.table_group_count <= 0) { args = { label: 'Your project is empty', message: EMPTY_STATE_MESSAGE.tableGroup, diff --git a/testgen/ui/components/frontend/js/pages/profiling_runs.js b/testgen/ui/components/frontend/js/pages/profiling_runs.js index 2b2c3392..93981556 100644 --- a/testgen/ui/components/frontend/js/pages/profiling_runs.js +++ b/testgen/ui/components/frontend/js/pages/profiling_runs.js @@ -33,7 +33,7 @@ import { SummaryBar } from '../components/summary_bar.js'; import { Link } from '../components/link.js'; import { Button } from '../components/button.js'; import { Streamlit } from '../streamlit.js'; -import { emitEvent, getValue, resizeFrameHeightToElement } from '../utils.js'; +import { emitEvent, getValue, resizeFrameHeightToElement, resizeFrameHeightOnDOMChange } from '../utils.js'; import { formatTimestamp, formatDuration } from '../display_utils.js'; import { Checkbox } from '../components/checkbox.js'; @@ -47,7 +47,7 @@ const ProfilingRuns = (/** @type Properties */ props) => { try { items = JSON.parse(props.items?.val); } catch { } - Streamlit.setFrameHeight(100 * items.length); + Streamlit.setFrameHeight(100 * items.length || 150); return items; }); const columns = ['5%', '15%', '20%', '20%', '30%', '10%']; @@ -58,6 +58,7 @@ const ProfilingRuns = (/** @type Properties */ props) => { const tableId = 'profiling-runs-table'; resizeFrameHeightToElement(tableId); + resizeFrameHeightOnDOMChange(tableId); const initializeSelectedStates = (items) => { for (const profilingRun of items) { @@ -73,7 +74,8 @@ const ProfilingRuns = (/** @type Properties */ props) => { initializeSelectedStates(profilingRunItems.val); }); - return div( + return () => getValue(profilingRunItems).length + ? div( { class: 'table', id: tableId }, () => { const items = profilingRunItems.val; @@ -146,9 +148,13 @@ const ProfilingRuns = (/** @type Properties */ props) => { 'Profiling Score', ), ), - () => div( + div( profilingRunItems.val.map(item => ProfilingRunItem(item, columns, selectedRuns[item.profiling_run_id], userCanRun, userCanEdit)), ), + ) + : div( + { class: 'pt-7 text-secondary', style: 'text-align: center;' }, + 'No profiling runs found matching filters', ); } diff --git a/testgen/ui/components/frontend/js/pages/project_dashboard.js b/testgen/ui/components/frontend/js/pages/project_dashboard.js index 1ae77b42..d2394b03 100644 --- a/testgen/ui/components/frontend/js/pages/project_dashboard.js +++ b/testgen/ui/components/frontend/js/pages/project_dashboard.js @@ -1,26 +1,7 @@ /** - * @typedef ProjectSummary - * @type {object} - * @property {string} project_code - * @property {number} test_runs_count - * @property {number} profiling_runs_count - * @property {number} connections_count - * @property {string} default_connection_id - * - * @typedef TestSuiteSummary - * @type {object} - * @property {string} id - * @property {string} test_suite - * @property {number} test_ct - * @property {number?} latest_run_start - * @property {string?} latest_run_id - * @property {number} last_run_test_ct - * @property {number} last_run_passed_ct - * @property {number} last_run_warning_ct - * @property {number} last_run_failed_ct - * @property {number} last_run_error_ct - * @property {number} last_run_dismissed_ct - * + * @import { ProjectSummary } from '../types.js'; + * @import { TestSuiteSummary } from '../types.js'; + * * @typedef TableGroupSummary * @type {object} * @property {string} id @@ -48,7 +29,7 @@ * * @typedef Properties * @type {object} - * @property {ProjectSummary} project + * @property {ProjectSummary} project_summary * @property {TableGroupSummary[]} table_groups * @property {SortOption[]} table_groups_sort_options */ @@ -129,11 +110,16 @@ const ProjectDashboard = (/** @type Properties */ props) => { ) : '', () => getValue(tableGroups).length - ? div( - { class: 'flex-column mt-4' }, - getValue(filteredTableGroups).map(tableGroup => TableGroupCard(tableGroup)), - ) - : ConditionalEmptyState(getValue(props.project)), + ? getValue(filteredTableGroups).length + ? div( + { class: 'flex-column mt-4' }, + getValue(filteredTableGroups).map(tableGroup => TableGroupCard(tableGroup)) + ) + : div( + { class: 'mt-7 text-secondary', style: 'text-align: center;' }, + 'No table groups found matching filters', + ) + : ConditionalEmptyState(getValue(props.project_summary)), ); } @@ -173,7 +159,7 @@ const TableGroupLatestProfile = (/** @type TableGroupSummary */ tableGroup) => { ); } - const daysAgo = Math.round((new Date() - new Date(tableGroup.latest_profile_start)) / (1000 * 60 * 60 * 24)); + const daysAgo = Math.round((new Date() - new Date(tableGroup.latest_profile_start * 1000)) / (1000 * 60 * 60 * 24)); return div( div( @@ -283,7 +269,7 @@ const ConditionalEmptyState = (/** @type ProjectSummary */ project) => { }, }; - const args = project.connections_count > 0 ? forTablegroups : forConnections; + const args = project.connection_count > 0 ? forTablegroups : forConnections; return EmptyState({ icon: 'home', diff --git a/testgen/ui/components/frontend/js/pages/quality_dashboard.js b/testgen/ui/components/frontend/js/pages/quality_dashboard.js index e502b011..55a5e22f 100644 --- a/testgen/ui/components/frontend/js/pages/quality_dashboard.js +++ b/testgen/ui/components/frontend/js/pages/quality_dashboard.js @@ -1,13 +1,6 @@ /** * @import { Score } from '../components/score_card.js'; - * - * @typedef ProjectSummary - * @type {object} - * @property {string} project_code - * @property {number} connections_count - * @property {string} default_connection_id - * @property {number} table_groups_count - * @property {number} profiling_runs_count + * @import { ProjectSummary } from '../types.js'; * * @typedef Category * @type {object} @@ -66,20 +59,25 @@ const QualityDashboard = (/** @type {Properties} */ props) => { sortedBy, getValue(props.project_summary), ), - () => div( - { class: 'flex-row fx-flex-wrap fx-gap-4' }, - getValue(scores).map(score => ScoreCard( - score, - Link({ - label: 'View details', - right_icon: 'chevron_right', - href: 'quality-dashboard:score-details', - class: 'ml-4', - params: { definition_id: score.id }, - }), - {showHistory: true}, - )) - ), + () => getValue(scores).length + ? div( + { class: 'flex-row fx-flex-wrap fx-gap-4' }, + getValue(scores).map(score => ScoreCard( + score, + Link({ + label: 'View details', + right_icon: 'chevron_right', + href: 'quality-dashboard:score-details', + class: 'ml-4', + params: { definition_id: score.id }, + }), + {showHistory: true}, + )) + ) + : div( + { class: 'mt-7 text-secondary', style: 'text-align: center;' }, + 'No scorecards found matching filters', + ), ) : ConditionalEmptyState(getValue(props.project_summary)), ); }; @@ -91,7 +89,7 @@ const Toolbar = ( /** @type ProjectSummary */ projectSummary ) => { const sortOptions = [ - { label: "Score Name", value: "name" }, + { label: "Scorecard Name", value: "name" }, { label: "Lowest Score", value: "score" }, ]; @@ -103,7 +101,7 @@ const Toolbar = ( style: 'font-size: 14px; margin-right: 16px;', icon: 'search', clearable: true, - placeholder: 'Search scores', + placeholder: 'Search scorecards', value: filterBy, onChange: options?.onsearch, testId: 'scorecards-filter', @@ -153,7 +151,7 @@ const ConditionalEmptyState = (/** @type ProjectSummary */ projectSummary) => { }, }; - if (projectSummary.connections_count <= 0) { + if (projectSummary.connection_count <= 0) { args = { message: EMPTY_STATE_MESSAGE.connection, link: { @@ -162,9 +160,9 @@ const ConditionalEmptyState = (/** @type ProjectSummary */ projectSummary) => { params: { project_code: projectSummary.project_code }, }, }; - } else if (projectSummary.profiling_runs_count <= 0) { + } else if (projectSummary.profiling_run_count <= 0) { args = { - message: projectSummary.table_groups_count ? EMPTY_STATE_MESSAGE.profiling : EMPTY_STATE_MESSAGE.tableGroup, + message: projectSummary.table_group_count ? EMPTY_STATE_MESSAGE.profiling : EMPTY_STATE_MESSAGE.tableGroup, link: { label: 'Go to Table Groups', href: 'table-groups', diff --git a/testgen/ui/components/frontend/js/pages/table_group_list.js b/testgen/ui/components/frontend/js/pages/table_group_list.js index 1037c39c..059a5746 100644 --- a/testgen/ui/components/frontend/js/pages/table_group_list.js +++ b/testgen/ui/components/frontend/js/pages/table_group_list.js @@ -1,4 +1,5 @@ /** + * @import { ProjectSummary } from '../types.js'; * @import { TableGroup } from '../components/table_group_form.js'; * @import { Connection } from '../components/connection_form.js'; * @@ -8,7 +9,7 @@ * * @typedef Properties * @type {object} - * @property {string} project_code + * @property {ProjectSummary} project_summary * @property {string?} connection_id * @property {Connection[]} connections * @property {TableGroup[]} table_groups @@ -43,12 +44,13 @@ const TableGroupList = (props) => { resizeFrameHeightOnDOMChange(wrapperId); return div( - { id: wrapperId, style: 'overflow-y: auto;' }, + { id: wrapperId, class: 'tg-tablegroups' }, () => { const permissions = getValue(props.permissions) ?? {can_edit: false}; const connections = getValue(props.connections) ?? []; const connectionId = getValue(props.connection_id); const tableGroups = getValue(props.table_groups) ?? []; + const projectSummary = getValue(props.project_summary); if (connections.length <= 0) { return EmptyState({ @@ -58,16 +60,17 @@ const TableGroupList = (props) => { link: { label: 'Go to Connections', href: 'connections', - params: { project_code: getValue(props.project_code) }, + params: { project_code: projectSummary.project_code }, disabled: !permissions.can_edit, }, }); } - return tableGroups.length > 0 + return projectSummary.table_group_count > 0 ? div( Toolbar(permissions, connections, connectionId), - tableGroups.map((tableGroup) => Card({ + tableGroups.length + ? tableGroups.map((tableGroup) => Card({ testId: 'table-group-card', class: '', title: div( @@ -89,7 +92,7 @@ const TableGroupList = (props) => { Link({ label: 'View test suites', href: 'test-suites', - params: { 'project_code': getValue(props.project_code), 'table_group_id': tableGroup.id }, + params: { 'project_code': projectSummary.project_code, 'table_group_id': tableGroup.id }, right_icon: 'chevron_right', right_icon_size: 20, }), @@ -179,6 +182,10 @@ const TableGroupList = (props) => { ) : undefined, })) + : div( + { class: 'mt-7 text-secondary', style: 'text-align: center;' }, + 'No table groups found matching filters', + ), ) : EmptyState({ icon: 'table_view', @@ -279,6 +286,11 @@ const TruncatedText = ({ max, ...options }, ...children) => { const stylesheet = new CSSStyleSheet(); stylesheet.replace(` +.tg-tablegroups { + overflow-y: auto; + min-height: 400px; +} + .tg-tablegroup--card-title h4 { margin: 0; color: var(--primary-text-color); diff --git a/testgen/ui/components/frontend/js/pages/table_group_wizard.js b/testgen/ui/components/frontend/js/pages/table_group_wizard.js index a88829d3..916506ad 100644 --- a/testgen/ui/components/frontend/js/pages/table_group_wizard.js +++ b/testgen/ui/components/frontend/js/pages/table_group_wizard.js @@ -143,12 +143,22 @@ const TableGroupWizard = (props) => { stepsValidity.testTableGroup.val = false; stepsState.testTableGroup.val = false; - emitEvent('PreviewTableGroupClicked', { payload: tableGroup }); + emitEvent('PreviewTableGroupClicked', { payload: {table_group: tableGroup} }); } return TableGroupTest( tableGroup.table_group_schema ?? '--', props.table_group_preview, + { + onVerifyAcess: () => { + emitEvent('PreviewTableGroupClicked', { + payload: { + table_group: stepsState.tableGroup.rawVal, + verify_access: true, + }, + }); + } + } ); }), () => { diff --git a/testgen/ui/components/frontend/js/pages/test_runs.js b/testgen/ui/components/frontend/js/pages/test_runs.js index 0159b0cd..f5376560 100644 --- a/testgen/ui/components/frontend/js/pages/test_runs.js +++ b/testgen/ui/components/frontend/js/pages/test_runs.js @@ -33,7 +33,7 @@ import { SummaryBar } from '../components/summary_bar.js'; import { Link } from '../components/link.js'; import { Button } from '../components/button.js'; import { Streamlit } from '../streamlit.js'; -import { emitEvent, getValue, resizeFrameHeightToElement } from '../utils.js'; +import { emitEvent, getValue, resizeFrameHeightToElement, resizeFrameHeightOnDOMChange } from '../utils.js'; import { formatTimestamp, formatDuration } from '../display_utils.js'; import { Checkbox } from '../components/checkbox.js'; @@ -47,7 +47,7 @@ const TestRuns = (/** @type Properties */ props) => { try { items = JSON.parse(props.items?.val); } catch { } - Streamlit.setFrameHeight(100 * items.length); + Streamlit.setFrameHeight(100 * items.length || 150); return items; }); const columns = ['5%', '28%', '17%', '40%', '10%']; @@ -58,6 +58,7 @@ const TestRuns = (/** @type Properties */ props) => { const tableId = 'test-runs-table'; resizeFrameHeightToElement(tableId); + resizeFrameHeightOnDOMChange(tableId); const initializeSelectedStates = (items) => { for (const testRun of items) { @@ -73,7 +74,8 @@ const TestRuns = (/** @type Properties */ props) => { initializeSelectedStates(testRunItems.val); }); - return div( + return () => getValue(testRunItems).length + ? div( { class: 'table', id: tableId }, () => { const items = testRunItems.val; @@ -142,9 +144,13 @@ const TestRuns = (/** @type Properties */ props) => { 'Testing Score', ), ), - () => div( + div( testRunItems.val.map(item => TestRunItem(item, columns, selectedRuns[item.test_run_id], userCanRun, userCanEdit)), ), + ) + : div( + { class: 'pt-7 text-secondary', style: 'text-align: center;' }, + 'No test runs found matching filters', ); } diff --git a/testgen/ui/components/frontend/js/pages/test_suites.js b/testgen/ui/components/frontend/js/pages/test_suites.js index 4aba36ce..f38faed3 100644 --- a/testgen/ui/components/frontend/js/pages/test_suites.js +++ b/testgen/ui/components/frontend/js/pages/test_suites.js @@ -1,12 +1,6 @@ /** - * @typedef ProjectSummary - * @type {object} - * @property {string} project_code - * @property {number} test_suites_ct - * @property {number} connections_ct - * @property {number} table_groups_ct - * @property {string} default_connection_id - * @property {boolean} can_export_to_observability + * @import { ProjectSummary } from '../types.js'; + * @import { TestSuiteSummary } from '../types.js'; * * @typedef TableGroupOption * @type {object} @@ -14,24 +8,6 @@ * @property {string} name * @property {boolean} selected * - * @typedef TestSuite - * @type {object} - * @property {string} id - * @property {string} connection_name - * @property {string} table_groups_name - * @property {string} test_suite - * @property {string} test_suite_description - * @property {number} test_ct - * @property {string} latest_run_start - * @property {string} latest_run_id - * @property {number} last_run_test_ct - * @property {number} last_run_passed_ct - * @property {number} last_run_warning_ct - * @property {number} last_run_failed_ct - * @property {number} last_run_error_ct - * @property {number} last_run_dismissed_ct - * @property {string} last_complete_profile_run_id - * * @typedef Permissions * @type {object} * @property {boolean} can_edit @@ -39,7 +15,7 @@ * @typedef Properties * @type {object} * @property {ProjectSummary} project_summary - * @property {TestSuite} test_suites + * @property {TestSuiteSummary} test_suites * @property {TableGroupOption[]} table_group_filter_options * @property {Permissions} permissions */ @@ -73,7 +49,7 @@ const TestSuites = (/** @type Properties */ props) => { { id: wrapperId, style: 'overflow-y: auto;' }, () => { const projectSummary = getValue(props.project_summary); - return projectSummary.test_suites_ct > 0 + return projectSummary.test_suite_count > 0 ? div( { class: 'tg-test-suites'}, () => div( @@ -112,9 +88,10 @@ const TestSuites = (/** @type Properties */ props) => { : '', ), ), - () => div( + () => getValue(testSuites)?.length + ? div( { class: 'flex-column' }, - getValue(testSuites).map((/** @type TestSuite */ testSuite) => Card({ + getValue(testSuites).map((/** @type TestSuiteSummary */ testSuite) => Card({ border: true, testId: 'test-suite-card', title: () => div( @@ -129,11 +106,13 @@ const TestSuites = (/** @type Properties */ props) => { Button({ type: 'icon', icon: 'output', - tooltip: projectSummary.can_export_to_observability - ? 'Export results to Observability' - : 'Observability export not configured in Project Settings', + tooltip: !projectSummary.can_export_to_observability + ? 'Observability export not configured in Project Settings' + : !testSuite.export_to_observability + ? 'Observability export not configured for test suite' + : 'Export results to Observability', tooltipPosition: 'left', - disabled: !projectSummary.can_export_to_observability, + disabled: !projectSummary.can_export_to_observability || !testSuite.export_to_observability, onclick: () => emitEvent('ExportActionClicked', {payload: testSuite.id}), }), Button({ @@ -217,6 +196,10 @@ const TestSuites = (/** @type Properties */ props) => { ), ), })), + ) + : div( + { class: 'mt-7 text-secondary', style: 'text-align: center;' }, + 'No test suites found matching filters', ), ) : ConditionalEmptyState(projectSummary, userCanEdit); @@ -244,7 +227,7 @@ const ConditionalEmptyState = ( }), }; - if (projectSummary.connections_ct <= 0) { + if (projectSummary.connection_count <= 0) { args = { message: EMPTY_STATE_MESSAGE.connection, link: { @@ -253,7 +236,7 @@ const ConditionalEmptyState = ( params: { project_code: projectSummary.project_code }, }, }; - } else if (projectSummary.table_groups_ct <= 0) { + } else if (projectSummary.table_group_count <= 0) { args = { message: EMPTY_STATE_MESSAGE.tableGroup, link: { diff --git a/testgen/ui/components/frontend/js/types.js b/testgen/ui/components/frontend/js/types.js new file mode 100644 index 00000000..b5066483 --- /dev/null +++ b/testgen/ui/components/frontend/js/types.js @@ -0,0 +1,34 @@ +/** + * @typedef ProjectSummary + * @type {object} + * @property {string} project_code + * @property {number} connection_count + * @property {string} default_connection_id + * @property {number} table_group_count + * @property {number} profiling_run_count + * @property {number} test_suite_count + * @property {number} test_definition_count + * @property {number} test_run_count + * @property {bool} can_export_to_observability + * + * @typedef TestSuiteSummary + * @type {object} + * @property {string} id + * @property {string} project_code + * @property {string} test_suite + * @property {string} connection_name + * @property {string} table_groups_id + * @property {string} table_groups_name + * @property {string} test_suite_description + * @property {bool} export_to_observability + * @property {number} test_ct + * @property {string} last_complete_profile_run_id + * @property {string} latest_run_id + * @property {string} latest_run_start + * @property {number} last_run_test_ct + * @property {number} last_run_passed_ct + * @property {number} last_run_warning_ct + * @property {number} last_run_failed_ct + * @property {number} last_run_error_ct + * @property {number} last_run_dismissed_ct + */ diff --git a/testgen/ui/components/widgets/sidebar.py b/testgen/ui/components/widgets/sidebar.py index e1f8002e..4ee5cec8 100644 --- a/testgen/ui/components/widgets/sidebar.py +++ b/testgen/ui/components/widgets/sidebar.py @@ -1,7 +1,8 @@ import logging import time -from typing import Literal +from collections.abc import Iterable +from testgen.common.models.project import Project from testgen.common.version_service import Version from testgen.ui.components.utils.component import component from testgen.ui.navigation.menu import Menu @@ -17,7 +18,7 @@ def sidebar( key: str = SIDEBAR_KEY, - projects: list[dict[Literal["name", "codde"], str]] | None = None, + projects: Iterable[Project] | None = None, current_project: str | None = None, menu: Menu = None, current_page: str | None = None, @@ -39,7 +40,7 @@ def sidebar( component( id_="sidebar", props={ - "projects": projects, + "projects": [ {"code": item.project_code, "name": item.project_name} for item in projects ], "current_project": current_project, "menu": menu.filter_for_current_user().sort_items().unflatten().asdict(), "current_page": current_page, diff --git a/testgen/ui/navigation/page.py b/testgen/ui/navigation/page.py index 1ba4cfff..489a0fc1 100644 --- a/testgen/ui/navigation/page.py +++ b/testgen/ui/navigation/page.py @@ -8,8 +8,8 @@ from streamlit.runtime.state.query_params_proxy import QueryParamsProxy import testgen.ui.navigation.router +from testgen.common.models.project import Project from testgen.ui.navigation.menu import MenuItem -from testgen.ui.services import project_service from testgen.ui.session import session CanActivateGuard = typing.Callable[[], bool | str] @@ -34,7 +34,7 @@ def _navigate(self) -> None: for guard in self.can_activate or []: can_activate = guard() if can_activate != True: - session.sidebar_project = session.sidebar_project or project_service.get_projects()[0]["code"] + session.sidebar_project = session.sidebar_project or Project.select_where()[0].project_code if type(can_activate) == str: return self.router.navigate(to=can_activate, with_args={ "project_code": session.sidebar_project }) diff --git a/testgen/ui/navigation/router.py b/testgen/ui/navigation/router.py index 4aae9eee..c754aab6 100644 --- a/testgen/ui/navigation/router.py +++ b/testgen/ui/navigation/router.py @@ -7,6 +7,7 @@ import testgen.ui.navigation.page from testgen.common.mixpanel_service import MixpanelService +from testgen.common.models.project import Project from testgen.ui.session import session from testgen.utils.singleton import Singleton @@ -114,8 +115,8 @@ def navigate(self, /, to: str, with_args: dict = {}) -> None: # noqa: B006 def navigate_with_warning(self, warning: str, to: str, with_args: dict = {}) -> None: # noqa: B006 st.warning(warning) time.sleep(3) - self.navigate(to, with_args) - + session.sidebar_project = session.sidebar_project or Project.select_where()[0].project_code + self.navigate(to, {"project_code": session.sidebar_project, **with_args}) def set_query_params(self, with_args: dict) -> None: params = st.query_params diff --git a/testgen/ui/pdf/hygiene_issue_report.py b/testgen/ui/pdf/hygiene_issue_report.py index 31844a78..de5addb6 100644 --- a/testgen/ui/pdf/hygiene_issue_report.py +++ b/testgen/ui/pdf/hygiene_issue_report.py @@ -21,7 +21,7 @@ get_formatted_datetime, ) from testgen.ui.pdf.templates import DatakitchenTemplate -from testgen.ui.services.hygiene_issues_service import get_source_data +from testgen.ui.queries.source_data_queries import get_hygiene_issue_source_data from testgen.utils import get_base_url SECTION_MIN_AVAILABLE_HEIGHT = 120 @@ -186,7 +186,7 @@ def get_report_content(document, hi_data): yield Paragraph("Suggested Action", style=PARA_STYLE_H1) yield Paragraph(hi_data["suggested_action"], style=PARA_STYLE_TEXT) - sample_data_tuple = get_source_data(hi_data, limit=ISSUE_REPORT_SOURCE_DATA_LOOKUP_LIMIT) + sample_data_tuple = get_hygiene_issue_source_data(hi_data, limit=ISSUE_REPORT_SOURCE_DATA_LOOKUP_LIMIT) yield CondPageBreak(SECTION_MIN_AVAILABLE_HEIGHT) yield Paragraph("Sample Data", PARA_STYLE_H1) diff --git a/testgen/ui/pdf/test_result_report.py b/testgen/ui/pdf/test_result_report.py index f583c71e..c79b9be2 100644 --- a/testgen/ui/pdf/test_result_report.py +++ b/testgen/ui/pdf/test_result_report.py @@ -27,10 +27,8 @@ get_formatted_datetime, ) from testgen.ui.pdf.templates import DatakitchenTemplate -from testgen.ui.services.database_service import get_schema -from testgen.ui.services.test_results_service import ( - do_source_data_lookup, - do_source_data_lookup_custom, +from testgen.ui.queries.source_data_queries import get_test_issue_source_data, get_test_issue_source_data_custom +from testgen.ui.queries.test_result_queries import ( get_test_result_history, ) from testgen.utils import get_base_url @@ -104,7 +102,7 @@ def build_summary_table(document, tr_data): parent=TABLE_STYLE_DEFAULT, ) - test_timestamp = get_formatted_datetime(tr_data["test_time"]) + test_timestamp = get_formatted_datetime(tr_data["test_date"]) summary_table_data = [ ( "Test", @@ -166,7 +164,7 @@ def build_summary_table(document, tr_data): def build_history_table(document, tr_data): - history_data = get_test_result_history(get_schema(), tr_data, limit=15) + history_data = get_test_result_history(tr_data, limit=15) history_table_style = TableStyle( ( @@ -174,7 +172,7 @@ def build_history_table(document, tr_data): ), parent=TABLE_STYLE_DATA) - test_timestamp = pandas.to_datetime(tr_data["test_time"]) + test_timestamp = pandas.to_datetime(tr_data["test_date"]) style_per_status = { status: ParagraphStyle(f"result_{status}", parent=PARA_STYLE_CELL, textColor=color) @@ -243,17 +241,9 @@ def get_report_content(document, tr_data): yield build_history_table(document, tr_data) if tr_data["test_type"] == "CUSTOM": - sample_data_tuple = do_source_data_lookup_custom( - get_schema(), - tr_data, - limit=ISSUE_REPORT_SOURCE_DATA_LOOKUP_LIMIT, - ) + sample_data_tuple = get_test_issue_source_data_custom(tr_data, limit=ISSUE_REPORT_SOURCE_DATA_LOOKUP_LIMIT) else: - sample_data_tuple = do_source_data_lookup( - get_schema(), - tr_data, - limit=ISSUE_REPORT_SOURCE_DATA_LOOKUP_LIMIT, - ) + sample_data_tuple = get_test_issue_source_data(tr_data, limit=ISSUE_REPORT_SOURCE_DATA_LOOKUP_LIMIT) yield CondPageBreak(SECTION_MIN_AVAILABLE_HEIGHT) yield Paragraph("Sample Data", PARA_STYLE_H1) diff --git a/testgen/ui/queries/connection_queries.py b/testgen/ui/queries/connection_queries.py deleted file mode 100644 index fc86517f..00000000 --- a/testgen/ui/queries/connection_queries.py +++ /dev/null @@ -1,137 +0,0 @@ -from typing import cast - -import pandas as pd -import streamlit as st - -import testgen.ui.services.database_service as db - - -def get_by_id(connection_id): - str_schema = st.session_state["dbschema"] - str_sql = f""" - SELECT id::VARCHAR(50), project_code, connection_id, connection_name, - sql_flavor, COALESCE(sql_flavor_code, sql_flavor) AS sql_flavor_code, - project_host, project_port, project_user, - project_db, project_pw_encrypted, NULL as password, - max_threads, max_query_chars, url, connect_by_url, connect_by_key, private_key, - private_key_passphrase, http_path - FROM {str_schema}.connections - WHERE connection_id = '{connection_id}' - """ - return db.retrieve_data(str_sql) - - -def get_connections(project_code): - str_schema = st.session_state["dbschema"] - str_sql = f""" - SELECT id::VARCHAR(50), project_code, connection_id, connection_name, - sql_flavor, COALESCE(sql_flavor_code, sql_flavor) AS sql_flavor_code, - project_host, project_port, project_user, - project_db, project_pw_encrypted, NULL as password, - max_threads, max_query_chars, connect_by_url, url, connect_by_key, private_key, - private_key_passphrase, http_path - FROM {str_schema}.connections - WHERE project_code = '{project_code}' - ORDER BY connection_id - """ - return db.retrieve_data(str_sql) - - -def get_table_group_names_by_connection(schema: str, connection_ids: list[str]) -> pd.DataFrame: - items = [f"'{item}'" for item in connection_ids] - str_sql = f"""select table_groups_name from {schema}.table_groups where connection_id in ({",".join(items)})""" - return db.retrieve_data(str_sql) - - -def edit_connection(schema, connection, encrypted_password, encrypted_private_key, encrypted_private_key_passphrase): - encrypted_password_value = f"'{encrypted_password}'" if encrypted_password is not None else "null" - encrypted_private_key_value = f"'{encrypted_private_key}'" if encrypted_private_key is not None else "null" - encrypted_passphrase_value = f"'{encrypted_private_key_passphrase}'" if encrypted_private_key_passphrase is not None else "null" - - sql = f""" - UPDATE {schema}.connections - SET - project_code = '{connection["project_code"]}', - sql_flavor = '{connection["sql_flavor"]}', - sql_flavor_code = '{connection["sql_flavor_code"]}', - project_host = '{connection["project_host"]}', - project_port = '{connection["project_port"]}', - project_user = '{connection["project_user"]}', - project_db = '{connection["project_db"]}', - connection_name = '{connection["connection_name"]}', - max_threads = '{connection["max_threads"]}', - max_query_chars = '{connection["max_query_chars"]}', - url = '{connection["url"]}', - connect_by_key = '{connection["connect_by_key"]}', - connect_by_url = '{connection["connect_by_url"]}', - http_path = '{connection["http_path"]}', - project_pw_encrypted = {encrypted_password_value}, - private_key = {encrypted_private_key_value}, - private_key_passphrase = {encrypted_passphrase_value} - WHERE connection_id = '{connection["connection_id"]}'; - """ - db.execute_sql(sql) - st.cache_data.clear() - - -def add_connection( - schema: str, - connection: dict, - encrypted_password: str | None, - encrypted_private_key: str | None, - encrypted_private_key_passphrase: str | None, -) -> int: - sql_header = f"""INSERT INTO {schema}.connections - (project_code, sql_flavor, sql_flavor_code, url, connect_by_url, connect_by_key, - project_host, project_port, project_user, project_db, - connection_name, http_path, """ - - sql_footer = f""" SELECT - '{connection["project_code"]}' as project_code, - '{connection["sql_flavor"]}' as sql_flavor, - '{connection["sql_flavor_code"]}' as sql_flavor_code, - '{connection["url"]}' as url, - {connection["connect_by_url"]} as connect_by_url, - {connection["connect_by_key"]} as connect_by_key, - '{connection["project_host"]}' as project_host, - '{connection["project_port"]}' as project_port, - '{connection["project_user"]}' as project_user, - '{connection["project_db"]}' as project_db, - '{connection["connection_name"]}' as connection_name, - '{connection["http_path"]}' as http_path, """ - - if encrypted_password: - sql_header += "project_pw_encrypted, " - sql_footer += f""" '{encrypted_password}' as project_pw_encrypted, """ - - if encrypted_private_key: - sql_header += "private_key, " - sql_footer += f""" '{encrypted_private_key}' as private_key, """ - - if encrypted_private_key_passphrase: - sql_header += "private_key_passphrase, " - sql_footer += f""" '{encrypted_private_key_passphrase}' as private_key_passphrase, """ - - sql_header += """max_threads, max_query_chars) """ - - sql_footer += f""" '{connection["max_threads"]}' as max_threads, - '{connection["max_query_chars"]}' as max_query_chars""" - - sql = sql_header + sql_footer + " RETURNING connection_id" - - cursor = db.execute_sql(sql) - st.cache_data.clear() - if cursor and (primary_key := cast(tuple, cursor.fetchone())): - return primary_key[0] - - return 0 - - -def delete_connections(schema, connection_ids): - if connection_ids is None or len(connection_ids) == 0: - raise ValueError("No connection is specified.") - - items = [f"'{item}'" for item in connection_ids] - sql = f"""DELETE FROM {schema}.connections WHERE connection_id in ({",".join(items)})""" - db.execute_sql(sql) - st.cache_data.clear() diff --git a/testgen/ui/queries/profiling_queries.py b/testgen/ui/queries/profiling_queries.py index db755ab9..01d6373a 100644 --- a/testgen/ui/queries/profiling_queries.py +++ b/testgen/ui/queries/profiling_queries.py @@ -1,12 +1,7 @@ -import json -from datetime import datetime -from typing import NamedTuple - import pandas as pd import streamlit as st -import testgen.ui.services.database_service as db -from testgen.common.models import get_current_session +from testgen.ui.services.database_service import fetch_all_from_db, fetch_df_from_db from testgen.utils import is_uuid4 TAG_FIELDS = [ @@ -72,36 +67,9 @@ boolean_true_ct """ -@st.cache_data(show_spinner="Loading data ...") -def get_run_by_id(profile_run_id: str) -> pd.Series: - if not is_uuid4(profile_run_id): - return pd.Series() - - schema: str = st.session_state["dbschema"] - sql = f""" - SELECT profiling_starttime, table_groups_id::VARCHAR, table_groups_name, pr.project_code, pr.dq_score_profiling, - CASE WHEN pr.id = tg.last_complete_profile_run_id THEN true ELSE false END AS is_latest_run - FROM {schema}.profiling_runs pr - INNER JOIN {schema}.table_groups tg - ON pr.table_groups_id = tg.id - WHERE pr.id = '{profile_run_id}' - """ - df = db.retrieve_data(sql) - if not df.empty: - return df.iloc[0] - else: - return pd.Series() - @st.cache_data(show_spinner=False) -def get_profiling_results(profiling_run_id: str, table_name: str | None = None, column_name: str | None = None, sorting_columns = None): - db_session = get_current_session() - params = { - "profiling_run_id": profiling_run_id, - "table_name": table_name if table_name else "%%", - "column_name": column_name if column_name else "%%", - } - +def get_profiling_results(profiling_run_id: str, table_name: str | None = None, column_name: str | None = None, sorting_columns = None) -> pd.DataFrame: order_by = "" if sorting_columns is None: order_by = "ORDER BY schema_name, table_name, position" @@ -142,11 +110,13 @@ def get_profiling_results(profiling_run_id: str, table_name: str | None = None, AND column_name ILIKE :column_name {order_by}; """ + params = { + "profiling_run_id": profiling_run_id, + "table_name": table_name or "%%", + "column_name": column_name or "%%", + } - results = db_session.execute(query, params=params) - columns = [column.name for column in results.cursor.description] - - return pd.DataFrame(list(results), columns=columns) + return fetch_df_from_db(query, params) @st.cache_data(show_spinner=False) @@ -160,8 +130,9 @@ def get_table_by_id( if not is_uuid4(table_id): return None - condition = f"WHERE table_id = '{table_id}'" - return get_tables_by_condition(condition, include_tags, include_has_test_runs, include_active_tests, include_scores)[0] + condition = "WHERE table_id = :table_id" + params = {"table_id": table_id} + return get_tables_by_condition(condition, params, include_tags, include_has_test_runs, include_active_tests, include_scores)[0] def get_tables_by_id( @@ -170,12 +141,13 @@ def get_tables_by_id( include_has_test_runs: bool = False, include_active_tests: bool = False, include_scores: bool = False, -) -> list[dict] | None: - condition = f""" +) -> list[dict]: + condition = """ INNER JOIN ( - SELECT UNNEST(ARRAY [{", ".join([ f"'{col}'" for col in table_ids if is_uuid4(col) ])}]) AS id + SELECT UNNEST(ARRAY [:table_ids]) AS id ) selected ON (table_chars.table_id = selected.id::UUID)""" - return get_tables_by_condition(condition, include_tags, include_has_test_runs, include_active_tests, include_scores) + params = {"table_ids": table_ids} + return get_tables_by_condition(condition, params, include_tags, include_has_test_runs, include_active_tests, include_scores) def get_tables_by_table_group( @@ -184,31 +156,32 @@ def get_tables_by_table_group( include_has_test_runs: bool = False, include_active_tests: bool = False, include_scores: bool = False, -) -> list[dict] | None: +) -> list[dict]: if not is_uuid4(table_group_id): return None - condition = f"WHERE table_chars.table_groups_id = '{table_group_id}'" - return get_tables_by_condition(condition, include_tags, include_has_test_runs, include_active_tests, include_scores) + condition = "WHERE table_chars.table_groups_id = :table_group_id" + params = {"table_group_id": table_group_id} + return get_tables_by_condition(condition, params, include_tags, include_has_test_runs, include_active_tests, include_scores) def get_tables_by_condition( filter_condition: str, + filter_params: dict, include_tags: bool = False, include_has_test_runs: bool = False, include_active_tests: bool = False, include_scores: bool = False, -) -> list[dict] | None: - schema: str = st.session_state["dbschema"] +) -> list[dict]: query = f""" - {f""" + {""" WITH active_test_definitions AS ( SELECT test_defs.table_groups_id, test_defs.table_name, COUNT(*) AS count - FROM {schema}.test_definitions test_defs - LEFT JOIN {schema}.data_column_chars ON ( + FROM test_definitions test_defs + LEFT JOIN data_column_chars ON ( test_defs.table_groups_id = data_column_chars.table_groups_id AND test_defs.table_name = data_column_chars.table_name AND test_defs.column_name = data_column_chars.column_name @@ -241,11 +214,11 @@ def get_tables_by_condition( -- Table Groups Tags {", ".join([ f"table_groups.{tag} AS table_group_{tag}" for tag in TAG_FIELDS if tag != "aggregation_level" ])}, """ if include_tags else ""} - {f""" + {""" -- Has Test Runs EXISTS( SELECT 1 - FROM {schema}.test_results + FROM test_results WHERE table_groups_id = table_chars.table_groups_id AND table_name = table_chars.table_name ) AS has_test_runs, @@ -263,12 +236,12 @@ def get_tables_by_condition( table_chars.last_complete_profile_run_id::VARCHAR AS profile_run_id, profiling_starttime AS profile_run_date, TRUE AS is_latest_profile - FROM {schema}.data_table_chars table_chars - LEFT JOIN {schema}.profiling_runs ON ( + FROM data_table_chars table_chars + LEFT JOIN profiling_runs ON ( table_chars.last_complete_profile_run_id = profiling_runs.id ) - {f""" - LEFT JOIN {schema}.table_groups ON ( + {""" + LEFT JOIN table_groups ON ( table_chars.table_groups_id = table_groups.id ) """ if include_tags else ""} @@ -282,10 +255,8 @@ def get_tables_by_condition( ORDER BY table_name; """ - results = db.retrieve_data(query) - if not results.empty: - # to_json converts datetimes, NaN, etc, to JSON-safe values (Note: to_dict does not) - return json.loads(results.to_json(orient="records")) + results = fetch_all_from_db(query, filter_params) + return [ dict(row) for row in results ] @st.cache_data(show_spinner=False) @@ -299,8 +270,9 @@ def get_column_by_id( if not is_uuid4(column_id): return None - condition = f"WHERE column_chars.column_id = '{column_id}'" - return get_columns_by_condition(condition, include_tags, include_has_test_runs, include_active_tests, include_scores)[0] + condition = "WHERE column_chars.column_id = :column_id" + params = {"column_id": column_id} + return get_columns_by_condition(condition, params, include_tags, include_has_test_runs, include_active_tests, include_scores)[0] @st.cache_data(show_spinner="Loading data ...") @@ -313,13 +285,17 @@ def get_column_by_name( include_active_tests: bool = False, include_scores: bool = False, ) -> dict | None: - - condition = f""" - WHERE column_chars.column_name = '{column_name}' - AND column_chars.table_name = '{table_name}' - AND column_chars.table_groups_id = '{table_group_id}' + condition = """ + WHERE column_chars.column_name = :column_name + AND column_chars.table_name = :table_name + AND column_chars.table_groups_id = :table_group_id """ - return get_columns_by_condition(condition, include_tags, include_has_test_runs, include_active_tests, include_scores)[0] + params = { + "column_name": column_name, + "table_name": table_name, + "table_group_id": table_group_id, + } + return get_columns_by_condition(condition, params, include_tags, include_has_test_runs, include_active_tests, include_scores)[0] def get_columns_by_id( @@ -328,12 +304,13 @@ def get_columns_by_id( include_has_test_runs: bool = False, include_active_tests: bool = False, include_scores: bool = False, -) -> list[dict] | None: - condition = f""" +) -> list[dict]: + condition = """ INNER JOIN ( - SELECT UNNEST(ARRAY [{", ".join([ f"'{col}'" for col in column_ids if is_uuid4(col) ])}]) AS id + SELECT UNNEST(ARRAY [:column_ids]) AS id ) selected ON (column_chars.column_id = selected.id::UUID)""" - return get_columns_by_condition(condition, include_tags, include_has_test_runs, include_active_tests, include_scores) + params = {"column_ids": [ col for col in column_ids if is_uuid4(col) ]} + return get_columns_by_condition(condition, params, include_tags, include_has_test_runs, include_active_tests, include_scores) def get_columns_by_table_group( @@ -342,23 +319,23 @@ def get_columns_by_table_group( include_has_test_runs: bool = False, include_active_tests: bool = False, include_scores: bool = False, -) -> list[dict] | None: +) -> list[dict]: if not is_uuid4(table_group_id): return None - condition = f"WHERE column_chars.table_groups_id = '{table_group_id}'" - return get_columns_by_condition(condition, include_tags, include_has_test_runs, include_active_tests, include_scores) + condition = "WHERE column_chars.table_groups_id = :table_group_id" + params = {"table_group_id": table_group_id} + return get_columns_by_condition(condition, params, include_tags, include_has_test_runs, include_active_tests, include_scores) def get_columns_by_condition( filter_condition: str, + filter_params: dict, include_tags: bool = False, include_has_test_runs: bool = False, include_active_tests: bool = False, include_scores: bool = False, -) -> list[dict] | None: - schema: str = st.session_state["dbschema"] - +) -> list[dict]: query = f""" SELECT column_chars.column_id::VARCHAR AS id, @@ -391,21 +368,21 @@ def get_columns_by_condition( column_chars.last_complete_profile_run_id::VARCHAR AS profile_run_id, run_date AS profile_run_date, TRUE AS is_latest_profile, - {f""" + {""" -- Has Test Runs EXISTS( SELECT 1 - FROM {schema}.test_results + FROM test_results WHERE table_groups_id = column_chars.table_groups_id AND table_name = column_chars.table_name AND column_names = column_chars.column_name ) AS has_test_runs, """ if include_has_test_runs else ""} - {f""" + {""" -- Test Definition Count ( SELECT COUNT(*) - FROM {schema}.test_definitions + FROM test_definitions WHERE table_groups_id = column_chars.table_groups_id AND table_name = column_chars.table_name AND column_name = column_chars.column_name @@ -418,16 +395,16 @@ def get_columns_by_condition( column_chars.dq_score_testing, """ if include_scores else ""} {COLUMN_PROFILING_FIELDS} - FROM {schema}.data_column_chars column_chars - {f""" - LEFT JOIN {schema}.data_table_chars table_chars ON ( + FROM data_column_chars column_chars + {""" + LEFT JOIN data_table_chars table_chars ON ( column_chars.table_id = table_chars.table_id ) - LEFT JOIN {schema}.table_groups ON ( + LEFT JOIN table_groups ON ( column_chars.table_groups_id = table_groups.id ) """ if include_tags else ""} - LEFT JOIN {schema}.profile_results ON ( + LEFT JOIN profile_results ON ( column_chars.last_complete_profile_run_id = profile_results.profile_run_id AND column_chars.table_name = profile_results.table_name AND column_chars.column_name = profile_results.column_name @@ -435,11 +412,8 @@ def get_columns_by_condition( {filter_condition} ORDER BY table_name, ordinal_position; """ - - results = db.retrieve_data(query) - if not results.empty: - # to_json converts datetimes, NaN, etc, to JSON-safe values (Note: to_dict does not) - return json.loads(results.to_json(orient="records")) + results = fetch_all_from_db(query, filter_params) + return [ dict(row) for row in results ] @st.cache_data(show_spinner=False) @@ -447,12 +421,6 @@ def get_hygiene_issues(profile_run_id: str, table_name: str, column_name: str | if not profile_run_id: return [] - schema: str = st.session_state["dbschema"] - - column_condition = "" - if column_name: - column_condition = f"AND column_name = '{column_name}'" - query = f""" WITH pii_results AS ( SELECT id, @@ -461,23 +429,23 @@ def get_hygiene_issues(profile_run_id: str, table_name: str, column_name: str | WHEN detail LIKE 'Risk: MODERATE%%' THEN 'Moderate' ELSE null END AS pii_risk - FROM {schema}.profile_anomaly_results + FROM profile_anomaly_results ) SELECT column_name, anomaly_name, issue_likelihood, detail, pii_risk - FROM {schema}.profile_anomaly_results anomaly_results - LEFT JOIN {schema}.profile_anomaly_types anomaly_types ON ( + FROM profile_anomaly_results anomaly_results + LEFT JOIN profile_anomaly_types anomaly_types ON ( anomaly_types.id = anomaly_results.anomaly_id ) LEFT JOIN pii_results ON ( anomaly_results.id = pii_results.id ) - WHERE profile_run_id = '{profile_run_id}' - AND table_name = '{table_name}' - {column_condition} + WHERE profile_run_id = :profile_run_id + AND table_name = :table_name + {"AND column_name = :column_name" if column_name else ""} AND COALESCE(disposition, 'Confirmed') = 'Confirmed' ORDER BY CASE issue_likelihood @@ -493,29 +461,10 @@ def get_hygiene_issues(profile_run_id: str, table_name: str, column_name: str | END, column_name; """ - - results = db.retrieve_data(query) - return [row.to_dict() for _, row in results.iterrows()] - - -class LatestProfilingRun(NamedTuple): - id: str - run_time: datetime - - -def get_latest_run_date(project_code: str) -> LatestProfilingRun | None: - session = get_current_session() - result = session.execute( - """ - SELECT id, profiling_starttime - FROM profiling_runs - WHERE project_code = :project_code - AND status = 'Complete' - ORDER BY profiling_starttime DESC - LIMIT 1 - """, - params={"project_code": project_code}, - ) - if result and (latest_run := result.first()): - return LatestProfilingRun(str(latest_run.id), latest_run.profiling_starttime) - return None + params = { + "profile_run_id": profile_run_id, + "table_name": table_name, + "column_name": column_name, + } + results = fetch_all_from_db(query, params) + return [ dict(row) for row in results ] diff --git a/testgen/ui/queries/profiling_run_queries.py b/testgen/ui/queries/profiling_run_queries.py deleted file mode 100644 index ea40f93d..00000000 --- a/testgen/ui/queries/profiling_run_queries.py +++ /dev/null @@ -1,47 +0,0 @@ -import streamlit as st - -import testgen.ui.services.database_service as db -from testgen.common import date_service -from testgen.common.models import get_current_session - - -def update_status(profile_run_id: str, status: str) -> None: - schema: str = st.session_state["dbschema"] - now = date_service.get_now_as_string() - - sql = f""" - UPDATE {schema}.profiling_runs - SET status = '{status}', - profiling_endtime = '{now}' - WHERE id = '{profile_run_id}'::UUID; - """ - db.execute_sql(sql) - st.cache_data.clear() - - -def cancel_all_running() -> None: - schema: str = db.get_schema() - db.execute_sql(f""" - UPDATE {schema}.profiling_runs - SET status = 'Cancelled' - WHERE status = 'Running'; - """) - - -def cascade_delete_multiple_profiling_runs(profiling_run_ids: list[str]) -> None: - session = get_current_session() - - if not profiling_run_ids: - raise ValueError("No profiling run is specified.") - - params = {f"id_{idx}": value for idx, value in enumerate(profiling_run_ids)} - param_keys = [f":{slot}" for slot in params.keys()] - - with session.begin(): - session.execute(f"DELETE FROM profile_pair_rules WHERE profile_run_id IN ({', '.join(param_keys)})", params=params) - session.execute(f"DELETE FROM profile_anomaly_results WHERE profile_run_id IN ({', '.join(param_keys)})", params=params) - session.execute(f"DELETE FROM profile_results WHERE profile_run_id IN ({', '.join(param_keys)})", params=params) - session.execute(f"DELETE FROM profiling_runs WHERE id IN ({', '.join(param_keys)})", params=params) - session.commit() - - st.cache_data.clear() diff --git a/testgen/ui/queries/project_queries.py b/testgen/ui/queries/project_queries.py deleted file mode 100644 index 342702e9..00000000 --- a/testgen/ui/queries/project_queries.py +++ /dev/null @@ -1,64 +0,0 @@ -import pandas as pd -import streamlit as st - -import testgen.ui.services.database_service as db -import testgen.ui.services.query_service as query_service - - -@st.cache_data(show_spinner=False) -def get_projects(): - schema: str = st.session_state["dbschema"] - return query_service.run_project_lookup_query(schema) - - -@st.cache_data(show_spinner=False) -def get_summary_by_code(project_code: str) -> pd.Series: - schema: str = st.session_state["dbschema"] - sql = f""" - SELECT ( - SELECT COUNT(*) AS count - FROM {schema}.connections - WHERE connections.project_code = '{project_code}' - ) AS connections_ct, - ( - SELECT connection_id - FROM {schema}.connections - WHERE connections.project_code = '{project_code}' - LIMIT 1 - ) AS default_connection_id, - ( - SELECT COUNT(*) - FROM {schema}.table_groups - WHERE table_groups.project_code = '{project_code}' - ) AS table_groups_ct, - ( - SELECT COUNT(*) - FROM {schema}.profiling_runs - LEFT JOIN {schema}.table_groups ON profiling_runs.table_groups_id = table_groups.id - WHERE table_groups.project_code = '{project_code}' - ) AS profiling_runs_ct, - ( - SELECT COUNT(*) - FROM {schema}.test_suites - WHERE test_suites.project_code = '{project_code}' - ) AS test_suites_ct, - ( - SELECT COUNT(*) - FROM {schema}.test_definitions - LEFT JOIN {schema}.test_suites ON test_definitions.test_suite_id = test_suites.id - WHERE test_suites.project_code = '{project_code}' - ) AS test_definitions_ct, - ( - SELECT COUNT(*) - FROM {schema}.test_runs - LEFT JOIN {schema}.test_suites ON test_runs.test_suite_id = test_suites.id - WHERE test_suites.project_code = '{project_code}' - ) AS test_runs_ct, - ( - SELECT COALESCE(observability_api_key, '') <> '' - AND COALESCE(observability_api_url, '') <> '' - FROM {schema}.projects - WHERE project_code = '{project_code}' - ) AS can_export_to_observability; - """ - return db.retrieve_data(sql).iloc[0] diff --git a/testgen/ui/queries/scoring_queries.py b/testgen/ui/queries/scoring_queries.py index f6a72741..f6199b71 100644 --- a/testgen/ui/queries/scoring_queries.py +++ b/testgen/ui/queries/scoring_queries.py @@ -1,10 +1,9 @@ from collections import defaultdict -import pandas as pd import streamlit as st -from testgen.common.models import engine from testgen.common.models.scores import ScoreCard, ScoreCategory, ScoreDefinition, SelectedIssue +from testgen.ui.services.database_service import fetch_all_from_db @st.cache_data(show_spinner="Loading data :gray[:small[(This might take a few minutes)]] ...") @@ -15,17 +14,16 @@ def get_all_score_cards(project_code: str) -> list["ScoreCard"]: ] -def get_score_card_issue_reports(selected_issues: list["SelectedIssue"]): +def get_score_card_issue_reports(selected_issues: list["SelectedIssue"]) -> list[dict]: profile_ids = [] test_ids = [] for issue in selected_issues: id_list = profile_ids if issue["issue_type"] == "hygiene" else test_ids id_list.append(issue["id"]) - schema: str = st.session_state["dbschema"] results = [] if profile_ids: - profile_query = f""" + profile_query = """ SELECT results.id::VARCHAR, 'hygiene' AS issue_type, @@ -55,32 +53,32 @@ def get_score_card_issue_reports(selected_issues: list["SelectedIssue"]): COALESCE(column_chars.transform_level, table_chars.transform_level, groups.transform_level) as transform_level, COALESCE(column_chars.aggregation_level, table_chars.aggregation_level) as aggregation_level, COALESCE(column_chars.data_product, table_chars.data_product, groups.data_product) as data_product - FROM {schema}.profile_anomaly_results results - INNER JOIN {schema}.profile_anomaly_types types + FROM profile_anomaly_results results + INNER JOIN profile_anomaly_types types ON results.anomaly_id = types.id - INNER JOIN {schema}.profiling_runs runs + INNER JOIN profiling_runs runs ON results.profile_run_id = runs.id - INNER JOIN {schema}.table_groups groups + INNER JOIN table_groups groups ON results.table_groups_id = groups.id - LEFT JOIN {schema}.data_column_chars column_chars + LEFT JOIN data_column_chars column_chars ON (groups.id = column_chars.table_groups_id AND results.schema_name = column_chars.schema_name AND results.table_name = column_chars.table_name AND results.column_name = column_chars.column_name) - LEFT JOIN {schema}.data_table_chars table_chars + LEFT JOIN data_table_chars table_chars ON column_chars.table_id = table_chars.table_id - WHERE results.id IN ({",".join([f"'{issue_id}'" for issue_id in profile_ids])}); + WHERE results.id IN :profile_ids; """ - profile_results = pd.read_sql_query(profile_query, engine) - results.extend([row.to_dict() for _, row in profile_results.iterrows()]) + profile_results = fetch_all_from_db(profile_query, {"profile_ids": tuple(profile_ids)}) + results.extend([dict(row) for row in profile_results]) if test_ids: - test_query = f""" + test_query = """ SELECT results.id::VARCHAR AS test_result_id, 'test' AS issue_type, results.result_status, - results.test_time, + results.test_time AS test_date, types.test_name_short, types.test_name_long, results.test_description, @@ -104,6 +102,11 @@ def get_score_card_issue_reports(selected_issues: list["SelectedIssue"]): results.auto_gen, results.test_suite_id, results.test_definition_id::VARCHAR as test_definition_id_runtime, + CASE + WHEN results.auto_gen = TRUE + THEN definitions.id + ELSE results.test_definition_id + END::VARCHAR AS test_definition_id_current, results.table_groups_id::VARCHAR, types.id::VARCHAR AS test_type_id, column_chars.description as column_description, @@ -116,24 +119,31 @@ def get_score_card_issue_reports(selected_issues: list["SelectedIssue"]): COALESCE(column_chars.transform_level, table_chars.transform_level, groups.transform_level) as transform_level, COALESCE(column_chars.aggregation_level, table_chars.aggregation_level) as aggregation_level, COALESCE(column_chars.data_product, table_chars.data_product, groups.data_product) as data_product - FROM {schema}.test_results results - INNER JOIN {schema}.test_types types + FROM test_results results + INNER JOIN test_types types ON (results.test_type = types.test_type) - INNER JOIN {schema}.test_suites suites + INNER JOIN test_suites suites ON (results.test_suite_id = suites.id) - INNER JOIN {schema}.table_groups groups + INNER JOIN table_groups groups ON (results.table_groups_id = groups.id) - LEFT JOIN {schema}.data_column_chars column_chars + LEFT JOIN test_definitions definitions + ON (results.test_suite_id = definitions.test_suite_id + AND results.table_name = definitions.table_name + AND COALESCE(results.column_names, 'N/A') = COALESCE(definitions.column_name, 'N/A') + AND results.test_type = definitions.test_type + AND results.auto_gen = TRUE + AND definitions.last_auto_gen_date IS NOT NULL) + LEFT JOIN data_column_chars column_chars ON (groups.id = column_chars.table_groups_id AND results.schema_name = column_chars.schema_name AND results.table_name = column_chars.table_name AND results.column_names = column_chars.column_name) - LEFT JOIN {schema}.data_table_chars table_chars + LEFT JOIN data_table_chars table_chars ON column_chars.table_id = table_chars.table_id - WHERE results.id IN ({",".join([f"'{issue_id}'" for issue_id in test_ids])}); + WHERE results.id IN :test_ids; """ - test_results = pd.read_sql_query(test_query, engine) - results.extend([row.to_dict() for _, row in test_results.iterrows()]) + test_results = fetch_all_from_db(test_query, {"test_ids": tuple(test_ids)}) + results.extend([dict(row) for row in test_results]) return results @@ -167,25 +177,25 @@ def get_score_category_values(project_code: str) -> dict[ScoreCategory, list[str UNNEST(array[{', '.join([quote(c) for c in categories])}]) as category, UNNEST(array[{', '.join(categories)}]) AS value FROM v_dq_test_scoring_latest_by_column - WHERE project_code = '{project_code}' + WHERE project_code = :project_code UNION SELECT DISTINCT UNNEST(array[{', '.join([quote(c) for c in categories])}]) as category, UNNEST(array[{', '.join(categories)}]) AS value FROM v_dq_profile_scoring_latest_by_column - WHERE project_code = '{project_code}' + WHERE project_code = :project_code ORDER BY value """ - results = pd.read_sql_query(query, engine) - for _, row in results.iterrows(): - if row["category"] and row["value"]: - values[row["category"]].append(row["value"]) + results = fetch_all_from_db(query, {"project_code": project_code}) + for row in results: + if row.category and row.value: + values[row.category].append(row.value) return values @st.cache_data(show_spinner="Loading data :gray[:small[(This might take a few minutes)]] ...") def get_column_filters(project_code: str) -> list[dict]: - query = f""" + query = """ SELECT data_column_chars.column_id::text AS column_id, data_column_chars.column_name AS name, @@ -195,7 +205,8 @@ def get_column_filters(project_code: str) -> list[dict]: table_groups.table_groups_name AS table_group FROM data_column_chars INNER JOIN table_groups ON (table_groups.id = data_column_chars.table_groups_id) - WHERE table_groups.project_code = '{project_code}' + WHERE table_groups.project_code = :project_code ORDER BY table_name, ordinal_position; """ - return [row.to_dict() for _, row in pd.read_sql_query(query, engine).iterrows()] + results = fetch_all_from_db(query, {"project_code": project_code}) + return [dict(row) for row in results] diff --git a/testgen/ui/queries/source_data_queries.py b/testgen/ui/queries/source_data_queries.py new file mode 100644 index 00000000..ce4cd9b6 --- /dev/null +++ b/testgen/ui/queries/source_data_queries.py @@ -0,0 +1,232 @@ +import logging +from dataclasses import dataclass +from typing import Literal + +import pandas as pd +import streamlit as st + +from testgen.common.clean_sql import ConcatColumnList +from testgen.common.database.database_service import replace_params +from testgen.common.models.connection import Connection, SQLFlavor +from testgen.common.models.test_definition import TestDefinition +from testgen.common.read_file import replace_templated_functions +from testgen.ui.services.database_service import fetch_from_target_db, fetch_one_from_db +from testgen.utils import to_dataframe + +LOG = logging.getLogger("testgen") + + +@st.cache_data(show_spinner=False) +def get_hygiene_issue_source_data( + issue_data: dict, + limit: int | None = None, +) -> tuple[Literal["OK"], None, str, pd.DataFrame] | tuple[Literal["NA", "ND", "ERR"], str, str | None, None]: + def generate_lookup_query(test_id: str, detail_exp: str, column_names: list[str], sql_flavor: SQLFlavor) -> str: + if test_id in {"1019", "1020"}: + start_index = detail_exp.find("Columns: ") + if start_index == -1: + columns = [col.strip() for col in column_names.split(",")] + else: + start_index += len("Columns: ") + column_names_str = detail_exp[start_index:] + columns = [col.strip() for col in column_names_str.split(",")] + quote = "`" if sql_flavor == "databricks" else '"' + queries = [ + f"SELECT '{column}' AS column_name, MAX({quote}{column}{quote}) AS max_date_available FROM {{TARGET_SCHEMA}}.{{TABLE_NAME}}" + for column in columns + ] + sql_query = " UNION ALL ".join(queries) + " ORDER BY max_date_available DESC;" + else: + sql_query = "" + return sql_query + + lookup_query = None + try: + lookup_data = _get_lookup_data(issue_data["table_groups_id"], issue_data["anomaly_id"], "Profile Anomaly") + if not lookup_data: + return "NA", "Source data lookup is not available for this hygiene issue.", None, None + + lookup_query = ( + generate_lookup_query( + issue_data["anomaly_id"], issue_data["detail"], issue_data["column_name"], lookup_data.sql_flavor + ) + if lookup_data.lookup_query == "created_in_ui" + else lookup_data.lookup_query + ) + + if not lookup_query: + return "NA", "Source data lookup is not available for this hygiene issue.", None, None + + params = { + "TARGET_SCHEMA": issue_data["schema_name"], + "TABLE_NAME": issue_data["table_name"], + "COLUMN_NAME": issue_data["column_name"], + "DETAIL_EXPRESSION": issue_data["detail"], + "PROFILE_RUN_DATE": issue_data["profiling_starttime"], + } + lookup_query = replace_params(lookup_query, params) + lookup_query = replace_templated_functions(lookup_query, lookup_data.sql_flavor) + + connection = Connection.get_by_table_group(issue_data["table_groups_id"]) + results = fetch_from_target_db(connection, lookup_query) + + if results: + df = to_dataframe(results) + if limit: + df = df.sample(n=min(len(df), limit)).sort_index() + return "OK", None, lookup_query, df + else: + return ( + "ND", + "Data that violates hygiene issue criteria is not present in the current dataset.", + lookup_query, + None, + ) + except Exception as e: + LOG.exception("Source data lookup for hygiene issue encountered an error.") + return "ERR", f"Source data lookup encountered an error:\n\n{e.args[0]}", lookup_query, None + + +@st.cache_data(show_spinner=False) +def get_test_issue_source_data( + issue_data: dict, + limit: int | None = None, +) -> tuple[Literal["OK"], None, str, pd.DataFrame] | tuple[Literal["NA", "ND", "ERR"], str, str | None, None]: + lookup_query = None + try: + lookup_data = _get_lookup_data(issue_data["table_groups_id"], issue_data["test_type_id"], "Test Results") + + if not lookup_data or not lookup_data.lookup_query: + return "NA", "Source data lookup is not available for this test.", None, None + + test_definition = TestDefinition.get(issue_data["test_definition_id_current"]) + if not test_definition: + return "NA", "Test definition no longer exists.", None, None + + params = { + "TARGET_SCHEMA": issue_data["schema_name"], + "TABLE_NAME": issue_data["table_name"], + "COLUMN_NAME": issue_data["column_names"], + "TEST_DATE": str(issue_data["test_date"]), + "CUSTOM_QUERY": test_definition.custom_query, + "BASELINE_VALUE": test_definition.baseline_value, + "BASELINE_CT": test_definition.baseline_ct, + "BASELINE_AVG": test_definition.baseline_avg, + "BASELINE_SD": test_definition.baseline_sd, + "LOWER_TOLERANCE": test_definition.lower_tolerance, + "UPPER_TOLERANCE": test_definition.upper_tolerance, + "THRESHOLD_VALUE": test_definition.threshold_value, + "SUBSET_CONDITION": test_definition.subset_condition or "1=1", + "GROUPBY_NAMES": test_definition.groupby_names, + "HAVING_CONDITION": test_definition.having_condition, + "MATCH_SCHEMA_NAME": test_definition.match_schema_name, + "MATCH_TABLE_NAME": test_definition.match_table_name, + "MATCH_COLUMN_NAMES": test_definition.match_column_names, + "MATCH_SUBSET_CONDITION": test_definition.match_subset_condition or "1=1", + "MATCH_GROUPBY_NAMES": test_definition.match_groupby_names, + "MATCH_HAVING_CONDITION": test_definition.match_having_condition, + "COLUMN_NAME_NO_QUOTES": issue_data["column_names"], + "WINDOW_DATE_COLUMN": test_definition.window_date_column, + "WINDOW_DAYS": test_definition.window_days, + "CONCAT_COLUMNS": ConcatColumnList(issue_data["column_names"], ""), + "CONCAT_MATCH_GROUPBY": ConcatColumnList(test_definition.match_groupby_names, ""), + } + + lookup_query = replace_params(lookup_data.lookup_query, params) + lookup_query = replace_templated_functions(lookup_query, lookup_data.sql_flavor) + + connection = Connection.get_by_table_group(issue_data["table_groups_id"]) + results = fetch_from_target_db(connection, lookup_query) + + if results: + df = to_dataframe(results) + if limit: + df = df.sample(n=min(len(df), limit)).sort_index() + return "OK", None, lookup_query, df + else: + return "ND", "Data that violates test criteria is not present in the current dataset.", lookup_query, None + except Exception as e: + LOG.exception("Source data lookup for test encountered an error.") + return "ERR", f"Source data lookup encountered an error:\n\n{e.args[0]}", lookup_query, None + + +@st.cache_data(show_spinner=False) +def get_test_issue_source_data_custom( + issue_data: dict, + limit: int | None = None, +) -> tuple[Literal["OK"], None, str, pd.DataFrame] | tuple[Literal["NA", "ND", "ERR"], str, str | None, None]: + lookup_query = None + try: + lookup_data = _get_lookup_data_custom(issue_data["test_definition_id_current"]) + + if not lookup_data or not lookup_data.lookup_query: + return "NA", "Source data lookup is not available for this test.", None, None + + params = { + "DATA_SCHEMA": issue_data["schema_name"], + } + lookup_query = replace_params(lookup_data.lookup_query, params) + + connection = Connection.get_by_table_group(issue_data["table_groups_id"]) + results = fetch_from_target_db(connection, lookup_query) + + if results: + df = to_dataframe(results) + if limit: + df = df.sample(n=min(len(df), limit)).sort_index() + return "OK", None, lookup_query, df + else: + return "ND", "Data that violates test criteria is not present in the current dataset.", lookup_query, None + except Exception as e: + LOG.exception("Source data lookup for custom test encountered an error.") + return "ERR", f"Source data lookup encountered an error:\n\n{e.args[0]}", lookup_query, None + + +@dataclass +class LookupData: + lookup_query: str + sql_flavor: SQLFlavor | None + + +def _get_lookup_data( + table_group_id: str, + anomaly_id: str, + error_type: Literal["Profile Anomaly", "Test Results"], +) -> LookupData | None: + result = fetch_one_from_db( + """ + SELECT + t.lookup_query, + c.sql_flavor + FROM target_data_lookups t + INNER JOIN table_groups tg + ON (:table_group_id = tg.id) + INNER JOIN connections c + ON (tg.connection_id = c.connection_id) + AND (t.sql_flavor = c.sql_flavor) + WHERE t.error_type = :error_type + AND t.test_id = :anomaly_id + AND t.lookup_query > ''; + """, + { + "table_group_id": table_group_id, + "error_type": error_type, + "anomaly_id": anomaly_id, + }, + ) + return LookupData(**result) if result else None + + +def _get_lookup_data_custom( + test_definition_id: str, +) -> LookupData | None: + result = fetch_one_from_db( + """ + SELECT + d.custom_query as lookup_query + FROM test_definitions d + WHERE d.id = :test_definition_id; + """, + {"test_definition_id": test_definition_id}, + ) + return LookupData(**result) if result else None diff --git a/testgen/ui/queries/table_group_queries.py b/testgen/ui/queries/table_group_queries.py index 040a1c6d..eac18e06 100644 --- a/testgen/ui/queries/table_group_queries.py +++ b/testgen/ui/queries/table_group_queries.py @@ -1,249 +1,100 @@ -import uuid - -import streamlit as st - -import testgen.ui.services.database_service as db - - -def _get_select_statement(schema): - return f""" - WITH table_groups AS ( - SELECT table_groups.*, connections.connection_name, connections.sql_flavor, - COALESCE(connections.sql_flavor_code, connections.sql_flavor) AS sql_flavor_code - FROM {schema}.table_groups - INNER JOIN {schema}.connections ON connections.connection_id = table_groups.connection_id - ) - SELECT id::VARCHAR(50), project_code, connection_id, connection_name, sql_flavor, sql_flavor_code, - table_groups_name, table_group_schema, - profiling_include_mask, profiling_exclude_mask, - profiling_table_set, - profile_id_column_mask, profile_sk_column_mask, - description, data_source, source_system, source_process, data_location, - business_domain, stakeholder_group, transform_level, data_product, - CASE WHEN profile_use_sampling = 'Y' THEN true ELSE false END AS profile_use_sampling, - profile_sample_percent, profile_sample_min_count, - profiling_delay_days, profile_flag_cdes, include_in_dashboard - FROM table_groups - """ - - -@st.cache_data(show_spinner=False) -def get_by_id(schema, table_group_id): - sql = _get_select_statement(schema) - sql += f"""WHERE id = '{table_group_id}' - ORDER BY table_groups_name - """ - return db.retrieve_data(sql) - - -@st.cache_data(show_spinner=False) -def get_by_name(project_code: str, table_group_name: str) -> dict | None: - schema: str = st.session_state["dbschema"] - sql = _get_select_statement(schema) - sql += f"""WHERE project_code = '{project_code}' AND table_groups_name = '{table_group_name}';""" - results = db.retrieve_data(sql) - if results.empty: - return None - return results.iloc[0].to_dict() - - -def get_test_suite_ids_by_table_group_names(schema, table_group_names): - names_str = ", ".join([f"'{item}'" for item in table_group_names]) - sql = f""" - SELECT ts.id::VARCHAR - FROM {schema}.test_suites ts - INNER JOIN {schema}.table_groups tg ON tg.id = ts.table_groups_id - WHERE tg.table_groups_name in ({names_str}) - """ - return db.retrieve_data(sql) - - -def get_table_group_dependencies(schema, table_group_names): - if table_group_names is None or len(table_group_names) == 0: - raise ValueError("No Table Group is specified.") - - table_group_items = [f"'{item}'" for item in table_group_names] - sql = f"""select ppr.profile_run_id from {schema}.profile_pair_rules ppr - INNER JOIN {schema}.profiling_runs pr ON pr.id = ppr.profile_run_id - INNER JOIN {schema}.table_groups tg ON tg.id = pr.table_groups_id - where tg.table_groups_name in ({",".join(table_group_items)}) - union - select par.table_groups_id from {schema}.profile_anomaly_results par INNER JOIN {schema}.table_groups tg ON tg.id = par.table_groups_id where tg.table_groups_name in ({",".join(table_group_items)}) - union - select pr.table_groups_id from {schema}.profile_results pr INNER JOIN {schema}.table_groups tg ON tg.id = pr.table_groups_id where tg.table_groups_name in ({",".join(table_group_items)}) - union - select pr.table_groups_id from {schema}.profiling_runs pr INNER JOIN {schema}.table_groups tg ON tg.id = pr.table_groups_id where tg.table_groups_name in ({",".join(table_group_items)}) - union - select dtc.table_groups_id from {schema}.data_table_chars dtc INNER JOIN {schema}.table_groups tg ON tg.id = dtc.table_groups_id where tg.table_groups_name in ({",".join(table_group_items)}) - union - select dcs.table_groups_id from {schema}.data_column_chars dcs INNER JOIN {schema}.table_groups tg ON tg.id = dcs.table_groups_id where tg.table_groups_name in ({",".join(table_group_items)});""" - return db.retrieve_data(sql) - - -def get_table_group_usage(schema, table_group_names): - items = [f"'{item}'" for item in table_group_names] - sql = f"""select distinct pr.id from {schema}.profiling_runs pr -INNER JOIN {schema}.table_groups tg ON tg.id = pr.table_groups_id -where tg.table_groups_name in ({",".join(items)}) and pr.status = 'Running'""" - return db.retrieve_data(sql) - - -@st.cache_data(show_spinner=False) -def get_all(schema, project_code): - sql = _get_select_statement(schema) - sql += f"""WHERE project_code = '{project_code}' - ORDER BY table_groups_name - """ - return db.retrieve_data(sql) - - -@st.cache_data(show_spinner=False) -def get_by_connection(schema, project_code, connection_id): - sql = _get_select_statement(schema) - sql += f"""WHERE project_code = '{project_code}' - AND connection_id = '{connection_id}' - ORDER BY table_groups_name - """ - return db.retrieve_data(sql) - - -def edit(schema, table_group): - sql = f"""UPDATE {schema}.table_groups - SET - table_groups_name='{table_group["table_groups_name"]}', - table_group_schema='{table_group["table_group_schema"]}', - profiling_table_set=NULLIF('{table_group["profiling_table_set"]}', ''), - profiling_include_mask='{table_group["profiling_include_mask"]}', - profiling_exclude_mask='{table_group["profiling_exclude_mask"]}', - profile_id_column_mask='{table_group["profile_id_column_mask"]}', - profile_sk_column_mask='{table_group["profile_sk_column_mask"]}', - profile_use_sampling='{'Y' if table_group["profile_use_sampling"] else 'N'}', - profile_sample_percent='{table_group["profile_sample_percent"]}', - profile_sample_min_count={int(table_group["profile_sample_min_count"])}, - profiling_delay_days='{table_group["profiling_delay_days"]}', - profile_flag_cdes={table_group["profile_flag_cdes"]}, - include_in_dashboard={table_group["include_in_dashboard"]}, - description='{table_group["description"]}', - data_source=NULLIF('{table_group["data_source"]}', ''), - source_system=NULLIF('{table_group["source_system"]}', ''), - source_process=NULLIF('{table_group["source_process"]}', ''), - data_location=NULLIF('{table_group["data_location"]}', ''), - business_domain=NULLIF('{table_group["business_domain"]}', ''), - stakeholder_group=NULLIF('{table_group["stakeholder_group"]}', ''), - transform_level=NULLIF('{table_group["transform_level"]}', ''), - data_product=NULLIF('{table_group["data_product"]}', '') - WHERE - id = '{table_group["id"]}' - ; - """ - db.execute_sql(sql) - st.cache_data.clear() - - -def add(schema, table_group) -> str: - new_table_group_id = str(uuid.uuid4()) - sql = f"""INSERT INTO {schema}.table_groups - (id, - project_code, - connection_id, - table_groups_name, - table_group_schema, - profiling_table_set, - profiling_include_mask, - profiling_exclude_mask, - profile_id_column_mask, - profile_sk_column_mask, - profile_use_sampling, - profile_sample_percent, - profile_sample_min_count, - profiling_delay_days, - profile_flag_cdes, - include_in_dashboard, - description, - data_source, - source_system, - source_process, - data_location, - business_domain, - stakeholder_group, - transform_level, - data_product) - SELECT - '{new_table_group_id}', - '{table_group["project_code"]}', - '{table_group["connection_id"]}', - '{table_group["table_groups_name"]}', - '{table_group["table_group_schema"]}', - NULLIF('{table_group["profiling_table_set"]}', ''), - '{table_group["profiling_include_mask"]}', - '{table_group["profiling_exclude_mask"]}', - '{table_group["profile_id_column_mask"]}'::character varying(2000), - '{table_group["profile_sk_column_mask"]}'::character varying, - '{'Y' if table_group["profile_use_sampling"] else 'N' }'::character varying, - '{table_group["profile_sample_percent"]}'::character varying, - {table_group["profile_sample_min_count"]}, - '{table_group["profiling_delay_days"]}'::character varying, - {table_group["profile_flag_cdes"]}, - {table_group["include_in_dashboard"]}, - '{table_group["description"]}', - NULLIF('{table_group["data_source"]}', ''), - NULLIF('{table_group["source_system"]}', ''), - NULLIF('{table_group["source_process"]}', ''), - NULLIF('{table_group["data_location"]}', ''), - NULLIF('{table_group["business_domain"]}', ''), - NULLIF('{table_group["stakeholder_group"]}', ''), - NULLIF('{table_group["transform_level"]}', ''), - NULLIF('{table_group["data_product"]}', '') - ;""" - db.execute_sql(sql) - st.cache_data.clear() - return new_table_group_id - - -def delete(schema, table_group_ids): - if table_group_ids is None or len(table_group_ids) == 0: - raise ValueError("No table group is specified.") - - items = [f"'{item}'" for item in table_group_ids] - sql = f"""DELETE FROM {schema}.table_groups WHERE id in ({",".join(items)})""" - db.execute_sql(sql) - st.cache_data.clear() - - -def cascade_delete(schema, table_group_names): - if table_group_names is None or len(table_group_names) == 0: - raise ValueError("No Table Group is specified.") - - table_group_items = [f"'{item}'" for item in table_group_names] - sql = f"""delete from {schema}.profile_pair_rules ppr -USING {schema}.profiling_runs pr, {schema}.table_groups tg -WHERE -pr.id = ppr.profile_run_id -AND tg.id = pr.table_groups_id -AND tg.table_groups_name in ({",".join(table_group_items)}); -delete from {schema}.profile_anomaly_results par USING {schema}.table_groups tg where tg.id = par.table_groups_id and tg.table_groups_name in ({",".join(table_group_items)}); -delete from {schema}.profile_results pr USING {schema}.table_groups tg where tg.id = pr.table_groups_id and tg.table_groups_name in ({",".join(table_group_items)}); -delete from {schema}.profiling_runs pr USING {schema}.table_groups tg where tg.id = pr.table_groups_id and tg.table_groups_name in ({",".join(table_group_items)}); -delete from {schema}.data_table_chars dtc USING {schema}.table_groups tg where tg.id = dtc.table_groups_id and tg.table_groups_name in ({",".join(table_group_items)}); -delete from {schema}.data_column_chars dcs USING {schema}.table_groups tg where tg.id = dcs.table_groups_id and tg.table_groups_name in ({",".join(table_group_items)}); -delete from {schema}.table_groups where table_groups_name in ({",".join(table_group_items)});""" - db.execute_sql(sql) - st.cache_data.clear() - - -def get_test_suite_ids_by_table_group_id(schema, table_group_id: str) -> list[str]: - sql = f""" - SELECT ts.id::VARCHAR - FROM {schema}.test_suites ts - WHERE ts.table_groups_id = '{table_group_id}' - """ - return db.retrieve_data(sql) - - -def get_profiling_run_ids_by_table_group_id(schema, table_group_id: str) -> list[str]: - sql = f""" - SELECT pr.id::VARCHAR - FROM {schema}.profiling_runs pr - WHERE pr.table_groups_id = '{table_group_id}' - """ - return db.retrieve_data(sql) +from typing import TypedDict + +from sqlalchemy.engine import Row + +from testgen.commands.queries.profiling_query import CProfilingSQL +from testgen.common.models.connection import Connection +from testgen.common.models.table_group import TableGroup +from testgen.ui.services.database_service import fetch_from_target_db + + +class TableGroupPreview(TypedDict): + schema: str + tables: dict[str, bool] + column_count: int + success: bool + message: str | None + + +def get_table_group_preview( + table_group: TableGroup, + connection: Connection | None = None, + verify_table_access: bool = False, +) -> TableGroupPreview: + table_group_preview: TableGroupPreview = { + "schema": table_group.table_group_schema, + "tables": {}, + "column_count": 0, + "success": True, + "message": None, + } + if connection or table_group.connection_id: + try: + connection = connection or Connection.get(table_group.connection_id) + + table_group_results = _fetch_table_group_columns(connection, table_group) + + for column in table_group_results: + table_group_preview["schema"] = column["table_schema"] + table_group_preview["tables"][column["table_name"]] = None + table_group_preview["column_count"] += 1 + + if len(table_group_results) <= 0: + table_group_preview["success"] = False + table_group_preview["message"] = ( + "No tables found matching the criteria. Please check the Table Group configuration" + " or the database permissions." + ) + + if verify_table_access: + for table_name in table_group_preview["tables"].keys(): + try: + results = fetch_from_target_db( + connection, + ( + f"SELECT 1 FROM {table_group_preview['schema']}.{table_name} LIMIT 1" + if connection.sql_flavor != "mssql" + else f"SELECT TOP 1 * FROM {table_group_preview['schema']}.{table_name}" + ), + ) + except Exception as error: + table_group_preview["tables"][table_name] = False + else: + table_group_preview["tables"][table_name] = results is not None and len(results) > 0 + + if not all(table_group_preview["tables"].values()): + table_group_preview["message"] = ( + "Some tables were not accessible. Please the check the database permissions." + ) + except Exception as error: + table_group_preview["success"] = False + table_group_preview["message"] = error.args[0] + else: + table_group_preview["success"] = False + table_group_preview["message"] = "No connection selected. Please select a connection to preview the Table Group." + return table_group_preview + + +def _fetch_table_group_columns(connection: Connection, table_group: TableGroup) -> list[Row]: + profiling_table_set = table_group.profiling_table_set + + sql_generator = CProfilingSQL(table_group.project_code, connection.sql_flavor) + + sql_generator.table_groups_id = table_group.id + sql_generator.connection_id = str(table_group.connection_id) + sql_generator.profile_run_id = "" + sql_generator.data_schema = table_group.table_group_schema + sql_generator.parm_table_set = ( + ",".join([f"'{item.strip()}'" for item in profiling_table_set.split(",")]) + if profiling_table_set + else profiling_table_set + ) + sql_generator.parm_table_include_mask = table_group.profiling_include_mask + sql_generator.parm_table_exclude_mask = table_group.profiling_exclude_mask + sql_generator.profile_id_column_mask = table_group.profile_id_column_mask + sql_generator.profile_sk_column_mask = table_group.profile_sk_column_mask + sql_generator.profile_use_sampling = "Y" if table_group.profile_use_sampling else "N" + sql_generator.profile_sample_percent = table_group.profile_sample_percent + sql_generator.profile_sample_min_count = table_group.profile_sample_min_count + + return fetch_from_target_db(connection, *sql_generator.GetDDFQuery()) diff --git a/testgen/ui/queries/test_definition_queries.py b/testgen/ui/queries/test_definition_queries.py deleted file mode 100644 index 2294a688..00000000 --- a/testgen/ui/queries/test_definition_queries.py +++ /dev/null @@ -1,433 +0,0 @@ -import pandas as pd -import streamlit as st - -import testgen.ui.services.database_service as db -from testgen.common.models import get_current_session, with_database_session - - -def update_attribute(schema, test_definition_ids, attribute, value): - sql = f""" - WITH selected as ( - SELECT UNNEST(ARRAY [{", ".join([ f"'{item}'" for item in test_definition_ids ])}]) AS id - ) - UPDATE {schema}.test_definitions - SET {attribute}='{value}' - FROM {schema}.test_definitions td - INNER JOIN selected ON (td.id = selected.id::UUID) - WHERE td.id = test_definitions.id; - """ - db.execute_sql_raw(sql) - st.cache_data.clear() - - -@st.cache_data(show_spinner=False) -@with_database_session -def get_test_definitions(_, project_code, test_suite, table_name, column_name, test_type, test_definition_ids: list[str] | None): - db_session = get_current_session() - params = {} - order_by = "ORDER BY d.schema_name, d.table_name, d.column_name, d.test_type" - filters = "" - - if project_code: - filters += " AND s.project_code = :project_code" - params["project_code"] = project_code - - if test_suite: - filters += " AND s.test_suite = :test_suite" - params["test_suite"] = test_suite - - if test_definition_ids: - test_definition_params = {f"test_definition_id_{idx}": status for idx, status in enumerate(test_definition_ids)} - filters += f" AND d.id IN ({', '.join([f':{p}' for p in test_definition_params.keys()])})" - params.update(test_definition_params) - - if table_name: - filters += " AND d.table_name = :table_name" - params["table_name"] = table_name - - if column_name: - filters += " AND d.column_name ILIKE :column_name" - params["column_name"] = column_name - - if test_type: - filters += " AND d.test_type = :test_type" - params["test_type"] = test_type - - sql = f""" - SELECT - d.schema_name, d.table_name, d.column_name, t.test_name_short, t.test_name_long, - d.id::VARCHAR(50), - s.project_code, d.table_groups_id::VARCHAR(50), s.test_suite, d.test_suite_id::VARCHAR, - d.test_type, d.cat_test_id::VARCHAR(50), - d.test_active, - CASE WHEN d.test_active = 'Y' THEN 'Yes' ELSE 'No' END as test_active_display, - d.lock_refresh, - CASE WHEN d.lock_refresh = 'Y' THEN 'Yes' ELSE 'No' END as lock_refresh_display, - t.test_scope, - d.test_description, - d.profiling_as_of_date, - d.last_manual_update, - d.severity, COALESCE(d.severity, s.severity, t.default_severity) as urgency, - d.export_to_observability as export_to_observability_raw, - CASE - WHEN d.export_to_observability = 'Y' THEN 'Yes' - WHEN d.export_to_observability = 'N' THEN 'No' - WHEN d.export_to_observability IS NULL AND s.export_to_observability = 'Y' THEN 'Inherited (Yes)' - ELSE 'Inherited (No)' - END as export_to_observability, - -- test_action, - d.threshold_value, COALESCE(t.measure_uom_description, t.measure_uom) as export_uom, - d.baseline_ct, d.baseline_unique_ct, d.baseline_value, - d.baseline_value_ct, d.baseline_sum, d.baseline_avg, d.baseline_sd, - d.lower_tolerance, d.upper_tolerance, - d.subset_condition, - d.groupby_names, d.having_condition, d.window_date_column, d.window_days, - d.match_schema_name, d.match_table_name, d.match_column_names, - d.match_subset_condition, d.match_groupby_names, d.match_having_condition, - d.skip_errors, d.custom_query, - COALESCE(d.test_description, t.test_description) as final_test_description, - t.default_parm_columns, t.selection_criteria, - d.profile_run_id::VARCHAR(50), d.test_action, d.test_definition_status, - d.watch_level, d.check_result, d.last_auto_gen_date, - d.test_mode - FROM test_definitions d - INNER JOIN test_types t ON (d.test_type = t.test_type) - INNER JOIN test_suites s ON (d.test_suite_id = s.id) - WHERE True - {filters} - {order_by} - """ - - results = db_session.execute(sql, params=params) - columns = [column.name for column in results.cursor.description] - - return pd.DataFrame(list(results), columns=columns) - - -def update(schema, test_definition): - sql = f"""UPDATE {schema}.test_definitions - SET - cat_test_id = {test_definition["cat_test_id"]}, - --last_auto_gen_date = NULLIF('test_definition["last_auto_gen_date"]', ''), - --profiling_as_of_date = NULLIF('test_definition["profiling_as_of_date"]', ''), - last_manual_update = CURRENT_TIMESTAMP AT TIME ZONE 'UTC', - skip_errors = {test_definition["skip_errors"]}, - custom_query = NULLIF($${test_definition["custom_query"]}$$, ''), - test_definition_status = NULLIF('{test_definition["test_definition_status"]}', ''), - export_to_observability = NULLIF('{test_definition["export_to_observability"]}', ''), - column_name = NULLIF($${test_definition["column_name"]}$$, ''), - watch_level = NULLIF('{test_definition["watch_level"]}', ''), - table_groups_id = '{test_definition["table_groups_id"]}'::UUID, - """ - - if test_definition["profile_run_id"]: - sql += f"profile_run_id = '{test_definition['profile_run_id']}'::UUID,\n" - if test_definition["test_suite_id"]: - sql += f"test_suite_id = '{test_definition['test_suite_id']}'::UUID,\n" - - sql += f""" test_type = NULLIF('{test_definition["test_type"]}', ''), - test_description = NULLIF($${test_definition["test_description"]}$$, ''), - test_action = NULLIF('{test_definition["test_action"]}', ''), - test_mode = NULLIF('{test_definition["test_mode"]}', ''), - lock_refresh = NULLIF('{test_definition["lock_refresh"]}', ''), - schema_name = NULLIF('{test_definition["schema_name"]}', ''), - table_name = NULLIF('{test_definition["table_name"]}', ''), - test_active = NULLIF('{test_definition["test_active"]}', ''), - severity = NULLIF('{test_definition["severity"]}', ''), - check_result = NULLIF('{test_definition["check_result"]}', ''), - baseline_ct = NULLIF('{test_definition["baseline_ct"]}', ''), - baseline_unique_ct = NULLIF('{test_definition["baseline_unique_ct"]}', ''), - baseline_value = NULLIF($${test_definition["baseline_value"]}$$, ''), - baseline_value_ct = NULLIF('{test_definition["baseline_value_ct"]}', ''), - threshold_value = NULLIF($${test_definition["threshold_value"]}$$, ''), - baseline_sum = NULLIF('{test_definition["baseline_sum"]}', ''), - baseline_avg = NULLIF('{test_definition["baseline_avg"]}', ''), - baseline_sd = NULLIF('{test_definition["baseline_sd"]}', ''), - lower_tolerance = NULLIF('{test_definition["lower_tolerance"]}', ''), - upper_tolerance = NULLIF('{test_definition["upper_tolerance"]}', ''), - subset_condition = NULLIF($${test_definition["subset_condition"]}$$, ''), - groupby_names = NULLIF($${test_definition["groupby_names"]}$$, ''), - having_condition = NULLIF($${test_definition["having_condition"]}$$, ''), - window_date_column = NULLIF('{test_definition["window_date_column"]}', ''), - match_schema_name = NULLIF('{test_definition["match_schema_name"]}', ''), - match_table_name = NULLIF('{test_definition["match_table_name"]}', ''), - match_column_names = NULLIF($${test_definition["match_column_names"]}$$, ''), - match_subset_condition = NULLIF($${test_definition["match_subset_condition"]}$$, ''), - match_groupby_names = NULLIF($${test_definition["match_groupby_names"]}$$, ''), - match_having_condition = NULLIF($${test_definition["match_having_condition"]}$$, ''), - window_days = COALESCE({test_definition["window_days"]}, 0) - where - id = '{test_definition["id"]}' - ; - """ - db.execute_sql(sql) - st.cache_data.clear() - - -def add(schema, test_definition): - sql = f"""INSERT INTO {schema}.test_definitions - ( - --cat_test_id, - --last_auto_gen_date, - --profiling_as_of_date, - last_manual_update, - skip_errors, - custom_query, - test_definition_status, - export_to_observability, - column_name, - watch_level, - table_groups_id, - profile_run_id, - test_type, - test_suite_id, - test_description, - test_action, - test_mode, - lock_refresh, - schema_name, - table_name, - test_active, - severity, - check_result, - baseline_ct, - baseline_unique_ct, - baseline_value, - baseline_value_ct, - threshold_value, - baseline_sum, - baseline_avg, - baseline_sd, - lower_tolerance, - upper_tolerance, - subset_condition, - groupby_names, - having_condition, - window_date_column, - match_schema_name, - match_table_name, - match_column_names, - match_subset_condition, - match_groupby_names, - match_having_condition, - window_days - ) - SELECT - --{test_definition["cat_test_id"]} as cat_test_id, - --NULLIF('test_definition["last_auto_gen_date"]', '') as last_auto_gen_date, - --NULLIF('test_definition["profiling_as_of_date"]', '') as profiling_as_of_date, - CURRENT_TIMESTAMP AT TIME ZONE 'UTC' as last_manual_update, - {test_definition["skip_errors"]} as skip_errors, - NULLIF($${test_definition["custom_query"]}$$, '') as custom_query, - NULLIF('{test_definition["test_definition_status"]}', '') as test_definition_status, - NULLIF('{test_definition["export_to_observability"]}', '') as export_to_observability, - NULLIF('{test_definition["column_name"]}', '') as column_name, - NULLIF('{test_definition["watch_level"]}', '') as watch_level, - '{test_definition["table_groups_id"]}'::UUID as table_groups_id, - NULL AS profile_run_id, - NULLIF('{test_definition["test_type"]}', '') as test_type, - '{test_definition["test_suite_id"]}'::UUID as test_suite_id, - NULLIF('{test_definition["test_description"]}', '') as test_description, - NULLIF('{test_definition["test_action"]}', '') as test_action, - NULLIF('{test_definition["test_mode"]}', '') as test_mode, - NULLIF('{test_definition["lock_refresh"]}', '') as lock_refresh, - NULLIF('{test_definition["schema_name"]}', '') as schema_name, - NULLIF('{test_definition["table_name"]}', '') as table_name, - NULLIF('{test_definition["test_active"]}', '') as test_active, - NULLIF('{test_definition["severity"]}', '') as severity, - NULLIF('{test_definition["check_result"]}', '') as check_result, - NULLIF('{test_definition["baseline_ct"]}', '') as baseline_ct, - NULLIF('{test_definition["baseline_unique_ct"]}', '') as baseline_unique_ct, - NULLIF($${test_definition["baseline_value"]}$$, '') as baseline_value, - NULLIF($${test_definition["baseline_value_ct"]}$$, '') as baseline_value_ct, - NULLIF($${test_definition["threshold_value"]}$$, '') as threshold_value, - NULLIF($${test_definition["baseline_sum"]}$$, '') as baseline_sum, - NULLIF('{test_definition["baseline_avg"]}', '') as baseline_avg, - NULLIF('{test_definition["baseline_sd"]}', '') as baseline_sd, - NULLIF('{test_definition["lower_tolerance"]}', '') as lower_tolerance, - NULLIF('{test_definition["upper_tolerance"]}', '') as upper_tolerance, - NULLIF($${test_definition["subset_condition"]}$$, '') as subset_condition, - NULLIF($${test_definition["groupby_names"]}$$, '') as groupby_names, - NULLIF($${test_definition["having_condition"]}$$, '') as having_condition, - NULLIF('{test_definition["window_date_column"]}', '') as window_date_column, - NULLIF('{test_definition["match_schema_name"]}', '') as match_schema_name, - NULLIF('{test_definition["match_table_name"]}', '') as match_table_name, - NULLIF($${test_definition["match_column_names"]}$$, '') as match_column_names, - NULLIF($${test_definition["match_subset_condition"]}$$, '') as match_subset_condition, - NULLIF($${test_definition["match_groupby_names"]}$$, '') as match_groupby_names, - NULLIF($${test_definition["match_having_condition"]}$$, '') as match_having_condition, - COALESCE({test_definition["window_days"]}, 0) as window_days - ; - """ - db.execute_sql(sql) - st.cache_data.clear() - - -def delete(schema, test_definition_ids): - if test_definition_ids is None or len(test_definition_ids) == 0: - raise ValueError("No Test Definition is specified.") - - items = [f"'{item}'" for item in test_definition_ids] - sql = f"""DELETE FROM {schema}.test_definitions WHERE id in ({",".join(items)})""" - db.execute_sql(sql) - st.cache_data.clear() - - -def cascade_delete(schema, test_suite_ids): - if not test_suite_ids: - raise ValueError("No Test Suite is specified.") - - ids_str = ", ".join([f"'{item}'" for item in test_suite_ids]) - sql = f""" - DELETE FROM {schema}.test_definitions WHERE test_suite_id in ({ids_str}) - """ - db.execute_sql(sql) - st.cache_data.clear() - - -def move(schema, test_definitions, target_table_group, target_test_suite, target_table_column=None): - if target_table_column is not None: - update_target_table_column = f""" - column_name = '{target_table_column['column_name']}', - table_name = '{target_table_column['table_name']}', - """ - else: - update_target_table_column = "" - sql = f""" - WITH selected as ( - SELECT UNNEST(ARRAY [{", ".join([ f"'{td['id']}'" for td in test_definitions ])}]) AS id - ) - UPDATE {schema}.test_definitions - SET - {update_target_table_column} - table_groups_id = '{target_table_group}'::UUID, - test_suite_id = '{target_test_suite}'::UUID - FROM {schema}.test_definitions td - INNER JOIN selected ON (td.id = selected.id::UUID) - WHERE td.id = test_definitions.id; - """ - db.execute_sql_raw(sql) - st.cache_data.clear() - - -def copy(schema, test_definitions, target_table_group, target_test_suite, target_table_column=None): - if target_table_column is not None: - update_target_column = f"'{target_table_column['column_name']}' as column_name" - update_target_table = f"'{target_table_column['table_name']}' as table_name" - else: - update_target_column = "td.column_name" - update_target_table = "td.table_name" - test_definition_ids = [f"'{td['id']}'" for td in test_definitions] - sql = f""" - INSERT INTO {schema}.test_definitions - ( - profiling_as_of_date, - last_manual_update, - skip_errors, - custom_query, - test_definition_status, - export_to_observability, - column_name, - watch_level, - table_groups_id, - profile_run_id, - test_type, - test_suite_id, - test_description, - test_action, - test_mode, - lock_refresh, - last_auto_gen_date, - schema_name, - table_name, - test_active, - severity, - check_result, - baseline_ct, - baseline_unique_ct, - baseline_value, - baseline_value_ct, - threshold_value, - baseline_sum, - baseline_avg, - baseline_sd, - lower_tolerance, - upper_tolerance, - subset_condition, - groupby_names, - having_condition, - window_date_column, - match_schema_name, - match_table_name, - match_column_names, - match_subset_condition, - match_groupby_names, - match_having_condition, - window_days - ) - SELECT - td.profiling_as_of_date, - td.last_manual_update, - td.skip_errors, - td.custom_query, - td.test_definition_status, - td.export_to_observability, - {update_target_column}, - td.watch_level, - '{target_table_group}'::UUID AS table_groups_id, - CASE WHEN td.table_groups_id = '{target_table_group}' THEN td.profile_run_id ELSE NULL END AS profile_run_id, - td.test_type, - '{target_test_suite}'::UUID AS test_suite_id, - td.test_description, - td.test_action, - td.test_mode, - td.lock_refresh, - td.last_auto_gen_date, - td.schema_name, - {update_target_table}, - td.test_active, - td.severity, - td.check_result, - td.baseline_ct, - td.baseline_unique_ct, - td.baseline_value, - td.baseline_value_ct, - td.threshold_value, - td.baseline_sum, - td.baseline_avg, - td.baseline_sd, - td.lower_tolerance, - td.upper_tolerance, - td.subset_condition, - td.groupby_names, - td.having_condition, - td.window_date_column, - td.match_schema_name, - td.match_table_name, - td.match_column_names, - td.match_subset_condition, - td.match_groupby_names, - td.match_having_condition, - td.window_days - FROM {schema}.test_definitions as td - WHERE - td.id in ({",".join(test_definition_ids)}) - ; - """ - db.execute_sql(sql) - st.cache_data.clear() - - -def get_test_definitions_collision(schema, test_definitions, target_table_group, target_test_suite): - test_definition_keys = [f"('{td['table_name']}', '{td['column_name']}', '{td['test_type']}')" for td in test_definitions] - test_definitions_keys_str = f"({", ".join(test_definition_keys)})" - sql = f""" - SELECT table_name, column_name, test_type, lock_refresh - FROM {schema}.test_definitions - WHERE table_groups_id = '{target_table_group}' - AND test_suite_id = '{target_test_suite}' - AND last_auto_gen_date IS NOT NULL - AND (table_name, column_name, test_type) in {test_definitions_keys_str}; - """ - return db.retrieve_data(sql) - diff --git a/testgen/ui/queries/test_result_queries.py b/testgen/ui/queries/test_result_queries.py new file mode 100644 index 00000000..4636bb4c --- /dev/null +++ b/testgen/ui/queries/test_result_queries.py @@ -0,0 +1,162 @@ +from typing import Literal + +import pandas as pd +import streamlit as st + +from testgen.ui.services.database_service import fetch_df_from_db + + +@st.cache_data(show_spinner="Loading data ...") +def get_test_results( + run_id: str, + test_statuses: list[str] | None = None, + test_type_id: str | None = None, + table_name: str | None = None, + column_name: str | None = None, + action: Literal["Confirmed", "Dismissed", "Muted", "No Action"] | None = None, + sorting_columns: list[str] | None = None, +) -> pd.DataFrame: + query = f""" + WITH run_results + AS (SELECT * + FROM test_results r + WHERE + r.test_run_id = :run_id + {"AND r.result_status IN :test_statuses" if test_statuses else ""} + {"AND r.test_type = :test_type_id" if test_type_id else ""} + {"AND r.table_name = :table_name" if table_name else ""} + {"AND r.column_names ILIKE :column_name" if column_name else ""} + {"AND r.disposition IS NULL" if action == "No Action" else "AND r.disposition = :disposition" if action else ""} + ) + SELECT r.table_name, + p.project_name, ts.test_suite, tg.table_groups_name, cn.connection_name, cn.project_host, cn.sql_flavor, + tt.dq_dimension, tt.test_scope, + r.schema_name, r.column_names, r.test_time::DATE as test_date, r.test_type, tt.id as test_type_id, + tt.test_name_short, tt.test_name_long, r.test_description, tt.measure_uom, tt.measure_uom_description, + c.test_operator, r.threshold_value::NUMERIC(16, 5), r.result_measure::NUMERIC(16, 5), r.result_status, + CASE + WHEN r.result_code <> 1 THEN r.disposition + ELSE 'Passed' + END as disposition, + NULL::VARCHAR(1) as action, + r.input_parameters, r.result_message, CASE WHEN result_code <> 1 THEN r.severity END as severity, + r.result_code as passed_ct, + (1 - r.result_code)::INTEGER as exception_ct, + CASE + WHEN result_status = 'Warning' + AND result_message NOT ILIKE 'Inactivated%%' THEN 1 + END::INTEGER as warning_ct, + CASE + WHEN result_status = 'Failed' + AND result_message NOT ILIKE 'Inactivated%%' THEN 1 + END::INTEGER as failed_ct, + CASE + WHEN result_message ILIKE 'Inactivated%%' THEN 1 + END as execution_error_ct, + p.project_code, r.table_groups_id::VARCHAR, + r.id::VARCHAR as test_result_id, r.test_run_id::VARCHAR, + c.id::VARCHAR as connection_id, r.test_suite_id::VARCHAR, + r.test_definition_id::VARCHAR as test_definition_id_runtime, + CASE + WHEN r.auto_gen = TRUE THEN d.id + ELSE r.test_definition_id + END::VARCHAR as test_definition_id_current, + r.auto_gen, + + -- These are used in the PDF report + tt.threshold_description, tt.usage_notes, r.test_time, + dcc.description as column_description, + COALESCE(dcc.critical_data_element, dtc.critical_data_element) as critical_data_element, + COALESCE(dcc.data_source, dtc.data_source, tg.data_source) as data_source, + COALESCE(dcc.source_system, dtc.source_system, tg.source_system) as source_system, + COALESCE(dcc.source_process, dtc.source_process, tg.source_process) as source_process, + COALESCE(dcc.business_domain, dtc.business_domain, tg.business_domain) as business_domain, + COALESCE(dcc.stakeholder_group, dtc.stakeholder_group, tg.stakeholder_group) as stakeholder_group, + COALESCE(dcc.transform_level, dtc.transform_level, tg.transform_level) as transform_level, + COALESCE(dcc.aggregation_level, dtc.aggregation_level) as aggregation_level, + COALESCE(dcc.data_product, dtc.data_product, tg.data_product) as data_product + FROM run_results r + INNER JOIN test_types tt + ON (r.test_type = tt.test_type) + LEFT JOIN test_definitions d + ON (r.test_suite_id = d.test_suite_id + AND r.table_name = d.table_name + AND COALESCE(r.column_names, 'N/A') = COALESCE(d.column_name, 'N/A') + AND r.test_type = d.test_type + AND r.auto_gen = TRUE + AND d.last_auto_gen_date IS NOT NULL) + INNER JOIN test_suites ts + ON r.test_suite_id = ts.id + INNER JOIN projects p + ON (ts.project_code = p.project_code) + INNER JOIN table_groups tg + ON (ts.table_groups_id = tg.id) + INNER JOIN connections cn + ON (tg.connection_id = cn.connection_id) + LEFT JOIN cat_test_conditions c + ON (cn.sql_flavor = c.sql_flavor + AND r.test_type = c.test_type) + LEFT JOIN data_column_chars dcc + ON (tg.id = dcc.table_groups_id + AND r.schema_name = dcc.schema_name + AND r.table_name = dcc.table_name + AND r.column_names = dcc.column_name) + LEFT JOIN data_table_chars dtc + ON dcc.table_id = dtc.table_id + {f"ORDER BY {', '.join(' '.join(col) for col in sorting_columns)}" if sorting_columns else ""}; + """ + params = { + "run_id": run_id, + "test_statuses": tuple(test_statuses or []), + "test_type_id": test_type_id, + "table_name": table_name, + "column_name": column_name, + "disposition": { + "Muted": "Inactive", + }.get(action, action), + } + + df = fetch_df_from_db(query, params) + df["test_date"] = pd.to_datetime(df["test_date"]) + return df + + +@st.cache_data(show_spinner=False) +def get_test_result_history(tr_data, limit: int | None = None): + query = f""" + SELECT test_date, + test_type, + test_name_short, + test_name_long, + measure_uom, + test_operator, + threshold_value::NUMERIC, + result_measure::NUMERIC, + result_status + FROM v_test_results + WHERE {f""" + test_suite_id = :test_suite_id + AND table_name = :table_name + AND column_names {"= :column_names" if tr_data["column_names"] else "IS NULL"} + AND test_type = :test_type + AND auto_gen = TRUE + """ if tr_data["auto_gen"] else """ + test_definition_id_runtime = :test_definition_id_runtime + """} + ORDER BY test_date DESC + {'LIMIT ' + str(limit) if limit else ''}; + """ + params = { + "test_suite_id": tr_data["test_suite_id"], + "table_name": tr_data["table_name"], + "column_names": tr_data["column_names"], + "test_type": tr_data["test_type"], + "test_definition_id_runtime": tr_data["test_definition_id_runtime"], + } + + df = fetch_df_from_db(query, params) + df["test_date"] = pd.to_datetime(df["test_date"]) + df["threshold_value"] = pd.to_numeric(df["threshold_value"]) + df["result_measure"] = pd.to_numeric(df["result_measure"]) + + return df diff --git a/testgen/ui/queries/test_run_queries.py b/testgen/ui/queries/test_run_queries.py deleted file mode 100644 index 9259f3ed..00000000 --- a/testgen/ui/queries/test_run_queries.py +++ /dev/null @@ -1,134 +0,0 @@ -from datetime import datetime -from typing import NamedTuple - -import streamlit as st - -import testgen.common.date_service as date_service -import testgen.ui.services.database_service as db -from testgen.common.models import get_current_session - - -def is_running(test_run_id: str | tuple[str]) -> bool: - session = get_current_session() - - test_run_ids: tuple[str] = tuple(test_run_id) - if isinstance(test_run_id, str): - test_run_ids = (test_run_id,) - - query = """ - SELECT id - FROM test_runs - WHERE id::text IN :test_run_ids - AND status = 'Running' - """ - result = session.execute(query, params={"test_run_ids": test_run_ids}) - return result and len(result.all()) > 0 - - -def cascade_delete(test_suite_ids: list[str]) -> None: - if not test_suite_ids: - raise ValueError("No Test Suite is specified.") - - schema: str = st.session_state["dbschema"] - ids_str = ", ".join([f"'{item}'" for item in test_suite_ids]) - sql = f""" - DELETE - FROM {schema}.working_agg_cat_results - WHERE test_run_id in (select id from {schema}.test_runs where test_suite_id in ({ids_str})); - DELETE - FROM {schema}.working_agg_cat_tests - WHERE test_run_id in (select id from {schema}.test_runs where test_suite_id in ({ids_str})); - DELETE FROM {schema}.test_runs WHERE test_suite_id in ({ids_str}); - DELETE FROM {schema}.test_results WHERE test_suite_id in ({ids_str}); - """ - db.execute_sql(sql) - st.cache_data.clear() - - -def cascade_delete_test_run(test_run_id: str) -> None: - if not test_run_id: - raise ValueError("No Test Run is specified.") - - schema: str = st.session_state["dbschema"] - sql = f""" - DELETE - FROM {schema}.working_agg_cat_results - WHERE test_run_id = '{test_run_id}'; - DELETE - FROM {schema}.working_agg_cat_tests - WHERE test_run_id = '{test_run_id}'; - DELETE FROM {schema}.test_runs WHERE id = '{test_run_id}'; - DELETE FROM {schema}.test_results WHERE test_run_id = '{test_run_id}'; - """ - db.execute_sql(sql) - st.cache_data.clear() - - -def cascade_delete_multiple_test_runs(test_run_ids: list[str]) -> None: - if not test_run_ids: - raise ValueError("No Test Run is specified.") - - test_run_ids_str = ", ".join([f"'{run_id}'" for run_id in test_run_ids]) - schema: str = st.session_state["dbschema"] - sql = f""" - DELETE - FROM {schema}.working_agg_cat_results - WHERE test_run_id IN ({test_run_ids_str}); - DELETE - FROM {schema}.working_agg_cat_tests - WHERE test_run_id IN ({test_run_ids_str}); - DELETE FROM {schema}.test_runs WHERE id IN ({test_run_ids_str}); - DELETE FROM {schema}.test_results WHERE test_run_id IN ({test_run_ids_str}); - """ - db.execute_sql(sql) - st.cache_data.clear() - - -def update_status(test_run_id: str, status: str) -> None: - if not all([test_run_id, status]): - raise ValueError("Missing query parameters.") - - schema: str = st.session_state["dbschema"] - now = date_service.get_now_as_string() - - sql = f""" - UPDATE {schema}.test_runs - SET status = '{status}', - test_endtime = '{now}' - WHERE id = '{test_run_id}'::UUID; - """ - db.execute_sql(sql) - st.cache_data.clear() - - -def cancel_all_running() -> None: - schema: str = db.get_schema() - db.execute_sql(f""" - UPDATE {schema}.test_runs - SET status = 'Cancelled' - WHERE status = 'Running'; - """) - - -class LatestTestRun(NamedTuple): - id: str - run_time: datetime - - -def get_latest_run_date(project_code: str) -> LatestTestRun | None: - session = get_current_session() - result = session.execute( - """ - SELECT runs.id, test_starttime - FROM test_runs AS runs - INNER JOIN test_suites AS suite ON (suite.id = runs.test_suite_id) - WHERE project_code = :project_code - AND status = 'Complete' - ORDER BY test_starttime DESC - LIMIT 1 - """, - params={"project_code": project_code}, - ) - if result and (latest_run := result.first()): - return LatestTestRun(str(latest_run.id), latest_run.test_starttime) - return None diff --git a/testgen/ui/queries/test_suite_queries.py b/testgen/ui/queries/test_suite_queries.py deleted file mode 100644 index 790a178c..00000000 --- a/testgen/ui/queries/test_suite_queries.py +++ /dev/null @@ -1,255 +0,0 @@ -import pandas as pd -import streamlit as st - -import testgen.ui.services.database_service as db - - -@st.cache_data(show_spinner=False) -def get_by_project(schema, project_code, table_group_id=None): - sql = f""" - WITH last_gen_date AS ( - SELECT test_suite_id, - MAX(last_auto_gen_date) as auto_gen_date - FROM {schema}.test_definitions - GROUP BY test_suite_id - ), - last_run AS ( - SELECT test_runs.test_suite_id, - test_runs.id, - test_runs.test_starttime, - test_runs.test_ct, - SUM( - CASE - WHEN COALESCE(test_results.disposition, 'Confirmed') = 'Confirmed' - AND test_results.result_status = 'Passed' THEN 1 - ELSE 0 - END - ) as passed_ct, - SUM( - CASE - WHEN COALESCE(test_results.disposition, 'Confirmed') = 'Confirmed' - AND test_results.result_status = 'Warning' THEN 1 - ELSE 0 - END - ) as warning_ct, - SUM( - CASE - WHEN COALESCE(test_results.disposition, 'Confirmed') = 'Confirmed' - AND test_results.result_status = 'Failed' THEN 1 - ELSE 0 - END - ) as failed_ct, - SUM( - CASE - WHEN COALESCE(test_results.disposition, 'Confirmed') = 'Confirmed' - AND test_results.result_status = 'Error' THEN 1 - ELSE 0 - END - ) as error_ct, - SUM( - CASE - WHEN COALESCE(test_results.disposition, 'Confirmed') IN ('Dismissed', 'Inactive') THEN 1 - ELSE 0 - END - ) as dismissed_ct - FROM {schema}.test_suites - LEFT JOIN {schema}.test_runs ON ( - test_suites.last_complete_test_run_id = test_runs.id - ) - LEFT JOIN {schema}.test_results ON ( - test_runs.id = test_results.test_run_id - ) - GROUP BY test_runs.id - ), - test_defs AS ( - SELECT test_suite_id, - COUNT(*) as count - FROM {schema}.test_definitions - GROUP BY test_suite_id - ) - SELECT - suites.id::VARCHAR(50), - suites.project_code, - suites.test_suite, - suites.connection_id::VARCHAR(50), - connections.connection_name, - suites.table_groups_id::VARCHAR(50), - groups.table_groups_name, - suites.test_suite_description, - suites.test_action, - CASE WHEN suites.severity IS NULL THEN 'Inherit' ELSE suites.severity END, - suites.export_to_observability, - suites.dq_score_exclude, - suites.test_suite_schema, - suites.component_key, - suites.component_type, - suites.component_name, - test_defs.count as test_ct, - last_gen_date.auto_gen_date as latest_auto_gen_date, - last_complete_profile_run_id, - last_run.id as latest_run_id, - last_run.test_starttime as latest_run_start, - last_run.test_ct as last_run_test_ct, - last_run.passed_ct as last_run_passed_ct, - last_run.warning_ct as last_run_warning_ct, - last_run.failed_ct as last_run_failed_ct, - last_run.error_ct as last_run_error_ct, - last_run.dismissed_ct as last_run_dismissed_ct - FROM {schema}.test_suites as suites - LEFT JOIN last_gen_date - ON (suites.id = last_gen_date.test_suite_id) - LEFT JOIN last_run - ON (suites.id = last_run.test_suite_id) - LEFT JOIN test_defs - ON (suites.id = test_defs.test_suite_id) - LEFT JOIN {schema}.connections AS connections - ON (connections.connection_id = suites.connection_id) - LEFT JOIN {schema}.table_groups as groups - ON (groups.id = suites.table_groups_id) - WHERE suites.project_code = '{project_code}' - """ - - if table_group_id: - sql += f""" - AND suites.table_groups_id = '{table_group_id}' - """ - - sql += """ - ORDER BY suites.test_suite; - """ - - return db.retrieve_data(sql) - - -@st.cache_data(show_spinner=False) -def get_by_id(schema: str, test_suite_id: str) -> pd.DataFrame: - sql = f""" - SELECT - suites.id::VARCHAR(50), - suites.project_code, - suites.test_suite, - suites.connection_id::VARCHAR(50), - suites.table_groups_id::VARCHAR(50), - suites.test_suite_description, - suites.test_action, - CASE WHEN suites.severity IS NULL THEN 'Inherit' ELSE suites.severity END, - suites.export_to_observability, - suites.dq_score_exclude, - suites.test_suite_schema, - suites.component_key, - suites.component_type, - suites.component_name - FROM {schema}.test_suites as suites - WHERE suites.id = '{test_suite_id}'; - """ - return db.retrieve_data(sql) - - -def edit(schema, test_suite): - sql = f"""UPDATE {schema}.test_suites - SET - test_suite='{test_suite["test_suite"]}', - test_suite_description='{test_suite["test_suite_description"]}', - test_action=NULLIF('{test_suite["test_action"]}', ''), - severity=NULLIF('{test_suite["severity"]}', 'Inherit'), - export_to_observability='{'Y' if test_suite["export_to_observability"] else 'N'}', - dq_score_exclude={test_suite["dq_score_exclude"]}, - test_suite_schema=NULLIF('{test_suite["test_suite_schema"]}', ''), - component_key=NULLIF('{test_suite["component_key"]}', ''), - component_type=NULLIF('{test_suite["component_type"]}', ''), - component_name=NULLIF('{test_suite["component_name"]}', '') - where - id = '{test_suite["id"]}'; - """ - db.execute_sql(sql) - st.cache_data.clear() - - -def add(schema, test_suite): - sql = f"""INSERT INTO {schema}.test_suites - (id, - project_code, test_suite, connection_id, table_groups_id, test_suite_description, test_action, - severity, export_to_observability, dq_score_exclude, test_suite_schema, component_key, component_type, - component_name) - SELECT - gen_random_uuid(), - '{test_suite["project_code"]}', - '{test_suite["test_suite"]}', - '{test_suite["connection_id"]}', - '{test_suite["table_groups_id"]}', - NULLIF('{test_suite["test_suite_description"]}', ''), - NULLIF('{test_suite["test_action"]}', ''), - NULLIF('{test_suite["severity"]}', 'Inherit'), - '{'Y' if test_suite["export_to_observability"] else 'N' }'::character varying, - {test_suite["dq_score_exclude"]}, - NULLIF('{test_suite["test_suite_schema"]}', ''), - NULLIF('{test_suite["component_key"]}', ''), - NULLIF('{test_suite["component_type"]}', ''), - NULLIF('{test_suite["component_name"]}', '') - ;""" - db.execute_sql(sql) - st.cache_data.clear() - - -def delete(schema, test_suite_ids: list[str]): - if not test_suite_ids: - raise ValueError("No table group is specified.") - - ids_str = ",".join([f"'{item}'" for item in test_suite_ids]) - sql = f"""DELETE FROM {schema}.test_suites WHERE id in ({ids_str})""" - db.execute_sql(sql) - st.cache_data.clear() - - -def get_test_suite_dependencies(schema: str, test_suite_ids: list[str]) -> pd.DataFrame: - ids_str = ", ".join([f"'{item}'" for item in test_suite_ids]) - sql = f""" - SELECT DISTINCT test_suite_id FROM {schema}.test_definitions WHERE test_suite_id in ({ids_str}) - UNION - SELECT DISTINCT test_suite_id FROM {schema}.test_results WHERE test_suite_id in ({ids_str}); - """ - return db.retrieve_data(sql) - - -def get_test_suite_usage(schema: str, test_suite_ids: list[str]) -> pd.DataFrame: - ids_str = ", ".join([f"'{item}'" for item in test_suite_ids]) - sql = f""" - SELECT DISTINCT test_suite_id FROM {schema}.test_runs WHERE test_suite_id in ({ids_str}) AND status = 'Running' - """ - return db.retrieve_data(sql) - - -def get_test_suite_refresh_check(schema, test_suite_id): - sql = f""" - SELECT COUNT(*) as test_ct, - SUM(CASE WHEN COALESCE(d.lock_refresh, 'N') = 'N' THEN 1 ELSE 0 END) as unlocked_test_ct, - SUM(CASE WHEN COALESCE(d.lock_refresh, 'N') = 'N' AND d.last_manual_update IS NOT NULL THEN 1 ELSE 0 END) as unlocked_edits_ct - FROM {schema}.test_definitions d - INNER JOIN {schema}.test_types t - ON (d.test_type = t.test_type) - WHERE d.test_suite_id = '{test_suite_id}' - AND t.run_type = 'CAT' - AND t.selection_criteria IS NOT NULL; -""" - return db.retrieve_data_list(sql)[0] - - -def get_generation_sets(schema): - sql = f""" - SELECT DISTINCT generation_set - FROM {schema}.generation_sets - ORDER BY generation_set; -""" - return db.retrieve_data(sql) - - -def lock_edited_tests(schema, test_suite_id): - sql = f""" - UPDATE {schema}.test_definitions - SET lock_refresh = 'Y' - WHERE test_suite_id = '{test_suite_id}' - AND last_manual_update IS NOT NULL - AND lock_refresh = 'N'; -""" - db.execute_sql(sql) - return True diff --git a/testgen/ui/queries/user_queries.py b/testgen/ui/queries/user_queries.py deleted file mode 100644 index d245dbeb..00000000 --- a/testgen/ui/queries/user_queries.py +++ /dev/null @@ -1,57 +0,0 @@ -import streamlit as st - -import testgen.ui.services.database_service as db -from testgen.common.encrypt import encrypt_ui_password - - -@st.cache_data(show_spinner=False) -def get_users(include_password: bool=False): - schema: str = st.session_state["dbschema"] - sql = f""" - SELECT - id::VARCHAR(50), - username, email, "name", - {"password," if include_password else ""} - role - FROM {schema}.auth_users - """ - return db.retrieve_data(sql) - - -def delete_users(user_ids): - if user_ids is None or len(user_ids) == 0: - raise ValueError("No user is specified.") - - schema: str = st.session_state["dbschema"] - items = [f"'{item}'" for item in user_ids] - sql = f"""DELETE FROM {schema}.auth_users WHERE id in ({",".join(items)})""" - db.execute_sql(sql) - st.cache_data.clear() - - -def add_user(user): - schema: str = st.session_state["dbschema"] - encrypted_password = encrypt_ui_password(user["password"]) - sql = f"""INSERT INTO {schema}.auth_users - (username, email, name, password, role) -SELECT - '{user["username"]}' as username, - '{user["email"]}' as email, - '{user["name"]}' as name, - '{encrypted_password}' as password, - '{user["role"]}' as role;""" - db.execute_sql(sql) - st.cache_data.clear() - - -def edit_user(user): - schema: str = st.session_state["dbschema"] - sql = f"""UPDATE {schema}.auth_users SET - username = '{user["username"]}', - email = '{user["email"]}', - name = '{user["name"]}', - {f"password = '{encrypt_ui_password(user["password"])}'," if user["password"] else ""} - role = '{user["role"]}' - WHERE id = '{user["user_id"]}';""" - db.execute_sql(sql) - st.cache_data.clear() diff --git a/testgen/ui/services/connection_service.py b/testgen/ui/services/connection_service.py deleted file mode 100644 index 215fc7d4..00000000 --- a/testgen/ui/services/connection_service.py +++ /dev/null @@ -1,212 +0,0 @@ -import streamlit as st - -import testgen.ui.queries.connection_queries as connection_queries -import testgen.ui.services.table_group_service as table_group_service -from testgen.commands.run_profiling_bridge import InitializeProfilingSQL -from testgen.common.database.database_service import ( - AssignConnectParms, - empty_cache, - get_db_type, - get_flavor_service, -) -from testgen.common.encrypt import DecryptText, EncryptText - - -def get_by_id(connection_id: str, hide_passwords: bool = True) -> dict | None: - connections_df = connection_queries.get_by_id(connection_id) - decrypt_connections(connections_df, hide_passwords) - connections_list = connections_df.to_dict(orient="records") - if len(connections_list): - return connections_list[0] - - -def get_connections(project_code, hide_passwords: bool = False): - connections = connection_queries.get_connections(project_code) - decrypt_connections(connections, hide_passwords) - return connections - - -def decrypt_connections(connections, hide_passwords: bool = False): - for index, connection in connections.iterrows(): - if hide_passwords: - password = "***" # noqa S105 - private_key = "***" # S105 - private_key_passphrase = "***" # noqa S105 - else: - password = DecryptText(connection["project_pw_encrypted"]) if connection["project_pw_encrypted"] else None - private_key = DecryptText(connection["private_key"]) if connection["private_key"] else None - private_key_passphrase = DecryptText(connection["private_key_passphrase"]) if connection["private_key_passphrase"] else "" - connections.at[index, "password"] = password - connections.at[index, "private_key"] = private_key - connections.at[index, "private_key_passphrase"] = private_key_passphrase - - -def encrypt_credentials(connection): - encrypted_password = EncryptText(connection["password"]) if connection["password"] else None - encrypted_private_key = EncryptText(connection["private_key"]) if connection["private_key"] else None - encrypted_private_key_passphrase = EncryptText(connection["private_key_passphrase"]) if connection["private_key_passphrase"] else None - return encrypted_password, encrypted_private_key, encrypted_private_key_passphrase - - -def edit_connection(connection): - empty_cache() - schema = st.session_state["dbschema"] - connection = pre_save_connection_process(connection) - encrypted_password, encrypted_private_key, encrypted_private_key_passphrase = encrypt_credentials(connection) - connection_queries.edit_connection(schema, connection, encrypted_password, encrypted_private_key, encrypted_private_key_passphrase) - - -def add_connection(connection) -> int: - empty_cache() - schema = st.session_state["dbschema"] - connection = pre_save_connection_process(connection) - encrypted_password, encrypted_private_key, encrypted_private_key_passphrase = encrypt_credentials(connection) - return connection_queries.add_connection( - schema, - connection, - encrypted_password, - encrypted_private_key, - encrypted_private_key_passphrase, - ) - - -def pre_save_connection_process(connection): - if connection["connect_by_url"]: - url = connection["url"] - if url: - url_sections = url.split("/") - if len(url_sections) > 0: - host_port = url_sections[0] - host_port_sections = host_port.split(":") - if len(host_port_sections) > 0: - connection["project_host"] = host_port_sections[0] - connection["project_port"] = "".join(host_port_sections[1:]) - else: - connection["project_host"] = host_port - connection["project_port"] = "" - if len(url_sections) > 1: - connection["project_db"] = url_sections[1] - return connection - - -def delete_connections(connection_ids): - empty_cache() - schema = st.session_state["dbschema"] - return connection_queries.delete_connections(schema, connection_ids) - - -def cascade_delete(connection_ids, dry_run=False): - schema = st.session_state["dbschema"] - can_be_deleted = True - table_group_names = get_table_group_names_by_connection(connection_ids) - connection_has_dependencies = table_group_names is not None and len(table_group_names) > 0 - if connection_has_dependencies: - can_be_deleted = False - if not dry_run: - if connection_has_dependencies: - table_group_service.cascade_delete(table_group_names) - connection_queries.delete_connections(schema, connection_ids) - return can_be_deleted - - -def are_connections_in_use(connection_ids): - table_group_names = get_table_group_names_by_connection(connection_ids) - table_groups_in_use = table_group_service.are_table_groups_in_use(table_group_names) - return table_groups_in_use - - -def get_table_group_names_by_connection(connection_ids): - if not connection_ids: - return [] - schema = st.session_state["dbschema"] - table_group_names = connection_queries.get_table_group_names_by_connection(schema, connection_ids) - return table_group_names.to_dict()["table_groups_name"].values() - - -def init_profiling_sql(project_code, connection, table_group_schema=None): - # get connection data - empty_cache() - connection_id = str(connection["connection_id"]) if connection["connection_id"] else None - sql_flavor = connection["sql_flavor"] - url = connection["url"] - connect_by_url = connection["connect_by_url"] - connect_by_key = connection["connect_by_key"] - private_key = connection["private_key"] - private_key_passphrase = connection["private_key_passphrase"] - project_host = connection["project_host"] - project_port = connection["project_port"] - project_db = connection["project_db"] - project_user = connection["project_user"] - password = connection["password"] - http_path = connection["http_path"] - - # prepare the profiling query - clsProfiling = InitializeProfilingSQL(project_code, sql_flavor) - - AssignConnectParms( - project_code, - connection_id, - project_host, - project_port, - project_db, - table_group_schema, - project_user, - sql_flavor, - url, - connect_by_url, - connect_by_key, - private_key, - private_key_passphrase, - http_path, - connectname="PROJECT", - password=password, - ) - - return clsProfiling - - -def form_overwritten_connection_url(connection) -> str: - flavor = connection["sql_flavor"] - - connection_credentials = { - "flavor": flavor, - "user": "", - "host": connection["project_host"], - "port": connection["project_port"], - "dbname": connection["project_db"], - "url": None, - "connect_by_url": None, - "connect_by_key": connection.get("connect_by_key"), - "private_key": None, - "private_key_passphrase": "", - "dbschema": "", - } - - db_type = get_db_type(flavor) - flavor_service = get_flavor_service(db_type) - flavor_service.init(connection_credentials) - connection_string = flavor_service.get_connection_string("") - - return connection_string - - -def get_connection_string(flavor: str) -> str: - db_type = get_db_type(flavor) - flavor_service = get_flavor_service(db_type) - flavor_service.init({ - "flavor": flavor, - "user": "", - "host": "", - "port": "", - "dbname": "", - "url": None, - "connect_by_url": None, - "connect_by_key": False, - "private_key": None, - "private_key_passphrase": "", - "dbschema": "", - "http_path": "", - }) - return flavor_service.get_connection_string( - "" - ).replace("%3E", ">").replace("%3C", "<") diff --git a/testgen/ui/services/database_service.py b/testgen/ui/services/database_service.py index c779948f..a094bc84 100644 --- a/testgen/ui/services/database_service.py +++ b/testgen/ui/services/database_service.py @@ -1,283 +1,65 @@ -from urllib.parse import quote_plus - -import pandas as pd -from sqlalchemy import create_engine, text -from sqlalchemy.engine.cursor import CursorResult - -from testgen.common.credentials import ( - get_tg_db, - get_tg_host, - get_tg_password, - get_tg_port, - get_tg_schema, - get_tg_username, -) -from testgen.common.database.database_service import get_flavor_service -from testgen.common.encrypt import DecryptText - -""" - Shared database access and utility functions -""" - - -def get_schema(): - return get_tg_schema() - - -def _start_engine(): - # TestGen database - dbhost = get_tg_host() - dbport = get_tg_port() - dbname = get_tg_db() - # User Information - dbuser = get_tg_username() - dbpw = get_tg_password() - - conn_str = "postgresql://" + dbuser + ":" + quote_plus(dbpw) + "@" + dbhost + ":" + dbport + "/" + dbname - return create_engine(conn_str) - - -def _make_connection(): - engine = _start_engine() - return engine - - -def make_header_db_friendly(str_header): - return str_header.replace(" ", "_").lower() - - -def make_value_db_friendly(value): - if value is None or pd.isna(value): - newval = "NULL" - else: - newval = str(value) if isinstance(value, int | float) else f"'{value}'" - return newval - - -def retrieve_data(str_sql): - tg_engine = _start_engine() - # Retrieve data from Postgres - return pd.read_sql_query(str_sql, tg_engine) - - -def retrieve_data_list(str_sql): - tg_engine = _start_engine() - # Retrieve data from Postgres - with tg_engine.connect() as con: - return con.execute(text(str_sql)).fetchall() - - -def retrieve_single_result(str_sql): - tg_engine = _start_engine() - with tg_engine.connect() as con: - lstResult = con.execute(text(str_sql)).fetchone() - if lstResult: - return lstResult[0] - - -def execute_sql(str_sql) -> CursorResult | None: - if str_sql > "": - tg_engine = _start_engine() - return tg_engine.execute(text(str_sql)) +from __future__ import annotations +from typing import TYPE_CHECKING -def execute_sql_raw(str_sql): - # For special cases where SQLAlchemy can't handle query syntax - if str_sql > "": - tg_engine = _start_engine() - con = tg_engine.raw_connection() - with con.cursor() as cur: - cur.execute(str_sql) - con.commit() - - -def _get_df_edits(df_original: pd.DataFrame, df_edited: pd.DataFrame, lst_id_columns: list) -> tuple: - # Rows in df_edited that exist in df_original but have had any column changed - # based on composite ID columns - - # Merge the two dataframes based on the composite ID columns - merged_df = df_edited.merge(df_original, on=lst_id_columns, how="outer", indicator=True, suffixes=("", "_original")) - # Filter the merged dataframe to only keep rows that are changed - # Step 1: Filter rows that exist in both dataframes - both_rows = merged_df[merged_df["_merge"] == "both"] - - # Step 2: Identify changed rows - def has_changes(row): - for col in df_original.columns: - # Skip the ID columns - if col in lst_id_columns: - continue - if row[col] != row[col + "_original"]: - return True - return False - - changed_rows_mask = both_rows.apply(has_changes, axis=1) - - # Step 3: Combine the filters - changed_rows = both_rows[changed_rows_mask] - - # All rows in df_edited that are newly created and don't exist in df_original - new_rows = merged_df[merged_df["_merge"] == "left_only"].drop( - columns=["_merge"] + [col + "_original" for col in df_original.columns if col not in lst_id_columns] - ) - - # All rows in df_original that have been deleted from df_edited - deleted_rows = merged_df[merged_df["_merge"] == "right_only"][df_original.columns] - - return changed_rows, new_rows, deleted_rows - - -def _gen_df_update_sql( - changed_rows: pd.DataFrame, table_name: str, lst_id_columns: list, no_update_columns: list -) -> list: - # Generate a list of SQL UPDATE statements based on the changed rows. - - # Extract the original column names by removing the "_original" suffix - original_columns = [col.replace("_original", "") for col in changed_rows.columns if col.endswith("_original")] - # Drop columns we aren't updating from list - update_columns = [col for col in original_columns if col not in no_update_columns] - - # Generate SQL UPDATE statements - sql_statements = [] - for _, row in changed_rows.iterrows(): - set_statements = [] - for col in update_columns: - # If the value is different from the original value - if row[col] != row[col + "_original"]: - value = make_value_db_friendly(row[col]) - set_statements.append(f"{col} = {value}") - - # Handle composite keys for the WHERE clause - where_statements = [] - for col in lst_id_columns: - value = make_value_db_friendly(row[col]) - # value = f"'{row[col]}'" if isinstance(row[col], str) else row[col] - where_statements.append(f"{col} = {value}") - - update_statement = f"UPDATE {get_schema()}.{table_name} SET {', '.join(set_statements)} WHERE {' AND '.join(where_statements)};" - sql_statements.append(update_statement) - - return sql_statements - - -def _gen_df_delete_sql(deleted_rows: pd.DataFrame, table_name: str, lst_id_columns: list) -> list: - # Generate a list of SQL DELETE statements based on the deleted rows. +import pandas as pd - # Generate SQL DELETE statements - sql_statements = [] - for _, row in deleted_rows.iterrows(): - # Handle composite keys for the WHERE clause - where_statements = [] - for col in lst_id_columns: - value = make_value_db_friendly(row[col]) - # value = f"'{row[col]}'" if isinstance(row[col], str) else row[col] - where_statements.append(f"{col} = {value}") +from testgen.utils import to_dataframe - delete_statement = f"DELETE FROM {get_schema()}.{table_name} WHERE {' AND '.join(where_statements)};" - sql_statements.append(delete_statement) +if TYPE_CHECKING: + from testgen.common.models.connection import Connection - return sql_statements +from typing import Any +from sqlalchemy import create_engine, text +from sqlalchemy.engine import Row, RowMapping +from sqlalchemy.engine.cursor import CursorResult -def _gen_insert_sql( - new_rows: pd.DataFrame, - table_name: str, - lst_id_columns: list, - no_update_columns: list, - dct_hard_default_columns: dict, -) -> str: - # Generate a SQL INSERT statement for the new rows, ensuring strings are properly quoted. +from testgen.common.database.database_service import get_flavor_service +from testgen.common.models import get_current_session - # Remove the id column as it will be generated by the server - if lst_id_columns: - new_rows = new_rows.drop(columns=lst_id_columns) - if no_update_columns: - # Remove columns we aren't updating - new_rows = new_rows.drop(columns=no_update_columns) - if dct_hard_default_columns: - # Add and default all columns - new_rows = new_rows.assign(**dct_hard_default_columns) - # Generate column names and values for the INSERT statement - columns = ", ".join(new_rows.columns) +def execute_db_query(query: str, params: dict | None = None) -> Any: + db_session = get_current_session() + cursor: CursorResult = db_session.execute(text(query), params) + try: + result = cursor.fetchone()[0] + except: + result = None + db_session.commit() + return result - # Ensure strings are quoted - values = [] - for _, row in new_rows.iterrows(): - row_values = [] - for val in row: - row_values.append(make_value_db_friendly(val)) - # if isinstance(val, str): - # row_values.append(f"'{val}'") - # else: - # row_values.append(str(val)) - values.append(f"({', '.join(row_values)})") - if values: - values_str = ", ".join(values) - # Construct the SQL INSERT statement - sql_statement = f"INSERT INTO {get_schema()}.{table_name} ({columns}) VALUES {values_str};" - return sql_statement +def fetch_all_from_db(query: str, params: dict | None = None) -> list[RowMapping]: + db_session = get_current_session() + cursor: CursorResult = db_session.execute(text(query), params) + return cursor.mappings().all() -def apply_df_edits(df_original, df_edited, str_table, lst_id_columns, no_update_columns, dct_hard_default_columns): - booStatus = False - df_changed, df_new, df_deleted = _get_df_edits(df_original, df_edited, lst_id_columns) +# Only use this for old parts of the app that still use dataframes +# Prefer to use fetch_all_from_db instead and avoid usage of pandas +def fetch_df_from_db(query: str, params: dict | None = None) -> pd.DataFrame: + db_session = get_current_session() + cursor: CursorResult = db_session.execute(text(query), params) + results = cursor.mappings().all() + columns = cursor.keys() + return to_dataframe(results, columns) - # Generate SQL UPDATE statements - lst_update_SQL = _gen_df_update_sql(df_changed, str_table, lst_id_columns, no_update_columns) - if lst_update_SQL: - for str_sql in lst_update_SQL: - execute_sql(str_sql) - booStatus = True - # Generate SQL DELETE statements - lst_delete_SQL = _gen_df_delete_sql(df_deleted, str_table, lst_id_columns) - if lst_delete_SQL: - for str_sql in lst_delete_SQL: - execute_sql(str_sql) - booStatus = True - # Generate SQL INSERT statements - str_insert_sql = _gen_insert_sql(df_new, str_table, lst_id_columns, no_update_columns, dct_hard_default_columns) - if str_insert_sql: - execute_sql(str_insert_sql) - booStatus = True - return booStatus +def fetch_one_from_db(query: str, params: dict | None = None) -> RowMapping | None: + db_session = get_current_session() + cursor: CursorResult = db_session.execute(text(query), params) + result = cursor.first() + return result._mapping if result else None -def _start_target_db_engine(flavor, host, port, db_name, user, password, url, connect_by_url, connect_by_key, private_key, private_key_passphrase, http_path): - connection_params = { - "flavor": flavor if flavor != "redshift" else "postgresql", - "user": user, - "host": host, - "port": port, - "dbname": db_name, - "url": url, - "connect_by_url": connect_by_url, - "connect_by_key": connect_by_key, - "private_key": private_key, - "private_key_passphrase": private_key_passphrase, - "http_path": http_path, - "dbschema": None, - } - flavor_service = get_flavor_service(flavor) - flavor_service.init(connection_params) - connection_string = flavor_service.get_connection_string(password) +def fetch_from_target_db(connection: Connection, query: str, params: dict | None = None) -> list[Row]: + flavor_service = get_flavor_service(connection.sql_flavor) + flavor_service.init(connection.to_dict()) + connection_string = flavor_service.get_connection_string() connect_args = flavor_service.get_connect_args() - return create_engine(connection_string, connect_args=connect_args) - - -def retrieve_target_db_data(flavor, host, port, db_name, user, password, url, connect_by_url, connect_by_key, private_key, private_key_passphrase, http_path, sql_query, decrypt=False): - if decrypt: - password = DecryptText(password) - db_engine = _start_target_db_engine(flavor, host, port, db_name, user, password, url, connect_by_url, connect_by_key, private_key, private_key_passphrase, http_path) - with db_engine.connect() as connection: - query_result = connection.execute(text(sql_query)) - return query_result.fetchall() - + engine = create_engine(connection_string, connect_args=connect_args) -def retrieve_target_db_df(flavor, host, port, db_name, user, password, sql_query, url, connect_by_url, connect_by_key, private_key, private_key_passphrase, http_path): - if password: - password = DecryptText(password) - db_engine = _start_target_db_engine(flavor, host, port, db_name, user, password, url, connect_by_url, connect_by_key, private_key, private_key_passphrase, http_path) - return pd.read_sql_query(text(sql_query), db_engine) + with engine.connect() as connection: + cursor: CursorResult = connection.execute(text(query), params) + return cursor.fetchall() diff --git a/testgen/ui/services/hygiene_issues_service.py b/testgen/ui/services/hygiene_issues_service.py deleted file mode 100644 index 3fba8539..00000000 --- a/testgen/ui/services/hygiene_issues_service.py +++ /dev/null @@ -1,93 +0,0 @@ -import streamlit as st - -from testgen.common.read_file import replace_templated_functions -from testgen.ui.services import database_service as db - - -def get_source_data(hi_data, limit: int | None = None): - str_schema = st.session_state["dbschema"] - # Define the query - str_sql = f""" - SELECT t.lookup_query, tg.table_group_schema, - c.sql_flavor, c.project_host, c.project_port, c.project_db, c.project_user, c.project_pw_encrypted, - c.url, c.connect_by_url, c.connect_by_key, c.private_key, c.private_key_passphrase, c.http_path - FROM {str_schema}.target_data_lookups t - INNER JOIN {str_schema}.table_groups tg - ON ('{hi_data["table_groups_id"]}'::UUID = tg.id) - INNER JOIN {str_schema}.connections c - ON (tg.connection_id = c.connection_id) - AND (t.sql_flavor = c.sql_flavor) - WHERE t.error_type = 'Profile Anomaly' - AND t.test_id = '{hi_data["anomaly_id"]}' - AND t.lookup_query > ''; - """ - - def get_lookup_query(test_id, detail_exp, column_names, sql_flavor): - if test_id in {"1019", "1020"}: - start_index = detail_exp.find("Columns: ") - if start_index == -1: - columns = [col.strip() for col in column_names.split(",")] - else: - start_index += len("Columns: ") - column_names_str = detail_exp[start_index:] - columns = [col.strip() for col in column_names_str.split(",")] - quote = "`" if sql_flavor == "databricks" else '"' - queries = [ - f"SELECT '{column}' AS column_name, MAX({quote}{column}{quote}) AS max_date_available FROM {{TARGET_SCHEMA}}.{{TABLE_NAME}}" - for column in columns - ] - sql_query = " UNION ALL ".join(queries) + " ORDER BY max_date_available DESC;" - else: - sql_query = "" - return sql_query - - def replace_parms(str_query): - str_query = ( - get_lookup_query(hi_data["anomaly_id"], hi_data["detail"], hi_data["column_name"], lst_query[0]["sql_flavor"]) - if lst_query[0]["lookup_query"] == "created_in_ui" - else lst_query[0]["lookup_query"] - ) - str_query = str_query.replace("{TARGET_SCHEMA}", lst_query[0]["table_group_schema"]) - str_query = str_query.replace("{TABLE_NAME}", hi_data["table_name"]) - str_query = str_query.replace("{COLUMN_NAME}", hi_data["column_name"]) - str_query = str_query.replace("{DETAIL_EXPRESSION}", hi_data["detail"]) - str_query = str_query.replace("{PROFILE_RUN_DATE}", hi_data["profiling_starttime"]) - str_query = replace_templated_functions(str_query, lst_query[0]["sql_flavor"]) - - if str_query is None or str_query == "": - raise ValueError("Lookup query is not defined for this Anomaly Type.") - return str_query - - try: - # Retrieve SQL for customer lookup - lst_query = db.retrieve_data_list(str_sql) - - # Retrieve and return data as df - if lst_query: - str_sql = replace_parms(str_sql) - df = db.retrieve_target_db_df( - lst_query[0]["sql_flavor"], - lst_query[0]["project_host"], - lst_query[0]["project_port"], - lst_query[0]["project_db"], - lst_query[0]["project_user"], - lst_query[0]["project_pw_encrypted"], - str_sql, - lst_query[0]["url"], - lst_query[0]["connect_by_url"], - lst_query[0]["connect_by_key"], - lst_query[0]["private_key"], - lst_query[0]["private_key_passphrase"], - lst_query[0]["http_path"], - ) - if df.empty: - return "ND", "Data that violates Hygiene Issue criteria is not present in the current dataset.", str_sql, None - else: - if limit: - df = df.sample(n=min(len(df), limit)).sort_index() - return "OK", None, str_sql, df - else: - return "NA", "Source data lookup is not available for this Issue.", None, None - - except Exception as e: - return "ERR", f"Source data lookup query caused an error:\n\n{e.args[0]}", None, None diff --git a/testgen/ui/services/project_service.py b/testgen/ui/services/project_service.py deleted file mode 100644 index 2f4443cf..00000000 --- a/testgen/ui/services/project_service.py +++ /dev/null @@ -1,42 +0,0 @@ -import streamlit as st - -from testgen.ui.queries import project_queries -from testgen.ui.services import database_service, query_service -from testgen.ui.session import session - - -@st.cache_data(show_spinner=False) -def get_projects(): - projects = project_queries.get_projects() - projects = [ - {"code": project["project_code"], "name": project["project_name"]} for project in projects.to_dict("records") - ] - - return projects - - -def set_sidebar_project(project_code: str) -> None: - if project_code != session.sidebar_project: - session.sidebar_project = project_code - st.rerun() - - -@st.cache_data(show_spinner=False) -def get_project_by_code(code: str): - if not code: - return None - return query_service.get_project_by_code(session.dbschema, code) - - -def edit_project(project: dict): - schema = st.session_state["dbschema"] - query = f""" - UPDATE {schema}.projects - SET - project_name = '{project["project_name"]}', - observability_api_url = '{project["observability_api_url"]}', - observability_api_key = '{project["observability_api_key"]}' - WHERE id = '{project["id"]}'; - """ - database_service.execute_sql(query) - st.cache_data.clear() diff --git a/testgen/ui/services/query_service.py b/testgen/ui/services/query_service.py deleted file mode 100644 index 12c3d7dc..00000000 --- a/testgen/ui/services/query_service.py +++ /dev/null @@ -1,257 +0,0 @@ -import pandas as pd - -import testgen.ui.services.database_service as db - -""" -Shared queries for standard lookups - - should be called by cached functions within page -""" - - -def run_project_lookup_query(str_schema): - str_sql = f""" - SELECT - id::VARCHAR(50), - project_code, - project_name, - effective_from_date, - observability_api_url, - observability_api_key - FROM {str_schema}.projects - ORDER BY project_name - """ - return db.retrieve_data(str_sql) - - -def get_project_by_code(schema: str, project_code: str): - str_sql = f""" - SELECT - id::VARCHAR(50), - project_code, - project_name, - effective_from_date, - observability_api_url, - observability_api_key - FROM {schema}.projects - WHERE project_code = {db.make_value_db_friendly(project_code)}; - """ - results = db.retrieve_data(str_sql) - if results.size <= 0: - return None - return results.iloc[0] - - -def run_test_type_lookup_query(str_schema, str_test_type=None, boo_show_referential=True, boo_show_table=True, - boo_show_column=True, boo_show_custom=True): - if str_test_type: - str_criteria = f" AND tt.test_type = '{str_test_type}'" - else: - str_criteria = "" - - if (boo_show_referential and boo_show_table and boo_show_column and boo_show_custom) == False: - str_scopes = "" - str_scopes += "'referential'," if boo_show_referential else "" - str_scopes += "'table'," if boo_show_table else "" - str_scopes += "'column'," if boo_show_column else "" - str_scopes += "'custom'," if boo_show_custom else "" - if str_scopes > "": - str_criteria += f"AND tt.test_scope in ({str_scopes[:-1]})" - - str_sql = f""" - SELECT tt.id, tt.test_type, tt.id as cat_test_id, - tt.test_name_short, tt.test_name_long, tt.test_description, - tt.measure_uom, COALESCE(tt.measure_uom_description, '') as measure_uom_description, - tt.default_parm_columns, tt.default_severity, - tt.run_type, tt.test_scope, tt.dq_dimension, tt.threshold_description, - tt.column_name_prompt, tt.column_name_help, - tt.default_parm_prompts, tt.default_parm_help, tt.usage_notes, - CASE tt.test_scope WHEN 'referential' THEN '⧉ ' WHEN 'custom' THEN '⛭ ' WHEN 'table' THEN '⊞ ' WHEN 'column' THEN 'â‰Ŗ ' ELSE '? ' END - || tt.test_name_short || ': ' || lower(tt.test_name_long) - || CASE WHEN tt.selection_criteria > '' THEN ' [auto-generated]' ELSE '' END as select_name - FROM {str_schema}.test_types tt - WHERE tt.active = 'Y' {str_criteria} - ORDER BY CASE tt.test_scope WHEN 'referential' THEN 1 WHEN 'custom' THEN 2 WHEN 'table' THEN 3 WHEN 'column' THEN 4 ELSE 5 END, - tt.test_name_short; - """ - return db.retrieve_data(str_sql) - - -def run_connections_lookup_query(str_schema, str_project_code): - str_sql = f""" - SELECT c.id::VARCHAR(50), c.connection_id, c.connection_name - FROM {str_schema}.connections c - WHERE c.project_code = '{str_project_code}' - ORDER BY connection_name - """ - return db.retrieve_data(str_sql) - - -def run_table_groups_lookup_query(schema: str, project_code: str, connection_id: str | None = None, table_group_id: str | None = None) -> pd.DataFrame: - sql = f""" - SELECT tg.id::VARCHAR(50), tg.table_groups_name, tg.connection_id, tg.table_group_schema - FROM {schema}.table_groups tg - """ - - if connection_id: - sql += f""" - inner join {schema}.connections c on c.connection_id = tg.connection_id - """ - - sql += f""" - WHERE tg.project_code = '{project_code}' - """ - - if table_group_id: - sql += f""" - AND tg.id = '{table_group_id}'::UUID - """ - - if connection_id: - sql += f""" - AND c.id = '{connection_id}'::UUID - """ - - sql += """ - ORDER BY table_groups_name - """ - return db.retrieve_data(sql) - - -def run_test_suite_lookup_by_tgroup_query(str_schema, str_table_groups_id, test_suite_name=None): - str_sql = f""" - SELECT id::VARCHAR(50), test_suite, test_suite_schema, severity, export_to_observability - FROM {str_schema}.test_suites - WHERE table_groups_id = '{str_table_groups_id}' - """ - - if test_suite_name: - str_sql += f""" - AND test_suite = '{test_suite_name}' - """ - - str_sql += """ - ORDER BY test_suite - """ - - return db.retrieve_data(str_sql) - - -def run_test_suite_lookup_by_project_query(str_schema, str_project): - str_sql = f""" - SELECT s.id::VARCHAR(50), s.test_suite, s.test_suite_schema, - s.test_suite - || CASE - WHEN tg.table_groups_name IS NULL THEN '' - ELSE '(' || tg.table_groups_name || ')' - END as test_suite_with_tg, - s.test_suite_description - FROM {str_schema}.test_suites s - LEFT JOIN {str_schema}.table_groups tg - ON (s.table_groups_id = tg.id) - WHERE s.project_code = '{str_project}' - ORDER BY s.test_suite - """ - return db.retrieve_data(str_sql) - - -def run_test_run_lookup_by_date(str_schema, str_project_code, str_run_date): - str_sql = f""" - SELECT - r.id::VARCHAR(50), - r.test_starttime::VARCHAR || ' - ' || s.test_suite as test_run_desc - FROM {str_schema}.test_runs r - LEFT JOIN {str_schema}.test_suites s ON r.test_suite_id = s.id) - WHERE - s.project_code = '{str_project_code}' - AND r.test_starttime::DATE = '{str_run_date}' - ORDER BY r.test_starttime DESC - """ - return db.retrieve_data(str_sql) - - -def update_anomaly_disposition(selected, str_schema, str_new_status): - def finalize_small_update(status, ids): - return f"""UPDATE {str_schema}.profile_anomaly_results - SET disposition = NULLIF('{status}', 'No Decision') - WHERE id IN ({ids});""" - - def finalize_big_update(status, ids): - return f"""WITH selects - as ( SELECT UNNEST(ARRAY [{ids}]) AS selected_id ) - UPDATE {str_schema}.profile_anomaly_results - SET disposition = NULLIF('{status}', 'No Decision') - FROM {str_schema}.profile_anomaly_results r - INNER JOIN selects s - ON (r.id = s.selected_id) - WHERE r.id = profile_anomaly_results.id;""" - - lst_ids = [row["id"] for row in selected if "id" in row] - lst_updates = [] - str_ids = "" - - if len(lst_ids) > 0: - for my_id in lst_ids: - str_ids += f" '{my_id}'::UUID," - str_ids = str_ids.rstrip(",") - if len(lst_ids) > 4: - lst_updates.append(finalize_big_update(str_new_status, str_ids)) - else: - lst_updates.append(finalize_small_update(str_new_status, str_ids)) - - for q in lst_updates: - db.execute_sql_raw(q) - - return True - - -def update_result_disposition(selected, str_schema, str_new_status): - active_yn = "N" if str_new_status == "Inactive" else "Y" - - def finalize_small_update(status, ids): - return f"""UPDATE {str_schema}.test_results - SET disposition = NULLIF('{status}', 'No Decision') - WHERE id IN ({ids}) and result_status != 'Passed';""" - - def finalize_big_update(status, ids): - return f"""WITH selects - as ( SELECT UNNEST(ARRAY [{ids}]) AS selected_id ) - UPDATE {str_schema}.test_results - SET disposition = NULLIF('{status}', 'No Decision') - FROM {str_schema}.test_results r - INNER JOIN selects s - ON (r.id = s.selected_id) - WHERE r.id = test_results.id and result_status != 'Passed';""" - - def finalize_test_update(ids): - str_lock_test = ", lock_refresh = 'N'" if active_yn == "Y" else ", lock_refresh = 'Y'" - return f"""WITH selects - as ( SELECT UNNEST(ARRAY [{ids}]) AS selected_id ) - UPDATE {str_schema}.test_definitions - SET test_active = '{active_yn}', - last_manual_update = CURRENT_TIMESTAMP AT TIME ZONE 'UTC' {str_lock_test} - FROM {str_schema}.test_definitions d - INNER JOIN {str_schema}.test_results r - ON (d.id = r.test_definition_id) - INNER JOIN selects s - ON (r.id = s.selected_id) - WHERE d.id = test_definitions.id""" - - lst_ids = [row["test_result_id"] for row in selected if "test_result_id" in row] - lst_updates = [] - str_ids = "" - - for my_id in lst_ids: - str_ids += f" '{my_id}'::UUID," - str_ids = str_ids.rstrip(",") - - if len(lst_ids) > 0: - if len(lst_ids) > 4: - lst_updates.append(finalize_big_update(str_new_status, str_ids)) - else: - lst_updates.append(finalize_small_update(str_new_status, str_ids)) - lst_updates.append(finalize_test_update(str_ids)) - - for q in lst_updates: - db.execute_sql_raw(q) - - return True diff --git a/testgen/ui/services/table_group_service.py b/testgen/ui/services/table_group_service.py deleted file mode 100644 index 66097f51..00000000 --- a/testgen/ui/services/table_group_service.py +++ /dev/null @@ -1,180 +0,0 @@ -import streamlit as st - -import testgen.ui.queries.table_group_queries as table_group_queries -import testgen.ui.services.connection_service as connection_service -import testgen.ui.services.test_suite_service as test_suite_service -from testgen.common.database.database_service import RetrieveDBResultsToDictList -from testgen.common.models.scores import ScoreDefinition - - -def get_by_id(table_group_id: str): - schema = st.session_state["dbschema"] - return table_group_queries.get_by_id(schema, table_group_id).iloc[0] - - -def get_by_connection(project_code, connection_id): - schema = st.session_state["dbschema"] - return table_group_queries.get_by_connection(schema, project_code, connection_id) - - -def get_all(project_code: str): - schema = st.session_state["dbschema"] - return table_group_queries.get_all(schema, project_code) - - -def edit(table_group): - schema = st.session_state["dbschema"] - table_group_queries.edit(schema, table_group) - - -def add(table_group: dict) -> str: - schema = st.session_state["dbschema"] - table_group_id = table_group_queries.add(schema, table_group) - if table_group.get("add_scorecard_definition", True): - ScoreDefinition.from_table_group(table_group).save() - return table_group_id - - -def cascade_delete(table_group_names, dry_run=False): - schema = st.session_state["dbschema"] - test_suite_ids = get_test_suite_ids_by_table_group_names(table_group_names) - - can_be_deleted = not any( - ( - table_group_has_dependencies(table_group_names), - test_suite_service.has_test_suite_dependencies(test_suite_ids), - ) - ) - - if not dry_run: - test_suite_service.cascade_delete(test_suite_ids) - table_group_queries.cascade_delete(schema, table_group_names) - return can_be_deleted - - -def table_group_has_dependencies(table_group_names): - if not table_group_names: - return False - schema = st.session_state["dbschema"] - return not table_group_queries.get_table_group_dependencies(schema, table_group_names).empty - - -def are_table_groups_in_use(table_group_names: list[str]): - if not table_group_names: - return False - - schema = st.session_state["dbschema"] - - test_suite_ids = get_test_suite_ids_by_table_group_names(table_group_names) - test_suites_in_use = test_suite_service.are_test_suites_in_use(test_suite_ids) - - table_groups_in_use_result = table_group_queries.get_table_group_usage(schema, table_group_names) - table_groups_in_use = not table_groups_in_use_result.empty - - return test_suites_in_use or table_groups_in_use - - -def is_table_group_used(table_group_id: str) -> bool: - schema = st.session_state["dbschema"] - test_suite_ids = table_group_queries.get_test_suite_ids_by_table_group_id(schema, table_group_id) - proling_run_ids = table_group_queries.get_profiling_run_ids_by_table_group_id(schema, table_group_id) - - return len(test_suite_ids) + len(proling_run_ids) > 0 - - -def get_test_suite_ids_by_table_group_names(table_group_names): - if not table_group_names: - return [] - schema = st.session_state["dbschema"] - result = table_group_queries.get_test_suite_ids_by_table_group_names(schema, table_group_names) - return result.to_dict()["id"].values() - - -def get_table_group_preview(project_code: str, connection: dict | None, table_group: dict) -> dict: - table_group_preview = { - "schema": table_group["table_group_schema"], - "tables": set(), - "column_count": 0, - "success": True, - "message": None, - } - if connection: - try: - table_group_results = test_table_group(table_group, connection, project_code) - - for column in table_group_results: - table_group_preview["schema"] = column["table_schema"] - table_group_preview["tables"].add(column["table_name"]) - table_group_preview["column_count"] += 1 - - if len(table_group_results) <= 0: - table_group_preview["success"] = False - table_group_preview["message"] = ( - "No tables found matching the criteria. Please check the Table Group configuration" - " or the database permissions." - ) - except Exception as error: - table_group_preview["success"] = False - table_group_preview["message"] = error.args[0] - else: - table_group_preview["success"] = False - table_group_preview["message"] = "No connection selected. Please select a connection to preview the Table Group." - - table_group_preview["tables"] = list(table_group_preview["tables"]) - return table_group_preview - - -def test_table_group(table_group, connection, project_code): - connection_id = str(connection["connection_id"]) - - # get table group data - table_group_schema = table_group["table_group_schema"] - table_group_id = table_group["id"] - profiling_table_set = table_group["profiling_table_set"] - profiling_include_mask = table_group["profiling_include_mask"] - profiling_exclude_mask = table_group["profiling_exclude_mask"] - profile_id_column_mask = table_group["profile_id_column_mask"] - profile_sk_column_mask = table_group["profile_sk_column_mask"] - profile_use_sampling = "Y" if table_group["profile_use_sampling"] else "N" - profile_sample_percent = table_group["profile_sample_percent"] - profile_sample_min_count = table_group["profile_sample_min_count"] - - clsProfiling = connection_service.init_profiling_sql(project_code, connection, table_group_schema) - - # Set General Parms - clsProfiling.table_groups_id = table_group_id - clsProfiling.connection_id = connection_id - clsProfiling.parm_do_sample = "N" - clsProfiling.parm_sample_size = 0 - clsProfiling.parm_vldb_flag = "N" - clsProfiling.parm_do_freqs = "Y" - clsProfiling.parm_max_freq_length = 25 - clsProfiling.parm_do_patterns = "Y" - clsProfiling.parm_max_pattern_length = 25 - clsProfiling.profile_run_id = "" - clsProfiling.data_schema = table_group_schema - clsProfiling.parm_table_set = get_profiling_table_set_with_quotes(profiling_table_set) - clsProfiling.parm_table_include_mask = profiling_include_mask - clsProfiling.parm_table_exclude_mask = profiling_exclude_mask - clsProfiling.profile_id_column_mask = profile_id_column_mask - clsProfiling.profile_sk_column_mask = profile_sk_column_mask - clsProfiling.profile_use_sampling = profile_use_sampling - clsProfiling.profile_sample_percent = profile_sample_percent - clsProfiling.profile_sample_min_count = profile_sample_min_count - - query = clsProfiling.GetDDFQuery() - table_group_results = RetrieveDBResultsToDictList("PROJECT", query) - - return table_group_results - - -def get_profiling_table_set_with_quotes(profiling_table_set): - if not profiling_table_set: - return profiling_table_set - - aux_list = [] - split = profiling_table_set.split(",") - for item in split: - aux_list.append(f"'{item.strip()}'") - profiling_table_set = ",".join(aux_list) - return profiling_table_set diff --git a/testgen/ui/services/test_definition_service.py b/testgen/ui/services/test_definition_service.py deleted file mode 100644 index 2d18af89..00000000 --- a/testgen/ui/services/test_definition_service.py +++ /dev/null @@ -1,168 +0,0 @@ -import streamlit as st - -import testgen.ui.queries.test_definition_queries as test_definition_queries -import testgen.ui.services.connection_service as connection_service -import testgen.ui.services.database_service as database_service -import testgen.ui.services.table_group_service as table_group_service -from testgen.ui.queries import test_run_queries - - -def update_attribute(test_definition_ids, attribute, value): - schema = st.session_state["dbschema"] - raw_value = "Y" if value else "N" - test_definition_queries.update_attribute(schema, test_definition_ids, attribute, raw_value) - - -def get_test_definitions( - project_code=None, test_suite=None, table_name=None, column_name=None, test_type=None, test_definition_ids=None, -): - schema = st.session_state["dbschema"] - return test_definition_queries.get_test_definitions( - schema, project_code, test_suite, table_name, column_name, test_type, test_definition_ids, - ) - - -def get_test_definition(db_schema, test_def_id): - str_sql = f""" - SELECT - d.id::VARCHAR, - tg.table_group_schema as schema_name, - ts.test_suite as test_suite_name, - d.export_to_observability as export_to_observability, - ts.export_to_observability as default_export_to_observability, - tt.test_name_short as test_name, - tt.test_name_long as full_name, - tt.test_description as description, - d.test_definition_status as status, - tt.usage_notes, - d.table_name, - d.column_name, - d.baseline_value, d.baseline_ct, d.baseline_unique_ct, d.baseline_value_ct, - d.baseline_avg, d.baseline_sd, d.threshold_value, d.baseline_sum, - d.lower_tolerance, d.upper_tolerance, - d.subset_condition, d.groupby_names, d.having_condition, d.match_schema_name, - d.match_table_name, d.match_column_names, d.match_subset_condition, - d.match_groupby_names, d.match_having_condition, - d.window_date_column, d.window_days::VARCHAR as window_days, - d.custom_query, d.test_mode, - d.severity, tt.default_severity, - d.test_active, d.lock_refresh, d.last_manual_update, - tt.default_parm_prompts, tt.default_parm_columns, tt.default_parm_help - FROM {db_schema}.test_definitions d - INNER JOIN {db_schema}.test_types tt - ON (d.test_type = tt.test_type) - INNER JOIN {db_schema}.test_suites ts - ON (ts.id = d.test_suite_id) - INNER JOIN {db_schema}.table_groups tg - ON (tg.id = d.table_groups_id) - WHERE d.id = '{test_def_id}'; - """ - return database_service.retrieve_data(str_sql) - - -def delete(test_definition_ids): - schema = st.session_state["dbschema"] - test_definition_queries.delete(schema, test_definition_ids) - - -def cascade_delete(test_suite_ids: list[str]): - schema = st.session_state["dbschema"] - test_run_queries.cascade_delete(test_suite_ids) - test_definition_queries.cascade_delete(schema, test_suite_ids) - - -def add(test_definition): - schema = st.session_state["dbschema"] - prepare_to_persist(test_definition) - test_definition_queries.add(schema, test_definition) - - -def update(test_definition): - schema = st.session_state["dbschema"] - prepare_to_persist(test_definition) - return test_definition_queries.update(schema, test_definition) - - -def prepare_to_persist(test_definition): - # severity - if test_definition["severity"] and test_definition["severity"].startswith("Inherited"): - test_definition["severity"] = None - - test_definition["export_to_observability"] = prepare_boolean_for_update( - test_definition["export_to_observability_raw"] - ) - test_definition["lock_refresh"] = prepare_boolean_for_update(test_definition["lock_refresh"]) - test_definition["test_active"] = prepare_boolean_for_update(test_definition["test_active"]) - - if test_definition["custom_query"] is not None: - test_definition["custom_query"] = test_definition["custom_query"].strip() - if test_definition["custom_query"].endswith(";"): - test_definition["custom_query"] = test_definition["custom_query"][:-1] - - empty_if_null(test_definition) - - -def empty_if_null(test_definition): - for k, v in test_definition.items(): - if v is None: - test_definition[k] = "" - - -def prepare_boolean_for_update(value): - if "Yes" == value or "Y" == value or value is True: - return "Y" - elif "No" == value or "N" == value or value is False: - return "N" - else: - return None - - -def validate_test(test_definition): - schema = test_definition["schema_name"] - table_name = test_definition["table_name"] - - if test_definition["test_type"] == "Condition_Flag": - condition = test_definition["custom_query"] - sql_query = f"""SELECT COALESCE(CAST(SUM(CASE WHEN {condition} THEN 1 ELSE 0 END) AS VARCHAR(1000) ) || '|' ,'|') FROM {schema}.{table_name}""" - else: - sql_query = test_definition["custom_query"] - sql_query = sql_query.replace("{DATA_SCHEMA}", schema) - - table_group_id = test_definition["table_groups_id"] - table_group = table_group_service.get_by_id(table_group_id) - - connection_id = table_group["connection_id"] - - connection = connection_service.get_by_id(connection_id, hide_passwords=False) - - database_service.retrieve_target_db_data( - connection["sql_flavor"], - connection["project_host"], - connection["project_port"], - connection["project_db"], - connection["project_user"], - connection["password"], - connection["url"], - connection["connect_by_url"], - connection["connect_by_key"], - connection["private_key"], - connection["private_key_passphrase"], - connection["http_path"], - sql_query, - ) - - -def move(test_definitions, target_table_group, target_test_suite, target_table_column=None): - schema = st.session_state["dbschema"] - test_definition_queries.move(schema, test_definitions, target_table_group, target_test_suite, target_table_column) - - - -def copy(test_definitions, target_table_group, target_test_suite, target_table_column=None): - schema = st.session_state["dbschema"] - test_definition_queries.copy(schema, test_definitions, target_table_group, target_test_suite, target_table_column) - - -def get_test_definitions_collision(test_definitions, target_table_group, target_test_suite): - schema = st.session_state["dbschema"] - return test_definition_queries.get_test_definitions_collision(schema, test_definitions, target_table_group, target_test_suite) diff --git a/testgen/ui/services/test_results_service.py b/testgen/ui/services/test_results_service.py deleted file mode 100644 index 2cdf327d..00000000 --- a/testgen/ui/services/test_results_service.py +++ /dev/null @@ -1,332 +0,0 @@ -import pandas as pd - -from testgen.common import ConcatColumnList -from testgen.common.models import get_current_session, with_database_session -from testgen.common.read_file import replace_templated_functions -from testgen.ui.services import database_service as db -from testgen.ui.services.string_service import empty_if_null -from testgen.ui.services.test_definition_service import get_test_definition - - -@with_database_session -def get_test_results( - _: str, - run_id: str, - test_status: str | list[str] | None = None, - test_type_id: str | None = None, - table_name: str | None = None, - column_name: str | None = None, - sorting_columns: list[str] | None = None, -) -> pd.DataFrame: - # First visible row first, so multi-select checkbox will render - db_session = get_current_session() - params = {"run_id": run_id} - - order_by = "ORDER BY " + (", ".join(" ".join(col) for col in sorting_columns)) if sorting_columns else "" - filters = "" - if test_status: - if isinstance(test_status, str): - test_status = [status.strip() for status in test_status.split(",")] - test_status_params = {f"test_status_{idx}": status for idx, status in enumerate(test_status)} - - filters += f" AND r.result_status IN ({', '.join([f':{p}' for p in test_status_params.keys()])})" - params.update(test_status_params) - if test_type_id: - filters += " AND r.test_type = :test_type_id" - params["test_type_id"] = test_type_id - if table_name: - filters += " AND r.table_name = :table_name" - params["table_name"] = table_name - if column_name: - filters += " AND r.column_names ILIKE :column_name" - params["column_name"] = column_name - - sql = f""" - WITH run_results AS ( - SELECT * - FROM test_results r - WHERE r.test_run_id = :run_id - {filters} - ) - SELECT r.table_name, - p.project_name, ts.test_suite, tg.table_groups_name, cn.connection_name, cn.project_host, cn.sql_flavor, - tt.dq_dimension, tt.test_scope, - r.schema_name, r.column_names, r.test_time::DATE as test_date, r.test_type, tt.id as test_type_id, - tt.test_name_short, tt.test_name_long, r.test_description, tt.measure_uom, tt.measure_uom_description, - c.test_operator, r.threshold_value::NUMERIC(16, 5), r.result_measure::NUMERIC(16, 5), r.result_status, - CASE - WHEN r.result_code <> 1 THEN r.disposition - ELSE 'Passed' - END as disposition, - NULL::VARCHAR(1) as action, - r.input_parameters, r.result_message, CASE WHEN result_code <> 1 THEN r.severity END as severity, - r.result_code as passed_ct, - (1 - r.result_code)::INTEGER as exception_ct, - CASE - WHEN result_status = 'Warning' - AND result_message NOT ILIKE 'Inactivated%%' THEN 1 - END::INTEGER as warning_ct, - CASE - WHEN result_status = 'Failed' - AND result_message NOT ILIKE 'Inactivated%%' THEN 1 - END::INTEGER as failed_ct, - CASE - WHEN result_message ILIKE 'Inactivated%%' THEN 1 - END as execution_error_ct, - p.project_code, r.table_groups_id::VARCHAR, - r.id::VARCHAR as test_result_id, r.test_run_id::VARCHAR, - c.id::VARCHAR as connection_id, r.test_suite_id::VARCHAR, - r.test_definition_id::VARCHAR as test_definition_id_runtime, - CASE - WHEN r.auto_gen = TRUE THEN d.id - ELSE r.test_definition_id - END::VARCHAR as test_definition_id_current, - r.auto_gen, - - -- These are used in the PDF report - tt.threshold_description, tt.usage_notes, r.test_time, - dcc.description as column_description, - COALESCE(dcc.critical_data_element, dtc.critical_data_element) as critical_data_element, - COALESCE(dcc.data_source, dtc.data_source, tg.data_source) as data_source, - COALESCE(dcc.source_system, dtc.source_system, tg.source_system) as source_system, - COALESCE(dcc.source_process, dtc.source_process, tg.source_process) as source_process, - COALESCE(dcc.business_domain, dtc.business_domain, tg.business_domain) as business_domain, - COALESCE(dcc.stakeholder_group, dtc.stakeholder_group, tg.stakeholder_group) as stakeholder_group, - COALESCE(dcc.transform_level, dtc.transform_level, tg.transform_level) as transform_level, - COALESCE(dcc.aggregation_level, dtc.aggregation_level) as aggregation_level, - COALESCE(dcc.data_product, dtc.data_product, tg.data_product) as data_product - - FROM run_results r - INNER JOIN test_types tt - ON (r.test_type = tt.test_type) - LEFT JOIN test_definitions rd - ON (r.test_definition_id = rd.id) - LEFT JOIN test_definitions d - ON (r.test_suite_id = d.test_suite_id - AND r.table_name = d.table_name - AND COALESCE(r.column_names, 'N/A') = COALESCE(d.column_name, 'N/A') - AND r.test_type = d.test_type - AND r.auto_gen = TRUE - AND d.last_auto_gen_date IS NOT NULL) - INNER JOIN test_suites ts - ON r.test_suite_id = ts.id - INNER JOIN projects p - ON (ts.project_code = p.project_code) - INNER JOIN table_groups tg - ON (ts.table_groups_id = tg.id) - INNER JOIN connections cn - ON (tg.connection_id = cn.connection_id) - LEFT JOIN cat_test_conditions c - ON (cn.sql_flavor = c.sql_flavor - AND r.test_type = c.test_type) - LEFT JOIN data_column_chars dcc - ON (tg.id = dcc.table_groups_id - AND r.schema_name = dcc.schema_name - AND r.table_name = dcc.table_name - AND r.column_names = dcc.column_name) - LEFT JOIN data_table_chars dtc - ON dcc.table_id = dtc.table_id - {order_by} - """ - - results = db_session.execute(sql, params=params) - columns = [column.name for column in results.cursor.description] - - df = pd.DataFrame(list(results), columns=columns) - df["test_date"] = pd.to_datetime(df["test_date"]) - - return df - - -def get_test_result_history(db_schema, tr_data, limit: int | None = None): - if tr_data["auto_gen"]: - if tr_data["column_names"]: - col_name_cond = f"column_names = '{tr_data["column_names"]}'" - else: - col_name_cond = "column_names IS NULL" - - str_where = f""" - WHERE test_suite_id = '{tr_data["test_suite_id"]}' - AND table_name = '{tr_data["table_name"]}' - AND {col_name_cond} - AND test_type = '{tr_data["test_type"]}' - AND auto_gen = TRUE - """ - else: - str_where = f""" - WHERE test_definition_id_runtime = '{tr_data["test_definition_id_runtime"]}' - """ - - str_sql = f""" - SELECT test_date, test_type, - test_name_short, test_name_long, measure_uom, test_operator, - threshold_value::NUMERIC, result_measure, result_status - FROM {db_schema}.v_test_results {str_where} - ORDER BY test_date DESC - {'LIMIT ' + str(limit) if limit else ''}; - """ - - df = db.retrieve_data(str_sql) - # Clean Up - df["test_date"] = pd.to_datetime(df["test_date"]) - - return df - - -def do_source_data_lookup_custom(db_schema, tr_data, limit: int | None = None): - # Define the query - str_sql = f""" - SELECT d.custom_query as lookup_query, tg.table_group_schema, - c.sql_flavor, c.project_host, c.project_port, c.project_db, c.project_user, c.project_pw_encrypted, - c.url, c.connect_by_url, c.connect_by_key, c.private_key, c.private_key_passphrase, c.http_path - FROM {db_schema}.test_definitions d - INNER JOIN {db_schema}.table_groups tg - ON ('{tr_data["table_groups_id"]}'::UUID = tg.id) - INNER JOIN {db_schema}.connections c - ON (tg.connection_id = c.connection_id) - WHERE d.id = '{tr_data["test_definition_id_current"]}'; - """ - - try: - # Retrieve SQL for customer lookup - lst_query = db.retrieve_data_list(str_sql) - - # Retrieve and return data as df - if lst_query: - str_sql = lst_query[0]["lookup_query"] - str_sql = str_sql.replace("{DATA_SCHEMA}", empty_if_null(lst_query[0]["table_group_schema"])) - df = db.retrieve_target_db_df( - lst_query[0]["sql_flavor"], - lst_query[0]["project_host"], - lst_query[0]["project_port"], - lst_query[0]["project_db"], - lst_query[0]["project_user"], - lst_query[0]["project_pw_encrypted"], - str_sql, - lst_query[0]["url"], - lst_query[0]["connect_by_url"], - lst_query[0]["connect_by_key"], - lst_query[0]["private_key"], - lst_query[0]["private_key_passphrase"], - lst_query[0]["http_path"], - ) - if df.empty: - return "ND", "Data that violates Test criteria is not present in the current dataset.", str_sql, None - else: - if limit: - df = df.sample(n=min(len(df), limit)).sort_index() - return "OK", None, str_sql, df - else: - return "NA", "Source data lookup is not available for this test.", None, None - - except Exception as e: - return "ERR", f"Source data lookup query caused an error:\n\n{e.args[0]}", str_sql, None - - -def do_source_data_lookup(db_schema, tr_data, sql_only=False, limit: int | None = None): - # Define the query - str_sql = f""" - SELECT t.lookup_query, tg.table_group_schema, - c.sql_flavor, c.project_host, c.project_port, c.project_db, c.project_user, c.project_pw_encrypted, - c.url, c.connect_by_url, - c.connect_by_key, c.private_key, c.private_key_passphrase, - c.http_path - FROM {db_schema}.target_data_lookups t - INNER JOIN {db_schema}.table_groups tg - ON ('{tr_data["table_groups_id"]}'::UUID = tg.id) - INNER JOIN {db_schema}.connections c - ON (tg.connection_id = c.connection_id) - AND (t.sql_flavor = c.sql_flavor) - WHERE t.error_type = 'Test Results' - AND t.test_id = '{tr_data["test_type_id"]}' - AND t.lookup_query > ''; - """ - - def replace_parms(df_test, str_query): - if df_test.empty: - raise ValueError("This test definition is no longer present.") - - str_query = str_query.replace("{TARGET_SCHEMA}", empty_if_null(lst_query[0]["table_group_schema"])) - str_query = str_query.replace("{TABLE_NAME}", empty_if_null(tr_data["table_name"])) - str_query = str_query.replace("{COLUMN_NAME}", empty_if_null(tr_data["column_names"])) - str_query = str_query.replace("{TEST_DATE}", str(empty_if_null(tr_data["test_date"]))) - - str_query = str_query.replace("{CUSTOM_QUERY}", empty_if_null(df_test.at[0, "custom_query"])) - str_query = str_query.replace("{BASELINE_VALUE}", empty_if_null(df_test.at[0, "baseline_value"])) - str_query = str_query.replace("{BASELINE_CT}", empty_if_null(df_test.at[0, "baseline_ct"])) - str_query = str_query.replace("{BASELINE_AVG}", empty_if_null(df_test.at[0, "baseline_avg"])) - str_query = str_query.replace("{BASELINE_SD}", empty_if_null(df_test.at[0, "baseline_sd"])) - str_query = str_query.replace("{THRESHOLD_VALUE}", empty_if_null(df_test.at[0, "threshold_value"])) - str_query = str_query.replace("{LOWER_TOLERANCE}", empty_if_null(df_test.at[0, "lower_tolerance"])) - str_query = str_query.replace("{UPPER_TOLERANCE}", empty_if_null(df_test.at[0, "upper_tolerance"])) - - str_substitute = empty_if_null(df_test.at[0, "subset_condition"]) - str_substitute = "1=1" if str_substitute == "" else str_substitute - str_query = str_query.replace("{SUBSET_CONDITION}", str_substitute) - - str_query = str_query.replace("{GROUPBY_NAMES}", empty_if_null(df_test.at[0, "groupby_names"])) - str_query = str_query.replace("{HAVING_CONDITION}", empty_if_null(df_test.at[0, "having_condition"])) - str_query = str_query.replace("{MATCH_SCHEMA_NAME}", empty_if_null(df_test.at[0, "match_schema_name"])) - str_query = str_query.replace("{MATCH_TABLE_NAME}", empty_if_null(df_test.at[0, "match_table_name"])) - str_query = str_query.replace("{MATCH_COLUMN_NAMES}", empty_if_null(df_test.at[0, "match_column_names"])) - - str_substitute = empty_if_null(df_test.at[0, "match_subset_condition"]) - str_substitute = "1=1" if str_substitute == "" else str_substitute - str_query = str_query.replace("{MATCH_SUBSET_CONDITION}", str_substitute) - - str_query = str_query.replace("{MATCH_GROUPBY_NAMES}", empty_if_null(df_test.at[0, "match_groupby_names"])) - str_query = str_query.replace("{MATCH_HAVING_CONDITION}", empty_if_null(df_test.at[0, "match_having_condition"])) - str_query = str_query.replace("{COLUMN_NAME_NO_QUOTES}", empty_if_null(tr_data["column_names"])) - - str_query = str_query.replace("{WINDOW_DATE_COLUMN}", empty_if_null(df_test.at[0, "window_date_column"])) - str_query = str_query.replace("{WINDOW_DAYS}", empty_if_null(df_test.at[0, "window_days"])) - - str_substitute = ConcatColumnList(tr_data["column_names"], "") - str_query = str_query.replace("{CONCAT_COLUMNS}", str_substitute) - str_substitute = ConcatColumnList(df_test.at[0, "match_groupby_names"], "") - str_query = str_query.replace("{CONCAT_MATCH_GROUPBY}", str_substitute) - - str_query = replace_templated_functions(str_query, lst_query[0]["sql_flavor"]) - - if str_query is None or str_query == "": - raise ValueError("Lookup query is not defined for this Test Type.") - return str_query - - try: - # Retrieve SQL for customer lookup - lst_query = db.retrieve_data_list(str_sql) - - if sql_only: - return lst_query, replace_parms, None - - # Retrieve and return data as df - if lst_query: - df_test = get_test_definition(db_schema, tr_data["test_definition_id_current"]) - - str_sql = replace_parms(df_test, lst_query[0]["lookup_query"]) - df = db.retrieve_target_db_df( - lst_query[0]["sql_flavor"], - lst_query[0]["project_host"], - lst_query[0]["project_port"], - lst_query[0]["project_db"], - lst_query[0]["project_user"], - lst_query[0]["project_pw_encrypted"], - str_sql, - lst_query[0]["url"], - lst_query[0]["connect_by_url"], - lst_query[0]["connect_by_key"], - lst_query[0]["private_key"], - lst_query[0]["private_key_passphrase"], - lst_query[0]["http_path"], - ) - if df.empty: - return "ND", "Data that violates Test criteria is not present in the current dataset.", str_sql, None - else: - if limit: - df = df.sample(n=min(len(df), limit)).sort_index() - return "OK", None, str_sql, df - else: - return "NA", "A source data lookup for this Test is not available.", None, None - - except Exception as e: - return "ERR", f"Source data lookup query caused:\n\n{e.args[0]}", str_sql, None diff --git a/testgen/ui/services/test_suite_service.py b/testgen/ui/services/test_suite_service.py deleted file mode 100644 index 9a6326a1..00000000 --- a/testgen/ui/services/test_suite_service.py +++ /dev/null @@ -1,91 +0,0 @@ -import pandas as pd -import streamlit as st - -import testgen.ui.queries.test_suite_queries as test_suite_queries -import testgen.ui.services.test_definition_service as test_definition_service -from testgen.utils import is_uuid4 - - -def get_by_project(project_code, table_group_id=None): - schema = st.session_state["dbschema"] - return test_suite_queries.get_by_project(schema, project_code, table_group_id) - - -def get_by_id(test_suite_id: str) -> pd.Series: - if not is_uuid4(test_suite_id): - return pd.Series() - - schema = st.session_state["dbschema"] - df = test_suite_queries.get_by_id(schema, test_suite_id) - if not df.empty: - return df.iloc[0] - else: - return pd.Series() - - -def edit(test_suite): - schema = st.session_state["dbschema"] - test_suite_queries.edit(schema, test_suite) - - -def add(test_suite): - schema = st.session_state["dbschema"] - test_suite_queries.add(schema, test_suite) - - -def cascade_delete(test_suite_ids, dry_run=False): - if not test_suite_ids: - return True - schema = st.session_state["dbschema"] - can_be_deleted = not has_test_suite_dependencies(test_suite_ids) - if not dry_run: - test_definition_service.cascade_delete(test_suite_ids) - test_suite_queries.delete(schema, test_suite_ids) - return can_be_deleted - - -def has_test_suite_dependencies(test_suite_ids: list[str]): - schema = st.session_state["dbschema"] - if not test_suite_ids: - return False - return not test_suite_queries.get_test_suite_dependencies(schema, test_suite_ids).empty - - -def are_test_suites_in_use(test_suite_ids: list[str]): - if not test_suite_ids: - return False - schema = st.session_state["dbschema"] - usage_result = test_suite_queries.get_test_suite_usage(schema, test_suite_ids) - return not usage_result.empty - - -def get_test_suite_refresh_warning(test_suite_id): - if not test_suite_id: - return False - schema = st.session_state["dbschema"] - row_result = test_suite_queries.get_test_suite_refresh_check(schema, test_suite_id) - - test_ct = None - unlocked_test_ct = None - unlocked_edits_ct = None - if row_result: - test_ct = row_result["test_ct"] - unlocked_test_ct = row_result["unlocked_test_ct"] - unlocked_edits_ct = row_result["unlocked_edits_ct"] - - return test_ct, unlocked_test_ct, unlocked_edits_ct - - -def get_generation_set_choices(): - schema = st.session_state["dbschema"] - dfSets = test_suite_queries.get_generation_sets(schema) - if dfSets.empty: - return None - else: - return dfSets["generation_set"].to_list() - - -def lock_edited_tests(test_suite_id): - schema = st.session_state["dbschema"] - tests_locked = test_suite_queries.lock_edited_tests(schema, test_suite_id) - return tests_locked diff --git a/testgen/ui/services/user_session_service.py b/testgen/ui/services/user_session_service.py index 463f454a..8bbb1000 100644 --- a/testgen/ui/services/user_session_service.py +++ b/testgen/ui/services/user_session_service.py @@ -1,18 +1,15 @@ import base64 import datetime import logging -import typing import extra_streamlit_components as stx import jwt import streamlit as st from testgen import settings -from testgen.ui.queries import user_queries +from testgen.common.models.user import RoleType, User from testgen.ui.session import session -RoleType = typing.Literal["admin", "data_quality", "analyst", "business", "catalog"] - AUTH_TOKEN_COOKIE_NAME = "dk_cookie_name" # noqa: S105 AUTH_TOKEN_EXPIRATION_DAYS = 1 DISABLED_ACTION_TEXT = "You do not have permissions to perform this action. Contact your administrator." @@ -72,11 +69,10 @@ def end_user_session() -> None: def get_auth_data(): - auth_data = user_queries.get_users(include_password=True) - + users = User.select_where() usernames = {} - for item in auth_data.itertuples(): + for item in users: usernames[item.username.lower()] = { "email": item.email, "name": item.name, diff --git a/testgen/ui/session.py b/testgen/ui/session.py index f64cff4e..cb8d028a 100644 --- a/testgen/ui/session.py +++ b/testgen/ui/session.py @@ -27,8 +27,6 @@ class TestgenSession(Singleton): page_args_pending_router: dict current_page: str - dbschema: str - name: str username: str authentication_status: bool @@ -61,6 +59,11 @@ def __delattr__(self, key: str) -> None: if key in state: del state[key] + def set_sidebar_project(self, project_code: str) -> None: + if project_code != self.sidebar_project: + self.sidebar_project = project_code + st.rerun() + def temp_value(session_key: str, *, default: T | None = None) -> tuple[TempValueGetter[T | None], TempValueSetter[T]]: scoped_session_key = f"tg-session:tmp-value:{session_key}" diff --git a/testgen/ui/views/connections.py b/testgen/ui/views/connections.py index 77b059e2..d8af0332 100644 --- a/testgen/ui/views/connections.py +++ b/testgen/ui/views/connections.py @@ -5,6 +5,9 @@ import streamlit as st +from testgen.common.database.flavor.flavor_service import ConnectionParams +from testgen.ui.queries import table_group_queries + try: from pyodbc import Error as PyODBCError except ImportError: @@ -13,15 +16,16 @@ import testgen.ui.services.database_service as db from testgen.commands.run_profiling_bridge import run_profiling_in_background -from testgen.common.database.database_service import empty_cache +from testgen.common.database.database_service import empty_cache, get_flavor_service from testgen.common.models import with_database_session +from testgen.common.models.connection import Connection, ConnectionMinimal +from testgen.common.models.table_group import TableGroup from testgen.ui.assets import get_asset_data_url from testgen.ui.components import widgets as testgen from testgen.ui.navigation.menu import MenuItem from testgen.ui.navigation.page import Page -from testgen.ui.services import connection_service, table_group_service, user_session_service +from testgen.ui.services import user_session_service from testgen.ui.session import session, temp_value -from testgen.utils import format_field LOG = logging.getLogger("testgen") PAGE_TITLE = "Connection" @@ -49,10 +53,10 @@ def render(self, project_code: str, **_kwargs) -> None: "connect-your-database", ) - dataframe = connection_service.get_connections(project_code) - connection = dataframe.iloc[0] + connections = Connection.select_where(Connection.project_code == project_code) + connection: Connection = connections[0] has_table_groups = ( - len(connection_service.get_table_group_names_by_connection([connection["connection_id"]]) or []) > 0 + len(TableGroup.select_minimal_where(TableGroup.connection_id == connection.connection_id) or []) > 0 ) user_is_admin = user_session_service.user_is_admin() should_check_status, set_check_status = temp_value( @@ -77,7 +81,7 @@ def on_save_connection_clicked(updated_connection): updated_connection["url"] = url_parts[1] if updated_connection.get("connect_by_key"): - updated_connection["password"] = "" + updated_connection["project_pw_encrypted"] = "" if is_pristine(updated_connection["private_key_passphrase"]): del updated_connection["private_key_passphrase"] else: @@ -92,11 +96,11 @@ def on_save_connection_clicked(updated_connection): else: updated_connection["private_key"] = base64.b64decode(updated_connection["private_key"]).decode() - if is_pristine(updated_connection.get("password")): - del updated_connection["password"] + if is_pristine(updated_connection.get("project_pw_encrypted")): + del updated_connection["project_pw_encrypted"] - if updated_connection.get("password") == CLEAR_SENTINEL: - updated_connection["password"] = "" + if updated_connection.get("project_pw_encrypted") == CLEAR_SENTINEL: + updated_connection["project_pw_encrypted"] = "" updated_connection["sql_flavor"] = self._get_sql_flavor_from_value(updated_connection["sql_flavor_code"]).flavor @@ -104,13 +108,13 @@ def on_save_connection_clicked(updated_connection): set_updated_connection(updated_connection) def on_test_connection_clicked(updated_connection: dict) -> None: - password = updated_connection.get("password") + password = updated_connection.get("project_pw_encrypted") private_key = updated_connection.get("private_key") private_key_passphrase = updated_connection.get("private_key_passphrase") is_pristine = lambda value: value in ["", "***"] if is_pristine(password): - del updated_connection["password"] + del updated_connection["project_pw_encrypted"] if is_pristine(private_key): del updated_connection["private_key"] @@ -128,11 +132,13 @@ def on_test_connection_clicked(updated_connection: dict) -> None: set_updated_connection(updated_connection) results = None - connection = {**connection.to_dict(), **get_updated_connection()} + for key, value in get_updated_connection().items(): + setattr(connection, key, value) + if should_save(): success = True try: - connection_service.edit_connection(connection) + connection.save() message = "Changes have been saved successfully." except Exception as error: message = "Error creating connection" @@ -159,7 +165,7 @@ def on_test_connection_clicked(updated_connection: dict) -> None: on_change_handlers={ "TestConnectionClicked": on_test_connection_clicked, "SaveConnectionClicked": on_save_connection_clicked, - "SetupTableGroupClicked": lambda _: self.setup_data_configuration(project_code, connection["connection_id"]), + "SetupTableGroupClicked": lambda _: self.setup_data_configuration(project_code, connection.connection_id), }, ) @@ -169,31 +175,17 @@ def _get_sql_flavor_from_value(self, value: str) -> "ConnectionFlavor | None": return match[0] return None - def _format_connection(self, connection: dict, should_test: bool = False) -> dict: + def _format_connection(self, connection: Connection, should_test: bool = False) -> dict: formatted_connection = format_connection(connection) if should_test: formatted_connection["status"] = asdict(self.test_connection(connection)) return formatted_connection - def test_connection(self, connection: dict) -> "ConnectionStatus": + def test_connection(self, connection: Connection) -> "ConnectionStatus": empty_cache() try: sql_query = "select 1;" - results = db.retrieve_target_db_data( - connection["sql_flavor"], - connection["project_host"], - connection["project_port"], - connection["project_db"], - connection["project_user"], - connection["password"], - connection["url"], - connection["connect_by_url"], - connection["connect_by_key"], - connection.get("private_key", ""), - connection.get("private_key_passphrase", ""), - connection.get("http_path", ""), - sql_query, - ) + results = db.fetch_from_target_db(connection, sql_query) connection_successful = len(results) == 1 and results[0][0] == 1 if not connection_successful: @@ -242,9 +234,13 @@ def on_save_table_group_clicked(payload: dict) -> None: def on_go_to_profiling_runs(params: dict) -> None: set_navigation_params({ **params, "project_code": project_code }) - def on_preview_table_group(table_group: dict) -> None: + def on_preview_table_group(payload: dict) -> None: + table_group = payload["table_group"] + verify_table_access = payload.get("verify_access") or False + set_new_table_group(table_group) mark_for_preview(True) + mark_for_access_preview(verify_table_access) get_navigation_params, set_navigation_params = temp_value( "connections:new_table_group:go_to_profiling_run", @@ -263,49 +259,52 @@ def on_preview_table_group(table_group: dict) -> None: ) results = None - table_group = get_new_table_group() + table_group_data = get_new_table_group() should_run_profiling = get_run_profiling() should_preview, mark_for_preview = temp_value( f"connections:{connection_id}:tg_preview", default=False, ) + should_verify_access, mark_for_access_preview = temp_value( + f"connections:{connection_id}:tg_preview_access", + default=False, + ) is_table_group_verified, set_table_group_verified = temp_value( f"connections:{connection_id}:tg_verified", default=False, ) + table_group = TableGroup( + **table_group_data or {}, + project_code = project_code, + connection_id = connection_id, + ) + table_group_preview = None if should_preview(): - connection = connection_service.get_by_id(connection_id, hide_passwords=False) - table_group_preview = table_group_service.get_table_group_preview( - project_code, - connection, - {"id": "temp", **table_group}, + table_group_preview = table_group_queries.get_table_group_preview( + table_group, + verify_table_access=should_verify_access(), ) - if table_group: + if table_group_data: success = True message = None - table_group_id = None if is_table_group_verified(): try: - table_group_id = table_group_service.add({ - **table_group, - "project_code": project_code, - "connection_id": connection_id, - }) + table_group.save() if should_run_profiling: try: - run_profiling_in_background(table_group_id) - message = f"Profiling run started for table group {table_group['table_groups_name']}." + run_profiling_in_background(table_group.id) + message = f"Profiling run started for table group {table_group.table_groups_name}." except Exception as error: message = "Profiling run encountered errors" success = False LOG.exception(message) else: - LOG.info("Table group %s created", table_group_id) + LOG.info("Table group %s created", table_group.id) st.rerun() except Exception as error: message = "Error creating table group" @@ -315,7 +314,7 @@ def on_preview_table_group(table_group: dict) -> None: results = { "success": success, "message": message, - "table_group_id": table_group_id, + "table_group_id": table_group.id, } else: results = { @@ -363,37 +362,14 @@ def is_open_ssl_error(error: Exception): ) -def format_connection(connection: dict) -> dict: - fields = [ - "project_code", - "connection_id", - "connection_name", - "sql_flavor", - "sql_flavor_code", - "project_host", - "project_port", - "project_db", - "project_user", - "password", - "max_threads", - "max_query_chars", - "connect_by_url", - "connect_by_key", - "private_key", - "private_key_passphrase", - "http_path", - "url", - ] - formatted_connection = {} - - for fieldname in fields: - formatted_connection[fieldname] = format_field(connection[fieldname]) +def format_connection(connection: Connection | ConnectionMinimal) -> dict: + formatted_connection = connection.to_dict(json_safe=True) - if formatted_connection["password"]: - formatted_connection["password"] = "***" # noqa S105 - if formatted_connection["private_key"]: + if formatted_connection.get("project_pw_encrypted"): + formatted_connection["project_pw_encrypted"] = "***" + if formatted_connection.get("private_key"): formatted_connection["private_key"] = "***" # S105 - if formatted_connection["private_key_passphrase"]: + if formatted_connection.get("private_key_passphrase"): formatted_connection["private_key_passphrase"] = "***" # noqa S105 flavors = [f for f in FLAVOR_OPTIONS if f.value == formatted_connection["sql_flavor_code"]] @@ -403,6 +379,22 @@ def format_connection(connection: dict) -> dict: return formatted_connection +def get_connection_string(flavor: str) -> str: + connection_params: ConnectionParams = { + "sql_flavor": flavor, + "project_host": "", + "project_port": "", + "project_user": "", + "project_db": "", + "project_pw_encrypted": "", + "http_path": "", + "table_group_schema": "", + } + flavor_service = get_flavor_service(flavor) + flavor_service.init(connection_params) + return flavor_service.get_connection_string().replace("%3E", ">").replace("%3C", "<") + + @dataclass(frozen=True, slots=True, kw_only=True) class ConnectionFlavor: value: str @@ -418,48 +410,48 @@ class ConnectionFlavor: value="redshift", flavor="redshift", icon=get_asset_data_url("flavors/redshift.svg"), - connection_string=connection_service.get_connection_string("redshift"), + connection_string=get_connection_string("redshift"), ), ConnectionFlavor( label="Azure SQL Database", value="azure_mssql", flavor="mssql", icon=get_asset_data_url("flavors/azure_sql.svg"), - connection_string=connection_service.get_connection_string("mssql"), + connection_string=get_connection_string("mssql"), ), ConnectionFlavor( label="Azure Synapse Analytics", value="synapse_mssql", flavor="mssql", icon=get_asset_data_url("flavors/azure_synapse_table.svg"), - connection_string=connection_service.get_connection_string("mssql"), + connection_string=get_connection_string("mssql"), ), ConnectionFlavor( label="Microsoft SQL Server", value="mssql", flavor="mssql", icon=get_asset_data_url("flavors/mssql.svg"), - connection_string=connection_service.get_connection_string("mssql"), + connection_string=get_connection_string("mssql"), ), ConnectionFlavor( label="PostgreSQL", value="postgresql", flavor="postgresql", icon=get_asset_data_url("flavors/postgresql.svg"), - connection_string=connection_service.get_connection_string("postgresql"), + connection_string=get_connection_string("postgresql"), ), ConnectionFlavor( label="Snowflake", value="snowflake", flavor="snowflake", icon=get_asset_data_url("flavors/snowflake.svg"), - connection_string=connection_service.get_connection_string("snowflake"), + connection_string=get_connection_string("snowflake"), ), ConnectionFlavor( label="Databricks", value="databricks", flavor="databricks", icon=get_asset_data_url("flavors/databricks.svg"), - connection_string=connection_service.get_connection_string("databricks"), + connection_string=get_connection_string("databricks"), ), ] diff --git a/testgen/ui/views/data_catalog.py b/testgen/ui/views/data_catalog.py index 4abb021c..decb83dd 100644 --- a/testgen/ui/views/data_catalog.py +++ b/testgen/ui/views/data_catalog.py @@ -9,8 +9,9 @@ import streamlit as st from streamlit.delta_generator import DeltaGenerator -import testgen.ui.services.database_service as db -import testgen.ui.services.query_service as dq +from testgen.common.models import with_database_session +from testgen.common.models.project import Project +from testgen.common.models.table_group import TableGroup, TableGroupMinimal from testgen.ui.components import widgets as testgen from testgen.ui.components.widgets import testgen_component from testgen.ui.components.widgets.download_dialog import ( @@ -22,7 +23,6 @@ from testgen.ui.navigation.menu import MenuItem from testgen.ui.navigation.page import Page from testgen.ui.navigation.router import Router -from testgen.ui.queries import project_queries from testgen.ui.queries.profiling_queries import ( TAG_FIELDS, get_column_by_id, @@ -34,12 +34,13 @@ get_tables_by_table_group, ) from testgen.ui.services import user_session_service +from testgen.ui.services.database_service import execute_db_query, fetch_all_from_db from testgen.ui.session import session, temp_value from testgen.ui.views.dialogs.column_history_dialog import column_history_dialog from testgen.ui.views.dialogs.data_preview_dialog import data_preview_dialog from testgen.ui.views.dialogs.run_profiling_dialog import run_profiling_dialog from testgen.ui.views.dialogs.table_create_script_dialog import table_create_script_dialog -from testgen.utils import format_field, friendly_score, is_uuid4, score +from testgen.utils import friendly_score, is_uuid4, make_json_safe, score PAGE_ICON = "dataset" PAGE_TITLE = "Data Catalog" @@ -68,44 +69,39 @@ def render(self, project_code: str, table_group_id: str | None = None, selected: # (something to do with displaying the extra cache spinner next to the custom component) # Enclosing the loading logic in a Streamlit container also fixes it - project_summary = project_queries.get_summary_by_code(project_code) + project_summary = Project.get_summary(project_code) user_can_navigate = not user_session_service.user_has_catalog_role() - table_groups = get_table_group_options(project_code) + table_groups = TableGroup.select_minimal_where(TableGroup.project_code == project_code) - if not table_group_id or table_group_id not in table_groups["id"].values: - table_group_id = table_groups.iloc[0]["id"] if not table_groups.empty else None + if not table_group_id or table_group_id not in [ str(item.id) for item in table_groups ]: + table_group_id = str(table_groups[0].id) if table_groups else None on_table_group_selected(table_group_id) - columns, selected_item, selected_table_group = pd.DataFrame(), None, None + columns, selected_item, selected_table_group = [], None, None if table_group_id: - selected_table_group = table_groups.loc[table_groups["id"] == table_group_id].iloc[0] + selected_table_group = next(item for item in table_groups if str(item.id) == table_group_id) columns = get_table_group_columns(table_group_id) selected_item = get_selected_item(selected, table_group_id) if selected_item: selected_item["project_code"] = project_code - selected_item["connection_id"] = format_field(selected_table_group["connection_id"]) + selected_item["connection_id"] = str(selected_table_group.connection_id) else: on_item_selected(None) testgen_component( "data_catalog", props={ - "project_summary": { - "project_code": project_code, - "connections_ct": format_field(project_summary["connections_ct"]), - "table_groups_ct": format_field(project_summary["table_groups_ct"]), - "default_connection_id": format_field(project_summary["default_connection_id"]), - }, + "project_summary": project_summary.to_dict(json_safe=True), "table_group_filter_options": [ { - "value": format_field(table_group["id"]), - "label": format_field(table_group["table_groups_name"]), - "selected": str(table_group_id) == str(table_group["id"]), - } for _, table_group in table_groups.iterrows() + "value": str(table_group.id), + "label": table_group.table_groups_name, + "selected": table_group_id == str(table_group.id), + } for table_group in table_groups ], - "columns": columns.to_json(orient="records") if not columns.empty else None, - "selected_item": json.dumps(selected_item), + "columns": json.dumps(make_json_safe(columns)) if columns else None, + "selected_item": json.dumps(make_json_safe(selected_item)) if selected_item else None, "tag_values": get_tag_values(), "last_saved_timestamp": st.session_state.get("data_catalog:last_saved_timestamp"), "permissions": { @@ -161,7 +157,7 @@ class ExportItem(typing.TypedDict): id: str type: typing.Literal["table", "column"] -def get_excel_report_data(update_progress: PROGRESS_UPDATE_TYPE, table_group: dict, items: list[ExportItem] | None) -> None: +def get_excel_report_data(update_progress: PROGRESS_UPDATE_TYPE, table_group: TableGroupMinimal, items: list[ExportItem] | None) -> None: if items: table_data = get_tables_by_id( table_ids=[ item["id"] for item in items if item["type"] == "table" ], @@ -175,12 +171,12 @@ def get_excel_report_data(update_progress: PROGRESS_UPDATE_TYPE, table_group: di ) else: table_data = get_tables_by_table_group( - table_group["id"], + table_group.id, include_tags=True, include_active_tests=True, ) column_data = get_columns_by_table_group( - table_group["id"], + table_group.id, include_tags=True, include_active_tests=True, ) @@ -197,7 +193,7 @@ def get_excel_report_data(update_progress: PROGRESS_UPDATE_TYPE, table_group: di for key in ["min_date", "max_date", "add_date", "last_mod_date", "drop_date"]: data[key] = data[key].apply( - lambda val: datetime.fromtimestamp(val / 1000).strftime("%b %-d %Y, %-I:%M %p") if not pd.isna(val) else None + lambda val: val.strftime("%b %-d %Y, %-I:%M %p") if not pd.isna(val) else None ) for key in ["data_source", "source_system", "source_process", "business_domain", "stakeholder_group", "transform_level", "aggregation_level", "data_product"]: @@ -295,13 +291,14 @@ def get_excel_report_data(update_progress: PROGRESS_UPDATE_TYPE, table_group: di return get_excel_file_data( data, "Data Catalog Columns", - details={"Table group": table_group["table_groups_name"]}, + details={"Table group": table_group.table_groups_name}, columns=file_columns, update_progress=update_progress, ) @st.dialog(title="Remove Table from Catalog") +@with_database_session def remove_table_dialog(item: dict) -> None: remove_clicked, set_remove_clicked = temp_value("data-catalog:confirm-remove-table-val") st.html(f"Are you sure you want to remove the table {item['table_name']} from the data catalog?") @@ -318,15 +315,14 @@ def remove_table_dialog(item: dict) -> None: ) if remove_clicked(): - schema = st.session_state["dbschema"] - db.execute_sql(f""" - DELETE FROM {schema}.data_column_chars - WHERE table_id = '{item["id"]}'; - """) - db.execute_sql(f""" - DELETE FROM {schema}.data_table_chars - WHERE table_id = '{item["id"]}'; - """) + execute_db_query( + "DELETE FROM data_column_chars WHERE table_id = :table_id;", + {"table_id": item["id"]}, + ) + execute_db_query( + "DELETE FROM data_table_chars WHERE table_id = :table_id;", + {"table_id": item["id"]}, + ) st.success("Table has been removed.") time.sleep(1) @@ -339,51 +335,48 @@ def remove_table_dialog(item: dict) -> None: def on_tags_changed(spinner_container: DeltaGenerator, payload: dict) -> FILE_DATA_TYPE: attributes = ["description"] attributes.extend(TAG_FIELDS) - cde_value_map = { - True: "TRUE", - False: "FALSE", - None: "NULL", - } tags = payload["tags"] - set_attributes = [ f"{key} = NULLIF('{tags.get(key) or ''}', '')" for key in attributes if key in tags ] + set_attributes = [ f"{key} = NULLIF(:{key}, '')" for key in attributes if key in tags ] + params = { key: tags.get(key) or "" for key in attributes if key in tags } if "critical_data_element" in tags: - set_attributes.append(f"critical_data_element = {cde_value_map[tags.get('critical_data_element')]}") - - tables = [] - columns = [] - for item in payload["items"]: - id_list = tables if item["type"] == "table" else columns - id_list.append(item["id"]) + set_attributes.append("critical_data_element = :critical_data_element") + params.update({"critical_data_element": tags.get("critical_data_element")}) - schema = st.session_state["dbschema"] + params["table_ids"] = [ item["id"] for item in payload["items"] if item["type"] == "table" ] + params["column_ids"] = [ item["id"] for item in payload["items"] if item["type"] == "column" ] with spinner_container: with st.spinner("Saving tags"): - if tables: - db.execute_sql_raw(f""" - WITH selected as ( - SELECT UNNEST(ARRAY [{", ".join([ f"'{item}'" for item in tables ])}]) AS table_id + if params["table_ids"]: + execute_db_query( + f""" + WITH selected as ( + SELECT UNNEST(ARRAY [:table_ids]) AS table_id + ) + UPDATE data_table_chars + SET {', '.join(set_attributes)} + FROM data_table_chars dtc + INNER JOIN selected ON (dtc.table_id = selected.table_id::UUID) + WHERE dtc.table_id = data_table_chars.table_id; + """, + params, ) - UPDATE {schema}.data_table_chars - SET {', '.join(set_attributes)} - FROM {schema}.data_table_chars dtc - INNER JOIN selected ON (dtc.table_id = selected.table_id::UUID) - WHERE dtc.table_id = data_table_chars.table_id; - """) - - - if columns: - db.execute_sql_raw(f""" - WITH selected as ( - SELECT UNNEST(ARRAY [{", ".join([ f"'{item}'" for item in columns ])}]) AS column_id + + if params["column_ids"]: + execute_db_query( + f""" + WITH selected as ( + SELECT UNNEST(ARRAY [:column_ids]) AS column_id + ) + UPDATE data_column_chars + SET {', '.join(set_attributes)} + FROM data_column_chars dcc + INNER JOIN selected ON (dcc.column_id = selected.column_id::UUID) + WHERE dcc.column_id = data_column_chars.column_id; + """, + params, ) - UPDATE {schema}.data_column_chars - SET {', '.join(set_attributes)} - FROM {schema}.data_column_chars dcc - INNER JOIN selected ON (dcc.column_id = selected.column_id::UUID) - WHERE dcc.column_id = data_column_chars.column_id; - """) for func in [ get_table_group_columns, get_table_by_id, get_column_by_id, get_tag_values ]: func.clear() @@ -392,18 +385,11 @@ def on_tags_changed(spinner_container: DeltaGenerator, payload: dict) -> FILE_DA @st.cache_data(show_spinner=False) -def get_table_group_options(project_code): - schema = st.session_state["dbschema"] - return dq.run_table_groups_lookup_query(schema, project_code) - - -@st.cache_data(show_spinner=False) -def get_table_group_columns(table_group_id: str) -> pd.DataFrame: +def get_table_group_columns(table_group_id: str) -> list[dict]: if not is_uuid4(table_group_id): - return pd.DataFrame() + return [] - schema = st.session_state["dbschema"] - sql = f""" + query = f""" SELECT CONCAT('column_', column_chars.column_id) AS column_id, CONCAT('table_', table_chars.table_id) AS table_id, column_chars.column_name, @@ -421,23 +407,26 @@ def get_table_group_columns(table_group_id: str) -> pd.DataFrame: table_chars.critical_data_element AS table_critical_data_element, {", ".join([ f"column_chars.{tag}" for tag in TAG_FIELDS ])}, {", ".join([ f"table_chars.{tag} AS table_{tag}" for tag in TAG_FIELDS ])} - FROM {schema}.data_column_chars column_chars - LEFT JOIN {schema}.data_table_chars table_chars ON ( + FROM data_column_chars column_chars + LEFT JOIN data_table_chars table_chars ON ( column_chars.table_id = table_chars.table_id ) - LEFT JOIN {schema}.profile_results ON ( + LEFT JOIN profile_results ON ( column_chars.last_complete_profile_run_id = profile_results.profile_run_id AND column_chars.table_name = profile_results.table_name AND column_chars.column_name = profile_results.column_name ) - WHERE column_chars.table_groups_id = '{table_group_id}' + WHERE column_chars.table_groups_id = :table_group_id ORDER BY table_name, ordinal_position; """ - return db.retrieve_data(sql) + params = {"table_group_id": table_group_id} + + results = fetch_all_from_db(query, params) + return [ dict(row) for row in results ] def get_selected_item(selected: str, table_group_id: str) -> dict | None: - if not selected or not is_uuid4(table_group_id): + if not selected or "_" not in selected or not is_uuid4(table_group_id): return None item_type, item_id = selected.split("_", 2) @@ -460,14 +449,8 @@ def get_selected_item(selected: str, table_group_id: str) -> dict | None: @st.cache_data(show_spinner=False) -def get_latest_test_issues(table_group_id: str, table_name: str, column_name: str | None = None) -> dict | None: - schema = st.session_state["dbschema"] - - column_condition = "" - if column_name: - column_condition = f"AND column_names = '{column_name}'" - - sql = f""" +def get_latest_test_issues(table_group_id: str, table_name: str, column_name: str | None = None) -> list[dict]: + query = f""" SELECT test_results.id::VARCHAR(50), column_names AS column_name, test_name_short AS test_name, @@ -475,20 +458,20 @@ def get_latest_test_issues(table_group_id: str, table_name: str, column_name: st result_message, test_suite, test_results.test_run_id::VARCHAR(50), - EXTRACT(EPOCH FROM test_starttime) * 1000 AS test_run_date - FROM {schema}.test_suites - LEFT JOIN {schema}.test_runs ON ( + EXTRACT(EPOCH FROM test_starttime)::INT AS test_run_date + FROM test_suites + LEFT JOIN test_runs ON ( test_suites.last_complete_test_run_id = test_runs.id ) - LEFT JOIN {schema}.test_results ON ( + LEFT JOIN test_results ON ( test_runs.id = test_results.test_run_id ) - LEFT JOIN {schema}.test_types ON ( + LEFT JOIN test_types ON ( test_results.test_type = test_types.test_type ) - WHERE test_suites.table_groups_id = '{table_group_id}' - AND table_name = '{table_name}' - {column_condition} + WHERE test_suites.table_groups_id = :table_group_id + AND table_name = :table_name + {"AND column_names = :column_name" if column_name else ""} AND result_status <> 'Passed' AND COALESCE(disposition, 'Confirmed') = 'Confirmed' ORDER BY @@ -499,65 +482,67 @@ def get_latest_test_issues(table_group_id: str, table_name: str, column_name: st END, column_name; """ + params = { + "table_group_id": table_group_id, + "table_name": table_name, + "column_name": column_name, + } - df = db.retrieve_data(sql) - return [row.to_dict() for _, row in df.iterrows()] + results = fetch_all_from_db(query, params) + return [ dict(row) for row in results ] @st.cache_data(show_spinner=False) -def get_related_test_suites(table_group_id: str, table_name: str, column_name: str | None = None) -> dict | None: - schema = st.session_state["dbschema"] - - column_condition = "" - if column_name: - column_condition = f"AND column_name = '{column_name}'" - - sql = f""" +def get_related_test_suites(table_group_id: str, table_name: str, column_name: str | None = None) -> list[dict]: + query = f""" SELECT test_suites.id::VARCHAR, test_suite AS name, COUNT(*) AS test_count - FROM {schema}.test_definitions - LEFT JOIN {schema}.test_suites ON ( + FROM test_definitions + LEFT JOIN test_suites ON ( test_definitions.test_suite_id = test_suites.id ) - WHERE test_suites.table_groups_id = '{table_group_id}' - AND table_name = '{table_name}' - {column_condition} + WHERE test_suites.table_groups_id = :table_group_id + AND table_name = :table_name + {"AND column_name = :column_name" if column_name else ""} GROUP BY test_suites.id ORDER BY test_suite; """ + params = { + "table_group_id": table_group_id, + "table_name": table_name, + "column_name": column_name, + } - df = db.retrieve_data(sql) - return [row.to_dict() for _, row in df.iterrows()] + results = fetch_all_from_db(query, params) + return [ dict(row) for row in results ] @st.cache_data(show_spinner=False) def get_tag_values() -> dict[str, list[str]]: - schema = st.session_state["dbschema"] - quote = lambda v: f"'{v}'" - sql = f""" + query = f""" SELECT DISTINCT UNNEST(array[{', '.join([quote(t) for t in TAG_FIELDS])}]) as tag, UNNEST(array[{', '.join(TAG_FIELDS)}]) AS value - FROM {schema}.data_column_chars + FROM data_column_chars UNION SELECT DISTINCT UNNEST(array[{', '.join([quote(t) for t in TAG_FIELDS])}]) as tag, UNNEST(array[{', '.join(TAG_FIELDS)}]) AS value - FROM {schema}.data_table_chars + FROM data_table_chars UNION SELECT DISTINCT UNNEST(array[{', '.join([quote(t) for t in TAG_FIELDS if t != 'aggregation_level'])}]) as tag, UNNEST(array[{', '.join([ t for t in TAG_FIELDS if t != 'aggregation_level'])}]) AS value - FROM {schema}.table_groups - ORDER BY value + FROM table_groups + ORDER BY value; """ - df = db.retrieve_data(sql) + results = fetch_all_from_db(query) values = defaultdict(list) - for _, row in df.iterrows(): - if row["tag"] and row["value"]: - values[row["tag"]].append(row["value"]) + for row in results: + if row.tag and row.value: + values[row.tag].append(row.value) return values diff --git a/testgen/ui/views/dialogs/column_history_dialog.py b/testgen/ui/views/dialogs/column_history_dialog.py index 6a224004..24915163 100644 --- a/testgen/ui/views/dialogs/column_history_dialog.py +++ b/testgen/ui/views/dialogs/column_history_dialog.py @@ -1,13 +1,13 @@ -import json - -import pandas as pd import streamlit as st +from sqlalchemy.sql.expression import func -import testgen.ui.services.database_service as db +from testgen.common.models import with_database_session +from testgen.common.models.profiling_run import ProfilingRun from testgen.ui.components import widgets as testgen from testgen.ui.components.widgets import testgen_component from testgen.ui.queries.profiling_queries import COLUMN_PROFILING_FIELDS -from testgen.utils import format_field +from testgen.ui.services.database_service import fetch_one_from_db +from testgen.utils import make_json_safe def column_history_dialog(*args) -> None: @@ -16,6 +16,7 @@ def column_history_dialog(*args) -> None: @st.dialog(title="Column History") +@with_database_session def _column_history_dialog( table_group_id: str, schema_name: str, @@ -31,8 +32,12 @@ def _column_history_dialog( with loading_column: with st.spinner("Loading data ..."): - profiling_runs = get_profiling_runs(table_group_id, add_date) - run_id = st.session_state.get("column_history_dialog:run_id") or profiling_runs.iloc[0]["id"] + profiling_runs = ProfilingRun.select_minimal_where( + ProfilingRun.table_groups_id == table_group_id, + ProfilingRun.profiling_starttime >= func.to_timestamp(add_date), + ) + profiling_runs = [run.to_dict(json_safe=True) for run in profiling_runs] + run_id = st.session_state.get("column_history_dialog:run_id") or profiling_runs[0]["id"] selected_item = get_run_column(run_id, schema_name, table_name, column_name) testgen_component( @@ -40,11 +45,11 @@ def _column_history_dialog( props={ "profiling_runs": [ { - "run_id": format_field(run["id"]), - "run_date": format_field(run["profiling_starttime"]), - } for _, run in profiling_runs.iterrows() + "run_id": run["id"], + "run_date": run["profiling_starttime"], + } for run in profiling_runs ], - "selected_item": selected_item, + "selected_item": make_json_safe(selected_item), }, on_change_handlers={ "RunSelected": on_run_selected, @@ -56,39 +61,24 @@ def on_run_selected(run_id: str) -> None: st.session_state["column_history_dialog:run_id"] = run_id -@st.cache_data(show_spinner=False) -def get_profiling_runs( - table_group_id: str, - after_date: int, -) -> pd.DataFrame: - schema: str = st.session_state["dbschema"] - query = f""" - SELECT - id::VARCHAR, - profiling_starttime - FROM {schema}.profiling_runs - WHERE table_groups_id = '{table_group_id}' - AND profiling_starttime >= TO_TIMESTAMP({after_date / 1000}) - ORDER BY profiling_starttime DESC; - """ - return db.retrieve_data(query) - - @st.cache_data(show_spinner=False) def get_run_column(run_id: str, schema_name: str, table_name: str, column_name: str) -> dict: - schema: str = st.session_state["dbschema"] query = f""" SELECT profile_run_id::VARCHAR, general_type, {COLUMN_PROFILING_FIELDS} - FROM {schema}.profile_results - WHERE profile_run_id = '{run_id}' - AND schema_name = '{schema_name}' - AND table_name = '{table_name}' - AND column_name = '{column_name}'; + FROM profile_results + WHERE profile_run_id = :run_id + AND schema_name = :schema_name + AND table_name = :table_name + AND column_name = :column_name; """ - results = db.retrieve_data(query) - if not results.empty: - # to_json converts datetimes, NaN, etc, to JSON-safe values (Note: to_dict does not) - return json.loads(results.to_json(orient="records"))[0] + params = { + "run_id": run_id, + "schema_name": schema_name, + "table_name": table_name, + "column_name": column_name, + } + result = fetch_one_from_db(query, params) + return dict(result) if result else None diff --git a/testgen/ui/views/dialogs/data_preview_dialog.py b/testgen/ui/views/dialogs/data_preview_dialog.py index 9d5beaea..12a7648f 100644 --- a/testgen/ui/views/dialogs/data_preview_dialog.py +++ b/testgen/ui/views/dialogs/data_preview_dialog.py @@ -1,8 +1,10 @@ import pandas as pd import streamlit as st -import testgen.ui.services.database_service as db +from testgen.common.models.connection import Connection from testgen.ui.components import widgets as testgen +from testgen.ui.services.database_service import fetch_from_target_db +from testgen.utils import to_dataframe @st.dialog(title="Data Preview") @@ -40,31 +42,10 @@ def get_preview_data( table_name: str, column_name: str | None = None, ) -> pd.DataFrame: - tg_schema = st.session_state["dbschema"] - connection_query=f""" - SELECT - c.sql_flavor, - c.project_host, - c.project_port, - c.project_db, - c.project_user, - c.project_pw_encrypted, - c.url, - c.connect_by_url, - c.connect_by_key, - c.private_key, - c.private_key_passphrase, - c.http_path - FROM {tg_schema}.table_groups tg - INNER JOIN {tg_schema}.connections c ON ( - tg.connection_id = c.connection_id - ) - WHERE tg.id = '{table_group_id}'; - """ - connection_df = db.retrieve_data(connection_query).iloc[0] + connection = Connection.get_by_table_group(table_group_id) - if not connection_df.empty: - use_top = connection_df["sql_flavor"] == "mssql" + if connection: + use_top = connection.sql_flavor == "mssql" query = f""" SELECT DISTINCT {"TOP 100" if use_top else ""} @@ -74,24 +55,11 @@ def get_preview_data( """ try: - df = db.retrieve_target_db_df( - connection_df["sql_flavor"], - connection_df["project_host"], - connection_df["project_port"], - connection_df["project_db"], - connection_df["project_user"], - connection_df["project_pw_encrypted"], - query, - connection_df["url"], - connection_df["connect_by_url"], - connection_df["connect_by_key"], - connection_df["private_key"], - connection_df["private_key_passphrase"], - connection_df["http_path"], - ) + results = fetch_from_target_db(connection, query) except: return pd.DataFrame() else: + df = to_dataframe(results) df.index = df.index + 1 df.fillna("", inplace=True) return df diff --git a/testgen/ui/views/dialogs/generate_tests_dialog.py b/testgen/ui/views/dialogs/generate_tests_dialog.py index 89013108..c0c159b1 100644 --- a/testgen/ui/views/dialogs/generate_tests_dialog.py +++ b/testgen/ui/views/dialogs/generate_tests_dialog.py @@ -1,23 +1,25 @@ import time -import pandas as pd import streamlit as st -import testgen.ui.services.test_suite_service as test_suite_service from testgen.commands.run_generate_tests import run_test_gen_queries +from testgen.common.models import with_database_session +from testgen.common.models.test_suite import TestSuiteMinimal from testgen.ui.components import widgets as testgen +from testgen.ui.services.database_service import execute_db_query, fetch_all_from_db, fetch_one_from_db ALL_TYPES_LABEL = "All Test Types" @st.dialog(title="Generate Tests") -def generate_tests_dialog(test_suite: pd.Series) -> None: - test_suite_id = test_suite["id"] - test_suite_name = test_suite["test_suite"] - table_group_id = test_suite["table_groups_id"] +@with_database_session +def generate_tests_dialog(test_suite: TestSuiteMinimal) -> None: + test_suite_id = test_suite.id + test_suite_name = test_suite.test_suite + table_group_id = test_suite.table_groups_id selected_set = "" - generation_sets = test_suite_service.get_generation_set_choices() + generation_sets = get_generation_set_choices() if generation_sets: generation_sets.insert(0, ALL_TYPES_LABEL) @@ -27,7 +29,7 @@ def generate_tests_dialog(test_suite: pd.Series) -> None: if selected_set == ALL_TYPES_LABEL: selected_set = "" - test_ct, unlocked_test_ct, unlocked_edits_ct = test_suite_service.get_test_suite_refresh_warning(test_suite_id) + test_ct, unlocked_test_ct, unlocked_edits_ct = get_test_suite_refresh_warning(test_suite_id) if test_ct: unlocked_message = "" if unlocked_edits_ct > 0: @@ -45,8 +47,8 @@ def generate_tests_dialog(test_suite: pd.Series) -> None: st.warning(warning_message, icon=":material/warning:") if unlocked_edits_ct > 0: if st.button("Lock Edited Tests"): - if test_suite_service.lock_edited_tests(test_suite_id): - st.info("Edited tests have been successfully locked.") + lock_edited_tests(test_suite_id) + st.info("Edited tests have been successfully locked.") with st.container(): st.markdown(f"Execute test generation for the test suite **{test_suite_name}**?") @@ -79,3 +81,50 @@ def generate_tests_dialog(test_suite: pd.Series) -> None: time.sleep(1) st.cache_data.clear() st.rerun() + + +def get_test_suite_refresh_warning(test_suite_id: str) -> tuple[int, int, int]: + result = fetch_one_from_db( + """ + SELECT + COUNT(*) AS test_ct, + SUM(CASE WHEN COALESCE(td.lock_refresh, 'N') = 'N' THEN 1 ELSE 0 END) AS unlocked_test_ct, + SUM(CASE WHEN COALESCE(td.lock_refresh, 'N') = 'N' AND td.last_manual_update IS NOT NULL THEN 1 ELSE 0 END) AS unlocked_edits_ct + FROM test_definitions td + INNER JOIN test_types tt + ON (td.test_type = tt.test_type) + WHERE td.test_suite_id = :test_suite_id + AND tt.run_type = 'CAT' + AND tt.selection_criteria IS NOT NULL; + """, + {"test_suite_id": test_suite_id}, + ) + + if result: + return result.test_ct, result.unlocked_test_ct, result.unlocked_edits_ct + + return None, None, None + + +def get_generation_set_choices() -> list[str]: + results = fetch_all_from_db( + """ + SELECT DISTINCT generation_set + FROM generation_sets + ORDER BY generation_set; + """ + ) + return [ row.generation_set for row in results ] + + +def lock_edited_tests(test_suite_id: str) -> None: + execute_db_query( + """ + UPDATE test_definitions + SET lock_refresh = 'Y' + WHERE test_suite_id = :test_suite_id + AND last_manual_update IS NOT NULL + AND lock_refresh = 'N'; + """, + {"test_suite_id": test_suite_id} + ) diff --git a/testgen/ui/views/dialogs/manage_schedules.py b/testgen/ui/views/dialogs/manage_schedules.py index cc96ecc5..3260eca4 100644 --- a/testgen/ui/views/dialogs/manage_schedules.py +++ b/testgen/ui/views/dialogs/manage_schedules.py @@ -8,7 +8,7 @@ import streamlit as st from sqlalchemy.exc import IntegrityError -from testgen.common.models import Session +from testgen.common.models import Session, with_database_session from testgen.common.models.scheduler import JobSchedule from testgen.ui.components import widgets as testgen from testgen.ui.components.widgets import tz_select @@ -33,6 +33,7 @@ def get_arg_value(self, job): def arg_value_input(self) -> tuple[bool, list[Any], dict[str, Any]]: raise NotImplementedError + @with_database_session def open(self, project_code: str) -> None: st.session_state["schedule_form_success"] = None st.session_state["schedule_cron_expr"] = "" @@ -130,7 +131,7 @@ def add_schedule_form(self): # We postpone the validation status update when the previous rerun had a failed # attempt to insert a schedule. This prevents the error message of being overridden if st.session_state.get("schedule_form_success", None) is None: - st.success( + st.info( f"**Next runs:** {' | '.join(sample)} ({cron_tz.replace('_', ' ')})", icon=":material/check:", ) diff --git a/testgen/ui/views/dialogs/profiling_results_dialog.py b/testgen/ui/views/dialogs/profiling_results_dialog.py index 26950f50..3b824a5b 100644 --- a/testgen/ui/views/dialogs/profiling_results_dialog.py +++ b/testgen/ui/views/dialogs/profiling_results_dialog.py @@ -3,7 +3,9 @@ import streamlit as st import testgen.ui.queries.profiling_queries as profiling_queries +from testgen.common.models import with_database_session from testgen.ui.components.widgets.testgen_component import testgen_component +from testgen.utils import make_json_safe def view_profiling_button(column_name: str, table_name: str, table_groups_id: str): @@ -17,11 +19,12 @@ def view_profiling_button(column_name: str, table_name: str, table_groups_id: st @st.dialog(title="Column Profiling Results") +@with_database_session def profiling_results_dialog(column_name: str, table_name: str, table_groups_id: str): column = profiling_queries.get_column_by_name(column_name, table_name, table_groups_id) if column: testgen_component( "column_profiling_results", - props={ "column": json.dumps(column) }, + props={ "column": json.dumps(make_json_safe(column)) }, ) diff --git a/testgen/ui/views/dialogs/run_profiling_dialog.py b/testgen/ui/views/dialogs/run_profiling_dialog.py index 1b6cf22f..3d5b6d6e 100644 --- a/testgen/ui/views/dialogs/run_profiling_dialog.py +++ b/testgen/ui/views/dialogs/run_profiling_dialog.py @@ -1,24 +1,27 @@ import time -import pandas as pd import streamlit as st -import testgen.ui.services.query_service as dq from testgen.commands.run_profiling_bridge import run_profiling_in_background +from testgen.common.models import with_database_session +from testgen.common.models.table_group import TableGroup, TableGroupMinimal from testgen.ui.components import widgets as testgen from testgen.ui.session import session +from testgen.utils import to_dataframe LINK_KEY = "run_profiling_dialog:keys:go-to-runs" LINK_HREF = "profiling-runs" @st.dialog(title="Run Profiling") -def run_profiling_dialog(project_code: str, table_group: pd.Series | None = None, default_table_group_id: str | None = None) -> None: - if table_group is not None and not table_group.empty: - table_group_id: str = table_group["id"] - table_group_name: str = table_group["table_groups_name"] +@with_database_session +def run_profiling_dialog(project_code: str, table_group: TableGroupMinimal | None = None, default_table_group_id: str | None = None) -> None: + if table_group: + table_group_id: str = str(table_group.id) + table_group_name: str = table_group.table_groups_name else: - table_groups_df = get_table_group_options(project_code) + table_groups = TableGroup.select_minimal_where(TableGroup.project_code == project_code) + table_groups_df = to_dataframe(table_groups, TableGroupMinimal.columns()) table_group_id: str = testgen.select( label="Table Group", options=table_groups_df, @@ -79,9 +82,3 @@ def run_profiling_dialog(project_code: str, table_group: pd.Series | None = None time.sleep(2) st.cache_data.clear() st.rerun() - - -@st.cache_data(show_spinner=False) -def get_table_group_options(project_code: str) -> pd.DataFrame: - schema: str = st.session_state["dbschema"] - return dq.run_table_groups_lookup_query(schema, project_code) diff --git a/testgen/ui/views/dialogs/run_tests_dialog.py b/testgen/ui/views/dialogs/run_tests_dialog.py index 93c89dbb..40e3a05e 100644 --- a/testgen/ui/views/dialogs/run_tests_dialog.py +++ b/testgen/ui/views/dialogs/run_tests_dialog.py @@ -1,24 +1,27 @@ import time -import pandas as pd import streamlit as st -import testgen.ui.services.database_service as db from testgen.commands.run_execute_tests import run_execution_steps_in_background +from testgen.common.models import with_database_session +from testgen.common.models.test_suite import TestSuite, TestSuiteMinimal from testgen.ui.components import widgets as testgen from testgen.ui.session import session +from testgen.utils import to_dataframe LINK_KEY = "run_tests_dialog:keys:go-to-runs" LINK_HREF = "test-runs" @st.dialog(title="Run Tests") -def run_tests_dialog(project_code: str, test_suite: pd.Series | None = None, default_test_suite_id: str | None = None) -> None: - if test_suite is not None and not test_suite.empty: - test_suite_id: str = test_suite["id"] - test_suite_name: str = test_suite["test_suite"] +@with_database_session +def run_tests_dialog(project_code: str, test_suite: TestSuiteMinimal | None = None, default_test_suite_id: str | None = None) -> None: + if test_suite: + test_suite_id: str = str(test_suite.id) + test_suite_name: str = test_suite.test_suite else: - test_suites_df = get_test_suite_options(project_code) + test_suites = TestSuite.select_minimal_where(TestSuite.project_code == project_code) + test_suites_df = to_dataframe(test_suites, TestSuiteMinimal.columns()) test_suite_id: str = testgen.select( label="Test Suite", options=test_suites_df, @@ -29,7 +32,7 @@ def run_tests_dialog(project_code: str, test_suite: pd.Series | None = None, def placeholder="Select test suite to run", ) if test_suite_id: - test_suite_name: str = test_suites_df.loc[test_suites_df["id"] == test_suite_id, "test_suite"].iloc[0] + test_suite_name: str = next(item.test_suite for item in test_suites if item.id == test_suite_id) testgen.whitespace(1) if test_suite_id: @@ -83,16 +86,3 @@ def run_tests_dialog(project_code: str, test_suite: pd.Series | None = None, def time.sleep(2) st.cache_data.clear() st.rerun() - - -@st.cache_data(show_spinner=False) -def get_test_suite_options(project_code: str) -> pd.DataFrame: - schema: str = st.session_state["dbschema"] - sql = f""" - SELECT test_suites.id::VARCHAR(50), - test_suites.test_suite - FROM {schema}.test_suites - WHERE test_suites.project_code = '{project_code}' - ORDER BY test_suites.test_suite - """ - return db.retrieve_data(sql) diff --git a/testgen/ui/views/dialogs/table_create_script_dialog.py b/testgen/ui/views/dialogs/table_create_script_dialog.py index 59716632..992568ba 100644 --- a/testgen/ui/views/dialogs/table_create_script_dialog.py +++ b/testgen/ui/views/dialogs/table_create_script_dialog.py @@ -1,39 +1,36 @@ -import pandas as pd import streamlit as st from testgen.ui.components import widgets as testgen @st.dialog(title="Table CREATE Script with Suggested Data Types") -def table_create_script_dialog(table_name: str, data: pd.DataFrame) -> None: - testgen.caption( - f"Table: {table_name}" - ) +def table_create_script_dialog(table_name: str, data: list[dict]) -> None: + testgen.caption(f"Table: {table_name}") st.code(generate_create_script(table_name, data), "sql") -def generate_create_script(table_name: str, data: pd.DataFrame) -> str: - df = data[data["table_name"] == table_name][["schema_name", "table_name", "column_name", "column_type", "datatype_suggestion"]] - df = df.copy().reset_index(drop=True) - df.fillna("", inplace=True) - - df["comment"] = df.apply( - lambda row: f"-- WAS {row['column_type']}" - if isinstance(row["column_type"], str) - and isinstance(row["datatype_suggestion"], str) - and row["column_type"].lower() != row["datatype_suggestion"].lower() - else "", - axis=1, - ) - max_len_name = df.apply(lambda row: len(row["column_name"]), axis=1).max() + 3 - max_len_type = df.apply(lambda row: len(row["datatype_suggestion"]), axis=1).max() + 3 - - header = f"CREATE TABLE {df.at[0, 'schema_name']}.{df.at[0, 'table_name']} ( " - col_defs = df.apply( - lambda row: f" {row['column_name']:<{max_len_name}} {row['datatype_suggestion']:<{max_len_type}}, {row['comment']}", - axis=1, - ).tolist() - footer = ");" - col_defs[-1] = col_defs[-1].replace(", --", " --") - - return "\n".join([header, *list(col_defs), footer]) +def generate_create_script(table_name: str, data: list[dict]) -> str | None: + table_data = [col for col in data if col["table_name"] == table_name] + if not table_data: + return None + + max_name = max(len(col["column_name"]) for col in table_data) + 3 + max_type = max(len(col["datatype_suggestion"] or "") for col in table_data) + 3 + + col_defs = [] + for index, col in enumerate(table_data): + comment = ( + f"-- WAS {col['column_type']}" + if isinstance(col["column_type"], str) + and isinstance(col["datatype_suggestion"], str) + and col["column_type"].lower() != col["datatype_suggestion"].lower() + else "" + ) + col_type = col["datatype_suggestion"] or col["column_type"] or "" + separator = " " if index == len(table_data) - 1 else "," + col_defs.append(f"{col['column_name']:<{max_name}} {(col_type):<{max_type}}{separator} {comment}") + + return f""" +CREATE TABLE {table_data[0]['schema_name']}.{table_data[0]['table_name']} ( + {"\n ".join(col_defs)} +);""" diff --git a/testgen/ui/views/hygiene_issues.py b/testgen/ui/views/hygiene_issues.py index dedea1f4..a46f2185 100644 --- a/testgen/ui/views/hygiene_issues.py +++ b/testgen/ui/views/hygiene_issues.py @@ -5,14 +5,12 @@ import pandas as pd import streamlit as st -import testgen.ui.queries.profiling_queries as profiling_queries -import testgen.ui.services.database_service as db import testgen.ui.services.form_service as fm -import testgen.ui.services.query_service as dq from testgen.commands.run_rollup_scores import run_profile_rollup_scoring_queries from testgen.common import date_service from testgen.common.mixpanel_service import MixpanelService -from testgen.common.models import get_current_session +from testgen.common.models import with_database_session +from testgen.common.models.profiling_run import ProfilingRun from testgen.ui.components import widgets as testgen from testgen.ui.components.widgets.download_dialog import ( FILE_DATA_TYPE, @@ -24,8 +22,13 @@ from testgen.ui.components.widgets.page import css_class, flex_row_end from testgen.ui.navigation.page import Page from testgen.ui.pdf.hygiene_issue_report import create_report -from testgen.ui.services import project_service, user_session_service -from testgen.ui.services.hygiene_issues_service import get_source_data as get_source_data_uncached +from testgen.ui.queries.source_data_queries import get_hygiene_issue_source_data +from testgen.ui.services import user_session_service +from testgen.ui.services.database_service import ( + execute_db_query, + fetch_df_from_db, + fetch_one_from_db, +) from testgen.ui.session import session from testgen.ui.views.dialogs.profiling_results_dialog import view_profiling_button from testgen.utils import friendly_score @@ -42,40 +45,48 @@ class HygieneIssuesPage(Page): def render( self, run_id: str, - issue_class: str | None = None, + likelihood: str | None = None, issue_type: str | None = None, table_name: str | None = None, column_name: str | None = None, action: str | None = None, **_kwargs, ) -> None: - run_df = profiling_queries.get_run_by_id(run_id) - if run_df.empty: + run = ProfilingRun.get_minimal(run_id) + if not run: self.router.navigate_with_warning( f"Profiling run with ID '{run_id}' does not exist. Redirecting to list of Profiling Runs ...", "profiling-runs", ) return - run_date = date_service.get_timezoned_timestamp(st.session_state, run_df["profiling_starttime"]) - project_service.set_sidebar_project(run_df["project_code"]) + run_date = date_service.get_timezoned_timestamp(st.session_state, run.profiling_starttime) + session.set_sidebar_project(run.project_code) testgen.page_header( "Hygiene Issues", "view-hygiene-issues", breadcrumbs=[ - { "label": "Profiling Runs", "path": "profiling-runs", "params": { "project_code": run_df["project_code"] } }, - { "label": f"{run_df['table_groups_name']} | {run_date}" }, + { "label": "Profiling Runs", "path": "profiling-runs", "params": { "project_code": run.project_code } }, + { "label": f"{run.table_groups_name} | {run_date}" }, ], ) others_summary_column, pii_summary_column, score_column, actions_column, export_button_column = st.columns([.2, .2, .15, .3, .15], vertical_alignment="bottom") - (table_filter_column, column_filter_column, issue_type_filter_column, liklihood_filter_column, action_filter_column, sort_column) = ( - st.columns([.15, .2, .2, .2, .15, .1], vertical_alignment="bottom") + (liklihood_filter_column, table_filter_column, column_filter_column, issue_type_filter_column, action_filter_column, sort_column) = ( + st.columns([.2, .15, .2, .2, .15, .1], vertical_alignment="bottom") ) testgen.flex_row_end(actions_column) testgen.flex_row_end(export_button_column) + with liklihood_filter_column: + likelihood = testgen.select( + options=["Definite", "Likely", "Possible", "Potential PII"], + default_value=likelihood, + bind_to_query="likelihood", + label="Likelihood", + ) + run_columns_df = get_profiling_run_columns(run_id) with table_filter_column: table_name = testgen.select( @@ -122,29 +133,22 @@ def render( ) issue_type_id = testgen.select( options=issue_type_options, - default_value=None if issue_class == "Potential PII" else issue_type, + default_value=None if likelihood == "Potential PII" else issue_type, value_column="anomaly_id", display_column="anomaly_name", bind_to_query="issue_type", label="Issue Type", - disabled=issue_class == "Potential PII", - ) - - with liklihood_filter_column: - issue_class = testgen.select( - options=["Definite", "Likely", "Possible", "Potential PII"], - default_value=issue_class, - bind_to_query="issue_class", - label="Likelihood", + disabled=likelihood == "Potential PII", ) with action_filter_column: action = testgen.select( - options=["✓ Confirmed", "✘ Dismissed", "🔇 Muted", "â†Šī¸Ž No Action"], + options=["✓ Confirmed", "✘ Dismissed", "🔇 Muted", "â†Šī¸Ž No Action"], default_value=action, bind_to_query="action", label="Action", ) + action = action.split(" ", 1)[1] if action else None with sort_column: sortable_columns = ( @@ -164,7 +168,7 @@ def render( with st.container(): with st.spinner("Loading data ..."): # Get hygiene issue list - df_pa = get_profiling_anomalies(run_id, issue_class, issue_type_id, table_name, column_name, action, sorting_columns) + df_pa = get_profiling_anomalies(run_id, likelihood, issue_type_id, table_name, column_name, action, sorting_columns) # Retrieve disposition action (cache refreshed) df_action = get_anomaly_disposition(run_id) @@ -193,9 +197,6 @@ def render( width=400, ) - with score_column: - render_score(run_df["project_code"], run_id) - lst_show_columns = [ "table_name", "column_name", @@ -234,7 +235,7 @@ def open_download_dialog(data: pd.DataFrame | None = None) -> None: download_dialog( dialog_title="Download Excel Report", file_content_func=get_excel_report_data, - args=(run_df["table_groups_name"], run_date, run_id, data), + args=(run.table_groups_name, run_date, run_id, data), ) with popover_container.container(key="tg--export-popover"): @@ -261,7 +262,7 @@ def open_download_dialog(data: pd.DataFrame | None = None) -> None: with buttons_column: col1, col2, col3 = st.columns([.3, .3, .3]) - + with col1: view_profiling_button( selected_row["column_name"], selected_row["table_name"], selected_row["table_groups_id"] @@ -351,19 +352,25 @@ def open_download_dialog(data: pd.DataFrame | None = None) -> None: lst_cached_functions=cached_functions, ) + # Needs to be after all data loading/updating + # Otherwise the database session is lost for any queries after the fragment -_- + with score_column: + render_score(run.project_code, run_id) + # Help Links st.markdown( "[Help on Hygiene Issues](https://docs.datakitchen.io/article/dataops-testgen-help/data-hygiene-issues)" ) @st.fragment +@with_database_session def render_score(project_code: str, run_id: str): - run_df = profiling_queries.get_run_by_id(run_id) + run = ProfilingRun.get_minimal(run_id) testgen.flex_row_center() with st.container(): testgen.caption("Score", "text-align: center;") testgen.text( - friendly_score(run_df["dq_score_profiling"]) or "--", + friendly_score(run.dq_score_profiling) or "--", "font-size: 28px;", ) @@ -374,12 +381,12 @@ def render_score(project_code: str, run_id: str): style="color: var(--secondary-text-color);", icon="autorenew", icon_size=22, - tooltip=f"Recalculate scores for run {'and table group' if run_df["is_latest_run"] else ''}", + tooltip=f"Recalculate scores for run {'and table group' if run.is_latest_run else ''}", on_click=partial( refresh_score, project_code, run_id, - run_df["table_groups_id"] if run_df["is_latest_run"] else None, + run.table_groups_id if run.is_latest_run else None, ), ) @@ -391,15 +398,14 @@ def refresh_score(project_code: str, run_id: str, table_group_id: str | None) -> @st.cache_data(show_spinner=False) def get_profiling_run_columns(profiling_run_id: str) -> pd.DataFrame: - schema: str = st.session_state["dbschema"] - sql = f""" + query = """ SELECT r.table_name table_name, r.column_name column_name, r.anomaly_id anomaly_id, t.anomaly_name anomaly_name - FROM {schema}.profile_anomaly_results r - LEFT JOIN {schema}.profile_anomaly_types t on t.id = r.anomaly_id - WHERE r.profile_run_id = '{profiling_run_id}' + FROM profile_anomaly_results r + LEFT JOIN profile_anomaly_types t on t.id = r.anomaly_id + WHERE r.profile_run_id = :profiling_run_id ORDER BY r.table_name, r.column_name; """ - return db.retrieve_data(sql) + return fetch_df_from_db(query, {"profiling_run_id": profiling_run_id}) @st.cache_data(show_spinner=False) @@ -409,40 +415,10 @@ def get_profiling_anomalies( issue_type_id: str | None = None, table_name: str | None = None, column_name: str | None = None, - action: str | None = None, + action: typing.Literal["Confirmed", "Dismissed", "Muted", "No Action"] | None = None, sorting_columns: list[str] | None = None, -): - db_session = get_current_session() - criteria = "" - order_by = "" - params = {"profile_run_id": profile_run_id} - - if likelihood: - criteria += " AND t.issue_likelihood = :likelihood" - params["likelihood"] = likelihood - if issue_type_id: - criteria += " AND t.id = :issue_type_id" - params["issue_type_id"] = issue_type_id - if table_name: - criteria += " AND r.table_name = :table_name" - params["table_name"] = table_name - if column_name: - criteria += " AND r.column_name ILIKE :column_name" - params["column_name"] = column_name - if action: - action = action.split(" ", 1)[1] - if action == "No Action": - criteria += " AND r.disposition IS NULL" - else: - action_disposition_converter = {"Muted": "Inactive"} - criteria += " AND r.disposition = :disposition_name" - params["disposition_name"] = action_disposition_converter.get(action, action) - - if sorting_columns: - order_by = "ORDER BY " + (", ".join(" ".join(col) for col in sorting_columns)) - - # Define the query -- first visible column must be first, because will hold the multi-select box - str_sql = f""" +) -> pd.DataFrame: + query = f""" SELECT r.table_name, r.column_name, @@ -465,14 +441,8 @@ def get_profiling_anomalies( WHEN t.issue_likelihood = 'Likely' THEN 3 WHEN t.issue_likelihood = 'Definite' THEN 4 END AS likelihood_order, - t.anomaly_description, - r.detail, - t.suggested_action, - r.anomaly_id, - r.table_groups_id::VARCHAR, - r.id::VARCHAR, - p.profiling_starttime, - r.profile_run_id::VARCHAR, + t.anomaly_description, r.detail, t.suggested_action, + r.anomaly_id, r.table_groups_id::VARCHAR, r.id::VARCHAR, p.profiling_starttime, r.profile_run_id::VARCHAR, tg.table_groups_name, -- These are used in the PDF report @@ -487,7 +457,6 @@ def get_profiling_anomalies( COALESCE(dcc.transform_level, dtc.transform_level, tg.transform_level) as transform_level, COALESCE(dcc.aggregation_level, dtc.aggregation_level) as aggregation_level, COALESCE(dcc.data_product, dtc.data_product, tg.data_product) as data_product - FROM profile_anomaly_results r INNER JOIN profile_anomaly_types t ON r.anomaly_id = t.id @@ -497,20 +466,30 @@ def get_profiling_anomalies( ON r.table_groups_id = tg.id LEFT JOIN data_column_chars dcc ON (tg.id = dcc.table_groups_id - AND r.schema_name = dcc.schema_name - AND r.table_name = dcc.table_name - AND r.column_name = dcc.column_name) + AND r.schema_name = dcc.schema_name + AND r.table_name = dcc.table_name + AND r.column_name = dcc.column_name) LEFT JOIN data_table_chars dtc ON dcc.table_id = dtc.table_id WHERE r.profile_run_id = :profile_run_id - {criteria} - {order_by} + {"AND t.issue_likelihood = :likelihood" if likelihood else ""} + {"AND t.id = :issue_type_id" if issue_type_id else ""} + {"AND r.table_name = :table_name" if table_name else ""} + {"AND r.column_name ILIKE :column_name" if column_name else ""} + {"AND r.disposition IS NULL" if action == "No Action" else "AND r.disposition = :disposition" if action else ""} + {f"ORDER BY {', '.join(' '.join(col) for col in sorting_columns)}" if sorting_columns else ""} """ - - results = db_session.execute(str_sql, params=params) - columns = [column.name for column in results.cursor.description] - - df = pd.DataFrame(list(results), columns=columns) + params = { + "profile_run_id": profile_run_id, + "likelihood": likelihood, + "issue_type_id": issue_type_id, + "table_name": table_name, + "column_name": column_name, + "disposition": { + "Muted": "Inactive", + }.get(action, action), + } + df = fetch_df_from_db(query, params) dct_replace = {"Confirmed": "✓", "Dismissed": "✘", "Inactive": "🔇"} df["action"] = df["disposition"].replace(dct_replace) @@ -518,15 +497,13 @@ def get_profiling_anomalies( @st.cache_data(show_spinner=False) -def get_anomaly_disposition(str_profile_run_id): - str_schema = st.session_state["dbschema"] - str_sql = f""" - SELECT id::VARCHAR, disposition - FROM {str_schema}.profile_anomaly_results s - WHERE s.profile_run_id = '{str_profile_run_id}'; +def get_anomaly_disposition(profile_run_id: str) -> pd.DataFrame: + query = """ + SELECT id::VARCHAR, disposition + FROM profile_anomaly_results s + WHERE s.profile_run_id = :profile_run_id; """ - # Retrieve data as df - df = db.retrieve_data(str_sql) + df = fetch_df_from_db(query, {"profile_run_id": profile_run_id}) dct_replace = {"Confirmed": "✓", "Dismissed": "✘", "Inactive": "🔇", "Passed": ""} df["action"] = df["disposition"].replace(dct_replace) @@ -534,45 +511,44 @@ def get_anomaly_disposition(str_profile_run_id): @st.cache_data(show_spinner=False) -def get_profiling_anomaly_summary(str_profile_run_id): - str_schema = st.session_state["dbschema"] - # Define the query - str_sql = f""" - SELECT - schema_name, - COUNT(DISTINCT s.table_name) as table_ct, - COUNT(DISTINCT s.column_name) as column_ct, - COUNT(*) as issue_ct, - SUM(CASE WHEN COALESCE(s.disposition, 'Confirmed') = 'Confirmed' - AND t.issue_likelihood = 'Definite' THEN 1 ELSE 0 END) as definite_ct, - SUM(CASE WHEN COALESCE(s.disposition, 'Confirmed') = 'Confirmed' - AND t.issue_likelihood = 'Likely' THEN 1 ELSE 0 END) as likely_ct, - SUM(CASE WHEN COALESCE(s.disposition, 'Confirmed') = 'Confirmed' - AND t.issue_likelihood = 'Possible' THEN 1 ELSE 0 END) as possible_ct, - SUM(CASE WHEN COALESCE(s.disposition, 'Confirmed') - IN ('Dismissed', 'Inactive') - AND t.issue_likelihood <> 'Potential PII' THEN 1 ELSE 0 END) as dismissed_ct, - SUM(CASE WHEN COALESCE(s.disposition, 'Confirmed') = 'Confirmed' AND t.issue_likelihood = 'Potential PII' AND s.detail LIKE 'Risk: HIGH%%' THEN 1 ELSE 0 END) as pii_high_ct, - SUM(CASE WHEN COALESCE(s.disposition, 'Confirmed') = 'Confirmed' AND t.issue_likelihood = 'Potential PII' AND s.detail LIKE 'Risk: MODERATE%%' THEN 1 ELSE 0 END) as pii_moderate_ct, - SUM(CASE WHEN COALESCE(s.disposition, 'Confirmed') IN ('Dismissed', 'Inactive') AND t.issue_likelihood = 'Potential PII' THEN 1 ELSE 0 END) as pii_dismissed_ct - FROM {str_schema}.profile_anomaly_results s - LEFT JOIN {str_schema}.profile_anomaly_types t ON (s.anomaly_id = t.id) - WHERE s.profile_run_id = '{str_profile_run_id}' - GROUP BY schema_name; +def get_profiling_anomaly_summary(profile_run_id: str) -> list[dict]: + query = """ + SELECT + schema_name, + COUNT(DISTINCT s.table_name) as table_ct, + COUNT(DISTINCT s.column_name) as column_ct, + COUNT(*) as issue_ct, + SUM(CASE WHEN COALESCE(s.disposition, 'Confirmed') = 'Confirmed' + AND t.issue_likelihood = 'Definite' THEN 1 ELSE 0 END) as definite_ct, + SUM(CASE WHEN COALESCE(s.disposition, 'Confirmed') = 'Confirmed' + AND t.issue_likelihood = 'Likely' THEN 1 ELSE 0 END) as likely_ct, + SUM(CASE WHEN COALESCE(s.disposition, 'Confirmed') = 'Confirmed' + AND t.issue_likelihood = 'Possible' THEN 1 ELSE 0 END) as possible_ct, + SUM(CASE WHEN COALESCE(s.disposition, 'Confirmed') + IN ('Dismissed', 'Inactive') + AND t.issue_likelihood <> 'Potential PII' THEN 1 ELSE 0 END) as dismissed_ct, + SUM(CASE WHEN COALESCE(s.disposition, 'Confirmed') = 'Confirmed' AND t.issue_likelihood = 'Potential PII' AND s.detail LIKE 'Risk: HIGH%%' THEN 1 ELSE 0 END) as pii_high_ct, + SUM(CASE WHEN COALESCE(s.disposition, 'Confirmed') = 'Confirmed' AND t.issue_likelihood = 'Potential PII' AND s.detail LIKE 'Risk: MODERATE%%' THEN 1 ELSE 0 END) as pii_moderate_ct, + SUM(CASE WHEN COALESCE(s.disposition, 'Confirmed') IN ('Dismissed', 'Inactive') AND t.issue_likelihood = 'Potential PII' THEN 1 ELSE 0 END) as pii_dismissed_ct + FROM profile_anomaly_results s + LEFT JOIN profile_anomaly_types t ON (s.anomaly_id = t.id) + WHERE s.profile_run_id = :profile_run_id + GROUP BY schema_name; """ - df = db.retrieve_data(str_sql) + result = fetch_one_from_db(query, {"profile_run_id": profile_run_id}) return [ - { "label": "Definite", "value": int(df.at[0, "definite_ct"]), "color": "red" }, - { "label": "Likely", "value": int(df.at[0, "likely_ct"]), "color": "orange" }, - { "label": "Possible", "value": int(df.at[0, "possible_ct"]), "color": "yellow" }, - { "label": "Dismissed", "value": int(df.at[0, "dismissed_ct"]), "color": "grey" }, - { "label": "High Risk", "value": int(df.at[0, "pii_high_ct"]), "color": "red", "type": "PII" }, - { "label": "Moderate Risk", "value": int(df.at[0, "pii_moderate_ct"]), "color": "orange", "type": "PII" }, - { "label": "Dismissed", "value": int(df.at[0, "pii_dismissed_ct"]), "color": "grey", "type": "PII" }, + { "label": "Definite", "value": result.definite_ct, "color": "red" }, + { "label": "Likely", "value": result.likely_ct, "color": "orange" }, + { "label": "Possible", "value": result.possible_ct, "color": "yellow" }, + { "label": "Dismissed", "value": result.dismissed_ct, "color": "grey" }, + { "label": "High Risk", "value": result.pii_high_ct, "color": "red", "type": "PII" }, + { "label": "Moderate Risk", "value": result.pii_moderate_ct, "color": "orange", "type": "PII" }, + { "label": "Dismissed", "value": result.pii_dismissed_ct, "color": "grey", "type": "PII" }, ] +@with_database_session def get_excel_report_data( update_progress: PROGRESS_UPDATE_TYPE, table_group: str, @@ -603,12 +579,8 @@ def get_excel_report_data( ) -@st.cache_data(show_spinner=False) -def get_source_data(hi_data, limit): - return get_source_data_uncached(hi_data, limit) - - @st.dialog(title="Source Data") +@with_database_session def source_data_dialog(selected_row): st.markdown(f"#### {selected_row['anomaly_name']}") st.caption(selected_row["anomaly_description"]) @@ -618,7 +590,7 @@ def source_data_dialog(selected_row): fm.render_html_list(selected_row, ["detail"], None, 700, ["Hygiene Issue Detail"]) with st.spinner("Retrieving source data..."): - bad_data_status, bad_data_msg, _, df_bad = get_source_data(selected_row, limit=500) + bad_data_status, bad_data_msg, _, df_bad = get_hygiene_issue_source_data(selected_row, limit=500) if bad_data_status in {"ND", "NA"}: st.info(bad_data_msg) elif bad_data_status == "ERR": @@ -645,8 +617,7 @@ def do_disposition_update(selected, str_new_status): elif len(selected) == 1: str_which = f"of one issue to {str_new_status}" - str_schema = st.session_state["dbschema"] - if not dq.update_anomaly_disposition(selected, str_schema, str_new_status): + if not update_anomaly_disposition(selected, str_new_status): str_result = f":red[**The update {str_which} did not succeed.**]" return str_result @@ -661,3 +632,27 @@ def get_report_file_data(update_progress, tr_data) -> FILE_DATA_TYPE: update_progress(1.0) buffer.seek(0) return file_name, "application/pdf", buffer.read() + + +def update_anomaly_disposition( + selected: list[dict], + disposition: typing.Literal["Confirmed", "Dismissed", "Inactive", "No Decision"], +): + execute_db_query( + """ + WITH selects + AS (SELECT UNNEST(ARRAY [:anomaly_result_ids]) AS selected_id) + UPDATE profile_anomaly_results + SET disposition = NULLIF(:disposition, 'No Decision') + FROM profile_anomaly_results r + INNER JOIN selects s + ON (r.id = s.selected_id::UUID) + WHERE r.id = profile_anomaly_results.id; + """, + { + "anomaly_result_ids": [row["id"] for row in selected if "id"], + "disposition": disposition, + } + ) + + return True diff --git a/testgen/ui/views/profiling_results.py b/testgen/ui/views/profiling_results.py index 9b4f2142..faff90f5 100644 --- a/testgen/ui/views/profiling_results.py +++ b/testgen/ui/views/profiling_results.py @@ -7,10 +7,10 @@ import streamlit as st import testgen.ui.queries.profiling_queries as profiling_queries -import testgen.ui.services.database_service as db import testgen.ui.services.form_service as fm from testgen.common import date_service from testgen.common.models import with_database_session +from testgen.common.models.profiling_run import ProfilingRun from testgen.ui.components import widgets as testgen from testgen.ui.components.widgets.download_dialog import ( FILE_DATA_TYPE, @@ -21,7 +21,8 @@ from testgen.ui.components.widgets.page import css_class, flex_row_end from testgen.ui.components.widgets.testgen_component import testgen_component from testgen.ui.navigation.page import Page -from testgen.ui.services import project_service, user_session_service +from testgen.ui.services import user_session_service +from testgen.ui.services.database_service import fetch_df_from_db from testgen.ui.session import session from testgen.ui.views.dialogs.data_preview_dialog import data_preview_dialog @@ -37,23 +38,23 @@ class ProfilingResultsPage(Page): ] def render(self, run_id: str, table_name: str | None = None, column_name: str | None = None, **_kwargs) -> None: - run_df = profiling_queries.get_run_by_id(run_id) - if run_df.empty: + run = ProfilingRun.get_minimal(run_id) + if not run: self.router.navigate_with_warning( f"Profiling run with ID '{run_id}' does not exist. Redirecting to list of Profiling Runs ...", "profiling-runs", ) return - run_date = date_service.get_timezoned_timestamp(st.session_state, run_df["profiling_starttime"]) - project_service.set_sidebar_project(run_df["project_code"]) + run_date = date_service.get_timezoned_timestamp(st.session_state, run.profiling_starttime) + session.set_sidebar_project(run.project_code) testgen.page_header( "Data Profiling Results", "view-data-profiling-results", breadcrumbs=[ - { "label": "Profiling Runs", "path": "profiling-runs", "params": { "project_code": run_df["project_code"] } }, - { "label": f"{run_df['table_groups_name']} | {run_date}" }, + { "label": "Profiling Runs", "path": "profiling-runs", "params": { "project_code": run.project_code } }, + { "label": f"{run.table_groups_name} | {run_date}" }, ], ) @@ -143,7 +144,7 @@ def open_download_dialog(data: pd.DataFrame | None = None) -> None: download_dialog( dialog_title="Download Excel Report", file_content_func=get_excel_report_data, - args=(run_df["table_groups_name"], run_date, run_id, data), + args=(run.table_groups_name, run_date, run_id, data), ) with popover_container.container(key="tg--export-popover"): @@ -281,25 +282,27 @@ def get_excel_report_data( @st.cache_data(show_spinner=False) -def get_profiling_run_tables(profiling_run_id: str): - schema: str = st.session_state["dbschema"] - query = f""" +def get_profiling_run_tables(profiling_run_id: str) -> pd.DataFrame: + query = """ SELECT DISTINCT table_name - FROM {schema}.profile_results - WHERE profile_run_id = '{profiling_run_id}' - ORDER BY table_name + FROM profile_results + WHERE profile_run_id = :profiling_run_id + ORDER BY table_name; """ - return db.retrieve_data(query) + return fetch_df_from_db(query, {"profiling_run_id": profiling_run_id}) @st.cache_data(show_spinner=False) -def get_profiling_run_columns(profiling_run_id: str, table_name: str): - schema: str = st.session_state["dbschema"] - query = f""" +def get_profiling_run_columns(profiling_run_id: str, table_name: str) -> pd.DataFrame: + query = """ SELECT DISTINCT column_name - FROM {schema}.profile_results - WHERE profile_run_id = '{profiling_run_id}' - AND table_name = '{table_name}' - ORDER BY column_name + FROM profile_results + WHERE profile_run_id = :profiling_run_id + AND table_name = :table_name + ORDER BY column_name; """ - return db.retrieve_data(query) + params = { + "profiling_run_id": profiling_run_id, + "table_name": table_name or "", + } + return fetch_df_from_db(query, params) diff --git a/testgen/ui/views/profiling_runs.py b/testgen/ui/views/profiling_runs.py index 263b4b4d..540623b8 100644 --- a/testgen/ui/views/profiling_runs.py +++ b/testgen/ui/views/profiling_runs.py @@ -1,25 +1,26 @@ +import json import logging import typing +from collections.abc import Iterable from functools import partial -import pandas as pd import streamlit as st import testgen.common.process_service as process_service -import testgen.ui.services.database_service as db import testgen.ui.services.form_service as fm -import testgen.ui.services.query_service as dq from testgen.common.models import with_database_session +from testgen.common.models.profiling_run import ProfilingRun +from testgen.common.models.project import Project +from testgen.common.models.table_group import TableGroup, TableGroupMinimal from testgen.ui.components import widgets as testgen from testgen.ui.components.widgets import testgen_component from testgen.ui.navigation.menu import MenuItem from testgen.ui.navigation.page import Page -from testgen.ui.queries import profiling_run_queries, project_queries from testgen.ui.services import user_session_service from testgen.ui.session import session, temp_value from testgen.ui.views.dialogs.manage_schedules import ScheduleDialog from testgen.ui.views.dialogs.run_profiling_dialog import run_profiling_dialog -from testgen.utils import friendly_score, to_int +from testgen.utils import friendly_score, to_dataframe, to_int LOG = logging.getLogger("testgen") FORM_DATA_WIDTH = 400 @@ -56,7 +57,9 @@ def render(self, project_code: str, table_group_id: str | None = None, **_kwargs group_filter_column, actions_column = st.columns([.3, .7], vertical_alignment="bottom") with group_filter_column: - table_groups_df = get_db_table_group_choices(project_code) + table_groups = TableGroup.select_minimal_where(TableGroup.project_code == project_code) + table_groups_df = to_dataframe(table_groups, TableGroupMinimal.columns()) + table_groups_df["id"] = table_groups_df["id"].apply(lambda x: str(x)) table_group_id = testgen.select( options=table_groups_df, value_column="id", @@ -87,18 +90,25 @@ def render(self, project_code: str, table_group_id: str | None = None, **_kwargs testgen.whitespace(0.5) list_container = st.container() - profiling_runs_df = get_db_profiling_runs(project_code, table_group_id) + with st.spinner("Loading data ..."): + profiling_runs = ProfilingRun.select_summary(project_code, table_group_id) - run_count = len(profiling_runs_df) - page_index = testgen.paginator(count=run_count, page_size=PAGE_SIZE) - profiling_runs_df["dq_score_profiling"] = profiling_runs_df["dq_score_profiling"].map(lambda score: friendly_score(score)) - paginated_df = profiling_runs_df[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)] + paginated = [] + if run_count := len(profiling_runs): + page_index = testgen.paginator(count=run_count, page_size=PAGE_SIZE) + profiling_runs = [ + { + **row.to_dict(json_safe=True), + "dq_score_profiling": friendly_score(row.dq_score_profiling), + } for row in profiling_runs + ] + paginated = profiling_runs[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)] with list_container: testgen_component( "profiling_runs", props={ - "items": paginated_df.to_json(orient="records"), + "items": json.dumps(paginated), "permissions": { "can_run": user_can_run, "can_edit": user_can_run, @@ -116,36 +126,35 @@ class ProfilingScheduleDialog(ScheduleDialog): title = "Profiling Schedules" arg_label = "Table Group" job_key = "run-profile" - table_groups: pd.DataFrame | None = None + table_groups: Iterable[TableGroupMinimal] | None = None def init(self) -> None: - self.table_groups = get_db_table_group_choices(self.project_code) + self.table_groups = TableGroup.select_minimal_where(TableGroup.project_code == self.project_code) def get_arg_value(self, job): - return self.table_groups.loc[ - self.table_groups["id"] == job.kwargs["table_group_id"], "table_groups_name" - ].iloc[0] + return next(item.table_groups_name for item in self.table_groups if str(item.id) == job.kwargs["table_group_id"]) def arg_value_input(self) -> tuple[bool, list[typing.Any], dict[str, typing.Any]]: + table_groups_df = to_dataframe(self.table_groups, TableGroupMinimal.columns()) tg_id = testgen.select( label="Table Group", - options=self.table_groups, + options=table_groups_df, value_column="id", display_column="table_groups_name", required=True, placeholder="Select table group", ) - return bool(tg_id), [], {"table_group_id": tg_id} + return bool(tg_id), [], {"table_group_id": str(tg_id)} def render_empty_state(project_code: str, user_can_run: bool) -> bool: - project_summary_df = project_queries.get_summary_by_code(project_code) - if project_summary_df["profiling_runs_ct"]: + project_summary = Project.get_summary(project_code) + if project_summary.profiling_run_count: return False label = "No profiling runs yet" testgen.whitespace(5) - if not project_summary_df["connections_ct"]: + if not project_summary.connection_count: testgen.empty_state( label=label, icon=PAGE_ICON, @@ -154,7 +163,7 @@ def render_empty_state(project_code: str, user_can_run: bool) -> bool: link_href="connections", link_params={ "project_code": project_code }, ) - elif not project_summary_df["table_groups_ct"]: + elif not project_summary.table_group_count: testgen.empty_state( label=label, icon=PAGE_ICON, @@ -163,7 +172,7 @@ def render_empty_state(project_code: str, user_can_run: bool) -> bool: link_href="table-groups", link_params={ "project_code": project_code, - "connection_id": str(project_summary_df["default_connection_id"]), + "connection_id": str(project_summary.default_connection_id), }, ) else: @@ -179,10 +188,10 @@ def render_empty_state(project_code: str, user_can_run: bool) -> bool: return True -def on_cancel_run(profiling_run: pd.Series) -> None: +def on_cancel_run(profiling_run: dict) -> None: process_status, process_message = process_service.kill_profile_run(to_int(profiling_run["process_id"])) if process_status: - profiling_run_queries.update_status(profiling_run["profiling_run_id"], "Cancelled") + ProfilingRun.update_status(profiling_run["profiling_run_id"], "Cancelled") fm.reset_post_updates(str_message=f":{'green' if process_status else 'red'}[{process_message}]", as_toast=True) @@ -202,6 +211,9 @@ def on_delete_confirmed(*_args) -> None: message = "Are you sure you want to delete the selected profiling run?" constraint["confirmation"] = "Yes, cancel and delete the profiling run." + if not ProfilingRun.has_running_process(profiling_run_ids): + constraint = None + result, set_result = temp_value("profiling-runs:result-value", default=None) delete_confirmed, set_delete_confirmed = temp_value("profiling-runs:confirm-delete", default=False) @@ -223,14 +235,13 @@ def on_delete_confirmed(*_args) -> None: if delete_confirmed(): try: with st.spinner("Deleting runs ..."): - profiling_runs = get_db_profiling_runs(project_code, table_group_id, profiling_run_ids=profiling_run_ids) - for _, profiling_run in profiling_runs.iterrows(): - profiling_run_id = profiling_run["profiling_run_id"] - if profiling_run["status"] == "Running": - process_status, process_message = process_service.kill_profile_run(to_int(profiling_run["process_id"])) + profiling_runs = ProfilingRun.select_summary(project_code, table_group_id, profiling_run_ids) + for profiling_run in profiling_runs: + if profiling_run.status == "Running": + process_status, _ = process_service.kill_profile_run(to_int(profiling_run.process_id)) if process_status: - profiling_run_queries.update_status(profiling_run_id, "Cancelled") - profiling_run_queries.cascade_delete_multiple_profiling_runs(profiling_run_ids) + ProfilingRun.update_status(profiling_run.profiling_run_id, "Cancelled") + ProfilingRun.cascade_delete(profiling_run_ids) st.rerun() except Exception: LOG.exception("Failed to delete profiling runs") @@ -239,87 +250,3 @@ def on_delete_confirmed(*_args) -> None: "message": "Unable to delete the selected profiling runs, try again.", }) st.rerun(scope="fragment") - - -@st.cache_data(show_spinner=False) -def get_db_table_group_choices(project_code: str) -> pd.DataFrame: - schema = st.session_state["dbschema"] - return dq.run_table_groups_lookup_query(schema, project_code) - - -@st.cache_data(show_spinner="Loading data ...") -def get_db_profiling_runs( - project_code: str, - table_group_id: str | None = None, - profiling_run_ids: list[str] | None = None, -) -> pd.DataFrame: - schema = st.session_state["dbschema"] - table_group_condition = f" AND v_profiling_runs.table_groups_id = '{table_group_id}' " if table_group_id else "" - - profling_runs_condition = "" - if profiling_run_ids and len(profiling_run_ids) > 0: - profiling_run_ids_ = [f"'{run_id}'" for run_id in profiling_run_ids] - profling_runs_condition = f" AND v_profiling_runs.profiling_run_id::VARCHAR IN ({', '.join(profiling_run_ids_)})" - - sql = f""" - WITH profile_anomalies AS ( - SELECT profile_anomaly_results.profile_run_id, - SUM( - CASE - WHEN COALESCE(profile_anomaly_results.disposition, 'Confirmed') = 'Confirmed' - AND profile_anomaly_types.issue_likelihood = 'Definite' THEN 1 - ELSE 0 - END - ) as definite_ct, - SUM( - CASE - WHEN COALESCE(profile_anomaly_results.disposition, 'Confirmed') = 'Confirmed' - AND profile_anomaly_types.issue_likelihood = 'Likely' THEN 1 - ELSE 0 - END - ) as likely_ct, - SUM( - CASE - WHEN COALESCE(profile_anomaly_results.disposition, 'Confirmed') = 'Confirmed' - AND profile_anomaly_types.issue_likelihood = 'Possible' THEN 1 - ELSE 0 - END - ) as possible_ct, - SUM( - CASE - WHEN COALESCE(profile_anomaly_results.disposition, 'Confirmed') IN ('Dismissed', 'Inactive') - AND profile_anomaly_types.issue_likelihood <> 'Potential PII' THEN 1 - ELSE 0 - END - ) as dismissed_ct - FROM {schema}.profile_anomaly_results - LEFT JOIN {schema}.profile_anomaly_types ON ( - profile_anomaly_types.id = profile_anomaly_results.anomaly_id - ) - GROUP BY profile_anomaly_results.profile_run_id - ) - SELECT v_profiling_runs.profiling_run_id::VARCHAR, - v_profiling_runs.start_time, - v_profiling_runs.table_groups_name, - v_profiling_runs.status, - v_profiling_runs.process_id, - v_profiling_runs.duration, - v_profiling_runs.log_message, - v_profiling_runs.schema_name, - v_profiling_runs.table_ct, - v_profiling_runs.column_ct, - v_profiling_runs.anomaly_ct, - profile_anomalies.definite_ct as anomalies_definite_ct, - profile_anomalies.likely_ct as anomalies_likely_ct, - profile_anomalies.possible_ct as anomalies_possible_ct, - profile_anomalies.dismissed_ct as anomalies_dismissed_ct, - v_profiling_runs.dq_score_profiling - FROM {schema}.v_profiling_runs - LEFT JOIN profile_anomalies ON (v_profiling_runs.profiling_run_id = profile_anomalies.profile_run_id) - WHERE project_code = '{project_code}' - {table_group_condition} - {profling_runs_condition} - ORDER BY start_time DESC; - """ - - return db.retrieve_data(sql) diff --git a/testgen/ui/views/project_dashboard.py b/testgen/ui/views/project_dashboard.py index 8e28f5db..44586419 100644 --- a/testgen/ui/views/project_dashboard.py +++ b/testgen/ui/views/project_dashboard.py @@ -1,16 +1,16 @@ import typing -import pandas as pd import streamlit as st -import testgen.ui.services.database_service as db +from testgen.common.models.project import Project +from testgen.common.models.table_group import TableGroup +from testgen.common.models.test_suite import TestSuite from testgen.ui.components import widgets as testgen from testgen.ui.navigation.menu import MenuItem from testgen.ui.navigation.page import Page -from testgen.ui.queries import project_queries -from testgen.ui.services import test_suite_service, user_session_service +from testgen.ui.services import user_session_service from testgen.ui.session import session -from testgen.utils import format_field, friendly_score, score +from testgen.utils import friendly_score, make_json_safe, score PAGE_TITLE = "Project Dashboard" PAGE_ICON = "home" @@ -36,63 +36,40 @@ def render(self, project_code: str, **_kwargs): ) with st.spinner("Loading data ..."): - table_groups = get_table_groups_summary(project_code) - test_suites = test_suite_service.get_by_project(project_code) - project_summary_df = project_queries.get_summary_by_code(project_code) - - table_groups_fields: list[str] = [ - "id", - "table_groups_name", - "latest_profile_id", - "latest_profile_start", - "latest_profile_table_ct", - "latest_profile_column_ct", - "latest_anomalies_ct", - "latest_anomalies_definite_ct", - "latest_anomalies_likely_ct", - "latest_anomalies_possible_ct", - "latest_anomalies_dismissed_ct", - ] - test_suite_fields: list[str] = [ - "id", - "test_suite", - "test_ct", - "latest_run_start", - "latest_run_id", - "last_run_test_ct", - "last_run_passed_ct", - "last_run_warning_ct", - "last_run_failed_ct", - "last_run_error_ct", - "last_run_dismissed_ct", - ] + table_groups = TableGroup.select_summary(project_code, for_dashboard=True) + test_suites = TestSuite.select_summary(project_code) + project_summary = Project.get_summary(project_code) table_groups_sort = st.session_state.get("overview_table_groups_sort") or "latest_activity_date" testgen.testgen_component( "project_dashboard", props={ - "project": { - "project_code": project_code, - "test_runs_count": int(project_summary_df["test_runs_ct"]), - "profiling_runs_count": int(project_summary_df["profiling_runs_ct"]), - "connections_count": int(project_summary_df["connections_ct"]), - "default_connection_id": str(project_summary_df["default_connection_id"]), - }, + "project_summary": project_summary.to_dict(json_safe=True), "table_groups": [ { - **{field: format_field(table_group[field]) for field in table_groups_fields}, + **table_group.to_dict(json_safe=True), "test_suites": [ - { field: format_field(test_suite[field]) for field in test_suite_fields} - for _, test_suite in test_suites[test_suites["table_groups_id"] == table_group_id].iterrows() + test_suite.to_dict(json_safe=True) + for test_suite in test_suites + if test_suite.table_groups_id == table_group.id ], - "latest_tests_start": format_field(test_suites[test_suites["table_groups_id"] == table_group_id]["latest_run_start"].max()), - "dq_score": friendly_score(score(table_group["dq_score_profiling"], table_group["dq_score_testing"])), - "dq_score_profiling": friendly_score(table_group["dq_score_profiling"]), - "dq_score_testing": friendly_score(table_group["dq_score_testing"]), + "latest_tests_start": make_json_safe( + max( + ( + test_suite.latest_run_start + for test_suite in test_suites + if test_suite.table_groups_id == table_group.id + and test_suite.latest_run_start + ), + default=None, + ) + ), + "dq_score": friendly_score(score(table_group.dq_score_profiling, table_group.dq_score_testing)), + "dq_score_profiling": friendly_score(table_group.dq_score_profiling), + "dq_score_testing": friendly_score(table_group.dq_score_testing), } - for _, table_group in table_groups.iterrows() - if (table_group_id := str(table_group["id"])) + for table_group in table_groups ], "table_groups_sort_options": [ { @@ -113,76 +90,3 @@ def render(self, project_code: str, **_kwargs): ], }, ) - - -@st.cache_data(show_spinner=False) -def get_table_groups_summary(project_code: str) -> pd.DataFrame: - schema = st.session_state["dbschema"] - sql = f""" - WITH latest_profile AS ( - SELECT latest_run.table_groups_id, - latest_run.id, - latest_run.profiling_starttime, - latest_run.table_ct, - latest_run.column_ct, - latest_run.anomaly_ct, - SUM( - CASE - WHEN COALESCE(latest_anomalies.disposition, 'Confirmed') = 'Confirmed' - AND anomaly_types.issue_likelihood = 'Definite' THEN 1 - ELSE 0 - END - ) as definite_ct, - SUM( - CASE - WHEN COALESCE(latest_anomalies.disposition, 'Confirmed') = 'Confirmed' - AND anomaly_types.issue_likelihood = 'Likely' THEN 1 - ELSE 0 - END - ) as likely_ct, - SUM( - CASE - WHEN COALESCE(latest_anomalies.disposition, 'Confirmed') = 'Confirmed' - AND anomaly_types.issue_likelihood = 'Possible' THEN 1 - ELSE 0 - END - ) as possible_ct, - SUM( - CASE - WHEN COALESCE(latest_anomalies.disposition, 'Confirmed') IN ('Dismissed', 'Inactive') - AND anomaly_types.issue_likelihood <> 'Potential PII' THEN 1 - ELSE 0 - END - ) as dismissed_ct - FROM {schema}.table_groups groups - LEFT JOIN {schema}.profiling_runs latest_run ON ( - groups.last_complete_profile_run_id = latest_run.id - ) - LEFT JOIN {schema}.profile_anomaly_results latest_anomalies ON ( - latest_run.id = latest_anomalies.profile_run_id - ) - LEFT JOIN {schema}.profile_anomaly_types anomaly_types ON ( - anomaly_types.id = latest_anomalies.anomaly_id - ) - GROUP BY latest_run.id - ) - SELECT groups.id::VARCHAR(50), - groups.table_groups_name, - groups.dq_score_profiling, - groups.dq_score_testing, - latest_profile.id as latest_profile_id, - latest_profile.profiling_starttime as latest_profile_start, - latest_profile.table_ct as latest_profile_table_ct, - latest_profile.column_ct as latest_profile_column_ct, - latest_profile.anomaly_ct as latest_anomalies_ct, - latest_profile.definite_ct as latest_anomalies_definite_ct, - latest_profile.likely_ct as latest_anomalies_likely_ct, - latest_profile.possible_ct as latest_anomalies_possible_ct, - latest_profile.dismissed_ct as latest_anomalies_dismissed_ct - FROM {schema}.table_groups as groups - LEFT JOIN latest_profile ON (groups.id = latest_profile.table_groups_id) - WHERE groups.project_code = '{project_code}' - AND groups.include_in_dashboard IS TRUE; - """ - - return db.retrieve_data(sql) diff --git a/testgen/ui/views/project_settings.py b/testgen/ui/views/project_settings.py index 18e9980a..06a93fc8 100644 --- a/testgen/ui/views/project_settings.py +++ b/testgen/ui/views/project_settings.py @@ -6,10 +6,12 @@ from streamlit.delta_generator import DeltaGenerator from testgen.commands.run_observability_exporter import test_observability_exporter +from testgen.common.models import with_database_session +from testgen.common.models.project import Project from testgen.ui.components import widgets as testgen from testgen.ui.navigation.menu import MenuItem from testgen.ui.navigation.page import Page -from testgen.ui.services import project_service, user_session_service +from testgen.ui.services import user_session_service from testgen.ui.session import session PAGE_TITLE = "Project Settings" @@ -30,11 +32,11 @@ class ProjectSettingsPage(Page): roles=[ "admin" ], ) - project: dict | None = None + project: Project | None = None existing_names: list[str] | None = None def render(self, project_code: str | None = None, **_kwargs) -> None: - self.project = project_service.get_project_by_code(project_code) + self.project = Project.get(project_code) testgen.page_header( PAGE_TITLE, @@ -52,18 +54,18 @@ def show_edit_form(self) -> None: with testgen.card(): name_input = st.text_input( label="Project Name", - value=self.project["project_name"], + value=self.project.project_name, max_chars=30, key="project_settings:keys:project_name", ) st.text_input( label="Observability API URL", - value=self.project["observability_api_url"], + value=self.project.observability_api_url, key="project_settings:keys:observability_api_url", ) st.text_input( label="Observability API Key", - value=self.project["observability_api_key"], + value=self.project.observability_api_key, key="project_settings:keys:observability_api_key", ) @@ -97,22 +99,27 @@ def show_edit_form(self) -> None: key="project-settings:keys:edit", ) + @with_database_session def edit_project(self) -> None: - project = self._get_edited_project() - if project["project_name"] and (not self.existing_names or project["project_name"] not in self.existing_names): - project_service.edit_project(project) + edited_project = self._get_edited_project() + if edited_project["project_name"] and (not self.existing_names or edited_project["project_name"] not in self.existing_names): + self.project.project_name = edited_project["project_name"] + self.project.observability_api_url = edited_project["observability_api_url"] + self.project.observability_api_key = edited_project["observability_api_key"] + self.project.save() st.toast("Changes have been saved.") def _get_edited_project(self) -> None: edited_project = { - "id": self.project["id"], - "project_code": self.project["project_code"], + "id": self.project.id, + "project_code": self.project.project_code, } # We have to get the input widget values from the session state # The return values for st.text_input do not reflect the latest user input if the button is clicked without unfocusing the input # https://discuss.streamlit.io/t/issue-with-modifying-text-using-st-text-input-and-st-button/56619/5 for key in [ "project_name", "observability_api_url", "observability_api_key" ]: - edited_project[key] = st.session_state[f"project_settings:keys:{key}"].strip() + value = st.session_state.get(f"project_settings:keys:{key}") + edited_project[key] = value.strip() if value else None return edited_project def _display_connection_status(self, status_container: DeltaGenerator) -> None: @@ -126,7 +133,7 @@ def _display_connection_status(self, status_container: DeltaGenerator) -> None: project["observability_api_url"], project["observability_api_key"], ) - status_container.success("The connection was successful.") + single_element_container.success("The connection was successful.") except Exception as e: with single_element_container.container(): st.error("Error attempting the connection.") diff --git a/testgen/ui/views/quality_dashboard.py b/testgen/ui/views/quality_dashboard.py index 665366cd..d66f746b 100644 --- a/testgen/ui/views/quality_dashboard.py +++ b/testgen/ui/views/quality_dashboard.py @@ -2,10 +2,10 @@ import streamlit as st +from testgen.common.models.project import Project from testgen.ui.components import widgets as testgen from testgen.ui.navigation.menu import MenuItem from testgen.ui.navigation.page import Page -from testgen.ui.queries import project_queries from testgen.ui.queries.scoring_queries import get_all_score_cards from testgen.ui.services import user_session_service from testgen.ui.session import session @@ -29,19 +29,13 @@ class QualityDashboardPage(Page): ) def render(self, *, project_code: str, **_kwargs) -> None: - project_summary = project_queries.get_summary_by_code(project_code) + project_summary = Project.get_summary(project_code) testgen.page_header(PAGE_TITLE) testgen.testgen_component( "quality_dashboard", props={ - "project_summary": { - "project_code": project_code, - "connections_count": int(project_summary["connections_ct"]), - "default_connection_id": str(project_summary["default_connection_id"]), - "table_groups_count": int(project_summary["table_groups_ct"]), - "profiling_runs_count": int(project_summary["profiling_runs_ct"]), - }, + "project_summary": project_summary.to_dict(json_safe=True), "scores": [ format_score_card(score) for score in get_all_score_cards(project_code) diff --git a/testgen/ui/views/score_details.py b/testgen/ui/views/score_details.py index e490217b..70877754 100644 --- a/testgen/ui/views/score_details.py +++ b/testgen/ui/views/score_details.py @@ -1,4 +1,5 @@ import logging +import typing from io import BytesIO from typing import ClassVar @@ -8,14 +9,21 @@ from testgen.commands.run_refresh_score_cards_results import run_recalculate_score_card from testgen.common.mixpanel_service import MixpanelService from testgen.common.models import with_database_session -from testgen.common.models.scores import ScoreCategory, ScoreDefinition, ScoreDefinitionBreakdownItem, SelectedIssue +from testgen.common.models.scores import ( + Categories, + ScoreCategory, + ScoreDefinition, + ScoreDefinitionBreakdownItem, + ScoreTypes, + SelectedIssue, +) from testgen.ui.components import widgets as testgen from testgen.ui.components.widgets.download_dialog import FILE_DATA_TYPE, download_dialog, zip_multi_file_data from testgen.ui.navigation.page import Page from testgen.ui.navigation.router import Router from testgen.ui.pdf import hygiene_issue_report, test_result_report from testgen.ui.queries.scoring_queries import get_all_score_cards, get_score_card_issue_reports -from testgen.ui.services import project_service, user_session_service +from testgen.ui.services import user_session_service from testgen.ui.session import session, temp_value from testgen.ui.views.dialogs.profiling_results_dialog import profiling_results_dialog from testgen.utils import format_score_card, format_score_card_breakdown, format_score_card_issues @@ -50,7 +58,7 @@ def render( ) return - project_service.set_sidebar_project(score_definition.project_code) + session.set_sidebar_project(score_definition.project_code) testgen.page_header( "Score Details", @@ -60,6 +68,9 @@ def render( ], ) + if category not in typing.get_args(Categories): + category = None + if not category and score_definition.category: category = score_definition.category.value @@ -72,6 +83,8 @@ def render( with st.spinner(text="Loading data :gray[:small[(This might take a few minutes)]] ..."): user_can_edit = user_session_service.user_can_edit() score_card = format_score_card(score_definition.as_cached_score_card()) + if score_type not in typing.get_args(ScoreTypes): + score_type = None if not score_type: score_type = "cde_score" if score_card["cde_score"] and not score_card["score"] else "score" if not drilldown: @@ -157,7 +170,7 @@ def get_report_file_data(update_progress, issue) -> FILE_DATA_TYPE: hygiene_issue_report.create_report(buffer, issue) else: issue_id = issue["test_result_id"][:8] - timestamp = pd.Timestamp(issue["test_time"]).strftime("%Y%m%d_%H%M%S") + timestamp = pd.Timestamp(issue["test_date"]).strftime("%Y%m%d_%H%M%S") test_result_report.create_report(buffer, issue) update_progress(1.0) diff --git a/testgen/ui/views/score_explorer.py b/testgen/ui/views/score_explorer.py index 80d9ab82..f90e3786 100644 --- a/testgen/ui/views/score_explorer.py +++ b/testgen/ui/views/score_explorer.py @@ -12,13 +12,14 @@ run_refresh_score_cards_results, ) from testgen.common.mixpanel_service import MixpanelService +from testgen.common.models.profiling_run import ProfilingRun from testgen.common.models.scores import ScoreCategory, ScoreDefinition, ScoreDefinitionCriteria, SelectedIssue +from testgen.common.models.test_run import TestRun from testgen.ui.components import widgets as testgen from testgen.ui.components.widgets.download_dialog import FILE_DATA_TYPE, download_dialog, zip_multi_file_data from testgen.ui.navigation.page import Page from testgen.ui.navigation.router import Router from testgen.ui.pdf import hygiene_issue_report, test_result_report -from testgen.ui.queries import profiling_queries, test_run_queries from testgen.ui.queries.scoring_queries import ( get_all_score_cards, get_column_filters, @@ -104,7 +105,7 @@ def render( score_definition.name = name score_definition.total_score = total_score and total_score.lower() == "true" score_definition.cde_score = cde_score and cde_score.lower() == "true" - score_definition.category = ScoreCategory(category) if category else None + score_definition.category = ScoreCategory(category) if category in [cat.value for cat in ScoreCategory] else None if filters: applied_filters: list[dict] = try_json(filters, default=[]) @@ -224,7 +225,7 @@ def get_report_file_data(update_progress, issue) -> FILE_DATA_TYPE: hygiene_issue_report.create_report(buffer, issue) else: issue_id = issue["test_result_id"][:8] - timestamp = pd.Timestamp(issue["test_time"]).strftime("%Y%m%d_%H%M%S") + timestamp = pd.Timestamp(issue["test_date"]).strftime("%Y%m%d_%H%M%S") test_result_report.create_report(buffer, issue) update_progress(1.0) @@ -330,8 +331,8 @@ def save_score_definition(_) -> None: if is_new: latest_run = max( - profiling_queries.get_latest_run_date(project_code), - test_run_queries.get_latest_run_date(project_code), + ProfilingRun.get_latest_run(project_code), + TestRun.get_latest_run(project_code), key=lambda run: getattr(run, "run_time", datetime.min), ) diff --git a/testgen/ui/views/table_groups.py b/testgen/ui/views/table_groups.py index daecc088..e0d4158d 100644 --- a/testgen/ui/views/table_groups.py +++ b/testgen/ui/views/table_groups.py @@ -1,18 +1,21 @@ import logging import typing +from collections.abc import Iterable from dataclasses import asdict from functools import partial import streamlit as st from sqlalchemy.exc import IntegrityError -import testgen.ui.services.connection_service as connection_service -import testgen.ui.services.table_group_service as table_group_service from testgen.commands.run_profiling_bridge import run_profiling_in_background from testgen.common.models import with_database_session +from testgen.common.models.connection import Connection +from testgen.common.models.project import Project +from testgen.common.models.table_group import TableGroup, TableGroupMinimal from testgen.ui.components import widgets as testgen from testgen.ui.navigation.menu import MenuItem from testgen.ui.navigation.page import Page +from testgen.ui.queries import table_group_queries from testgen.ui.services import user_session_service from testgen.ui.session import session, temp_value from testgen.ui.views.connections import FLAVOR_OPTIONS, format_connection @@ -27,6 +30,7 @@ class TableGroupsPage(Page): can_activate: typing.ClassVar = [ lambda: session.authentication_status, lambda: not user_session_service.user_has_catalog_role(), + lambda: "project_code" in st.query_params, ] menu_item = MenuItem( icon="table_view", @@ -40,23 +44,30 @@ def render(self, project_code: str, connection_id: str | None = None, **_kwargs) testgen.page_header(PAGE_TITLE, "create-a-table-group") user_can_edit = user_session_service.user_can_edit() + project_summary = Project.get_summary(project_code) + if connection_id and not connection_id.isdigit(): + connection_id = None + if connection_id: - table_groups = table_group_service.get_by_connection(project_code, connection_id) + table_groups = TableGroup.select_minimal_where( + TableGroup.project_code == project_code, + TableGroup.connection_id == connection_id, + ) else: - table_groups = table_group_service.get_all(project_code) + table_groups = TableGroup.select_minimal_where(TableGroup.project_code == project_code) + + connections = self._get_connections(project_code) return testgen.testgen_component( "table_group_list", props={ - "project_code": project_code, + "project_summary": project_summary.to_dict(json_safe=True), "connection_id": connection_id, "permissions": { "can_edit": user_can_edit, }, - "connections": self._get_connections(project_code), - "table_groups": self._format_table_group_list([ - table_group.to_dict() for _, table_group in table_groups.iterrows() - ]), + "connections": connections, + "table_groups": self._format_table_group_list(table_groups, connections), }, on_change_handlers={ "RunSchedulesClicked": lambda *_: ProfilingScheduleDialog().open(project_code), @@ -76,7 +87,6 @@ def render(self, project_code: str, connection_id: str | None = None, **_kwargs) def add_table_group_dialog(self, project_code: str, connection_id: str | None, *_args): return self._table_group_wizard( project_code, - save_table_group_fn=table_group_service.add, connection_id=connection_id, steps=[ "tableGroup", @@ -86,10 +96,10 @@ def add_table_group_dialog(self, project_code: str, connection_id: str | None, * ) @st.dialog(title="Edit Table Group") + @with_database_session def edit_table_group_dialog(self, project_code: str, table_group_id: str): return self._table_group_wizard( project_code, - save_table_group_fn=table_group_service.edit, table_group_id=table_group_id, steps=[ "tableGroup", @@ -101,13 +111,16 @@ def _table_group_wizard( self, project_code: str, *, - save_table_group_fn: typing.Callable[[dict], str], steps: list[str] | None = None, connection_id: str | None = None, table_group_id: str | None = None, ): - def on_preview_table_group_clicked(table_group: dict): + def on_preview_table_group_clicked(payload: dict): + table_group = payload["table_group"] + verify_table_access = payload.get("verify_access") or False + mark_for_preview(True) + mark_for_access_preview(verify_table_access) set_table_group(table_group) def on_save_table_group_clicked(payload: dict): @@ -131,6 +144,7 @@ def on_go_to_profiling_runs(params: dict) -> None: self.router.navigate(to="profiling-runs", with_args=params) should_preview, mark_for_preview = temp_value("table_groups:preview:new", default=False) + should_verify_access, mark_for_access_preview = temp_value("table_groups:preview_access:new", default=False) should_save, set_save = temp_value("table_groups:save:new", default=False) get_table_group, set_table_group = temp_value("table_groups:updated:new", default={}) is_table_group_verified, set_table_group_verified = temp_value( @@ -144,54 +158,56 @@ def on_go_to_profiling_runs(params: dict) -> None: is_table_group_used = False connections = self._get_connections(project_code) - original_table_group = {"project_code": project_code} + table_group = TableGroup(project_code=project_code) + original_table_group_schema = None if table_group_id: - original_table_group = table_group_service.get_by_id(table_group_id=table_group_id).to_dict() - is_table_group_used = table_group_service.is_table_group_used(table_group_id) + table_group = TableGroup.get(table_group_id) + original_table_group_schema = table_group.table_group_schema + is_table_group_used = TableGroup.is_in_use([table_group_id]) + + add_scorecard_definition = False + for key, value in get_table_group().items(): + if key == "add_scorecard_definition": + add_scorecard_definition = value + else: + setattr(table_group, key, value) - table_group = { - **original_table_group, - **get_table_group(), - } table_group_preview = None if is_table_group_used: - table_group["table_group_schema"] = original_table_group["table_group_schema"] + table_group.table_group_schema = original_table_group_schema if len(connections) == 1: - table_group["connection_id"] = connections[0]["connection_id"] + table_group.connection_id = connections[0]["connection_id"] - if not table_group.get("connection_id"): + if not table_group.connection_id: if connection_id: - table_group["connection_id"] = int(connection_id) + table_group.connection_id = int(connection_id) elif len(connections) == 1: - table_group["connection_id"] = connections[0]["connection_id"] - elif table_group.get("id"): + table_group.connection_id = connections[0]["connection_id"] + elif table_group.id: connections = [ conn for conn in connections - if int(conn["connection_id"]) == int(table_group.get("connection_id")) + if int(conn["connection_id"]) == int(table_group.connection_id) ] if should_preview(): - connection = connection_service.get_by_id(table_group["connection_id"], hide_passwords=False) - table_group_preview = table_group_service.get_table_group_preview( - project_code, - connection, - {"id": table_group.get("id") or "temp", **table_group}, + table_group_preview = table_group_queries.get_table_group_preview( + table_group, + verify_table_access=should_verify_access(), ) success = None message = "" - table_group_id = None if should_save(): success = True if is_table_group_verified(): try: - table_group_id = save_table_group_fn(table_group) + table_group.save(add_scorecard_definition) if should_run_profiling(): try: - run_profiling_in_background(table_group_id) - message = f"Profiling run started for table group {table_group['table_groups_name']}." + run_profiling_in_background(table_group.id) + message = f"Profiling run started for table group {table_group.table_groups_name}." except Exception: success = False message = "Profiling run encountered errors" @@ -210,14 +226,14 @@ def on_go_to_profiling_runs(params: dict) -> None: props={ "project_code": project_code, "connections": connections, - "table_group": table_group, + "table_group": table_group.to_dict(json_safe=True), "is_in_use": is_table_group_used, "table_group_preview": table_group_preview, "steps": steps, "results": { "success": success, "message": message, - "table_group_id": table_group_id, + "table_group_id": str(table_group.id), } if success is not None else None, }, on_change_handlers={ @@ -229,24 +245,33 @@ def on_go_to_profiling_runs(params: dict) -> None: def _get_connections(self, project_code: str, connection_id: str | None = None) -> list[dict]: if connection_id: - connections = [connection_service.get_by_id(connection_id, hide_passwords=True)] + connections = [Connection.get_minimal(connection_id)] else: - connections = [ - connection for _, connection in connection_service.get_connections( - project_code, hide_passwords=True - ).iterrows() - ] + connections = Connection.select_minimal_where(Connection.project_code == project_code) return [ format_connection(connection) for connection in connections ] - def _format_table_group_list(self, table_groups: list[dict]) -> list[dict]: + def _format_table_group_list( + self, + table_groups: Iterable[TableGroupMinimal], + connections: list[dict], + ) -> list[dict]: + connections_by_id = { con["connection_id"]: con for con in connections } + formatted_list = [] + for table_group in table_groups: - flavors = [f for f in FLAVOR_OPTIONS if f.value == table_group["sql_flavor_code"]] + formatted_table_group = table_group.to_dict(json_safe=True) + connection = connections_by_id[table_group.connection_id] + + flavors = [f for f in FLAVOR_OPTIONS if f.value == connection["sql_flavor_code"]] if flavors and (flavor := flavors[0]): - table_group["connection"] = { - "name": table_group["connection_name"], + formatted_table_group["connection"] = { + "name": connection["connection_name"], "flavor": asdict(flavor), } - return table_groups + + formatted_list.append(formatted_table_group) + + return formatted_list @st.dialog(title="Run Profiling") def run_profiling_dialog(self, project_code: str, table_group_id: str) -> None: @@ -268,7 +293,7 @@ def on_run_profiling_confirmed(*_args) -> None: default=False, ) - table_group = table_group_service.get_by_id(table_group_id).to_dict() + table_group = TableGroup.get_minimal(table_group_id) result = None if should_run_profiling(): success = True @@ -285,7 +310,7 @@ def on_run_profiling_confirmed(*_args) -> None: "run_profiling_dialog", props={ "project_code": project_code, - "table_group": table_group, + "table_group": table_group.to_dict(json_safe=True), "result": result, }, on_change_handlers={ @@ -295,13 +320,13 @@ def on_run_profiling_confirmed(*_args) -> None: ) @st.dialog(title="Delete Table Group") + @with_database_session def delete_table_group_dialog(self, project_code: str, table_group_id: str): def on_delete_confirmed(*_args): confirm_deletion(True) - table_group = table_group_service.get_by_id(table_group_id=table_group_id) - table_group_name = table_group["table_groups_name"] - can_be_deleted = table_group_service.cascade_delete([table_group_name], dry_run=True) + table_group = TableGroup.get_minimal(table_group_id) + can_be_deleted = not TableGroup.is_in_use([table_group_id]) is_deletion_confirmed, confirm_deletion = temp_value( f"table_groups:confirm_delete:{table_group_id}", default=False, @@ -311,9 +336,9 @@ def on_delete_confirmed(*_args): result = None if is_deletion_confirmed(): - if not table_group_service.are_table_groups_in_use([table_group_name]): - table_group_service.cascade_delete([table_group_name]) - message = f"Table Group {table_group_name} has been deleted. " + if not TableGroup.has_running_process([table_group_id]): + TableGroup.cascade_delete([table_group_id]) + message = f"Table Group {table_group.table_groups_name} has been deleted. " st.rerun() else: message = "This Table Group is in use by a running process and cannot be deleted." @@ -323,7 +348,7 @@ def on_delete_confirmed(*_args): "table_group_delete", props={ "project_code": project_code, - "table_group": table_group.to_dict(), + "table_group": table_group.to_dict(json_safe=True), "can_be_deleted": can_be_deleted, "result": result, }, diff --git a/testgen/ui/views/test_definitions.py b/testgen/ui/views/test_definitions.py index 18ddfa39..8123049f 100644 --- a/testgen/ui/views/test_definitions.py +++ b/testgen/ui/views/test_definitions.py @@ -6,16 +6,17 @@ import pandas as pd import streamlit as st +from sqlalchemy import asc, tuple_ from streamlit.delta_generator import DeltaGenerator from streamlit_extras.no_default_selectbox import selectbox -import testgen.ui.services.database_service as db import testgen.ui.services.form_service as fm -import testgen.ui.services.query_service as dq -import testgen.ui.services.table_group_service as table_group_service -import testgen.ui.services.test_definition_service as test_definition_service -import testgen.ui.services.test_suite_service as test_suite_service from testgen.common import date_service +from testgen.common.models import with_database_session +from testgen.common.models.connection import Connection +from testgen.common.models.table_group import TableGroup, TableGroupMinimal +from testgen.common.models.test_definition import TestDefinition, TestDefinitionMinimal, TestDefinitionSummary +from testgen.common.models.test_suite import TestSuite, TestSuiteMinimal from testgen.ui.components import widgets as testgen from testgen.ui.components.widgets.download_dialog import ( FILE_DATA_TYPE, @@ -25,11 +26,13 @@ ) from testgen.ui.components.widgets.page import css_class, flex_row_end from testgen.ui.navigation.page import Page -from testgen.ui.services import project_service, user_session_service +from testgen.ui.services import user_session_service +from testgen.ui.services.database_service import fetch_all_from_db, fetch_df_from_db, fetch_from_target_db from testgen.ui.services.string_service import empty_if_null, snake_case_to_title_case from testgen.ui.session import session, temp_value from testgen.ui.views.dialogs.profiling_results_dialog import view_profiling_button from testgen.ui.views.dialogs.run_tests_dialog import run_tests_dialog +from testgen.utils import to_dataframe LOG = logging.getLogger("testgen") @@ -43,16 +46,16 @@ class TestDefinitionsPage(Page): ] def render(self, test_suite_id: str, table_name: str | None = None, column_name: str | None = None, **_kwargs) -> None: - test_suite = test_suite_service.get_by_id(test_suite_id) - if test_suite.empty: + test_suite = TestSuite.get(test_suite_id) + if not test_suite: self.router.navigate_with_warning( f"Test suite with ID '{test_suite_id}' does not exist. Redirecting to list of Test Suites ...", "test-suites", ) - table_group = table_group_service.get_by_id(test_suite["table_groups_id"]) - project_code = table_group["project_code"] - project_service.set_sidebar_project(project_code) + table_group = TableGroup.get_minimal(test_suite.table_groups_id) + project_code = table_group.project_code + session.set_sidebar_project(project_code) user_can_edit = user_session_service.user_can_edit() user_can_disposition = user_session_service.user_can_disposition() @@ -61,7 +64,7 @@ def render(self, test_suite_id: str, table_name: str | None = None, column_name: "testgen-test-types", breadcrumbs=[ { "label": "Test Suites", "path": "test-suites", "params": { "project_code": project_code } }, - { "label": test_suite["test_suite"] }, + { "label": test_suite.test_suite }, ], ) @@ -111,7 +114,7 @@ def render(self, test_suite_id: str, table_name: str | None = None, column_name: if user_can_edit and actions_column.button( ":material/add: Add", help="Add a new Test Definition" ): - add_test_dialog(project_code, table_group, test_suite, table_name, column_name) + add_test_dialog(table_group, test_suite, table_name, column_name) if user_can_edit and table_actions_column.button( ":material/play_arrow: Run Tests", @@ -120,8 +123,7 @@ def render(self, test_suite_id: str, table_name: str | None = None, column_name: run_tests_dialog(project_code, test_suite) selected = show_test_defs_grid( - project_code, test_suite["test_suite"], table_name, column_name, test_type, do_multi_select, table_actions_column, - table_group["id"] + test_suite, table_name, column_name, test_type, do_multi_select, table_actions_column, table_group ) fm.render_refresh_button(table_actions_column) @@ -138,7 +140,7 @@ def render(self, test_suite_id: str, table_name: str | None = None, column_name: ]) for action in disposition_actions: - action_disabled = not selected or all(sel[action["attribute"]] == ("Y" if action["value"] else "N") for sel in selected) + action_disabled = not selected or all(sel[action["attribute"]] == action["value"] for sel in selected) action["button"] = disposition_column.button(action["icon"], help=action["help"], disabled=action_disabled) # This has to be done as a second loop - otherwise, the rest of the buttons after the clicked one are not displayed briefly while refreshing @@ -163,7 +165,7 @@ def render(self, test_suite_id: str, table_name: str | None = None, column_name: ":material/edit: Edit", disabled=not selected, ): - edit_test_dialog(project_code, table_group, test_suite, table_name, column_name, selected_test_def) + edit_test_dialog(table_group, test_suite, table_name, column_name, selected_test_def) if actions_column.button( ":material/file_copy: Copy/Move", @@ -179,6 +181,7 @@ def render(self, test_suite_id: str, table_name: str | None = None, column_name: @st.dialog("Delete Tests") +@with_database_session def delete_test_dialog(test_definitions: list[dict]): delete_clicked, set_delete_clicked = temp_value("test-definitions:confirm-delete-tests-val") st.html(f""" @@ -199,42 +202,33 @@ def delete_test_dialog(test_definitions: list[dict]): ) if delete_clicked(): - test_definition_service.delete([ item["id"] for item in test_definitions ]) + TestDefinition.delete_where(TestDefinition.id.in_([ item["id"] for item in test_definitions ])) st.success("Test definitions have been deleted.") time.sleep(1) st.rerun() def show_test_form_by_id(test_definition_id): - selected_test_raw = test_definition_service.get_test_definitions(test_definition_ids=[test_definition_id]) - test_definition = selected_test_raw.iloc[0].to_dict() - - project_code = test_definition["project_code"] - table_group_id = test_definition["table_groups_id"] - test_suite_name = test_definition["test_suite"] - table_name = test_definition["table_name"] - column_name = test_definition["column_name"] - - table_group_raw = run_table_groups_lookup_query(project_code, table_group_id=None) - table_group = table_group_raw.iloc[0].to_dict() - - test_suite_raw = run_test_suite_lookup_query(table_group_id, test_suite_name) - if not test_suite_raw.empty: - test_suite = test_suite_raw.iloc[0].to_dict() - + test_definition = TestDefinition.get(test_definition_id) + table_group = TableGroup.get_minimal(test_definition.table_groups_id) + test_suite = TestSuite.get(test_definition.test_suite_id) + if test_suite: edit_test_dialog( - project_code, table_group, test_suite, table_name, column_name, test_definition + table_group, + test_suite, + test_definition.table_name, + test_definition.column_name, + test_definition.to_dict(), ) def show_test_form( - mode, - project_code, - table_group, - test_suite, - str_table_name, - str_column_name, - selected_test_def=None, + mode: typing.Literal["add", "edit"], + table_group: TableGroupMinimal, + test_suite: TestSuite, + table_name: str, + column_name: str, + selected_test_def: dict | None = None, ): # test_type logic if mode == "add": @@ -262,7 +256,7 @@ def show_test_form( test_description_placeholder = f"Inherited ({test_type_test_description})" # severity - test_suite_severity = test_suite["severity"] + test_suite_severity = test_suite.severity test_types_severity = selected_test_type_row["default_severity"] inherited_severity = test_suite_severity if test_suite_severity else test_types_severity @@ -273,25 +267,19 @@ def show_test_form( severity_index = severity_options.index(selected_test_def["severity"]) # general value parsing - entity_id = selected_test_def["id"] if mode == "edit" else "" - cat_test_id = selected_test_def["cat_test_id"] if mode == "edit" else "" - project_code = selected_test_def["project_code"] if mode == "edit" else project_code - table_groups_id = selected_test_def["table_groups_id"] if mode == "edit" else table_group["id"] - profile_run_id = selected_test_def["profile_run_id"] if mode == "edit" else "" - test_suite_name = selected_test_def["test_suite"] if mode == "edit" else test_suite["test_suite"] - test_suite_id = test_suite["id"] - test_action = empty_if_null(selected_test_def["test_action"]) if mode == "edit" else "" - schema_name = selected_test_def["schema_name"] if mode == "edit" else table_group["table_group_schema"] - table_name = empty_if_null(selected_test_def["table_name"]) if mode == "edit" else empty_if_null(str_table_name) + table_groups_id = selected_test_def["table_groups_id"] if mode == "edit" else table_group.id + test_suite_id = test_suite.id + schema_name = selected_test_def["schema_name"] if mode == "edit" else table_group.table_group_schema + table_name = empty_if_null(selected_test_def["table_name"]) if mode == "edit" else empty_if_null(table_name) skip_errors = selected_test_def["skip_errors"] or 0 if mode == "edit" else 0 - test_active = selected_test_def["test_active"] == "Y" if mode == "edit" else True - lock_refresh = selected_test_def["lock_refresh"] == "Y" if mode == "edit" else False + test_active = bool(selected_test_def["test_active"]) if mode == "edit" else True + lock_refresh = bool(selected_test_def["lock_refresh"]) if mode == "edit" else False test_definition_status = selected_test_def["test_definition_status"] if mode == "edit" else "" - check_result = selected_test_def["check_result"] if mode == "edit" else None - column_name = empty_if_null(selected_test_def["column_name"]) if mode == "edit" else empty_if_null(str_column_name) + column_name = empty_if_null(selected_test_def["column_name"]) if mode == "edit" else empty_if_null(column_name) last_auto_gen_date = empty_if_null(selected_test_def["last_auto_gen_date"]) if mode == "edit" else "" profiling_as_of_date = empty_if_null(selected_test_def["profiling_as_of_date"]) if mode == "edit" else "" profile_run_id = empty_if_null(selected_test_def["profile_run_id"]) if mode == "edit" else "" + # dynamic attributes custom_query = empty_if_null(selected_test_def["custom_query"]) if mode == "edit" else "" @@ -316,19 +304,16 @@ def show_test_form( match_groupby_names = empty_if_null(selected_test_def["match_groupby_names"]) if mode == "edit" else "" match_having_condition = empty_if_null(selected_test_def["match_having_condition"]) if mode == "edit" else "" window_days = selected_test_def["window_days"] or 0 if mode == "edit" else 0 - test_mode = empty_if_null(selected_test_def["test_mode"]) if mode == "edit" else "" # export_to_observability - test_suite_export_to_observability = test_suite["export_to_observability"] - inherited_export_to_observability = "Yes" if test_suite_export_to_observability == "Y" else "No" - + inherited_export_to_observability = "Yes" if test_suite.export_to_observability else "No" inherited_legend = f"Inherited ({inherited_export_to_observability})" export_to_observability_options = [inherited_legend, "Yes", "No"] if mode == "edit": - match selected_test_def["export_to_observability_raw"]: - case "N": + match selected_test_def["export_to_observability"]: + case False: export_to_observability = "No" - case "Y": + case True: export_to_observability = "Yes" case _: export_to_observability = inherited_legend @@ -336,9 +321,6 @@ def show_test_form( export_to_observability = inherited_legend export_to_observability_index = export_to_observability_options.index(export_to_observability) - # watch_level - watch_level = selected_test_def["watch_level"] if mode == "edit" else "WARN" - # dynamic attributes dynamic_attributes_raw = selected_test_type_row["default_parm_columns"] dynamic_attributes = dynamic_attributes_raw.split(",") @@ -371,18 +353,13 @@ def show_test_form( st.info(f"**Usage Notes:**\n\n{selected_test_type_row['usage_notes']}") left_column, right_column = st.columns([0.5, 0.5]) + left_column.text_input( + label="Test Suite Name", max_chars=200, value=test_suite.test_suite, disabled=True + ) test_definition = { - "id": entity_id, - "cat_test_id": cat_test_id, - "watch_level": watch_level, - "project_code": project_code, "table_groups_id": table_groups_id, - "profile_run_id": profile_run_id, "test_type": test_type, - "test_suite": left_column.text_input( - label="Test Suite Name", max_chars=200, value=test_suite_name, disabled=True - ), "test_suite_id": test_suite_id, "test_description": left_column.text_area( label="Test Description Override", @@ -392,15 +369,12 @@ def show_test_form( value=test_description, help=test_description_help, ), - "test_action": test_action, - "test_mode": test_mode, "lock_refresh": left_column.toggle( label="Lock Refresh", value=lock_refresh, help="Protects test parameters from being overwritten when tests in this Test Suite are regenerated.", ), "test_active": left_column.toggle(label="Test Active", value=test_active), - "check_result": check_result, "custom_query": custom_query, "baseline_ct": baseline_ct, "baseline_unique_ct": baseline_unique_ct, @@ -436,21 +410,25 @@ def show_test_form( # export_to_observability export_to_observability_help = "Send results to DataKitchen Observability - overrides Test Suite toggle" - test_definition["export_to_observability_raw"] = right_column.selectbox( + export_to_observability_select = right_column.selectbox( label="Send to Observability - Override", options=export_to_observability_options, index=export_to_observability_index, help=export_to_observability_help, ) + test_definition["export_to_observability"] = ( + True if export_to_observability_select == "Yes" else (False if export_to_observability_select == "No" else None) + ) # severity severity_help = "Urgency is defined by default for the Test Type, but can be overridden for all tests in the Test Suite, and ultimately here for each individual test." - test_definition["severity"] = right_column.selectbox( + severity_select = right_column.selectbox( label="Urgency Override", options=severity_options, index=severity_index, help=severity_help, ) + test_definition["severity"] = None if severity_select.startswith("Inherited") else severity_select if mode == "edit": columns = st.columns([0.5, 0.5]) @@ -463,7 +441,7 @@ def show_test_form( with container: testgen.link( href="profiling-runs:results", - params={"run_id": profile_run_id}, + params={"run_id": str(profile_run_id)}, label=formatted_time, open_new=True, ) @@ -634,7 +612,7 @@ def render_dynamic_attribute(attribute: str, container: DeltaGenerator): ) if validate: try: - test_definition_service.validate_test(test_definition) + validate_test(test_definition, table_group) bottom_right_column.success("Validation is successful.") except Exception as e: bottom_right_column.error(f"Test validation failed with error: {e}") @@ -650,42 +628,51 @@ def render_dynamic_attribute(attribute: str, container: DeltaGenerator): if submit: if validate_form(test_scope, test_type, test_definition, column_name_label): if mode == "edit": - test_definition_service.update(test_definition) - st.rerun() - else: - test_definition_service.add(test_definition) - st.rerun() + test_definition["id"] = selected_test_def["id"] + TestDefinition(**test_definition).save() + get_test_suite_columns.clear() + st.rerun() @st.dialog(title="Add Test") -def add_test_dialog(project_code, table_group, test_suite, str_table_name, str_column_name): - show_test_form("add", project_code, table_group, test_suite, str_table_name, str_column_name) +@with_database_session +def add_test_dialog(table_group, test_suite, str_table_name, str_column_name): + show_test_form("add", table_group, test_suite, str_table_name, str_column_name) @st.dialog(title="Edit Test") -def edit_test_dialog(project_code, table_group, test_suite, str_table_name, str_column_name, selected_test_def): - show_test_form("edit", project_code, table_group, test_suite, str_table_name, str_column_name, selected_test_def) +@with_database_session +def edit_test_dialog(table_group, test_suite, str_table_name, str_column_name, selected_test_def): + show_test_form("edit", table_group, test_suite, str_table_name, str_column_name, selected_test_def) @st.dialog(title="Copy/Move Tests") -def copy_move_test_dialog(project_code, origin_table_group, origin_test_suite, selected_test_definitions): +@with_database_session +def copy_move_test_dialog( + project_code: str, + origin_table_group: TableGroup, + origin_test_suite: TestSuite, + selected_test_definitions: list[dict], +): st.text(f"Selected tests: {len(selected_test_definitions)}") group_filter_column, suite_filter_column, table_filter_column = st.columns([.33, .33, .33], vertical_alignment="bottom") with group_filter_column: - table_groups_df = run_table_groups_lookup_query(project_code) + table_groups = TableGroup.select_minimal_where(TableGroup.project_code == project_code) + table_groups_df = to_dataframe(table_groups, TableGroupMinimal.columns()) target_table_group_id = testgen.select( options=table_groups_df, value_column="id", display_column="table_groups_name", - default_value=origin_table_group["id"], + default_value=origin_table_group.id, required=True, label="Target Table Group", ) with suite_filter_column: - test_suites_df = run_test_suite_lookup_query(target_table_group_id) + test_suites = TestSuite.select_minimal_where(TestSuite.table_groups_id == target_table_group_id) + test_suites_df = to_dataframe(test_suites, TestSuiteMinimal.columns()) target_test_suite_id = testgen.select( options=test_suites_df, value_column="id", @@ -695,36 +682,33 @@ def copy_move_test_dialog(project_code, origin_table_group, origin_test_suite, s label="Target Test Suite", ) - target_table_column = None - if target_test_suite_id == origin_test_suite["id"]: + target_table_name = None + target_column_name = None + if target_test_suite_id == origin_test_suite.id: with table_filter_column: - columns_df = get_test_suite_columns(origin_test_suite["id"]) - table_name = testgen.select( + columns_df = get_test_suite_columns(origin_test_suite.id) + target_table_name = testgen.select( options=list(columns_df["table_name"].unique()), value_column="table_name", default_value=None, required=True, label="Target Table Name", ) - column_options = list(columns_df.loc[columns_df["table_name"] == table_name]["column_name"].unique()) - column_name = testgen.select( + column_options = list(columns_df.loc[columns_df["table_name"] == target_table_name]["column_name"].unique()) + target_column_name = testgen.select( options=column_options, default_value=None, required=True, label="Column Name", - disabled=not table_name, + disabled=not target_table_name, ) - target_table_column = { - "table_name": table_name, - "column_name":column_name - } movable_test_definitions = [] if target_table_group_id and target_test_suite_id: - collision_test_definitions = test_definition_service.get_test_definitions_collision(selected_test_definitions, target_table_group_id, target_test_suite_id) + collision_test_definitions = get_test_definitions_collision(selected_test_definitions, target_table_group_id, target_test_suite_id) if not collision_test_definitions.empty: - unlocked = collision_test_definitions[collision_test_definitions["lock_refresh"] == "N"] - locked = collision_test_definitions[collision_test_definitions["lock_refresh"] == "Y"] + unlocked = collision_test_definitions[collision_test_definitions["lock_refresh"] == False] + locked = collision_test_definitions[collision_test_definitions["lock_refresh"] == True] locked_tuples = [ (test["table_name"], test["column_name"], test["test_type"]) for test in locked.iterrows() ] movable_test_definitions = [ test for test in selected_test_definitions if (test["table_name"], test["column_name"], test["test_type"]) not in locked_tuples ] @@ -750,16 +734,19 @@ def copy_move_test_dialog(project_code, origin_table_group, origin_test_suite, s use_container_width=True, ) + test_definition_ids = [item["id"] for item in movable_test_definitions] if move: - test_definition_service.move(movable_test_definitions, target_table_group_id, target_test_suite_id, target_table_column) + TestDefinition.move(test_definition_ids, target_table_group_id, target_test_suite_id, target_table_name, target_column_name) success_message = "Test Definitions have been moved." st.success(success_message) + get_test_suite_columns.clear() time.sleep(1) st.rerun() elif copy: - test_definition_service.copy(movable_test_definitions, target_table_group_id, target_test_suite_id, target_table_column) + TestDefinition.copy(test_definition_ids, target_table_group_id, target_test_suite_id, target_table_name, target_column_name) success_message = "Test Definitions have been copied." st.success(success_message) + get_test_suite_columns.clear() time.sleep(1) st.rerun() @@ -776,36 +763,17 @@ def validate_form(test_scope, test_type, test_definition, column_name_label): return True -def validate_test_definition_uniqueness(test_definition, test_scope): - record_count = test_definition_service.check_test_definition_uniqueness(test_definition) - if record_count > 0: - match test_scope: - case "column": - message_bit = "and Column Name " - case "referential": - message_bit = "and Column Names " - case "custom": - message_bit = "and Test Focus " - case "table": - message_bit = "" - case _: - message_bit = "" - - return f"Validation error: the combination of Table Name, Test Type {message_bit}must be unique within a Test Suite" - - def prompt_for_test_type(): col0, col1, col2, col3, col4, col5 = st.columns([0.1, 0.2, 0.2, 0.2, 0.2, 0.1]) col0.write("Show Types") - boo_show_referential = col1.checkbox(":green[⧉] Referential", True) - boo_show_table = col2.checkbox(":green[⊞] Table", True) - boo_show_column = col3.checkbox(":green[â‰Ŗ] Column", True) - boo_show_custom = col4.checkbox(":green[⛭] Custom", True) - - df = run_test_type_lookup_query(str_test_type=None, boo_show_referential=boo_show_referential, - boo_show_table=boo_show_table, boo_show_column=boo_show_column, - boo_show_custom=boo_show_custom) + + df = run_test_type_lookup_query( + include_referential=col1.checkbox(":green[⧉] Referential", True), + include_table=col2.checkbox(":green[⊞] Table", True), + include_column=col3.checkbox(":green[â‰Ŗ] Column", True), + include_custom=col4.checkbox(":green[⛭] Custom", True), + ) lst_choices = df["select_name"].tolist() str_selected = selectbox("Test Type", lst_choices) @@ -819,6 +787,7 @@ def prompt_for_test_type(): @st.dialog(title="Unlock Test Definition") +@with_database_session def confirm_unlocking_test_definition(test_definitions: list[dict]): unlock_confirmed, set_unlock_confirmed = temp_value("test-definitions:confirm-unlock-tests") @@ -852,24 +821,23 @@ def confirm_unlocking_test_definition(test_definitions: list[dict]): def update_test_definition(selected, attribute, value, message): result = None test_definition_ids = [row["id"] for row in selected if "id" in row] - test_definition_service.update_attribute(test_definition_ids, attribute, value) + TestDefinition.set_status_attribute(attribute, test_definition_ids, value) st.success(message) return result def show_test_defs_grid( - str_project_code, str_test_suite, str_table_name, str_column_name, str_test_type, do_multi_select, export_container, - str_table_groups_id + test_suite: TestSuite, + table_name: str | None, + column_name: str | None, + test_type: str | None, + do_multi_select: bool, + export_container: DeltaGenerator, + table_group: TableGroupMinimal, ): with st.container(): with st.spinner("Loading data ..."): - df = test_definition_service.get_test_definitions( - str_project_code, str_test_suite, str_table_name, str_column_name, str_test_type - ) - date_service.accommodate_dataframe_to_timezone(df, st.session_state) - - for col in df.select_dtypes(include=["datetime"]).columns: - df[col] = df[col].astype(str).replace("NaT", "") + df = get_test_definitions(test_suite, table_name, column_name, test_type) lst_show_columns = [ "schema_name", @@ -879,7 +847,7 @@ def show_test_defs_grid( "test_active_display", "lock_refresh_display", "urgency", - "export_to_observability", + "export_to_observability_display", "profiling_as_of_date", "last_manual_update", ] @@ -895,8 +863,9 @@ def show_test_defs_grid( "Based on Profiling", "Last Manual Update", ] - - # show_column_headers = list(map(snake_case_to_title_case, show_column_headers)) + # Multiselect checkboxes do not display correctly if the dataframe column order does not start with the first displayed column -_- + columns = [lst_show_columns[0]] + [ col for col in df.columns.to_list() if col != lst_show_columns[0] ] + df = df.reindex(columns=columns) dct_selected_row = fm.render_grid_select( df, @@ -919,7 +888,7 @@ def open_download_dialog(data: pd.DataFrame | None = None) -> None: download_dialog( dialog_title="Download Excel Report", file_content_func=get_excel_report_data, - args=(str_project_code, str_test_suite, data), + args=(test_suite, data), ) with popover_container.container(key="tg--export-popover"): @@ -987,7 +956,7 @@ def open_download_dialog(data: pd.DataFrame | None = None) -> None: view_profiling_button( selected_row["column_name"], selected_row["table_name"], - str_table_groups_id, + str(table_group.id), ) with right_column: @@ -996,17 +965,16 @@ def open_download_dialog(data: pd.DataFrame | None = None) -> None: return dct_selected_row +@with_database_session def get_excel_report_data( update_progress: PROGRESS_UPDATE_TYPE, - project_code: str, - test_suite: str, + test_suite: TestSuite, data: pd.DataFrame | None = None, ) -> FILE_DATA_TYPE: if data is not None: data = data.copy() else: - data = test_definition_service.get_test_definitions(project_code, test_suite) - date_service.accommodate_dataframe_to_timezone(data, st.session_state) + data = get_test_definitions(test_suite) for key in ["test_active_display", "lock_refresh_display"]: data[key] = data[key].apply(lambda val: val if val == "Yes" else None) @@ -1033,7 +1001,7 @@ def get_excel_report_data( return get_excel_file_data( data, "Test Definitions", - details={"Test suite": test_suite}, + details={"Test suite": test_suite.test_suite}, columns=columns, update_progress=update_progress, ) @@ -1068,47 +1036,160 @@ def generate_test_defs_help(str_test_type): @st.cache_data(show_spinner=False) -def run_test_type_lookup_query(str_test_type=None, boo_show_referential=True, boo_show_table=True, - boo_show_column=True, boo_show_custom=True): - str_schema = st.session_state["dbschema"] - return dq.run_test_type_lookup_query(str_schema, str_test_type, boo_show_referential, boo_show_table, - boo_show_column, boo_show_custom) +def run_test_type_lookup_query( + test_type: str | None = None, + include_referential: bool = True, + include_table: bool = True, + include_column: bool = True, + include_custom: bool = True, +) -> pd.DataFrame: + scope_map = { + "referential": include_referential, + "table": include_table, + "column": include_column, + "custom": include_custom, + } + scopes = [ key for key, include in scope_map.items() if include ] + + query = f""" + SELECT + tt.id, tt.test_type, tt.id as cat_test_id, + tt.test_name_short, tt.test_name_long, tt.test_description, + tt.measure_uom, COALESCE(tt.measure_uom_description, '') as measure_uom_description, + tt.default_parm_columns, tt.default_severity, + tt.run_type, tt.test_scope, tt.dq_dimension, tt.threshold_description, + tt.column_name_prompt, tt.column_name_help, + tt.default_parm_prompts, tt.default_parm_help, tt.usage_notes, + CASE tt.test_scope + WHEN 'referential' THEN '⧉ ' + WHEN 'custom' THEN '⛭ ' + WHEN 'table' THEN '⊞ ' + WHEN 'column' THEN 'â‰Ŗ ' + ELSE '? ' + END + || tt.test_name_short + || ': ' + || lower(tt.test_name_long) + || CASE + WHEN tt.selection_criteria > '' THEN ' [auto-generated]' + ELSE '' + END as select_name + FROM test_types tt + WHERE tt.active = 'Y' + {"AND tt.test_type = :test_type" if test_type else ""} + {"AND tt.test_scope in :scopes" if scopes else ""} + ORDER BY + CASE tt.test_scope + WHEN 'referential' THEN 1 + WHEN 'custom' THEN 2 + WHEN 'table' THEN 3 + WHEN 'column' THEN 4 + ELSE 5 + END, + tt.test_name_short; + """ + params = { + "test_type": test_type, + "scopes": tuple(scopes), + } + return fetch_df_from_db(query, params) @st.cache_data(show_spinner=False) -def run_table_groups_lookup_query(str_project_code, str_connection_id=None, table_group_id=None): - str_schema = st.session_state["dbschema"] - return dq.run_table_groups_lookup_query(str_schema, str_project_code, str_connection_id, table_group_id) +def get_test_suite_columns(test_suite_id: str) -> pd.DataFrame: + results = TestDefinition.select_minimal_where( + TestDefinition.test_suite_id == test_suite_id, + order_by = (asc(TestDefinition.table_name), asc(TestDefinition.column_name)), + ) + return to_dataframe(results, TestDefinitionMinimal.columns()) + + +def get_test_definitions( + test_suite: TestSuite, + table_name: str | None = None, + column_name: str | None = None, + test_type: str | None = None, +) -> pd.DataFrame: + clauses = [TestDefinition.test_suite_id == test_suite.id] + if table_name: + clauses.append(TestDefinition.table_name == table_name) + if column_name: + clauses.append(TestDefinition.column_name.ilike(column_name)) + if test_type: + clauses.append(TestDefinition.test_type == test_type) + test_definitions = TestDefinition.select_where(*clauses) + + df = to_dataframe(test_definitions, TestDefinitionSummary.columns()) + date_service.accommodate_dataframe_to_timezone(df, st.session_state) + for key in ["id", "table_groups_id", "profile_run_id", "test_suite_id"]: + df[key] = df[key].apply(lambda value: str(value)) + + df["test_active_display"] = df["test_active"].apply(lambda value: "Yes" if value else "No") + df["lock_refresh_display"] = df["lock_refresh"].apply(lambda value: "Yes" if value else "No") + df["urgency"] = df.apply(lambda row: row["severity"] or test_suite.severity or row["default_severity"], axis=1) + df["final_test_description"] = df.apply(lambda row: row["test_description"] or row["default_test_description"], axis=1) + df["export_uom"] = df.apply(lambda row: row["measure_uom_description"] or row["measure_uom"], axis=1) + + def get_export_to_observability_display(value: str) -> str: + if value is not None: + return "Yes" if value else "No" + return f"Inherited ({'Yes' if test_suite.export_to_observability else 'No'})" + df["export_to_observability_display"] = df["export_to_observability"].apply(get_export_to_observability_display) + + for col in df.select_dtypes(include=["datetime"]).columns: + df[col] = df[col].astype(str).replace("NaT", "") + + return df + + +def get_test_definitions_collision( + test_definitions: list[dict], + target_table_group_id: str, + target_test_suite_id: str, +) -> pd.DataFrame: + results = TestDefinition.select_minimal_where( + TestDefinition.table_groups_id == target_table_group_id, + TestDefinition.test_suite_id == target_test_suite_id, + TestDefinition.last_auto_gen_date.isnot(None), + tuple_(TestDefinition.table_name, TestDefinition.column_name, TestDefinition.test_type).in_( + [(item["table_name"], item["column_name"], item["test_type"]) for item in test_definitions] + ), + ) + return to_dataframe(results, TestDefinitionMinimal.columns()) + + +def get_column_names(table_groups_id: str, table_name: str) -> list[str]: + results = fetch_all_from_db( + """ + SELECT column_name + FROM data_column_chars + WHERE table_groups_id = :table_groups_id + AND table_name = :table_name + AND drop_date IS NULL + ORDER BY column_name + """, + { + "table_groups_id": table_groups_id, + "table_name": table_name, + }, + ) + return [ row.column_name for row in results ] -@st.cache_data(show_spinner=False) -def get_test_suite_columns(test_suite_id: str) -> pd.DataFrame: - schema: str = st.session_state["dbschema"] - sql = f""" - SELECT d.table_name, d.column_name, t.test_name_short, d.test_type - FROM {schema}.test_definitions d - LEFT JOIN {schema}.test_types as t on t.test_type = d.test_type - WHERE test_suite_id = '{test_suite_id}' - ORDER BY table_name, column_name; - """ - return db.retrieve_data(sql) +def validate_test(test_definition, table_group: TableGroupMinimal): + schema = test_definition["schema_name"] + table_name = test_definition["table_name"] + if test_definition["test_type"] == "Condition_Flag": + condition = test_definition["custom_query"] + query = f""" + SELECT + COALESCE(CAST(SUM(CASE WHEN {condition} THEN 1 ELSE 0 END) AS VARCHAR(1000) ) || '|' ,'|') + FROM {schema}.{table_name}; + """ + else: + query = test_definition["custom_query"] + query = query.replace("{DATA_SCHEMA}", schema) -@st.cache_data(show_spinner=False) -def run_test_suite_lookup_query(str_table_groups_id, test_suite_name=None): - str_schema = st.session_state["dbschema"] - return dq.run_test_suite_lookup_by_tgroup_query(str_schema, str_table_groups_id, test_suite_name) - - -def get_column_names(table_groups_id: str, table_name: str) -> list: - schema: str = st.session_state["dbschema"] - sql = f""" - SELECT column_name - FROM {schema}.data_column_chars - WHERE table_groups_id = '{table_groups_id}'::UUID - AND table_name = '{table_name}' - AND drop_date IS NULL - ORDER BY column_name - """ - df = db.retrieve_data(sql) - return df["column_name"].tolist() + connection = Connection.get_by_table_group(table_group.id) + fetch_from_target_db(connection, query) diff --git a/testgen/ui/views/test_results.py b/testgen/ui/views/test_results.py index 67335b4e..6f03fe83 100644 --- a/testgen/ui/views/test_results.py +++ b/testgen/ui/views/test_results.py @@ -2,7 +2,8 @@ from functools import partial from io import BytesIO from itertools import zip_longest -from operator import itemgetter +from operator import attrgetter +from uuid import UUID import pandas as pd import plotly.express as px @@ -10,12 +11,14 @@ import streamlit as st from streamlit.delta_generator import DeltaGenerator -import testgen.ui.services.database_service as db import testgen.ui.services.form_service as fm -import testgen.ui.services.query_service as dq from testgen.commands.run_rollup_scores import run_test_rollup_scoring_queries from testgen.common import date_service from testgen.common.mixpanel_service import MixpanelService +from testgen.common.models import with_database_session +from testgen.common.models.test_definition import TestDefinition +from testgen.common.models.test_run import TestRun +from testgen.common.models.test_suite import TestSuite from testgen.ui.components import widgets as testgen from testgen.ui.components.widgets.download_dialog import ( FILE_DATA_TYPE, @@ -27,14 +30,16 @@ from testgen.ui.components.widgets.page import css_class, flex_row_end from testgen.ui.navigation.page import Page from testgen.ui.pdf.test_result_report import create_report -from testgen.ui.services import project_service, test_definition_service, test_results_service, user_session_service +from testgen.ui.queries import test_result_queries +from testgen.ui.queries.source_data_queries import get_test_issue_source_data, get_test_issue_source_data_custom +from testgen.ui.services import user_session_service +from testgen.ui.services.database_service import execute_db_query, fetch_df_from_db, fetch_one_from_db from testgen.ui.services.string_service import empty_if_null, snake_case_to_title_case from testgen.ui.session import session from testgen.ui.views.dialogs.profiling_results_dialog import view_profiling_button from testgen.ui.views.test_definitions import show_test_form_by_id -from testgen.utils import friendly_score, is_uuid4 +from testgen.utils import friendly_score -ALWAYS_SPIN = False PAGE_PATH = "test-runs:results" @@ -53,31 +58,32 @@ def render( test_type: str | None = None, table_name: str | None = None, column_name: str | None = None, + action: str | None = None, **_kwargs, ) -> None: - run_df = get_run_by_id(run_id) - if run_df.empty: + run = TestRun.get_minimal(run_id) + if not run: self.router.navigate_with_warning( f"Test run with ID '{run_id}' does not exist. Redirecting to list of Test Runs ...", "test-runs", ) return - run_date = date_service.get_timezoned_timestamp(st.session_state, run_df["test_starttime"]) - project_service.set_sidebar_project(run_df["project_code"]) + run_date = date_service.get_timezoned_timestamp(st.session_state, run.test_starttime) + session.set_sidebar_project(run.project_code) testgen.page_header( "Test Results", "view-testgen-test-results", breadcrumbs=[ - { "label": "Test Runs", "path": "test-runs", "params": { "project_code": run_df["project_code"] } }, - { "label": f"{run_df['test_suite']} | {run_date}" }, + { "label": "Test Runs", "path": "test-runs", "params": { "project_code": run.project_code } }, + { "label": f"{run.test_suite} | {run_date}" }, ], ) - summary_column, score_column, actions_column = st.columns([.4, .2, .4], vertical_alignment="bottom") - table_filter_column, column_filter_column, test_type_filter_column, status_filter_column, sort_column, export_button_column = st.columns( - [.175, .175, .2, .2, .1, .15], vertical_alignment="bottom" + summary_column, score_column, actions_column, export_button_column = st.columns([.3, .15, .3, .15], vertical_alignment="bottom") + status_filter_column, table_filter_column, column_filter_column, test_type_filter_column, action_filter_column, sort_column = st.columns( + [.175, .2, .2, .175, .15, .1], vertical_alignment="bottom" ) testgen.flex_row_end(actions_column) @@ -87,8 +93,21 @@ def render( tests_summary = get_test_result_summary(run_id) testgen.summary_bar(items=tests_summary, height=20, width=800) - with score_column: - render_score(run_df["project_code"], run_id) + with status_filter_column: + status_options = [ + "Failed + Warning", + "Failed", + "Warning", + "Passed", + "Error", + ] + status = testgen.select( + options=status_options, + default_value=status or "Failed + Warning", + bind_to_query="status", + bind_empty_value=True, + label="Status", + ) run_columns_df = get_test_run_columns(run_id) with table_filter_column: @@ -126,21 +145,14 @@ def render( label="Test Type", ) - with status_filter_column: - status_options = [ - "Failed + Warning", - "Failed", - "Warning", - "Passed", - "Error", - ] - status = testgen.select( - options=status_options, - default_value=status or "Failed + Warning", - bind_to_query="status", - bind_empty_value=True, - label="Status", + with action_filter_column: + action = testgen.select( + options=["✓ Confirmed", "✘ Dismissed", "🔇 Muted", "â†Šī¸Ž No Action"], + default_value=action, + bind_to_query="action", + label="Action", ) + action = action.split(" ", 1)[1] if action else None with sort_column: sortable_columns = ( @@ -163,30 +175,31 @@ def render( case "Failed + Warning": status = ["Failed", "Warning"] case "Failed": - status = "Failed" + status = ["Failed"] case "Warning": - status = "Warning" + status = ["Warning"] case "Passed": - status = "Passed" + status = ["Passed"] case "Error": - status = "Error" + status = ["Error"] # Display main grid and retrieve selection selected = show_result_detail( run_id, run_date, - run_df["test_suite"], + run.test_suite_id, export_button_column, status, test_type, table_name, column_name, + action, sorting_columns, do_multi_select, ) # Need to render toolbar buttons after grid, so selection status is maintained - affected_cached_functions = [get_test_disposition, get_test_results] + affected_cached_functions = [get_test_disposition, test_result_queries.get_test_results] disposition_actions = [ { "icon": "✓", "help": "Confirm this issue as relevant for this run", "status": "Confirmed" }, @@ -216,6 +229,11 @@ def render( lst_cached_functions=affected_cached_functions, ) + # Needs to be after all data loading/updating + # Otherwise the database session is lost for any queries after the fragment -_- + with score_column: + render_score(run.project_code, run_id) + # Help Links st.markdown( "[Help on Test Types](https://docs.datakitchen.io/article/dataops-testgen-help/testgen-test-types)" @@ -223,13 +241,14 @@ def render( @st.fragment +@with_database_session def render_score(project_code: str, run_id: str): - run_df = get_run_by_id(run_id) + run = TestRun.get_minimal(run_id) testgen.flex_row_center() with st.container(): testgen.caption("Score", "text-align: center;") testgen.text( - friendly_score(run_df["dq_score_test_run"]) or "--", + friendly_score(run.dq_score_test_run) or "--", "font-size: 28px;", ) @@ -240,12 +259,12 @@ def render_score(project_code: str, run_id: str): style="color: var(--secondary-text-color);", icon="autorenew", icon_size=22, - tooltip=f"Recalculate scores for run {'and table group' if run_df["is_latest_run"] else ''}", + tooltip=f"Recalculate scores for run {'and table group' if run.is_latest_run else ''}", on_click=partial( refresh_score, project_code, run_id, - run_df["table_groups_id"] if run_df["is_latest_run"] else None, + run.table_groups_id if run.is_latest_run else None, ), ) @@ -255,77 +274,35 @@ def refresh_score(project_code: str, run_id: str, table_group_id: str | None) -> st.cache_data.clear() -@st.cache_data(show_spinner=ALWAYS_SPIN) -def get_run_by_id(test_run_id: str) -> pd.Series: - if not is_uuid4(test_run_id): - return pd.Series() - - schema: str = st.session_state["dbschema"] - sql = f""" - SELECT tr.test_starttime, - ts.test_suite, - ts.project_code, - ts.table_groups_id::VARCHAR, - tr.dq_score_test_run, - CASE WHEN tr.id = ts.last_complete_test_run_id THEN true ELSE false END AS is_latest_run - FROM {schema}.test_runs tr - INNER JOIN {schema}.test_suites ts ON tr.test_suite_id = ts.id - WHERE tr.id = '{test_run_id}'::UUID; - """ - df = db.retrieve_data(sql) - if not df.empty: - return df.iloc[0] - else: - return pd.Series() - - @st.cache_data(show_spinner=False) def get_test_run_columns(test_run_id: str) -> pd.DataFrame: - schema: str = st.session_state["dbschema"] - sql = f""" + query = """ SELECT r.table_name as table_name, r.column_names AS column_name, t.test_name_short as test_name_short, t.test_type as test_type - FROM {schema}.test_results r - LEFT JOIN {schema}.test_types t ON t.test_type = r.test_type - WHERE test_run_id = '{test_run_id}' + FROM test_results r + LEFT JOIN test_types t ON t.test_type = r.test_type + WHERE test_run_id = :test_run_id ORDER BY table_name, column_names; """ - return db.retrieve_data(sql) + return fetch_df_from_db(query, {"test_run_id": test_run_id}) @st.cache_data(show_spinner=False) -def get_test_results( - run_id: str, - test_status: str | list[str] | None = None, - test_type_id: str | None = None, - table_name: str | None = None, - column_name: str | None = None, - sorting_columns: list[str] | None = None, -) -> pd.DataFrame: - schema: str = st.session_state["dbschema"] - return test_results_service.get_test_results(schema, run_id, test_status, test_type_id, table_name, column_name, sorting_columns) - - -@st.cache_data(show_spinner=False) -def get_test_disposition(str_run_id): - str_schema = st.session_state["dbschema"] - str_sql = f""" - SELECT id::VARCHAR, disposition - FROM {str_schema}.test_results - WHERE test_run_id = '{str_run_id}'; +def get_test_disposition(test_run_id: str) -> pd.DataFrame: + query = """ + SELECT id::VARCHAR, disposition + FROM test_results + WHERE test_run_id = :test_run_id; """ - - df = db.retrieve_data(str_sql) - + df = fetch_df_from_db(query, {"test_run_id": test_run_id}) dct_replace = {"Confirmed": "✓", "Dismissed": "✘", "Inactive": "🔇", "Passed": ""} df["action"] = df["disposition"].replace(dct_replace) return df[["id", "action"]] -@st.cache_data(show_spinner=ALWAYS_SPIN) -def get_test_result_summary(run_id): - schema = st.session_state["dbschema"] - sql = f""" +@st.cache_data(show_spinner=False) +def get_test_result_summary(test_run_id: str) -> list[dict]: + query = """ SELECT SUM( CASE WHEN COALESCE(test_results.disposition, 'Confirmed') = 'Confirmed' @@ -360,77 +337,49 @@ def get_test_result_summary(run_id): ELSE 0 END ) as dismissed_ct - FROM {schema}.test_runs - LEFT JOIN {schema}.test_results ON ( - test_runs.id = test_results.test_run_id - ) - WHERE test_runs.id = '{run_id}'::UUID; + FROM test_runs + LEFT JOIN test_results ON ( + test_runs.id = test_results.test_run_id + ) + WHERE test_runs.id = :test_run_id; """ - df = db.retrieve_data(sql) + result = fetch_one_from_db(query, {"test_run_id": test_run_id}) return [ - { "label": "Passed", "value": int(df.at[0, "passed_ct"]), "color": "green" }, - { "label": "Warning", "value": int(df.at[0, "warning_ct"]), "color": "yellow" }, - { "label": "Failed", "value": int(df.at[0, "failed_ct"]), "color": "red" }, - { "label": "Error", "value": int(df.at[0, "error_ct"]), "color": "brown" }, - { "label": "Dismissed", "value": int(df.at[0, "dismissed_ct"]), "color": "grey" }, + { "label": "Passed", "value": result.passed_ct, "color": "green" }, + { "label": "Warning", "value": result.warning_ct, "color": "yellow" }, + { "label": "Failed", "value": result.failed_ct, "color": "red" }, + { "label": "Error", "value": result.error_ct, "color": "brown" }, + { "label": "Dismissed", "value": result.dismissed_ct, "color": "grey" }, ] -@st.cache_data(show_spinner=ALWAYS_SPIN) -def get_test_definition(str_test_def_id): - str_schema = st.session_state["dbschema"] - return test_definition_service.get_test_definition(str_schema, str_test_def_id) - - -@st.cache_data(show_spinner=False) -def do_source_data_lookup(selected_row): - schema = st.session_state["dbschema"] - return test_results_service.do_source_data_lookup(schema, selected_row, limit=500) - - -@st.cache_data(show_spinner=False) -def do_source_data_lookup_custom(selected_row): - schema = st.session_state["dbschema"] - return test_results_service.do_source_data_lookup_custom(schema, selected_row, limit=500) - - -@st.cache_data(show_spinner=False) -def get_test_result_history(selected_row): - schema = st.session_state["dbschema"] - return test_results_service.get_test_result_history(schema, selected_row) - - -def show_test_def_detail(test_def_id: str): - def readable_boolean(v: typing.Literal["Y", "N"]): - return "Yes" if v == "Y" else "No" - - if not test_def_id: +def show_test_def_detail(test_definition_id: str, test_suite: TestSuite): + def readable_boolean(v: bool): + return "Yes" if v else "No" + + if not test_definition_id: st.warning("Test definition no longer exists.") return + + test_definition = TestDefinition.get(test_definition_id) - df = get_test_definition(test_def_id) - - specs = [] - if not df.empty: - test_definition = df.iloc[0] - row = test_definition - - dynamic_attributes_labels_raw: str = test_definition["default_parm_prompts"] + if test_definition: + dynamic_attributes_labels_raw = test_definition.default_parm_prompts if not dynamic_attributes_labels_raw: dynamic_attributes_labels_raw = "" dynamic_attributes_labels = dynamic_attributes_labels_raw.split(",") - dynamic_attributes_raw: str = test_definition["default_parm_columns"] + dynamic_attributes_raw = test_definition.default_parm_columns dynamic_attributes_fields = dynamic_attributes_raw.split(",") - dynamic_attributes_values = itemgetter(*dynamic_attributes_fields)(test_definition)\ + dynamic_attributes_values = attrgetter(*dynamic_attributes_fields)(test_definition)\ if len(dynamic_attributes_fields) > 1\ - else (test_definition[dynamic_attributes_fields[0]],) + else (getattr(test_definition, dynamic_attributes_fields[0]),) for field_name in dynamic_attributes_fields[len(dynamic_attributes_labels):]: dynamic_attributes_labels.append(snake_case_to_title_case(field_name)) - dynamic_attributes_help_raw: str = test_definition["default_parm_help"] + dynamic_attributes_help_raw = test_definition.default_parm_help if not dynamic_attributes_help_raw: dynamic_attributes_help_raw = "" dynamic_attributes_help = dynamic_attributes_help_raw.split("|") @@ -439,22 +388,21 @@ def readable_boolean(v: typing.Literal["Y", "N"]): "test_definition_summary", props={ "test_definition": { - "schema": test_definition["schema_name"], - "test_suite_name": test_definition["test_suite_name"], - "table_name": test_definition["table_name"], - "test_focus": test_definition["column_name"], - "export_to_observability": readable_boolean(test_definition["export_to_observability"]) - if test_definition["export_to_observability"] - else f"Inherited ({readable_boolean(test_definition["default_export_to_observability"])})", - "severity": test_definition["severity"] or f"Test Default ({test_definition['default_severity']})", - "locked": readable_boolean(test_definition["lock_refresh"]), - "active": readable_boolean(test_definition["test_active"]), - "status": test_definition["status"], - "usage_notes": test_definition["usage_notes"], - "last_manual_update": test_definition["last_manual_update"].isoformat() - if test_definition["last_manual_update"] + "schema": test_definition.schema_name, + "test_suite_name": test_suite.test_suite, + "table_name": test_definition.table_name, + "test_focus": test_definition.column_name, + "export_to_observability": readable_boolean(test_definition.export_to_observability) + if test_definition.export_to_observability is not None + else f"Inherited ({readable_boolean(test_suite.export_to_observability)})", + "severity": test_definition.severity or f"Test Default ({test_definition.default_severity})", + "locked": readable_boolean(test_definition.lock_refresh), + "active": readable_boolean(test_definition.test_active), + "usage_notes": test_definition.usage_notes, + "last_manual_update": test_definition.last_manual_update.isoformat() + if test_definition.last_manual_update else None, - "custom_query": test_definition["custom_query"] + "custom_query": test_definition.custom_query if "custom_query" in dynamic_attributes_fields else None, "attributes": [ @@ -474,26 +422,32 @@ def readable_boolean(v: typing.Literal["Y", "N"]): def show_result_detail( run_id: str, run_date: str, - test_suite: str, + test_suite_id: UUID, export_container: DeltaGenerator, - test_status: str | None = None, + test_statuses: list[str] | None = None, test_type_id: str | None = None, table_name: str | None = None, column_name: str | None = None, + action: typing.Literal["Confirmed", "Dismissed", "Muted", "No Action"] | None = None, sorting_columns: list[str] | None = None, do_multi_select: bool = False, ): with st.container(): with st.spinner("Loading data ..."): # Retrieve test results (always cached, action as null) - df = get_test_results(run_id, test_status, test_type_id, table_name, column_name, sorting_columns) + df = test_result_queries.get_test_results(run_id, test_statuses, test_type_id, table_name, column_name, action, sorting_columns) # Retrieve disposition action (cache refreshed) df_action = get_test_disposition(run_id) + # Update action from disposition df + action_map = df_action.set_index("id")["action"].to_dict() + df["action"] = df["test_result_id"].map(action_map).fillna(df["action"]) # Update action from disposition df action_map = df_action.set_index("id")["action"].to_dict() df["action"] = df["test_result_id"].map(action_map).fillna(df["action"]) + test_suite = TestSuite.get_minimal(test_suite_id) + lst_show_columns = [ "table_name", "column_names", @@ -536,7 +490,7 @@ def open_download_dialog(data: pd.DataFrame | None = None) -> None: download_dialog( dialog_title="Download Excel Report", file_content_func=get_excel_report_data, - args=(test_suite, run_date, run_id, data), + args=(test_suite.test_suite, run_date, run_id, data), ) with popover_container.container(key="tg--export-popover"): @@ -553,7 +507,7 @@ def open_download_dialog(data: pd.DataFrame | None = None) -> None: st.markdown(":orange[Select a record to see more information.]") else: selected_row = selected_rows[0] - dfh = get_test_result_history(selected_row) + dfh = test_result_queries.get_test_result_history(selected_row) show_hist_columns = ["test_date", "threshold_value", "result_measure", "result_status"] time_columns = ["test_date"] @@ -639,10 +593,11 @@ def open_download_dialog(data: pd.DataFrame | None = None) -> None: else: write_history_graph(dfh) with ut_tab2: - show_test_def_detail(selected_row["test_definition_id_current"]) + show_test_def_detail(selected_row["test_definition_id_current"], test_suite) return selected_rows +@with_database_session def get_excel_report_data( update_progress: PROGRESS_UPDATE_TYPE, test_suite: str, @@ -651,7 +606,7 @@ def get_excel_report_data( data: pd.DataFrame | None = None, ) -> FILE_DATA_TYPE: if data is None: - data = get_test_results(run_id) + data = test_result_queries.get_test_results(run_id) columns = { "schema_name": {"header": "Schema"}, @@ -764,14 +719,14 @@ def do_disposition_update(selected, str_new_status): elif len(selected) == 1: str_which = f"of one result to {str_new_status}" - str_schema = st.session_state["dbschema"] - if not dq.update_result_disposition(selected, str_schema, str_new_status): + if not update_result_disposition(selected, str_new_status): str_result = f":red[**The update {str_which} did not succeed.**]" return str_result @st.dialog(title="Source Data") +@with_database_session def source_data_dialog(selected_row): st.markdown(f"#### {selected_row['test_name_short']}") st.caption(selected_row["test_description"]) @@ -784,13 +739,13 @@ def source_data_dialog(selected_row): with st.spinner("Retrieving source data..."): if selected_row["test_type"] == "CUSTOM": - bad_data_status, bad_data_msg, query, df_bad = do_source_data_lookup_custom(selected_row) + bad_data_status, bad_data_msg, _, df_bad = get_test_issue_source_data_custom(selected_row, limit=500) else: - bad_data_status, bad_data_msg, query, df_bad = do_source_data_lookup(selected_row) + bad_data_status, bad_data_msg, _, df_bad = get_test_issue_source_data(selected_row, limit=500) if bad_data_status in {"ND", "NA"}: st.info(bad_data_msg) elif bad_data_status == "ERR": - st.error(f"{bad_data_msg}\n\n{query}") + st.error(bad_data_msg) elif df_bad is None: st.error("An unknown error was encountered.") else: @@ -814,7 +769,7 @@ def view_edit_test(button_container, test_definition_id): def get_report_file_data(update_progress, tr_data) -> FILE_DATA_TYPE: tr_id = tr_data["test_result_id"][:8] - tr_time = pd.Timestamp(tr_data["test_time"]).strftime("%Y%m%d_%H%M%S") + tr_time = pd.Timestamp(tr_data["test_date"]).strftime("%Y%m%d_%H%M%S") file_name = f"testgen_test_issue_report_{tr_id}_{tr_time}.pdf" with BytesIO() as buffer: @@ -822,3 +777,53 @@ def get_report_file_data(update_progress, tr_data) -> FILE_DATA_TYPE: update_progress(1.0) buffer.seek(0) return file_name, "application/pdf", buffer.read() + + +def update_result_disposition( + selected: list[dict], + disposition: typing.Literal["Confirmed", "Dismissed", "Inactive", "No Decision"], +): + test_result_ids = [row["test_result_id"] for row in selected] + + execute_db_query( + """ + WITH selects + AS (SELECT UNNEST(ARRAY [:test_result_ids]) AS selected_id) + UPDATE test_results + SET disposition = NULLIF(:disposition, 'No Decision') + FROM test_results r + INNER JOIN selects s + ON (r.id = s.selected_id::UUID) + WHERE r.id = test_results.id + AND r.result_status != 'Passed'; + """, + { + "test_result_ids": test_result_ids, + "disposition": disposition, + }, + ) + + execute_db_query( + """ + WITH selects + AS (SELECT UNNEST(ARRAY [:test_result_ids]) AS selected_id) + UPDATE test_definitions + SET test_active = :test_active, + last_manual_update = CURRENT_TIMESTAMP AT TIME ZONE 'UTC', + lock_refresh = :lock_refresh + FROM test_definitions d + INNER JOIN test_results r + ON (d.id = r.test_definition_id) + INNER JOIN selects s + ON (r.id = s.selected_id::UUID) + WHERE d.id = test_definitions.id + AND r.result_status != 'Passed'; + """, + { + "test_result_ids": test_result_ids, + "test_active": "N" if disposition == "Inactive" else "Y", + "lock_refresh": "Y" if disposition == "Inactive" else "N", + }, + ) + + return True diff --git a/testgen/ui/views/test_runs.py b/testgen/ui/views/test_runs.py index 0b50d649..8f2ebd59 100644 --- a/testgen/ui/views/test_runs.py +++ b/testgen/ui/views/test_runs.py @@ -1,25 +1,27 @@ +import json import logging import typing +from collections.abc import Iterable from functools import partial -import pandas as pd import streamlit as st import testgen.common.process_service as process_service -import testgen.ui.services.database_service as db import testgen.ui.services.form_service as fm -import testgen.ui.services.query_service as dq from testgen.common.models import with_database_session +from testgen.common.models.project import Project +from testgen.common.models.table_group import TableGroup, TableGroupMinimal +from testgen.common.models.test_run import TestRun +from testgen.common.models.test_suite import TestSuite, TestSuiteMinimal from testgen.ui.components import widgets as testgen from testgen.ui.components.widgets import testgen_component from testgen.ui.navigation.menu import MenuItem from testgen.ui.navigation.page import Page -from testgen.ui.queries import project_queries, test_run_queries from testgen.ui.services import user_session_service from testgen.ui.session import session, temp_value from testgen.ui.views.dialogs.manage_schedules import ScheduleDialog from testgen.ui.views.dialogs.run_tests_dialog import run_tests_dialog -from testgen.utils import friendly_score, to_int +from testgen.utils import friendly_score, to_dataframe, to_int PAGE_SIZE = 50 PAGE_ICON = "labs" @@ -55,7 +57,9 @@ def render(self, project_code: str, table_group_id: str | None = None, test_suit group_filter_column, suite_filter_column, actions_column = st.columns([.3, .3, .4], vertical_alignment="bottom") with group_filter_column: - table_groups_df = get_db_table_group_choices(project_code) + table_groups = TableGroup.select_minimal_where(TableGroup.project_code == project_code) + table_groups_df = to_dataframe(table_groups, TableGroupMinimal.columns()) + table_groups_df["id"] = table_groups_df["id"].apply(lambda x: str(x)) table_group_id = testgen.select( options=table_groups_df, value_column="id", @@ -67,7 +71,12 @@ def render(self, project_code: str, table_group_id: str | None = None, test_suit ) with suite_filter_column: - test_suites_df = get_db_test_suite_choices(project_code, table_group_id) + clauses = [TestSuite.project_code == project_code] + if table_group_id: + clauses.append(TestSuite.table_groups_id == table_group_id) + test_suites = TestSuite.select_where(*clauses) + test_suites_df = to_dataframe(test_suites, TestSuite.columns()) + test_suites_df["id"] = test_suites_df["id"].apply(lambda x: str(x)) test_suite_id = testgen.select( options=test_suites_df, value_column="id", @@ -99,16 +108,25 @@ def render(self, project_code: str, table_group_id: str | None = None, test_suit testgen.whitespace(0.5) list_container = st.container() - test_runs_df = get_db_test_runs(project_code, table_group_id, test_suite_id) - page_index = testgen.paginator(count=len(test_runs_df), page_size=PAGE_SIZE) - test_runs_df["dq_score_testing"] = test_runs_df["dq_score_testing"].map(lambda score: friendly_score(score)) - paginated_df = test_runs_df[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)] + with st.spinner("Loading data ..."): + test_runs = TestRun.select_summary(project_code, table_group_id, test_suite_id) + + paginated = [] + if run_count := len(test_runs): + page_index = testgen.paginator(count=run_count, page_size=PAGE_SIZE) + test_runs = [ + { + **row.to_dict(json_safe=True), + "dq_score_testing": friendly_score(row.dq_score_testing), + } for row in test_runs + ] + paginated = test_runs[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)] with list_container: testgen_component( "test_runs", props={ - "items": paginated_df.to_json(orient="records"), + "items": json.dumps(paginated), "permissions": { "can_run": user_can_run, "can_edit": user_can_run, @@ -126,18 +144,19 @@ class TestRunScheduleDialog(ScheduleDialog): title = "Test Run Schedules" arg_label = "Test Suite" job_key = "run-tests" - test_suites: pd.DataFrame | None = None + test_suites: Iterable[TestSuiteMinimal] | None = None def init(self) -> None: - self.test_suites = get_db_test_suite_choices(self.project_code) + self.test_suites = TestSuite.select_minimal_where(TestSuite.project_code == self.project_code) def get_arg_value(self, job): return job.kwargs["test_suite_key"] def arg_value_input(self) -> tuple[bool, list[typing.Any], dict[str, typing.Any]]: + test_suites_df = to_dataframe(self.test_suites, TestSuiteMinimal.columns()) ts_name = testgen.select( label="Test Suite", - options=self.test_suites, + options=test_suites_df, value_column="test_suite", display_column="test_suite", required=True, @@ -147,13 +166,13 @@ def arg_value_input(self) -> tuple[bool, list[typing.Any], dict[str, typing.Any] def render_empty_state(project_code: str, user_can_run: bool) -> bool: - project_summary_df = project_queries.get_summary_by_code(project_code) - if project_summary_df["test_runs_ct"]: + project_summary = Project.get_summary(project_code) + if project_summary.test_run_count: return False label="No test runs yet" testgen.whitespace(5) - if not project_summary_df["connections_ct"]: + if not project_summary.connection_count: testgen.empty_state( label=label, icon=PAGE_ICON, @@ -162,7 +181,7 @@ def render_empty_state(project_code: str, user_can_run: bool) -> bool: link_href="connections", link_params={ "project_code": project_code }, ) - elif not project_summary_df["table_groups_ct"]: + elif not project_summary.table_group_count: testgen.empty_state( label=label, icon=PAGE_ICON, @@ -171,10 +190,10 @@ def render_empty_state(project_code: str, user_can_run: bool) -> bool: link_href="table-groups", link_params={ "project_code": project_code, - "connection_id": str(project_summary_df["default_connection_id"]), + "connection_id": str(project_summary.default_connection_id), } ) - elif not project_summary_df["test_suites_ct"] or not project_summary_df["test_definitions_ct"]: + elif not project_summary.test_suite_count or not project_summary.test_definition_count: testgen.empty_state( label=label, icon=PAGE_ICON, @@ -196,10 +215,11 @@ def render_empty_state(project_code: str, user_can_run: bool) -> bool: return True -def on_cancel_run(test_run: pd.Series) -> None: +def on_cancel_run(test_run: dict) -> None: process_status, process_message = process_service.kill_test_run(to_int(test_run["process_id"])) if process_status: - test_run_queries.update_status(test_run["test_run_id"], "Cancelled") + TestRun.update_status(test_run["test_run_id"], "Cancelled") + fm.reset_post_updates(str_message=f":{'green' if process_status else 'red'}[{process_message}]", as_toast=True) @@ -214,11 +234,11 @@ def on_delete_confirmed(*_args) -> None: "warning": "Any running processes will be canceled.", "confirmation": "Yes, cancel and delete the test runs.", } - if len(test_run_ids) == 1 and (test_run_id := test_run_ids[0]): + if len(test_run_ids) == 1: message = "Are you sure you want to delete the selected test run?" constraint["confirmation"] = "Yes, cancel and delete the test run." - if not test_run_queries.is_running(test_run_ids): + if not TestRun.has_running_process(test_run_ids): constraint = None result = None @@ -241,139 +261,16 @@ def on_delete_confirmed(*_args) -> None: if delete_confirmed(): try: with st.spinner("Deleting runs ..."): - test_runs = _get_db_test_runs(project_code, table_group_id, test_suite_id, test_runs_ids=test_run_ids) - for _, test_run in test_runs.iterrows(): - test_run_id = test_run["test_run_id"] - if test_run["status"] == "Running": - process_status, _ = process_service.kill_test_run(to_int(test_run["process_id"])) + test_runs = TestRun.select_summary(project_code, table_group_id, test_suite_id, test_run_ids) + for test_run in test_runs: + if test_run.status == "Running": + process_status, _ = process_service.kill_test_run(to_int(test_run.process_id)) if process_status: - test_run_queries.update_status(test_run_id, "Cancelled") - test_run_queries.cascade_delete_multiple_test_runs(test_run_ids) + TestRun.update_status(test_run.test_run_id, "Cancelled") + TestRun.cascade_delete(test_run_ids) st.rerun() except Exception: LOG.exception("Failed to delete test run") result = {"success": False, "message": "Unable to delete the test run, try again."} - - -@st.cache_data(show_spinner=False) -def run_test_suite_lookup_query(schema: str, project_code: str, table_groups_id: str | None = None) -> pd.DataFrame: - table_group_condition = f" AND test_suites.table_groups_id = '{table_groups_id}' " if table_groups_id else "" - sql = f""" - SELECT test_suites.id::VARCHAR(50), - test_suites.test_suite - FROM {schema}.test_suites - LEFT JOIN {schema}.table_groups ON test_suites.table_groups_id = table_groups.id - WHERE test_suites.project_code = '{project_code}' - {table_group_condition} - ORDER BY test_suites.test_suite - """ - return db.retrieve_data(sql) - - -@st.cache_data(show_spinner=False) -def get_db_table_group_choices(project_code: str) -> pd.DataFrame: - schema = st.session_state["dbschema"] - return dq.run_table_groups_lookup_query(schema, project_code) - - -@st.cache_data(show_spinner=False) -def get_db_test_suite_choices(project_code: str, table_groups_id: str | None = None) -> pd.DataFrame: - schema = st.session_state["dbschema"] - return run_test_suite_lookup_query(schema, project_code, table_groups_id) - - -@st.cache_data(show_spinner="Loading data ...") -def get_db_test_runs( - project_code: str, - table_groups_id: str | None = None, - test_suite_id: str | None = None, - test_runs_ids: list[str] | None = None, -) -> pd.DataFrame: - return _get_db_test_runs( - project_code, table_groups_id=table_groups_id, test_suite_id=test_suite_id, test_runs_ids=test_runs_ids - ) - - -def _get_db_test_runs( - project_code: str, - table_groups_id: str | None = None, - test_suite_id: str | None = None, - test_runs_ids: list[str] | None = None, -) -> pd.DataFrame: - schema = st.session_state["dbschema"] - table_group_condition = f" AND test_suites.table_groups_id = '{table_groups_id}' " if table_groups_id else "" - test_suite_condition = f" AND test_suites.id = '{test_suite_id}' " if test_suite_id else "" - - test_runs_conditions = "" - if test_runs_ids and len(test_runs_ids) > 0: - test_runs_ids_ = [f"'{run_id}'" for run_id in test_runs_ids] - test_runs_conditions = f" AND test_runs.id::VARCHAR IN ({', '.join(test_runs_ids_)})" - - sql = f""" - WITH run_results AS ( - SELECT test_run_id, - SUM( - CASE - WHEN COALESCE(disposition, 'Confirmed') = 'Confirmed' - AND result_status = 'Passed' THEN 1 - ELSE 0 - END - ) as passed_ct, - SUM( - CASE - WHEN COALESCE(disposition, 'Confirmed') = 'Confirmed' - AND result_status = 'Warning' THEN 1 - ELSE 0 - END - ) as warning_ct, - SUM( - CASE - WHEN COALESCE(disposition, 'Confirmed') = 'Confirmed' - AND result_status = 'Failed' THEN 1 - ELSE 0 - END - ) as failed_ct, - SUM( - CASE - WHEN COALESCE(disposition, 'Confirmed') = 'Confirmed' - AND result_status = 'Error' THEN 1 - ELSE 0 - END - ) as error_ct, - SUM( - CASE - WHEN COALESCE(disposition, 'Confirmed') IN ('Dismissed', 'Inactive') THEN 1 - ELSE 0 - END - ) as dismissed_ct - FROM {schema}.test_results - GROUP BY test_run_id - ) - SELECT test_runs.id::VARCHAR as test_run_id, - test_runs.test_starttime, - table_groups.table_groups_name, - test_suites.test_suite, - test_runs.status, - test_runs.duration, - test_runs.process_id, - test_runs.log_message, - test_runs.test_ct, - run_results.passed_ct, - run_results.warning_ct, - run_results.failed_ct, - run_results.error_ct, - run_results.dismissed_ct, - test_runs.dq_score_test_run AS dq_score_testing - FROM {schema}.test_runs - LEFT JOIN run_results ON (test_runs.id = run_results.test_run_id) - INNER JOIN {schema}.test_suites ON (test_runs.test_suite_id = test_suites.id) - INNER JOIN {schema}.table_groups ON (test_suites.table_groups_id = table_groups.id) - INNER JOIN {schema}.projects ON (test_suites.project_code = projects.project_code) - WHERE test_suites.project_code = '{project_code}' - {table_group_condition} - {test_suite_condition} - {test_runs_conditions} - ORDER BY test_runs.test_starttime DESC; - """ - - return db.retrieve_data(sql) + st.rerun(scope="fragment") + \ No newline at end of file diff --git a/testgen/ui/views/test_suites.py b/testgen/ui/views/test_suites.py index 0ae9b654..83169174 100644 --- a/testgen/ui/views/test_suites.py +++ b/testgen/ui/views/test_suites.py @@ -1,30 +1,30 @@ import time import typing +from collections.abc import Iterable from functools import partial import streamlit as st -import testgen.ui.services.form_service as fm -import testgen.ui.services.query_service as dq -import testgen.ui.services.test_suite_service as test_suite_service from testgen.commands.run_observability_exporter import export_test_results +from testgen.common.models import with_database_session +from testgen.common.models.project import Project +from testgen.common.models.table_group import TableGroup, TableGroupMinimal +from testgen.common.models.test_suite import TestSuite from testgen.ui.components import widgets as testgen from testgen.ui.navigation.menu import MenuItem from testgen.ui.navigation.page import Page from testgen.ui.navigation.router import Router -from testgen.ui.queries import project_queries from testgen.ui.services import user_session_service from testgen.ui.services.string_service import empty_if_null from testgen.ui.session import session from testgen.ui.views.dialogs.generate_tests_dialog import generate_tests_dialog from testgen.ui.views.dialogs.run_tests_dialog import run_tests_dialog from testgen.ui.views.test_runs import TestRunScheduleDialog -from testgen.utils import format_field +from testgen.utils import to_dataframe PAGE_ICON = "rule" PAGE_TITLE = "Test Suites" - class TestSuitesPage(Page): path = "test-suites" can_activate: typing.ClassVar = [ @@ -46,50 +46,22 @@ def render(self, project_code: str, table_group_id: str | None = None, **_kwargs "create-a-test-suite", ) - table_groups = get_db_table_group_choices(project_code) + table_groups = TableGroup.select_minimal_where(TableGroup.project_code == project_code) user_can_edit = user_session_service.user_can_edit() - test_suites = test_suite_service.get_by_project(project_code, table_group_id) - project_summary = project_queries.get_summary_by_code(project_code) - - test_suite_fields = [ - "id", - "connection_name", - "table_groups_name", - "test_suite", - "test_suite_description", - "test_ct", - "latest_run_start", - "latest_run_id", - "last_run_test_ct", - "last_run_passed_ct", - "last_run_warning_ct", - "last_run_failed_ct", - "last_run_error_ct", - "last_run_dismissed_ct", - "last_complete_profile_run_id", - ] + test_suites = TestSuite.select_summary(project_code, table_group_id) + project_summary = Project.get_summary(project_code) + testgen.testgen_component( "test_suites", props={ - "project_summary": { - "project_code": project_code, - "test_suites_ct": format_field(project_summary["test_suites_ct"]), - "connections_ct": format_field(project_summary["connections_ct"]), - "table_groups_ct": format_field(project_summary["table_groups_ct"]), - "default_connection_id": format_field(project_summary["default_connection_id"]), - "can_export_to_observability": format_field(project_summary["can_export_to_observability"]), - }, - "test_suites": [ - { - fieldname: format_field(test_suite[fieldname]) for fieldname in test_suite_fields - } for _, test_suite in test_suites.iterrows() - ], + "project_summary": project_summary.to_dict(json_safe=True), + "test_suites": [test_suite.to_dict(json_safe=True) for test_suite in test_suites], "table_group_filter_options": [ { - "value": format_field(table_group["id"]), - "label": format_field(table_group["table_groups_name"]), - "selected": str(table_group_id) == str(table_group["id"]), - } for _, table_group in table_groups.iterrows() + "value": str(table_group.id), + "label": table_group.table_groups_name, + "selected": str(table_group_id) == str(table_group.id), + } for table_group in table_groups ], "permissions": { "can_edit": user_can_edit, @@ -102,8 +74,8 @@ def render(self, project_code: str, table_group_id: str | None = None, **_kwargs "ExportActionClicked": observability_export_dialog, "EditActionClicked": partial(edit_test_suite_dialog, project_code, table_groups), "DeleteActionClicked": delete_test_suite_dialog, - "RunTestsClicked": lambda test_suite_id: run_tests_dialog(project_code, test_suite_service.get_by_id(test_suite_id)), - "GenerateTestsClicked": lambda test_suite_id: generate_tests_dialog(test_suite_service.get_by_id(test_suite_id)), + "RunTestsClicked": lambda test_suite_id: run_tests_dialog(project_code, TestSuite.get_minimal(test_suite_id)), + "GenerateTestsClicked": lambda test_suite_id: generate_tests_dialog(TestSuite.get_minimal(test_suite_id)), }, ) @@ -112,44 +84,41 @@ def on_test_suites_filtered(table_group_id: str | None = None) -> None: Router().set_query_params({ "table_group_id": table_group_id }) -@st.cache_data(show_spinner=False) -def get_db_table_group_choices(project_code): - schema = st.session_state["dbschema"] - return dq.run_table_groups_lookup_query(schema, project_code) - - @st.dialog(title="Add Test Suite") -def add_test_suite_dialog(project_code, table_groups_df): - show_test_suite("add", project_code, table_groups_df) +@with_database_session +def add_test_suite_dialog(project_code, table_groups): + show_test_suite("add", project_code, table_groups) @st.dialog(title="Edit Test Suite") -def edit_test_suite_dialog(project_code, table_groups_df, test_suite_id: str) -> None: - selected = test_suite_service.get_by_id(test_suite_id) - show_test_suite("edit", project_code, table_groups_df, selected) +@with_database_session +def edit_test_suite_dialog(project_code, table_groups, test_suite_id: str) -> None: + selected = TestSuite.get(test_suite_id) + show_test_suite("edit", project_code, table_groups, selected) -def show_test_suite(mode, project_code, table_groups_df, selected=None): +def show_test_suite(mode, project_code, table_groups: Iterable[TableGroupMinimal], selected: TestSuite | None = None): severity_options = ["Inherit", "Failed", "Warning"] selected_test_suite = selected if mode == "edit" else None - - if mode == "edit" and not selected_test_suite["severity"]: - selected_test_suite["severity"] = severity_options[0] + table_groups_df = to_dataframe(table_groups, TableGroupMinimal.columns()) # establish default values - test_suite_id = selected_test_suite["id"] if mode == "edit" else None - test_suite = empty_if_null(selected_test_suite["test_suite"]) if mode == "edit" else "" - connection_id = selected_test_suite["connection_id"] if mode == "edit" else None - table_groups_id = selected_test_suite["table_groups_id"] if mode == "edit" else None - test_suite_description = empty_if_null(selected_test_suite["test_suite_description"]) if mode == "edit" else "" - test_action = empty_if_null(selected_test_suite["test_action"]) if mode == "edit" else "" - severity_index = severity_options.index(selected_test_suite["severity"]) if mode == "edit" else 0 - export_to_observability = selected_test_suite["export_to_observability"] == "Y" if mode == "edit" else False - dq_score_exclude = selected_test_suite["dq_score_exclude"] if mode == "edit" else False - test_suite_schema = empty_if_null(selected_test_suite["test_suite_schema"]) if mode == "edit" else "" - component_key = empty_if_null(selected_test_suite["component_key"]) if mode == "edit" else "" - component_type = empty_if_null(selected_test_suite["component_type"]) if mode == "edit" else "dataset" - component_name = empty_if_null(selected_test_suite["component_name"]) if mode == "edit" else "" + test_suite_id = selected_test_suite.id if mode == "edit" else None + test_suite_name = empty_if_null(selected_test_suite.test_suite) if mode == "edit" else "" + connection_id = selected_test_suite.connection_id if mode == "edit" else None + table_groups_id = selected_test_suite.table_groups_id if mode == "edit" else None + test_suite_description = empty_if_null(selected_test_suite.test_suite_description) if mode == "edit" else "" + test_action = empty_if_null(selected_test_suite.test_action) if mode == "edit" else "" + try: + severity_index = severity_options.index(selected_test_suite.severity) if mode == "edit" else 0 + except ValueError: + severity_index = 0 + export_to_observability = selected_test_suite.export_to_observability if mode == "edit" else False + dq_score_exclude = selected_test_suite.dq_score_exclude if mode == "edit" else False + test_suite_schema = empty_if_null(selected_test_suite.test_suite_schema) if mode == "edit" else "" + component_key = empty_if_null(selected_test_suite.component_key) if mode == "edit" else "" + component_type = empty_if_null(selected_test_suite.component_type) if mode == "edit" else "dataset" + component_name = empty_if_null(selected_test_suite.component_name) if mode == "edit" else "" left_column, right_column = st.columns([0.50, 0.50]) expander = st.expander("", expanded=True) @@ -161,7 +130,7 @@ def show_test_suite(mode, project_code, table_groups_df, selected=None): "id": test_suite_id, "project_code": project_code, "test_suite": left_column.text_input( - label="Test Suite Name", max_chars=40, value=test_suite, disabled=(mode != "add") + label="Test Suite Name", max_chars=40, value=test_suite_name, disabled=(mode != "add") ), "connection_id": connection_id, "table_groups_id": table_groups_id, @@ -224,14 +193,18 @@ def show_test_suite(mode, project_code, table_groups_df, selected=None): f"Blank spaces not allowed in field 'Test Suite Name'. Use dash or underscore instead. i.e.: {proposed_test_suite}" ) else: + test_suite = selected or TestSuite() + for key, value in entity.items(): + setattr(test_suite, key, value) + if mode == "edit": - test_suite_service.edit(entity) + test_suite.save() else: selected_table_group_name = entity["table_groups_name"] selected_table_group = table_groups_df[table_groups_df["table_groups_name"] == selected_table_group_name].iloc[0] - entity["connection_id"] = selected_table_group["connection_id"] - entity["table_groups_id"] = selected_table_group["id"] - test_suite_service.add(entity) + test_suite.connection_id = int(selected_table_group["connection_id"]) + test_suite.table_groups_id = selected_table_group["id"] + test_suite.save() success_message = ( "Changes have been saved successfully. " if mode == "edit" @@ -243,37 +216,21 @@ def show_test_suite(mode, project_code, table_groups_df, selected=None): @st.dialog(title="Delete Test Suite") +@with_database_session def delete_test_suite_dialog(test_suite_id: str) -> None: - selected_test_suite = test_suite_service.get_by_id(test_suite_id) - test_suite_id = selected_test_suite["id"] - test_suite_name = selected_test_suite["test_suite"] - can_be_deleted = test_suite_service.cascade_delete([test_suite_id], dry_run=True) - - fm.render_html_list( - selected_test_suite, - [ - "id", - "test_suite", - "test_suite_description", - ], - "Test Suite Information", - int_data_width=700, - ) + selected_test_suite = TestSuite.get_minimal(test_suite_id) + test_suite_id = selected_test_suite.id + test_suite_name = selected_test_suite.test_suite + is_in_use = TestSuite.is_in_use([test_suite_id]) + + st.markdown(f"Are you sure you want to delete the test suite **{test_suite_name}**?") - if not can_be_deleted: - st.html( - """ -
- - This Test Suite has related data, which includes test definitions and may - include test results. If you proceed, all related data will be permanently deleted. - -
- Are you sure you want to proceed? -
- """ + if is_in_use: + st.warning( + """This Test Suite has related data, which may include test definitions and test results. + \nIf you proceed, all related data will be permanently deleted.""" ) - accept_cascade_delete = st.toggle("I accept deletion of this Test Suite and all related TestGen data.") + accept_cascade_delete = st.toggle(f"Yes, delete the test suite **{test_suite_name}** and related TestGen data.") with st.form("Delete Test Suite", clear_on_submit=True, border=False): delete = False @@ -282,15 +239,15 @@ def delete_test_suite_dialog(test_suite_id: str) -> None: delete = st.form_submit_button( "Delete", type="primary", - disabled=not can_be_deleted and not accept_cascade_delete, + disabled=is_in_use and not accept_cascade_delete, use_container_width=True, ) if delete: - if test_suite_service.are_test_suites_in_use([test_suite_id]): + if TestSuite.has_running_process([test_suite_id]): st.error("This Test Suite is in use by a running process and cannot be deleted.") else: - test_suite_service.cascade_delete([test_suite_id]) + TestSuite.cascade_delete([test_suite_id]) success_message = f"Test Suite {test_suite_name} has been deleted. " st.success(success_message) time.sleep(1) @@ -299,9 +256,9 @@ def delete_test_suite_dialog(test_suite_id: str) -> None: @st.dialog(title="Export to Observability") def observability_export_dialog(test_suite_id: str) -> None: - selected_test_suite = test_suite_service.get_by_id(test_suite_id) - project_key = selected_test_suite["project_code"] - test_suite_key = selected_test_suite["test_suite"] + selected_test_suite = TestSuite.get_minimal(test_suite_id) + project_key = selected_test_suite.project_code + test_suite_key = selected_test_suite.test_suite start_process_button_message = "Start" with st.container(): @@ -328,7 +285,7 @@ def observability_export_dialog(test_suite_id: str) -> None: status_container.info("Executing Export ...") try: - qty_of_exported_events = export_test_results(selected_test_suite["id"]) + qty_of_exported_events = export_test_results(selected_test_suite.id) status_container.empty() status_container.success( f"Process has successfully finished, {qty_of_exported_events} events have been exported." diff --git a/testgen/utils/__init__.py b/testgen/utils/__init__.py index a73b2770..3e864034 100644 --- a/testgen/utils/__init__.py +++ b/testgen/utils/__init__.py @@ -1,5 +1,8 @@ from __future__ import annotations +from collections.abc import Iterable +from datetime import UTC, datetime +from decimal import Decimal from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -10,7 +13,6 @@ from typing import Any, TypeVar from uuid import UUID -import numpy as np import pandas as pd import streamlit as st @@ -23,7 +25,26 @@ def to_int(value: float | int) -> int: return 0 +def to_dataframe( + data: Iterable[Any], + columns: list[str] | None = None, +) -> pd.DataFrame: + records = [] + for item in data: + if hasattr(item, "to_dict") and callable(item.to_dict): + row = item.to_dict() + elif hasattr(item, "__dict__"): + row = item.__dict__ + else: + row = dict(item) + records.append(row) + return pd.DataFrame.from_records(records, columns=columns) + + def is_uuid4(value: str) -> bool: + if isinstance(value, UUID): + return True + try: uuid = UUID(value, version=4) except Exception: @@ -45,24 +66,18 @@ def get_base_url() -> str: return urllib.parse.urlunparse([session.client.request.protocol, session.client.request.host, "", "", "", ""]) -def format_field(field: Any) -> Any: - defaults = { - float: 0.0, - int: 0, - } - if isinstance(field, UUID): - return str(field) - elif isinstance(field, pd.Timestamp): - return field.value / 1_000_000 - elif pd.isnull(field): - return defaults.get(type(field), None) - elif isinstance(field, np.integer): - return int(field) - elif isinstance(field, np.floating): - return float(field) - elif isinstance(field, np.bool_): - return bool(field) - return field +def make_json_safe(value: Any) -> str | bool | int | float | None: + if isinstance(value, UUID): + return str(value) + elif isinstance(value, datetime): + return int(value.replace(tzinfo=UTC).timestamp()) + elif isinstance(value, Decimal): + return float(value) + elif isinstance(value, list): + return [ make_json_safe(item) for item in value ] + elif isinstance(value, dict): + return { key: make_json_safe(value) for key, value in value.items() } + return value def chunk_queries(queries: list[str], join_string: str, max_query_length: int) -> list[str]: @@ -165,6 +180,7 @@ def format_score_card_breakdown(breakdown: list[dict], category: str) -> dict: "table_groups_id": str(row["table_groups_id"]) if row.get("table_groups_id") else None, "score": friendly_score(row["score"]), "impact": friendly_score_impact(row["impact"]), + "issue_ct": int(row["issue_ct"]), } for row in breakdown], } @@ -175,7 +191,10 @@ def format_score_card_issues(issues: list[dict], category: str) -> dict: columns.insert(0, "column") return { "columns": columns, - "items": issues, + "items": [{ + **row, + "time": int(row["time"]), + } for row in issues], } diff --git a/tests/unit/test_profiling_query.py b/tests/unit/test_profiling_query.py index 826faad1..6ca71ecc 100644 --- a/tests/unit/test_profiling_query.py +++ b/tests/unit/test_profiling_query.py @@ -14,15 +14,16 @@ def test_include_exclude_mask_basic(): profiling_query.parm_table_exclude_mask = "temp%,tmp%,raw_slot_utilization%,gps_product_step_change_log" # test run - query = profiling_query.GetDDFQuery() + query, _ = profiling_query.GetDDFQuery() # test assertions assert "SELECT 'dummy_project_code'" in query - assert r"AND ((c.table_name LIKE 'important%' ) OR (c.table_name LIKE '%useful%' ))" in query - assert ( - r"AND NOT ((c.table_name LIKE 'temp%' ) OR (c.table_name LIKE 'tmp%' ) OR (c.table_name LIKE 'raw\_slot\_utilization%' ) OR (c.table_name LIKE 'gps\_product\_step\_change\_log' ))" - in query - ) + assert r"""AND ( + (c.table_name LIKE 'important%' ) OR (c.table_name LIKE '%useful%' ) + )""" in query + assert r"""AND NOT ( + (c.table_name LIKE 'temp%' ) OR (c.table_name LIKE 'tmp%' ) OR (c.table_name LIKE 'raw\_slot\_utilization%' ) OR (c.table_name LIKE 'gps\_product\_step\_change\_log' ) + )""" in query @pytest.mark.unit @@ -37,13 +38,13 @@ def test_include_empty_exclude_mask(mask): profiling_query.parm_table_exclude_mask = "temp%,tmp%,raw_slot_utilization%,gps_product_step_change_log" # test run - query = profiling_query.GetDDFQuery() + query, _ = profiling_query.GetDDFQuery() + print(query) # test assertions - assert ( - r"AND NOT ((c.table_name LIKE 'temp%' ESCAPE '\\') OR (c.table_name LIKE 'tmp%' ESCAPE '\\') OR (c.table_name LIKE 'raw\\_slot\\_utilization%' ESCAPE '\\') OR (c.table_name LIKE 'gps\\_product\\_step\\_change\\_log' ESCAPE '\\')" - in query - ) + assert r"""AND NOT ( + (c.table_name LIKE 'temp%' ESCAPE '\\') OR (c.table_name LIKE 'tmp%' ESCAPE '\\') OR (c.table_name LIKE 'raw\\_slot\\_utilization%' ESCAPE '\\') OR (c.table_name LIKE 'gps\\_product\\_step\\_change\\_log' ESCAPE '\\') + )""" in query @pytest.mark.unit @@ -58,7 +59,10 @@ def test_include_empty_include_mask(mask): profiling_query.parm_table_exclude_mask = mask # test run - query = profiling_query.GetDDFQuery() + query, _ = profiling_query.GetDDFQuery() + print(query) # test assertions - assert r"AND ((c.table_name LIKE 'important%' ) OR (c.table_name LIKE '%useful[_]%' ))" in query + assert r"""AND ( + (c.table_name LIKE 'important%' ) OR (c.table_name LIKE '%useful[_]%' ) + )""" in query