diff --git a/.gitignore b/.gitignore index 084e882fe3..46b3a146e2 100644 --- a/.gitignore +++ b/.gitignore @@ -46,5 +46,7 @@ make/local src/docs/**/*.pdf output.csv - +output*.csv *.d +# gdb +.gdb_history diff --git a/make/command b/make/command index 883d4cfa18..320908806a 100644 --- a/make/command +++ b/make/command @@ -2,11 +2,11 @@ ifeq ($(CMDSTAN_SUBMODULES),1) bin/cmdstan/stansummary.o : src/cmdstan/stansummary_helper.hpp bin/cmdstan/%.o : src/cmdstan/%.cpp @mkdir -p $(dir $@) - $(COMPILE.cpp) -fvisibility=hidden $(OUTPUT_OPTION) $< + $(COMPILE.cpp) -fvisibility=hidden $(OUTPUT_OPTION) $(LDLIBS) $< .PRECIOUS: bin/print$(EXE) bin/stansummary$(EXE) bin/diagnose$(EXE) bin/print$(EXE) bin/stansummary$(EXE) bin/diagnose$(EXE) : CPPFLAGS_MPI = -bin/print$(EXE) bin/stansummary$(EXE) bin/diagnose$(EXE) : LDFLAGS_MPI = +bin/print$(EXE) bin/stansummary$(EXE) bin/diagnose$(EXE) : LDFLAGS_MPI = bin/print$(EXE) bin/stansummary$(EXE) bin/diagnose$(EXE) : LDLIBS_MPI = bin/print$(EXE) bin/stansummary$(EXE) bin/diagnose$(EXE) : bin/%$(EXE) : bin/cmdstan/%.o @mkdir -p $(dir $@) diff --git a/make/tests b/make/tests index 870e31e6d7..dcc6a2e3e7 100644 --- a/make/tests +++ b/make/tests @@ -57,8 +57,8 @@ test-headers: $(HEADER_TESTS) TEST_MODELS := $(wildcard src/test/test-models/*.stan) .PHONY: test-models-hpp -test-models-hpp: - $(MAKE) $(patsubst %.stan,%$(EXE),$(TEST_MODELS)) +test-models-hpp: + $(MAKE) $(patsubst %.stan,%$(EXE),$(TEST_MODELS)) ## # Tests that depend on compiled models diff --git a/makefile b/makefile index 25ed545851..8c2541ad2c 100644 --- a/makefile +++ b/makefile @@ -244,7 +244,7 @@ build-mpi: $(MPI_TARGETS) ifeq ($(CMDSTAN_SUBMODULES),1) .PHONY: build -build: bin/stanc$(EXE) bin/stansummary$(EXE) bin/print$(EXE) bin/diagnose$(EXE) $(LIBSUNDIALS) $(MPI_TARGETS) $(TBB_TARGETS) $(CMDSTAN_MAIN_O) $(PRECOMPILED_MODEL_HEADER) +build: bin/stanc$(EXE) $(LIBSUNDIALS) $(MPI_TARGETS) $(TBB_TARGETS) $(CMDSTAN_MAIN_O) $(PRECOMPILED_MODEL_HEADER) bin/stansummary$(EXE) bin/print$(EXE) bin/diagnose$(EXE) @echo '' ifeq ($(OS),Windows_NT) @echo 'NOTE: Please add $(TBB_BIN_ABSOLUTE_PATH) to your PATH variable.' @@ -338,3 +338,6 @@ compile_info: ## .PHONY: print-% print-% : ; @echo $* = $($*) + +.PHONY: clean-build +clean-build: clean-all build diff --git a/src/cmdstan/arguments/arg_id.hpp b/src/cmdstan/arguments/arg_id.hpp index 08bbf60a8d..bfaaa7c72b 100644 --- a/src/cmdstan/arguments/arg_id.hpp +++ b/src/cmdstan/arguments/arg_id.hpp @@ -11,8 +11,8 @@ class arg_id : public int_argument { _name = "id"; _description = "Unique process identifier"; _validity = "id >= 0"; - _default = "0"; - _default_value = 0; + _default = "1"; + _default_value = 1; _constrained = true; _good_value = 2.0; _bad_value = -1.0; diff --git a/src/cmdstan/arguments/arg_num_chains.hpp b/src/cmdstan/arguments/arg_num_chains.hpp new file mode 100644 index 0000000000..89165b4c1a --- /dev/null +++ b/src/cmdstan/arguments/arg_num_chains.hpp @@ -0,0 +1,27 @@ +#ifndef CMDSTAN_ARGUMENTS_ARG_NUM_CHAINS_HPP +#define CMDSTAN_ARGUMENTS_ARG_NUM_CHAINS_HPP + +#include + +namespace cmdstan { + +class arg_num_chains : public int_argument { + public: + arg_num_chains() : int_argument() { + _name = "num_chains"; + _description = std::string("Number of chains"); + _validity = "num_chains > 0"; + _default = "1"; + _constrained = true; + _good_value = 2.0; + _bad_value = 0.0; + _default = "1"; + _default_value = 1; + _value = _default_value; + } + + bool is_valid(int value) { return value > 0; } +}; + +} // namespace cmdstan +#endif diff --git a/src/cmdstan/arguments/arg_num_threads.hpp b/src/cmdstan/arguments/arg_num_threads.hpp new file mode 100644 index 0000000000..02d272ea9d --- /dev/null +++ b/src/cmdstan/arguments/arg_num_threads.hpp @@ -0,0 +1,29 @@ +#ifndef CMDSTAN_ARGUMENTS_ARG_NUM_THREADS_HPP +#define CMDSTAN_ARGUMENTS_ARG_NUM_THREADS_HPP + +#include + +namespace cmdstan { + +class arg_num_threads : public int_argument { + public: + arg_num_threads() : int_argument() { + _name = "num_threads"; + _description = std::string("Number of threads available to the program."); + _validity = "num_threads > 0 || num_threads == -1"; + _default = "1"; + _default_value = 1; + _good_value = 1.0; + _bad_value = -2.0; + _constrained = true; + _value = _default_value; + } +#ifdef STAN_THREADS + bool is_valid(int value) { return value > -2 && value != 0; } +#else + bool is_valid(int value) { return value == 1; } +#endif +}; + +} // namespace cmdstan +#endif diff --git a/src/cmdstan/arguments/arg_sample.hpp b/src/cmdstan/arguments/arg_sample.hpp index 5c20a191da..49ee17fe0f 100644 --- a/src/cmdstan/arguments/arg_sample.hpp +++ b/src/cmdstan/arguments/arg_sample.hpp @@ -2,6 +2,7 @@ #define CMDSTAN_ARGUMENTS_ARG_SAMPLE_HPP #include +#include #include #include #include @@ -23,6 +24,7 @@ class arg_sample : public categorical_argument { _subarguments.push_back(new arg_thin()); _subarguments.push_back(new arg_adapt()); _subarguments.push_back(new arg_sample_algo()); + _subarguments.push_back(new arg_num_chains()); } }; diff --git a/src/cmdstan/command.hpp b/src/cmdstan/command.hpp index cfaa20c71c..51c0c27eff 100644 --- a/src/cmdstan/command.hpp +++ b/src/cmdstan/command.hpp @@ -5,11 +5,14 @@ #include #include #include +#include +#include #include #include #include #include #include +#include #include #include #include @@ -21,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -49,6 +53,7 @@ #include #include #include +#include #include #include #include @@ -81,7 +86,13 @@ stan::math::mpi_cluster &get_mpi_cluster() { } #endif -std::shared_ptr get_var_context(const std::string file) { +using shared_context_ptr = std::shared_ptr; + +/** + * Given the name of a file, return a shared pointer holding the data contents. + * @param file A system file to read from. + */ +inline shared_context_ptr get_var_context(const std::string file) { std::fstream stream(file.c_str(), std::fstream::in); if (file != "" && (stream.rdstate() & std::ifstream::failbit)) { std::stringstream msg; @@ -90,19 +101,186 @@ std::shared_ptr get_var_context(const std::string file) { } if (stan::io::ends_with(".json", file)) { cmdstan::json::json_data var_context(stream); - stream.close(); - std::shared_ptr result - = std::make_shared(var_context); - return result; + return std::make_shared(var_context); } stan::io::dump var_context(stream); - stream.close(); - std::shared_ptr result - = std::make_shared(var_context); - return result; + return std::make_shared(var_context); +} + +using context_vector = std::vector; +/** + * Make a vector of shared pointers to contexts. + * @param file The name of the file. For multi-chain we will attempt to find + * {file_name}_1{file_ending} and if that fails try to use the named file as + * the data for each chain. + * @param num_chains The number of chains to run. + * @return An std vector of shared pointers to var contexts + */ +context_vector get_vec_var_context(const std::string &file, size_t num_chains) { + using stan::io::var_context; + if (num_chains == 1) { + return context_vector(1, get_var_context(file)); + } + auto make_context = [](auto &&file, auto &&stream, + auto &&file_ending) -> shared_context_ptr { + if (file_ending == ".json") { + using cmdstan::json::json_data; + return std::make_shared(json_data(stream)); + } else if (file_ending == ".csv") { + using stan::io::dump; + return std::make_shared(dump(stream)); + } else { + std::stringstream msg; + msg << "file ending of " << file_ending << " is not supported by cmdstan"; + throw std::invalid_argument(msg.str()); + using stan::io::dump; + return std::make_shared(dump(stream)); + } + }; + // use default for all chain inits + if (file == "") { + using stan::io::dump; + std::fstream stream(file.c_str(), std::fstream::in); + return context_vector(num_chains, std::make_shared(dump(stream))); + } else { + size_t file_marker_pos = file.find_last_of("."); + if (file_marker_pos > file.size()) { + std::stringstream msg; + msg << "Found: \"" << file + << "\" but user specied files must end in .json or .csv"; + throw std::invalid_argument(msg.str()); + } + std::string file_name = file.substr(0, file_marker_pos); + std::string file_ending = file.substr(file_marker_pos, file.size()); + if (file_ending != ".json" || file_ending != ".csv") { + std::stringstream msg; + msg << "file ending of " << file_ending << " is not supported by cmdstan"; + throw std::invalid_argument(msg.str()); + } + std::string file_1 + = std::string(file_name + "_" + std::to_string(1) + file_ending); + std::fstream stream_1(file_1.c_str(), std::fstream::in); + // Check if file_1 exists, if so then we'll assume num_chains of these + // exist. + if (stream_1.rdstate() & std::ifstream::failbit) { + // if that fails we will try to find a base file + std::fstream stream(file.c_str(), std::fstream::in); + if (stream.rdstate() & std::ifstream::failbit) { + std::string file_name_err + = std::string("\"" + file_1 + "\" and base file \"" + file + "\""); + std::stringstream msg; + msg << "Searching for \"" << file_name_err << std::endl; + msg << "Can't open either of specified files," << file_name_err + << std::endl; + throw std::invalid_argument(msg.str()); + } else { + return context_vector(1, make_context(file, stream, file_ending)); + } + } else { + // If we found file_1 then we'll assume file_{1...N} exists + context_vector ret; + ret.reserve(num_chains); + ret.push_back(make_context(file_1, stream_1, file_ending)); + for (size_t i = 1; i < num_chains; ++i) { + std::string file_i + = std::string(file_name + "_" + std::to_string(i) + file_ending); + std::fstream stream_i(file_1.c_str(), std::fstream::in); + // If any stream fails at this point something went wrong with file + // names. + if (stream_i.rdstate() & std::ifstream::failbit) { + std::string file_name_err = std::string( + "\"" + file_1 + "\" but cannot open \"" + file_i + "\""); + std::stringstream msg; + msg << "Found " << file_name_err << std::endl; + throw std::invalid_argument(msg.str()); + } + ret.push_back(make_context(file_i, stream_i, file_ending)); + } + return ret; + } + } + // This should not happen + using stan::io::dump; + std::fstream stream(file.c_str(), std::fstream::in); + return context_vector(num_chains, std::make_shared(dump(stream))); +} + +static constexpr int hmc_fixed_cols + = 7; // hmc sampler outputs columns __lp + 6 + +namespace internal { + +/** + * Base of helper function for getting arguments + * @param x A pointer to an argument in the argument pointer list + */ +template +inline constexpr auto get_arg_pointer(T &&x) { + return x; +} + +/** + * Given a pointer to a list of argument pointers, extract the named argument + * from the list. + * @tparam List A pointer to a list that has a valid arg(const char*) method + * @tparam Args A paramter pack of const char* + * @param arg_list The list argument to access the arg from + * @param arg1 The name of the first argument to extract + * @param args An optional pack of named arguments to access from the first arg. + */ +template +inline constexpr auto get_arg_pointer(List &&arg_list, const char *arg1, + Args &&... args) { + return get_arg_pointer(arg_list->arg(arg1), args...); } -static int hmc_fixed_cols = 7; // hmc sampler outputs columns __lp + 6 +} // namespace internal + +/** + * Given a list of argument pointers, extract the named argument from the list. + * @tparam List An list argument that has a valid arg(const char*) method + * @tparam Args A paramter pack of const char* + * @param arg_list The list argument to access the arg from + * @param arg1 The name of the first argument to extract + * @param args An optional pack of named arguments to access from the first arg. + */ +template +inline constexpr auto get_arg(List &&arg_list, const char *arg1, + Args &&... args) { + return internal::get_arg_pointer(arg_list.arg(arg1), args...); +} + +/** + * Given an argument return its value. Because all of the elements in + * our list of command line arguments is an `argument` class with no + * `value()` method, we must give the function the type of the argument class we + * want to access. + * @tparam caster The type to cast the `argument` class in the list to. + * @tparam Arg An object that inherits from `argument`. + * @param argument holds the argument to access + * @param arg_name The name of the argument to access. + */ +template +inline constexpr auto get_arg_val(Arg &&argument, const char *arg_name) { + return dynamic_cast *>(argument.arg(arg_name))->value(); +} + +/** + * Given a list of arguments, index into the args and return the value held + * by the underlying element in the list. Because all of the elements in + * our list of command line arguments is an `argument` class with no + * `value()` method, we must give the function the type of the argument class we + * want to access. + * @tparam caster The type to cast the `argument` class in the list to. + * @tparam List A pointer or object that inherits from `argument`. + * @param arg_list holds the arguments to access + * @param args A parameter pack of names of arguments to index into. + */ +template +inline constexpr auto get_arg_val(List &&arg_list, Args &&... args) { + return dynamic_cast *>(get_arg(arg_list, args...)) + ->value(); +} int command(int argc, const char *argv[]) { stan::callbacks::stream_writer info(std::cout); @@ -117,8 +295,6 @@ int command(int argc, const char *argv[]) { return 0; #endif - stan::math::init_threadpool_tbb(); - // Read arguments std::vector valid_arguments; valid_arguments.push_back(new arg_id()); @@ -126,6 +302,7 @@ int command(int argc, const char *argv[]) { valid_arguments.push_back(new arg_init()); valid_arguments.push_back(new arg_random()); valid_arguments.push_back(new arg_output()); + valid_arguments.push_back(new arg_num_threads()); #ifdef STAN_OPENCL valid_arguments.push_back(new arg_opencl()); #endif @@ -143,6 +320,49 @@ int command(int argc, const char *argv[]) { if (parser.help_printed()) return return_codes::OK; + int num_threads = get_arg_val(parser, "num_threads"); + // Need to make sure these two ways to set thread # match. + int env_threads = stan::math::internal::get_num_threads(); + if (env_threads != num_threads) { + if (env_threads != 1) { + std::stringstream thread_msg; + thread_msg << "STAN_NUM_THREADS= " << env_threads + << " but argument num_threads= " << num_threads + << ". Please either only set one or make sure they are equal."; + throw std::invalid_argument(thread_msg.str()); + } + } + stan::math::init_threadpool_tbb(num_threads); + + unsigned int num_chains = 1; + auto user_method = parser.arg("method"); + // num_chains > 1 is only supported in diag_e and dense_e of hmc + if (user_method->arg("sample")) { + num_chains + = get_arg_val(parser, "method", "sample", "num_chains"); + auto sample_arg = parser.arg("method")->arg("sample"); + list_argument *algo + = dynamic_cast(sample_arg->arg("algorithm")); + categorical_argument *adapt + = dynamic_cast(sample_arg->arg("adapt")); + const bool adapt_engaged + = dynamic_cast(adapt->arg("engaged"))->value(); + const bool is_hmc = algo->value() == "hmc"; + if (num_chains > 1) { + if (is_hmc && adapt_engaged) { + list_argument *engine + = dynamic_cast(algo->arg("hmc")->arg("engine")); + list_argument *metric + = dynamic_cast(algo->arg("hmc")->arg("metric")); + if (engine->value() != "nuts" + && (metric->value() != "dense_e" || metric->value() == "diag_e")) { + throw std::invalid_argument( + "num_chains can currently only be used for NUTS with adaptation " + "and dense_e or diag_e metric"); + } + } + } + } arg_seed *random_arg = dynamic_cast(parser.arg("random")->arg("seed")); unsigned int random_seed = random_arg->random_value(); @@ -195,31 +415,11 @@ int command(int argc, const char *argv[]) { stan::callbacks::writer init_writer; stan::callbacks::interrupt interrupt; - std::fstream output_stream( - dynamic_cast(parser.arg("output")->arg("file")) - ->value() - .c_str(), - std::fstream::out); + ////////////////////////////////////////////////// + // Initialize Model // + ////////////////////////////////////////////////// + std::string filename = get_arg_val(parser, "data", "file"); - int_argument *sig_figs_arg - = dynamic_cast(parser.arg("output")->arg("sig_figs")); - if (!sig_figs_arg->is_default()) { - output_stream << std::setprecision(sig_figs_arg->value()); - } - stan::callbacks::stream_writer sample_writer(output_stream, "# "); - - std::fstream diagnostic_stream( - dynamic_cast( - parser.arg("output")->arg("diagnostic_file")) - ->value() - .c_str(), - std::fstream::out); - stan::callbacks::stream_writer diagnostic_writer(diagnostic_stream, "# "); - - // Read input data - std::string filename( - dynamic_cast(parser.arg("data")->arg("file")) - ->value()); std::shared_ptr var_context = get_var_context(filename); @@ -229,22 +429,96 @@ int command(int argc, const char *argv[]) { std::vector model_compile_info = model.model_compile_info(); - write_stan(sample_writer); - write_model(sample_writer, model.model_name()); - write_datetime(sample_writer); - parser.print(sample_writer); - write_parallel_info(sample_writer); - write_opencl_device(sample_writer); - write_compile_info(sample_writer, model_compile_info); + ////////////////////////////////////////////////// + // Initialize Writers // + ////////////////////////////////////////////////// - write_stan(diagnostic_writer); - write_model(diagnostic_writer, model.model_name()); - parser.print(diagnostic_writer); + std::string output_file + = get_arg_val(parser, "output", "file"); + if (output_file == "") { + throw std::invalid_argument( + std::string("File output name must not be blank")); + } + std::string output_name; + std::string output_ending; + size_t output_marker_pos = output_file.find_last_of("."); + if (output_marker_pos > output_file.size()) { + output_name = output_file; + output_ending = ""; + } else { + output_name = output_file.substr(0, output_marker_pos); + output_ending = output_file.substr(output_marker_pos, output_file.size()); + } + + std::string diagnostic_file + = get_arg_val(parser, "output", "diagnostic_file"); + size_t diagnostic_marker_pos = diagnostic_file.find_last_of("."); + std::string diagnostic_name; + std::string diagnostic_ending; + // no . seperator found. + if (diagnostic_marker_pos > diagnostic_file.size()) { + diagnostic_name = diagnostic_file; + diagnostic_ending = ""; + } else { + diagnostic_name = diagnostic_file.substr(0, diagnostic_marker_pos); + diagnostic_ending + = diagnostic_file.substr(diagnostic_marker_pos, diagnostic_file.size()); + } + + std::vector> + sample_writers; + sample_writers.reserve(num_chains); + std::vector> + diagnostic_writers; + diagnostic_writers.reserve(num_chains); + std::vector init_writers{num_chains, + stan::callbacks::writer{}}; + unsigned int id = dynamic_cast(parser.arg("id"))->value(); + int_argument *sig_figs_arg + = dynamic_cast(parser.arg("output")->arg("sig_figs")); + auto name_iterator = [num_chains, id](auto i) { + if (num_chains == 1) { + return std::string(""); + } else { + return std::string("_" + std::to_string(i + id)); + } + }; + for (int i = 0; i < num_chains; i++) { + auto output_filename = output_name + name_iterator(i) + output_ending; + auto unique_fstream + = std::make_unique(output_filename, std::fstream::out); + if (!sig_figs_arg->is_default()) { + (*unique_fstream.get()) << std::setprecision(sig_figs_arg->value()); + } + sample_writers.emplace_back(std::move(unique_fstream), "# "); + if (diagnostic_file != "") { + auto diagnostic_filename + = diagnostic_name + name_iterator(i) + diagnostic_ending; + diagnostic_writers.emplace_back( + std::make_unique(diagnostic_filename, + std::fstream::out), + "# "); + } else { + diagnostic_writers.emplace_back( + std::make_unique("", std::fstream::out), "# "); + } + } + for (int i = 0; i < num_chains; i++) { + write_stan(sample_writers[i]); + write_model(sample_writers[i], model.model_name()); + write_datetime(sample_writers[i]); + parser.print(sample_writers[i]); + write_parallel_info(sample_writers[i]); + write_opencl_device(sample_writers[i]); + write_compile_info(sample_writers[i], model_compile_info); + write_stan(diagnostic_writers[i]); + write_model(diagnostic_writers[i], model.model_name()); + parser.print(diagnostic_writers[i]); + } int refresh = dynamic_cast(parser.arg("output")->arg("refresh")) ->value(); - unsigned int id = dynamic_cast(parser.arg("id"))->value(); // Read initial parameter values or user-specified radius std::string init @@ -256,12 +530,11 @@ int command(int argc, const char *argv[]) { init = ""; } catch (const boost::bad_lexical_cast &e) { } - std::shared_ptr init_context = get_var_context(init); - - // Invoke specified method - int return_code = return_codes::OK; - - if (parser.arg("method")->arg("generate_quantities")) { + std::vector> init_contexts + = get_vec_var_context(init, num_chains); + int return_code = stan::services::error_codes::CONFIG; + if (user_method->arg("generate_quantities")) { + // read sample from cmdstan csv output file string_argument *fitted_params_file = dynamic_cast( parser.arg("method")->arg("generate_quantities")->arg("fitted_params")); if (fitted_params_file->is_default()) { @@ -313,8 +586,8 @@ int command(int argc, const char *argv[]) { return_code = stan::services::standalone_generate( model, fitted_params.samples.block(0, hmc_fixed_cols, num_rows, num_cols), - random_seed, interrupt, logger, sample_writer); - } else if (parser.arg("method")->arg("diagnose")) { + random_seed, interrupt, logger, sample_writers[0]); + } else if (user_method->arg("diagnose")) { list_argument *test = dynamic_cast( parser.arg("method")->arg("diagnose")->arg("test")); @@ -326,10 +599,10 @@ int command(int argc, const char *argv[]) { = dynamic_cast(test->arg("gradient")->arg("error")) ->value(); return_code = stan::services::diagnose::diagnose( - model, *init_context, random_seed, id, init_radius, epsilon, error, - interrupt, logger, init_writer, sample_writer); + model, *(init_contexts[0]), random_seed, id, init_radius, epsilon, + error, interrupt, logger, init_writers[0], sample_writers[0]); } - } else if (parser.arg("method")->arg("optimize")) { + } else if (user_method->arg("optimize")) { list_argument *algo = dynamic_cast( parser.arg("method")->arg("optimize")->arg("algorithm")); int num_iterations = dynamic_cast( @@ -342,8 +615,9 @@ int command(int argc, const char *argv[]) { if (algo->value() == "newton") { return_code = stan::services::optimize::newton( - model, *init_context, random_seed, id, init_radius, num_iterations, - save_iterations, interrupt, logger, init_writer, sample_writer); + model, *(init_contexts[0]), random_seed, id, init_radius, + num_iterations, save_iterations, interrupt, logger, init_writers[0], + sample_writers[0]); } else if (algo->value() == "bfgs") { double init_alpha = dynamic_cast(algo->arg("bfgs")->arg("init_alpha")) @@ -365,10 +639,10 @@ int command(int argc, const char *argv[]) { ->value(); return_code = stan::services::optimize::bfgs( - model, *init_context, random_seed, id, init_radius, init_alpha, + model, *(init_contexts[0]), random_seed, id, init_radius, init_alpha, tol_obj, tol_rel_obj, tol_grad, tol_rel_grad, tol_param, num_iterations, save_iterations, refresh, interrupt, logger, - init_writer, sample_writer); + init_writers[0], sample_writers[0]); } else if (algo->value() == "lbfgs") { int history_size = dynamic_cast( algo->arg("lbfgs")->arg("history_size")) @@ -393,30 +667,26 @@ int command(int argc, const char *argv[]) { ->value(); return_code = stan::services::optimize::lbfgs( - model, *init_context, random_seed, id, init_radius, history_size, - init_alpha, tol_obj, tol_rel_obj, tol_grad, tol_rel_grad, tol_param, - num_iterations, save_iterations, refresh, interrupt, logger, - init_writer, sample_writer); + model, *(init_contexts[0]), random_seed, id, init_radius, + history_size, init_alpha, tol_obj, tol_rel_obj, tol_grad, + tol_rel_grad, tol_param, num_iterations, save_iterations, refresh, + interrupt, logger, init_writers[0], sample_writers[0]); } - } else if (parser.arg("method")->arg("sample")) { - int num_warmup = dynamic_cast( - parser.arg("method")->arg("sample")->arg("num_warmup")) - ->value(); + } else if (user_method->arg("sample")) { + auto sample_arg = parser.arg("method")->arg("sample"); + int num_warmup + = dynamic_cast(sample_arg->arg("num_warmup"))->value(); int num_samples - = dynamic_cast( - parser.arg("method")->arg("sample")->arg("num_samples")) - ->value(); - int num_thin = dynamic_cast( - parser.arg("method")->arg("sample")->arg("thin")) - ->value(); + = dynamic_cast(sample_arg->arg("num_samples"))->value(); + int num_thin + = dynamic_cast(sample_arg->arg("thin"))->value(); bool save_warmup - = dynamic_cast( - parser.arg("method")->arg("sample")->arg("save_warmup")) + = dynamic_cast(sample_arg->arg("save_warmup")) ->value(); - list_argument *algo = dynamic_cast( - parser.arg("method")->arg("sample")->arg("algorithm")); - categorical_argument *adapt = dynamic_cast( - parser.arg("method")->arg("sample")->arg("adapt")); + list_argument *algo + = dynamic_cast(sample_arg->arg("algorithm")); + categorical_argument *adapt + = dynamic_cast(sample_arg->arg("adapt")); bool adapt_engaged = dynamic_cast(adapt->arg("engaged"))->value(); @@ -426,9 +696,9 @@ int command(int argc, const char *argv[]) { "Model contains no parameters, running fixed_param sampler, " "no updates to Markov chain"); return_code = stan::services::sample::fixed_param( - model, *init_context, random_seed, id, init_radius, num_samples, - num_thin, refresh, interrupt, logger, init_writer, sample_writer, - diagnostic_writer); + model, *(init_contexts[0]), random_seed, id, init_radius, num_samples, + num_thin, refresh, interrupt, logger, init_writers[0], + sample_writers[0], diagnostic_writers[0]); } else if (algo->value() == "hmc") { list_argument *engine = dynamic_cast(algo->arg("hmc")->arg("engine")); @@ -441,18 +711,16 @@ int command(int argc, const char *argv[]) { std::string metric_filename( dynamic_cast(algo->arg("hmc")->arg("metric_file")) ->value()); - std::shared_ptr metric_context - = get_var_context(metric_filename); - - categorical_argument *adapt = dynamic_cast( - parser.arg("method")->arg("sample")->arg("adapt")); + context_vector metric_contexts + = get_vec_var_context(metric_filename, num_chains); + categorical_argument *adapt + = dynamic_cast(sample_arg->arg("adapt")); categorical_argument *hmc = dynamic_cast(algo->arg("hmc")); double stepsize = dynamic_cast(hmc->arg("stepsize"))->value(); double stepsize_jitter = dynamic_cast(hmc->arg("stepsize_jitter"))->value(); - if (adapt_engaged == true && num_warmup == 0) { info( "The number of warmup samples (num_warmup) must be greater than " @@ -466,10 +734,10 @@ int command(int argc, const char *argv[]) { ->arg("max_depth")) ->value(); return_code = stan::services::sample::hmc_nuts_dense_e( - model, *init_context, random_seed, id, init_radius, num_warmup, - num_samples, num_thin, save_warmup, refresh, stepsize, - stepsize_jitter, max_depth, interrupt, logger, init_writer, - sample_writer, diagnostic_writer); + model, *(init_contexts[0]), random_seed, id, init_radius, + num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, + stepsize_jitter, max_depth, interrupt, logger, init_writers[0], + sample_writers[0], diagnostic_writers[0]); } else if (engine->value() == "nuts" && metric->value() == "dense_e" && adapt_engaged == false && metric_supplied == true) { int max_depth = dynamic_cast( @@ -478,10 +746,10 @@ int command(int argc, const char *argv[]) { ->arg("max_depth")) ->value(); return_code = stan::services::sample::hmc_nuts_dense_e( - model, *init_context, *metric_context, random_seed, id, init_radius, - num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, - stepsize_jitter, max_depth, interrupt, logger, init_writer, - sample_writer, diagnostic_writer); + model, *(init_contexts[0]), *(metric_contexts[0]), random_seed, id, + init_radius, num_warmup, num_samples, num_thin, save_warmup, + refresh, stepsize, stepsize_jitter, max_depth, interrupt, logger, + init_writers[0], sample_writers[0], diagnostic_writers[0]); } else if (engine->value() == "nuts" && metric->value() == "dense_e" && adapt_engaged == true && metric_supplied == false) { int max_depth = dynamic_cast( @@ -505,11 +773,11 @@ int command(int argc, const char *argv[]) { unsigned int window = dynamic_cast(adapt->arg("window"))->value(); return_code = stan::services::sample::hmc_nuts_dense_e_adapt( - model, *init_context, random_seed, id, init_radius, num_warmup, - num_samples, num_thin, save_warmup, refresh, stepsize, + model, num_chains, init_contexts, random_seed, id, init_radius, + num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0, init_buffer, - term_buffer, window, interrupt, logger, init_writer, sample_writer, - diagnostic_writer); + term_buffer, window, interrupt, logger, init_writers, + sample_writers, diagnostic_writers); } else if (engine->value() == "nuts" && metric->value() == "dense_e" && adapt_engaged == true && metric_supplied == true) { int max_depth = dynamic_cast( @@ -533,11 +801,11 @@ int command(int argc, const char *argv[]) { unsigned int window = dynamic_cast(adapt->arg("window"))->value(); return_code = stan::services::sample::hmc_nuts_dense_e_adapt( - model, *init_context, *metric_context, random_seed, id, init_radius, - num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, - stepsize_jitter, max_depth, delta, gamma, kappa, t0, init_buffer, - term_buffer, window, interrupt, logger, init_writer, sample_writer, - diagnostic_writer); + model, num_chains, init_contexts, metric_contexts, random_seed, id, + init_radius, num_warmup, num_samples, num_thin, save_warmup, + refresh, stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, + t0, init_buffer, term_buffer, window, interrupt, logger, + init_writers, sample_writers, diagnostic_writers); } else if (engine->value() == "nuts" && metric->value() == "diag_e" && adapt_engaged == false && metric_supplied == false) { categorical_argument *base = dynamic_cast( @@ -545,10 +813,10 @@ int command(int argc, const char *argv[]) { int max_depth = dynamic_cast(base->arg("max_depth"))->value(); return_code = stan::services::sample::hmc_nuts_diag_e( - model, *init_context, random_seed, id, init_radius, num_warmup, - num_samples, num_thin, save_warmup, refresh, stepsize, - stepsize_jitter, max_depth, interrupt, logger, init_writer, - sample_writer, diagnostic_writer); + model, *(init_contexts[0]), random_seed, id, init_radius, + num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, + stepsize_jitter, max_depth, interrupt, logger, init_writers[0], + sample_writers[0], diagnostic_writers[0]); } else if (engine->value() == "nuts" && metric->value() == "diag_e" && adapt_engaged == false && metric_supplied == true) { categorical_argument *base = dynamic_cast( @@ -556,10 +824,10 @@ int command(int argc, const char *argv[]) { int max_depth = dynamic_cast(base->arg("max_depth"))->value(); return_code = stan::services::sample::hmc_nuts_diag_e( - model, *init_context, *metric_context, random_seed, id, init_radius, - num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, - stepsize_jitter, max_depth, interrupt, logger, init_writer, - sample_writer, diagnostic_writer); + model, *(init_contexts[0]), *(metric_contexts[0]), random_seed, id, + init_radius, num_warmup, num_samples, num_thin, save_warmup, + refresh, stepsize, stepsize_jitter, max_depth, interrupt, logger, + init_writers[0], sample_writers[0], diagnostic_writers[0]); } else if (engine->value() == "nuts" && metric->value() == "diag_e" && adapt_engaged == true && metric_supplied == false) { categorical_argument *base = dynamic_cast( @@ -582,11 +850,11 @@ int command(int argc, const char *argv[]) { unsigned int window = dynamic_cast(adapt->arg("window"))->value(); return_code = stan::services::sample::hmc_nuts_diag_e_adapt( - model, *init_context, random_seed, id, init_radius, num_warmup, - num_samples, num_thin, save_warmup, refresh, stepsize, + model, num_chains, init_contexts, random_seed, id, init_radius, + num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0, init_buffer, - term_buffer, window, interrupt, logger, init_writer, sample_writer, - diagnostic_writer); + term_buffer, window, interrupt, logger, init_writers, + sample_writers, diagnostic_writers); } else if (engine->value() == "nuts" && metric->value() == "diag_e" && adapt_engaged == true && metric_supplied == true) { categorical_argument *base = dynamic_cast( @@ -609,11 +877,11 @@ int command(int argc, const char *argv[]) { unsigned int window = dynamic_cast(adapt->arg("window"))->value(); return_code = stan::services::sample::hmc_nuts_diag_e_adapt( - model, *init_context, *metric_context, random_seed, id, init_radius, - num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, - stepsize_jitter, max_depth, delta, gamma, kappa, t0, init_buffer, - term_buffer, window, interrupt, logger, init_writer, sample_writer, - diagnostic_writer); + model, num_chains, init_contexts, metric_contexts, random_seed, id, + init_radius, num_warmup, num_samples, num_thin, save_warmup, + refresh, stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, + t0, init_buffer, term_buffer, window, interrupt, logger, + init_writers, sample_writers, diagnostic_writers); } else if (engine->value() == "nuts" && metric->value() == "unit_e" && adapt_engaged == false) { categorical_argument *base = dynamic_cast( @@ -621,10 +889,10 @@ int command(int argc, const char *argv[]) { int max_depth = dynamic_cast(base->arg("max_depth"))->value(); return_code = stan::services::sample::hmc_nuts_unit_e( - model, *init_context, random_seed, id, init_radius, num_warmup, - num_samples, num_thin, save_warmup, refresh, stepsize, - stepsize_jitter, max_depth, interrupt, logger, init_writer, - sample_writer, diagnostic_writer); + model, *(init_contexts[0]), random_seed, id, init_radius, + num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, + stepsize_jitter, max_depth, interrupt, logger, init_writers[0], + sample_writers[0], diagnostic_writers[0]); } else if (engine->value() == "nuts" && metric->value() == "unit_e" && adapt_engaged == true) { categorical_argument *base = dynamic_cast( @@ -639,10 +907,10 @@ int command(int argc, const char *argv[]) { = dynamic_cast(adapt->arg("kappa"))->value(); double t0 = dynamic_cast(adapt->arg("t0"))->value(); return_code = stan::services::sample::hmc_nuts_unit_e_adapt( - model, *init_context, random_seed, id, init_radius, num_warmup, - num_samples, num_thin, save_warmup, refresh, stepsize, + model, *(init_contexts[0]), random_seed, id, init_radius, + num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0, interrupt, - logger, init_writer, sample_writer, diagnostic_writer); + logger, init_writers[0], sample_writers[0], diagnostic_writers[0]); } else if (engine->value() == "static" && metric->value() == "dense_e" && adapt_engaged == false && metric_supplied == false) { categorical_argument *base = dynamic_cast( @@ -650,10 +918,10 @@ int command(int argc, const char *argv[]) { double int_time = dynamic_cast(base->arg("int_time"))->value(); return_code = stan::services::sample::hmc_static_dense_e( - model, *init_context, random_seed, id, init_radius, num_warmup, - num_samples, num_thin, save_warmup, refresh, stepsize, - stepsize_jitter, int_time, interrupt, logger, init_writer, - sample_writer, diagnostic_writer); + model, *(init_contexts[0]), random_seed, id, init_radius, + num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, + stepsize_jitter, int_time, interrupt, logger, init_writers[0], + sample_writers[0], diagnostic_writers[0]); } else if (engine->value() == "static" && metric->value() == "dense_e" && adapt_engaged == false && metric_supplied == true) { categorical_argument *base = dynamic_cast( @@ -661,10 +929,10 @@ int command(int argc, const char *argv[]) { double int_time = dynamic_cast(base->arg("int_time"))->value(); return_code = stan::services::sample::hmc_static_dense_e( - model, *init_context, *metric_context, random_seed, id, init_radius, - num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, - stepsize_jitter, int_time, interrupt, logger, init_writer, - sample_writer, diagnostic_writer); + model, *(init_contexts[0]), *(metric_contexts[0]), random_seed, id, + init_radius, num_warmup, num_samples, num_thin, save_warmup, + refresh, stepsize, stepsize_jitter, int_time, interrupt, logger, + init_writers[0], sample_writers[0], diagnostic_writers[0]); } else if (engine->value() == "static" && metric->value() == "dense_e" && adapt_engaged == true && metric_supplied == false) { categorical_argument *base = dynamic_cast( @@ -687,11 +955,11 @@ int command(int argc, const char *argv[]) { unsigned int window = dynamic_cast(adapt->arg("window"))->value(); return_code = stan::services::sample::hmc_static_dense_e_adapt( - model, *init_context, random_seed, id, init_radius, num_warmup, - num_samples, num_thin, save_warmup, refresh, stepsize, + model, *(init_contexts[0]), random_seed, id, init_radius, + num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, int_time, delta, gamma, kappa, t0, init_buffer, - term_buffer, window, interrupt, logger, init_writer, sample_writer, - diagnostic_writer); + term_buffer, window, interrupt, logger, init_writers[0], + sample_writers[0], diagnostic_writers[0]); } else if (engine->value() == "static" && metric->value() == "dense_e" && adapt_engaged == true && metric_supplied == true) { categorical_argument *base = dynamic_cast( @@ -714,11 +982,11 @@ int command(int argc, const char *argv[]) { unsigned int window = dynamic_cast(adapt->arg("window"))->value(); return_code = stan::services::sample::hmc_static_dense_e_adapt( - model, *init_context, *metric_context, random_seed, id, init_radius, - num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, - stepsize_jitter, int_time, delta, gamma, kappa, t0, init_buffer, - term_buffer, window, interrupt, logger, init_writer, sample_writer, - diagnostic_writer); + model, *(init_contexts[0]), *(metric_contexts[0]), random_seed, id, + init_radius, num_warmup, num_samples, num_thin, save_warmup, + refresh, stepsize, stepsize_jitter, int_time, delta, gamma, kappa, + t0, init_buffer, term_buffer, window, interrupt, logger, + init_writers[0], sample_writers[0], diagnostic_writers[0]); } else if (engine->value() == "static" && metric->value() == "diag_e" && adapt_engaged == false && metric_supplied == false) { categorical_argument *base = dynamic_cast( @@ -726,10 +994,10 @@ int command(int argc, const char *argv[]) { double int_time = dynamic_cast(base->arg("int_time"))->value(); return_code = stan::services::sample::hmc_static_diag_e( - model, *init_context, random_seed, id, init_radius, num_warmup, - num_samples, num_thin, save_warmup, refresh, stepsize, - stepsize_jitter, int_time, interrupt, logger, init_writer, - sample_writer, diagnostic_writer); + model, *(init_contexts[0]), random_seed, id, init_radius, + num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, + stepsize_jitter, int_time, interrupt, logger, init_writers[0], + sample_writers[0], diagnostic_writers[0]); } else if (engine->value() == "static" && metric->value() == "diag_e" && adapt_engaged == false && metric_supplied == true) { categorical_argument *base = dynamic_cast( @@ -737,10 +1005,10 @@ int command(int argc, const char *argv[]) { double int_time = dynamic_cast(base->arg("int_time"))->value(); return_code = stan::services::sample::hmc_static_diag_e( - model, *init_context, *metric_context, random_seed, id, init_radius, - num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, - stepsize_jitter, int_time, interrupt, logger, init_writer, - sample_writer, diagnostic_writer); + model, *(init_contexts[0]), *(metric_contexts[0]), random_seed, id, + init_radius, num_warmup, num_samples, num_thin, save_warmup, + refresh, stepsize, stepsize_jitter, int_time, interrupt, logger, + init_writers[0], sample_writers[0], diagnostic_writers[0]); } else if (engine->value() == "static" && metric->value() == "diag_e" && adapt_engaged == true && metric_supplied == false) { categorical_argument *base = dynamic_cast( @@ -763,11 +1031,11 @@ int command(int argc, const char *argv[]) { unsigned int window = dynamic_cast(adapt->arg("window"))->value(); return_code = stan::services::sample::hmc_static_diag_e_adapt( - model, *init_context, random_seed, id, init_radius, num_warmup, - num_samples, num_thin, save_warmup, refresh, stepsize, + model, *(init_contexts[0]), random_seed, id, init_radius, + num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, int_time, delta, gamma, kappa, t0, init_buffer, - term_buffer, window, interrupt, logger, init_writer, sample_writer, - diagnostic_writer); + term_buffer, window, interrupt, logger, init_writers[0], + sample_writers[0], diagnostic_writers[0]); } else if (engine->value() == "static" && metric->value() == "diag_e" && adapt_engaged == true && metric_supplied == true) { categorical_argument *base = dynamic_cast( @@ -790,11 +1058,11 @@ int command(int argc, const char *argv[]) { unsigned int window = dynamic_cast(adapt->arg("window"))->value(); return_code = stan::services::sample::hmc_static_diag_e_adapt( - model, *init_context, *metric_context, random_seed, id, init_radius, - num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, - stepsize_jitter, int_time, delta, gamma, kappa, t0, init_buffer, - term_buffer, window, interrupt, logger, init_writer, sample_writer, - diagnostic_writer); + model, *(init_contexts[0]), *(metric_contexts[0]), random_seed, id, + init_radius, num_warmup, num_samples, num_thin, save_warmup, + refresh, stepsize, stepsize_jitter, int_time, delta, gamma, kappa, + t0, init_buffer, term_buffer, window, interrupt, logger, + init_writers[0], sample_writers[0], diagnostic_writers[0]); } else if (engine->value() == "static" && metric->value() == "unit_e" && adapt_engaged == false) { categorical_argument *base = dynamic_cast( @@ -802,10 +1070,10 @@ int command(int argc, const char *argv[]) { double int_time = dynamic_cast(base->arg("int_time"))->value(); return_code = stan::services::sample::hmc_static_unit_e( - model, *init_context, random_seed, id, init_radius, num_warmup, - num_samples, num_thin, save_warmup, refresh, stepsize, - stepsize_jitter, int_time, interrupt, logger, init_writer, - sample_writer, diagnostic_writer); + model, *(init_contexts[0]), random_seed, id, init_radius, + num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, + stepsize_jitter, int_time, interrupt, logger, init_writers[0], + sample_writers[0], diagnostic_writers[0]); } else if (engine->value() == "static" && metric->value() == "unit_e" && adapt_engaged == true) { categorical_argument *base = dynamic_cast( @@ -820,13 +1088,13 @@ int command(int argc, const char *argv[]) { = dynamic_cast(adapt->arg("kappa"))->value(); double t0 = dynamic_cast(adapt->arg("t0"))->value(); return_code = stan::services::sample::hmc_static_unit_e_adapt( - model, *init_context, random_seed, id, init_radius, num_warmup, - num_samples, num_thin, save_warmup, refresh, stepsize, + model, *(init_contexts[0]), random_seed, id, init_radius, + num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, int_time, delta, gamma, kappa, t0, interrupt, - logger, init_writer, sample_writer, diagnostic_writer); + logger, init_writers[0], sample_writers[0], diagnostic_writers[0]); } } - } else if (parser.arg("method")->arg("variational")) { + } else if (user_method->arg("variational")) { list_argument *algo = dynamic_cast( parser.arg("method")->arg("variational")->arg("algorithm")); int grad_samples @@ -869,16 +1137,16 @@ int command(int argc, const char *argv[]) { if (algo->value() == "fullrank") { return_code = stan::services::experimental::advi::fullrank( - model, *init_context, random_seed, id, init_radius, grad_samples, - elbo_samples, max_iterations, tol_rel_obj, eta, adapt_engaged, - adapt_iterations, eval_elbo, output_samples, interrupt, logger, - init_writer, sample_writer, diagnostic_writer); + model, *(init_contexts[0]), random_seed, id, init_radius, + grad_samples, elbo_samples, max_iterations, tol_rel_obj, eta, + adapt_engaged, adapt_iterations, eval_elbo, output_samples, interrupt, + logger, init_writers[0], sample_writers[0], diagnostic_writers[0]); } else if (algo->value() == "meanfield") { return_code = stan::services::experimental::advi::meanfield( - model, *init_context, random_seed, id, init_radius, grad_samples, - elbo_samples, max_iterations, tol_rel_obj, eta, adapt_engaged, - adapt_iterations, eval_elbo, output_samples, interrupt, logger, - init_writer, sample_writer, diagnostic_writer); + model, *(init_contexts[0]), random_seed, id, init_radius, + grad_samples, elbo_samples, max_iterations, tol_rel_obj, eta, + adapt_engaged, adapt_iterations, eval_elbo, output_samples, interrupt, + logger, init_writers[0], sample_writers[0], diagnostic_writers[0]); } } stan::math::profile_map &profile_data = get_stan_profile_data(); @@ -894,10 +1162,9 @@ int command(int argc, const char *argv[]) { write_profiling(profile_stream, profile_data); profile_stream.close(); } - output_stream.close(); - diagnostic_stream.close(); - for (size_t i = 0; i < valid_arguments.size(); ++i) + for (size_t i = 0; i < valid_arguments.size(); ++i) { delete valid_arguments.at(i); + } #ifdef STAN_MPI cluster.stop_listen(); #endif diff --git a/src/cmdstan/write_chain.hpp b/src/cmdstan/write_chain.hpp new file mode 100644 index 0000000000..bf2756d5da --- /dev/null +++ b/src/cmdstan/write_chain.hpp @@ -0,0 +1,16 @@ +#ifndef CMDSTAN_WRITE_CHAIN_HPP +#define CMDSTAN_WRITE_CHAIN_HPP + +#include +#include +#include + +namespace cmdstan { + +inline void write_chain(stan::callbacks::writer& writer, + unsigned int chain_id) { + writer("chain_id = " + std::to_string(chain_id)); +} + +} // namespace cmdstan +#endif diff --git a/src/cmdstan/write_parallel_info.hpp b/src/cmdstan/write_parallel_info.hpp index feb595839a..929d78c1bd 100644 --- a/src/cmdstan/write_parallel_info.hpp +++ b/src/cmdstan/write_parallel_info.hpp @@ -10,13 +10,6 @@ namespace cmdstan { void write_parallel_info(stan::callbacks::writer &writer) { #ifdef STAN_MPI writer("mpi_enabled = 1"); -#else -#ifdef STAN_THREADS - std::stringstream msg_threads; - msg_threads << "num_threads = "; - msg_threads << stan::math::internal::get_num_threads(); - writer(msg_threads.str()); -#endif #endif } diff --git a/src/test/interface/multi_chain_test.cpp b/src/test/interface/multi_chain_test.cpp new file mode 100644 index 0000000000..8564cf4e8e --- /dev/null +++ b/src/test/interface/multi_chain_test.cpp @@ -0,0 +1,74 @@ +#include +#include +#include +#include +#include +#include + +TEST(interface, output_multi) { + std::vector model_path; + model_path.push_back("src"); + model_path.push_back("test"); + model_path.push_back("test-models"); + model_path.push_back("test_model"); + + std::string command + = cmdstan::test::convert_model_path(model_path) + + " id=10 sample num_warmup=200 num_samples=1 num_chains=2 random seed=1234" + + " output file=" + cmdstan::test::convert_model_path(model_path) + + ".csv diagnostic_file=" + cmdstan::test::convert_model_path(model_path) + "_diag.csv"; + + cmdstan::test::run_command_output out = cmdstan::test::run_command(command); + EXPECT_EQ(int(stan::services::error_codes::OK), out.err_code); + EXPECT_FALSE(out.hasError); + { + std::string csv_file + = cmdstan::test::convert_model_path(model_path) + "_10.csv"; + std::vector filenames; + filenames.push_back(csv_file); + stan::io::stan_csv_metadata metadata; + Eigen::VectorXd warmup_times(filenames.size()); + Eigen::VectorXd sampling_times(filenames.size()); + Eigen::VectorXi thin(filenames.size()); + stan::mcmc::chains<> chains = parse_csv_files( + filenames, metadata, warmup_times, sampling_times, thin, &std::cout); + constexpr std::array names{ + "lp__", "accept_stat__", "stepsize__", + "treedepth__", "n_leapfrog__", "divergent__", + "energy__", "mu1", "mu2", + }; + const auto chain_param_names = chains.param_names(); + for (size_t i = 0; i < 9; ++i) { + EXPECT_EQ(names[i], chain_param_names[i]); + } + std::string diag_name + = cmdstan::test::convert_model_path(model_path) + "_diag_10.csv"; + std::ifstream diag_file(diag_name); + EXPECT_TRUE(diag_file.good()); + } + { + std::string csv_file + = cmdstan::test::convert_model_path(model_path) + "_11.csv"; + std::vector filenames; + filenames.push_back(csv_file); + stan::io::stan_csv_metadata metadata; + Eigen::VectorXd warmup_times(filenames.size()); + Eigen::VectorXd sampling_times(filenames.size()); + Eigen::VectorXi thin(filenames.size()); + stan::mcmc::chains<> chains = parse_csv_files( + filenames, metadata, warmup_times, sampling_times, thin, &std::cout); + constexpr std::array names{ + "lp__", "accept_stat__", "stepsize__", + "treedepth__", "n_leapfrog__", "divergent__", + "energy__", "mu1", "mu2", + }; + const auto chain_param_names = chains.param_names(); + for (size_t i = 0; i < 9; ++i) { + EXPECT_EQ(names[i], chain_param_names[i]); + } + std::string diag_name + = cmdstan::test::convert_model_path(model_path) + "_diag_11.csv"; + std::ifstream diag_file(diag_name); + EXPECT_TRUE(diag_file.good()); + } +}