{label_div}
{item_spans}
@@ -60,6 +59,53 @@ def summary_bar(
{caption_div}
""")
+
+
+def summary_counts(
+ items: list["SummaryItem"],
+ label: str | None = None,
+) -> None:
+ """
+ Testgen component to display summary counts.
+
+ # Parameters
+ :param items: list of dicts with value, label, and color
+ """
+
+ label_div = ""
+ item_spans = ""
+ caption_div = ""
+
+ if label:
+ label_div = f"""
+
+ {label}
+
+ """
+
+ item_divs = "".join(
+ [
+ f"""
+
+
+
+
{item["label"]}
+
{item["value"]}
+
+
+ """
+ for item in items
+ ]
+ )
+
+ st.html(f"""
+
+ {label_div}
+
+ {item_divs}
+
+
+ """)
class SummaryItem(typing.TypedDict):
diff --git a/testgen/ui/navigation/router.py b/testgen/ui/navigation/router.py
index ab968569..243b6916 100644
--- a/testgen/ui/navigation/router.py
+++ b/testgen/ui/navigation/router.py
@@ -30,6 +30,11 @@ def run(self) -> None:
current_page = st.navigation(streamlit_pages, position="hidden")
+ if not session.initialized:
+ # Clear cache on initial load or page refresh
+ st.cache_data.clear()
+ session.initialized = True
+
# This hack is needed because the auth cookie is not set if navigation happens immediately after login
# We have to navigate on the next run
if session.auth.logging_in:
diff --git a/testgen/ui/pdf/hygiene_issue_report.py b/testgen/ui/pdf/hygiene_issue_report.py
index de5addb6..ba5d97a1 100644
--- a/testgen/ui/pdf/hygiene_issue_report.py
+++ b/testgen/ui/pdf/hygiene_issue_report.py
@@ -109,7 +109,7 @@ def build_summary_table(document, hi_data):
("Profiling Date", profiling_timestamp, "Table Group", hi_data["table_groups_name"]),
("Database/Schema", hi_data["schema_name"], "Disposition", hi_data["disposition"] or "No Decision"),
- ("Table", hi_data["table_name"], "Column Type", hi_data["column_type"]),
+ ("Table", hi_data["table_name"], "Data Type", hi_data["db_data_type"]),
("Column", hi_data["column_name"], "Semantic Data Type", hi_data["functional_data_type"]),
(
"Column Tags",
diff --git a/testgen/ui/queries/profiling_queries.py b/testgen/ui/queries/profiling_queries.py
index 94e18287..89119ef9 100644
--- a/testgen/ui/queries/profiling_queries.py
+++ b/testgen/ui/queries/profiling_queries.py
@@ -86,7 +86,7 @@ def get_profiling_results(profiling_run_id: str, table_name: str | None = None,
table_groups_id::VARCHAR AS table_group_id,
-- Characteristics
general_type,
- column_type,
+ db_data_type,
functional_data_type,
datatype_suggestion,
-- Profile Run
@@ -347,7 +347,7 @@ def get_columns_by_condition(
column_chars.ordinal_position,
-- Characteristics
column_chars.general_type,
- column_chars.column_type,
+ column_chars.db_data_type,
column_chars.functional_data_type,
datatype_suggestion,
column_chars.add_date,
diff --git a/testgen/ui/queries/scoring_queries.py b/testgen/ui/queries/scoring_queries.py
index ea2b3cd8..38b7387e 100644
--- a/testgen/ui/queries/scoring_queries.py
+++ b/testgen/ui/queries/scoring_queries.py
@@ -35,7 +35,7 @@ def get_score_card_issue_reports(selected_issues: list["SelectedIssue"]) -> list
results.schema_name,
results.table_name,
results.column_name,
- results.column_type,
+ results.db_data_type,
groups.table_groups_name,
results.disposition,
results.profile_run_id::VARCHAR,
diff --git a/testgen/ui/queries/source_data_queries.py b/testgen/ui/queries/source_data_queries.py
index d632457b..45afd5d7 100644
--- a/testgen/ui/queries/source_data_queries.py
+++ b/testgen/ui/queries/source_data_queries.py
@@ -6,7 +6,7 @@
import streamlit as st
from testgen.common.clean_sql import ConcatColumnList
-from testgen.common.database.database_service import replace_params
+from testgen.common.database.database_service import get_flavor_service, replace_params
from testgen.common.models.connection import Connection, SQLFlavor
from testgen.common.models.test_definition import TestDefinition
from testgen.common.read_file import replace_templated_functions
@@ -26,9 +26,14 @@ def generate_lookup_query(test_id: str, detail_exp: str, column_names: list[str]
start_index += len("Columns: ")
column_names_str = detail_exp[start_index:]
columns = [col.strip() for col in column_names_str.split(",")]
- quote = "`" if sql_flavor == "databricks" else '"'
+ quote = get_flavor_service(sql_flavor).quote_character
queries = [
- f"SELECT '{column}' AS column_name, MAX({quote}{column}{quote}) AS max_date_available FROM {{TARGET_SCHEMA}}.{{TABLE_NAME}}"
+ f"""
+ SELECT
+ '{column}' AS column_name,
+ MAX({quote}{column}{quote}) AS max_date_available
+ FROM {quote}{{TARGET_SCHEMA}}{quote}.{quote}{{TABLE_NAME}}{quote}
+ """
for column in columns
]
sql_query = " UNION ALL ".join(queries) + " ORDER BY max_date_available DESC;"
@@ -62,7 +67,7 @@ def generate_lookup_query(test_id: str, detail_exp: str, column_names: list[str]
lookup_query = replace_params(lookup_query, params)
lookup_query = replace_templated_functions(lookup_query, lookup_data.sql_flavor)
return lookup_query
-
+
@st.cache_data(show_spinner=False)
def get_hygiene_issue_source_data(
@@ -98,7 +103,7 @@ def get_test_issue_source_query(issue_data: dict) -> str:
lookup_data = _get_lookup_data(issue_data["table_groups_id"], issue_data["test_type_id"], "Test Results")
if not lookup_data or not lookup_data.lookup_query:
return None
-
+
test_definition = TestDefinition.get(issue_data["test_definition_id_current"])
if not test_definition:
return None
@@ -107,6 +112,7 @@ def get_test_issue_source_query(issue_data: dict) -> str:
"TARGET_SCHEMA": issue_data["schema_name"],
"TABLE_NAME": issue_data["table_name"],
"COLUMN_NAME": issue_data["column_names"],
+ "COLUMN_TYPE": issue_data["column_type"],
"TEST_DATE": str(issue_data["test_date"]),
"CUSTOM_QUERY": test_definition.custom_query,
"BASELINE_VALUE": test_definition.baseline_value,
@@ -146,7 +152,7 @@ def get_test_issue_source_data(
test_definition = TestDefinition.get(issue_data["test_definition_id_current"])
if not test_definition:
return "NA", "Test definition no longer exists.", None, None
-
+
lookup_query = get_test_issue_source_query(issue_data)
if not lookup_query:
return "NA", "Source data lookup is not available for this test.", None, None
@@ -189,7 +195,7 @@ def get_test_issue_source_data_custom(
test_definition = TestDefinition.get(issue_data["test_definition_id_current"])
if not test_definition:
return "NA", "Test definition no longer exists.", None, None
-
+
lookup_query = get_test_issue_source_query_custom(issue_data)
if not lookup_query:
return "NA", "Source data lookup is not available for this test.", None, None
@@ -249,7 +255,7 @@ def _get_lookup_data_custom(
) -> LookupData | None:
result = fetch_one_from_db(
"""
- SELECT
+ SELECT
d.custom_query as lookup_query
FROM test_definitions d
WHERE d.id = :test_definition_id;
diff --git a/testgen/ui/queries/table_group_queries.py b/testgen/ui/queries/table_group_queries.py
index eac18e06..c698212a 100644
--- a/testgen/ui/queries/table_group_queries.py
+++ b/testgen/ui/queries/table_group_queries.py
@@ -3,6 +3,7 @@
from sqlalchemy.engine import Row
from testgen.commands.queries.profiling_query import CProfilingSQL
+from testgen.common.database.database_service import get_flavor_service
from testgen.common.models.connection import Connection
from testgen.common.models.table_group import TableGroup
from testgen.ui.services.database_service import fetch_from_target_db
@@ -47,14 +48,17 @@ def get_table_group_preview(
)
if verify_table_access:
+ schema_name = table_group_preview["schema"]
+ flavor_service = get_flavor_service(connection.sql_flavor)
+ quote = flavor_service.quote_character
for table_name in table_group_preview["tables"].keys():
try:
results = fetch_from_target_db(
connection,
(
- f"SELECT 1 FROM {table_group_preview['schema']}.{table_name} LIMIT 1"
- if connection.sql_flavor != "mssql"
- else f"SELECT TOP 1 * FROM {table_group_preview['schema']}.{table_name}"
+ f"SELECT 1 FROM {quote}{schema_name}{quote}.{quote}{table_name}{quote} LIMIT 1"
+ if not flavor_service.use_top
+ else f"SELECT TOP 1 * FROM {quote}{schema_name}{quote}.{quote}{table_name}{quote}"
),
)
except Exception as error:
diff --git a/testgen/ui/queries/test_result_queries.py b/testgen/ui/queries/test_result_queries.py
index 52a51767..f11abea6 100644
--- a/testgen/ui/queries/test_result_queries.py
+++ b/testgen/ui/queries/test_result_queries.py
@@ -70,6 +70,7 @@ def get_test_results(
-- These are used in the PDF report
tt.threshold_description, tt.usage_notes, r.test_time,
dcc.description as column_description,
+ dcc.column_type as column_type,
COALESCE(dcc.critical_data_element, dtc.critical_data_element) as critical_data_element,
COALESCE(dcc.data_source, dtc.data_source, tg.data_source) as data_source,
COALESCE(dcc.source_system, dtc.source_system, tg.source_system) as source_system,
diff --git a/testgen/ui/services/database_service.py b/testgen/ui/services/database_service.py
index a094bc84..cf5c7280 100644
--- a/testgen/ui/services/database_service.py
+++ b/testgen/ui/services/database_service.py
@@ -53,12 +53,15 @@ def fetch_one_from_db(query: str, params: dict | None = None) -> RowMapping | No
return result._mapping if result else None
-def fetch_from_target_db(connection: Connection, query: str, params: dict | None = None) -> list[Row]:
+def fetch_from_target_db(connection: Connection, query: str, params: dict | None = None) -> list[Row]:
flavor_service = get_flavor_service(connection.sql_flavor)
flavor_service.init(connection.to_dict())
- connection_string = flavor_service.get_connection_string()
- connect_args = flavor_service.get_connect_args()
- engine = create_engine(connection_string, connect_args=connect_args)
+
+ engine = create_engine(
+ flavor_service.get_connection_string(),
+ connect_args=flavor_service.get_connect_args(),
+ **flavor_service.get_engine_args(),
+ )
with engine.connect() as connection:
cursor: CursorResult = connection.execute(text(query), params)
diff --git a/testgen/ui/services/form_service.py b/testgen/ui/services/form_service.py
index d9e7223c..2d9e99f3 100644
--- a/testgen/ui/services/form_service.py
+++ b/testgen/ui/services/form_service.py
@@ -9,6 +9,7 @@
from pandas.api.types import is_datetime64_any_dtype
from st_aggrid import AgGrid, ColumnsAutoSizeMode, DataReturnMode, GridOptionsBuilder, GridUpdateMode, JsCode
+from testgen.ui.components import widgets as testgen
from testgen.ui.navigation.router import Router
"""
@@ -142,30 +143,26 @@ def render_html_list(dct_row, lst_columns, str_section_header=None, int_data_wid
def render_grid_select(
df: pd.DataFrame,
- show_columns,
- str_prompt=None,
- int_height=400,
- do_multi_select: bool | None = None,
+ columns: list[str],
+ column_headers: list[str] | None = None,
+ id_column: str | None = None,
selection_mode: typing.Literal["single", "multiple", "disabled"] = "single",
- show_column_headers=None,
- render_highlights=True,
- bind_to_query_name: str | None = None,
- bind_to_query_prop: str | None = None,
+ page_size: int = 500,
+ reset_pagination: bool = False,
+ bind_to_query: bool = False,
+ render_highlights: bool = True,
key: str = "aggrid",
-):
+) -> tuple[list[dict], dict]:
"""
- :param do_multi_select: DEPRECATED. boolean to choose between single
- or multiple selection.
:param selection_mode: one of single, multiple or disabled. defaults
to single.
- :param bind_to_query_name: name of the query param where to bind the
- selected row.
- :param bind_to_query_prop: name of the property of the selected row
- which value will be set in the query param.
+ :param bind_to_query: whether to bind the selected row and page to
+ query params.
:param key: Streamlit cache key for the grid. required when binding
selection to query.
"""
- show_prompt(str_prompt)
+ if selection_mode != "disabled" and not id_column:
+ raise ValueError("id_column is required when using 'single' or 'multiple' selection mode")
# Set grid formatting
cellstyle_jscode = JsCode(
@@ -253,39 +250,62 @@ def render_grid_select(
rendering_counter = st.session_state.get(f"{key}_counter") or 0
previous_dataframe = st.session_state.get(f"{key}_dataframe")
- df = df.copy()
if previous_dataframe is not None:
data_changed = not df.equals(previous_dataframe)
- dct_col_to_header = dict(zip(show_columns, show_column_headers, strict=True)) if show_column_headers else None
+ page_changed = st.session_state.get(f"{key}_page_change", False)
+ if page_changed:
+ st.session_state[f"{key}_page_change"] = False
- gb = GridOptionsBuilder.from_dataframe(df)
- selection_mode_ = selection_mode
- if do_multi_select is not None:
- selection_mode_ = "multiple" if do_multi_select else "single"
+ grid_container = st.container()
+ selected_column, paginator_column = st.columns([.5, .5])
+ with paginator_column:
+ def on_page_change():
+ st.session_state[f"{key}_page_change"] = True
+
+ page_index = testgen.paginator(
+ count=len(df),
+ page_size=page_size,
+ page_index=0 if reset_pagination else None,
+ bind_to_query="page" if bind_to_query else None,
+ on_change=on_page_change,
+ key=f"{key}_paginator",
+ )
+ # Prevent flickering data when filters are changed (which triggers 2 reruns - one from filter and another from paginator)
+ page_index = 0 if reset_pagination else page_index
+ paginated_df = df.iloc[page_size * page_index : page_size * (page_index + 1)]
+
+ dct_col_to_header = dict(zip(columns, column_headers, strict=True)) if column_headers else None
+
+ gb = GridOptionsBuilder.from_dataframe(paginated_df)
pre_selected_rows: typing.Any = {}
- if bind_to_query_name and bind_to_query_prop:
- bound_value = st.query_params.get(bind_to_query_name)
- bound_items = df[df[bind_to_query_prop] == bound_value]
+ if selection_mode == "single" and bind_to_query:
+ bound_value = st.query_params.get("selected")
+ bound_items = paginated_df[paginated_df[id_column] == bound_value]
if len(bound_items) > 0:
# https://github.com/PablocFonseca/streamlit-aggrid/issues/207#issuecomment-1793039564
- pre_selected_rows = {str(bound_items.iloc[0][bind_to_query_prop]): True}
+ pre_selected_rows = {str(bound_value): True}
else:
- if data_changed and st.query_params.get(bind_to_query_name):
+ if data_changed and st.query_params.get("selected"):
rendering_counter += 1
- Router().set_query_params({bind_to_query_name: None})
+ Router().set_query_params({"selected": None})
+
+ selection = set()
+ if selection_mode == "multiple":
+ selection = st.session_state.get(f"{key}_multiselection", set())
+ pre_selected_rows = {str(item): True for item in selection}
gb.configure_selection(
- selection_mode=selection_mode_,
- use_checkbox=selection_mode_ == "multiple",
+ selection_mode=selection_mode,
+ use_checkbox=selection_mode == "multiple",
pre_selected_rows=pre_selected_rows,
)
- if bind_to_query_prop:
- gb.configure_grid_options(getRowId=JsCode(f"""function(row) {{ return row.data['{bind_to_query_prop}'] }}"""))
+ if id_column:
+ gb.configure_grid_options(getRowId=JsCode(f"function(row) {{ return row.data['{id_column}'] }}"))
- all_columns = list(df.columns)
+ all_columns = list(paginated_df.columns)
for column in all_columns:
# Define common kwargs for all columns: NOTE THAT FIRST COLUMN HOLDS CHECKBOX AND SHOULD BE SHOWN!
@@ -293,9 +313,11 @@ def render_grid_select(
common_kwargs = {
"field": column,
"header_name": str_header if str_header else ut_prettify_header(column),
- "hide": column not in show_columns,
- "headerCheckboxSelection": selection_mode_ == "multiple" and column == show_columns[0],
- "headerCheckboxSelectionFilteredOnly": selection_mode_ == "multiple" and column == show_columns[0],
+ "hide": column not in columns,
+ "headerCheckboxSelection": selection_mode == "multiple" and column == columns[0],
+ "headerCheckboxSelectionFilteredOnly": selection_mode == "multiple" and column == columns[0],
+ "sortable": False,
+ "filter": False,
}
highlight_kwargs = {
"cellStyle": cellstyle_jscode,
@@ -307,8 +329,8 @@ def render_grid_select(
}
# Check if the column is a date-time column
- if is_datetime64_any_dtype(df[column]):
- if (df[column].dt.time == pd.Timestamp("00:00:00").time()).all():
+ if is_datetime64_any_dtype(paginated_df[column]):
+ if (paginated_df[column].dt.time == pd.Timestamp("00:00:00").time()).all():
format_string = "yyyy-MM-dd"
else:
format_string = "yyyy-MM-dd HH:mm"
@@ -327,49 +349,66 @@ def render_grid_select(
# Apply configuration using kwargs
gb.configure_column(**all_kwargs)
- grid_options = gb.build()
-
# Render Grid: custom_css fixes spacing bug and tightens empty space at top of grid
- grid_data = AgGrid(
- df,
- gridOptions=grid_options,
- theme="balham",
- enable_enterprise_modules=False,
- allow_unsafe_jscode=True,
- update_mode=GridUpdateMode.NO_UPDATE,
- update_on=["selectionChanged"],
- data_return_mode=DataReturnMode.FILTERED_AND_SORTED,
- columns_auto_size_mode=ColumnsAutoSizeMode.FIT_CONTENTS,
- height=int_height,
- custom_css={
- "#gridToolBar": {
- "padding-bottom": "0px !important",
- },
- ".ag-row-hover .ag-cell.status-tag": {
- "border-color": "var(--ag-row-hover-color) !important",
+ with grid_container:
+ grid_options = gb.build()
+ grid_data = AgGrid(
+ paginated_df.copy(),
+ gridOptions=grid_options,
+ theme="balham",
+ enable_enterprise_modules=False,
+ allow_unsafe_jscode=True,
+ update_mode=GridUpdateMode.NO_UPDATE,
+ update_on=["selectionChanged"],
+ data_return_mode=DataReturnMode.FILTERED_AND_SORTED,
+ columns_auto_size_mode=ColumnsAutoSizeMode.FIT_CONTENTS,
+ height=400,
+ custom_css={
+ "#gridToolBar": {
+ "padding-bottom": "0px !important",
+ },
+ ".ag-row-hover .ag-cell.status-tag": {
+ "border-color": "var(--ag-row-hover-color) !important",
+ },
+ ".ag-row-selected .ag-cell.status-tag": {
+ "border-color": "var(--ag-selected-row-background-color) !important",
+ },
},
- ".ag-row-selected .ag-cell.status-tag": {
- "border-color": "var(--ag-selected-row-background-color) !important",
- },
- },
- key=f"{key}_{selection_mode_}_{rendering_counter}",
- reload_data=data_changed,
- )
+ key=f"{key}_{page_index}_{selection_mode}_{rendering_counter}",
+ reload_data=data_changed,
+ )
st.session_state[f"{key}_counter"] = rendering_counter
st.session_state[f"{key}_dataframe"] = df
- selected_rows = grid_data["selected_rows"]
- if len(selected_rows) > 0:
- if bind_to_query_name and bind_to_query_prop:
- Router().set_query_params({bind_to_query_name: selected_rows[0][bind_to_query_prop]})
-
+ if selection_mode != "disabled":
+ selected_rows = grid_data["selected_rows"]
+ # During page change, there are 2 reruns and the first one does not return the selected rows
+ # So we ignore that run to prevent flickering the selected count
+ if not page_changed:
+ selection.difference_update(paginated_df[id_column].to_list())
+ selection.update([row[id_column] for row in selected_rows])
+ st.session_state[f"{key}_multiselection"] = selection
+
+ if selection:
# We need to get the data from the original dataframe
# Otherwise changes to the dataframe (e.g., editing the current selection) do not get reflected in the returned rows
# Adding "modelUpdated" to AgGrid(update_on=...) does not work
# because it causes unnecessary reruns that cause dialogs to close abruptly
- selected_props = [row[bind_to_query_prop] for row in selected_rows]
- selected_df = df[df[bind_to_query_prop].isin(selected_props)]
- selected_rows = json.loads(selected_df.to_json(orient="records"))
-
- return selected_rows
+ selected_df = df[df[id_column].isin(selection)]
+ selected_data = json.loads(selected_df.to_json(orient="records"))
+
+ selected_id, selected_item = None, None
+ if selected_rows:
+ selected_id = selected_rows[len(selected_rows) - 1][id_column]
+ selected_item = next((item for item in selected_data if item[id_column] == selected_id), None)
+ if bind_to_query:
+ Router().set_query_params({"selected": selected_id})
+
+ if selection_mode == "multiple" and (count := len(selected_data)):
+ with selected_column:
+ testgen.caption(f"{count} item{'s' if count != 1 else ''} selected")
+
+ return selected_data, selected_item
+
+ return None, None
diff --git a/testgen/ui/session.py b/testgen/ui/session.py
index e5cd7ebb..e1525d37 100644
--- a/testgen/ui/session.py
+++ b/testgen/ui/session.py
@@ -23,6 +23,7 @@ class TestgenSession(Singleton):
# streamlit_authenticator sets this attribute implicitly
authentication_status: bool
+ initialized: bool
page_pending_cookies: st.Page # type: ignore
page_pending_login: str
page_args_pending_login: dict
diff --git a/testgen/ui/views/connections.py b/testgen/ui/views/connections.py
index ab52de96..c492dde5 100644
--- a/testgen/ui/views/connections.py
+++ b/testgen/ui/views/connections.py
@@ -51,6 +51,12 @@ class ConnectionsPage(Page):
"url",
"http_path",
]
+ encrypted_fields: typing.ClassVar[list[str]] = [
+ "project_pw_encrypted",
+ "private_key",
+ "private_key_passphrase",
+ "service_account_key",
+ ]
def render(self, project_code: str, **_kwargs) -> None:
testgen.page_header(
@@ -95,25 +101,23 @@ def on_save_connection_clicked(updated_connection):
if updated_connection.get("connect_by_key"):
updated_connection["project_pw_encrypted"] = ""
- if is_pristine(updated_connection["private_key_passphrase"]):
+ if is_pristine(updated_connection.get("private_key_passphrase")):
del updated_connection["private_key_passphrase"]
+ elif updated_connection.get("private_key_passphrase") == CLEAR_SENTINEL:
+ updated_connection["private_key_passphrase"] = ""
+
+ if is_pristine(updated_connection.get("private_key")):
+ del updated_connection["private_key"]
+ else:
+ updated_connection["private_key"] = base64.b64decode(updated_connection["private_key"]).decode()
else:
updated_connection["private_key"] = ""
updated_connection["private_key_passphrase"] = ""
- if updated_connection.get("private_key_passphrase") == CLEAR_SENTINEL:
- updated_connection["private_key_passphrase"] = ""
-
- if is_pristine(updated_connection.get("private_key")):
- del updated_connection["private_key"]
- else:
- updated_connection["private_key"] = base64.b64decode(updated_connection["private_key"]).decode()
-
- if is_pristine(updated_connection.get("project_pw_encrypted")):
- del updated_connection["project_pw_encrypted"]
-
- if updated_connection.get("project_pw_encrypted") == CLEAR_SENTINEL:
- updated_connection["project_pw_encrypted"] = ""
+ if is_pristine(updated_connection.get("project_pw_encrypted")):
+ del updated_connection["project_pw_encrypted"]
+ elif updated_connection.get("project_pw_encrypted") == CLEAR_SENTINEL:
+ updated_connection["project_pw_encrypted"] = ""
updated_connection["sql_flavor"] = self._get_sql_flavor_from_value(updated_connection["sql_flavor_code"]).flavor
@@ -162,7 +166,7 @@ def on_test_connection_clicked(updated_connection: dict) -> None:
message = "Error creating connection"
success = False
LOG.exception(message)
-
+
results = {
"success": success,
"message": message,
@@ -204,6 +208,8 @@ def _sanitize_connection_input(self, connection: dict) -> dict:
sanitized_value = value
if isinstance(value, str) and key in self.trim_fields:
sanitized_value = value.strip()
+ if isinstance(value, str) and key in self.encrypted_fields:
+ sanitized_value = value if value != "" else None
sanitized_connection_input[key] = sanitized_value
return sanitized_connection_input
@@ -441,6 +447,12 @@ class ConnectionFlavor:
flavor="redshift",
icon=get_asset_data_url("flavors/redshift.svg"),
),
+ ConnectionFlavor(
+ label="Amazon Redshift Spectrum",
+ value="redshift_spectrum",
+ flavor="redshift_spectrum",
+ icon=get_asset_data_url("flavors/redshift.svg"),
+ ),
ConnectionFlavor(
label="Azure SQL Database",
value="azure_mssql",
@@ -453,6 +465,18 @@ class ConnectionFlavor:
flavor="mssql",
icon=get_asset_data_url("flavors/azure_synapse_table.svg"),
),
+ ConnectionFlavor(
+ label="Databricks",
+ value="databricks",
+ flavor="databricks",
+ icon=get_asset_data_url("flavors/databricks.svg"),
+ ),
+ ConnectionFlavor(
+ label="Google BigQuery",
+ value="bigquery",
+ flavor="bigquery",
+ icon=get_asset_data_url("flavors/bigquery.svg"),
+ ),
ConnectionFlavor(
label="Microsoft SQL Server",
value="mssql",
@@ -471,10 +495,4 @@ class ConnectionFlavor:
flavor="snowflake",
icon=get_asset_data_url("flavors/snowflake.svg"),
),
- ConnectionFlavor(
- label="Databricks",
- value="databricks",
- flavor="databricks",
- icon=get_asset_data_url("flavors/databricks.svg"),
- ),
]
diff --git a/testgen/ui/views/data_catalog.py b/testgen/ui/views/data_catalog.py
index 58ef1bee..f4984577 100644
--- a/testgen/ui/views/data_catalog.py
+++ b/testgen/ui/views/data_catalog.py
@@ -185,7 +185,7 @@ def get_excel_report_data(update_progress: PROGRESS_UPDATE_TYPE, table_group: Ta
data = pd.DataFrame(table_data + column_data)
data = data.sort_values(by=["table_name", "ordinal_position"], na_position="first", key=lambda x: x.str.lower() if x.dtype == "object" else x)
- for key in ["column_type", "datatype_suggestion"]:
+ for key in ["datatype_suggestion"]:
data[key] = data[key].apply(lambda val: val.lower() if not pd.isna(val) else None)
for key in ["avg_embedded_spaces", "avg_length", "avg_value", "stdev_value"]:
@@ -228,7 +228,7 @@ def get_excel_report_data(update_progress: PROGRESS_UPDATE_TYPE, table_group: Ta
"active_test_count": {"header": "Active tests"},
"ordinal_position": {"header": "Position"},
"general_type": {},
- "column_type": {"header": "Data type"},
+ "db_data_type": {"header": "Data type"},
"datatype_suggestion": {"header": "Suggested data type"},
"functional_data_type": {"header": "Semantic data type"},
"add_date": {"header": "First detected"},
@@ -396,7 +396,7 @@ def get_table_group_columns(table_group_id: str) -> list[dict]:
table_chars.table_name,
column_chars.schema_name,
column_chars.general_type,
- column_chars.column_type,
+ column_chars.db_data_type,
column_chars.functional_data_type,
profile_results.datatype_suggestion,
table_chars.record_ct,
diff --git a/testgen/ui/views/dialogs/data_preview_dialog.py b/testgen/ui/views/dialogs/data_preview_dialog.py
index 12a7648f..d2837e3e 100644
--- a/testgen/ui/views/dialogs/data_preview_dialog.py
+++ b/testgen/ui/views/dialogs/data_preview_dialog.py
@@ -1,6 +1,7 @@
import pandas as pd
import streamlit as st
+from testgen.common.database.database_service import get_flavor_service
from testgen.common.models.connection import Connection
from testgen.ui.components import widgets as testgen
from testgen.ui.services.database_service import fetch_from_target_db
@@ -45,12 +46,14 @@ def get_preview_data(
connection = Connection.get_by_table_group(table_group_id)
if connection:
- use_top = connection.sql_flavor == "mssql"
+ flavor_service = get_flavor_service(connection.sql_flavor)
+ use_top = flavor_service.use_top
+ quote = flavor_service.quote_character
query = f"""
SELECT DISTINCT
{"TOP 100" if use_top else ""}
- {column_name or "*"}
- FROM {schema_name}.{table_name}
+ {f"{quote}{column_name}{quote}" if column_name else "*"}
+ FROM {quote}{schema_name}{quote}.{quote}{table_name}{quote}
{"LIMIT 100" if not use_top else ""}
"""
diff --git a/testgen/ui/views/dialogs/manage_schedules.py b/testgen/ui/views/dialogs/manage_schedules.py
index 820cd782..85292565 100644
--- a/testgen/ui/views/dialogs/manage_schedules.py
+++ b/testgen/ui/views/dialogs/manage_schedules.py
@@ -2,7 +2,6 @@
import zoneinfo
from datetime import datetime
from typing import Any
-from uuid import UUID
import cron_converter
import cron_descriptor
@@ -14,7 +13,7 @@
from testgen.ui.components import widgets as testgen
from testgen.ui.session import session, temp_value
-
+CRON_SAMPLE_COUNT = 3
class ScheduleDialog:
title: str = ""
@@ -43,16 +42,20 @@ def open(self, project_code: str) -> None:
return st.dialog(title=self.title)(self.render)()
def render(self) -> None:
+ @with_database_session
def on_delete_sched(item):
- with Session() as db_session:
- try:
- sched, = db_session.query(JobSchedule).where(JobSchedule.id == UUID(item["id"]))
- db_session.delete(sched)
- except ValueError:
- db_session.rollback()
- else:
- db_session.commit()
- st.rerun(scope="fragment")
+ JobSchedule.delete(item["id"])
+ st.rerun(scope="fragment")
+
+ @with_database_session
+ def on_pause_sched(item):
+ JobSchedule.update_active(item["id"], False)
+ st.rerun(scope="fragment")
+
+ @with_database_session
+ def on_resume_sched(item):
+ JobSchedule.update_active(item["id"], True)
+ st.rerun(scope="fragment")
def on_cron_sample(payload: dict[str, str]):
try:
@@ -66,7 +69,7 @@ def on_cron_sample(payload: dict[str, str]):
)
set_cron_sample({
- "sample": cron_schedule.next().strftime("%a %b %-d, %-I:%M %p"),
+ "samples": [cron_schedule.next().strftime("%a %b %-d, %-I:%M %p") for _ in range(CRON_SAMPLE_COUNT)],
"readable_expr": readble_cron_schedule,
})
except ValueError as e:
@@ -113,6 +116,7 @@ def on_add_schedule(payload: dict[str, str]):
key=self.job_key,
cron_expr=cron_obj.to_string(),
cron_tz=cron_tz,
+ active=True,
args=args,
kwargs=kwargs,
)
@@ -147,11 +151,13 @@ def on_add_schedule(payload: dict[str, str]):
"cronTz": job.cron_tz_str,
"sample": [
sample.strftime("%a %b %-d, %-I:%M %p")
- for sample in job.get_sample_triggering_timestamps(2)
+ for sample in job.get_sample_triggering_timestamps(CRON_SAMPLE_COUNT + 1)
],
+ "active": job.active,
}
scheduled_jobs_json.append(job_json)
+ testgen.css_class("l-dialog")
testgen.testgen_component(
"schedule_list",
props={
@@ -163,6 +169,8 @@ def on_add_schedule(payload: dict[str, str]):
"results": results,
},
event_handlers={
+ "PauseSchedule": on_pause_sched,
+ "ResumeSchedule": on_resume_sched,
"DeleteSchedule": on_delete_sched,
},
on_change_handlers={
diff --git a/testgen/ui/views/dialogs/table_create_script_dialog.py b/testgen/ui/views/dialogs/table_create_script_dialog.py
index 992568ba..1bcd386e 100644
--- a/testgen/ui/views/dialogs/table_create_script_dialog.py
+++ b/testgen/ui/views/dialogs/table_create_script_dialog.py
@@ -20,13 +20,13 @@ def generate_create_script(table_name: str, data: list[dict]) -> str | None:
col_defs = []
for index, col in enumerate(table_data):
comment = (
- f"-- WAS {col['column_type']}"
- if isinstance(col["column_type"], str)
+ f"-- WAS {col['db_data_type']}"
+ if isinstance(col["db_data_type"], str)
and isinstance(col["datatype_suggestion"], str)
- and col["column_type"].lower() != col["datatype_suggestion"].lower()
+ and col["db_data_type"].lower() != col["datatype_suggestion"].lower()
else ""
)
- col_type = col["datatype_suggestion"] or col["column_type"] or ""
+ col_type = col["datatype_suggestion"] or col["db_data_type"] or ""
separator = " " if index == len(table_data) - 1 else ","
col_defs.append(f"{col['column_name']:<{max_name}} {(col_type):<{max_type}}{separator} {comment}")
diff --git a/testgen/ui/views/hygiene_issues.py b/testgen/ui/views/hygiene_issues.py
index 7cc788d0..7c6d3b1d 100644
--- a/testgen/ui/views/hygiene_issues.py
+++ b/testgen/ui/views/hygiene_issues.py
@@ -44,9 +44,9 @@ def render(
self,
run_id: str,
likelihood: str | None = None,
- issue_type: str | None = None,
table_name: str | None = None,
column_name: str | None = None,
+ issue_type: str | None = None,
action: str | None = None,
**_kwargs,
) -> None:
@@ -70,13 +70,20 @@ def render(
],
)
- others_summary_column, pii_summary_column, score_column, actions_column, export_button_column = st.columns([.2, .2, .15, .3, .15], vertical_alignment="bottom")
+ others_summary_column, pii_summary_column, score_column, actions_column, export_button_column = st.columns([.25, .2, .1, .3, .15], vertical_alignment="bottom")
(liklihood_filter_column, table_filter_column, column_filter_column, issue_type_filter_column, action_filter_column, sort_column) = (
st.columns([.2, .15, .2, .2, .15, .1], vertical_alignment="bottom")
)
- testgen.flex_row_end(actions_column)
+ testgen.flex_row_end(actions_column, wrap=True)
testgen.flex_row_end(export_button_column)
+ filters_changed = False
+ current_filters = (likelihood, table_name, column_name, issue_type, action)
+ if (query_filters := st.session_state.get("hygiene_issues:filters")) != current_filters:
+ if query_filters:
+ filters_changed = True
+ st.session_state["hygiene_issues:filters"] = current_filters
+
with liklihood_filter_column:
likelihood = testgen.select(
options=["Definite", "Likely", "Possible", "Potential PII"],
@@ -114,7 +121,6 @@ def render(
)
column_name = testgen.select(
options=column_options,
- value_column="column_name",
default_value=column_name,
bind_to_query="column_name",
label="Column",
@@ -160,8 +166,10 @@ def render(
sorting_columns = testgen.sorting_selector(sortable_columns, default)
with actions_column:
- str_help = "Toggle on to perform actions on multiple Hygiene Issues"
- do_multi_select = st.toggle("Multi-Select", help=str_help)
+ multi_select = st.toggle(
+ "Multi-Select",
+ help="Toggle on to perform actions on multiple Hygiene Issues",
+ )
with st.container():
with st.spinner("Loading data ..."):
@@ -178,48 +186,27 @@ def render(
summaries = get_profiling_anomaly_summary(run_id)
others_summary = [summary for summary in summaries if summary.get("type") != "PII"]
with others_summary_column:
- testgen.summary_bar(
+ testgen.summary_counts(
items=others_summary,
label="Hygiene Issues",
- height=20,
- width=400,
)
anomalies_pii_summary = [summary for summary in summaries if summary.get("type") == "PII"]
if anomalies_pii_summary:
with pii_summary_column:
- testgen.summary_bar(
+ testgen.summary_counts(
items=anomalies_pii_summary,
- label="Potential PII",
- height=20,
- width=400,
+ label="Potential PII (Risk)",
)
- lst_show_columns = [
- "table_name",
- "column_name",
- "issue_likelihood",
- "action",
- "anomaly_name",
- "detail",
- ]
-
- # Show main grid and retrieve selections
- selected = fm.render_grid_select(
+ selected, selected_row = fm.render_grid_select(
df_pa,
- lst_show_columns,
- int_height=400,
- do_multi_select=do_multi_select,
- bind_to_query_name="selected",
- bind_to_query_prop="id",
- show_column_headers=[
- "Table",
- "Column",
- "Likelihood",
- "Action",
- "Issue Type",
- "Detail"
- ]
+ ["table_name", "column_name", "issue_likelihood", "action", "anomaly_name", "detail"],
+ ["Table", "Column", "Likelihood", "Action", "Issue Type", "Detail"],
+ id_column="id",
+ selection_mode="multiple" if multi_select else "single",
+ reset_pagination=filters_changed,
+ bind_to_query=True,
)
popover_container = export_button_column.empty()
@@ -245,22 +232,16 @@ def open_download_dialog(data: pd.DataFrame | None = None) -> None:
if selected:
st.button(label="Selected issues", type="tertiary", on_click=partial(open_download_dialog, pd.DataFrame(selected)))
- if not df_pa.empty:
- if selected:
- # Always show details for last selected row
- selected_row = selected[len(selected) - 1]
- else:
- selected_row = None
-
- # Display hygiene issue detail for selected row
- if not selected_row:
- st.markdown(":orange[Select a record to see more information.]")
- else:
- _, buttons_column = st.columns([0.5, 0.5])
+ # Display hygiene issue detail for selected row
+ if not selected:
+ st.markdown(":orange[Select a record to see more information.]")
+ else:
+ _, buttons_column = st.columns([0.5, 0.5])
- with buttons_column:
- col1, col2, col3 = st.columns([.3, .3, .3])
+ with buttons_column:
+ col1, col2, col3 = st.columns([.3, .3, .3])
+ if selected_row:
with col1:
view_profiling_button(
selected_row["column_name"], selected_row["table_name"], selected_row["table_groups_id"]
@@ -277,39 +258,40 @@ def open_download_dialog(data: pd.DataFrame | None = None) -> None:
)
source_data_dialog(selected_row)
- with col3:
- if st.button(
- ":material/download: Issue Report",
- use_container_width=True,
- help="Generate a PDF report for each selected issue",
- ):
- MixpanelService().send_event(
- "download-issue-report",
- page=self.path,
- issue_count=len(selected),
+ with col3:
+ if st.button(
+ ":material/download: Issue Report",
+ use_container_width=True,
+ help="Generate a PDF report for each selected issue",
+ ):
+ MixpanelService().send_event(
+ "download-issue-report",
+ page=self.path,
+ issue_count=len(selected),
+ )
+ dialog_title = "Download Issue Report"
+ if len(selected) == 1:
+ download_dialog(
+ dialog_title=dialog_title,
+ file_content_func=get_report_file_data,
+ args=(selected[0],),
)
- dialog_title = "Download Issue Report"
- if len(selected) == 1:
- download_dialog(
- dialog_title=dialog_title,
- file_content_func=get_report_file_data,
- args=(selected[0],),
- )
- else:
- zip_func = zip_multi_file_data(
- "testgen_hygiene_issue_reports.zip",
- get_report_file_data,
- [(arg,) for arg in selected],
- )
- download_dialog(dialog_title=dialog_title, file_content_func=zip_func)
+ else:
+ zip_func = zip_multi_file_data(
+ "testgen_hygiene_issue_reports.zip",
+ get_report_file_data,
+ [(arg,) for arg in selected],
+ )
+ download_dialog(dialog_title=dialog_title, file_content_func=zip_func)
+ if selected_row:
fm.render_html_list(
selected_row,
[
"anomaly_name",
"table_name",
"column_name",
- "column_type",
+ "db_data_type",
"anomaly_description",
"detail",
"likelihood_explanation",
@@ -318,8 +300,6 @@ def open_download_dialog(data: pd.DataFrame | None = None) -> None:
"Hygiene Issue Detail",
int_data_width=700,
)
- else:
- st.markdown(":green[**No Hygiene Issues Found**]")
cached_functions = [get_anomaly_disposition, get_profiling_anomaly_summary, get_profiling_anomalies]
@@ -421,7 +401,7 @@ def get_profiling_anomalies(
r.table_name,
r.column_name,
r.schema_name,
- r.column_type,
+ r.db_data_type,
t.anomaly_name,
t.issue_likelihood,
r.disposition,
@@ -540,8 +520,8 @@ def get_profiling_anomaly_summary(profile_run_id: str) -> list[dict]:
{ "label": "Likely", "value": result.likely_ct, "color": "orange" },
{ "label": "Possible", "value": result.possible_ct, "color": "yellow" },
{ "label": "Dismissed", "value": result.dismissed_ct, "color": "grey" },
- { "label": "High Risk", "value": result.pii_high_ct, "color": "red", "type": "PII" },
- { "label": "Moderate Risk", "value": result.pii_moderate_ct, "color": "orange", "type": "PII" },
+ { "label": "High", "value": result.pii_high_ct, "color": "red", "type": "PII" },
+ { "label": "Moderate", "value": result.pii_moderate_ct, "color": "orange", "type": "PII" },
{ "label": "Dismissed", "value": result.pii_dismissed_ct, "color": "grey", "type": "PII" },
]
@@ -591,7 +571,7 @@ def source_data_dialog(selected_row):
st.markdown("#### SQL Query")
query = get_hygiene_issue_source_query(selected_row)
if query:
- st.code(query, language="sql")
+ st.code(query, language="sql", height=100)
with st.spinner("Retrieving source data..."):
bad_data_status, bad_data_msg, _, df_bad = get_hygiene_issue_source_data(selected_row, limit=500)
@@ -610,7 +590,7 @@ def source_data_dialog(selected_row):
if len(df_bad) == 500:
testgen.caption("* Top 500 records displayed", "text-align: right;")
# Display the dataframe
- st.dataframe(df_bad, height=500, width=1050, hide_index=True)
+ st.dataframe(df_bad, width=1050, hide_index=True)
def do_disposition_update(selected, str_new_status):
diff --git a/testgen/ui/views/profiling_results.py b/testgen/ui/views/profiling_results.py
index 9519c314..e789f3a4 100644
--- a/testgen/ui/views/profiling_results.py
+++ b/testgen/ui/views/profiling_results.py
@@ -60,37 +60,53 @@ def render(self, run_id: str, table_name: str | None = None, column_name: str |
[.3, .3, .08, .32], vertical_alignment="bottom"
)
+ filters_changed = False
+ current_filters = (table_name, column_name)
+ if (query_filters := st.session_state.get("profiling_results:filters")) != current_filters:
+ if query_filters:
+ filters_changed = True
+ st.session_state["profiling_results:filters"] = current_filters
+
+ run_columns_df = get_profiling_run_columns(run_id)
with table_filter_column:
- # Table Name filter
- df = get_profiling_run_tables(run_id)
- df = df.sort_values("table_name", key=lambda x: x.str.lower())
table_name = testgen.select(
- options=df,
- value_column="table_name",
+ options=list(run_columns_df["table_name"].unique()),
default_value=table_name,
bind_to_query="table_name",
label="Table",
)
with column_filter_column:
- # Column Name filter
- df = get_profiling_run_columns(run_id, table_name)
- df = df.sort_values("column_name", key=lambda x: x.str.lower())
+ if table_name:
+ column_options = (
+ run_columns_df
+ .loc[run_columns_df["table_name"] == table_name]
+ ["column_name"]
+ .dropna()
+ .unique()
+ .tolist()
+ )
+ else:
+ column_options = (
+ run_columns_df
+ .groupby("column_name")
+ .first()
+ .reset_index()
+ .sort_values("column_name", key=lambda x: x.str.lower())
+ )
column_name = testgen.select(
- options=df,
- value_column="column_name",
+ options=column_options,
default_value=column_name,
bind_to_query="column_name",
label="Column",
- disabled=not table_name,
- accept_new_options=bool(table_name),
+ accept_new_options=True,
)
with sort_column:
sortable_columns = (
("Table", "LOWER(table_name)"),
("Column", "LOWER(column_name)"),
- ("Data Type", "LOWER(column_type)"),
+ ("Data Type", "LOWER(db_data_type)"),
("Semantic Data Type", "semantic_data_type"),
("Hygiene Issues", "hygiene_issues"),
)
@@ -107,27 +123,13 @@ def render(self, run_id: str, table_name: str | None = None, column_name: str |
sorting_columns=sorting_columns,
)
- show_columns = [
- "table_name",
- "column_name",
- "column_type",
- "semantic_data_type",
- "hygiene_issues",
- ]
- show_column_headers = [
- "Table",
- "Column",
- "Data Type",
- "Semantic Data Type",
- "Hygiene Issues",
- ]
-
- selected_row = fm.render_grid_select(
+ selected, selected_row = fm.render_grid_select(
df,
- show_columns,
- bind_to_query_name="selected",
- bind_to_query_prop="id",
- show_column_headers=show_column_headers,
+ ["table_name", "column_name", "db_data_type", "semantic_data_type", "hygiene_issues"],
+ ["Table", "Column", "Data Type", "Semantic Data Type", "Hygiene Issues"],
+ id_column="id",
+ reset_pagination=filters_changed,
+ bind_to_query=True,
)
popover_container = export_button_column.empty()
@@ -150,19 +152,18 @@ def open_download_dialog(data: pd.DataFrame | None = None) -> None:
css_class("tg--export-wrapper")
st.button(label="All results", type="tertiary", on_click=open_download_dialog)
st.button(label="Filtered results", type="tertiary", on_click=partial(open_download_dialog, df))
- if selected_row:
- st.button(label="Selected results", type="tertiary", on_click=partial(open_download_dialog, pd.DataFrame(selected_row)))
+ if selected:
+ st.button(label="Selected results", type="tertiary", on_click=partial(open_download_dialog, pd.DataFrame(selected)))
# Display profiling for selected row
if not selected_row:
st.markdown(":orange[Select a row to see profiling details.]")
else:
- item = selected_row[0]
- item["hygiene_issues"] = profiling_queries.get_hygiene_issues(run_id, item["table_name"], item.get("column_name"))
+ selected_row["hygiene_issues"] = profiling_queries.get_hygiene_issues(run_id, selected_row["table_name"], selected_row.get("column_name"))
testgen_component(
"column_profiling_results",
- props={ "column": json.dumps(item), "data_preview": True },
+ props={ "column": json.dumps(selected_row), "data_preview": True },
on_change_handlers={
"DataPreviewClicked": lambda item: data_preview_dialog(
item["table_group_id"],
@@ -189,7 +190,7 @@ def get_excel_report_data(
data = profiling_queries.get_profiling_results(run_id)
date_service.accommodate_dataframe_to_timezone(data, st.session_state)
- for key in ["column_type", "datatype_suggestion"]:
+ for key in ["datatype_suggestion"]:
data[key] = data[key].apply(lambda val: val.lower() if not pd.isna(val) else None)
for key in ["avg_embedded_spaces", "avg_length", "avg_value", "stdev_value"]:
@@ -221,7 +222,7 @@ def get_excel_report_data(
"column_name": {"header": "Column"},
"position": {},
"general_type": {},
- "column_type": {"header": "Data type"},
+ "db_data_type": {"header": "Data type"},
"datatype_suggestion": {"header": "Suggested data type"},
"semantic_data_type": {},
"record_ct": {"header": "Record count"},
@@ -279,27 +280,11 @@ def get_excel_report_data(
@st.cache_data(show_spinner=False)
-def get_profiling_run_tables(profiling_run_id: str) -> pd.DataFrame:
+def get_profiling_run_columns(profiling_run_id: str) -> pd.DataFrame:
query = """
- SELECT DISTINCT table_name
+ SELECT table_name, column_name
FROM profile_results
WHERE profile_run_id = :profiling_run_id
- ORDER BY table_name;
+ ORDER BY LOWER(table_name), LOWER(column_name);
"""
return fetch_df_from_db(query, {"profiling_run_id": profiling_run_id})
-
-
-@st.cache_data(show_spinner=False)
-def get_profiling_run_columns(profiling_run_id: str, table_name: str) -> pd.DataFrame:
- query = """
- SELECT DISTINCT column_name
- FROM profile_results
- WHERE profile_run_id = :profiling_run_id
- AND table_name = :table_name
- ORDER BY column_name;
- """
- params = {
- "profiling_run_id": profiling_run_id,
- "table_name": table_name or "",
- }
- return fetch_df_from_db(query, params)
diff --git a/testgen/ui/views/profiling_runs.py b/testgen/ui/views/profiling_runs.py
index 34f2154e..ffea8d10 100644
--- a/testgen/ui/views/profiling_runs.py
+++ b/testgen/ui/views/profiling_runs.py
@@ -1,4 +1,3 @@
-import json
import logging
import typing
from collections.abc import Iterable
@@ -17,14 +16,13 @@
from testgen.ui.components.widgets import testgen_component
from testgen.ui.navigation.menu import MenuItem
from testgen.ui.navigation.page import Page
+from testgen.ui.navigation.router import Router
from testgen.ui.session import session, temp_value
from testgen.ui.views.dialogs.manage_schedules import ScheduleDialog
from testgen.ui.views.dialogs.run_profiling_dialog import run_profiling_dialog
-from testgen.utils import friendly_score, to_dataframe, to_int
+from testgen.utils import friendly_score, to_int
LOG = logging.getLogger("testgen")
-FORM_DATA_WIDTH = 400
-PAGE_SIZE = 50
PAGE_ICON = "data_thresholding"
PAGE_TITLE = "Profiling Runs"
@@ -48,75 +46,54 @@ def render(self, project_code: str, table_group_id: str | None = None, **_kwargs
"investigate-profiling",
)
- user_can_run = session.auth.user_has_permission("edit")
- if render_empty_state(project_code, user_can_run):
- return
-
- group_filter_column, actions_column = st.columns([.3, .7], vertical_alignment="bottom")
-
- with group_filter_column:
- table_groups = TableGroup.select_minimal_where(TableGroup.project_code == project_code)
- table_groups_df = to_dataframe(table_groups, TableGroupMinimal.columns())
- table_groups_df["id"] = table_groups_df["id"].apply(lambda x: str(x))
- table_group_id = testgen.select(
- options=table_groups_df,
- value_column="id",
- display_column="table_groups_name",
- default_value=table_group_id,
- bind_to_query="table_group_id",
- label="Table Group",
- placeholder="---",
- )
-
- with actions_column:
- testgen.flex_row_end()
-
- st.button(
- ":material/today: Profiling Schedules",
- help="Manage when profiling should run for table groups",
- on_click=partial(ProfilingScheduleDialog().open, project_code)
- )
-
- if user_can_run:
- st.button(
- ":material/play_arrow: Run Profiling",
- help="Run profiling for a table group",
- on_click=partial(run_profiling_dialog, project_code, None, table_group_id)
- )
- fm.render_refresh_button(actions_column)
-
- testgen.whitespace(0.5)
- list_container = st.container()
-
with st.spinner("Loading data ..."):
+ project_summary = Project.get_summary(project_code)
profiling_runs = ProfilingRun.select_summary(project_code, table_group_id)
+ table_groups = TableGroup.select_minimal_where(TableGroup.project_code == project_code)
- paginated = []
- if run_count := len(profiling_runs):
- page_index = testgen.paginator(count=run_count, page_size=PAGE_SIZE)
- profiling_runs = [
- {
- **row.to_dict(json_safe=True),
- "dq_score_profiling": friendly_score(row.dq_score_profiling),
- } for row in profiling_runs
- ]
- paginated = profiling_runs[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)]
-
- with list_container:
- testgen_component(
- "profiling_runs",
- props={
- "items": json.dumps(paginated),
- "permissions": {
- "can_run": user_can_run,
- "can_edit": user_can_run,
- },
+ testgen_component(
+ "profiling_runs",
+ props={
+ "project_summary": project_summary.to_dict(json_safe=True),
+ "profiling_runs": [
+ {
+ **run.to_dict(json_safe=True),
+ "dq_score_profiling": friendly_score(run.dq_score_profiling),
+ } for run in profiling_runs
+ ],
+ "table_group_options": [
+ {
+ "value": str(table_group.id),
+ "label": table_group.table_groups_name,
+ "selected": str(table_group_id) == str(table_group.id),
+ } for table_group in table_groups
+ ],
+ "permissions": {
+ "can_edit": session.auth.user_has_permission("edit"),
},
- event_handlers={
- "RunCanceled": on_cancel_run,
- "RunsDeleted": partial(on_delete_runs, project_code, table_group_id),
- }
- )
+ },
+ on_change_handlers={
+ "FilterApplied": on_profiling_runs_filtered,
+ "RunSchedulesClicked": lambda *_: ProfilingScheduleDialog().open(project_code),
+ "RunProfilingClicked": lambda *_: run_profiling_dialog(project_code, None, table_group_id),
+ "RefreshData": refresh_data,
+ "RunsDeleted": partial(on_delete_runs, project_code, table_group_id),
+ },
+ event_handlers={
+ "RunCanceled": on_cancel_run,
+ },
+ )
+
+
+class ProfilingRunFilters(typing.TypedDict):
+ table_group_id: str
+
+def on_profiling_runs_filtered(filters: ProfilingRunFilters) -> None:
+ Router().set_query_params(filters)
+
+
+def refresh_data(*_) -> None:
+ ProfilingRun.select_summary.clear()
class ProfilingScheduleDialog(ScheduleDialog):
@@ -142,47 +119,6 @@ def get_job_arguments(self, arg_value: str) -> tuple[list[typing.Any], dict[str,
return [], {"table_group_id": str(arg_value)}
-def render_empty_state(project_code: str, user_can_run: bool) -> bool:
- project_summary = Project.get_summary(project_code)
- if project_summary.profiling_run_count:
- return False
-
- label = "No profiling runs yet"
- testgen.whitespace(5)
- if not project_summary.connection_count:
- testgen.empty_state(
- label=label,
- icon=PAGE_ICON,
- message=testgen.EmptyStateMessage.Connection,
- action_label="Go to Connections",
- link_href="connections",
- link_params={ "project_code": project_code },
- )
- elif not project_summary.table_group_count:
- testgen.empty_state(
- label=label,
- icon=PAGE_ICON,
- message=testgen.EmptyStateMessage.TableGroup,
- action_label="Go to Table Groups",
- link_href="table-groups",
- link_params={
- "project_code": project_code,
- "connection_id": str(project_summary.default_connection_id),
- },
- )
- else:
- testgen.empty_state(
- label=label,
- icon=PAGE_ICON,
- message=testgen.EmptyStateMessage.Profiling,
- action_label="Run Profiling",
- action_disabled=not user_can_run,
- button_onclick=partial(run_profiling_dialog, project_code),
- button_icon="play_arrow",
- )
- return True
-
-
def on_cancel_run(profiling_run: dict) -> None:
process_status, process_message = process_service.kill_profile_run(to_int(profiling_run["process_id"]))
if process_status:
diff --git a/testgen/ui/views/test_definitions.py b/testgen/ui/views/test_definitions.py
index 4dbe4c76..8b8a2a89 100644
--- a/testgen/ui/views/test_definitions.py
+++ b/testgen/ui/views/test_definitions.py
@@ -44,7 +44,14 @@ class TestDefinitionsPage(Page):
lambda: "test_suite_id" in st.query_params or "test-suites",
]
- def render(self, test_suite_id: str, table_name: str | None = None, column_name: str | None = None, **_kwargs) -> None:
+ def render(
+ self,
+ test_suite_id: str,
+ table_name: str | None = None,
+ column_name: str | None = None,
+ test_type: str | None = None,
+ **_kwargs,
+ ) -> None:
test_suite = TestSuite.get(test_suite_id)
if not test_suite:
self.router.navigate_with_warning(
@@ -74,25 +81,35 @@ def render(self, test_suite_id: str, table_name: str | None = None, column_name:
testgen.flex_row_start(actions_column)
testgen.flex_row_end(disposition_column)
+ filters_changed = False
+ current_filters = (table_name, column_name, test_type)
+ if (query_filters := st.session_state.get("test_definitions:filters")) != current_filters:
+ if query_filters:
+ filters_changed = True
+ st.session_state["test_definitions:filters"] = current_filters
+
with table_filter_column:
columns_df = get_test_suite_columns(test_suite_id)
table_options = list(columns_df["table_name"].unique())
table_name = testgen.select(
options=table_options,
value_column="table_name",
- default_value=table_name or (table_options[0] if table_options else None),
+ default_value=table_name,
bind_to_query="table_name",
- required=True,
label="Table",
)
with column_filter_column:
- column_options = columns_df.loc[columns_df["table_name"] == table_name]["column_name"].dropna().unique().tolist()
+ if table_name:
+ column_options = columns_df.loc[
+ columns_df["table_name"] == table_name
+ ]["column_name"].dropna().unique().tolist()
+ else:
+ column_options = columns_df.groupby("column_name").first().reset_index().sort_values("column_name", key=lambda x: x.str.lower())
column_name = testgen.select(
options=column_options,
default_value=column_name,
bind_to_query="column_name",
label="Column",
- disabled=not table_name,
accept_new_options=True,
)
with test_filter_column:
@@ -101,29 +118,57 @@ def render(self, test_suite_id: str, table_name: str | None = None, column_name:
options=test_options,
value_column="test_type",
display_column="test_name_short",
- default_value=None,
+ default_value=test_type,
bind_to_query="test_type",
label="Test Type",
)
- with disposition_column:
- str_help = "Toggle on to perform actions on multiple test definitions"
- do_multi_select = user_can_disposition and st.toggle("Multi-Select", help=str_help)
+ if user_can_disposition:
+ with disposition_column:
+ multi_select = st.toggle("Multi-Select", help="Toggle on to perform actions on multiple test definitions")
- if user_can_edit and actions_column.button(
- ":material/add: Add", help="Add a new Test Definition"
- ):
- add_test_dialog(table_group, test_suite, table_name, column_name)
+ if user_can_edit:
+ if actions_column.button(
+ ":material/add: Add",
+ help="Add a new Test Definition",
+ ):
+ add_test_dialog(table_group, test_suite, table_name, column_name)
- if user_can_edit and table_actions_column.button(
- ":material/play_arrow: Run Tests",
- help="Run test suite's tests",
- ):
- run_tests_dialog(project_code, test_suite)
+ if table_actions_column.button(
+ ":material/play_arrow: Run Tests",
+ help="Run test suite's tests",
+ ):
+ run_tests_dialog(project_code, test_suite)
+
+ with st.container():
+ with st.spinner("Loading data ..."):
+ df = get_test_definitions(test_suite, table_name, column_name, test_type)
+
+ selected, selected_test_def = render_grid(df, multi_select, filters_changed)
+
+ popover_container = table_actions_column.empty()
+
+ def open_download_dialog(data: pd.DataFrame | None = None) -> None:
+ # Hack to programmatically close popover: https://github.com/streamlit/streamlit/issues/8265#issuecomment-3001655849
+ with popover_container.container():
+ flex_row_end()
+ st.button(label="Export", icon=":material/download:", disabled=True)
+
+ download_dialog(
+ dialog_title="Download Excel Report",
+ file_content_func=get_excel_report_data,
+ args=(test_suite, table_group.table_group_schema, data),
+ )
+
+ with popover_container.container(key="tg--export-popover"):
+ flex_row_end()
+ with st.popover(label="Export", icon=":material/download:", help="Download test definitions to Excel"):
+ css_class("tg--export-wrapper")
+ st.button(label="All tests", type="tertiary", on_click=open_download_dialog)
+ st.button(label="Filtered tests", type="tertiary", on_click=partial(open_download_dialog, df))
+ if selected:
+ st.button(label="Selected tests", type="tertiary", on_click=partial(open_download_dialog, pd.DataFrame(selected)))
- selected = show_test_defs_grid(
- test_suite, table_name, column_name, test_type, do_multi_select, table_actions_column, table_group
- )
fm.render_refresh_button(table_actions_column)
if user_can_disposition:
@@ -156,9 +201,6 @@ def render(self, test_suite_id: str, table_name: str | None = None, column_name:
lst_cached_functions=[],
)
- if selected:
- selected_test_def = selected[0]
-
if user_can_edit:
if actions_column.button(
":material/edit: Edit",
@@ -178,6 +220,102 @@ def render(self, test_suite_id: str, table_name: str | None = None, column_name:
):
delete_test_dialog(selected)
+ if selected_test_def:
+ render_selected_details(selected_test_def, table_group)
+
+
+def render_grid(df: pd.DataFrame, multi_select: bool, filters_changed: bool) -> list[dict]:
+ columns = [
+ "table_name",
+ "column_name",
+ "test_name_short",
+ "test_active_display",
+ "lock_refresh_display",
+ "urgency",
+ "export_to_observability_display",
+ "profiling_as_of_date",
+ "last_manual_update",
+ ]
+ # Multiselect checkboxes do not display correctly if the dataframe column order does not start with the first displayed column -_-
+ df = df.reindex(columns=[columns[0]] + [ col for col in df.columns.to_list() if col != columns[0] ])
+
+ selected, selected_row = fm.render_grid_select(
+ df,
+ columns,
+ [
+ "Table",
+ "Columns / Focus",
+ "Test Type",
+ "Active",
+ "Locked",
+ "Urgency",
+ "Export to Observabilty",
+ "Based on Profiling",
+ "Last Manual Update",
+ ],
+ id_column="id",
+ selection_mode="multiple" if multi_select else "single",
+ reset_pagination=filters_changed,
+ bind_to_query=True,
+ render_highlights=False,
+ )
+
+ return selected, selected_row
+
+
+def render_selected_details(selected_test: dict, table_group: TableGroupMinimal) -> None:
+ columns = [
+ "schema_name",
+ "table_name",
+ "column_name",
+ "test_type",
+ "test_active_display",
+ "test_definition_status",
+ "lock_refresh_display",
+ "urgency",
+ "export_to_observability",
+ ]
+
+ labels = [
+ "schema_name",
+ "table_name",
+ "column_name",
+ "test_type",
+ "test_active",
+ "test_definition_status",
+ "lock_refresh",
+ "urgency",
+ "export_to_observability",
+ ]
+
+ additional_columns = [val.strip() for val in selected_test["default_parm_columns"].split(",")]
+ columns = columns + additional_columns
+ labels = labels + additional_columns
+ labels = list(map(snake_case_to_title_case, labels))
+
+ left_column, right_column = st.columns([0.5, 0.5])
+
+ with left_column:
+ fm.render_html_list(
+ selected_test,
+ columns,
+ "Test Definition Information",
+ int_data_width=700,
+ lst_labels=labels,
+ )
+
+ _, col_profile_button = right_column.columns([0.7, 0.3])
+ if selected_test["test_scope"] == "column" and selected_test["profile_run_id"]:
+ with col_profile_button:
+ view_profiling_button(
+ selected_test["column_name"],
+ selected_test["table_name"],
+ str(table_group.id),
+ )
+
+ with right_column:
+ st.write(generate_test_defs_help(selected_test["test_type"]))
+
@st.dialog("Delete Tests")
@with_database_session
@@ -472,51 +610,51 @@ def show_test_form(
# schema_name
test_definition["schema_name"] = left_column.text_input(
- label="Schema Name", max_chars=100, value=schema_name, disabled=True
+ label="Schema", max_chars=100, value=schema_name, disabled=True
)
# table_name
- test_definition["table_name"] = left_column.text_input(
- label="Table Name", max_chars=100, value=table_name, disabled=False
- )
-
- # column_name
- if selected_test_type_row["column_name_prompt"]:
- column_name_label = selected_test_type_row["column_name_prompt"]
- else:
- column_name_label = "Test Focus"
- if selected_test_type_row["column_name_help"]:
- column_name_help = selected_test_type_row["column_name_help"]
+ table_column_list = get_columns(table_groups_id)
+ if test_scope == "custom":
+ test_definition["table_name"] = left_column.text_input(
+ label="Table", max_chars=100, value=table_name, disabled=False
+ )
else:
- column_name_help = "Help is not available"
+ table_name_options = { item["table_name"] for item in table_column_list }
+ if table_name not in table_name_options:
+ table_name_options.add(table_name)
+ table_name_options = list(table_name_options)
+ table_name_options.sort(key=lambda x: x.lower())
+ test_definition["table_name"] = st.selectbox(
+ label="Table",
+ options=table_name_options,
+ index=table_name_options.index(table_name) if table_name else 0,
+ disabled=mode == "edit",
+ key="table-name-form",
+ )
+ column_name_label = None
if test_scope == "table":
test_definition["column_name"] = None
- column_name_label = None
- elif test_scope == "referential":
+ elif test_scope in ("referential", "custom"):
+ column_name_label = selected_test_type_row["column_name_prompt"] if selected_test_type_row["column_name_prompt"] else "Test Focus"
test_definition["column_name"] = left_column.text_input(
label=column_name_label,
value=column_name,
max_chars=500,
- help=column_name_help,
- )
- elif test_scope == "custom":
- test_definition["column_name"] = left_column.text_input(
- label=column_name_label,
- value=column_name,
- max_chars=100,
- help=column_name_help,
+ help=selected_test_type_row["column_name_help"] if selected_test_type_row["column_name_help"] else None,
)
elif test_scope == "column": # CAT column test
- column_name_label = "Column Name"
- column_name_options = get_column_names(table_groups_id, test_definition["table_name"])
- column_name_help = "Select the column to test"
- column_name_index = column_name_options.index(column_name) if column_name else 0
+ column_name_label = "Column"
+ column_name_options = { item["column_name"] for item in table_column_list if item["table_name"] == test_definition["table_name"]}
+ if column_name not in column_name_options:
+ column_name_options.add(column_name)
+ column_name_options = list(column_name_options)
+ column_name_options.sort(key=lambda x: x.lower())
test_definition["column_name"] = st.selectbox(
label=column_name_label,
options=column_name_options,
- index=column_name_index,
- help=column_name_help,
+ index=column_name_options.index(column_name) if column_name else 0,
key="column-name-form",
)
@@ -865,143 +1003,6 @@ def update_test_definition(selected, attribute, value, message):
return result
-def show_test_defs_grid(
- test_suite: TestSuite,
- table_name: str | None,
- column_name: str | None,
- test_type: str | None,
- do_multi_select: bool,
- export_container: DeltaGenerator,
- table_group: TableGroupMinimal,
-):
- with st.container():
- with st.spinner("Loading data ..."):
- df = get_test_definitions(test_suite, table_name, column_name, test_type)
-
- lst_show_columns = [
- "table_name",
- "column_name",
- "test_name_short",
- "test_active_display",
- "lock_refresh_display",
- "urgency",
- "export_to_observability_display",
- "profiling_as_of_date",
- "last_manual_update",
- ]
- show_column_headers = [
- "Table",
- "Columns / Focus",
- "Test Type",
- "Active",
- "Locked",
- "Urgency",
- "Export to Observabilty",
- "Based on Profiling",
- "Last Manual Update",
- ]
- # Multiselect checkboxes do not display correctly if the dataframe column order does not start with the first displayed column -_-
- columns = [lst_show_columns[0]] + [ col for col in df.columns.to_list() if col != lst_show_columns[0] ]
- df = df.reindex(columns=columns)
-
- dct_selected_row = fm.render_grid_select(
- df,
- lst_show_columns,
- do_multi_select=do_multi_select,
- show_column_headers=show_column_headers,
- render_highlights=False,
- bind_to_query_name="selected",
- bind_to_query_prop="id",
- )
-
- popover_container = export_container.empty()
-
- def open_download_dialog(data: pd.DataFrame | None = None) -> None:
- # Hack to programmatically close popover: https://github.com/streamlit/streamlit/issues/8265#issuecomment-3001655849
- with popover_container.container():
- flex_row_end()
- st.button(label="Export", icon=":material/download:", disabled=True)
-
- download_dialog(
- dialog_title="Download Excel Report",
- file_content_func=get_excel_report_data,
- args=(test_suite, table_group.table_group_schema, data),
- )
-
- with popover_container.container(key="tg--export-popover"):
- flex_row_end()
- with st.popover(label="Export", icon=":material/download:", help="Download test definitions to Excel"):
- css_class("tg--export-wrapper")
- st.button(label="All tests", type="tertiary", on_click=open_download_dialog)
- st.button(label="Filtered tests", type="tertiary", on_click=partial(open_download_dialog, df))
- if dct_selected_row:
- st.button(label="Selected tests", type="tertiary", on_click=partial(open_download_dialog, pd.DataFrame(dct_selected_row)))
-
- if dct_selected_row:
- st.html(" ")
- selected_row = dct_selected_row[0]
- str_test_id = selected_row["id"]
- row_selected = df[df["id"] == str_test_id].iloc[0]
- str_parm_columns = selected_row["default_parm_columns"]
-
- # Shared columns to show
- lst_show_columns = [
- "schema_name",
- "table_name",
- "column_name",
- "test_type",
- "test_active_display",
- "test_definition_status",
- "lock_refresh_display",
- "urgency",
- "export_to_observability",
- ]
-
- labels = [
- "schema_name",
- "table_name",
- "column_name",
- "test_type",
- "test_active",
- "test_definition_status",
- "lock_refresh",
- "urgency",
- "export_to_observability",
- ]
-
- # Test-specific columns to show
- additional_columns = [val.strip() for val in str_parm_columns.split(",")]
- lst_show_columns = lst_show_columns + additional_columns
- labels = labels + additional_columns
-
- labels = list(map(snake_case_to_title_case, labels))
-
- left_column, right_column = st.columns([0.5, 0.5])
-
- with left_column:
- fm.render_html_list(
- selected_row,
- lst_show_columns,
- "Test Definition Information",
- int_data_width=700,
- lst_labels=labels,
- )
-
- _, col_profile_button = right_column.columns([0.7, 0.3])
- if selected_row["test_scope"] == "column" and selected_row["profile_run_id"]:
- with col_profile_button:
- view_profiling_button(
- selected_row["column_name"],
- selected_row["table_name"],
- str(table_group.id),
- )
-
- with right_column:
- st.write(generate_test_defs_help(row_selected["test_type"]))
-
- return dct_selected_row
-
-
@with_database_session
def get_excel_report_data(
update_progress: PROGRESS_UPDATE_TYPE,
@@ -1198,22 +1199,19 @@ def get_test_definitions_collision(
return to_dataframe(results, TestDefinitionMinimal.columns())
-def get_column_names(table_groups_id: str, table_name: str) -> list[str]:
+def get_columns(table_groups_id: str) -> list[dict]:
results = fetch_all_from_db(
"""
- SELECT column_name
+ SELECT table_name, column_name
FROM data_column_chars
WHERE table_groups_id = :table_groups_id
- AND table_name = :table_name
AND drop_date IS NULL
- ORDER BY column_name
""",
{
"table_groups_id": table_groups_id,
- "table_name": table_name,
},
)
- return [ row.column_name for row in results ]
+ return [ dict(row) for row in results ]
def validate_test(test_definition, table_group: TableGroupMinimal):
@@ -1223,7 +1221,9 @@ def validate_test(test_definition, table_group: TableGroupMinimal):
if test_definition["test_type"] == "Condition_Flag":
condition = test_definition["custom_query"]
- concat_operator = get_flavor_service(connection.sql_flavor).get_concat_operator()
+ flavor_service = get_flavor_service(connection.sql_flavor)
+ concat_operator = flavor_service.concat_operator
+ quote = flavor_service.quote_character
query = f"""
SELECT
COALESCE(
@@ -1235,7 +1235,7 @@ def validate_test(test_definition, table_group: TableGroupMinimal):
{concat_operator} '|',
'
|'
)
- FROM {schema}.{table_name};
+ FROM {quote}{schema}{quote}.{quote}{table_name}{quote};
"""
else:
query = replace_params(
diff --git a/testgen/ui/views/test_results.py b/testgen/ui/views/test_results.py
index 80d24bf1..e88fc859 100644
--- a/testgen/ui/views/test_results.py
+++ b/testgen/ui/views/test_results.py
@@ -5,13 +5,11 @@
from io import BytesIO
from itertools import zip_longest
from operator import attrgetter
-from uuid import UUID
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import streamlit as st
-from streamlit.delta_generator import DeltaGenerator
import testgen.ui.services.form_service as fm
from testgen.commands.run_rollup_scores import run_test_rollup_scoring_queries
@@ -21,7 +19,7 @@
from testgen.common.models.table_group import TableGroup
from testgen.common.models.test_definition import TestDefinition
from testgen.common.models.test_run import TestRun
-from testgen.common.models.test_suite import TestSuite
+from testgen.common.models.test_suite import TestSuite, TestSuiteMinimal
from testgen.ui.components import widgets as testgen
from testgen.ui.components.widgets.download_dialog import (
FILE_DATA_TYPE,
@@ -41,7 +39,7 @@
get_test_issue_source_query_custom,
)
from testgen.ui.services.database_service import execute_db_query, fetch_df_from_db, fetch_one_from_db
-from testgen.ui.services.string_service import empty_if_null, snake_case_to_title_case
+from testgen.ui.services.string_service import snake_case_to_title_case
from testgen.ui.session import session
from testgen.ui.views.dialogs.profiling_results_dialog import view_profiling_button
from testgen.ui.views.test_definitions import show_test_form_by_id
@@ -61,9 +59,9 @@ def render(
self,
run_id: str,
status: str | None = None,
- test_type: str | None = None,
table_name: str | None = None,
column_name: str | None = None,
+ test_type: str | None = None,
action: str | None = None,
**_kwargs,
) -> None:
@@ -92,9 +90,16 @@ def render(
[.175, .2, .2, .175, .15, .1], vertical_alignment="bottom"
)
- testgen.flex_row_end(actions_column)
+ testgen.flex_row_end(actions_column, wrap=True)
testgen.flex_row_end(export_button_column)
+ filters_changed = False
+ current_filters = (status, table_name, column_name, test_type, action)
+ if (query_filters := st.session_state.get("test_results:filters")) != current_filters:
+ if query_filters:
+ filters_changed = True
+ st.session_state["test_results:filters"] = current_filters
+
with summary_column:
tests_summary = get_test_result_summary(run_id)
testgen.summary_bar(items=tests_summary, height=20, width=800)
@@ -175,8 +180,10 @@ def render(
sorting_columns = testgen.sorting_selector(sortable_columns, default)
with actions_column:
- str_help = "Toggle on to perform actions on multiple results"
- do_multi_select = st.toggle("Multi-Select", help=str_help)
+ multi_select = st.toggle(
+ "Multi-Select",
+ help="Toggle on to perform actions on multiple results",
+ )
match status:
case None:
@@ -186,22 +193,80 @@ def render(
case _:
status = [status]
- # Display main grid and retrieve selection
- selected = show_result_detail(
- run_id,
- run_date,
- run.test_suite_id,
- export_button_column,
- session.auth.user_has_permission("edit"),
- status,
- test_type,
- table_name,
- column_name,
- action,
- sorting_columns,
- do_multi_select,
+ with st.container():
+ with st.spinner("Loading data ..."):
+ # Retrieve test results (always cached, action as null)
+ df = test_result_queries.get_test_results(
+ run_id, status, test_type, table_name, column_name, action, sorting_columns
+ )
+ # Retrieve disposition action (cache refreshed)
+ df_action = get_test_disposition(run_id)
+ # Update action from disposition df
+ action_map = df_action.set_index("id")["action"].to_dict()
+ df["action"] = df["test_result_id"].map(action_map).fillna(df["action"])
+
+ # Update action from disposition df
+ action_map = df_action.set_index("id")["action"].to_dict()
+ df["action"] = df["test_result_id"].map(action_map).fillna(df["action"])
+
+ test_suite = TestSuite.get_minimal(run.test_suite_id)
+ table_group = TableGroup.get_minimal(test_suite.table_groups_id)
+
+ selected, selected_row = fm.render_grid_select(
+ df,
+ [
+ "table_name",
+ "column_names",
+ "test_name_short",
+ "result_measure",
+ "measure_uom",
+ "result_status",
+ "action",
+ "result_message",
+ ],
+ [
+ "Table",
+ "Columns/Focus",
+ "Test Type",
+ "Result Measure",
+ "Unit of Measure",
+ "Status",
+ "Action",
+ "Details",
+ ],
+ id_column="test_result_id",
+ selection_mode="multiple" if multi_select else "single",
+ reset_pagination=filters_changed,
+ bind_to_query=True,
)
+ popover_container = export_button_column.empty()
+
+ def open_download_dialog(data: pd.DataFrame | None = None) -> None:
+ # Hack to programmatically close popover: https://github.com/streamlit/streamlit/issues/8265#issuecomment-3001655849
+ with popover_container.container():
+ flex_row_end()
+ st.button(label="Export", icon=":material/download:", disabled=True)
+
+ download_dialog(
+ dialog_title="Download Excel Report",
+ file_content_func=get_excel_report_data,
+ args=(test_suite.test_suite, table_group.table_group_schema, run_date, run_id, data),
+ )
+
+ with popover_container.container(key="tg--export-popover"):
+ flex_row_end()
+ with st.popover(label="Export", icon=":material/download:", help="Download test results to Excel"):
+ css_class("tg--export-wrapper")
+ st.button(label="All tests", type="tertiary", on_click=open_download_dialog)
+ st.button(label="Filtered tests", type="tertiary", on_click=partial(open_download_dialog, df))
+ if selected:
+ st.button(
+ label="Selected tests",
+ type="tertiary",
+ on_click=partial(open_download_dialog, pd.DataFrame(selected)),
+ )
+
# Need to render toolbar buttons after grid, so selection status is maintained
affected_cached_functions = [get_test_disposition, test_result_queries.get_test_results]
@@ -238,10 +303,17 @@ def render(
with score_column:
render_score(run.project_code, run_id)
+ if selected:
+ render_selected_details(
+ selected,
+ selected_row,
+ test_suite,
+ session.auth.user_has_permission("edit"),
+ multi_select,
+ )
+
# Help Links
- st.markdown(
- "[Help on Test Types](https://docs.datakitchen.io/article/dataops-testgen-help/testgen-test-types)"
- )
+ st.markdown("[Help on Test Types](https://docs.datakitchen.io/article/dataops-testgen-help/testgen-test-types)")
@st.fragment
@@ -366,7 +438,7 @@ def get_test_result_summary(test_run_id: str) -> list[dict]:
]
-def show_test_def_detail(test_definition_id: str, test_suite: TestSuite):
+def show_test_def_detail(test_definition_id: str, test_suite: TestSuiteMinimal):
def readable_boolean(v: bool):
return "Yes" if v else "No"
@@ -431,128 +503,51 @@ def readable_boolean(v: bool):
)
-def show_result_detail(
- run_id: str,
- run_date: str,
- test_suite_id: UUID,
- export_container: DeltaGenerator,
+@with_database_session
+def render_selected_details(
+ selected_rows: list[dict],
+ selected_item: dict,
+ test_suite: TestSuiteMinimal,
user_can_edit: bool,
- test_statuses: list[str] | None = None,
- test_type_id: str | None = None,
- table_name: str | None = None,
- column_name: str | None = None,
- action: typing.Literal["Confirmed", "Dismissed", "Muted", "No Action"] | None = None,
- sorting_columns: list[str] | None = None,
- do_multi_select: bool = False,
-):
- with st.container():
- with st.spinner("Loading data ..."):
- # Retrieve test results (always cached, action as null)
- df = test_result_queries.get_test_results(run_id, test_statuses, test_type_id, table_name, column_name, action, sorting_columns)
- # Retrieve disposition action (cache refreshed)
- df_action = get_test_disposition(run_id)
- # Update action from disposition df
- action_map = df_action.set_index("id")["action"].to_dict()
- df["action"] = df["test_result_id"].map(action_map).fillna(df["action"])
-
- # Update action from disposition df
- action_map = df_action.set_index("id")["action"].to_dict()
- df["action"] = df["test_result_id"].map(action_map).fillna(df["action"])
-
- test_suite = TestSuite.get_minimal(test_suite_id)
- table_group = TableGroup.get_minimal(test_suite.table_groups_id)
-
- lst_show_columns = [
- "table_name",
- "column_names",
- "test_name_short",
- "result_measure",
- "measure_uom",
- "result_status",
- "action",
- "result_message",
- ]
-
- lst_show_headers = [
- "Table",
- "Columns/Focus",
- "Test Type",
- "Result Measure",
- "Unit of Measure",
- "Status",
- "Action",
- "Details",
- ]
-
- selected_rows = fm.render_grid_select(
- df,
- lst_show_columns,
- do_multi_select=do_multi_select,
- show_column_headers=lst_show_headers,
- bind_to_query_name="selected",
- bind_to_query_prop="test_result_id",
- )
-
- popover_container = export_container.empty()
-
- def open_download_dialog(data: pd.DataFrame | None = None) -> None:
- # Hack to programmatically close popover: https://github.com/streamlit/streamlit/issues/8265#issuecomment-3001655849
- with popover_container.container():
- flex_row_end()
- st.button(label="Export", icon=":material/download:", disabled=True)
-
- download_dialog(
- dialog_title="Download Excel Report",
- file_content_func=get_excel_report_data,
- args=(test_suite.test_suite, table_group.table_group_schema, run_date, run_id, data),
- )
-
- with popover_container.container(key="tg--export-popover"):
- flex_row_end()
- with st.popover(label="Export", icon=":material/download:", help="Download test results to Excel"):
- css_class("tg--export-wrapper")
- st.button(label="All tests", type="tertiary", on_click=open_download_dialog)
- st.button(label="Filtered tests", type="tertiary", on_click=partial(open_download_dialog, df))
- if selected_rows:
- st.button(label="Selected tests", type="tertiary", on_click=partial(open_download_dialog, pd.DataFrame(selected_rows)))
-
- # Display history and detail for selected row
+ multi_select: bool = False,
+) -> None:
if not selected_rows:
st.markdown(":orange[Select a record to see more information.]")
else:
- selected_row = selected_rows[0]
- dfh = test_result_queries.get_test_result_history(selected_row)
- show_hist_columns = ["test_date", "threshold_value", "result_measure", "result_status"]
-
- time_columns = ["test_date"]
- date_service.accommodate_dataframe_to_timezone(dfh, st.session_state, time_columns)
-
pg_col1, pg_col2 = st.columns([0.5, 0.5])
with pg_col2:
v_col1, v_col2, v_col3, v_col4 = st.columns([.25, .25, .25, .25])
- if user_can_edit:
- view_edit_test(v_col1, selected_row["test_definition_id_current"])
-
- if selected_row["test_scope"] == "column":
- with v_col2:
- view_profiling_button(
- selected_row["column_names"],
- selected_row["table_name"],
- selected_row["table_groups_id"],
- )
- with v_col3:
- if st.button(
- ":material/visibility: Source Data", help="View current source data for highlighted result",
- use_container_width=True
- ):
- MixpanelService().send_event(
- "view-source-data",
- page=PAGE_PATH,
- test_type=selected_row["test_name_short"],
- )
- source_data_dialog(selected_row)
+ if selected_item:
+ dfh = test_result_queries.get_test_result_history(selected_item)
+ show_hist_columns = ["test_date", "threshold_value", "result_measure", "result_status"]
+
+ time_columns = ["test_date"]
+ date_service.accommodate_dataframe_to_timezone(dfh, st.session_state, time_columns)
+
+ if user_can_edit:
+ view_edit_test(v_col1, selected_item["test_definition_id_current"])
+
+ if selected_item["test_scope"] == "column":
+ with v_col2:
+ view_profiling_button(
+ selected_item["column_names"],
+ selected_item["table_name"],
+ selected_item["table_groups_id"],
+ )
+
+ with v_col3:
+ if st.button(
+ ":material/visibility: Source Data", help="View current source data for highlighted result",
+ use_container_width=True
+ ):
+ MixpanelService().send_event(
+ "view-source-data",
+ page=PAGE_PATH,
+ test_type=selected_item["test_name_short"],
+ )
+ source_data_dialog(selected_item)
with v_col4:
@@ -561,7 +556,7 @@ def open_download_dialog(data: pd.DataFrame | None = None) -> None:
if row["result_status"] != "Passed" and row["disposition"] in (None, "Confirmed")
]
- if do_multi_select:
+ if multi_select:
report_btn_help = (
"Generate PDF reports for the selected results that are not muted or dismissed and are not Passed"
)
@@ -594,21 +589,24 @@ def open_download_dialog(data: pd.DataFrame | None = None) -> None:
)
download_dialog(dialog_title=dialog_title, file_content_func=zip_func)
- with pg_col1:
- fm.show_subheader(selected_row["test_name_short"])
- st.markdown(f"###### {selected_row['test_description']}")
- st.caption(empty_if_null(selected_row["measure_uom_description"]))
- fm.render_grid_select(dfh, show_hist_columns, selection_mode="disabled")
- with pg_col2:
- ut_tab1, ut_tab2 = st.tabs(["History", "Test Definition"])
- with ut_tab1:
- if dfh.empty:
- st.write("Test history not available.")
- else:
- write_history_graph(dfh)
- with ut_tab2:
- show_test_def_detail(selected_row["test_definition_id_current"], test_suite)
- return selected_rows
+ if selected_item:
+ with pg_col1:
+ fm.show_subheader(selected_item["test_name_short"])
+ st.markdown(f"###### {selected_item['test_description']}")
+ if selected_item["measure_uom_description"]:
+ st.caption(selected_item["measure_uom_description"])
+ if selected_item["result_message"]:
+ st.caption(selected_item["result_message"])
+ fm.render_grid_select(dfh, show_hist_columns, selection_mode="disabled", key="test_history")
+ with pg_col2:
+ ut_tab1, ut_tab2 = st.tabs(["History", "Test Definition"])
+ with ut_tab1:
+ if dfh.empty:
+ st.write("Test history not available.")
+ else:
+ write_history_graph(dfh)
+ with ut_tab2:
+ show_test_def_detail(selected_item["test_definition_id_current"], test_suite)
@with_database_session
@@ -805,7 +803,7 @@ def source_data_dialog(selected_row):
st.caption(selected_row["test_description"])
st.markdown("#### Test Parameters")
- testgen.caption(selected_row["input_parameters"], styles="max-height: 100px; overflow: auto;")
+ testgen.caption(selected_row["input_parameters"], styles="max-height: 75px; overflow: auto;")
st.markdown("#### Result Detail")
st.caption(selected_row["result_message"])
@@ -816,7 +814,7 @@ def source_data_dialog(selected_row):
else:
query = get_test_issue_source_query(selected_row)
if query:
- st.code(query, language="sql")
+ st.code(query, language="sql", height=100)
with st.spinner("Retrieving source data..."):
if selected_row["test_type"] == "CUSTOM":
@@ -838,7 +836,7 @@ def source_data_dialog(selected_row):
if len(df_bad) == 500:
testgen.caption("* Top 500 records displayed", "text-align: right;")
# Display the dataframe
- st.dataframe(df_bad, height=500, width=1050, hide_index=True)
+ st.dataframe(df_bad, width=1050, hide_index=True)
def view_edit_test(button_container, test_definition_id):
diff --git a/testgen/ui/views/test_runs.py b/testgen/ui/views/test_runs.py
index 5e4f594b..33cf379d 100644
--- a/testgen/ui/views/test_runs.py
+++ b/testgen/ui/views/test_runs.py
@@ -1,4 +1,3 @@
-import json
import logging
import typing
from collections.abc import Iterable
@@ -11,19 +10,19 @@
from testgen.common.models import with_database_session
from testgen.common.models.project import Project
from testgen.common.models.scheduler import RUN_TESTS_JOB_KEY
-from testgen.common.models.table_group import TableGroup, TableGroupMinimal
+from testgen.common.models.table_group import TableGroup
from testgen.common.models.test_run import TestRun
from testgen.common.models.test_suite import TestSuite, TestSuiteMinimal
from testgen.ui.components import widgets as testgen
from testgen.ui.components.widgets import testgen_component
from testgen.ui.navigation.menu import MenuItem
from testgen.ui.navigation.page import Page
+from testgen.ui.navigation.router import Router
from testgen.ui.session import session, temp_value
from testgen.ui.views.dialogs.manage_schedules import ScheduleDialog
from testgen.ui.views.dialogs.run_tests_dialog import run_tests_dialog
-from testgen.utils import friendly_score, to_dataframe, to_int
+from testgen.utils import friendly_score, to_int
-PAGE_SIZE = 50
PAGE_ICON = "labs"
PAGE_TITLE = "Test Runs"
LOG = logging.getLogger("testgen")
@@ -48,93 +47,64 @@ def render(self, project_code: str, table_group_id: str | None = None, test_suit
"test-results",
)
- user_can_run = session.auth.user_has_permission("edit")
- if render_empty_state(project_code, user_can_run):
- return
-
- group_filter_column, suite_filter_column, actions_column = st.columns([.3, .3, .4], vertical_alignment="bottom")
-
- with group_filter_column:
+ with st.spinner("Loading data ..."):
+ project_summary = Project.get_summary(project_code)
+ test_runs = TestRun.select_summary(project_code, table_group_id, test_suite_id)
table_groups = TableGroup.select_minimal_where(TableGroup.project_code == project_code)
- table_groups_df = to_dataframe(table_groups, TableGroupMinimal.columns())
- table_groups_df["id"] = table_groups_df["id"].apply(lambda x: str(x))
- table_group_id = testgen.select(
- options=table_groups_df,
- value_column="id",
- display_column="table_groups_name",
- default_value=table_group_id,
- bind_to_query="table_group_id",
- label="Table Group",
- placeholder="---",
- )
-
- with suite_filter_column:
- clauses = [TestSuite.project_code == project_code]
- if table_group_id:
- clauses.append(TestSuite.table_groups_id == table_group_id)
- test_suites = TestSuite.select_where(*clauses)
- test_suites_df = to_dataframe(test_suites, TestSuite.columns())
- test_suites_df["id"] = test_suites_df["id"].apply(lambda x: str(x))
- test_suite_id = testgen.select(
- options=test_suites_df,
- value_column="id",
- display_column="test_suite",
- default_value=test_suite_id,
- bind_to_query="test_suite_id",
- label="Test Suite",
- placeholder="---",
- )
-
- with actions_column:
- testgen.flex_row_end(actions_column)
-
- st.button(
- ":material/today: Test Run Schedules",
- help="Manage when test suites should run",
- on_click=partial(TestRunScheduleDialog().open, project_code)
- )
-
- if user_can_run:
- st.button(
- ":material/play_arrow: Run Tests",
- help="Run tests for a test suite",
- on_click=partial(run_tests_dialog, project_code, None, test_suite_id)
- )
+ test_suites = TestSuite.select_minimal_where(TestSuite.project_code == project_code)
+
+ testgen_component(
+ "test_runs",
+ props={
+ "project_summary": project_summary.to_dict(json_safe=True),
+ "test_runs": [
+ {
+ **run.to_dict(json_safe=True),
+ "dq_score_testing": friendly_score(run.dq_score_testing),
+ } for run in test_runs
+ ],
+ "table_group_options": [
+ {
+ "value": str(table_group.id),
+ "label": table_group.table_groups_name,
+ "selected": str(table_group_id) == str(table_group.id),
+ } for table_group in table_groups
+ ],
+ "test_suite_options": [
+ {
+ "value": str(test_suite.id),
+ "label": test_suite.test_suite,
+ "selected": str(test_suite_id) == str(test_suite.id),
+ } for test_suite in test_suites
+ if not table_group_id or str(table_group_id) == str(test_suite.table_groups_id)
+ ],
+ "permissions": {
+ "can_edit": session.auth.user_has_permission("edit"),
+ },
+ },
+ on_change_handlers={
+ "FilterApplied": on_test_runs_filtered,
+ "RunSchedulesClicked": lambda *_: TestRunScheduleDialog().open(project_code),
+ "RunTestsClicked": lambda *_: run_tests_dialog(project_code, None, test_suite_id),
+ "RefreshData": refresh_data,
+ "RunsDeleted": partial(on_delete_runs, project_code, table_group_id, test_suite_id),
+ },
+ event_handlers={
+ "RunCanceled": on_cancel_run,
+ },
+ )
- fm.render_refresh_button(actions_column)
- testgen.whitespace(0.5)
- list_container = st.container()
+class TestRunFilters(typing.TypedDict):
+ table_group_id: str
+ test_suite_id: str
- with st.spinner("Loading data ..."):
- test_runs = TestRun.select_summary(project_code, table_group_id, test_suite_id)
+def on_test_runs_filtered(filters: TestRunFilters) -> None:
+ Router().set_query_params(filters)
- paginated = []
- if run_count := len(test_runs):
- page_index = testgen.paginator(count=run_count, page_size=PAGE_SIZE)
- test_runs = [
- {
- **row.to_dict(json_safe=True),
- "dq_score_testing": friendly_score(row.dq_score_testing),
- } for row in test_runs
- ]
- paginated = test_runs[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)]
- with list_container:
- testgen_component(
- "test_runs",
- props={
- "items": json.dumps(paginated),
- "permissions": {
- "can_run": user_can_run,
- "can_edit": user_can_run,
- },
- },
- event_handlers={
- "RunCanceled": on_cancel_run,
- "RunsDeleted": partial(on_delete_runs, project_code, table_group_id, test_suite_id),
- }
- )
+def refresh_data(*_) -> None:
+ TestRun.select_summary.clear()
class TestRunScheduleDialog(ScheduleDialog):
@@ -160,56 +130,6 @@ def get_job_arguments(self, arg_value: str) -> tuple[list[typing.Any], dict[str,
return [], {"project_key": self.project_code, "test_suite_key": arg_value}
-def render_empty_state(project_code: str, user_can_run: bool) -> bool:
- project_summary = Project.get_summary(project_code)
- if project_summary.test_run_count:
- return False
-
- label="No test runs yet"
- testgen.whitespace(5)
- if not project_summary.connection_count:
- testgen.empty_state(
- label=label,
- icon=PAGE_ICON,
- message=testgen.EmptyStateMessage.Connection,
- action_label="Go to Connections",
- link_href="connections",
- link_params={ "project_code": project_code },
- )
- elif not project_summary.table_group_count:
- testgen.empty_state(
- label=label,
- icon=PAGE_ICON,
- message=testgen.EmptyStateMessage.TableGroup,
- action_label="Go to Table Groups",
- link_href="table-groups",
- link_params={
- "project_code": project_code,
- "connection_id": str(project_summary.default_connection_id),
- }
- )
- elif not project_summary.test_suite_count or not project_summary.test_definition_count:
- testgen.empty_state(
- label=label,
- icon=PAGE_ICON,
- message=testgen.EmptyStateMessage.TestSuite,
- action_label="Go to Test Suites",
- link_href="test-suites",
- link_params={ "project_code": project_code },
- )
- else:
- testgen.empty_state(
- label=label,
- icon=PAGE_ICON,
- message=testgen.EmptyStateMessage.TestExecution,
- action_label="Run Tests",
- action_disabled=not user_can_run,
- button_onclick=partial(run_tests_dialog, project_code),
- button_icon="play_arrow",
- )
- return True
-
-
def on_cancel_run(test_run: dict) -> None:
process_status, process_message = process_service.kill_test_run(to_int(test_run["process_id"]))
if process_status:
diff --git a/tests/unit/test_profiling_query.py b/tests/unit/test_profiling_query.py
index 6ca71ecc..368fb5b6 100644
--- a/tests/unit/test_profiling_query.py
+++ b/tests/unit/test_profiling_query.py
@@ -18,7 +18,7 @@ def test_include_exclude_mask_basic():
# test assertions
assert "SELECT 'dummy_project_code'" in query
- assert r"""AND (
+ assert r"""AND (
(c.table_name LIKE 'important%' ) OR (c.table_name LIKE '%useful%' )
)""" in query
assert r"""AND NOT (
@@ -63,6 +63,6 @@ def test_include_empty_include_mask(mask):
print(query)
# test assertions
- assert r"""AND (
+ assert r"""AND (
(c.table_name LIKE 'important%' ) OR (c.table_name LIKE '%useful[_]%' )
)""" in query