diff --git a/src/stan/callbacks/stream_writer.hpp b/src/stan/callbacks/stream_writer.hpp index 6519531c0ac..35a06a3e64b 100644 --- a/src/stan/callbacks/stream_writer.hpp +++ b/src/stan/callbacks/stream_writer.hpp @@ -24,8 +24,9 @@ class stream_writer : public writer { * each comment line. Default is "". */ explicit stream_writer(std::ostream& output, - const std::string& comment_prefix = "") - : output_(output), comment_prefix_(comment_prefix) {} + const std::string& comment_prefix = "", + bool is_empty = false) + : output_(output), comment_prefix_(comment_prefix), empty_(is_empty) {} /** * Virtual destructor @@ -57,7 +58,11 @@ class stream_writer : public writer { /** * Writes the comment_prefix to the stream followed by a newline. */ - void operator()() { output_ << comment_prefix_ << std::endl; } + void operator()() { + if (!empty_) { + output_ << comment_prefix_ << std::endl; + } + } /** * Writes the comment_prefix then the message followed by a newline. @@ -65,9 +70,16 @@ class stream_writer : public writer { * @param[in] message A string */ void operator()(const std::string& message) { - output_ << comment_prefix_ << message << std::endl; + if (!empty_) { + output_ << comment_prefix_ << message << std::endl; + } } + /** + * Check if the writer is writing to an empty stream + */ + inline bool is_empty() const noexcept final { return empty_; } + private: /** * Output stream @@ -79,6 +91,10 @@ class stream_writer : public writer { */ std::string comment_prefix_; + /** + * Used as check for whether output stream needs to be written to. + */ + bool empty_{false}; /** * Writes a set of values in csv format followed by a newline. * @@ -89,16 +105,18 @@ class stream_writer : public writer { */ template void write_vector(const std::vector& v) { - if (v.empty()) - return; + if (!empty_) { + if (v.empty()) + return; - typename std::vector::const_iterator last = v.end(); - --last; + typename std::vector::const_iterator last = v.end(); + --last; - for (typename std::vector::const_iterator it = v.begin(); it != last; - ++it) - output_ << *it << ","; - output_ << v.back() << std::endl; + for (typename std::vector::const_iterator it = v.begin(); it != last; + ++it) + output_ << *it << ","; + output_ << v.back() << std::endl; + } } }; diff --git a/src/stan/callbacks/tee_writer.hpp b/src/stan/callbacks/tee_writer.hpp index c491832e9ec..d0acd77baf4 100644 --- a/src/stan/callbacks/tee_writer.hpp +++ b/src/stan/callbacks/tee_writer.hpp @@ -49,6 +49,13 @@ class tee_writer final : public writer { writer2_(message); } + /** + * Check if both writers are writing to an empty stream + */ + inline bool is_empty() const noexcept { + return writer1_.is_empty() && writer2_.is_empty(); + } + private: /** * The first writer diff --git a/src/stan/callbacks/unique_stream_writer.hpp b/src/stan/callbacks/unique_stream_writer.hpp index 60e1c4e3843..79fee41fcdc 100644 --- a/src/stan/callbacks/unique_stream_writer.hpp +++ b/src/stan/callbacks/unique_stream_writer.hpp @@ -27,14 +27,18 @@ class unique_stream_writer final : public writer { * Default is "". */ explicit unique_stream_writer(std::unique_ptr&& output, - const std::string& comment_prefix = "") - : output_(std::move(output)), comment_prefix_(comment_prefix) {} + const std::string& comment_prefix = "", + bool is_empty = false) + : output_(std::move(output)), + comment_prefix_(comment_prefix), + empty_(is_empty) {} unique_stream_writer(); unique_stream_writer(unique_stream_writer& other) = delete; unique_stream_writer(unique_stream_writer&& other) : output_(std::move(other.output_)), - comment_prefix_(std::move(other.comment_prefix_)) {} + comment_prefix_(std::move(other.comment_prefix_)), + empty_(other.empty_) {} /** * Virtual destructor */ @@ -54,7 +58,7 @@ class unique_stream_writer final : public writer { /** * Get the underlying stream */ - auto& get_stream() { return *output_; } + inline auto& get_stream() noexcept { return *output_; } /** * Writes a set of values in csv format followed by a newline. @@ -70,10 +74,9 @@ class unique_stream_writer final : public writer { * Writes the comment_prefix to the stream followed by a newline. */ void operator()() { - std::stringstream streamer; - streamer.precision(output_.get()->precision()); - streamer << comment_prefix_ << std::endl; - *output_ << streamer.str(); + if (!empty_) { + *output_ << comment_prefix_ << std::endl; + } } /** @@ -82,12 +85,16 @@ class unique_stream_writer final : public writer { * @param[in] message A string */ void operator()(const std::string& message) { - std::stringstream streamer; - streamer.precision(output_.get()->precision()); - streamer << comment_prefix_ << message << std::endl; - *output_ << streamer.str(); + if (!empty_) { + *output_ << comment_prefix_ << message << std::endl; + } } + /** + * Check if the writer is writing to an empty stream + */ + inline bool is_empty() const noexcept { return empty_; } + private: /** * Output stream @@ -99,6 +106,10 @@ class unique_stream_writer final : public writer { */ std::string comment_prefix_; + /** + * Used as check for whether output stream needs to be written to. + */ + bool empty_{false}; /** * Writes a set of values in csv format followed by a newline. * @@ -109,18 +120,17 @@ class unique_stream_writer final : public writer { */ template void write_vector(const std::vector& v) { - if (v.empty()) - return; - using const_iter = typename std::vector::const_iterator; - const_iter last = v.end(); - --last; - std::stringstream streamer; - streamer.precision(output_.get()->precision()); - for (const_iter it = v.begin(); it != last; ++it) { - streamer << *it << ","; + if (!empty_) { + if (v.empty()) { + return; + } + auto last = v.end(); + --last; + for (auto it = v.begin(); it != last; ++it) { + *output_ << *it << ","; + } + *output_ << v.back() << std::endl; } - streamer << v.back() << std::endl; - *output_ << streamer.str(); } }; diff --git a/src/stan/callbacks/writer.hpp b/src/stan/callbacks/writer.hpp index a12108b256f..4e0df8878f2 100644 --- a/src/stan/callbacks/writer.hpp +++ b/src/stan/callbacks/writer.hpp @@ -45,6 +45,11 @@ class writer { * @param[in] message A string */ virtual void operator()(const std::string& message) {} + + /** + * Check if the writer is writing to an empty stream + */ + virtual bool is_empty() const noexcept { return false; } }; } // namespace callbacks diff --git a/src/stan/services/util/mcmc_writer.hpp b/src/stan/services/util/mcmc_writer.hpp index 22b0b4fbb94..77e7e9d2b4c 100644 --- a/src/stan/services/util/mcmc_writer.hpp +++ b/src/stan/services/util/mcmc_writer.hpp @@ -149,17 +149,19 @@ class mcmc_writer { template void write_diagnostic_names(stan::mcmc::sample sample, stan::mcmc::base_mcmc& sampler, Model& model) { - std::vector names; + if (!diagnostic_writer_.is_empty()) { + std::vector names; - sample.get_sample_param_names(names); - sampler.get_sampler_param_names(names); + sample.get_sample_param_names(names); + sampler.get_sampler_param_names(names); - std::vector model_names; - model.unconstrained_param_names(model_names, false, false); + std::vector model_names; + model.unconstrained_param_names(model_names, false, false); - sampler.get_sampler_diagnostic_names(model_names, names); + sampler.get_sampler_diagnostic_names(model_names, names); - diagnostic_writer_(names); + diagnostic_writer_(names); + } } /** @@ -170,13 +172,15 @@ class mcmc_writer { */ void write_diagnostic_params(stan::mcmc::sample& sample, stan::mcmc::base_mcmc& sampler) { - std::vector values; + if (!diagnostic_writer_.is_empty()) { + std::vector values; - sample.get_sample_params(values); - sampler.get_sampler_params(values); - sampler.get_sampler_diagnostics(values); + sample.get_sample_params(values); + sampler.get_sampler_params(values); + sampler.get_sampler_diagnostics(values); - diagnostic_writer_(values); + diagnostic_writer_(values); + } } /** @@ -247,7 +251,9 @@ class mcmc_writer { */ void write_timing(double warmDeltaT, double sampleDeltaT) { write_timing(warmDeltaT, sampleDeltaT, sample_writer_); - write_timing(warmDeltaT, sampleDeltaT, diagnostic_writer_); + if (!diagnostic_writer_.is_empty()) { + write_timing(warmDeltaT, sampleDeltaT, diagnostic_writer_); + } log_timing(warmDeltaT, sampleDeltaT); } }; diff --git a/src/stan/variational/advi.hpp b/src/stan/variational/advi.hpp index 4dca19d0335..35fdc19be67 100644 --- a/src/stan/variational/advi.hpp +++ b/src/stan/variational/advi.hpp @@ -392,12 +392,14 @@ class advi { = std::chrono::duration_cast(end - start) .count() / 1000.0; - std::vector print_vector; - print_vector.clear(); - print_vector.push_back(iter_counter); - print_vector.push_back(delta_t); - print_vector.push_back(elbo); - diagnostic_writer(print_vector); + if (!diagnostic_writer.is_empty()) { + std::vector print_vector; + print_vector.clear(); + print_vector.push_back(iter_counter); + print_vector.push_back(delta_t); + print_vector.push_back(elbo); + diagnostic_writer(print_vector); + } if (delta_elbo_ave < tol_rel_obj) { ss << " MEAN ELBO CONVERGED"; @@ -459,8 +461,9 @@ class advi { double tol_rel_obj, int max_iterations, callbacks::logger& logger, callbacks::writer& parameter_writer, callbacks::writer& diagnostic_writer) const { - diagnostic_writer("iter,time_in_seconds,ELBO"); - + if (!diagnostic_writer.is_empty()) { + diagnostic_writer("iter,time_in_seconds,ELBO"); + } // Initialize variational approximation Q variational = Q(cont_params_); diff --git a/src/test/unit/callbacks/stream_writer_test.cpp b/src/test/unit/callbacks/stream_writer_test.cpp index e0bc01bb9be..18f7839362f 100644 --- a/src/test/unit/callbacks/stream_writer_test.cpp +++ b/src/test/unit/callbacks/stream_writer_test.cpp @@ -5,7 +5,10 @@ class StanInterfaceCallbacksStreamWriter : public ::testing::Test { public: StanInterfaceCallbacksStreamWriter() - : ss(), writer(ss), writer_prefix(ss, "# ") {} + : ss(), + writer(ss), + writer_prefix(ss, "# "), + empty_writer(ss, "# ", true) {} void SetUp() { ss.str(std::string()); @@ -16,6 +19,7 @@ class StanInterfaceCallbacksStreamWriter : public ::testing::Test { std::stringstream ss; stan::callbacks::stream_writer writer; stan::callbacks::stream_writer writer_prefix; + stan::callbacks::stream_writer empty_writer; }; TEST_F(StanInterfaceCallbacksStreamWriter, double_vector) { @@ -28,6 +32,16 @@ TEST_F(StanInterfaceCallbacksStreamWriter, double_vector) { EXPECT_EQ("0,1,2,3,4\n", ss.str()); } +TEST_F(StanInterfaceCallbacksStreamWriter, empty_vector) { + const int N = 5; + std::vector x; + for (int n = 0; n < N; ++n) + x.push_back(n); + + EXPECT_NO_THROW(empty_writer(x)); + EXPECT_EQ("", ss.str()); +} + TEST_F(StanInterfaceCallbacksStreamWriter, string_vector) { const int N = 5; std::vector x; diff --git a/src/test/unit/callbacks/tee_writer_test.cpp b/src/test/unit/callbacks/tee_writer_test.cpp index 0dc4dbd4e2f..276f2c618ad 100644 --- a/src/test/unit/callbacks/tee_writer_test.cpp +++ b/src/test/unit/callbacks/tee_writer_test.cpp @@ -5,26 +5,53 @@ namespace test { class mock_writer : public stan::callbacks::writer { public: int N; + bool empty_; mock_writer() : N(0) {} + mock_writer(bool is_empty) : N(0), empty_(is_empty) {} - void operator()(const std::vector& names) { ++N; } + void operator()(const std::vector& names) { + if (!empty_) { + ++N; + } + } - void operator()(const std::vector& state) { ++N; } + void operator()(const std::vector& state) { + if (!empty_) { + ++N; + } + } - void operator()() { ++N; } + void operator()() { + if (!empty_) { + ++N; + } + } - void operator()(const std::string& message) { ++N; } + void operator()(const std::string& message) { + if (!empty_) { + ++N; + } + } + + inline bool is_empty() const noexcept { return false; } }; } // namespace test class StanCallbacksTeeWriter : public ::testing::Test { public: StanCallbacksTeeWriter() - : writer1(), writer2(), tee_writer(writer1, writer2) {} + : writer1(), + writer2(), + tee_writer(writer1, writer2), + empty_writer1(true), + empty_writer2(true), + empty_tee_writer(empty_writer1, empty_writer2) {} test::mock_writer writer1, writer2; stan::callbacks::tee_writer tee_writer; + test::mock_writer empty_writer1, empty_writer2; + stan::callbacks::tee_writer empty_tee_writer; }; TEST_F(StanCallbacksTeeWriter, names) { @@ -35,6 +62,14 @@ TEST_F(StanCallbacksTeeWriter, names) { EXPECT_EQ(1, writer2.N); } +TEST_F(StanCallbacksTeeWriter, empty_names) { + std::vector names; + + empty_tee_writer(names); + EXPECT_EQ(0, empty_writer1.N); + EXPECT_EQ(0, empty_writer2.N); +} + TEST_F(StanCallbacksTeeWriter, state) { std::vector state; diff --git a/src/test/unit/callbacks/unique_stream_writer_test.cpp b/src/test/unit/callbacks/unique_stream_writer_test.cpp index 43d783f48c4..7505e63e3e6 100644 --- a/src/test/unit/callbacks/unique_stream_writer_test.cpp +++ b/src/test/unit/callbacks/unique_stream_writer_test.cpp @@ -5,15 +5,21 @@ class StanInterfaceCallbacksStreamWriter : public ::testing::Test { public: StanInterfaceCallbacksStreamWriter() - : writer(std::make_unique(std::stringstream{})) {} + : writer(std::make_unique(std::stringstream{})), + empty_writer(std::make_unique(std::stringstream{}), + "#", true) {} void SetUp() { static_cast(writer.get_stream()).str(std::string()); static_cast(writer.get_stream()).clear(); + static_cast(empty_writer.get_stream()) + .str(std::string()); + static_cast(empty_writer.get_stream()).clear(); } void TearDown() {} stan::callbacks::unique_stream_writer writer; + stan::callbacks::unique_stream_writer empty_writer; }; TEST_F(StanInterfaceCallbacksStreamWriter, double_vector) { @@ -27,6 +33,17 @@ TEST_F(StanInterfaceCallbacksStreamWriter, double_vector) { static_cast(writer.get_stream()).str()); } +TEST_F(StanInterfaceCallbacksStreamWriter, empty_test) { + const int N = 5; + std::vector x; + for (int n = 0; n < N; ++n) + x.push_back(n); + + EXPECT_NO_THROW(empty_writer(x)); + EXPECT_EQ("", + static_cast(empty_writer.get_stream()).str()); +} + TEST_F(StanInterfaceCallbacksStreamWriter, double_vector_precision2) { const int N = 5; std::vector x{1.23456789, 2.3456789, 3.45678910, 4.567890123}; diff --git a/src/test/unit/services/experimental/advi/meanfield_test.cpp b/src/test/unit/services/experimental/advi/meanfield_test.cpp index 118b497922f..cf820561d32 100644 --- a/src/test/unit/services/experimental/advi/meanfield_test.cpp +++ b/src/test/unit/services/experimental/advi/meanfield_test.cpp @@ -1,8 +1,8 @@ #include -#include #include #include #include +#include class ServicesExperimentalAdvi : public testing::Test { public: diff --git a/src/test/unit/services/instrumented_callbacks.hpp b/src/test/unit/services/instrumented_callbacks.hpp index 797959ef0ee..61d7ced0e10 100644 --- a/src/test/unit/services/instrumented_callbacks.hpp +++ b/src/test/unit/services/instrumented_callbacks.hpp @@ -41,6 +41,8 @@ class instrumented_writer : public stan::callbacks::writer { public: instrumented_writer() {} + inline bool is_empty() const noexcept { return false; } + void operator()(const std::string& key, double value) { counter_["string_double"]++; string_double.push_back(std::make_pair(key, value));