Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions cmdstanpy/cmdstan_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(
adapt_metric_window: Optional[int] = None,
adapt_step_size: Optional[int] = None,
fixed_param: bool = False,
num_chains: int = 1,
) -> None:
"""Initialize object."""
self.iter_warmup = iter_warmup
Expand All @@ -73,6 +74,7 @@ def __init__(
self.adapt_step_size = adapt_step_size
self.fixed_param = fixed_param
self.diagnostic_file = None
self.num_chains = num_chains

def validate(self, chains: Optional[int]) -> None:
"""
Expand Down Expand Up @@ -316,6 +318,10 @@ def validate(self, chains: Optional[int]) -> None:
'Argument "adapt_step_size" must be a non-negative integer,'
'found {}'.format(self.adapt_step_size)
)
if self.num_chains < 1 or not isinstance(
self.num_chains, (int, np.integer)
):
raise ValueError("num_chains must be positive")

if self.fixed_param and (
self.max_treedepth is not None
Expand Down Expand Up @@ -378,6 +384,8 @@ def compose(self, idx: int, cmd: List[str]) -> List[str]:
cmd.append('window={}'.format(self.adapt_metric_window))
if self.adapt_step_size is not None:
cmd.append('term_buffer={}'.format(self.adapt_step_size))
if self.num_chains > 1:
cmd.append('num_chains={}'.format(self.num_chains))

return cmd

Expand Down Expand Up @@ -921,8 +929,12 @@ def validate(self) -> None:
)
)
elif isinstance(self.inits, str):
if not os.path.exists(self.inits):
raise ValueError('no such file {}'.format(self.inits))
if not (
isinstance(self.method_args, SamplerArgs)
and self.method_args.num_chains > 1
):
if not os.path.exists(self.inits):
raise ValueError('no such file {}'.format(self.inits))
elif isinstance(self.inits, list):
if self.chain_ids is None:
raise ValueError(
Expand All @@ -948,7 +960,6 @@ def compose_command(
*,
diagnostic_file: Optional[str] = None,
profile_file: Optional[str] = None,
num_chains: Optional[int] = None,
) -> List[str]:
"""
Compose CmdStan command for non-default arguments.
Expand Down Expand Up @@ -992,6 +1003,4 @@ def compose_command(
if self.sig_figs is not None:
cmd.append('sig_figs={}'.format(self.sig_figs))
cmd = self.method_args.compose(idx, cmd)
if num_chains:
cmd.append('num_chains={}'.format(num_chains))
return cmd
177 changes: 105 additions & 72 deletions cmdstanpy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
)
from cmdstanpy.utils import (
EXTENSION,
MaybeDictToFilePath,
SanitizedOrTmpFilePath,
cmdstan_path,
cmdstan_version,
Expand All @@ -67,6 +66,7 @@
get_logger,
returncode_msg,
)
from cmdstanpy.utils.filesystem import temp_inits, temp_single_json

from . import progress as progbar

Expand Down Expand Up @@ -573,7 +573,7 @@ def optimize(
self,
data: Union[Mapping[str, Any], str, os.PathLike, None] = None,
seed: Optional[int] = None,
inits: Union[Dict[str, float], float, str, os.PathLike, None] = None,
inits: Union[Mapping[str, Any], float, str, os.PathLike, None] = None,
output_dir: OptionalPath = None,
sig_figs: Optional[int] = None,
save_profile: bool = False,
Expand Down Expand Up @@ -722,7 +722,9 @@ def optimize(
"in CmdStan 2.32 and above."
)

with MaybeDictToFilePath(data, inits) as (_data, _inits):
with temp_single_json(data) as _data, temp_inits(
inits, allow_multiple=False
) as _inits:
args = CmdStanArgs(
self._name,
self._exe_file,
Expand Down Expand Up @@ -766,7 +768,14 @@ def sample(
threads_per_chain: Optional[int] = None,
seed: Union[int, List[int], None] = None,
chain_ids: Union[int, List[int], None] = None,
inits: Union[Dict[str, float], float, str, List[str], None] = None,
inits: Union[
Mapping[str, Any],
float,
str,
List[str],
List[Mapping[str, Any]],
None,
] = None,
iter_warmup: Optional[int] = None,
iter_sampling: Optional[int] = None,
save_warmup: bool = False,
Expand Down Expand Up @@ -1006,6 +1015,69 @@ def sample(
chains
)
)

if parallel_chains is None:
parallel_chains = max(min(cpu_count(), chains), 1)
elif parallel_chains > chains:
get_logger().info(
'Requested %u parallel_chains but only %u required, '
'will run all chains in parallel.',
parallel_chains,
chains,
)
parallel_chains = chains
elif parallel_chains < 1:
raise ValueError(
'Argument parallel_chains must be a positive integer, '
'found {}.'.format(parallel_chains)
)
if threads_per_chain is None:
threads_per_chain = 1
if threads_per_chain < 1:
raise ValueError(
'Argument threads_per_chain must be a positive integer, '
'found {}.'.format(threads_per_chain)
)

parallel_procs = parallel_chains
num_threads = threads_per_chain
one_process_per_chain = True
info_dict = self.exe_info()
stan_threads = info_dict.get('STAN_THREADS', 'false').lower()
# run multi-chain sampler unless algo is fixed_param or 1 chain
if fixed_param or (chains == 1):
force_one_process_per_chain = True

if (
force_one_process_per_chain is None
and not cmdstan_version_before(2, 28, info_dict)
and stan_threads == 'true'
):
one_process_per_chain = False
num_threads = parallel_chains * num_threads
parallel_procs = 1
if force_one_process_per_chain is False:
if not cmdstan_version_before(2, 28, info_dict):
one_process_per_chain = False
num_threads = parallel_chains * num_threads
parallel_procs = 1
if stan_threads == 'false':
get_logger().warning(
'Stan program not compiled for threading, '
'process will run chains sequentially. '
'For multi-chain parallelization, recompile '
'the model with argument '
'"cpp_options={\'STAN_THREADS\':\'TRUE\'}.'
)
else:
get_logger().warning(
'Installed version of CmdStan cannot multi-process '
'chains, will run %d processes. '
'Run "install_cmdstan" to upgrade to latest version.',
chains,
)
os.environ['STAN_NUM_THREADS'] = str(num_threads)

if chain_ids is None:
chain_ids = [i + 1 for i in range(chains)]
else:
Expand All @@ -1017,6 +1089,13 @@ def sample(
)
chain_ids = [i + chain_ids for i in range(chains)]
else:
if not one_process_per_chain:
for i, j in zip(chain_ids, chain_ids[1:]):
if i != j - 1:
raise ValueError(
'chain_ids must be sequential list of integers,'
' found {}.'.format(chain_ids)
)
if not len(chain_ids) == chains:
raise ValueError(
'Chain_ids must correspond to number of chains'
Expand All @@ -1032,6 +1111,7 @@ def sample(
)

sampler_args = SamplerArgs(
num_chains=1 if one_process_per_chain else chains,
iter_warmup=iter_warmup,
iter_sampling=iter_sampling,
save_warmup=save_warmup,
Expand All @@ -1046,14 +1126,25 @@ def sample(
adapt_step_size=adapt_step_size,
fixed_param=fixed_param,
)
with MaybeDictToFilePath(data, inits) as (_data, _inits):

with temp_single_json(data) as _data, temp_inits(
inits, id=chain_ids[0]
) as _inits:
cmdstan_inits: Union[str, List[str], int, float, None]
if one_process_per_chain and isinstance(inits, list): # legacy
cmdstan_inits = [
f"{_inits[:-5]}_{i}.json" for i in chain_ids # type: ignore
]
else:
cmdstan_inits = _inits

args = CmdStanArgs(
self._name,
self._exe_file,
chain_ids=chain_ids,
data=_data,
seed=seed,
inits=_inits,
inits=cmdstan_inits,
output_dir=output_dir,
sig_figs=sig_figs,
save_latent_dynamics=save_latent_dynamics,
Expand All @@ -1062,68 +1153,6 @@ def sample(
refresh=refresh,
)

if parallel_chains is None:
parallel_chains = max(min(cpu_count(), chains), 1)
elif parallel_chains > chains:
get_logger().info(
'Requested %u parallel_chains but only %u required, '
'will run all chains in parallel.',
parallel_chains,
chains,
)
parallel_chains = chains
elif parallel_chains < 1:
raise ValueError(
'Argument parallel_chains must be a positive integer, '
'found {}.'.format(parallel_chains)
)
if threads_per_chain is None:
threads_per_chain = 1
if threads_per_chain < 1:
raise ValueError(
'Argument threads_per_chain must be a positive integer, '
'found {}.'.format(threads_per_chain)
)

parallel_procs = parallel_chains
num_threads = threads_per_chain
one_process_per_chain = True
info_dict = self.exe_info()
stan_threads = info_dict.get('STAN_THREADS', 'false').lower()
# run multi-chain sampler unless algo is fixed_param or 1 chain
if fixed_param or (chains == 1):
force_one_process_per_chain = True

if (
force_one_process_per_chain is None
and not cmdstan_version_before(2, 28, info_dict)
and stan_threads == 'true'
):
one_process_per_chain = False
num_threads = parallel_chains * num_threads
parallel_procs = 1
if force_one_process_per_chain is False:
if not cmdstan_version_before(2, 28, info_dict):
one_process_per_chain = False
num_threads = parallel_chains * num_threads
parallel_procs = 1
if stan_threads == 'false':
get_logger().warning(
'Stan program not compiled for threading, '
'process will run chains sequentially. '
'For multi-chain parallelization, recompile '
'the model with argument '
'"cpp_options={\'STAN_THREADS\':\'TRUE\'}.'
)
else:
get_logger().warning(
'Installed version of CmdStan cannot multi-process '
'chains, will run %d processes. '
'Run "install_cmdstan" to upgrade to latest version.',
chains,
)
os.environ['STAN_NUM_THREADS'] = str(num_threads)

if show_console:
show_progress = False
else:
Expand Down Expand Up @@ -1376,7 +1405,7 @@ def generate_quantities(
csv_files=fit_csv_files
)
generate_quantities_args.validate(chains)
with MaybeDictToFilePath(data, None) as (_data, _inits):
with temp_single_json(data) as _data:
args = CmdStanArgs(
self._name,
self._exe_file,
Expand Down Expand Up @@ -1551,7 +1580,9 @@ def variational(
output_samples=output_samples,
)

with MaybeDictToFilePath(data, inits) as (_data, _inits):
with temp_single_json(data) as _data, temp_inits(
inits, allow_multiple=False
) as _inits:
args = CmdStanArgs(
self._name,
self._exe_file,
Expand Down Expand Up @@ -1658,7 +1689,9 @@ def log_prob(
"Method 'log_prob' not available for CmdStan versions "
"before 2.31"
)
with MaybeDictToFilePath(data, params) as (_data, _params):
with temp_single_json(data) as _data, temp_single_json(
params
) as _params:
cmd = [
str(self.exe_file),
"log_prob",
Expand Down Expand Up @@ -1766,7 +1799,7 @@ def laplace_sample(
cmdstan_mode.runset.csv_files[0], draws, jacobian
)

with MaybeDictToFilePath(data) as (_data,):
with temp_single_json(data) as _data:
args = CmdStanArgs(
self._name,
self._exe_file,
Expand Down
1 change: 0 additions & 1 deletion cmdstanpy/stanfit/runset.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ def cmd(self, idx: int) -> List[str]:
profile_file=self.file_path(".csv", extra="-profile")
if self._args.save_profile
else None,
num_chains=self._chains,
)

@property
Expand Down
2 changes: 0 additions & 2 deletions cmdstanpy/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from .command import do_command, returncode_msg
from .data_munging import build_xarray_data, flatten_chains
from .filesystem import (
MaybeDictToFilePath,
SanitizedOrTmpFilePath,
create_named_text_file,
pushd,
Expand Down Expand Up @@ -116,7 +115,6 @@ def show_versions(output: bool = True) -> str:
__all__ = [
'BaseType',
'EXTENSION',
'MaybeDictToFilePath',
'SanitizedOrTmpFilePath',
'build_xarray_data',
'check_sampler_csv',
Expand Down
Loading