Skip to content
Closed
42 changes: 30 additions & 12 deletions src/stan/callbacks/stream_writer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ class stream_writer : public writer {
* each comment line. Default is "".
*/
explicit stream_writer(std::ostream& output,
const std::string& comment_prefix = "")
: output_(output), comment_prefix_(comment_prefix) {}
const std::string& comment_prefix = "",
bool is_empty = false)
: output_(output), comment_prefix_(comment_prefix), empty_(is_empty) {}

/**
* Virtual destructor
Expand Down Expand Up @@ -57,17 +58,28 @@ class stream_writer : public writer {
/**
* Writes the comment_prefix to the stream followed by a newline.
*/
void operator()() { output_ << comment_prefix_ << std::endl; }
void operator()() {
if (!empty_) {
output_ << comment_prefix_ << std::endl;
}
}

/**
* Writes the comment_prefix then the message followed by a newline.
*
* @param[in] message A string
*/
void operator()(const std::string& message) {
output_ << comment_prefix_ << message << std::endl;
if (!empty_) {
output_ << comment_prefix_ << message << std::endl;
}
}

/**
* Check if the writer is writing to an empty stream
*/
inline bool is_empty() const noexcept final { return empty_; }

private:
/**
* Output stream
Expand All @@ -79,6 +91,10 @@ class stream_writer : public writer {
*/
std::string comment_prefix_;

/**
* Used as check for whether output stream needs to be written to.
*/
bool empty_{false};
/**
* Writes a set of values in csv format followed by a newline.
*
Expand All @@ -89,16 +105,18 @@ class stream_writer : public writer {
*/
template <class T>
void write_vector(const std::vector<T>& v) {
if (v.empty())
return;
if (!empty_) {
if (v.empty())
return;

typename std::vector<T>::const_iterator last = v.end();
--last;
typename std::vector<T>::const_iterator last = v.end();
--last;

for (typename std::vector<T>::const_iterator it = v.begin(); it != last;
++it)
output_ << *it << ",";
output_ << v.back() << std::endl;
for (typename std::vector<T>::const_iterator it = v.begin(); it != last;
++it)
output_ << *it << ",";
output_ << v.back() << std::endl;
}
}
};

Expand Down
7 changes: 7 additions & 0 deletions src/stan/callbacks/tee_writer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ class tee_writer final : public writer {
writer2_(message);
}

/**
* Check if both writers are writing to an empty stream
*/
inline bool is_empty() const noexcept {
return writer1_.is_empty() && writer2_.is_empty();
}

private:
/**
* The first writer
Expand Down
56 changes: 33 additions & 23 deletions src/stan/callbacks/unique_stream_writer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,18 @@ class unique_stream_writer final : public writer {
* Default is "".
*/
explicit unique_stream_writer(std::unique_ptr<Stream>&& output,
const std::string& comment_prefix = "")
: output_(std::move(output)), comment_prefix_(comment_prefix) {}
const std::string& comment_prefix = "",
bool is_empty = false)
: output_(std::move(output)),
comment_prefix_(comment_prefix),
empty_(is_empty) {}

