Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
cf1bb59
Merge branch 'main' into 'enterprise'
aarthy-dk Aug 25, 2025
18a0d63
fix(analytics): error in username
aarthy-dk Aug 27, 2025
f3d40cb
fix(test suites): allow spaces in name
aarthy-dk Aug 27, 2025
7bd44ef
fix(exports): add schema header detail and remove column
aarthy-dk Aug 27, 2025
c05c344
fix: apply case insensitive sorting
aarthy-dk Aug 27, 2025
f2fbce6
fix(application logs): make search case insensitive
aarthy-dk Aug 27, 2025
b6733d1
Merge branch 'aarthy/case-insensitive-sort' into 'enterprise'
Aug 28, 2025
7859cb5
feat(schedules): add cron expression editor component
luis-dk Aug 27, 2025
626e7eb
Merge branch 'cron-editor' into 'enterprise'
Sep 2, 2025
3712b8b
fix(connections): set snowflake account in connection url
luis-dk Aug 29, 2025
1702008
refactor: validate required fields for connections
luis-dk Sep 2, 2025
68b0dc2
feat(connections): add warehouse field to snowflake form
luis-dk Sep 2, 2025
8661083
refactor(connections): stop creating an initial connection
luis-dk Sep 2, 2025
ce5122c
fix(connections): enforce required url and private key
luis-dk Sep 3, 2025
0ff041a
style: fix linting errors
aarthy-dk Sep 4, 2025
0c9c6a2
Merge branch 'snowflake-account-fix' into 'enterprise'
Sep 4, 2025
94b67ab
fix: incorrect test counts in generate tests warning
aarthy-dk Sep 4, 2025
6b004cb
feat: display sql query in source data dialogs
aarthy-dk Sep 4, 2025
5c66fa0
Merge branch 'fix-tests' into 'enterprise'
Sep 10, 2025
af7dd51
fix: Fixing LOV_All input
rboni-dk Sep 10, 2025
7b379d9
Merge branch 'fix-lov-all-ui' into 'enterprise'
Sep 10, 2025
9cb87e8
fix: append snowflake computing domain when missing
luis-dk Sep 11, 2025
2e3895a
Merge branch 'snowflake-retrocompt-fix' into 'enterprise'
Sep 11, 2025
646f38f
fix(table freshness): filter general type in fingerprint
aarthy-dk Aug 26, 2025
a893d2c
fix: prevent STDEV overflow error
cbloche Sep 10, 2025
0ffba6d
fix(general types): update schema ddf queries
aarthy-dk Sep 10, 2025
80d630b
Merge branch 'aarthy/fix-fingerprint' into 'enterprise'
Sep 11, 2025
b30bc62
misc: Allow re-using the existing users and roles
rboni-dk Sep 10, 2025
1e4c453
Merge branch 'allow-test-and-ui' into 'enterprise'
Sep 11, 2025
97fb07a
feat(tests): generate a monitor suite for new table groups
luis-dk Sep 5, 2025
745a709
Merge branch 'monitor-test-suite-1' into 'enterprise'
Sep 12, 2025
e41338a
feat(ui): allow filtering timezone dropdown
luis-dk Sep 9, 2025
e27a48d
Merge branch 'filter-select-dropdown' into 'enterprise'
Sep 12, 2025
83ece03
Fix: Make authentication case insensitive
diogodk Sep 15, 2025
4fe163a
Merge branch 'fix_auth_case_sensitive_bug' into 'enterprise'
Sep 15, 2025
9c78277
fix(profiling): increase char limit when calculating pattern
aarthy-dk Sep 12, 2025
f44230c
fix(mssql): update is_date function
aarthy-dk Sep 12, 2025
ac2567c
Merge branch 'aarthy/sql-fixes' into 'enterprise'
aarthy-dk Sep 15, 2025
b3f1c90
fix(copy tests): bugs when copying multiple and table tests
aarthy-dk Sep 16, 2025
4dff068
fix(hygiene issues): sort by likelihood first
aarthy-dk Sep 16, 2025
71587c3
Merge branch 'copy-fix' into 'enterprise'
Sep 16, 2025
9de63e6
release: 4.22.2 -> 4.26.1
aarthy-dk Sep 17, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "dataops-testgen"
version = "4.22.2"
version = "4.26.1"
description = "DataKitchen's Data Quality DataOps TestGen"
authors = [
{ "name" = "DataKitchen, Inc.", "email" = "info@datakitchen.io" },
Expand Down Expand Up @@ -58,6 +58,7 @@ dependencies = [
"pydantic==1.10.13",
"streamlit-pydantic==0.6.0",
"cron-converter==1.2.1",
"cron-descriptor==2.0.5",

# Pinned to match the manually compiled libs or for security
"pyarrow==18.1.0",
Expand Down
5 changes: 4 additions & 1 deletion testgen/commands/run_execute_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,12 @@ def run_execution_steps_in_background(project_code, test_suite):
if settings.IS_DEBUG:
LOG.info(msg + ". Running in debug mode (new thread instead of new process).")
empty_cache()
username = None
if session.auth:
username = session.auth.user_display
background_thread = threading.Thread(
target=run_execution_steps,
args=(project_code, test_suite, session.auth.user_display),
args=(project_code, test_suite, username),
)
background_thread.start()
else:
Expand Down
10 changes: 7 additions & 3 deletions testgen/commands/run_launch_db_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ def _get_params_mapping() -> dict:


@with_database_session
def run_launch_db_config(delete_db: bool) -> None:
def run_launch_db_config(delete_db: bool, drop_users_and_roles: bool = True) -> None:
params_mapping = _get_params_mapping()

create_database(get_tg_db(), params_mapping, drop_existing=delete_db, drop_users_and_roles=True)
create_database(get_tg_db(), params_mapping, drop_existing=delete_db, drop_users_and_roles=drop_users_and_roles)

queries = get_queries_for_command("dbsetup", params_mapping)

Expand All @@ -91,4 +91,8 @@ def run_launch_db_config(delete_db: bool) -> None:
project_code=settings.PROJECT_KEY,
table_groups_name=settings.DEFAULT_TABLE_GROUPS_NAME,
)
).save()
).save()


