diff --git a/.gitignore b/.gitignore index 139decb6..62b713e7 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,6 @@ cython_debug/ **/jsconfig.json .pytest_cache/ .ruff_cache/ + +testgen/ui/components/frontend/js/plugins.js +testgen/ui/components/frontend/js/plugin_pages/ diff --git a/deploy/testgen.dockerfile b/deploy/testgen.dockerfile index f8ba88fd..318a3add 100644 --- a/deploy/testgen.dockerfile +++ b/deploy/testgen.dockerfile @@ -20,7 +20,7 @@ RUN addgroup -S testgen && adduser -S testgen -G testgen # Streamlit has to be able to write to these dirs RUN mkdir /var/lib/testgen -RUN chown -R testgen:testgen /var/lib/testgen /dk/lib/python3.12/site-packages/streamlit/static +RUN chown -R testgen:testgen /var/lib/testgen /dk/lib/python3.12/site-packages/streamlit/static /dk/lib/python3.12/site-packages/testgen/ui/components/frontend ENV TESTGEN_VERSION=${TESTGEN_VERSION} ENV TESTGEN_DOCKER_HUB_REPO=${TESTGEN_DOCKER_HUB_REPO} diff --git a/pyproject.toml b/pyproject.toml index b0f36486..9b0879ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta" [project] name = "dataops-testgen" -version = "4.1.2" +version = "4.1.3" description = "DataKitchen's Data Quality DataOps TestGen" authors = [ { "name" = "DataKitchen, Inc.", "email" = "info@datakitchen.io" }, diff --git a/testgen/common/mixpanel_service.py b/testgen/common/mixpanel_service.py index 77396eb9..dd6608b3 100644 --- a/testgen/common/mixpanel_service.py +++ b/testgen/common/mixpanel_service.py @@ -45,8 +45,10 @@ def _hash_value(self, value: bytes | str, digest_size: int = 8) -> str: @safe_method def send_event(self, event_name, **properties): properties.setdefault("instance_id", self.instance_id) + properties.setdefault("edition", settings.DOCKER_HUB_REPOSITORY) properties.setdefault("version", settings.VERSION) properties.setdefault("distinct_id", self.distinct_id) + properties.setdefault("username", session.username) track_payload = { "event": event_name, diff --git a/testgen/template/dbsetup/030_initialize_new_schema_structure.sql b/testgen/template/dbsetup/030_initialize_new_schema_structure.sql index 6b46af6d..897cb7c5 100644 --- a/testgen/template/dbsetup/030_initialize_new_schema_structure.sql +++ b/testgen/template/dbsetup/030_initialize_new_schema_structure.sql @@ -65,6 +65,7 @@ CREATE TABLE connections ( CONSTRAINT connections_connection_id_pk PRIMARY KEY, sql_flavor VARCHAR(30), + sql_flavor_code VARCHAR(30), project_host VARCHAR(250), project_port VARCHAR(5), project_user VARCHAR(50), diff --git a/testgen/template/dbsetup/040_populate_new_schema_project.sql b/testgen/template/dbsetup/040_populate_new_schema_project.sql index 7e6672c2..84d4d961 100644 --- a/testgen/template/dbsetup/040_populate_new_schema_project.sql +++ b/testgen/template/dbsetup/040_populate_new_schema_project.sql @@ -9,11 +9,12 @@ SELECT '{PROJECT_CODE}' as project_code, '{OBSERVABILITY_API_URL}' as observability_api_url; INSERT INTO connections -(project_code, sql_flavor, +(project_code, sql_flavor, sql_flavor_code, project_host, project_port, project_user, project_db, connection_name, project_pw_encrypted, http_path, max_threads, max_query_chars) SELECT '{PROJECT_CODE}' as project_code, '{SQL_FLAVOR}' as sql_flavor, + '{SQL_FLAVOR}' as sql_flavor_code, '{PROJECT_HOST}' as project_host, '{PROJECT_PORT}' as project_port, '{PROJECT_USER}' as project_user, diff --git a/testgen/template/dbupgrade/0140_incremental_upgrade.sql b/testgen/template/dbupgrade/0140_incremental_upgrade.sql new file mode 100644 index 00000000..c2a1770f --- /dev/null +++ b/testgen/template/dbupgrade/0140_incremental_upgrade.sql @@ -0,0 +1,3 @@ +SET SEARCH_PATH TO {SCHEMA_NAME}; + +ALTER TABLE connections ADD COLUMN sql_flavor_code VARCHAR(30) DEFAULT NULL; diff --git a/testgen/ui/bootstrap.py b/testgen/ui/bootstrap.py index 4d6efec5..6b0fed7a 100644 --- a/testgen/ui/bootstrap.py +++ b/testgen/ui/bootstrap.py @@ -1,14 +1,9 @@ import dataclasses -import importlib -import inspect import logging -import streamlit as st - from testgen import settings from testgen.commands.run_upgrade_db_config import get_schema_revision from testgen.common import configure_logging, version_service -from testgen.ui.assets import get_asset_path from testgen.ui.navigation.menu import Menu, Version from testgen.ui.navigation.page import Page from testgen.ui.navigation.router import Router @@ -53,19 +48,8 @@ LOG = logging.getLogger("testgen") -class Logo: - image_path: str = get_asset_path("dk_logo.svg") - icon_path: str = get_asset_path("dk_icon.svg") - - def render(self): - st.logo( - image=self.image_path, - icon_image=self.icon_path, - ) - - class Application(singleton.Singleton): - def __init__(self, logo: Logo, router: Router, menu: Menu, logger: logging.Logger) -> None: + def __init__(self, logo: plugins.Logo, router: Router, menu: Menu, logger: logging.Logger) -> None: self.logo = logo self.router = router self.menu = menu @@ -87,20 +71,30 @@ def run(log_level: int = logging.INFO) -> Application: pages = [*BUILTIN_PAGES] installed_plugins = plugins.discover() + if not settings.IS_DEBUG: + """ + This cleanup is called so that TestGen can remove uninstalled + plugins without having to be reinstalled. + + The check for DEBUG mode is because multithreading for Streamlit + fragments loads before the plugins can be re-loaded. + """ + plugins.cleanup() + configure_logging(level=log_level) - logo_class = Logo + logo_class = plugins.Logo for plugin in installed_plugins: - module = importlib.import_module(plugin.package) - for property_name in dir(module): - if ( - (maybe_class := getattr(module, property_name, None)) - and inspect.isclass(maybe_class) - ): - if issubclass(maybe_class, Page): - pages.append(maybe_class) - elif issubclass(maybe_class, Logo): - logo_class = maybe_class + spec = plugin.load() + + if spec.page: + pages.append(spec.page) + + if spec.logo: + logo_class = spec.logo + + if spec.component: + spec.component.provide() return Application( logo=logo_class(), diff --git a/testgen/ui/components/frontend/css/shared.css b/testgen/ui/components/frontend/css/shared.css index 20452b94..643b4ffb 100644 --- a/testgen/ui/components/frontend/css/shared.css +++ b/testgen/ui/components/frontend/css/shared.css @@ -426,6 +426,38 @@ body { margin-left: 40px; } +.p-0 { + padding: 0; +} + +.p-1 { + padding: 4px; +} + +.p-2 { + padding: 8px; +} + +.p-3 { + padding: 12px; +} + +.p-4 { + padding: 16px; +} + +.p-5 { + padding: 24px; +} + +.p-6 { + padding: 32px; +} + +.p-7 { + padding: 40px; +} + .pt-0 { padding-top: 0; } diff --git a/testgen/ui/components/frontend/js/components/alert.js b/testgen/ui/components/frontend/js/components/alert.js index 8b4f6e34..a797d6aa 100644 --- a/testgen/ui/components/frontend/js/components/alert.js +++ b/testgen/ui/components/frontend/js/components/alert.js @@ -1,38 +1,41 @@ /** - * @typedef Alert - * @type {object} - * @property {string} value - * @property {string} color - * @property {string} label - * * @typedef Properties * @type {object} * @property {string?} icon - * @property {'info'|'success'|'error'} type - * @property {string?} message + * @property {number?} timeout + * @property {boolean?} closeable + * @property {'info'|'success'|'warn'|'error'} type */ import van from '../van.min.js'; -import { getValue, loadStylesheet } from '../utils.js'; +import { getValue, loadStylesheet, getRandomId } from '../utils.js'; import { Icon } from './icon.js'; +import { Button } from './button.js'; const { div } = van.tags; const alertTypeColors = { info: {backgroundColor: 'rgba(28, 131, 225, 0.1)', color: 'rgb(0, 66, 128)'}, success: {backgroundColor: 'rgba(33, 195, 84, 0.1)', color: 'rgb(23, 114, 51)'}, + warn: {backgroundColor: 'rgba(255, 227, 18, 0.2)', color: 'rgb(255, 255, 194)'}, error: {backgroundColor: 'rgba(255, 43, 43, 0.09)', color: 'rgb(125, 53, 59)'}, }; const Alert = (/** @type Properties */ props, /** @type Array */ ...children) => { loadStylesheet('alert', stylesheet); + const elementId = getValue(props.id) ?? 'tg-alert-' + getRandomId(); + const close = () => { + document.getElementById(elementId)?.remove(); + }; + const timeout = getValue(props.timeout); + if (timeout && timeout > 0) { + setTimeout(close, timeout); + } + return div( { ...props, - class: () => (getValue(props.class) ?? '') + ` tg-alert flex-row`, - style: () => { - const colors = alertTypeColors[getValue(props.type)]; - return `color: ${colors.color}; background-color: ${colors.backgroundColor};`; - }, + id: elementId, + class: () => `tg-alert flex-row ${getValue(props.class) ?? ''} tg-alert-${getValue(props.type)}`, role: 'alert', }, () => { @@ -43,6 +46,19 @@ const Alert = (/** @type Properties */ props, /** @type Array */ .. {class: 'flex-column'}, ...children, ), + () => { + const isCloseable = getValue(props.closeable) ?? false; + if (!isCloseable) { + return ''; + } + + const colors = alertTypeColors[getValue(props.type)]; + return Button({ + type: 'icon', + icon: 'close', + style: `margin-left: auto; color: ${colors.color};`, + }); + }, ); }; @@ -54,6 +70,34 @@ stylesheet.replace(` font-size: 16px; line-height: 24px; } + +.tg-alert-info { + background-color: rgba(28, 131, 225, 0.1); + color: rgb(0, 66, 128); +} + +.tg-alert-success { + background-color: rgba(33, 195, 84, 0.1); + color: rgb(23, 114, 51); +} + +.tg-alert-error { + background-color: rgba(255, 43, 43, 0.09); + color: rgb(125, 53, 59); +} + +.tg-alert-warn { + background-color: rgba(255, 227, 18, 0.1); + color: rgb(146, 108, 5); +} + +@media (prefers-color-scheme: dark) { + .tg-alert-warn { + background-color: rgba(255, 227, 18, 0.2); + color: rgb(255, 255, 194); + } +} + .tg-alert > .tg-icon { color: inherit !important; } diff --git a/testgen/ui/components/frontend/js/components/button.js b/testgen/ui/components/frontend/js/components/button.js index d90b0034..08b32393 100644 --- a/testgen/ui/components/frontend/js/components/button.js +++ b/testgen/ui/components/frontend/js/components/button.js @@ -214,6 +214,10 @@ button.tg-button.tg-warn-button.tg-stroked-button { color: var(--button-warn-stroked-text-color); background: var(--button-warn-stroked-background); } + +button.tg-button.tg-warn-button[disabled] { + color: rgba(255, 255, 255, .5) !important; +} /* ... */ `); diff --git a/testgen/ui/components/frontend/js/components/checkbox.js b/testgen/ui/components/frontend/js/components/checkbox.js index b2619b9e..75bbe743 100644 --- a/testgen/ui/components/frontend/js/components/checkbox.js +++ b/testgen/ui/components/frontend/js/components/checkbox.js @@ -1,7 +1,9 @@ /** * @typedef Properties * @type {object} + * @property {string?} name * @property {string} label + * @property {string?} help * @property {boolean?} checked * @property {boolean?} indeterminate * @property {function(boolean, Event)?} onChange @@ -10,6 +12,8 @@ */ import van from '../van.min.js'; import { getValue, loadStylesheet } from '../utils.js'; +import { withTooltip } from './tooltip.js'; +import { Icon } from './icon.js'; const { input, label, span } = van.tags; @@ -19,11 +23,12 @@ const Checkbox = (/** @type Properties */ props) => { return label( { class: 'flex-row fx-gap-2 clickable', - 'data-testid': props.testId ?? '', + 'data-testid': props.testId ?? props.name ?? '', style: () => `width: ${props.width ? getValue(props.width) + 'px' : 'auto'}`, }, input({ type: 'checkbox', + name: props.name ?? '', class: 'tg-checkbox--input clickable', checked: props.checked, indeterminate: props.indeterminate, @@ -33,6 +38,12 @@ const Checkbox = (/** @type Properties */ props) => { }), }), span({'data-testid': 'checkbox-label'}, props.label), + () => getValue(props.help) + ? withTooltip( + Icon({ size: 16, classes: 'text-disabled' }, 'help'), + { text: props.help, position: 'top', width: 200 } + ) + : null, ); }; diff --git a/testgen/ui/components/frontend/js/components/connection_form.js b/testgen/ui/components/frontend/js/components/connection_form.js new file mode 100644 index 00000000..5486fd91 --- /dev/null +++ b/testgen/ui/components/frontend/js/components/connection_form.js @@ -0,0 +1,624 @@ +/** + * @import { FileValue } from './file_input.js'; + * + * @typedef Flavor + * @type {object} + * @property {string} label + * @property {string} value + * @property {string} icon + * @property {string} flavor + * @property {string} connection_string + * + * @typedef ConnectionStatus + * @type {object} + * @property {string} message + * @property {boolean} successful + * @property {string?} details + * + * @typedef Connection + * @type {object} + * @property {string} connection_id + * @property {string} connection_name + * @property {string} sql_flavor + * @property {string} sql_flavor_code + * @property {string} project_host + * @property {string} project_port + * @property {string} project_db + * @property {string} project_user + * @property {string} password + * @property {boolean} connect_by_url + * @property {string?} url + * @property {boolean} connect_by_key + * @property {string?} private_key + * @property {string?} private_key_passphrase + * @property {string?} http_path + * @property {ConnectionStatus?} status + * + * @typedef FormState + * @type {object} + * @property {boolean} dirty + * @property {boolean} valid + * + * @typedef FieldsCache + * @type {object} + * @property {FileValue} privateKey + * + * @typedef Properties + * @type {object} + * @property {Connection} connection + * @property {Array.} flavors + * @property {boolean} disableFlavor + * @property {FileValue?} cachedPrivateKeyFile + * @property {(c: Connection, state: FormState, cache?: FieldsCache) => void} onChange + */ +import van from '../van.min.js'; +import { Button } from './button.js'; +import { Alert } from './alert.js'; +import { getValue, emitEvent, loadStylesheet, isEqual } from '../utils.js'; +import { Input } from './input.js'; +import { Slider } from './slider.js'; +import { Checkbox } from './checkbox.js'; +import { Select } from './select.js'; +import { maxLength, minLength, sizeLimit } from '../form_validators.js'; +import { RadioGroup } from './radio_group.js'; +import { FileInput } from './file_input.js'; + +const { div, hr, i, span } = van.tags; +const clearSentinel = ''; +const secretsPlaceholder = ''; +const defaultPorts = { + redshift: '5439', + azure_mssql: '1433', + synapse_mssql: '1433', + mssql: '1433', + postgresql: '5432', + snowflake: '443', + databricks: '443', +}; + +/** + * + * @param {Properties} props + * @param {(any|undefined)} saveButton + * @returns {HTMLElement} + */ +const ConnectionForm = (props, saveButton) => { + loadStylesheet('connectionform', stylesheet); + + const connection = getValue(props.connection); + const isEditMode = !!connection?.connection_id; + const defaultPort = defaultPorts[connection?.sql_flavor]; + + const connectionFlavor = van.state(connection?.sql_flavor_code); + const connectionName = van.state(connection?.connection_name); + const connectionHost = van.state(connection?.project_host); + const connectionPort = van.state(connection?.project_port ?? defaultPort); + const connectionDatabase = van.state(connection?.project_db); + const connectionUsername = van.state(connection?.project_user); + const connectionPassword = van.state(connection?.password); + const connectionMaxThreads = van.state(connection?.max_threads ?? 4); + const connectionQueryChars = van.state(connection?.max_query_chars ?? 9000); + const connectByUrl = van.state(connection?.connect_by_url ?? false); + const connectByKey = van.state(connection?.connect_by_key ?? false); + const privateKey = van.state(connection?.private_key); + const privateKeyPhrase = van.state(connection?.private_key_passphrase); + const httpPath = van.state(connection?.http_path); + + const privateKeyFile = van.state(getValue(props.cachedPrivateKeyFile) ?? null); + van.derive(() => { + const fileInputValue = privateKeyFile.val; + if (fileInputValue?.content) { + privateKey.val = fileInputValue.content.split(',')?.[1] ?? ''; + } + }); + const clearPrivateKeyPhrase = van.state(false); + + if (isEditMode) { + connectionPassword.val = ''; + privateKey.val = ''; + privateKeyPhrase.val = ''; + } + + const flavor = getValue(props.flavors).find(f => f.value === connectionFlavor.val); + const originalURLTemplate = van.state(flavor.connection_string); + const [prefixPart, sufixPart] = originalURLTemplate.val.split('@'); + + const connectionStringPrefix = van.state(prefixPart); + const connectionStringSuffix = van.state(connection?.url ?? ''); + if (!connectionStringSuffix.val) { + connectionStringSuffix.val = formatURL(sufixPart ?? '', connectionHost.val, connectionPort.val, connectionDatabase.val); + } + + const updatedConnection = van.derive(() => { + return { + project_code: connection.project_code, + connection_id: connection.connection_id, + sql_flavor: connection?.sql_flavor ?? undefined, + sql_flavor_code: connectionFlavor.val ?? '', + connection_name: connectionName.val ?? '', + project_host: connectionHost.val ?? '', + project_port: connectionPort.val ?? '', + project_db: connectionDatabase.val ?? '', + project_user: connectionUsername.val ?? '', + password: connectionPassword.val ?? '', + max_threads: connectionMaxThreads.val ?? 4, + max_query_chars: connectionQueryChars.val ?? 9000, + connect_by_url: connectByUrl.val ?? false, + url: connectionStringSuffix.val, + connect_by_key: connectByKey.val ?? false, + private_key: privateKey.val ?? '', + private_key_passphrase: clearPrivateKeyPhrase.val ? clearSentinel : (privateKeyPhrase.val ?? ''), + http_path: httpPath.val ?? '', + }; + }); + const dirty = van.derive(() => !isEqual(updatedConnection.val, connection)); + const validityPerField = van.state({}); + + van.derive(() => { + const fieldsValidity = validityPerField.val; + const isValid = Object.keys(fieldsValidity).length > 0 && + Object.values(fieldsValidity).every(v => v); + props.onChange?.(updatedConnection.val, { dirty: dirty.val, valid: isValid }, { privateKey: privateKeyFile.rawVal }); + }); + + const setFieldValidity = (field, validity) => { + validityPerField.val = {...validityPerField.val, [field]: validity}; + } + + const authenticationForms = { + redshift: () => PasswordConnectionForm( + connection, + connectionPassword, + (value, state) => { + connectionPassword.val = value; + setFieldValidity('password', state.valid); + }, + isEditMode, + ), + mssql: () => PasswordConnectionForm( + connection, + connectionPassword, + (value, state) => { + connectionPassword.val = value; + setFieldValidity('password', state.valid); + }, + isEditMode, + ), + postgresql: () => PasswordConnectionForm( + connection, + connectionPassword, + (value, state) => { + connectionPassword.val = value; + setFieldValidity('password', state.valid); + }, + isEditMode, + ), + snowflake: () => KeyPairConnectionForm( + connection, + connectByKey, + connectionPassword, + privateKeyFile, + privateKeyPhrase, + clearPrivateKeyPhrase, + (value, state) => { + connectByKey.val = value.connect_by_key; + connectionPassword.val = value.password; + privateKeyFile.val = value.private_key; + privateKeyPhrase.val = value.private_key_passphrase; + setFieldValidity('key_pair_form', state.valid); + }, + isEditMode, + ), + databricks: () => HttpPathConnectionForm( + connection, + connectionPassword, + httpPath, + (value, state) => { + connectionPassword.val = value.password; + httpPath.val = value.http_path; + setFieldValidity('http_path_form', state.valid); + }, + isEditMode, + ), + }; + const authenticationForm = van.derive(() => { + const selectedFlavorCode = connectionFlavor.val; + const flavor = getValue(props.flavors).find(f => f.value === selectedFlavorCode); + return authenticationForms[flavor.flavor](); + }); + + van.derive(() => { + const selectedFlavorCode = connectionFlavor.val; + const previousFlavorCode = connectionFlavor.oldVal; + const isCustomPort = connectionPort.rawVal !== defaultPorts[previousFlavorCode]; + if (selectedFlavorCode !== previousFlavorCode && (!isCustomPort || !connectionPort.rawVal)) { + connectionPort.val = defaultPorts[selectedFlavorCode]; + } + }); + + van.derive(() => { + const connectionHost_ = connectionHost.val; + const connectionPort_ = connectionPort.val; + const connectionDatabase_ = connectionDatabase.val; + const connectionHttpPath_ = httpPath.val; + const urlTemplate = originalURLTemplate.val; + + if (!connectByUrl.rawVal && urlTemplate.includes('@')) { + const [originalURLPrefix, originalURLSuffix] = urlTemplate.split('@'); + connectionStringPrefix.val = originalURLPrefix; + connectionStringSuffix.val = formatURL(originalURLSuffix, connectionHost_, connectionPort_, connectionDatabase_, connectionHttpPath_); + } + }); + + return div( + { class: 'flex-column fx-gap-3 fx-align-stretch', style: 'overflow-y: auto;' }, + div( + { class: 'flex-row fx-gap-3 fx-align-stretch' }, + div( + { class: 'flex-column fx-gap-3', style: 'flex: 2' }, + Select({ + label: 'Database Type', + value: connectionFlavor, + options: props.flavors, + disabled: props.disableFlavor, + height: 38, + help: 'Type of database server to connect to. This determines the database driver and SQL dialect that will be used by TestGen.', + testId: 'sql_flavor', + onChange: (value) => { + const flavor = getValue(props.flavors).find(f => f.value === value); + originalURLTemplate.val = flavor.connection_string; + }, + }), + Input({ + name: 'connection_name', + label: 'Connection Name', + value: connectionName, + height: 38, + help: 'Unique name to describe the connection', + onChange: (value, state) => { + connectionName.val = value; + setFieldValidity('connection_name', state.valid); + }, + validators: [ minLength(3), maxLength(40) ], + }), + div( + { class: 'flex-row fx-gap-3 fx-flex' }, + Input({ + name: 'db_host', + label: 'Host', + value: connectionHost, + height: 38, + class: 'fx-flex', + disabled: connectByUrl, + onChange: (value, state) => { + connectionHost.val = value; + setFieldValidity('db_host', state.valid); + }, + validators: [ maxLength(250) ], + }), + Input({ + name: 'db_port', + label: 'Port', + value: connectionPort, + height: 38, + type: 'number', + disabled: connectByUrl, + onChange: (value, state) => { + connectionPort.val = value; + setFieldValidity('db_port', state.valid); + }, + validators: [ minLength(3), maxLength(5) ], + }), + ), + Input({ + name: 'db_name', + label: 'Database', + value: connectionDatabase, + height: 38, + disabled: connectByUrl, + onChange: (value, state) => { + connectionDatabase.val = value; + setFieldValidity('db_name', state.valid); + }, + validators: [ maxLength(100) ], + }), + Input({ + name: 'db_user', + label: 'Username', + value: connectionUsername, + height: 38, + onChange: (value, state) => { + connectionUsername.val = value; + setFieldValidity('db_user', state.valid); + }, + validators: [ maxLength(50) ], + }), + ), + div( + { class: 'flex-column fx-gap-3', style: 'padding: 2px; flex: 1;' }, + Slider({ + label: 'Max Threads (Advanced Tuning)', + hint: 'Maximum number of concurrent threads that run tests. Default values should be retained unless test queries are failing.', + value: connectionMaxThreads, + min: 1, + max: 8, + onChange: (value) => connectionMaxThreads.val = value, + }), + Slider({ + label: 'Max Expression Length (Advanced Tuning)', + hint: 'Some tests are consolidated into queries for maximum performance. Default values should be retained unless test queries are failing.', + value: connectionQueryChars, + min: 500, + max: 14000, + onChange: (value) => connectionQueryChars.val = value, + }), + ), + ), + authenticationForm, + hr({ style: 'width: 100%;', class: 'mt-2 mb-2' }), + Checkbox({ + name: 'connect_by_url', + label: 'URL Override', + help: 'When checked, the connection string will be driven by the field below, along with the username and password from the fields above', + checked: connectByUrl.val, + onChange: (checked) => connectByUrl.val = checked, + }), + () => { + const connectByUrl_ = getValue(connectByUrl); + + if (!connectByUrl_) { + return ''; + } + + return div( + { class: 'flex-row fx-gap-3 fx-align-stretch' }, + Input({ + label: 'URL Prefix', + disabled: true, + value: connectionStringPrefix, + height: 38, + width: 255, + name: 'url_prefix', + }), + Input({ + label: 'URL Suffix', + value: connectionStringSuffix, + class: 'fx-flex', + height: 38, + name: 'url_suffix', + onChange: (value, state) => connectionStringSuffix.val = value, + }), + ); + }, + div( + { class: 'flex-row fx-gap-3 fx-justify-space-between' }, + Button({ + label: 'Test Connection', + color: 'basic', + type: 'stroked', + width: 'auto', + onclick: () => emitEvent('TestConnectionClicked', { payload: updatedConnection.val }), + }), + saveButton, + ), + () => { + const conn = getValue(props.connection); + const connectionStatus = conn.status; + return connectionStatus + ? Alert( + {type: connectionStatus.successful ? 'success' : 'error', closeable: true}, + div( + { class: 'flex-column' }, + span(connectionStatus.message), + connectionStatus.details ? span(connectionStatus.details) : '', + ) + ) + : ''; + }, + ); +}; + +const PasswordConnectionForm = (connection, password, onValueChange, useSecretsPlaceholder) => { + return div( + { class: 'flex-row fx-gap-3 fx-align-stretch' }, + div( + { class: 'flex-column fx-gap-3', style: 'flex: 2' }, + Input({ + name: 'password', + label: 'Password', + value: password, + height: 38, + type: 'password', + placeholder: (useSecretsPlaceholder && connection.password) ? secretsPlaceholder : '', + onChange: onValueChange, + }), + ), + div( + { class: 'flex-column fx-gap-3', style: 'padding: 2px; flex: 1;' }, + '', + ), + ); +}; + +const HttpPathConnectionForm = ( + connection, + password, + httpPath, + onValueChange, + useSecretsPlaceholder, +) => { + const passwordFieldState = van.state({value: password.val, valid: false}); + const httpPathFieldState = van.state({value: httpPath.val, valid: false}); + + van.derive(() => { + const passwordField = passwordFieldState.val; + const httpPathField = httpPathFieldState.val; + onValueChange({password: passwordField.value, http_path: httpPathField.value}, { valid: passwordField.valid && httpPathField.valid }); + }); + + return div( + { class: 'flex-row fx-gap-3 fx-align-stretch' }, + div( + { class: 'flex-column fx-gap-3', style: 'flex: 2' }, + Input({ + name: 'password', + label: 'Password', + value: password, + height: 38, + type: 'password', + placeholder: (useSecretsPlaceholder && connection.password) ? secretsPlaceholder : '', + onChange: (value, state) => passwordFieldState.val = {value, valid: state.valid}, + }), + Input({ + label: 'HTTP Path', + value: httpPath, + class: 'fx-flex', + height: 38, + name: 'http_path', + onChange: (value, state) => httpPathFieldState.val = {value, valid: state.valid}, + validators: [ maxLength(50) ], + }) + ), + div( + { class: 'flex-column fx-gap-3', style: 'padding: 2px; flex: 1;' }, + '', + ), + ); +}; + +const KeyPairConnectionForm = ( + connection, + connectByKey, + password, + privateKey, + privateKeyPhrase, + clearPrivateKeyPhrase, + onValueChange, + useSecretsPlaceholder, +) => { + const connectByKeyFieldState = van.state({value: connectByKey.val, valid: true}); + const passwordFieldState = van.state({value: password.val, valid: true}); + const privateKeyFieldState = van.state({value: privateKey.val, valid: true}); + const privateKeyPhraseFieldState = van.state({value: privateKeyPhrase.val, valid: true}); + + van.derive(() => { + const connectByKeyField = connectByKeyFieldState.val; + const passwordField = passwordFieldState.val; + const privateKeyField = privateKeyFieldState.val; + const privateKeyPhraseField = privateKeyPhraseFieldState.val; + + let isValid = passwordField.valid; + if (connectByKeyField.value) { + isValid = privateKeyField.valid && privateKeyPhraseField.valid; + } + + onValueChange( + { + connect_by_key: connectByKeyField.value, + password: passwordField.value, + private_key: privateKeyField.value, + private_key_passphrase: privateKeyPhraseField.value, + }, + { valid: isValid }, + ); + }); + + return div( + { class: 'flex-column' }, + hr({ style: 'width: 100%;', class: 'mt-2 mb-2' }), + RadioGroup({ + label: 'Connection Strategy', + options: [ + {label: 'Connect By Password', value: false}, + {label: 'Connect By Key-Pair', value: true}, + ], + value: connectByKey, + onChange: (value) => connectByKeyFieldState.val = {value, valid: true}, + }), + () => { + if (connectByKey.val) { + return div( + { class: 'flex-column fx-gap-3' }, + div( + { class: 'key-pair-passphrase-field'}, + Input({ + name: 'private_key_passphrase', + label: 'Private Key Passphrase', + value: privateKeyPhrase, + height: 38, + type: 'password', + help: 'Passphrase used when creating the private key. Leave empty if the private key is not encrypted.', + placeholder: () => (useSecretsPlaceholder && connection.private_key_passphrase && !clearPrivateKeyPhrase.val) ? secretsPlaceholder : '', + onChange: (value, state) => { + if (value) { + clearPrivateKeyPhrase.val = false; + } + privateKeyPhraseFieldState.val = {value, valid: state.valid}; + }, + }), + () => { + const hasPrivateKeyPhrase = connection.private_key_passphrase || privateKeyPhraseFieldState.val?.value; + if (!hasPrivateKeyPhrase) { + return ''; + } + + return i( + { + class: 'material-symbols-rounded clickable text-secondary', + onclick: () => { + clearPrivateKeyPhrase.val = true; + privateKeyPhraseFieldState.val = {value: '', valid: true}; + }, + }, + 'clear', + ); + }, + ), + FileInput({ + name: 'private_key', + label: 'Upload private key (rsa_key.p8)', + placeholder: connection.private_key ? 'Drop file here or browse files to replace existing key' : undefined, + value: privateKey, + onChange: (value, state) => privateKeyFieldState.val = {value, valid: state.valid}, + validators: [ + sizeLimit(200 * 1024 * 1024), + ], + }), + ); + } + + return Input({ + name: 'password', + label: 'Password', + value: password, + height: 38, + type: 'password', + placeholder: (useSecretsPlaceholder && connection.password) ? secretsPlaceholder : '', + onChange: (value, state) => passwordFieldState.val = {value, valid: state.valid}, + }); + }, + ); +}; + +function formatURL(url, host, port, database, httpPath) { + return url.replace('', host) + .replace('', port) + .replace('', database) + .replace('', httpPath); +} + +const stylesheet = new CSSStyleSheet(); +stylesheet.replace(` +.key-pair-passphrase-field { + position: relative; +} + +.key-pair-passphrase-field > i { + position: absolute; + top: 26px; + right: 8px; +} + +`); + +export { ConnectionForm }; diff --git a/testgen/ui/components/frontend/js/components/expansion_panel.js b/testgen/ui/components/frontend/js/components/expansion_panel.js new file mode 100644 index 00000000..777cdf0d --- /dev/null +++ b/testgen/ui/components/frontend/js/components/expansion_panel.js @@ -0,0 +1,66 @@ +/** + * @typedef Options + * @type {object} + * @property {string} title + * @property {string?} testId + */ + +import van from '../van.min.js'; +import { loadStylesheet } from '../utils.js'; +import { Icon } from './icon.js'; + +const { div, span } = van.tags; + +/** + * + * @param {Options} options + * @param {...HTMLElement} children + */ +const ExpansionPanel = (options, ...children) => { + loadStylesheet('expansion-panel', stylesheet); + + const expanded = van.state(false); + const icon = van.derive(() => expanded.val ? 'keyboard_arrow_up' : 'keyboard_arrow_down'); + const expansionClass = van.derive(() => expanded.val ? '' : 'collapsed'); + + return div( + { class: () => `tg-expansion-panel ${expansionClass.val}`, 'data-testid': options.testId ?? '' }, + div( + { + class: 'tg-expansion-panel--title flex-row fx-justify-space-between clickable', + 'data-testid': 'expansion-panel-trigger', + onclick: () => expanded.val = !expanded.val, + }, + span({}, options.title), + Icon({}, icon), + ), + div( + { class: 'tg-expansion-panel--content mt-4' }, + ...children, + ), + ); +}; + +const stylesheet = new CSSStyleSheet(); +stylesheet.replace(` +.tg-expansion-panel { + border: 1px solid var(--border-color); + border-radius: 8px; + padding: 12px; +} + +.tg-expansion-panel--title:hover { + color: var(--primary-color); +} + +.tg-expansion-panel--title:hover i.tg-icon { + color: var(--primary-color) !important; +} + +.tg-expansion-panel.collapsed > .tg-expansion-panel--content { + height: 0; + display: none; +} +`); + +export { ExpansionPanel }; diff --git a/testgen/ui/components/frontend/js/components/file_input.js b/testgen/ui/components/frontend/js/components/file_input.js new file mode 100644 index 00000000..fb869f64 --- /dev/null +++ b/testgen/ui/components/frontend/js/components/file_input.js @@ -0,0 +1,231 @@ +/** + * @import {InputState} from './input.js'; + * @import {Validator} from '../form_validators.js'; + * + * @typedef FileValue + * @type {object} + * @property {string} name + * @property {string} content + * @property {number} size + * + * @typedef Options + * @type {object} + * @property {string} label + * @property {string?} placeholder + * @property {string} name + * @property {string} value + * @property {string?} class + * @property {Array?} validators + * @property {function(FileValue?, InputState)?} onChange + * + */ +import van from '../van.min.js'; +import { getRandomId, getValue, loadStylesheet } from "../utils.js"; +import { Icon } from './icon.js'; +import { Button } from './button.js'; +import { humanReadableSize } from '../display_utils.js'; + +const { div, input, label, span } = van.tags; + +/** + * File uploader component that emits change events with a base64 + * encoding of the uploaded file. + * + * @param {Options} options + * @returns {HTMLElement} + */ +const FileInput = (options) => { + loadStylesheet('file-uploader', stylesheet); + + const value = van.state(getValue(options.value)); + const inputId = `file-uploader-${getRandomId()}`; + const fileOver = van.state(false); + const cssClass = van.derive(() => `tg-file-uploader flex-column fx-gap-2 ${getValue(options.class) ?? ''}`) + const showLoading = van.state(false); + const loadingIndicatorProgress = van.state(0); + const loadingIndicatorStyle = van.derive(() => `width: ${loadingIndicatorProgress.val}%;`); + const errors = van.derive(() => { + const validators = getValue(options.validators) ?? []; + return validators.map(v => v(value.val)).filter(error => error); + }); + + let sizeLimit = undefined; + let sizeLimitValidator = (getValue(options.validators) ?? []).filter(v => v.args.name === 'sizeLimit')[0]; + if (sizeLimitValidator) { + sizeLimit = sizeLimitValidator.args.limit; + } + + van.derive(() => { + if (options.onChange && (value.val !== value.oldVal || errors.val.length !== errors.oldVal.length)) { + options.onChange(value.val, { errors: errors.val, valid: errors.val.length <= 0 }); + } + }); + + const browseFile = () => { + document.getElementById(inputId).click(); + }; + + const loadFile = (event) => { + const selectedFile = event.target.files[0]; + if (!selectedFile) { + value.val = null; + showLoading.val = false; + loadingIndicatorProgress.val = 0; + return; + } + + const fileReader = new FileReader(); + fileReader.addEventListener('loadstart', (event) => { + loadingIndicatorProgress.val = 0; + showLoading.val = event.lengthComputable; + }); + fileReader.addEventListener('progress', (event) => { + if (showLoading.val) { + loadingIndicatorProgress.val = event.loaded / event.total; + } + }); + fileReader.addEventListener('loadend', (event) => { + loadingIndicatorProgress.val = 100; + value.val = { + name: selectedFile.name, + content: fileReader.result, + size: event.loaded, + }; + }); + + fileReader.readAsDataURL(selectedFile); + }; + + const unloadFile = (event) => { + event.stopPropagation(); + value.val = null; + showLoading.val = false; + loadingIndicatorProgress.val = 0; + }; + + return div( + { class: cssClass }, + label( + { class: 'tg-file-uploader--label' }, + options.label, + ), + div( + { class: () => `tg-file-uploader--dropzone flex-column clickable ${fileOver.val ? 'on-dragover' : ''}` }, + div( + { + onclick: browseFile, + ondragenter: (event) => { + event.preventDefault(); + fileOver.val = true; + }, + ondragleave: (event) => { + if (!event.currentTarget.contains(event.relatedTarget)) { + fileOver.val = false; + } + }, + ondragover: (event) => event.preventDefault(), + ondrop: (/** @type {DragEvent} */event) => { + event.preventDefault(); + fileOver.val = false; + + let files = [...(event.dataTransfer.items ?? [])].filter((item) => item.kind === 'file').map((item) => item.getAsFile()); + if (!event.dataTransfer.items) { + files = [...(event.dataTransfer.files ?? [])]; + } + + loadFile({ target: { files }}); + }, + }, + input({ + id: inputId, + type: 'file', + name: options.name, + tabindex: '-1', + onchange: loadFile, + }), + () => value.val + ? FileSummary(value.val, unloadFile) + : FileSelectionDropZone(options.placeholder ?? 'Drop file here or browse files', sizeLimit) + ), + () => showLoading.val + ? div({ class: 'tg-file-uploader--loading', style: loadingIndicatorStyle }, '') + : '', + ), + ); +}; + +/** + * + * @param {string} placeholder + * @param {number} sizeLimit + * @returns + */ +const FileSelectionDropZone = (placeholder, sizeLimit) => { + return div( + { class: 'flex-row fx-gap-4' }, + Icon({size: 48}, 'cloud_upload'), + div( + { class: 'flex-column fx-gap-1' }, + span({}, placeholder), + span({ class: 'text-secondary text-caption' }, `Limit ${humanReadableSize(sizeLimit)} per file`), + ), + ); +}; + +const FileSummary = (value, onFileUnload) => { + const fileName = getValue(value).name; + const fileSize = humanReadableSize(getValue(value).size); + + return div( + { class: 'flex-row fx-gap-4' }, + Icon({size: 48}, 'draft'), + div( + { class: 'flex-column fx-gap-1' }, + span({}, fileName), + span({ class: 'text-secondary text-caption' }, `Size: ${fileSize}`), + ), + span({ style: 'margin: 0px auto;'}), + Button({ + type: 'icon', + color: 'basic', + icon: 'close', + onclick: onFileUnload, + }), + ); +}; + +const stylesheet = new CSSStyleSheet(); +stylesheet.replace(` +.tg-file-uploader { +} + +.tg-file-uploader--dropzone { + border-radius: 8px; + background: var(--form-field-color); + padding: 16px; + position: relative; + border: 1px transparent dashed; +} + +.tg-file-uploader--dropzone.on-dragover { + border-color: var(--primary-color); +} + +.tg-file-uploader--dropzone input[type="file"] { + display: none; +} + +.tg-file-uploader--loading { + height: 3px; + background: var(--primary-color); + position: absolute; + width: 0%; + left: 0; + bottom: 0; + border-bottom-left-radius: 8px; + border-bottom-right-radius: 8px; + transition: 200ms width ease-in; +} +`); + +export { FileInput }; diff --git a/testgen/ui/components/frontend/js/components/flavor_selector.js b/testgen/ui/components/frontend/js/components/flavor_selector.js deleted file mode 100644 index 8cd1c173..00000000 --- a/testgen/ui/components/frontend/js/components/flavor_selector.js +++ /dev/null @@ -1,147 +0,0 @@ -/** - * @typedef Falvor - * @type {object} - * @property {string} label - * @property {string} value - * @property {string} icon - * @property {(boolean|null)} selected - * - * @typedef Properties - * @type {object} - * @property {Array.} flavors - * @property {((number|null))} selected - * @property {(number|null)} columns - */ - -import van from '../van.min.js'; -import { Streamlit } from '../streamlit.js'; -import { loadStylesheet } from '../utils.js'; - -const headerHeight = 35; -const rowGap = 16; -const rowHeight = 67; -const columnSize = '200px'; -const { div, span, img, h3 } = van.tags; - -const DatabaseFlavorSelector = (/** @type Properties */props) => { - loadStylesheet('databaseFlavorSelector', stylesheet); - - const flavors = props.flavors?.val ?? props.flavors; - const numberOfColumns = props.columns?.val ?? props.columns ?? 3; - const numberOfRows = Math.ceil(flavors.length / numberOfColumns); - const selectedIndex = van.state(props.selected?.val ?? props.selected); - - window.testgen.isPage = true; - Streamlit.setFrameHeight( - headerHeight - + rowHeight * numberOfRows - + rowGap * (numberOfRows - 1) - ); - - return div( - {class: 'tg-flavor-selector-page'}, - h3( - {class: 'tg-flavor-selector-header'}, - 'Select your database type' - ), - () => { - return div( - { - class: 'tg-flavor-selector', - style: `grid-template-columns: ${Array(numberOfColumns).fill(columnSize).join(' ')}; row-gap: ${rowGap}px;` - }, - flavors.map((flavor, idx) => - DatabaseFlavor( - { - label: van.state(flavor.label), - value: van.state(flavor.value), - icon: van.state(flavor.icon), - selected: van.derive(() => selectedIndex.val == idx), - }, - () => { - selectedIndex.val = idx; - Streamlit.sendData({index: idx, value: flavor.value}); - }, - ) - ), - ); - }, - ); -}; - -const DatabaseFlavor = ( - /** @type Falvor */ props, - /** @type Function */ onClick, -) => { - return div( - { - class: () => `tg-flavor ${props.selected.val ? 'selected' : ''}`, - onclick: onClick, - }, - span({class: 'tg-flavor-focus-state-indicator'}, ''), - img( - {class: 'tg-flavor--icon', src: props.icon}, - ), - span( - {class: 'tg-flavor--label'}, - props.label - ), - ); -}; - -const stylesheet = new CSSStyleSheet(); -stylesheet.replace(` - .tg-flavor-selector-header { - margin: unset; - margin-bottom: 16px; - font-weight: 400; - } - - .tg-flavor-selector { - display: grid; - grid-template-rows: auto; - column-gap: 32px; - } - - .tg-flavor { - display: flex; - align-items: center; - padding: 16px; - border: 1px solid var(--border-color); - border-radius: 4px; - cursor: pointer; - position: relative; - } - - .tg-flavor .tg-flavor-focus-state-indicator::before { - content: ""; - opacity: 0; - top: 0; - left: 0; - right: 0; - bottom: 0; - position: absolute; - pointer-events: none; - border-radius: inherit; - background: var(--button-primary-hover-state-background); - } - - .tg-flavor.selected { - border-color: var(--primary-color); - } - - .tg-flavor:hover .tg-flavor-focus-state-indicator::before, - .tg-flavor.selected .tg-flavor-focus-state-indicator::before { - opacity: var(--button-hover-state-opacity); - } - - .tg-flavor--icon { - margin-right: 16px; - } - - .tg-flavor--label { - font-weight: 500; - } -`); - -export { DatabaseFlavorSelector }; diff --git a/testgen/ui/components/frontend/js/components/input.js b/testgen/ui/components/frontend/js/components/input.js index 0d86d080..ad3cdee8 100644 --- a/testgen/ui/components/frontend/js/components/input.js +++ b/testgen/ui/components/frontend/js/components/input.js @@ -1,19 +1,33 @@ /** + * @import { Properties as TooltipProperties } from './tooltip.js'; + * @import { Validator } from '../form_validators.js'; + * + * @typedef InputState + * @type {object} + * @property {boolean} valid + * @property {string[]} errors + * * @typedef Properties * @type {object} * @property {string?} id + * @property {string?} name * @property {string?} label * @property {string?} help + * @property {TooltipProperties['position']} helpPlacement * @property {(string | number)?} value * @property {string?} placeholder * @property {string[]?} autocompleteOptions * @property {string?} icon * @property {boolean?} clearable - * @property {function(string)?} onChange + * @property {boolean?} disabled + * @property {function(string, InputState)?} onChange * @property {number?} width * @property {number?} height * @property {string?} style + * @property {string?} type + * @property {string?} class * @property {string?} testId + * @property {Array?} validators */ import van from '../van.min.js'; import { debounce, getValue, loadStylesheet, getRandomId } from '../utils.js'; @@ -21,7 +35,7 @@ import { Icon } from './icon.js'; import { withTooltip } from './tooltip.js'; import { Portal } from './portal.js'; -const { div,input, label, i } = van.tags; +const { div,input, label, i, small } = van.tags; const defaultHeight = 32; const iconSize = 22; const clearIconSize = 20; @@ -31,10 +45,22 @@ const Input = (/** @type Properties */ props) => { const domId = van.derive(() => getValue(props.id) ?? getRandomId()); const value = van.derive(() => getValue(props.value) ?? ''); + const errors = van.derive(() => { + const validators = getValue(props.validators) ?? []; + return validators.map(v => v(value.val)).filter(error => error); + }); + const firstError = van.derive(() => { + return errors.val[0] ?? ''; + }); + + const onChange = props.onChange?.val ?? props.onChange; + if (onChange) { + onChange(value.val, { errors: errors.val, valid: errors.val.length <= 0 }); + } van.derive(() => { const onChange = props.onChange?.val ?? props.onChange; - if (value.val !== value.oldVal) { - onChange(value.val); + if (onChange && (value.val !== value.oldVal || errors.val.length !== errors.oldVal.length)) { + onChange(value.val, { errors: errors.val, valid: errors.val.length <= 0 }); } }); @@ -54,9 +80,9 @@ const Input = (/** @type Properties */ props) => { return label( { id: domId, - class: 'flex-column fx-gap-1 tg-input--label', + class: () => `flex-column fx-gap-1 tg-input--label ${getValue(props.class) ?? ''}`, style: () => `width: ${props.width ? getValue(props.width) + 'px' : 'auto'}; ${getValue(props.style)}`, - 'data-testid': props.testId ?? '', + 'data-testid': props.testId ?? props.name ?? '', }, div( { class: 'flex-row fx-gap-1 text-caption' }, @@ -64,7 +90,7 @@ const Input = (/** @type Properties */ props) => { () => getValue(props.help) ? withTooltip( Icon({ size: 16, classes: 'text-disabled' }, 'help'), - { text: props.help, position: 'top', width: 200 } + { text: props.help, position: getValue(props.helpPlacement) ?? 'top', width: 200 } ) : null, ), @@ -84,9 +110,12 @@ const Input = (/** @type Properties */ props) => { 'clear', ) : '', input({ - class: 'tg-input--field', + class: () => `tg-input--field ${getValue(props.disabled) ? 'tg-input--disabled' : ''}`, style: () => `height: ${getValue(props.height) || defaultHeight}px;`, value, + name: props.name ?? '', + type: props.type ?? 'text', + disabled: props.disabled, placeholder: () => getValue(props.placeholder) ?? '', oninput: debounce((/** @type Event */ event) => value.val = event.target.value, 300), onclick: van.derive(() => autocompleteOptions.val?.length @@ -94,6 +123,10 @@ const Input = (/** @type Properties */ props) => { : null ), }), + () => + getValue(props.validators)?.length > 0 + ? small({ class: 'tg-input--error' }, firstError) + : '', Portal( { target: domId.val, targetRelative: true, opened: autocompleteOpened }, () => div( @@ -198,6 +231,16 @@ stylesheet.replace(` .tg-input--option:hover { background: var(--select-hover-background); } + +.tg-input--disabled { + cursor: not-allowed; + color: var(--disabled-text-color); +} + +.tg-input--label > .tg-input--error { + height: 12px; + color: var(--error-color); +} `); export { Input }; diff --git a/testgen/ui/components/frontend/js/components/slider.js b/testgen/ui/components/frontend/js/components/slider.js new file mode 100644 index 00000000..2582fc8b --- /dev/null +++ b/testgen/ui/components/frontend/js/components/slider.js @@ -0,0 +1,164 @@ +/** + * @typedef Properties + * @type {object} + * @property {string} label + * @property {number} value + * @property {number} min + * @property {number} max + * @property {number} step + * @property {function(number)?} onChange + * @property {string?} hint + */ +import van from '../van.min.js'; +import { getValue, loadStylesheet } from '../utils.js'; + +const { input, label, span } = van.tags; + +const Slider = (/** @type Properties */ props) => { + loadStylesheet('slider', stylesheet); + + const value = van.state(getValue(props.value) ?? getValue(props.min) ?? 0); + + const handleInput = e => { + value.val = Number(e.target.value); + props.onChange?.(value.val); + }; + + return label( + { class: 'flex-col fx-gap-1 clickable tg-slider--label text-caption' }, + props.label, + input({ + type: "range", + min: props.min ?? 0, + max: props.max ?? 100, + step: props.step ?? 1, + value: value, + oninput: handleInput, + class: 'tg-slider--input', + }), + span({ class: "tg-slider--value" }, () => value.val), + props.hint && span({ class: "tg-slider--hint" }, props.hint) + ); +}; + +const stylesheet = new CSSStyleSheet(); +stylesheet.replace(` +.tg-slider--label { + display: flex; + flex-direction: column; + gap: 0.5em; + font-family: inherit; +} + +.tg-slider--value { + font-size: 0.9em; + color: var(--primary-text-color); +} + +.tg-slider--hint { + font-size: 0.8em; + color: var(--disabled-text-color); +} + +/* Basic reset and common styles for the range input */ +input[type=range].tg-slider--input { + -webkit-appearance: none; /* Override default WebKit styles */ + appearance: none; /* Override default pseudo-element styles */ + width: 100%; /* Full width */ + height: 20px; /* Set height to accommodate thumb; track will be smaller */ + cursor: pointer; + outline: none; + background: transparent; /* Make default track invisible, we'll style it manually */ + accent-color: var(--primary-color); /* Sets thumb and selected track color for modern browsers (Chrome, Edge, Firefox) */ +} + +/* --- Thumb Styling (#06a04a) --- */ +/* WebKit (Chrome, Safari, Opera, Edge Chromium) */ +input[type=range].tg-slider--input::-webkit-slider-thumb { + -webkit-appearance: none; /* Required to style */ + appearance: none; + height: 20px; /* Thumb height */ + width: 20px; /* Thumb width */ + background-color: var(--primary-color); /* Thumb color */ + border-radius: 50%; /* Make it circular */ + border: none; /* No border */ + margin-top: -7px; /* Vertically center thumb on track. (Thumb height - Track height) / 2 = (20px - 6px) / 2 = 7px */ + /* This assumes track height is 6px (defined below) */ +} + +/* Firefox */ +input[type=range].tg-slider--input::-moz-range-thumb { + height: 20px; /* Thumb height */ + width: 20px; /* Thumb width */ + background-color: var(--primary-color); /* Thumb color */ + border-radius: 50%; /* Make it circular */ + border: none; /* No border */ +} + +/* IE / Edge Legacy (EdgeHTML) */ +input[type=range].tg-slider--input::-ms-thumb { + height: 20px; /* Thumb height */ + width: 20px; /* Thumb width */ + background-color: var(--primary-color); /* Thumb color */ + border-radius: 50%; /* Make it circular */ + border: 0; /* No border */ + /* margin-top: 1px; /* IE may need slight adjustment if track style requires it */ +} + +/* --- Track Styling --- */ +/* Track "unselected" section: #EEEEEE */ +/* Track "selected" section: #06a04a */ + +/* WebKit browsers */ +input[type=range].tg-slider--input::-webkit-slider-runnable-track { + width: 100%; + height: 6px; /* Track height */ + background: var(--grey); /* Color of the "unselected" part of the track */ + /* accent-color (set on the input) will color the "selected" part */ +// background: transparent !important; + border-radius: 3px; /* Rounded track edges */ +} + +/* Firefox */ +input[type=range].tg-slider--input::-moz-range-track { + width: 100%; + height: 6px; /* Track height */ +// background: var(--grey); /* Color of the "unselected" part of the track */ + background: transparent !important; + border-radius: 3px; /* Rounded track edges */ +} + +/* For Firefox, the "selected" part of the track is ::-moz-range-progress */ +/* This is often handled by accent-color, but explicitly styling it provides a fallback. */ +input[type=range].tg-slider--input::-moz-range-progress { + height: 6px; /* Must match track height */ + background-color: var(--primary-color); /* Color of the "selected" part */ + border-radius: 3px; /* Rounded track edges */ +} + +/* IE / Edge Legacy (EdgeHTML) */ +input[type=range].tg-slider--input::-ms-track { + width: 100%; + height: 6px; /* Track height */ + cursor: pointer; + + /* Needs to be transparent for ms-fill-lower and ms-fill-upper to show through */ + background: transparent; + border-color: transparent; + color: transparent; + border-width: 7px 0; /* Adjust vertical positioning; (thumb height - track height) / 2 */ +} + +input[type=range].tg-slider--input::-ms-fill-lower { + background: var(--primary-color); /* Color of the "selected" part */ + border-radius: 3px; /* Rounded track edges */ +} + +input[type=range].tg-slider--input::-ms-fill-upper { + background: var(--grey); /* Color of the "unselected" part */ + border-radius: 3px; /* Rounded track edges */ +} + +`); + +export { Slider }; \ No newline at end of file diff --git a/testgen/ui/components/frontend/js/components/table_group_form.js b/testgen/ui/components/frontend/js/components/table_group_form.js new file mode 100644 index 00000000..05e4c490 --- /dev/null +++ b/testgen/ui/components/frontend/js/components/table_group_form.js @@ -0,0 +1,460 @@ +/** + * @typedef TableGroup + * @type {object} + * @property {string?} table_group_id + * @property {string?} table_groups_name + * @property {string?} profiling_include_mask + * @property {string?} profiling_exclude_mask + * @property {string?} profiling_table_set + * @property {string?} table_group_schema + * @property {string?} profile_id_column_mask + * @property {string?} profile_sk_column_mask + * @property {number?} profiling_delay_days + * @property {boolean?} profile_flag_cdes + * @property {boolean?} add_scorecard_definition + * @property {boolean?} profile_use_sampling + * @property {number?} profile_sample_percent + * @property {number?} profile_sample_min_count + * @property {string?} description + * @property {string?} data_source + * @property {string?} source_system + * @property {string?} source_process + * @property {string?} data_location + * @property {string?} business_domain + * @property {string?} stakeholder_group + * @property {string?} transform_level + * @property {string?} data_product + * + * @typedef FormState + * @type {object} + * @property {boolean} dirty + * @property {boolean} valid + * + * @typedef Properties + * @type {object} + * @property {TableGroup} tableGroup + * @property {(tg: TableGroup, state: FormState) => void} onChange + */ +import van from '../van.min.js'; +import { getValue, isEqual, loadStylesheet } from '../utils.js'; +import { Input } from './input.js'; +import { Checkbox } from './checkbox.js'; +import { ExpansionPanel } from './expansion_panel.js'; +import { required } from '../form_validators.js'; + +const { div, span } = van.tags; + +/** + * + * @param {Properties} props + * @returns + */ +const TableGroupForm = (props) => { + loadStylesheet('table-group-form', stylesheet); + + const tableGroup = getValue(props.tableGroup); + const tableGroupsName = van.state(tableGroup.table_groups_name); + const profilingIncludeMask = van.state(tableGroup.profiling_include_mask ?? '%'); + const profilingExcludeMask = van.state(tableGroup.profiling_exclude_mask ?? 'tmp%'); + const profilingTableSet = van.state(tableGroup.profiling_table_set); + const tableGroupSchema = van.state(tableGroup.table_group_schema); + const profileIdColumnMask = van.state(tableGroup.profile_id_column_mask ?? '%_id'); + const profileSkColumnMask = van.state(tableGroup.profile_sk_column_mask ?? '%_sk'); + const profilingDelayDays = van.state(tableGroup.profiling_delay_days ?? 0); + const profileFlagCdes = van.state(tableGroup.profile_flag_cdes ?? true); + const addScorecardDefinition = van.state(tableGroup.add_scorecard_definition ?? true); + const profileUseSampling = van.state(tableGroup.profile_use_sampling ?? false); + const profileSamplePercent = van.state(tableGroup.profile_sample_percent ?? 30); + const profileSampleMinCount = van.state(tableGroup.profile_sample_min_count ?? 15000); + const description = van.state(tableGroup.description); + const dataSource = van.state(tableGroup.data_source); + const sourceSystem = van.state(tableGroup.source_system); + const sourceProcess = van.state(tableGroup.source_process); + const dataLocation = van.state(tableGroup.data_location); + const businessDomain = van.state(tableGroup.business_domain); + const stakeholderGroup = van.state(tableGroup.stakeholder_group); + const transformLevel = van.state(tableGroup.transform_level); + const dataProduct = van.state(tableGroup.data_product); + + const updatedTableGroup = van.derive(() => { + return { + table_group_id: tableGroup.table_group_id, + table_groups_name: tableGroupsName.val, + profiling_include_mask: profilingIncludeMask.val, + profiling_exclude_mask: profilingExcludeMask.val, + profiling_table_set: profilingTableSet.val, + table_group_schema: tableGroupSchema.val, + profile_id_column_mask: profileIdColumnMask.val, + profile_sk_column_mask: profileSkColumnMask.val, + profiling_delay_days: profilingDelayDays.val, + profile_flag_cdes: profileFlagCdes.val, + add_scorecard_definition: addScorecardDefinition.val, + profile_use_sampling: profileUseSampling.val, + profile_sample_percent: profileSamplePercent.val, + profile_sample_min_count: profileSampleMinCount.val, + description: description.val, + data_source: dataSource.val, + source_system: sourceSystem.val, + source_process: sourceProcess.val, + data_location: dataLocation.val, + business_domain: businessDomain.val, + stakeholder_group: stakeholderGroup.val, + transform_level: transformLevel.val, + data_product: dataProduct.val, + }; + }); + const dirty = van.derive(() => !isEqual(updatedTableGroup.val, tableGroup)); + const validityPerField = van.state({}); + + van.derive(() => { + const fieldsValidity = validityPerField.val; + const isValid = Object.keys(fieldsValidity).length > 0 && + Object.values(fieldsValidity).every(v => v); + props.onChange?.(updatedTableGroup.val, { dirty: dirty.val, valid: isValid }); + }); + + const setFieldValidity = (field, validity) => { + validityPerField.val = {...validityPerField.val, [field]: validity}; + } + + return div( + { class: 'flex-column fx-gap-3' }, + MainForm( + { setValidity: setFieldValidity }, + tableGroupsName, + profilingIncludeMask, + profilingExcludeMask, + profilingTableSet, + tableGroupSchema, + profileIdColumnMask, + profileSkColumnMask, + profilingDelayDays, + profileFlagCdes, + addScorecardDefinition, + ), + SamplingForm( + { setValidity: setFieldValidity }, + profileUseSampling, + profileSamplePercent, + profileSampleMinCount, + ), + TaggingForm( + { setValidity: setFieldValidity }, + description, + dataSource, + sourceSystem, + sourceProcess, + dataLocation, + businessDomain, + stakeholderGroup, + transformLevel, + dataProduct, + ), + ); +}; + +const MainForm = ( + options, + tableGroupsName, + profilingIncludeMask, + profilingExcludeMask, + profilingTableSet, + tableGroupSchema, + profileIdColumnMask, + profileSkColumnMask, + profilingDelayDays, + profileFlagCdes, + addScorecardDefinition, +) => { + return div( + { class: 'tg-main-form flex-column fx-gap-3 fx-flex-wrap' }, + Input({ + name: 'table_groups_name', + label: 'Name', + value: tableGroupsName, + height: 38, + help: 'Unique name to describe the table group', + helpPlacement: 'bottom-right', + onChange: (value, state) => { + tableGroupsName.val = value; + options.setValidity?.('table_groups_name', state.valid); + }, + validators: [ required ], + }), + Input({ + name: 'profiling_include_mask', + label: 'Tables to Include Mask', + value: profilingIncludeMask, + height: 38, + help: 'SQL filter supported by your database\'s LIKE operator for table names to include', + onChange: (value, state) => { + profilingIncludeMask.val = value; + options.setValidity?.('profiling_include_mask', state.valid); + }, + }), + Input({ + name: 'profiling_exclude_mask', + label: 'Tables to Exclude Mask', + value: profilingExcludeMask, + height: 38, + help: 'SQL filter supported by your database\'s LIKE operator for table names to exclude', + onChange: (value, state) => { + profilingExcludeMask.val = value; + options.setValidity?.('profiling_exclude_mask', state.valid); + }, + }), + Input({ + name: 'profiling_table_set', + label: 'Explicit Table List', + value: profilingTableSet, + height: 38, + help: 'List of specific table names to include, separated by commas', + onChange: (value, state) => { + profilingTableSet.val = value; + options.setValidity?.('profiling_table_set', state.valid); + }, + }), + Checkbox({ + name: 'profile_flag_cdes', + label: 'Detect critical data elements (CDE) during profiling', + checked: profileFlagCdes, + onChange: (value) => profileFlagCdes.val = value, + }), + Input({ + name: 'table_group_schema', + label: 'Schema', + value: tableGroupSchema, + height: 38, + help: 'Database schema containing the tables for the Table Group', + helpPlacement: 'bottom-left', + onChange: (value, state) => { + tableGroupSchema.val = value; + options.setValidity?.('table_group_schema', state.valid); + }, + validators: [ required ], + }), + Input({ + name: 'profile_id_column_mask', + label: 'Profiling ID Column Mask', + value: profileIdColumnMask, + height: 38, + help: 'SQL filter supported by your database\'s LIKE operator representing ID columns (optional)', + onChange: (value, state) => { + profileIdColumnMask.val = value; + options.setValidity?.('profile_id_column_mask', state.valid); + }, + }), + Input({ + name: 'profile_sk_column_mask', + label: 'Profiling Surrogate Key Column Mask', + value: profileSkColumnMask, + height: 38, + help: 'SQL filter supported by your database\'s LIKE operator representing surrogate key columns (optional)', + onChange: (value, state) => { + profileSkColumnMask.val = value + options.setValidity?.('profile_sk_column_mask', state.valid); + }, + }), + Input({ + name: 'profiling_delay_days', + type: 'number', + label: 'Min Profiling Age (in days)', + value: profilingDelayDays, + height: 38, + help: 'Number of days to wait before new profiling will be available to generate tests', + onChange: (value, state) => { + profilingDelayDays.val = value; + options.setValidity?.('profiling_delay_days', state.valid); + }, + }), + Checkbox({ + name: 'add_scorecard_definition', + label: 'Add scorecard for table group', + help: 'Add a new scorecard to the Quality Dashboard upon creation of this table group', + checked: addScorecardDefinition, + onChange: (value) => addScorecardDefinition.val = value, + }), + ); +}; + +const SamplingForm = ( + options, + profileUseSampling, + profileSamplePercent, + profileSampleMinCount, +) => { + return div( + { class: 'tg-sampling-form flex-column fx-gap-3' }, + Checkbox({ + name: 'profile_use_sampling', + label: 'Use profile sampling', + help: 'When checked, profiling will be based on a sample of records instead of the full table', + checked: profileUseSampling, + onChange: (value) => profileUseSampling.val = value, + }), + ExpansionPanel( + { title: 'Sampling Parameters', testId: 'sampling-panel' }, + div( + { class: 'flex-row fx-gap-3' }, + Input({ + name: 'profile_sample_percent', + class: 'fx-flex', + type: 'number', + label: 'Sample percent', + value: profileSamplePercent, + height: 38, + help: 'Percent of records to include in the sample, unless the calculated count falls below the specified minimum', + onChange: (value, state) => { + profileSamplePercent.val = value; + options.setValidity?.('profile_sample_percent', state.valid); + }, + }), + Input({ + name: 'profile_sample_min_count', + class: 'fx-flex', + type: 'number', + label: 'Min Sample Record Count', + value: profileSampleMinCount, + height: 38, + help: 'Minimum number of records to be included in any sample (if available)', + onChange: (value, state) => { + profileSampleMinCount.val = value; + options.setValidity?.('profile_sample_min_count', state.valid); + }, + }), + ), + ), + ); +}; + +const TaggingForm = ( + options, + description, + dataSource, + sourceSystem, + sourceProcess, + dataLocation, + businessDomain, + stakeholderGroup, + transformLevel, + dataProduct, +) => { + return ExpansionPanel( + { title: 'Table Group Tags', testId: 'tags-panel' }, + Input({ + name: 'description', + class: 'fx-flex mb-3', + label: 'Description', + value: description, + height: 38, + onChange: (value, state) => { + description.val = value; + options.setValidity?.('description', state.valid); + }, + }), + div( + { class: 'tg-tagging-form-fields flex-column fx-gap-3 fx-flex-wrap' }, + Input({ + name: 'data_source', + label: 'Data Source', + value: dataSource, + height: 38, + help: 'Original source of the dataset', + onChange: (value, state) => { + dataSource.val = value; + options.setValidity?.('data_source', state.valid); + }, + }), + Input({ + name: 'source_process', + label: 'Source Process', + value: sourceProcess, + height: 38, + help: 'Process, program, or data flow that produced the dataset', + onChange: (value, state) => { + sourceProcess.val = value; + options.setValidity?.('source_process', state.valid); + }, + }), + Input({ + name: 'business_domain', + label: 'Business Domain', + value: businessDomain, + height: 38, + help: 'Business division responsible for the dataset, e.g., Finance, Sales, Manufacturing', + onChange: (value, state) => { + businessDomain.val = value; + options.setValidity?.('business_domain', state.valid); + }, + }), + Input({ + name: 'transform_level', + label: 'Transform Level', + value: transformLevel, + height: 38, + help: 'Data warehouse processing stage, e.g., Raw, Conformed, Processed, Reporting, or Medallion level (bronze, silver, gold)', + onChange: (value, state) => { + transformLevel.val = value; + options.setValidity?.('transform_level', state.valid); + }, + }), + Input({ + name: 'source_system', + label: 'Source System', + value: sourceSystem, + height: 38, + help: 'Enterprise system source for the dataset', + onChange: (value, state) => { + sourceSystem.val = value; + options.setValidity?.('source_system', state.valid); + }, + }), + Input({ + name: 'data_location', + label: 'Data Location', + value: dataLocation, + height: 38, + help: 'Physical or virtual location of the dataset, e.g., Headquarters, Cloud', + onChange: (value, state) => { + dataLocation.val = value; + options.setValidity?.('data_location', state.valid); + }, + }), + Input({ + name: 'stakeholder_group', + label: 'Stakeholder Group', + value: stakeholderGroup, + height: 38, + help: 'Data owners or stakeholders responsible for the dataset', + onChange: (value, state) => { + stakeholderGroup.val = value; + options.setValidity?.('stakeholder_group', state.valid); + }, + }), + Input({ + name: 'data_product', + label: 'Data Product', + value: dataProduct, + height: 38, + help: 'Data domain that comprises the dataset', + onChange: (value, state) => { + dataProduct.val = value; + options.setValidity?.('data_product', state.valid); + }, + }), + ), + ); +}; + +const stylesheet = new CSSStyleSheet(); +stylesheet.replace(` +.tg-main-form { + height: 316px; +} + +.tg-tagging-form-fields { + height: 332px; +} +`); + +export { TableGroupForm }; diff --git a/testgen/ui/components/frontend/js/components/toggle.js b/testgen/ui/components/frontend/js/components/toggle.js index b8b5ca14..8d01755a 100644 --- a/testgen/ui/components/frontend/js/components/toggle.js +++ b/testgen/ui/components/frontend/js/components/toggle.js @@ -2,6 +2,7 @@ * @typedef Properties * @type {object} * @property {string} label + * @property {string?} name * @property {boolean?} checked * @property {function(boolean)?} onChange */ @@ -14,11 +15,12 @@ const Toggle = (/** @type Properties */ props) => { loadStylesheet('toggle', stylesheet); return label( - { class: 'flex-row fx-gap-2 clickable' }, + { class: 'flex-row fx-gap-2 clickable', 'data-testid': props.name ?? '' }, input({ type: 'checkbox', role: 'switch', class: 'tg-toggle--input clickable', + name: props.name ?? '', checked: props.checked, onchange: van.derive(() => { const onChange = props.onChange?.val ?? props.onChange; diff --git a/testgen/ui/components/frontend/js/display_utils.js b/testgen/ui/components/frontend/js/display_utils.js index 652d3822..bc7c1a9d 100644 --- a/testgen/ui/components/frontend/js/display_utils.js +++ b/testgen/ui/components/frontend/js/display_utils.js @@ -44,6 +44,27 @@ function capitalize(/** @type string */ text) { .join(' '); } +/** + * Display bytes in the closest unit with an integer part. + * + * @param {number} bytes + * @returns {string} + */ +function humanReadableSize(bytes) { + const thresholds = { + MB: 1024 * 1024, + KB: 1024, + }; + + for (const [unit, startsAt] of Object.entries(thresholds)) { + if (bytes > startsAt) { + return `${(bytes / startsAt).toFixed()}${unit}`; + } + } + + return `${bytes}B`; +} + // https://m2.material.io/design/color/the-color-system.html#tools-for-picking-colors const colorMap = { red: '#EF5350', // Red 400 @@ -68,4 +89,4 @@ const colorMap = { const DISABLED_ACTION_TEXT = 'You do not have permissions to perform this action. Contact your administrator.'; -export { formatTimestamp, formatDuration, roundDigits, capitalize, colorMap, DISABLED_ACTION_TEXT }; +export { formatTimestamp, formatDuration, roundDigits, capitalize, humanReadableSize, colorMap, DISABLED_ACTION_TEXT }; diff --git a/testgen/ui/components/frontend/js/form_validators.js b/testgen/ui/components/frontend/js/form_validators.js new file mode 100644 index 00000000..905003d3 --- /dev/null +++ b/testgen/ui/components/frontend/js/form_validators.js @@ -0,0 +1,72 @@ +/** + * @typedef Validator + * @type {Function} + * @param {any} value + * @param {object} form + * @returns {string} + */ + +function required(value) { + if (!value) { + return 'This field is required' + } + return null; +} + +/** + * + * @param {number} min + * @returns {Validator} + */ +function minLength(min) { + return (value) => { + if (typeof value !== 'string' || value.length < min) { + return `Value must be at least ${min} characters long.`; + } + return null; + }; +} + +/** + * + * @param {number} max + * @returns {Validator} + */ +function maxLength(max) { + return (value) => { + if (typeof value !== 'string' || value.length > max) { + return `Value must be ${max} characters long or shorter.`; + } + return null; + }; +} + +/** + * To use with FileInput, enforce a cap on file size + * allowed to upload. + * + * @param {number} size + * @returns {Validator} + */ +function sizeLimit(limit) { + /** + * @import {FileValue} from './components/file_input.js'; + * @param {FileValue} value + */ + const validator = (value) => { + if (value != null && value.size > limit) { + return `Uploaded file must be smaller than ${limit}.`; + } + return null; + }; + validator['args'] = { name: 'sizeLimit', limit }; + + return validator; +} + +export { + maxLength, + minLength, + required, + sizeLimit, +}; diff --git a/testgen/ui/components/frontend/js/main.js b/testgen/ui/components/frontend/js/main.js index 0061db66..d4854cc5 100644 --- a/testgen/ui/components/frontend/js/main.js +++ b/testgen/ui/components/frontend/js/main.js @@ -6,6 +6,7 @@ * @property {object} props - object with the props to pass to the rendered component */ import van from './van.min.js'; +import pluginSpec from './plugins.js'; import { Streamlit } from './streamlit.js'; import { isEqual, getParents } from './utils.js'; import { Button } from './components/button.js' @@ -17,7 +18,6 @@ import { SortingSelector } from './components/sorting_selector.js'; import { ColumnSelector } from './components/explorer_column_selector.js'; import { TestRuns } from './pages/test_runs.js'; import { ProfilingRuns } from './pages/profiling_runs.js'; -import { DatabaseFlavorSelector } from './components/flavor_selector.js'; import { DataCatalog } from './pages/data_catalog.js'; import { ProjectDashboard } from './pages/project_dashboard.js'; import { TestSuites } from './pages/test_suites.js'; @@ -27,6 +27,8 @@ import { ScoreExplorer } from './pages/score_explorer.js'; import { ColumnProfilingResults } from './data_profiling/column_profiling_results.js'; import { ColumnProfilingHistory } from './data_profiling/column_profiling_history.js'; import { ScheduleList } from './pages/schedule_list.js'; +import { Connections } from './pages/connections.js'; +import { TableGroupWizard } from './pages/table_group_wizard.js'; let currentWindowVan = van; let topWindowVan = window.top.van; @@ -42,7 +44,6 @@ const TestGenComponent = (/** @type {string} */ id, /** @type {object} */ props) sidebar: window.top.testgen.components.Sidebar, test_runs: TestRuns, profiling_runs: ProfilingRuns, - database_flavor_selector: DatabaseFlavorSelector, data_catalog: DataCatalog, column_profiling_results: ColumnProfilingResults, column_profiling_history: ColumnProfilingHistory, @@ -53,17 +54,22 @@ const TestGenComponent = (/** @type {string} */ id, /** @type {object} */ props) score_explorer: ScoreExplorer, schedule_list: ScheduleList, column_selector: ColumnSelector, + connections: Connections, + table_group_wizard: TableGroupWizard, }; - if (Object.keys(componentById).includes(id)) { + if (Object.keys(window.testgen.plugins).includes(id)) { + return window.testgen.plugins[id](props); + } else if (Object.keys(componentById).includes(id)) { return componentById[id](props); } - return ''; }; -window.addEventListener('message', (event) => { +window.addEventListener('message', async (event) => { if (event.data.type === 'streamlit:render') { + await loadPlugins(); + const componentId = event.data.args.id; const componentKey = event.data.args.key; @@ -134,8 +140,29 @@ function shouldRenderOutsideFrame(componentId) { return 'sidebar' === componentId; } +async function loadPlugins() { + if (!window.testgen.pluginsLoaded) { + try { + const modules = await Promise.all(Object.values(pluginSpec).map(plugin => import(plugin.entrypoint))) + for (const pluginModule of modules) { + if (pluginModule && pluginModule.components) { + Object.assign(window.testgen.plugins, pluginModule.components) + } else if (pluginModule) { + console.warn(`Plugin '${pluginModule}' does not export a member 'components'.`); + } + } + } catch (error) { + console.warn('Error loading plugins:', error); + } + } + + window.testgen.pluginsLoaded = true; +} + window.testgen = { states: {}, loadedStylesheets: {}, portals: {}, + plugins: {}, + pluginsLoaded: false, }; diff --git a/testgen/ui/components/frontend/js/pages/connections.js b/testgen/ui/components/frontend/js/pages/connections.js new file mode 100644 index 00000000..cd28e96b --- /dev/null +++ b/testgen/ui/components/frontend/js/pages/connections.js @@ -0,0 +1,130 @@ +/** + * @import { Connection, Flavor } from '../components/connection_form.js'; + * + * @typedef Results + * @type {object} + * @property {boolean} success + * @property {string} message + * + * @typedef Permissions + * @type {object} + * @property {boolean} is_admin + * + * @typedef Properties + * @type {object} + * @property {Connection} connection + * @property {boolean} has_table_groups + * @property {Array} flavors + * @property {Permissions} permissions + * @property {Results?} results + */ +import van from '../van.min.js'; +import { Streamlit } from '../streamlit.js'; +import { loadStylesheet, resizeFrameHeightToElement, resizeFrameHeightOnDOMChange, getValue, emitEvent } from '../utils.js'; +import { ConnectionForm } from '../components/connection_form.js'; +import { Button } from '../components/button.js'; +import { Link } from '../components/link.js'; +import { Alert } from '../components/alert.js'; + +const { div, span } = van.tags; + +/** + * + * @param {Properties} props + * @returns + */ +const Connections = (props) => { + loadStylesheet('connections', stylesheet); + Streamlit.setFrameHeight(1); + window.testgen.isPage = true; + + const wrapperId = 'connections-list-wrapper'; + const connection = getValue(props.connection); + const connectionId = connection.connection_id; + const updatedConnection = van.state(connection); + const formState = van.state({dirty: false, valid: false}); + + resizeFrameHeightToElement(wrapperId); + resizeFrameHeightOnDOMChange(wrapperId); + + return div( + { id: wrapperId, class: 'flex-column fx-gap-4' }, + div( + { class: 'flex-row fx-justify-content-flex-end' }, + () => getValue(props.has_table_groups) + ? Link({ + href: 'connections:table-groups', + params: {"connection_id": connectionId}, + label: 'Manage Table Groups', + right_icon: 'chevron_right', + class: 'tg-connections--link', + }) + : Button({ + type: 'stroked', + color: 'primary', + icon: 'table_view', + label: 'Setup Table Groups', + width: 'auto', + disabled: !getValue(props.permissions).is_admin, + tooltip: 'You do not have permissions to perform this action. Contact your administrator.', + onclick: () => emitEvent('SetupTableGroupClicked', {}), + }), + ), + div( + { class: 'flex-column fx-gap-4 tg-connections--border p-4' }, + ConnectionForm( + { + connection: props.connection, + flavors: props.flavors, + disableFlavor: false, + onChange: (connection, state) => { + formState.val = state; + updatedConnection.val = connection; + }, + }, + () => { + const hasSavePermission = getValue(props.permissions).is_admin; + if (!hasSavePermission) { + return ''; + } + + const formState_ = formState.val; + const canSave = formState_.dirty && formState_.valid; + return Button({ + label: 'Save', + color: 'primary', + type: 'flat', + width: 'auto', + disabled: !canSave, + onclick: () => emitEvent('SaveConnectionClicked', { payload: updatedConnection.val }), + }); + }, + ), + () => { + const results = getValue(props.results) ?? {}; + return Object.keys(results).length > 0 + ? Alert({ type: results.success ? 'success' : 'error' }, span(results.message)) + : ''; + }, + ), + ); +} + +const stylesheet = new CSSStyleSheet(); +stylesheet.replace(` +.tg-connections--border { + border: var(--button-stroked-border); + border-radius: 8px; +} + +.tg-connections--link { + margin-left: auto; + border-radius: 4px; + background: var(--dk-card-background); + border: var(--button-stroked-border); + padding: 8px 8px 8px 16px; + color: var(--primary-color) !important; +} +`); + +export { Connections }; diff --git a/testgen/ui/components/frontend/js/pages/table_group_wizard.js b/testgen/ui/components/frontend/js/pages/table_group_wizard.js new file mode 100644 index 00000000..c2fcd7ea --- /dev/null +++ b/testgen/ui/components/frontend/js/pages/table_group_wizard.js @@ -0,0 +1,202 @@ +/** + * @typedef WizardResult + * @type {object} + * @property {boolean} success + * @property {string} message + * @property {string} table_group_id + * + * @typedef Properties + * @type {object} + * @property {string} project_code + * @property {string} connection_id + * @property {WizardResult?} results + */ +import van from '../van.min.js'; +import { Streamlit } from '../streamlit.js'; +import { TableGroupForm } from '../components/table_group_form.js'; +import { emitEvent, getValue, resizeFrameHeightOnDOMChange, resizeFrameHeightToElement } from '../utils.js'; +import { Button } from '../components/button.js'; +import { Alert } from '../components/alert.js'; +import { Checkbox } from '../components/checkbox.js'; +import { Icon } from '../components/icon.js'; + +const { div, i, span, strong } = van.tags; + +/** + * @param {Properties} props + */ +const TableGroupWizard = (props) => { + Streamlit.setFrameHeight(1); + window.testgen.isPage = true; + + const steps = [ + 'tableGroup', + 'runProfiling', + ]; + const stepsState = { + tableGroup: van.state({}), + runProfiling: van.state(true), + }; + const stepsValidity = { + tableGroup: van.state(false), + runProfiling: van.state(true), + }; + const currentStepIndex = van.state(0); + const currentStepIsInvalid = van.derive(() => { + const stepKey = steps[currentStepIndex.val]; + return !stepsValidity[stepKey].val; + }); + const nextButtonType = van.derive(() => { + const isLastStep = currentStepIndex.val === steps.length - 1; + return isLastStep ? 'flat' : 'stroked'; + }); + const nextButtonLabel = van.derive(() => { + const isLastStep = currentStepIndex.val === steps.length - 1; + if (isLastStep) { + return stepsState.runProfiling.val ? 'Save & Run Profiling' : 'Finish Setup'; + } + return 'Next'; + }); + const setStep = (stepIdx) => { + currentStepIndex.val = stepIdx; + }; + const saveTableGroup = () => { + const payload = { + table_group: stepsState.tableGroup.val, + run_profiling: stepsState.runProfiling.val, + }; + emitEvent('SaveTableGroupClicked', { payload }); + }; + + const domId = 'table-group-wizard-wrapper'; + resizeFrameHeightToElement(domId); + resizeFrameHeightOnDOMChange(domId); + + return div( + { id: domId, class: 'tg-table-group-wizard flex-column fx-gap-3' }, + WizardStep(0, currentStepIndex, () => { + currentStepIndex.val; + + return TableGroupForm({ + tableGroup: stepsState.tableGroup.rawVal, + onChange: (updatedTableGroup, state) => { + stepsState.tableGroup.val = updatedTableGroup; + stepsValidity.tableGroup.val = state.valid; + }, + }); + }), + () => { + const results = getValue(props.results); + const runProfiling = van.state(stepsState.runProfiling.rawVal); + + van.derive(() => { + stepsState.runProfiling.val = runProfiling.val; + }); + + return WizardStep(1, currentStepIndex, () => { + currentStepIndex.val; + + return RunProfilingStep( + stepsState.tableGroup.rawVal, + runProfiling, + results, + ); + }); + }, + div( + { class: 'tg-table-group-wizard--footer flex-row' }, + () => currentStepIndex.val > 0 + ? Button({ + label: 'Previous', + type: 'stroked', + color: 'basic', + width: 'auto', + style: 'margin-right: auto; min-width: 200px;', + onclick: () => setStep(currentStepIndex.val - 1), + }) + : '', + () => { + const results = getValue(props.results); + const runProfiling = stepsState.runProfiling.val; + + if (results && results.success && runProfiling) { + return Button({ + type: 'stroked', + color: 'primary', + label: 'Go to Profiling Runs', + width: 'auto', + icon: 'chevron_right', + onclick: () => emitEvent('GoToProfilingRunsClicked', { payload: { table_group_id: results.table_group_id } }), + }); + } + + return Button({ + label: nextButtonLabel, + type: nextButtonType, + color: 'primary', + width: 'auto', + style: 'margin-left: auto; min-width: 200px;', + disabled: currentStepIsInvalid, + onclick: () => { + if (currentStepIndex.val < steps.length - 1) { + return setStep(currentStepIndex.val + 1); + } + + saveTableGroup(); + }, + }); + }, + ), + ); +}; + +/** + * @param {object} tableGroup + * @param {boolean} runProfiling + * @param {WizardResult} result + * @returns + */ +const RunProfilingStep = (tableGroup, runProfiling, results) => { + return div( + { class: 'flex-column fx-gap-3' }, + Checkbox({ + label: div( + { class: 'flex-row'}, + span({ class: 'mr-1' }, 'Execute profiling for the table group'), + strong(() => tableGroup.table_groups_name), + span('?'), + ), + checked: runProfiling, + onChange: (value) => runProfiling.val = value, + }), + div( + { class: 'flex-row fx-gap-1' }, + Icon({}, 'info'), + () => runProfiling.val + ? i('Profiling will be performed in a background process.') + : i('Profiling will be skipped. You can run this step later from the Profiling Runs page.'), + ), + () => { + const results_ = getValue(results) ?? {}; + return Object.keys(results_).length > 0 + ? Alert({ type: results_.success ? 'success' : 'error' }, span(results_.message)) + : ''; + }, + ); +}; + +/** + * @param {number} index + * @param {number} currentIndex + * @param {any} content + */ +const WizardStep = (index, currentIndex, content) => { + const hidden = van.derive(() => getValue(currentIndex) !== getValue(index)); + + return div( + { class: () => hidden.val ? 'hidden' : ''}, + content, + ); +}; + +export { TableGroupWizard }; diff --git a/testgen/ui/components/frontend/js/utils.js b/testgen/ui/components/frontend/js/utils.js index caab512e..9edab31e 100644 --- a/testgen/ui/components/frontend/js/utils.js +++ b/testgen/ui/components/frontend/js/utils.js @@ -14,9 +14,12 @@ function enforceElementWidth( function resizeFrameHeightToElement(/** @type string */elementId) { const observer = new ResizeObserver(() => { - const height = document.getElementById(elementId).offsetHeight; - if (height) { - Streamlit.setFrameHeight(height); + const element = document.getElementById(elementId); + if (element) { + const height = element.offsetHeight; + if (height) { + Streamlit.setFrameHeight(height); + } } }); observer.observe(window.frameElement); @@ -24,9 +27,12 @@ function resizeFrameHeightToElement(/** @type string */elementId) { function resizeFrameHeightOnDOMChange(/** @type string */elementId) { const observer = new MutationObserver(() => { - const height = document.getElementById(elementId).offsetHeight; - if (height) { - Streamlit.setFrameHeight(height); + const element = document.getElementById(elementId); + if (element) { + const height = element.offsetHeight; + if (height) { + Streamlit.setFrameHeight(height); + } } }); observer.observe(window.frameElement.contentDocument.body, {subtree: true, childList: true}); diff --git a/testgen/ui/components/widgets/testgen_component.py b/testgen/ui/components/widgets/testgen_component.py index f4866bdd..ee80d18d 100644 --- a/testgen/ui/components/widgets/testgen_component.py +++ b/testgen/ui/components/widgets/testgen_component.py @@ -8,7 +8,6 @@ from testgen.ui.session import session AvailablePages = typing.Literal[ - "database_flavor_selector", "data_catalog", "column_profiling_results", "project_dashboard", @@ -19,6 +18,8 @@ "score_details", "schedule_list", "column_selector", + "connections", + "table_group_wizard", ] diff --git a/testgen/ui/queries/connection_queries.py b/testgen/ui/queries/connection_queries.py index dead1744..c3aad89c 100644 --- a/testgen/ui/queries/connection_queries.py +++ b/testgen/ui/queries/connection_queries.py @@ -10,7 +10,7 @@ def get_by_id(connection_id): str_schema = st.session_state["dbschema"] str_sql = f""" SELECT id::VARCHAR(50), project_code, connection_id, connection_name, - sql_flavor, project_host, project_port, project_user, + sql_flavor, sql_flavor_code, project_host, project_port, project_user, project_db, project_pw_encrypted, NULL as password, max_threads, max_query_chars, url, connect_by_url, connect_by_key, private_key, private_key_passphrase, http_path @@ -24,7 +24,7 @@ def get_connections(project_code): str_schema = st.session_state["dbschema"] str_sql = f""" SELECT id::VARCHAR(50), project_code, connection_id, connection_name, - sql_flavor, project_host, project_port, project_user, + sql_flavor, sql_flavor_code, project_host, project_port, project_user, project_db, project_pw_encrypted, NULL as password, max_threads, max_query_chars, connect_by_url, url, connect_by_key, private_key, private_key_passphrase, http_path @@ -42,31 +42,32 @@ def get_table_group_names_by_connection(schema: str, connection_ids: list[str]) def edit_connection(schema, connection, encrypted_password, encrypted_private_key, encrypted_private_key_passphrase): - sql = f"""UPDATE {schema}.connections SET - project_code = '{connection["project_code"]}', - sql_flavor = '{connection["sql_flavor"]}', - project_host = '{connection["project_host"]}', - project_port = '{connection["project_port"]}', - project_user = '{connection["project_user"]}', - project_db = '{connection["project_db"]}', - connection_name = '{connection["connection_name"]}', - max_threads = '{connection["max_threads"]}', - max_query_chars = '{connection["max_query_chars"]}', - url = '{connection["url"]}', - connect_by_key = '{connection["connect_by_key"]}', - connect_by_url = '{connection["connect_by_url"]}', - http_path = '{connection["http_path"]}'""" - - if encrypted_password: - sql += f""", project_pw_encrypted = '{encrypted_password}' """ - - if encrypted_private_key: - sql += f""", private_key = '{encrypted_private_key}' """ - - if encrypted_private_key_passphrase: - sql += f""", private_key_passphrase = '{encrypted_private_key_passphrase}' """ - - sql += f""" WHERE connection_id = '{connection["connection_id"]}';""" + encrypted_password_value = f"'{encrypted_password}'" if encrypted_password is not None else "null" + encrypted_private_key_value = f"'{encrypted_private_key}'" if encrypted_private_key is not None else "null" + encrypted_passphrase_value = f"'{encrypted_private_key_passphrase}'" if encrypted_private_key_passphrase is not None else "null" + + sql = f""" + UPDATE {schema}.connections + SET + project_code = '{connection["project_code"]}', + sql_flavor = '{connection["sql_flavor"]}', + sql_flavor_code = '{connection["sql_flavor_code"]}', + project_host = '{connection["project_host"]}', + project_port = '{connection["project_port"]}', + project_user = '{connection["project_user"]}', + project_db = '{connection["project_db"]}', + connection_name = '{connection["connection_name"]}', + max_threads = '{connection["max_threads"]}', + max_query_chars = '{connection["max_query_chars"]}', + url = '{connection["url"]}', + connect_by_key = '{connection["connect_by_key"]}', + connect_by_url = '{connection["connect_by_url"]}', + http_path = '{connection["http_path"]}', + project_pw_encrypted = {encrypted_password_value}, + private_key = {encrypted_private_key_value}, + private_key_passphrase = {encrypted_passphrase_value} + WHERE connection_id = '{connection["connection_id"]}'; + """ db.execute_sql(sql) st.cache_data.clear() @@ -79,13 +80,14 @@ def add_connection( encrypted_private_key_passphrase: str | None, ) -> int: sql_header = f"""INSERT INTO {schema}.connections - (project_code, sql_flavor, url, connect_by_url, connect_by_key, + (project_code, sql_flavor, sql_flavor_code, url, connect_by_url, connect_by_key, project_host, project_port, project_user, project_db, connection_name, http_path, """ sql_footer = f""" SELECT '{connection["project_code"]}' as project_code, '{connection["sql_flavor"]}' as sql_flavor, + '{connection["sql_flavor_code"]}' as sql_flavor_code, '{connection["url"]}' as url, {connection["connect_by_url"]} as connect_by_url, {connection["connect_by_key"]} as connect_by_key, diff --git a/testgen/ui/services/connection_service.py b/testgen/ui/services/connection_service.py index 8bf53a51..215fc7d4 100644 --- a/testgen/ui/services/connection_service.py +++ b/testgen/ui/services/connection_service.py @@ -165,12 +165,12 @@ def init_profiling_sql(project_code, connection, table_group_schema=None): return clsProfiling -def form_overwritten_connection_url(connection): +def form_overwritten_connection_url(connection) -> str: flavor = connection["sql_flavor"] connection_credentials = { "flavor": flavor, - "user": "", + "user": "", "host": connection["project_host"], "port": connection["project_port"], "dbname": connection["project_db"], @@ -188,3 +188,25 @@ def form_overwritten_connection_url(connection): connection_string = flavor_service.get_connection_string("") return connection_string + + +def get_connection_string(flavor: str) -> str: + db_type = get_db_type(flavor) + flavor_service = get_flavor_service(db_type) + flavor_service.init({ + "flavor": flavor, + "user": "", + "host": "", + "port": "", + "dbname": "", + "url": None, + "connect_by_url": None, + "connect_by_key": False, + "private_key": None, + "private_key_passphrase": "", + "dbschema": "", + "http_path": "", + }) + return flavor_service.get_connection_string( + "" + ).replace("%3E", ">").replace("%3C", "<") diff --git a/testgen/ui/views/connections.py b/testgen/ui/views/connections.py new file mode 100644 index 00000000..b30029ef --- /dev/null +++ b/testgen/ui/views/connections.py @@ -0,0 +1,417 @@ +import base64 +import logging +import typing +from dataclasses import asdict, dataclass, field + +import streamlit as st + +try: + from pyodbc import Error as PyODBCError +except ImportError: + PyODBCError = None +from sqlalchemy.exc import DatabaseError, DBAPIError + +import testgen.ui.services.database_service as db +from testgen.commands.run_profiling_bridge import run_profiling_in_background +from testgen.common.database.database_service import empty_cache +from testgen.common.models import with_database_session +from testgen.ui.assets import get_asset_data_url +from testgen.ui.components import widgets as testgen +from testgen.ui.navigation.menu import MenuItem +from testgen.ui.navigation.page import Page +from testgen.ui.services import connection_service, table_group_service, user_session_service +from testgen.ui.session import session, temp_value +from testgen.utils import format_field + +LOG = logging.getLogger("testgen") +PAGE_TITLE = "Connection" +CLEAR_SENTINEL = "" + + +@dataclass(frozen=True, slots=True, kw_only=True) +class ConnectionFlavor: + value: str + label: str + icon: str + flavor: str + connection_string: str + + +class ConnectionsPage(Page): + path = "connections" + can_activate: typing.ClassVar = [ + lambda: session.authentication_status, + lambda: not user_session_service.user_has_catalog_role(), + lambda: "project_code" in st.query_params, + ] + menu_item = MenuItem( + icon="database", + label=PAGE_TITLE, + section="Data Configuration", + order=0, + roles=[ role for role in typing.get_args(user_session_service.RoleType) if role != "catalog" ], + ) + flavor_options: typing.ClassVar[list[ConnectionFlavor]] = [ + ConnectionFlavor( + label="Amazon Redshift", + value="redshift", + flavor="redshift", + icon=get_asset_data_url("flavors/redshift.svg"), + connection_string=connection_service.get_connection_string("redshift"), + ), + ConnectionFlavor( + label="Azure SQL Database", + value="azure_mssql", + flavor="mssql", + icon=get_asset_data_url("flavors/azure_sql.svg"), + connection_string=connection_service.get_connection_string("mssql"), + ), + ConnectionFlavor( + label="Azure Synapse Analytics", + value="synapse_mssql", + flavor="mssql", + icon=get_asset_data_url("flavors/azure_synapse_table.svg"), + connection_string=connection_service.get_connection_string("mssql"), + ), + ConnectionFlavor( + label="Microsoft SQL Server", + value="mssql", + flavor="mssql", + icon=get_asset_data_url("flavors/mssql.svg"), + connection_string=connection_service.get_connection_string("mssql"), + ), + ConnectionFlavor( + label="PostgreSQL", + value="postgresql", + flavor="postgresql", + icon=get_asset_data_url("flavors/postgresql.svg"), + connection_string=connection_service.get_connection_string("postgresql"), + ), + ConnectionFlavor( + label="Snowflake", + value="snowflake", + flavor="snowflake", + icon=get_asset_data_url("flavors/snowflake.svg"), + connection_string=connection_service.get_connection_string("snowflake"), + ), + ConnectionFlavor( + label="Databricks", + value="databricks", + flavor="databricks", + icon=get_asset_data_url("flavors/databricks.svg"), + connection_string=connection_service.get_connection_string("databricks"), + ), + ] + + def render(self, project_code: str, **_kwargs) -> None: + testgen.page_header( + PAGE_TITLE, + "connect-your-database", + ) + + dataframe = connection_service.get_connections(project_code) + connection = dataframe.iloc[0] + has_table_groups = ( + len(connection_service.get_table_group_names_by_connection([connection["connection_id"]]) or []) > 0 + ) + user_is_admin = user_session_service.user_is_admin() + should_check_status, set_check_status = temp_value( + "connections:status_check", + default=False, + ) + get_updated_connection, set_updated_connection = temp_value( + "connections:partial_value", + default={}, + ) + should_save, set_save = temp_value( + "connections:update_connection", + default=False, + ) + + def on_save_connection_clicked(updated_connection): + is_pristine = lambda value: value in ["", "***"] + + if updated_connection.get("connect_by_url", False): + url_parts = updated_connection.get("url", "").split("@") + if len(url_parts) > 1: + updated_connection["url"] = url_parts[1] + + if updated_connection.get("connect_by_key"): + updated_connection["password"] = "" + if is_pristine(updated_connection["private_key_passphrase"]): + del updated_connection["private_key_passphrase"] + else: + updated_connection["private_key"] = "" + updated_connection["private_key_passphrase"] = "" + + if updated_connection.get("private_key_passphrase") == CLEAR_SENTINEL: + updated_connection["private_key_passphrase"] = "" + + if is_pristine(updated_connection.get("private_key")): + del updated_connection["private_key"] + else: + updated_connection["private_key"] = base64.b64decode(updated_connection["private_key"]).decode() + + updated_connection["sql_flavor"] = self._get_sql_flavor_from_value(updated_connection["sql_flavor_code"]).flavor + + set_save(True) + set_updated_connection(updated_connection) + + def on_test_connection_clicked(updated_connection: dict) -> None: + password = updated_connection.get("password") + private_key = updated_connection.get("private_key") + private_key_passphrase = updated_connection.get("private_key_passphrase") + is_pristine = lambda value: value in ["", "***"] + + if is_pristine(password): + del updated_connection["password"] + + if is_pristine(private_key): + del updated_connection["private_key"] + else: + updated_connection["private_key"] = base64.b64decode(updated_connection["private_key"]).decode() + + if is_pristine(private_key_passphrase): + del updated_connection["private_key_passphrase"] + elif updated_connection.get("private_key_passphrase") == CLEAR_SENTINEL: + updated_connection["private_key_passphrase"] = "" + + updated_connection["sql_flavor"] = self._get_sql_flavor_from_value(updated_connection["sql_flavor_code"]).flavor + + set_check_status(True) + set_updated_connection(updated_connection) + + results = None + connection = {**connection.to_dict(), **get_updated_connection()} + if should_save(): + success = True + try: + connection_service.edit_connection(connection) + message = "Changes have been saved successfully." + except Exception as error: + message = "Error creating connection" + success = False + LOG.exception(message) + + results = { + "success": success, + "message": message, + } + + return testgen.testgen_component( + "connections", + props={ + "connection": self._format_connection(connection, should_test=should_check_status()), + "has_table_groups": has_table_groups, + "flavors": [asdict(flavor) for flavor in self.flavor_options], + "permissions": { + "is_admin": user_is_admin, + }, + "results": results, + }, + on_change_handlers={ + "TestConnectionClicked": on_test_connection_clicked, + "SaveConnectionClicked": on_save_connection_clicked, + "SetupTableGroupClicked": lambda _: self.setup_data_configuration(project_code, connection["connection_id"]), + }, + ) + + def _get_sql_flavor_from_value(self, value: str) -> ConnectionFlavor | None: + match = [f for f in self.flavor_options if f.value == value] + if match: + return match[0] + return None + + def _format_connection(self, connection: dict, should_test: bool = False) -> dict: + fields = [ + "project_code", + "connection_id", + "connection_name", + "sql_flavor", + "sql_flavor_code", + "project_host", + "project_port", + "project_db", + "project_user", + "password", + "max_threads", + "max_query_chars", + "connect_by_url", + "connect_by_key", + "private_key", + "private_key_passphrase", + "http_path", + "url", + ] + formatted_connection = {} + + for fieldname in fields: + formatted_connection[fieldname] = format_field(connection[fieldname]) + + if should_test: + formatted_connection["status"] = asdict(self.test_connection(connection)) + + if formatted_connection["password"]: + formatted_connection["password"] = "***" # noqa S105 + if formatted_connection["private_key"]: + formatted_connection["private_key"] = "***" # S105 + if formatted_connection["private_key_passphrase"]: + formatted_connection["private_key_passphrase"] = "***" # noqa S105 + + first_match = [f for f in self.flavor_options if f.flavor == formatted_connection.get("sql_flavor")] + if formatted_connection["sql_flavor"] and not formatted_connection.get("sql_flavor_code") and first_match: + formatted_connection["sql_flavor_code"] = first_match[0].flavor + + flavors = [f for f in self.flavor_options if f.value == formatted_connection["sql_flavor_code"]] + if flavors and (flavor := flavors[0]): + formatted_connection["flavor"] = asdict(flavor) + + return formatted_connection + + def test_connection(self, connection: dict) -> "ConnectionStatus": + empty_cache() + try: + sql_query = "select 1;" + results = db.retrieve_target_db_data( + connection["sql_flavor"], + connection["project_host"], + connection["project_port"], + connection["project_db"], + connection["project_user"], + connection["password"], + connection["url"], + connection["connect_by_url"], + connection["connect_by_key"], + connection.get("private_key", ""), + connection.get("private_key_passphrase", ""), + connection.get("http_path", ""), + sql_query, + ) + connection_successful = len(results) == 1 and results[0][0] == 1 + + if not connection_successful: + return ConnectionStatus(message="Error completing a query to the database server.", successful=False) + return ConnectionStatus(message="The connection was successful.", successful=True) + except KeyError: + return ConnectionStatus( + message="Error attempting the connection. ", + details="Complete all the required fields.", + successful=False, + ) + except DatabaseError as error: + LOG.exception("Error testing database connection") + return ConnectionStatus(message="Error attempting the connection.", details=str(error.orig), successful=False) + except DBAPIError as error: + LOG.exception("Error testing database connection") + details = str(error.orig) + if PyODBCError and isinstance(error.orig, PyODBCError) and error.orig.args: + details = error.orig.args[1] + return ConnectionStatus(message="Error attempting the connection.", details=details, successful=False) + except (TypeError, ValueError) as error: + LOG.exception("Error testing database connection") + details = str(error) + if is_open_ssl_error(error): + details = error.args[0] + return ConnectionStatus(message="Error attempting the connection.", details=details, successful=False) + except Exception as error: + details = "Try again" + if connection["connect_by_key"] and not connection.get("private_key", ""): + details = "The private key is missing." + LOG.exception("Error testing database connection") + return ConnectionStatus(message="Error attempting the connection.", details=details, successful=False) + + @st.dialog(title="Data Configuration Setup") + @with_database_session + def setup_data_configuration(self, project_code: str, connection_id: str) -> None: + def on_save_table_group_clicked(payload: dict) -> None: + table_group: dict = payload["table_group"] + run_profiling: bool = payload.get("run_profiling", False) + + set_new_table_group(table_group) + set_run_profiling(run_profiling) + + def on_go_to_profiling_runs(params: dict) -> None: + set_navigation_params({ **params, "project_code": project_code }) + + get_navigation_params, set_navigation_params = temp_value( + "connections:new_table_group:go_to_profiling_run", + default=None, + ) + if (params := get_navigation_params()): + self.router.navigate(to="profiling-runs", with_args=params) + + get_new_table_group, set_new_table_group = temp_value( + "connections:new_connection:table_group", + default={}, + ) + get_run_profiling, set_run_profiling = temp_value( + "connections:new_connection:run_profiling", + default=False, + ) + + results = None + table_group = get_new_table_group() + should_run_profiling = get_run_profiling() + if table_group: + success = True + message = None + table_group_id = None + + try: + table_group_id = table_group_service.add({ + **table_group, + "project_code": project_code, + "connection_id": connection_id, + }) + + if should_run_profiling: + try: + run_profiling_in_background(table_group_id) + message = f"Profiling run started for table group {table_group['table_groups_name']}." + except Exception as error: + message = "Profiling run encountered errors" + success = False + LOG.exception(message) + else: + LOG.info("Table group %s created", table_group_id) + st.rerun() + except Exception as error: + message = "Error creating table group" + success = False + LOG.exception(message) + + results = { + "success": success, + "message": message, + "table_group_id": table_group_id, + } + + testgen.testgen_component( + "table_group_wizard", + props={ + "project_code": project_code, + "connection_id": connection_id, + "results": results, + }, + on_change_handlers={ + "SaveTableGroupClicked": on_save_table_group_clicked, + "GoToProfilingRunsClicked": on_go_to_profiling_runs, + }, + ) + + +@dataclass(frozen=True, slots=True) +class ConnectionStatus: + message: str + successful: bool + details: str | None = field(default=None) + + +def is_open_ssl_error(error: Exception): + return ( + error.args + and len(error.args) > 1 + and isinstance(error.args[1], list) + and len(error.args[1]) > 0 + and type(error.args[1][0]).__name__ == "OpenSSLError" + ) diff --git a/testgen/ui/views/connections/__init__.py b/testgen/ui/views/connections/__init__.py deleted file mode 100644 index cc9b67f1..00000000 --- a/testgen/ui/views/connections/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# ruff: noqa: F401 - -from testgen.ui.views.connections.forms import BaseConnectionForm, KeyPairConnectionForm, PasswordConnectionForm -from testgen.ui.views.connections.models import ConnectionStatus -from testgen.ui.views.connections.page import ConnectionsPage diff --git a/testgen/ui/views/connections/forms.py b/testgen/ui/views/connections/forms.py deleted file mode 100644 index 3ece7588..00000000 --- a/testgen/ui/views/connections/forms.py +++ /dev/null @@ -1,304 +0,0 @@ -# type: ignore -import time -import typing - -import streamlit as st -from streamlit.delta_generator import DeltaGenerator -from streamlit.runtime.uploaded_file_manager import UploadedFile - -from testgen.ui.components import widgets as testgen -from testgen.ui.forms import BaseForm, Field, ManualRender, computed_field -from testgen.ui.services import connection_service - -SQL_FLAVORS = ["redshift", "snowflake", "mssql", "postgresql", "databricks"] -SQLFlavor = typing.Literal["redshift", "snowflake", "mssql", "postgresql", "databricks"] - - -class BaseConnectionForm(BaseForm, ManualRender): - connection_name: str = Field( - default="", - min_length=3, - max_length=40, - st_kwargs_max_chars=40, - st_kwargs_label="Connection Name", - st_kwargs_help="Your name for this connection. Can be any text.", - ) - project_host: str = Field( - default="", - max_length=250, - st_kwargs_max_chars=250, - st_kwargs_label="Host", - ) - project_port: str = Field(default="", max_length=5, st_kwargs_max_chars=5, st_kwargs_label="Port") - project_db: str = Field( - default="", - max_length=100, - st_kwargs_max_chars=100, - st_kwargs_label="Database", - st_kwargs_help="The name of the database defined on your host where your schemas and tables is present.", - ) - project_user: str = Field( - default="", - max_length=50, - st_kwargs_max_chars=50, - st_kwargs_label="User", - st_kwargs_help="Username to connect to your database.", - ) - connect_by_url: bool = Field( - default=False, - st_kwargs_label="URL override", - st_kwargs_help=( - "If this switch is set to on, the connection string will be driven by the field below. " - "Only user name and password will be passed per the relevant fields above." - ), - ) - url_prefix: str = Field( - default="", - readOnly=True, - st_kwargs_label="URL Prefix", - ) - url: str = Field( - default="", - max_length=200, - st_kwargs_label="URL Suffix", - st_kwargs_max_chars=200, - st_kwargs_help=( - "Provide a connection string directly. This will override connection parameters if " - "the 'Connect by URL' switch is set." - ), - ) - max_threads: int = Field( - default=4, - ge=1, - le=8, - st_kwargs_min_value=1, - st_kwargs_max_value=8, - st_kwargs_label="Max Threads (Advanced Tuning)", - st_kwargs_help=( - "Maximum number of concurrent threads that run tests. Default values should be retained unless " - "test queries are failing." - ), - ) - max_query_chars: int = Field( - default=9000, - ge=500, - le=14000, - st_kwargs_label="Max Expression Length (Advanced Tuning)", - st_kwargs_min_value=500, - st_kwargs_max_value=14000, - st_kwargs_help=( - "Some tests are consolidated into queries for maximum performance. Default values should be retained " - "unless test queries are failing." - ), - ) - - connection_id: int | None = Field(default=None) - - sql_flavor: SQLFlavor = Field( - ..., - st_kwargs_label="SQL Flavor", - st_kwargs_options=SQL_FLAVORS, - st_kwargs_help=( - "The type of database server that you will connect to. This determines TestGen's drivers and SQL dialect." - ), - ) - - def form_key(self): - return f"connection_form:{self.connection_id or 'new'}" - - def render_input_ui(self, container: DeltaGenerator, data: dict) -> "BaseConnectionForm": - time.sleep(0.1) - main_fields_container, optional_fields_container = container.columns([0.7, 0.3]) - - if self.get_field_value("connect_by_url", latest=True): - self.disable("project_host") - self.disable("project_port") - self.disable("project_db") - - self.render_field("sql_flavor", container=main_fields_container) - self.render_field("connection_name", container=main_fields_container) - host_field_container, port_field_container = main_fields_container.columns([0.8, 0.2]) - self.render_field("project_host", container=host_field_container) - self.render_field("project_port", container=port_field_container) - - self.render_field("project_db", container=main_fields_container) - self.render_field("project_user", container=main_fields_container) - self.render_field("max_threads", container=optional_fields_container) - self.render_field("max_query_chars", container=optional_fields_container) - - self.render_extra(container, main_fields_container, optional_fields_container, data) - - testgen.divider(margin_top=8, margin_bottom=8, container=container) - - self.url_prefix = data.get("url_prefix", "") - self.render_field("connect_by_url") - if self.connect_by_url: - connection_string = connection_service.form_overwritten_connection_url(data) - connection_string_beginning, connection_string_end = connection_string.split("@", 1) - - self.update_field_value( - "url_prefix", - f"{connection_string_beginning}@".replace("%3E", ">").replace("%3C", "<"), - ) - if not data.get("url", ""): - self.update_field_value("url", connection_string_end) - - url_override_left_column, url_override_right_column = st.columns([0.25, 0.75]) - self.render_field("url_prefix", container=url_override_left_column) - self.render_field("url", container=url_override_right_column) - - time.sleep(0.1) - - return self - - def render_extra( - self, - _container: DeltaGenerator, - _left_fields_container: DeltaGenerator, - _right_fields_container: DeltaGenerator, - _data: dict, - ) -> None: - ... - - @staticmethod - def set_default_port(sql_flavor: SQLFlavor, form: type["BaseConnectionForm"]) -> None: - if sql_flavor == "mssql": - form.project_port = 1433 - elif sql_flavor == "redshift": - form.project_port = 5439 - elif sql_flavor == "postgresql": - form.project_port = 5432 - elif sql_flavor == "snowflake": - form.project_port = 443 - elif sql_flavor == "databricks": - form.project_port = 443 - - @staticmethod - def for_flavor(flavor: SQLFlavor) -> type["BaseConnectionForm"]: - return { - "redshift": PasswordConnectionForm, - "snowflake": KeyPairConnectionForm, - "mssql": PasswordConnectionForm, - "postgresql": PasswordConnectionForm, - "databricks": HttpPathConnectionForm, - }[flavor] - - -class PasswordConnectionForm(BaseConnectionForm): - password: str = Field( - default="", - max_length=50, - writeOnly=True, - st_kwargs_label="Password", - st_kwargs_max_chars=50, - st_kwargs_help="Password to connect to your database.", - ) - - def render_extra( - self, - _container: DeltaGenerator, - left_fields_container: DeltaGenerator, - _right_fields_container: DeltaGenerator, - _data: dict, - ) -> None: - self.render_field("password", left_fields_container) - - -class HttpPathConnectionForm(PasswordConnectionForm): - http_path: str = Field( - default="", - max_length=200, - st_kwargs_label="HTTP Path", - st_kwargs_max_chars=50, - ) - - def render_extra( - self, - _container: DeltaGenerator, - left_fields_container: DeltaGenerator, - _right_fields_container: DeltaGenerator, - _data: dict, - ) -> None: - super().render_extra(_container, left_fields_container, _right_fields_container, _data) - self.render_field("http_path", left_fields_container) - - -class KeyPairConnectionForm(PasswordConnectionForm): - connect_by_key: bool = Field(default=None) - private_key_passphrase: str = Field( - default="", - max_length=200, - writeOnly=True, - st_kwargs_max_chars=200, - st_kwargs_help=( - "Passphrase used while creating the private Key (leave empty if not applicable)" - ), - st_kwargs_label="Private Key Passphrase", - ) - _uploaded_file: UploadedFile | None = None - - @computed_field(default="") - def private_key(self) -> str: - if self._uploaded_file is None: - return "" - - file_contents: bytes = self._uploaded_file.getvalue() - return file_contents.decode("utf-8") - - def render_extra( - self, - container: DeltaGenerator, - _left_fields_container: DeltaGenerator, - _right_fields_container: DeltaGenerator, - _data: dict, - ) -> None: - testgen.divider(margin_top=8, margin_bottom=8, container=container) - - connect_by_key = self.connect_by_key - if connect_by_key is None: - connect_by_key = self.get_field_value("connect_by_key") - - connection_option: typing.Literal["Connect by Password", "Connect by Key-Pair"] = container.radio( - "Connection options", - options=["Connect by Password", "Connect by Key-Pair"], - index=1 if connect_by_key else 0, - horizontal=True, - help="Connection strategy", - key=self.get_field_key("connection_option"), - ) - self.update_field_value("connect_by_key", connection_option == "Connect by Key-Pair") - - if connection_option == "Connect by Password": - self.render_field("password", container) - else: - self.render_field("private_key_passphrase", container) - - file_uploader_key = self.get_field_key("private_key_uploader") - cached_file_upload_key = self.get_field_key("previous_private_key_file") - - self._uploaded_file = container.file_uploader( - key=file_uploader_key, - label="Upload private key (rsa_key.p8)", - accept_multiple_files=False, - on_change=lambda: st.session_state.pop(cached_file_upload_key, None), - ) - - if self._uploaded_file: - st.session_state[cached_file_upload_key] = self._uploaded_file - elif self._uploaded_file is None and (cached_file_upload := st.session_state.get(cached_file_upload_key)): - self._uploaded_file = cached_file_upload - file_size = f"{round(self._uploaded_file.size / 1024, 2)}KB" - container.html( - f""" -
- draft - {self._uploaded_file.name} - {file_size} -
- """ - ) - - def reset_cache(self) -> None: - st.session_state.pop(self.get_field_key("private_key_uploader"), None) - st.session_state.pop(self.get_field_key("previous_private_key_file"), None) - return super().reset_cache() diff --git a/testgen/ui/views/connections/models.py b/testgen/ui/views/connections/models.py deleted file mode 100644 index 90f16cad..00000000 --- a/testgen/ui/views/connections/models.py +++ /dev/null @@ -1,8 +0,0 @@ -import dataclasses - - -@dataclasses.dataclass(frozen=True, slots=True) -class ConnectionStatus: - message: str - successful: bool - details: str | None = dataclasses.field(default=None) diff --git a/testgen/ui/views/connections/page.py b/testgen/ui/views/connections/page.py deleted file mode 100644 index a8dd1787..00000000 --- a/testgen/ui/views/connections/page.py +++ /dev/null @@ -1,353 +0,0 @@ -import logging -import time -import typing -from functools import partial - -import streamlit as st -import streamlit_pydantic as sp -from pydantic import ValidationError -from streamlit.delta_generator import DeltaGenerator - -import testgen.ui.services.database_service as db -from testgen.commands.run_profiling_bridge import run_profiling_in_background -from testgen.common.database.database_service import empty_cache -from testgen.common.models import with_database_session -from testgen.ui.components import widgets as testgen -from testgen.ui.navigation.menu import MenuItem -from testgen.ui.navigation.page import Page -from testgen.ui.services import connection_service, table_group_service, user_session_service -from testgen.ui.session import session, temp_value -from testgen.ui.views.connections.forms import BaseConnectionForm -from testgen.ui.views.connections.models import ConnectionStatus -from testgen.ui.views.table_groups import TableGroupForm - -LOG = logging.getLogger("testgen") -PAGE_TITLE = "Connection" - - -class ConnectionsPage(Page): - path = "connections" - can_activate: typing.ClassVar = [ - lambda: session.authentication_status, - lambda: not user_session_service.user_has_catalog_role(), - lambda: "project_code" in st.query_params, - ] - menu_item = MenuItem( - icon="database", - label=PAGE_TITLE, - section="Data Configuration", - order=0, - roles=[ role for role in typing.get_args(user_session_service.RoleType) if role != "catalog" ], - ) - - def render(self, project_code: str, **_kwargs) -> None: - dataframe = connection_service.get_connections(project_code) - connection = dataframe.iloc[0] - has_table_groups = ( - len(connection_service.get_table_group_names_by_connection([connection["connection_id"]]) or []) > 0 - ) - - testgen.page_header( - PAGE_TITLE, - "connect-your-database", - ) - - testgen.whitespace(0.3) - _, actions_column = st.columns([.1, .9]) - testgen.whitespace(0.3) - testgen.flex_row_end(actions_column) - - with st.container(border=True): - self.show_connection_form(connection.to_dict(), "edit", project_code) - - if has_table_groups: - with actions_column: - testgen.link( - label="Manage Table Groups", - href="connections:table-groups", - params={"connection_id": str(connection["connection_id"])}, - right_icon="chevron_right", - underline=False, - height=40, - style="margin-left: auto; border-radius: 4px; background: var(--dk-card-background);" - " border: var(--button-stroked-border); padding: 8px 8px 8px 16px; color: var(--primary-color)", - ) - else: - user_can_edit = user_session_service.user_can_edit() - with actions_column: - testgen.button( - type_="stroked", - color="primary", - icon="table_view", - label="Setup Table Groups", - style="var(--dk-card-background)", - width=200, - disabled=not user_can_edit, - tooltip=None if user_can_edit else user_session_service.DISABLED_ACTION_TEXT, - on_click=lambda: self.setup_data_configuration(project_code, connection.to_dict()), - ) - - def show_connection_form(self, selected_connection: dict, _mode: str, project_code) -> None: - connection = selected_connection or {} - connection_id = connection.get("connection_id", 1) - connection_name = connection.get("connection_name", "default") - sql_flavor = connection.get("sql_flavor", "postgresql") - data = {} - - try: - FlavorForm = BaseConnectionForm.for_flavor(sql_flavor) - if connection: - connection["password"] = connection["password"] or "" - connection["private_key"] = connection["private_key"] or "" - - form_kwargs = connection or {"sql_flavor": sql_flavor, "connection_id": connection_id, "connection_name": connection_name} - form = FlavorForm(**form_kwargs) - - BaseConnectionForm.set_default_port(sql_flavor, form) - - sql_flavor = form.get_field_value("sql_flavor", latest=True) or sql_flavor - if form.sql_flavor != sql_flavor: - form = BaseConnectionForm.for_flavor(sql_flavor)(sql_flavor=sql_flavor, connection_id=connection_id) - - form.disable("connection_name") - - form_errors_container = st.empty() - data = sp.pydantic_input( - key=f"connection_form:{connection_id}", - model=form, # type: ignore - ) - data.update({ - "project_code": project_code, - }) - if "private_key" not in data: - data.update({ - "connect_by_key": False, - "private_key_passphrase": None, - "private_key": None, - }) - - data.setdefault("http_path", "") - - try: - FlavorForm(**data) - except ValidationError as error: - form_errors_container.warning("\n".join([ - f"- {field_label}: {err['msg']}" for err in error.errors() - if (field_label := FlavorForm.get_field_label(str(err["loc"][0]))) - ])) - except Exception: - LOG.exception("unexpected form validation error") - st.error("Unexpected error displaying the form. Try again") - - test_button_column, _, save_button_column = st.columns([.2, .6, .2]) - is_submitted, set_submitted = temp_value(f"connection_form-{connection_id}:submit") - is_connecting, set_connecting = temp_value( - f"connection_form-{connection_id}:test_conn" - ) - - if user_session_service.user_is_admin(): - with save_button_column: - testgen.button( - type_="flat", - label="Save", - key=f"connection_form:{connection_id}:submit", - on_click=lambda: set_submitted(True), - ) - - with test_button_column: - testgen.button( - type_="stroked", - color="basic", - label="Test Connection", - key=f"connection_form:{connection_id}:test", - on_click=lambda: set_connecting(True), - ) - - if is_connecting(): - single_element_container = st.empty() - single_element_container.info("Connecting ...") - connection_status = self.test_connection(data) - - with single_element_container.container(): - renderer = { - True: st.success, - False: st.error, - }[connection_status.successful] - - renderer(connection_status.message) - if not connection_status.successful and connection_status.details: - st.caption("Connection Error Details") - - with st.container(border=True): - st.markdown(connection_status.details) - - connection_status = None - else: - # This is needed to fix a strange bug in Streamlit when using dialog + input fields + button - # If an input field is changed and the button is clicked immediately (without unfocusing the input first), - # two fragment reruns happen successively, one for unfocusing the input and the other for clicking the button - # Some or all (it seems random) of the input fields disappear when this happens - time.sleep(0.1) - - if is_submitted(): - if not data.get("password") and not data.get("connect_by_key"): - st.error("Enter a valid password.") - else: - if data.get("private_key"): - data["private_key"] = data["private_key"].getvalue().decode("utf-8") - - connection_service.edit_connection(data) - st.success("Changes have been saved successfully.") - time.sleep(1) - st.rerun() - - def test_connection(self, connection: dict) -> "ConnectionStatus": - if connection["connect_by_key"] and connection["connection_id"] is None: - return ConnectionStatus( - message="Please add the connection before testing it (so that we can get your private key file).", - successful=False, - ) - - empty_cache() - try: - sql_query = "select 1;" - results = db.retrieve_target_db_data( - connection["sql_flavor"], - connection["project_host"], - connection["project_port"], - connection["project_db"], - connection["project_user"], - connection["password"], - connection["url"], - connection["connect_by_url"], - connection["connect_by_key"], - connection["private_key"], - connection["private_key_passphrase"], - connection["http_path"], - sql_query, - ) - connection_successful = len(results) == 1 and results[0][0] == 1 - - if not connection_successful: - return ConnectionStatus(message="Error completing a query to the database server.", successful=False) - return ConnectionStatus(message="The connection was successful.", successful=True) - except Exception as error: - return ConnectionStatus(message="Error attempting the connection.", details=error.args[0], successful=False) - - @st.dialog(title="Data Configuration Setup") - @with_database_session - def setup_data_configuration(self, project_code: str, connection: dict) -> None: - will_run_profiling = st.session_state.get("connection_form-new:run-profiling-toggle", True) - testgen.wizard( - key="connections:setup-wizard", - steps=[ - testgen.WizardStep( - title="Create a Table Group", - body=partial(self.create_table_group_step, project_code, connection), - ), - testgen.WizardStep( - title="Run Profiling", - body=self.run_data_profiling_step, - ), - ], - on_complete=self.execute_setup, - complete_label="Save & Run Profiling" if will_run_profiling else "Finish Setup", - navigate_to=st.session_state.pop("setup_data_config:navigate-to", None), - navigate_to_args=st.session_state.pop("setup_data_config:navigate-to-args", {}), - ) - - def create_table_group_step(self, project_code: str, connection: dict) -> tuple[dict | None, bool]: - is_valid: bool = True - data: dict = {} - - try: - form = TableGroupForm.construct() - form_errors_container = st.empty() - data = sp.pydantic_input(key="table_form:new", model=form) # type: ignore - - try: - TableGroupForm(**data) - form_errors_container.empty() - data.update({"project_code": project_code, "connection_id": connection["connection_id"]}) - except ValidationError as error: - form_errors_container.warning("\n".join([ - f"- {field_label}: {err['msg']}" for err in error.errors() - if (field_label := TableGroupForm.get_field_label(str(err["loc"][0]))) - ])) - is_valid = False - except Exception: - LOG.exception("unexpected form validation error") - st.error("Unexpected error displaying the form. Try again") - is_valid = False - - return data, is_valid - - def run_data_profiling_step(self, step_0: testgen.WizardStep | None = None) -> tuple[bool, bool]: - if not step_0 or not step_0.results: - st.error("A table group is required to complete this step.") - return False, False - - run_profiling = True - profiling_message = "Profiling will be performed in a background process." - table_group = step_0.results - - with st.container(): - run_profiling = st.checkbox( - label=f"Execute profiling for the table group **{table_group['table_groups_name']}**?", - key="connection_form-new:run-profiling-toggle", - value=True, - ) - if not run_profiling: - profiling_message = ( - "Profiling will be skipped. You can run this step later from the Profiling Runs page." - ) - st.markdown(f":material/info: _{profiling_message}_") - - return run_profiling, True - - def execute_setup( - self, - container: DeltaGenerator, - step_0: testgen.WizardStep[dict], - step_1: testgen.WizardStep[bool], - ) -> bool: - table_group = step_0.results - table_group_name: str = table_group["table_groups_name"] - should_run_profiling: bool = step_1.results - - with container.container(): - status_container = st.empty() - - try: - status_container.info(f"Creating table group **{table_group_name.strip()}**.") - table_group_id = table_group_service.add(table_group) - TableGroupForm.construct().reset_cache() - except Exception as err: - status_container.error(f"Error creating table group: {err!s}.") - - if should_run_profiling: - try: - status_container.info("Starting profiling run ...") - run_profiling_in_background(table_group_id) - status_container.success(f"Profiling run started for table group **{table_group_name.strip()}**.") - except Exception as err: - status_container.error(f"Profiling run encountered errors: {err!s}.") - - _, link_column = st.columns([.7, .3]) - with link_column: - testgen.button( - type_="stroked", - color="primary", - label="Go to Profiling Runs", - icon="chevron_right", - key="setup_data_config:keys:go-to-runs", - on_click=lambda: ( - st.session_state.__setattr__("setup_data_config:navigate-to", "profiling-runs") - or st.session_state.__setattr__("setup_data_config:navigate-to-args", { - "project_code": table_group["project_code"], - "table_group": table_group_id, - }) - ), - ) - - return not should_run_profiling diff --git a/testgen/utils/plugins.py b/testgen/utils/plugins.py index 16607525..6dc2563b 100644 --- a/testgen/utils/plugins.py +++ b/testgen/utils/plugins.py @@ -1,17 +1,119 @@ import dataclasses import importlib.metadata -import typing +import inspect +import json +import os +from collections.abc import Generator +from pathlib import Path +from typing import ClassVar + +from testgen.ui.assets import get_asset_path +from testgen.ui.navigation.page import Page PLUGIN_PREFIX = "testgen_" +ui_plugins_components_directory = ( + Path(__file__).parent.parent / "ui" / "components" / "frontend" / "js" / "plugin_pages" +) +ui_plugins_provision_file = Path(__file__).parent.parent / "ui" / "components" / "frontend" / "js" / "plugins.js" +ui_plugins_entrypoint_prefix = "./plugin_pages" -def discover() -> typing.Generator["Plugin", None, None]: +def discover() -> Generator["Plugin", None, None]: + ui_plugins_provision_file.touch(exist_ok=True) for package_path, distribution_names in importlib.metadata.packages_distributions().items(): if package_path.startswith(PLUGIN_PREFIX): yield Plugin(package=package_path, version=importlib.metadata.version(distribution_names[0])) +def cleanup() -> None: + if ui_plugins_components_directory.exists(): + for item in ui_plugins_components_directory.iterdir(): + if item.is_symlink(): + try: + item.unlink() + except OSError as e: + ... + _reset_ui_plugin_spec() + + +def _reset_ui_plugin_spec() -> None: + ui_plugins_provision_file.touch(exist_ok=True) + ui_plugins_provision_file.write_text("export default {};") + + +class Logo: + image_path: str = get_asset_path("dk_logo.svg") + icon_path: str = get_asset_path("dk_icon.svg") + + def render(self): + import streamlit as st + + st.logo( + image=self.image_path, + icon_image=self.icon_path, + ) + + +@dataclasses.dataclass +class ComponentSpec: + name: str + root: Path + entrypoint: str + + def provide(self) -> None: + ui_plugins_components_directory.mkdir(exist_ok=True) + + target = ui_plugins_components_directory / self.name + try: + os.symlink(self.root, target) + except FileExistsError: + ... + except OSError as e: + ... + + plugins_provision: dict = _read_ui_plugin_spec() + plugins_provision[self.name] = { + "name": self.name, + "entrypoint": f"{ui_plugins_entrypoint_prefix}/{self.name}/{self.entrypoint}", + } + ui_plugins_provision_file.write_text(f"""export default {json.dumps(plugins_provision, indent=2)};""") + + +def _read_ui_plugin_spec() -> dict: + contents = ui_plugins_provision_file.read_text() or "export default {};" + return json.loads(contents.replace("export default ", "")[:-1]) + + +class PluginSpec: + page: ClassVar[type[Page] | None] = None + logo: ClassVar[type[Logo] | None] = None + component: ClassVar[ComponentSpec | None] = None + + @dataclasses.dataclass class Plugin: package: str version: str + + def load(self) -> PluginSpec: + plugin_page = None + plugin_logo = None + plugin_component_spec = None + + module = importlib.import_module(self.package) + for property_name in dir(module): + if ((maybe_class := getattr(module, property_name, None)) and inspect.isclass(maybe_class)): + if issubclass(maybe_class, PluginSpec): + return maybe_class + + if issubclass(maybe_class, Page): + plugin_page = maybe_class + + elif issubclass(maybe_class, Logo): + plugin_logo = maybe_class + + return type("AnyPlugin", (PluginSpec,), { + "page": plugin_page, + "logo": plugin_logo, + "component": plugin_component_spec, + })