unique_stream_writer();
unique_stream_writer(unique_stream_writer& other) = delete;
unique_stream_writer(unique_stream_writer&& other)
: output_(std::move(other.output_)),
comment_prefix_(std::move(other.comment_prefix_)) {}
comment_prefix_(std::move(other.comment_prefix_)),
empty_(other.empty_) {}
/**
* Virtual destructor
*/
Expand All @@ -54,7 +58,7 @@ class unique_stream_writer final : public writer {
/**
* Get the underlying stream
*/
auto& get_stream() { return *output_; }
inline auto& get_stream() noexcept { return *output_; }

/**
* Writes a set of values in csv format followed by a newline.
Expand All @@ -70,10 +74,9 @@ class unique_stream_writer final : public writer {
* Writes the comment_prefix to the stream followed by a newline.
*/
void operator()() {
std::stringstream streamer;
streamer.precision(output_.get()->precision());
streamer << comment_prefix_ << std::endl;
*output_ << streamer.str();
if (!empty_) {
*output_ << comment_prefix_ << std::endl;
}
}

/**
Expand All @@ -82,12 +85,16 @@ class unique_stream_writer final : public writer {
* @param[in] message A string
*/
void operator()(const std::string& message) {
std::stringstream streamer;
streamer.precision(output_.get()->precision());
streamer << comment_prefix_ << message << std::endl;
*output_ << streamer.str();
if (!empty_) {
*output_ << comment_prefix_ << message << std::endl;
}
}

/**
* Check if the writer is writing to an empty stream
*/
inline bool is_empty() const noexcept { return empty_; }

private:
/**
* Output stream
Expand All @@ -99,6 +106,10 @@ class unique_stream_writer final : public writer {
*/
std::string comment_prefix_;

/**
* Used as check for whether output stream needs to be written to.
*/
bool empty_{false};
/**
* Writes a set of values in csv format followed by a newline.
*
Expand All @@ -109,18 +120,17 @@ class unique_stream_writer final : public writer {
*/
template <class T>
void write_vector(const std::vector<T>& v) {
if (v.empty())
return;
using const_iter = typename std::vector<T>::const_iterator;
const_iter last = v.end();
--last;
std::stringstream streamer;
streamer.precision(output_.get()->precision());
for (const_iter it = v.begin(); it != last; ++it) {
streamer << *it << ",";
if (!empty_) {
if (v.empty()) {
return;
}
auto last = v.end();
--last;
for (auto it = v.begin(); it != last; ++it) {
*output_ << *it << ",";
}
*output_ << v.back() << std::endl;
}
streamer << v.back() << std::endl;
*output_ << streamer.str();
}
};

Expand Down
5 changes: 5 additions & 0 deletions src/stan/callbacks/writer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ class writer {
* @param[in] message A string
*/
virtual void operator()(const std::string& message) {}

/**
* Check if the writer is writing to an empty stream
*/
virtual bool is_empty() const noexcept { return false; }
};

} // namespace callbacks
Expand Down
32 changes: 19 additions & 13 deletions src/stan/services/util/mcmc_writer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,17 +149,19 @@ class mcmc_writer {
template <class Model>
void write_diagnostic_names(stan::mcmc::sample sample,
stan::mcmc::base_mcmc& sampler, Model& model) {
std::vector<std::string> names;
if (!diagnostic_writer_.is_empty()) {
std::vector<std::string> names;

sample.get_sample_param_names(names);
sampler.get_sampler_param_names(names);
sample.get_sample_param_names(names);
sampler.get_sampler_param_names(names);

std::vector<std::string> model_names;
model.unconstrained_param_names(model_names, false, false);
std::vector<std::string> model_names;
model.unconstrained_param_names(model_names, false, false);

sampler.get_sampler_diagnostic_names(model_names, names);
sampler.get_sampler_diagnostic_names(model_names, names);

diagnostic_writer_(names);
diagnostic_writer_(names);
}
}

/**
Expand All @@ -170,13 +172,15 @@ class mcmc_writer {
*/
void write_diagnostic_params(stan::mcmc::sample& sample,
stan::mcmc::base_mcmc& sampler) {
std::vector<double> values;
if (!diagnostic_writer_.is_empty()) {
std::vector<double> values;

sample.get_sample_params(values);
sampler.get_sampler_params(values);
sampler.get_sampler_diagnostics(values);
sample.get_sample_params(values);
sampler.get_sampler_params(values);
sampler.get_sampler_diagnostics(values);

diagnostic_writer_(values);
diagnostic_writer_(values);
}
}

/**
Expand Down Expand Up @@ -247,7 +251,9 @@ class mcmc_writer {
*/
void write_timing(double warmDeltaT, double sampleDeltaT) {
write_timing(warmDeltaT, sampleDeltaT, sample_writer_);
write_timing(warmDeltaT, sampleDeltaT, diagnostic_writer_);
if (!diagnostic_writer_.is_empty()) {
write_timing(warmDeltaT, sampleDeltaT, diagnostic_writer_);
}
log_timing(warmDeltaT, sampleDeltaT);
}
};
Expand Down
19 changes: 11 additions & 8 deletions src/stan/variational/advi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,12 +392,14 @@ class advi {
= std::chrono::duration_cast<std::chrono::milliseconds>(end - start)
.count()
/ 1000.0;
std::vector<double> print_vector;
print_vector.clear();
print_vector.push_back(iter_counter);
print_vector.push_back(delta_t);
print_vector.push_back(elbo);
diagnostic_writer(print_vector);
if (!diagnostic_writer.is_empty()) {
std::vector<double> print_vector;
print_vector.clear();
print_vector.push_back(iter_counter);
print_vector.push_back(delta_t);
print_vector.push_back(elbo);
diagnostic_writer(print_vector);
}

if (delta_elbo_ave < tol_rel_obj) {
ss << " MEAN ELBO CONVERGED";
Expand Down Expand Up @@ -459,8 +461,9 @@ class advi {
double tol_rel_obj, int max_iterations, callbacks::logger& logger,
callbacks::writer& parameter_writer,
callbacks::writer& diagnostic_writer) const {
diagnostic_writer("iter,time_in_seconds,ELBO");

if (!diagnostic_writer.is_empty()) {
diagnostic_writer("iter,time_in_seconds,ELBO");
}
// Initialize variational approximation
Q variational = Q(cont_params_);

Expand Down
16 changes: 15 additions & 1 deletion src/test/unit/callbacks/stream_writer_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
class StanInterfaceCallbacksStreamWriter : public ::testing::Test {
public:
StanInterfaceCallbacksStreamWriter()
: ss(), writer(ss), writer_prefix(ss, "# ") {}
: ss(),
writer(ss),
writer_prefix(ss, "# "),
empty_writer(ss, "# ", true) {}

void SetUp() {
ss.str(std::string());
Expand All @@ -16,6 +19,7 @@ class StanInterfaceCallbacksStreamWriter : public ::testing::Test {
std::stringstream ss;
stan::callbacks::stream_writer writer;
stan::callbacks::stream_writer writer_prefix;
stan::callbacks::stream_writer empty_writer;
};

TEST_F(StanInterfaceCallbacksStreamWriter, double_vector) {
Expand All @@ -28,6 +32,16 @@ TEST_F(StanInterfaceCallbacksStreamWriter, double_vector) {
EXPECT_EQ("0,1,2,3,4\n", ss.str());
}

TEST_F(StanInterfaceCallbacksStreamWriter, empty_vector) {
const int N = 5;
std::vector<double> x;
for (int n = 0; n < N; ++n)
x.push_back(n);

EXPECT_NO_THROW(empty_writer(x));
EXPECT_EQ("", ss.str());
}

TEST_F(StanInterfaceCallbacksStreamWriter, string_vector) {
const int N = 5;
std::vector<std::string> x;
Expand Down
Loading