diff --git a/.gitignore b/.gitignore index e35ed4f..8cc65b3 100644 --- a/.gitignore +++ b/.gitignore @@ -151,4 +151,4 @@ src-stats.yaml config.yaml *.yaml.gz -*_stories.py +./*_stories.py diff --git a/datafaker/install.py b/datafaker/install.py new file mode 100644 index 0000000..33cc9e1 --- /dev/null +++ b/datafaker/install.py @@ -0,0 +1,244 @@ +"""Functions to install Python file references in ``config.yaml``.""" +from collections.abc import Mapping, MutableMapping, Sequence +from inspect import Parameter, signature +from pathlib import Path +from typing import Any + +from datafaker.utils import import_file, logger + + +def _make_where_from_annotation( + query_def: Mapping[str, Any], + fn_name: str, + param_name: str, +) -> str: + """Make a where clause from ``query`` value from the annotation.""" + if "where" not in query_def: + return "" + w = query_def["where"] + if isinstance(w, str): + return f" WHERE {w}" + if isinstance(w, Sequence): + return " WHERE " + " AND ".join(f'"({clause})"' for clause in w) + logger.warning( + '"where" in the query annotation of parameter "%s" of function "%s"' + " needs to be a string or a list of strings", + param_name, + fn_name, + ) + return "" + + +def _make_vars_from_annotation( + query_def: Mapping[str, Any], + fn_name: str, + param_name: str, +) -> Mapping[str, Any]: + """Make a variables dict from ``query`` value from the annotation.""" + if "vars" not in query_def: + return {} + vars_def = query_def["vars"] + if isinstance(vars_def, Mapping): + return vars_def + if isinstance(vars_def, Sequence): + return {v: v for v in query_def["vars"]} + logger.warning( + '"vars" in the query annotation of parameter "%s" of function "%s"' + " needs to be a list of strings or a dict of strings to strings", + param_name, + fn_name, + ) + return {} + + +def _add_count_vars_from_annotation( + group_vars_out: MutableMapping[str, Any], + query_def: Mapping[str, Any], + fn_name: str, + param_name: str, +) -> None: + """Add ``GROUP BY`` clauses from ``count_vars``.""" + if "count_vars" not in query_def: + return + cntv = query_def["count_vars"] + if isinstance(cntv, Mapping): + group_vars_out.update({k: f"COUNT({v})" for k, v in cntv}) + return + logger.warning( + '"count_vars" needs to be a dict in the annotation for parameter %s of function %s', + param_name, + fn_name, + ) + + +def _add_ms_vars_from_annotation( + group_vars_out: MutableMapping[str, Any], + query_def: Mapping[str, Any], + fn_name: str, + param_name: str, +) -> None: + """Add ``GROUP BY`` clauses from ``ms_vars``.""" + if "ms_vars" not in query_def: + return + msv = query_def["ms_vars"] + if not isinstance(msv, Mapping): + logger.warning( + '"ms_vars" needs to be a dict in the annotation for parameter %s of function %s', + param_name, + fn_name, + ) + return + for k, v in msv.items(): + group_vars_out[k + "_count"] = f"COUNT({v})" + group_vars_out[k + "_mean"] = f"AVG({v})" + group_vars_out[k + "_stddev"] = f"STDDEV({v})" + + +def make_query_from_annotation( + annotation_data: Any, + fn_name: str, + param_name: str, +) -> str | None: + """ + Make new configuration items describing a query. + + The query's result will be passed as this parameter to this function. + + The annotation must be a dict with the following keys: + + ``comment``: A string describing the query in natural language. + + ``query``: Either a string containing the SQL query required, or + a dict containing the following keys: + + * ``table``: The table to query. Could be "tablename AS alias" if you like. + * ``vars`` (optional): Either a list of columns to extract from the table(s), + or a dict of keys (the names of the keys in the dict to be passed to the + annotated function) to values (the names of the columns to be extracted). + At least one of ``vars``, ``ms_vars``, ``count_vars`` must be present. + * ``where`` (optional): A SQL expression to filter the results. + * ``count_vars`` (optional): A dict of keys to be passed to the function + to values that are the names of the columns to be counted (could be + ``*``; if the name of a column the result will be the number of non-null + entries in that column). The query will be grouped by ``vars``. + * ``ms_vars`` (optional): A dict of value names to columns to be analysed. + The keys to be passed to the function will be name + ``_count`` for the + number of non-null values in that column, name + ``_mean`` for the + average value in that column and name + ``_stddev`` for the standard + deviation of values in that column. + + :param annotation_data: The ``Annotation`` attached to the parameter. + :param fn_name: The name of the function that the parameter is of. + :param param_name: The name of the parameter with the annotation. + :return: A mapping of new configuration items to add to the configuration, + if the annotation had a well-defined query and comment value; otherwise + an empty dict. + """ + if not isinstance(annotation_data, Sequence): + return None + ann = annotation_data[0] + if not isinstance(ann, Mapping) or "query" not in ann: + return None + if isinstance(ann["query"], str): + return ann["query"] + query_def = ann["query"] + if "table" not in query_def: + logger.warning( + '"table" needs to be a key in the annotation for' + ' the "query" value of parameter "%s" of function "%s"', + param_name, + fn_name, + ) + return None + table = query_def["table"] + nongroup_vars = _make_vars_from_annotation(query_def, fn_name, param_name) + where = _make_where_from_annotation(query_def, fn_name, param_name) + group_vars: dict[str, Any] = {} + _add_count_vars_from_annotation(group_vars, query_def, fn_name, param_name) + _add_ms_vars_from_annotation(group_vars, query_def, fn_name, param_name) + if group_vars and nongroup_vars: + group_by = " GROUP BY " + ", ".join(f'"{v}"' for v in nongroup_vars) + else: + group_by = "" + vars_exprs = ", ".join( + f'{v} AS "{k}"' for k, v in {**nongroup_vars, **group_vars}.items() + ) + return f"SELECT {vars_exprs} FROM {table}{group_by}{where}" + + +def _add_kwarg( + kwargs_out: dict[str, Any], fn_name: str, param: Parameter +) -> list[dict[str, Any]]: + """ + Add a kwargs configuration and return a ``src_stats`` query item. + + :param kwargs_out: The story generator's ``kwargs`` value to be updated. + :param fn_name: The name of the story generator function. + :param param: The parameter to specify. + :return: A list of configuration items to add to the ``src_stats`` config, for + all the queries this parameter requires. + """ + if param.annotation is Parameter.empty: + return [] + meta = param.annotation.__metadata__ + query = make_query_from_annotation( + param.annotation.__metadata__, fn_name, param.name + ) + if query is None: + return [] + stat_name = f"story_auto__{fn_name}__{param.name}" + if "comments" in meta[0]: + comments = [meta[0]["comment"]] + else: + comments = [] + ssc = { + "name": stat_name, + "query": query, + "comments": comments, + } + kwargs_out[param.name] = f'SRC_STATS["{stat_name}"]["results"]' + return [ssc] + + +def install_stories_from(config: MutableMapping[str, Any], story_file: Path) -> bool: + """ + Configure datafaker with the stories in a Python file. + + :param config: The contents of the configuration file, to be mutated. + :param story_file: Path to the Python file containing the story generators. + :return: True if the config was updated correctly, False if it was untouched + because problems were encountered. + """ + story_generators: list[Mapping[str, Any]] = [] + src_stats = [ + s + for s in config.get("src_stats", []) + if isinstance(s, Mapping) + and "name" in s + and not s["name"].startswith("story_auto__") + ] + story_module_name = story_file.stem + story_module = import_file(story_file, story_module_name) + for attr_name in dir(story_module): + attr = getattr(story_module, attr_name) + if ( + hasattr(attr, "__module__") + and attr.__module__ == story_module_name + and not attr_name.startswith("_") + and callable(attr) + ): + kwargs: dict[str, None] = {} + sig = signature(attr) + for param in sig.parameters.values(): + src_stats += _add_kwarg(kwargs, attr, param) + story_generators.append( + { + "name": f"{story_module_name}.{attr_name}", + "num_stories_per_pass": 1, + "kwargs": kwargs, + } + ) + config["story_generators_module"] = story_module_name + config["story_generators"] = story_generators + config["src-stats"] = src_stats + return True diff --git a/datafaker/main.py b/datafaker/main.py index 5a89d5e..2859538 100644 --- a/datafaker/main.py +++ b/datafaker/main.py @@ -16,6 +16,7 @@ from datafaker.create import create_db_data, create_db_tables, create_db_vocab from datafaker.dump import dump_db_tables +from datafaker.install import install_stories_from from datafaker.interactive import ( update_config_generators, update_config_tables, @@ -71,7 +72,7 @@ def _require_src_db_dsn(settings: Settings) -> str: def load_metadata_config( - orm_file_name: str, config: dict | None = None + orm_file_name: Path, config: dict | None = None ) -> dict[str, Any]: """ Load the ``orm.yaml`` file, returning a dict representation. @@ -82,7 +83,7 @@ def load_metadata_config( :return: A dict representing the ``orm.yaml`` file, with the tables the ``config`` says to ignore removed. """ - with open(orm_file_name, encoding="utf-8") as orm_fh: + with orm_file_name.open(encoding="utf-8") as orm_fh: meta_dict = yaml.load(orm_fh, yaml.Loader) if not isinstance(meta_dict, dict): return {} @@ -95,7 +96,7 @@ def load_metadata_config( return meta_dict -def load_metadata(orm_file_name: str, config: dict | None = None) -> MetaData: +def load_metadata(orm_file_name: Path, config: dict | None = None) -> MetaData: """ Load metadata from ``orm.yaml``. @@ -107,7 +108,7 @@ def load_metadata(orm_file_name: str, config: dict | None = None) -> MetaData: return dict_to_metadata(meta_dict, None) -def load_metadata_for_output(orm_file_name: str, config: dict | None = None) -> Any: +def load_metadata_for_output(orm_file_name: Path, config: dict | None = None) -> Any: """Load metadata excluding any foreign keys pointing to ignored tables.""" meta_dict = load_metadata_config(orm_file_name, config) return dict_to_metadata(meta_dict, config) @@ -123,12 +124,12 @@ def main( @app.command() def create_data( - orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), + orm_file: Path = Option(ORM_FILENAME, help="The name of the ORM yaml file"), df_file: str = Option( DF_FILENAME, help="The name of the generators file. Must be in the current working directory.", ), - config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), + config_file: Path = Option(CONFIG_FILENAME, help="The configuration file"), num_passes: int = Option(1, help="Number of passes (rows or stories) to make"), ) -> None: """Populate the schema in the target directory with synthetic data. @@ -150,7 +151,7 @@ def create_data( $ datafaker create-data """ logger.debug("Creating data.") - config = read_config_file(config_file) if config_file is not None else {} + config = read_config_file(config_file) orm_metadata = load_metadata_for_output(orm_file, config) df_module = import_file(df_file) try: @@ -180,8 +181,8 @@ def create_data( @app.command() def create_vocab( - orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), - config_file: str = Option(CONFIG_FILENAME, help="The configuration file"), + orm_file: Path = Option(ORM_FILENAME, help="The name of the ORM yaml file"), + config_file: Path = Option(CONFIG_FILENAME, help="The configuration file"), ) -> None: """Import vocabulary data into the target database. @@ -199,8 +200,8 @@ def create_vocab( @app.command() def create_tables( - orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), - config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), + orm_file: Path = Option(ORM_FILENAME, help="The name of the ORM yaml file"), + config_file: Path = Option(CONFIG_FILENAME, help="The configuration file"), ) -> None: """Create schema from the ORM YAML file. @@ -211,7 +212,7 @@ def create_tables( $ datafaker create-tables """ logger.debug("Creating tables.") - config = read_config_file(config_file) if config_file is not None else {} + config = read_config_file(config_file) orm_metadata = load_metadata_for_output(orm_file, config) create_db_tables(orm_metadata) logger.debug("Tables created.") @@ -219,10 +220,10 @@ def create_tables( @app.command() def create_generators( - orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), - df_file: str = Option(DF_FILENAME, help="Path to write Python generators to."), - config_file: str = Option(CONFIG_FILENAME, help="The configuration file"), - stats_file: Optional[str] = Option( + orm_file: Path = Option(ORM_FILENAME, help="The name of the ORM yaml file"), + df_file: Path = Option(DF_FILENAME, help="Path to write Python generators to."), + config_file: Path = Option(CONFIG_FILENAME, help="The configuration file"), + stats_file: Optional[Path] = Option( None, help=( "Statistics file (output of make-stats); default is src-stats.yaml if the " @@ -248,9 +249,9 @@ def create_generators( if not force: _check_file_non_existence(df_file_path) - generator_config = read_config_file(config_file) if config_file is not None else {} + generator_config = read_config_file(config_file) if stats_file is None and generators_require_stats(generator_config): - stats_file = STATS_FILENAME + stats_file = Path(STATS_FILENAME) orm_metadata = load_metadata_for_output(orm_file, generator_config) result: str = make_table_generators( orm_metadata, @@ -267,8 +268,8 @@ def create_generators( @app.command() def make_vocab( - orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), - config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), + orm_file: Path = Option(ORM_FILENAME, help="The name of the ORM yaml file"), + config_file: Path = Option(CONFIG_FILENAME, help="The configuration file"), force: bool = Option( False, "--force/--no-force", @@ -288,7 +289,7 @@ def make_vocab( settings = get_settings() _require_src_db_dsn(settings) - generator_config = read_config_file(config_file) if config_file is not None else {} + generator_config = read_config_file(config_file) orm_metadata = load_metadata(orm_file, generator_config) make_vocabulary_tables( orm_metadata, @@ -301,7 +302,7 @@ def make_vocab( @app.command() def make_stats( - config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), + config_file: Path = Option(CONFIG_FILENAME, help="The configuration file"), stats_file: str = Option(STATS_FILENAME), force: bool = Option( False, "--force", "-f", help="Overwrite any existing vocabulary file." @@ -320,7 +321,7 @@ def make_stats( if not force: _check_file_non_existence(stats_file_path) - config = read_config_file(config_file) if config_file is not None else {} + config = read_config_file(config_file) settings = get_settings() src_dsn: str = _require_src_db_dsn(settings) @@ -334,14 +335,14 @@ def make_stats( @app.command() def make_tables( - config_file: Optional[str] = Option( + config_file: Optional[Path] = Option( None, help=( "The configuration file, used if you want" " an orm.yaml lacking data for the ignored tables" ), ), - orm_file: str = Option(ORM_FILENAME, help="Path to write the ORM yaml file to"), + orm_file: Path = Option(ORM_FILENAME, help="Path to write the ORM yaml file to"), force: bool = Option( False, "--force", "-f", help="Overwrite any existing orm yaml file." ), @@ -371,7 +372,7 @@ def configure_tables( config_file: str = Option( CONFIG_FILENAME, help="Path to write the configuration file to" ), - orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), + orm_file: Path = Option(ORM_FILENAME, help="The name of the ORM yaml file"), ) -> None: """Interactively set tables to ignored, vocabulary or primary private.""" logger.debug("Configuring tables in %s.", config_file) @@ -398,20 +399,19 @@ def configure_tables( @app.command() def configure_missing( - config_file: str = Option( + config_file: Path = Option( CONFIG_FILENAME, help="Path to write the configuration file to" ), - orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), + orm_file: Path = Option(ORM_FILENAME, help="The name of the ORM yaml file"), ) -> None: """Interactively set the missingness of the generated data.""" logger.debug("Configuring missingness in %s.", config_file) settings = get_settings() src_dsn: str = _require_src_db_dsn(settings) - config_file_path = Path(config_file) config: dict[str, Any] = {} - if config_file_path.exists(): + if config_file.exists(): config_any = yaml.load( - config_file_path.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader + config_file.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader ) if isinstance(config_any, dict): config = config_any @@ -421,16 +421,16 @@ def configure_missing( logger.debug("Cancelled") return content = yaml.dump(config_updated) - config_file_path.write_text(content, encoding="utf-8") + config_file.write_text(content, encoding="utf-8") logger.debug("Generators missingness in %s.", config_file) @app.command() def configure_generators( - config_file: str = Option( + config_file: Path = Option( CONFIG_FILENAME, help="Path of the configuration file to alter" ), - orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), + orm_file: Path = Option(ORM_FILENAME, help="The name of the ORM yaml file"), spec: Path = Option( None, help=( @@ -443,11 +443,10 @@ def configure_generators( logger.debug("Configuring generators in %s.", config_file) settings = get_settings() src_dsn: str = _require_src_db_dsn(settings) - config_file_path = Path(config_file) config = {} - if config_file_path.exists(): + if config_file.exists(): config = yaml.load( - config_file_path.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader + config_file.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader ) metadata = load_metadata(orm_file, config) config_updated = update_config_generators( @@ -457,16 +456,16 @@ def configure_generators( logger.debug("Cancelled") return content = yaml.dump(config_updated) - config_file_path.write_text(content, encoding="utf-8") + config_file.write_text(content, encoding="utf-8") logger.debug("Generators configured in %s.", config_file) @app.command() def dump_data( - config_file: Optional[str] = Option( + config_file: Path = Option( CONFIG_FILENAME, help="Path of the configuration file to alter" ), - orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), + orm_file: Path = Option(ORM_FILENAME, help="The name of the ORM yaml file"), table: str = Argument(help="The table to dump"), output: str | None = Option(None, help="output CSV file name"), ) -> None: @@ -475,7 +474,7 @@ def dump_data( dst_dsn: str = settings.dst_dsn or "" assert dst_dsn != "", "Missing DST_DSN setting." schema_name = settings.dst_schema - config = read_config_file(config_file) if config_file is not None else {} + config = read_config_file(config_file) metadata = load_metadata_for_output(orm_file, config) if output is None: if isinstance(sys.stdout, io.TextIOBase): @@ -504,8 +503,8 @@ def validate_config( @app.command() def remove_data( - orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), - config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), + orm_file: Path = Option(ORM_FILENAME, help="The name of the ORM yaml file"), + config_file: Path = Option(CONFIG_FILENAME, help="The configuration file"), yes: bool = Option( False, "--yes", prompt="Are you sure?", help="Just remove, don't ask first" ), @@ -513,7 +512,7 @@ def remove_data( """Truncate non-vocabulary tables in the destination schema.""" if yes: logger.debug("Truncating non-vocabulary tables.") - config = read_config_file(config_file) if config_file is not None else {} + config = read_config_file(config_file) metadata = load_metadata_for_output(orm_file, config) remove_db_data(metadata, config) logger.debug("Non-vocabulary tables truncated.") @@ -523,8 +522,8 @@ def remove_data( @app.command() def remove_vocab( - orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), - config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), + orm_file: Path = Option(ORM_FILENAME, help="The name of the ORM yaml file"), + config_file: Path = Option(CONFIG_FILENAME, help="The configuration file"), yes: bool = Option( False, "--yes", prompt="Are you sure?", help="Just remove, don't ask first" ), @@ -532,7 +531,7 @@ def remove_vocab( """Truncate vocabulary tables in the destination schema.""" if yes: logger.debug("Truncating vocabulary tables.") - config = read_config_file(config_file) if config_file is not None else {} + config = read_config_file(config_file) meta_dict = load_metadata_config(orm_file, config) orm_metadata = dict_to_metadata(meta_dict, config) remove_db_vocab(orm_metadata, meta_dict, config) @@ -543,8 +542,8 @@ def remove_vocab( @app.command() def remove_tables( - orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), - config_file: str = Option(CONFIG_FILENAME, help="The configuration file"), + orm_file: Path = Option(ORM_FILENAME, help="The name of the ORM yaml file"), + config_file: Path = Option(CONFIG_FILENAME, help="The configuration file"), # pylint: disable=redefined-builtin all: bool = Option( False, @@ -581,12 +580,12 @@ class TableType(str, Enum): @app.command() def list_tables( - orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), - config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), + orm_file: Path = Option(ORM_FILENAME, help="The name of the ORM yaml file"), + config_file: Path = Option(CONFIG_FILENAME, help="The configuration file"), tables: TableType = Option(TableType.GENERATED, help="Which tables to list"), ) -> None: """List the names of tables described in the metadata file.""" - config = read_config_file(config_file) if config_file is not None else {} + config = read_config_file(config_file) orm_metadata = load_metadata(orm_file, config) all_table_names = set(orm_metadata.tables.keys()) vocab_table_names = { @@ -604,6 +603,26 @@ def list_tables( print(name) +@app.command() +def install_stories( + config_file: Path = Option(CONFIG_FILENAME, help="The configuration file"), + story_file: Path = Argument(help="The Python file containing stories"), +) -> None: + """Add the story file's name and any contained query to the configuration file.""" + config_file_path = Path(config_file) + config = {} + if config_file_path.exists(): + config = yaml.load( + config_file_path.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader + ) + if not install_stories_from(config, story_file): + logger.debug("Cancelled") + sys.exit(1) + content = yaml.dump(config) + config_file_path.write_text(content, encoding="utf-8") + logger.debug("Stories configured in %s.", config_file) + + @app.command() def version() -> None: """Display version information.""" diff --git a/datafaker/make.py b/datafaker/make.py index 6f4cc9b..a496717 100644 --- a/datafaker/make.py +++ b/datafaker/make.py @@ -580,9 +580,9 @@ def make_vocabulary_tables( def make_table_generators( # pylint: disable=too-many-locals metadata: MetaData, config: Mapping, - orm_filename: str, - config_filename: str, - src_stats_filename: Optional[str], + orm_filename: Path, + config_filename: Path, + src_stats_filename: Optional[Path], ) -> str: """ Create datafaker generator classes. diff --git a/datafaker/utils.py b/datafaker/utils.py index 7ef91bf..2f20583 100644 --- a/datafaker/utils.py +++ b/datafaker/utils.py @@ -63,7 +63,7 @@ def iterable(cls) -> Iterable[T]: return (x for x in e) -def read_config_file(path: str) -> dict: +def read_config_file(path: Path) -> dict: """Read a config file, warning if it is invalid. Args: @@ -72,7 +72,7 @@ def read_config_file(path: str) -> dict: Returns: The config file as a dictionary. """ - with open(path, "r", encoding="utf8") as f: + with path.open(encoding="utf8") as f: config = yaml.safe_load(f) assert isinstance(config, dict) @@ -86,7 +86,7 @@ def read_config_file(path: str) -> dict: return config -def import_file(file_path: str) -> ModuleType: +def import_file(file_path: str | Path, module_name: str = "df") -> ModuleType: """Import a file. This utility function returns file_path imported as a module. @@ -97,7 +97,7 @@ def import_file(file_path: str) -> ModuleType: Returns: ModuleType """ - spec = importlib.util.spec_from_file_location("df", file_path) + spec = importlib.util.spec_from_file_location(module_name, file_path) if spec is None or spec.loader is None: raise ImportError(f"No loadable module at {file_path}") module = importlib.util.module_from_spec(spec) diff --git a/tests/examples/annotated_stories.py b/tests/examples/annotated_stories.py new file mode 100644 index 0000000..fcfce1e --- /dev/null +++ b/tests/examples/annotated_stories.py @@ -0,0 +1,28 @@ +"""Story generators which describe their own queries and can therefore be installed.""" +from collections.abc import Iterable +from typing import Annotated, Any + +def string_story_one_sd( + stats: Annotated[dict, { + "query": { + "ms_vars": {"freq": "frequency"}, + "table": "string", + }, + "comment": "Frequency mean and standard deviation", + }], +) -> Iterable[tuple[str, dict[str, Any]]]: + man = yield("manufacturer", {"name": "one"}) + mod = yield ("model", { + "name": "one_sd", + "manufacturer_id": man["id"] + }) + yield("string", { + "model_id": mod["id"], + "position": 0, + "frequency": stats[0]["freq_mean"] - stats[0]["freq_stddev"], + }) + yield("string", { + "model_id": mod["id"], + "position": stats[0]["freq_count"], + "frequency": stats[0]["freq_mean"] + stats[0]["freq_stddev"], + }) diff --git a/tests/examples/install_config.yaml b/tests/examples/install_config.yaml new file mode 100644 index 0000000..c155d2b --- /dev/null +++ b/tests/examples/install_config.yaml @@ -0,0 +1,3 @@ +tables: + string: + num_rows_per_pass: 0 diff --git a/tests/test_functional.py b/tests/test_functional.py index bfb2f09..d89690c 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -52,17 +52,6 @@ def setUp(self) -> None: env=self.env, ) - self.env = { - "src_dsn": self.dsn, - "src_schema": self.schema_name, - "dst_dsn": self.dsn, - "dst_schema": "dstschema", - } - self.runner = CliRunner( - mix_stderr=False, - env=self.env, - ) - # Copy some of the example files over to the workspace. self.test_dir = Path(tempfile.mkdtemp(prefix="df-")) for file in self.generator_file_paths + (self.config_file_path,): diff --git a/tests/test_install_stories.py b/tests/test_install_stories.py new file mode 100644 index 0000000..ce7f495 --- /dev/null +++ b/tests/test_install_stories.py @@ -0,0 +1,270 @@ +"""Tests for installing stories into ``config.yaml``.""" +import os +import re +import shutil +import tempfile +from pathlib import Path +from typing import Any, Mapping + +import yaml +from sqlalchemy import Row, func, select, text +from typer.testing import CliRunner, Result + +from datafaker.main import app, install_stories +from tests.utils import GeneratesDBTestCase, create_db_engine, get_sync_engine + +# pylint: disable=subprocess-run-check + + +class InstallTestCase(GeneratesDBTestCase): + """End-to-end tests that require a database.""" + + dump_file_path = "instrument.sql" + database_name = "instrument" + schema_name = "public" + + examples_dir = Path("tests/examples") + + orm_file_path = Path("orm.yaml") + + input_file_paths = [Path("annotated_stories.py"), Path("install_config.yaml")] + stats_file_path = Path("example_stats.yaml") + + src_stats_re = re.compile(r'SRC_STATS\["(.*)"\]\["results"\]') + + start_dir = os.getcwd() + + def setUp(self) -> None: + """Pre-test setup.""" + super().setUp() + self.env = { + "src_dsn": self.dsn, + "src_schema": self.schema_name, + "dst_dsn": self.dsn, + "dst_schema": "dstschema", + } + self.runner = CliRunner( + mix_stderr=False, + env=self.env, + ) + + # Copy some of the example files over to the workspace. + self.test_dir = Path(tempfile.mkdtemp(prefix="df-")) + for file in self.input_file_paths: + src = self.examples_dir / file + dst = self.test_dir / file + shutil.copy(src, dst) + + os.chdir(self.test_dir) + + def tearDown(self) -> None: + """Tear down post test.""" + os.chdir(self.start_dir) + super().tearDown() + + def assert_silent_success(self, completed_process: Result) -> None: + """Assert that the process completed successfully without producing output.""" + self.assertNoException(completed_process) + self.assertSuccess(completed_process) + self.assertEqual(completed_process.stderr, "") + self.assertEqual(completed_process.stdout, "") + + def test_install_stories_simple(self) -> None: + """Test story gets expected parameters after installation.""" + config_path = Path("config-iss.yaml") + config_path.write_text("{}", encoding="UTF-8") + + install_stories(config_path, Path("annotated_stories.py")) + + config = yaml.load( + config_path.read_text(encoding="UTF-8"), + Loader=yaml.SafeLoader, + ) + + # Module name configured + self.assertIn("story_generators_module", config) + self.assertEqual(config["story_generators_module"], "annotated_stories") + + # Generator added with parameter + self.assertIn("story_generators", config) + st_gen = config["story_generators"] + self.assertEqual(len(st_gen), 1) + self.assertIn("name", st_gen[0]) + self.assertEqual(st_gen[0]["name"], "annotated_stories.string_story_one_sd") + self.assertIn("kwargs", st_gen[0]) + self.assertIn("stats", st_gen[0]["kwargs"]) + stats_ref = st_gen[0]["kwargs"]["stats"] + stats_result = self.src_stats_re.match(stats_ref) + self.assertIsNotNone( + stats_result, f'parameter "{stats_ref}" is not a SRC_STATS reference' + ) + assert stats_result is not None + + # Source stats query + self.assertIn("src-stats", config) + src_stats = config["src-stats"] + assert src_stats is not None + self.assertEqual(len(src_stats), 1) + assert src_stats[0] is not None + self.assertIn("name", src_stats[0]) + self.assertEqual(src_stats[0]["name"], stats_result.group(1)) + self.assertIn("query", src_stats[0]) + query = src_stats[0]["query"] + (mean, stddev, _count) = self.get_string_stats() + + # Let's run the query and see what we get. + engine = get_sync_engine( + create_db_engine( + self.env["src_dsn"], + schema_name=self.env["src_schema"], + ) + ) + with engine.connect() as conn: + rows = conn.execute(text(query)).fetchall() + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0].freq_mean, mean) + self.assertEqual(rows[0].freq_stddev, stddev) + + def test_install_stories_end_to_end(self) -> None: + """Test the stories run with the expected parameters after installation.""" + completed_process = self.invoke( + "make-tables", + "--force", + ) + self.assert_silent_success(completed_process) + + completed_process = self.invoke( + "install-stories", + "--config-file", + "install_config.yaml", + "annotated_stories.py", + ) + self.assert_silent_success(completed_process) + + completed_process = self.invoke( + "make-stats", + "--config-file", + "install_config.yaml", + "--force", + ) + self.assert_silent_success(completed_process) + + completed_process = self.invoke( + "create-generators", + "--config-file", + "install_config.yaml", + "--force", + "--stats-file=src-stats.yaml", + ) + self.assert_silent_success(completed_process) + + completed_process = self.invoke( + "remove-tables", + "--config-file", + "install_config.yaml", + "--yes", + ) + self.assert_silent_success(completed_process) + + completed_process = self.invoke( + "create-tables", + "--config-file", + "install_config.yaml", + ) + self.assert_silent_success(completed_process) + + completed_process = self.invoke( + "create-data", + "--config-file", + "install_config.yaml", + ) + self.assertNoException(completed_process) + self.assertEqual("", completed_process.stderr) + self.assertSuccess(completed_process) + self.assertEqual( + "Generating data for story 'annotated_stories.string_story_one_sd'\n", + completed_process.stdout, + ) + + (mean, stddev, count) = self.get_string_stats() + + model_table = self.metadata.tables["model"] + string_table = self.metadata.tables["string"] + engine = get_sync_engine( + create_db_engine( + self.env["dst_dsn"], + schema_name=self.env["dst_schema"], + ) + ) + with engine.connect() as conn: + row = conn.execute( + select(model_table.c.name, model_table.c.id).where( + model_table.c.name == "one_sd" + ) + ).fetchone() + assert row is not None + strs = conn.execute( + select(string_table).where(string_table.c.model_id == row.id) + ).fetchall() + lower = None + higher = None + for s in strs: + if s.position == 0: + self.assertIsNone( + lower, "Multiple one_sd strings with zero position" + ) + lower = s.frequency + else: + self.assertIsNone( + higher, "Multiple one_sd strings with non-zero position" + ) + self.assertEqual(s.position, count) + higher = s.frequency + assert lower is not None + assert higher is not None + self.assertAlmostEqual((higher + lower) / 2, mean) + self.assertAlmostEqual((higher - lower) / 2, stddev) + + def get_string_stats(self) -> tuple[float | None, float | None, int | None]: + """Get the mean, standard deviation and count of frequencies in the string table.""" + string_table = self.metadata.tables["string"] + engine = get_sync_engine( + create_db_engine( + self.env["src_dsn"], + schema_name=self.env["src_schema"], + ) + ) + with engine.connect() as conn: + results = conn.execute( + select( + func.count(), # pylint: disable=not-callable + func.avg(string_table.c.frequency), + func.stddev(string_table.c.frequency), + ) + ).fetchone() + if not isinstance(results, Row): + return None, None, None + return results.avg_1, results.stddev_1, results.count_1 + + def invoke( + self, + *args: Any, + expected_error: str | None = None, + env: Mapping[str, str] | None = None, + ) -> Result: + """ + Run datafaker with the given arguments and environment. + + :param args: Arguments to provide to datafaker. + :param expected_error: If None, will assert that the invocation + passes successfully without throwing an exception. Otherwise, + the suggested error must be present in the standard error stream. + :param env: The environment variables to be set during invocation. + """ + res = self.runner.invoke(app, args, env=env) + if expected_error is None: + self.assertNoException(res) + self.assertSuccess(res) + else: + self.assertIn(expected_error, res.stderr) + return res diff --git a/tests/test_main.py b/tests/test_main.py index 2167207..290d2af 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -75,8 +75,8 @@ def test_create_generators( mock_make.assert_called_once_with( mock_load_meta.return_value, mock_config.return_value, - "orm.yaml", - "config.yaml", + Path("orm.yaml"), + Path("config.yaml"), None, ) mock_path.return_value.write_text.assert_called_once_with( @@ -117,9 +117,9 @@ def test_create_generators_uses_default_stats_file_if_necessary( mock_make.assert_called_once_with( mock_load_meta.return_value, mock_config.return_value, - "orm.yaml", - "config.yaml", - "src-stats.yaml", + Path("orm.yaml"), + Path("config.yaml"), + mock_path("src-stats.yaml"), ) mock_path.return_value.write_text.assert_called_once_with( "some text", encoding="utf-8" @@ -170,6 +170,7 @@ def test_create_generators_with_force_enabled( for force_option in ["--force", "-f"]: with self.subTest(f"Using option {force_option}"): + mock_make.reset_mock() result: Result = runner.invoke( app, [ @@ -181,8 +182,8 @@ def test_create_generators_with_force_enabled( mock_make.assert_called_once_with( mock_load_meta.return_value, mock_config.return_value, - "orm.yaml", - "config.yaml", + Path("orm.yaml"), + Path("config.yaml"), None, ) mock_path.return_value.write_text.assert_called_once_with( @@ -554,7 +555,7 @@ def test_remove_vocab( catch_exceptions=False, ) self.assertEqual(0, result.exit_code) - mock_read_config.assert_called_once_with("config.yaml") + mock_read_config.assert_called_once_with(Path("config.yaml")) mock_remove.assert_called_once_with( mock_d2m.return_value, mock_load_metadata.return_value, diff --git a/tests/test_utils.py b/tests/test_utils.py index ac82d12..9f0ab9b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -111,7 +111,7 @@ class TestReadConfig(DatafakerTestCase): def test_warns_of_invalid_config(self) -> None: """Test that we get a warning if the config is invalid.""" with patch("datafaker.utils.logger") as mock_logger: - read_config_file("tests/examples/invalid_config.yaml") + read_config_file(Path("tests/examples/invalid_config.yaml")) mock_logger.error.assert_called_with( "The config file is invalid: %s", "'a' is not of type 'integer'" ) diff --git a/tests/utils.py b/tests/utils.py index 4a7e85b..e35f89c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -208,21 +208,23 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.generators_file_path = "" self.stats_fd = 0 - self.stats_file_path = "" - self.config_file_path = "" + self.stats_file_path = Path("") + self.config_file_path = Path("") self.config_fd = 0 def setUp(self) -> None: """Set up the test case with an actual orm.yaml file.""" super().setUp() # Generate the `orm.yaml` from the database - (self.orm_fd, self.orm_file_path) = mkstemp(".yaml", "orm_", text=True) + (self.orm_fd, orm_file_path) = mkstemp(".yaml", "orm_", text=True) + self.orm_file_path = Path(orm_file_path) with os.fdopen(self.orm_fd, "w", encoding="utf-8") as orm_fh: orm_fh.write(make_tables_file(self.dsn, self.schema_name, {})) def set_configuration(self, config: Mapping[str, Any]) -> None: """Accepts a configuration file, writes it out.""" - (self.config_fd, self.config_file_path) = mkstemp(".yaml", "config_", text=True) + (self.config_fd, config_file_path) = mkstemp(".yaml", "config_", text=True) + self.config_file_path = Path(config_file_path) with os.fdopen(self.config_fd, "w", encoding="utf-8") as config_fh: config_fh.write(yaml.dump(config)) @@ -237,9 +239,8 @@ def get_src_stats(self, config: Mapping[str, Any]) -> dict[str, Any]: make_src_stats(self.dsn, config, self.schema_name) ) loop.close() - (self.stats_fd, self.stats_file_path) = mkstemp( - ".yaml", "src_stats_", text=True - ) + (self.stats_fd, stats_file_path) = mkstemp(".yaml", "src_stats_", text=True) + self.stats_file_path = Path(stats_file_path) with os.fdopen(self.stats_fd, "w", encoding="utf-8") as stats_fh: stats_fh.write(yaml.dump(src_stats)) return src_stats