diff --git a/cmdstanpy/cmdstan_args.py b/cmdstanpy/cmdstan_args.py index f6426817..75730c1c 100644 --- a/cmdstanpy/cmdstan_args.py +++ b/cmdstanpy/cmdstan_args.py @@ -55,6 +55,7 @@ def __init__( adapt_metric_window: Optional[int] = None, adapt_step_size: Optional[int] = None, fixed_param: bool = False, + num_chains: int = 1, ) -> None: """Initialize object.""" self.iter_warmup = iter_warmup @@ -73,6 +74,7 @@ def __init__( self.adapt_step_size = adapt_step_size self.fixed_param = fixed_param self.diagnostic_file = None + self.num_chains = num_chains def validate(self, chains: Optional[int]) -> None: """ @@ -316,6 +318,10 @@ def validate(self, chains: Optional[int]) -> None: 'Argument "adapt_step_size" must be a non-negative integer,' 'found {}'.format(self.adapt_step_size) ) + if self.num_chains < 1 or not isinstance( + self.num_chains, (int, np.integer) + ): + raise ValueError("num_chains must be positive") if self.fixed_param and ( self.max_treedepth is not None @@ -378,6 +384,8 @@ def compose(self, idx: int, cmd: List[str]) -> List[str]: cmd.append('window={}'.format(self.adapt_metric_window)) if self.adapt_step_size is not None: cmd.append('term_buffer={}'.format(self.adapt_step_size)) + if self.num_chains > 1: + cmd.append('num_chains={}'.format(self.num_chains)) return cmd @@ -921,8 +929,12 @@ def validate(self) -> None: ) ) elif isinstance(self.inits, str): - if not os.path.exists(self.inits): - raise ValueError('no such file {}'.format(self.inits)) + if not ( + isinstance(self.method_args, SamplerArgs) + and self.method_args.num_chains > 1 + ): + if not os.path.exists(self.inits): + raise ValueError('no such file {}'.format(self.inits)) elif isinstance(self.inits, list): if self.chain_ids is None: raise ValueError( @@ -948,7 +960,6 @@ def compose_command( *, diagnostic_file: Optional[str] = None, profile_file: Optional[str] = None, - num_chains: Optional[int] = None, ) -> List[str]: """ Compose CmdStan command for non-default arguments. @@ -992,6 +1003,4 @@ def compose_command( if self.sig_figs is not None: cmd.append('sig_figs={}'.format(self.sig_figs)) cmd = self.method_args.compose(idx, cmd) - if num_chains: - cmd.append('num_chains={}'.format(num_chains)) return cmd diff --git a/cmdstanpy/model.py b/cmdstanpy/model.py index 7fd9d33f..045b3428 100644 --- a/cmdstanpy/model.py +++ b/cmdstanpy/model.py @@ -58,7 +58,6 @@ ) from cmdstanpy.utils import ( EXTENSION, - MaybeDictToFilePath, SanitizedOrTmpFilePath, cmdstan_path, cmdstan_version, @@ -67,6 +66,7 @@ get_logger, returncode_msg, ) +from cmdstanpy.utils.filesystem import temp_inits, temp_single_json from . import progress as progbar @@ -573,7 +573,7 @@ def optimize( self, data: Union[Mapping[str, Any], str, os.PathLike, None] = None, seed: Optional[int] = None, - inits: Union[Dict[str, float], float, str, os.PathLike, None] = None, + inits: Union[Mapping[str, Any], float, str, os.PathLike, None] = None, output_dir: OptionalPath = None, sig_figs: Optional[int] = None, save_profile: bool = False, @@ -722,7 +722,9 @@ def optimize( "in CmdStan 2.32 and above." ) - with MaybeDictToFilePath(data, inits) as (_data, _inits): + with temp_single_json(data) as _data, temp_inits( + inits, allow_multiple=False + ) as _inits: args = CmdStanArgs( self._name, self._exe_file, @@ -766,7 +768,14 @@ def sample( threads_per_chain: Optional[int] = None, seed: Union[int, List[int], None] = None, chain_ids: Union[int, List[int], None] = None, - inits: Union[Dict[str, float], float, str, List[str], None] = None, + inits: Union[ + Mapping[str, Any], + float, + str, + List[str], + List[Mapping[str, Any]], + None, + ] = None, iter_warmup: Optional[int] = None, iter_sampling: Optional[int] = None, save_warmup: bool = False, @@ -1006,6 +1015,69 @@ def sample( chains ) ) + + if parallel_chains is None: + parallel_chains = max(min(cpu_count(), chains), 1) + elif parallel_chains > chains: + get_logger().info( + 'Requested %u parallel_chains but only %u required, ' + 'will run all chains in parallel.', + parallel_chains, + chains, + ) + parallel_chains = chains + elif parallel_chains < 1: + raise ValueError( + 'Argument parallel_chains must be a positive integer, ' + 'found {}.'.format(parallel_chains) + ) + if threads_per_chain is None: + threads_per_chain = 1 + if threads_per_chain < 1: + raise ValueError( + 'Argument threads_per_chain must be a positive integer, ' + 'found {}.'.format(threads_per_chain) + ) + + parallel_procs = parallel_chains + num_threads = threads_per_chain + one_process_per_chain = True + info_dict = self.exe_info() + stan_threads = info_dict.get('STAN_THREADS', 'false').lower() + # run multi-chain sampler unless algo is fixed_param or 1 chain + if fixed_param or (chains == 1): + force_one_process_per_chain = True + + if ( + force_one_process_per_chain is None + and not cmdstan_version_before(2, 28, info_dict) + and stan_threads == 'true' + ): + one_process_per_chain = False + num_threads = parallel_chains * num_threads + parallel_procs = 1 + if force_one_process_per_chain is False: + if not cmdstan_version_before(2, 28, info_dict): + one_process_per_chain = False + num_threads = parallel_chains * num_threads + parallel_procs = 1 + if stan_threads == 'false': + get_logger().warning( + 'Stan program not compiled for threading, ' + 'process will run chains sequentially. ' + 'For multi-chain parallelization, recompile ' + 'the model with argument ' + '"cpp_options={\'STAN_THREADS\':\'TRUE\'}.' + ) + else: + get_logger().warning( + 'Installed version of CmdStan cannot multi-process ' + 'chains, will run %d processes. ' + 'Run "install_cmdstan" to upgrade to latest version.', + chains, + ) + os.environ['STAN_NUM_THREADS'] = str(num_threads) + if chain_ids is None: chain_ids = [i + 1 for i in range(chains)] else: @@ -1017,6 +1089,13 @@ def sample( ) chain_ids = [i + chain_ids for i in range(chains)] else: + if not one_process_per_chain: + for i, j in zip(chain_ids, chain_ids[1:]): + if i != j - 1: + raise ValueError( + 'chain_ids must be sequential list of integers,' + ' found {}.'.format(chain_ids) + ) if not len(chain_ids) == chains: raise ValueError( 'Chain_ids must correspond to number of chains' @@ -1032,6 +1111,7 @@ def sample( ) sampler_args = SamplerArgs( + num_chains=1 if one_process_per_chain else chains, iter_warmup=iter_warmup, iter_sampling=iter_sampling, save_warmup=save_warmup, @@ -1046,14 +1126,25 @@ def sample( adapt_step_size=adapt_step_size, fixed_param=fixed_param, ) - with MaybeDictToFilePath(data, inits) as (_data, _inits): + + with temp_single_json(data) as _data, temp_inits( + inits, id=chain_ids[0] + ) as _inits: + cmdstan_inits: Union[str, List[str], int, float, None] + if one_process_per_chain and isinstance(inits, list): # legacy + cmdstan_inits = [ + f"{_inits[:-5]}_{i}.json" for i in chain_ids # type: ignore + ] + else: + cmdstan_inits = _inits + args = CmdStanArgs( self._name, self._exe_file, chain_ids=chain_ids, data=_data, seed=seed, - inits=_inits, + inits=cmdstan_inits, output_dir=output_dir, sig_figs=sig_figs, save_latent_dynamics=save_latent_dynamics, @@ -1062,68 +1153,6 @@ def sample( refresh=refresh, ) - if parallel_chains is None: - parallel_chains = max(min(cpu_count(), chains), 1) - elif parallel_chains > chains: - get_logger().info( - 'Requested %u parallel_chains but only %u required, ' - 'will run all chains in parallel.', - parallel_chains, - chains, - ) - parallel_chains = chains - elif parallel_chains < 1: - raise ValueError( - 'Argument parallel_chains must be a positive integer, ' - 'found {}.'.format(parallel_chains) - ) - if threads_per_chain is None: - threads_per_chain = 1 - if threads_per_chain < 1: - raise ValueError( - 'Argument threads_per_chain must be a positive integer, ' - 'found {}.'.format(threads_per_chain) - ) - - parallel_procs = parallel_chains - num_threads = threads_per_chain - one_process_per_chain = True - info_dict = self.exe_info() - stan_threads = info_dict.get('STAN_THREADS', 'false').lower() - # run multi-chain sampler unless algo is fixed_param or 1 chain - if fixed_param or (chains == 1): - force_one_process_per_chain = True - - if ( - force_one_process_per_chain is None - and not cmdstan_version_before(2, 28, info_dict) - and stan_threads == 'true' - ): - one_process_per_chain = False - num_threads = parallel_chains * num_threads - parallel_procs = 1 - if force_one_process_per_chain is False: - if not cmdstan_version_before(2, 28, info_dict): - one_process_per_chain = False - num_threads = parallel_chains * num_threads - parallel_procs = 1 - if stan_threads == 'false': - get_logger().warning( - 'Stan program not compiled for threading, ' - 'process will run chains sequentially. ' - 'For multi-chain parallelization, recompile ' - 'the model with argument ' - '"cpp_options={\'STAN_THREADS\':\'TRUE\'}.' - ) - else: - get_logger().warning( - 'Installed version of CmdStan cannot multi-process ' - 'chains, will run %d processes. ' - 'Run "install_cmdstan" to upgrade to latest version.', - chains, - ) - os.environ['STAN_NUM_THREADS'] = str(num_threads) - if show_console: show_progress = False else: @@ -1376,7 +1405,7 @@ def generate_quantities( csv_files=fit_csv_files ) generate_quantities_args.validate(chains) - with MaybeDictToFilePath(data, None) as (_data, _inits): + with temp_single_json(data) as _data: args = CmdStanArgs( self._name, self._exe_file, @@ -1551,7 +1580,9 @@ def variational( output_samples=output_samples, ) - with MaybeDictToFilePath(data, inits) as (_data, _inits): + with temp_single_json(data) as _data, temp_inits( + inits, allow_multiple=False + ) as _inits: args = CmdStanArgs( self._name, self._exe_file, @@ -1658,7 +1689,9 @@ def log_prob( "Method 'log_prob' not available for CmdStan versions " "before 2.31" ) - with MaybeDictToFilePath(data, params) as (_data, _params): + with temp_single_json(data) as _data, temp_single_json( + params + ) as _params: cmd = [ str(self.exe_file), "log_prob", @@ -1766,7 +1799,7 @@ def laplace_sample( cmdstan_mode.runset.csv_files[0], draws, jacobian ) - with MaybeDictToFilePath(data) as (_data,): + with temp_single_json(data) as _data: args = CmdStanArgs( self._name, self._exe_file, diff --git a/cmdstanpy/stanfit/runset.py b/cmdstanpy/stanfit/runset.py index d00f634d..a4da75a9 100644 --- a/cmdstanpy/stanfit/runset.py +++ b/cmdstanpy/stanfit/runset.py @@ -179,7 +179,6 @@ def cmd(self, idx: int) -> List[str]: profile_file=self.file_path(".csv", extra="-profile") if self._args.save_profile else None, - num_chains=self._chains, ) @property diff --git a/cmdstanpy/utils/__init__.py b/cmdstanpy/utils/__init__.py index 5b245bf8..cc1d3e45 100644 --- a/cmdstanpy/utils/__init__.py +++ b/cmdstanpy/utils/__init__.py @@ -22,7 +22,6 @@ from .command import do_command, returncode_msg from .data_munging import build_xarray_data, flatten_chains from .filesystem import ( - MaybeDictToFilePath, SanitizedOrTmpFilePath, create_named_text_file, pushd, @@ -116,7 +115,6 @@ def show_versions(output: bool = True) -> str: __all__ = [ 'BaseType', 'EXTENSION', - 'MaybeDictToFilePath', 'SanitizedOrTmpFilePath', 'build_xarray_data', 'check_sampler_csv', diff --git a/cmdstanpy/utils/filesystem.py b/cmdstanpy/utils/filesystem.py index 045a82dc..233898e1 100644 --- a/cmdstanpy/utils/filesystem.py +++ b/cmdstanpy/utils/filesystem.py @@ -7,7 +7,7 @@ import re import shutil import tempfile -from typing import Any, Iterator, List, Mapping, Tuple, Union +from typing import Any, Iterator, List, Mapping, Optional, Tuple, Union from cmdstanpy import _TMPDIR @@ -103,76 +103,92 @@ def pushd(new_dir: str) -> Iterator[None]: os.chdir(previous_dir) -class MaybeDictToFilePath: +def _temp_single_json( + data: Union[str, os.PathLike, Mapping[str, Any], None] +) -> Iterator[Optional[str]]: """Context manager for json files.""" + if data is None: + yield None + return + if isinstance(data, (str, os.PathLike)): + yield str(data) + return + + data_file = create_named_text_file(dir=_TMPDIR, prefix='', suffix='.json') + get_logger().debug('input tempfile: %s', data_file) + write_stan_json(data_file, data) + try: + yield data_file + finally: + with contextlib.suppress(PermissionError): + os.remove(data_file) - def __init__( - self, - *objs: Union[ - str, Mapping[str, Any], List[Any], int, float, os.PathLike, None - ], - ): - self._unlink = [False] * len(objs) - self._paths: List[Any] = [''] * len(objs) - i = 0 - # pylint: disable=isinstance-second-argument-not-valid-type - for obj in objs: - if isinstance(obj, Mapping): - data_file = create_named_text_file( - dir=_TMPDIR, prefix='', suffix='.json' - ) - get_logger().debug('input tempfile: %s', data_file) - write_stan_json(data_file, obj) - self._paths[i] = data_file - self._unlink[i] = True - elif isinstance(obj, (str, os.PathLike)): - if not os.path.exists(obj): - raise ValueError("File doesn't exist {}".format(obj)) - self._paths[i] = obj - elif isinstance(obj, list): - err_msgs = [] - missing_obj_items = [] - for j, obj_item in enumerate(obj): - if not isinstance(obj_item, str): - err_msgs.append( - ( - 'List element {} must be a filename string,' - ' found {}' - ).format(j, obj_item) - ) - elif not os.path.exists(obj_item): - missing_obj_items.append( - "File doesn't exist: {}".format(obj_item) - ) - if err_msgs: - raise ValueError('\n'.join(err_msgs)) - if missing_obj_items: - raise ValueError('\n'.join(missing_obj_items)) - self._paths[i] = obj - elif obj is None: - self._paths[i] = None - elif i == 1 and isinstance(obj, (int, float)): - self._paths[i] = obj + +temp_single_json = contextlib.contextmanager(_temp_single_json) + + +def _temp_multiinput( + input: Union[str, os.PathLike, Mapping[str, Any], List[Any], None], + base: int = 1, +) -> Iterator[Optional[str]]: + if isinstance(input, list): + # most complicated case: list of inits + # for multiple chains, we need to create multiple files + # which look like somename_{i}.json and then pass somename.json + # to CmdStan + + mother_file = create_named_text_file( + dir=_TMPDIR, prefix='', suffix='.json', name_only=True + ) + new_files = [ + os.path.splitext(mother_file)[0] + f'_{i+base}.json' + for i in range(len(input)) + ] + for init, file in zip(input, new_files): + if isinstance(init, dict): + write_stan_json(file, init) + elif isinstance(init, str): + shutil.copy(init, file) else: - raise ValueError('data must be string or dict') - i += 1 + raise ValueError( + 'A list of inits must contain dicts or strings, not' + + str(type(init)) + ) + try: + yield mother_file + finally: + for file in new_files: + with contextlib.suppress(PermissionError): + os.remove(file) + else: + yield from _temp_single_json(input) - def __enter__(self) -> List[str]: - return self._paths - def __exit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore - for can_unlink, path in zip(self._unlink, self._paths): - if can_unlink and path: - try: - os.remove(path) - except PermissionError: - pass +@contextlib.contextmanager +def temp_inits( + inits: Union[ + str, os.PathLike, Mapping[str, Any], float, int, List[Any], None + ], + *, + allow_multiple: bool = True, + id: int = 1, +) -> Iterator[Union[str, float, int, None]]: + if isinstance(inits, (float, int)): + yield inits + return + if allow_multiple: + yield from _temp_multiinput(inits, base=id) + else: + if isinstance(inits, list): + raise ValueError('Expected single initialization, got list') + yield from _temp_single_json(inits) class SanitizedOrTmpFilePath: """ Context manager for tmpfiles, handles special characters in filepath. """ + UNIXISH_PATTERN = re.compile(r"[\s~]") WINDOWS_PATTERN = re.compile(r"\s") diff --git a/test/test_sample.py b/test/test_sample.py index f600252d..d34d90b9 100644 --- a/test/test_sample.py +++ b/test/test_sample.py @@ -277,10 +277,24 @@ def test_init_types() -> None: iter_sampling=100, inits=[inits_path1, inits_path2], show_progress=False, + force_one_process_per_chain=False, ) - assert 'init={}'.format(inits_path1.replace('\\', '\\\\')) in repr( - bern_fit.runset + + # will be copied, given basename + assert isinstance(bern_fit.runset._args.inits, str) + + bern_fit = bern_model.sample( + data=jdata, + chains=2, + seed=12345, + iter_warmup=100, + iter_sampling=100, + inits=[inits_path1, inits_path2], + show_progress=False, + force_one_process_per_chain=True, ) + # one per process + assert isinstance(bern_fit.runset._args.inits, list) with pytest.raises(ValueError): bern_model.sample( @@ -290,19 +304,36 @@ def test_init_types() -> None: seed=12345, iter_warmup=100, iter_sampling=100, - inits=(1, 2), + inits=-1, ) - with pytest.raises(ValueError): - bern_model.sample( + # test that inits are actually used by having a bad one + init_1 = {"theta": 0.2} + init_2 = {"theta": 4.0} + with pytest.raises(RuntimeError): + bern_fit = bern_model.sample( data=jdata, chains=2, - parallel_chains=2, seed=12345, + inits=[init_1, init_2], iter_warmup=100, iter_sampling=100, - inits=-1, + force_one_process_per_chain=True, + show_progress=False, ) + if not cmdstan_version_before(2, 33): + # https://github.com/stan-dev/cmdstan/pull/1191 + with pytest.raises(RuntimeError): + bern_fit = bern_model.sample( + data=jdata, + chains=2, + seed=12345, + inits=[init_1, init_2], + iter_warmup=100, + iter_sampling=100, + force_one_process_per_chain=False, + show_progress=False, + ) def test_bernoulli_bad() -> None: diff --git a/test/test_utils.py b/test/test_utils.py index fe42f86b..6a7ef368 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -27,7 +27,6 @@ from cmdstanpy.utils import ( EXTENSION, BaseType, - MaybeDictToFilePath, SanitizedOrTmpFilePath, check_sampler_csv, cmdstan_path, @@ -49,6 +48,7 @@ windows_short_path, write_stan_json, ) +from cmdstanpy.utils.filesystem import temp_inits, temp_single_json HERE = os.path.dirname(os.path.abspath(__file__)) DATAFILES_PATH = os.path.join(HERE, 'data') @@ -249,17 +249,46 @@ def test_dict_to_file() -> None: file_good = os.path.join(DATAFILES_PATH, 'bernoulli_output_1.csv') dict_good = {'a': 0.5} created_tmp = None - with MaybeDictToFilePath(file_good, dict_good) as (fg1, fg2): + + with temp_single_json(file_good) as fg1: assert os.path.exists(fg1) + assert os.path.exists(file_good) + + with temp_single_json(dict_good) as fg2: assert os.path.exists(fg2) with open(fg2) as fg2_d: assert json.load(fg2_d) == dict_good created_tmp = fg2 - assert os.path.exists(file_good) + assert not os.path.exists(created_tmp) + with pytest.raises(AttributeError): + with temp_single_json(123) as _: + pass + + +def test_temp_inits(): + dict_good = {'a': 0.5} + with temp_inits([dict_good, dict_good]) as base_file: + fg1 = base_file[:-5] + '_1.json' + fg2 = base_file[:-5] + '_2.json' + assert os.path.exists(fg1) + assert os.path.exists(fg2) + with open(fg1) as fg1_d: + assert json.load(fg1_d) == dict_good + with open(fg2) as fg2_d: + assert json.load(fg2_d) == dict_good + created_tmp = (fg1, fg2) + + assert not os.path.exists(created_tmp[0]) + assert not os.path.exists(created_tmp[1]) + + with pytest.raises(ValueError): + with temp_inits([123]) as _: + pass + with pytest.raises(ValueError): - with MaybeDictToFilePath(123, dict_good) as (fg1, fg2): + with temp_inits([dict_good], allow_multiple=False) as _: pass