def get_app_db_params_mapping() -> dict:
return _get_params_mapping()
25 changes: 24 additions & 1 deletion testgen/commands/run_profiling_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import testgen.common.process_service as process_service
from testgen import settings
from testgen.commands.queries.profiling_query import CProfilingSQL
from testgen.commands.run_execute_tests import run_execution_steps_in_background
from testgen.commands.run_generate_tests import run_test_gen_queries
from testgen.commands.run_refresh_score_cards_results import run_refresh_score_cards_results
from testgen.common import (
date_service,
Expand All @@ -25,6 +27,7 @@
from testgen.common.mixpanel_service import MixpanelService
from testgen.common.models import with_database_session
from testgen.common.models.connection import Connection
from testgen.common.models.test_suite import TestSuite
from testgen.ui.session import session

LOG = logging.getLogger("testgen")
Expand Down Expand Up @@ -211,7 +214,7 @@ def run_profiling_in_background(table_group_id):
empty_cache()
background_thread = threading.Thread(
target=run_profiling_queries,
args=(table_group_id, session.auth.user_display),
args=(table_group_id, session.auth.user_display if session.auth else None),
)
background_thread.start()
else:
Expand All @@ -238,6 +241,9 @@ def run_profiling_queries(table_group_id: str, username: str | None = None, spin
profiling_run_id = str(uuid.uuid4())

params = get_profiling_params(table_group_id)
needs_monitor_tests_generated = (
bool(params["monitor_test_suite_id"]) and not params["last_complete_profile_run_id"]
)

LOG.info("CurrentStep: Initializing Query Generator")
clsProfiling = CProfilingSQL(params["project_code"], connection.sql_flavor, minutes_offset=minutes_offset)
Expand Down Expand Up @@ -471,7 +477,24 @@ def run_profiling_queries(table_group_id: str, username: str | None = None, spin
scoring_duration=(datetime.now(UTC) - end_time).total_seconds(),
)

if needs_monitor_tests_generated:
_generate_monitor_tests(params["project_code"], table_group_id, params["monitor_test_suite_id"])

return f"""
Profiling completed {"with errors. Check log for details." if has_errors else "successfully."}
Run ID: {profiling_run_id}
"""


@with_database_session
def _generate_monitor_tests(project_code: str, table_group_id: str, test_suite_id: str) -> None:
try:
monitor_test_suite = TestSuite.get(test_suite_id)
if not monitor_test_suite:
LOG.info("Skipping test generation on missing monitor test suite")
else:
LOG.info("Generating monitor tests")
run_test_gen_queries(table_group_id, monitor_test_suite.test_suite, "Monitor")
run_execution_steps_in_background(project_code, monitor_test_suite.test_suite)
except Exception:
LOG.exception("Error generating monitor tests")
10 changes: 9 additions & 1 deletion testgen/commands/run_quick_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import click

from testgen import settings
from testgen.commands.run_launch_db_config import run_launch_db_config
from testgen.commands.run_launch_db_config import get_app_db_params_mapping, run_launch_db_config
from testgen.common.credentials import get_tg_schema
from testgen.common.database.database_service import (
create_database,
Expand Down Expand Up @@ -117,6 +117,14 @@ def run_quick_start(delete_target_db: bool) -> None:
delete_db = True
run_launch_db_config(delete_db)

click.echo("Seeding the application db")
app_db_params = get_app_db_params_mapping()
execute_db_queries(
[
(replace_params(read_template_sql_file("initial_data_seeding.sql", "quick_start"), app_db_params), app_db_params),
],
)

# Schema and Populate target db
click.echo(f"Populating target db : {target_db_name}")
execute_db_queries(
Expand Down
2 changes: 1 addition & 1 deletion testgen/commands/run_refresh_score_cards_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def run_refresh_score_cards_results(
history_entry.add_as_cutoff()
definition.save()
LOG.info(
"CurrentStep: Done rereshing scorecard %s in project %s",
"CurrentStep: Done refreshing scorecard %s in project %s",
definition.name,
definition.project_code,
)
Expand Down
61 changes: 53 additions & 8 deletions testgen/common/database/flavor/flavor_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import abstractmethod
from typing import Literal, TypedDict
from typing import Any, Literal, TypedDict
from urllib.parse import parse_qs, urlparse

from testgen.common.encrypt import DecryptText

Expand Down Expand Up @@ -37,19 +38,21 @@ class FlavorService:
private_key_passphrase = None
http_path = None
catalog = None
warehouse = None

def init(self, connection_params: ConnectionParams):
self.url = connection_params.get("url", None)
self.url = connection_params.get("url") or ""
self.connect_by_url = connection_params.get("connect_by_url", False)
self.username = connection_params.get("project_user")
self.host = connection_params.get("project_host")
self.port = connection_params.get("project_port")
self.dbname = connection_params.get("project_db")
self.username = connection_params.get("project_user") or ""
self.host = connection_params.get("project_host") or ""
self.port = connection_params.get("project_port") or ""
self.dbname = connection_params.get("project_db") or ""
self.flavor = connection_params.get("sql_flavor")
self.dbschema = connection_params.get("table_group_schema", None)
self.connect_by_key = connection_params.get("connect_by_key", False)
self.http_path = connection_params.get("http_path", None)
self.catalog = connection_params.get("catalog", None)
self.http_path = connection_params.get("http_path") or ""
self.catalog = connection_params.get("catalog") or ""
self.warehouse = connection_params.get("warehouse") or ""

password = connection_params.get("project_pw_encrypted", None)
if isinstance(password, memoryview) or isinstance(password, bytes):
Expand Down Expand Up @@ -90,3 +93,45 @@ def get_connection_string_from_fields(self) -> str:
@abstractmethod
def get_connection_string_head(self) -> str:
raise NotImplementedError("Subclasses must implement this method")

def get_parts_from_connection_string(self) -> dict[str, Any]:
if self.connect_by_url:
if not self.url:
return {}

parsed_url = urlparse(self.get_connection_string())
credentials, location = (
parsed_url.netloc if "@" in parsed_url.netloc else f"@{parsed_url.netloc}"
).split("@")
username, password = (
credentials if ":" in credentials else f"{credentials}:"
).split(":")
host, port = (
location if ":" in location else f"{location}:"
).split(":")

database = (path_patrs[0] if (path_patrs := parsed_url.path.strip("/").split("/")) else "")

extras = {
param_name: param_values[0]
for param_name, param_values in parse_qs(parsed_url.query or "").items()
}

return {
"username": username,
"password": password,
"host": host,
"port": port,
"dbname": database,
**extras,
}

return {
"username": self.username,
"password": self.password,
"host": self.host,
"port": self.port,
"dbname": self.dbname,
"http_path": self.http_path,
"catalog": self.catalog,
}
37 changes: 20 additions & 17 deletions testgen/common/database/flavor/snowflake_flavor_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from snowflake.sqlalchemy import URL

from testgen.common.database.flavor.flavor_service import FlavorService

Expand Down Expand Up @@ -38,25 +39,27 @@ def get_connection_string_from_fields(self):
# optionally + '/[schema]' + '?warehouse=xxx'
# NOTE: Snowflake host should NOT include ".snowflakecomputing.com"

def get_raw_host_name(host):
endings = [
".snowflakecomputing.com",
]
for ending in endings:
if host.endswith(ending):
i = host.index(ending)
return host[0:i]
return host
account, _ = self.host.split(".", maxsplit=1) if "." in self.host else ("", "")
host = self.host
if ".snowflakecomputing.com" not in host:
host = f"{host}.snowflakecomputing.com"

raw_host = get_raw_host_name(self.host)
host = raw_host
if self.port != "443":
host += ":" + self.port
extra_params = {}
if self.warehouse:
extra_params["warehouse"] = self.warehouse

if self.connect_by_key:
return f"snowflake://{self.username}@{host}/{self.dbname}/{self.dbschema}"
else:
return f"snowflake://{self.username}:{quote_plus(self.password)}@{host}/{self.dbname}/{self.dbschema}"
connection_url = URL(
host=host,
port=int(self.port if str(self.port).isdigit() else 443),
account=account,
user=self.username,
password="" if self.connect_by_key else self.password,
database=self.dbname,
schema=self.dbschema or "",
**extra_params,
)

return connection_url

def get_pre_connection_queries(self):
return [
Expand Down
2 changes: 2 additions & 0 deletions testgen/common/get_pipeline_parms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class ProfilingParams(BaseParams):
profile_sample_min_count: int
profile_do_pair_rules: str
profile_pair_rule_pct: int
monitor_test_suite_id: str | None
last_complete_profile_run_id: str | None


class TestGenerationParams(BaseParams):
Expand Down
2 changes: 1 addition & 1 deletion testgen/common/mixpanel_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def send_event(self, event_name, include_usage=False, **properties):
properties.setdefault("instance_id", self.instance_id)
properties.setdefault("edition", settings.DOCKER_HUB_REPOSITORY)
properties.setdefault("version", settings.VERSION)
properties.setdefault("username", session.auth.user_display)
properties.setdefault("username", session.auth.user_display if session.auth else None)
properties.setdefault("distinct_id", self.get_distinct_id(properties["username"]))
if include_usage:
properties.update(self.get_usage())
Expand Down
23 changes: 14 additions & 9 deletions testgen/common/models/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
Integer,
String,
asc,
func,
select,
)
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import InstrumentedAttribute

from testgen.common.database.database_service import get_flavor_service
from testgen.common.database.flavor.flavor_service import SQLFlavor
from testgen.common.models import get_current_session
from testgen.common.models.custom_types import EncryptedBytea
Expand Down Expand Up @@ -58,9 +60,10 @@ class Connection(Entity):
private_key: str = Column(EncryptedBytea)
private_key_passphrase: str = Column(EncryptedBytea)
http_path: str = Column(String)
warehouse: str = Column(String)

_get_by = "connection_id"
_default_order_by = (asc(connection_name),)
_default_order_by = (asc(func.lower(connection_name)),)
_minimal_columns = ConnectionMinimal.__annotations__.keys()

@classmethod
Expand Down Expand Up @@ -114,13 +117,15 @@ def clear_cache(cls) -> bool:

def save(self) -> None:
if self.connect_by_url and self.url:
url_sections = self.url.split("/")
if url_sections:
host_port = url_sections[0]
host_port_sections = host_port.split(":")
self.project_host = host_port_sections[0] if host_port_sections else host_port
self.project_port = "".join(host_port_sections[1:]) if host_port_sections else ""
if len(url_sections) > 1:
self.project_db = url_sections[1]
flavor_service = get_flavor_service(self.sql_flavor)
flavor_service.init(self.to_dict())

connection_parts = flavor_service.get_parts_from_connection_string()
if connection_parts:
self.project_host = connection_parts["host"]
self.project_port = connection_parts["port"]
self.project_db = connection_parts["dbname"]
self.http_path = connection_parts.get("http_path") or None
self.warehouse = connection_parts.get("warehouse") or None

super().save()
2 changes: 2 additions & 0 deletions testgen/common/models/profiling_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class ProfilingRunMinimal(EntityMinimal):
project_code: str
table_groups_id: UUID
table_groups_name: str
table_group_schema: str
profiling_starttime: datetime
dq_score_profiling: float
is_latest_run: bool
Expand Down Expand Up @@ -81,6 +82,7 @@ class ProfilingRun(Entity):
project_code,
table_groups_id,
TableGroup.table_groups_name,
TableGroup.table_group_schema,
profiling_starttime,
dq_score_profiling,
case(
Expand Down
4 changes: 2 additions & 2 deletions testgen/common/models/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from uuid import UUID, uuid4

import streamlit as st
from sqlalchemy import Column, String, asc, text
from sqlalchemy import Column, String, asc, func, text
from sqlalchemy.dialects import postgresql

from testgen.common.models import get_current_session
Expand Down Expand Up @@ -34,7 +34,7 @@ class Project(Entity):
observability_api_key: str = Column(NullIfEmptyString)

_get_by = "project_code"
_default_order_by = (asc(project_name),)
_default_order_by = (asc(func.lower(project_name)),)

@classmethod
@st.cache_data(show_spinner=False)
Expand Down
Loading