diff --git a/examples/bernoulli/.gitignore b/examples/bernoulli/.gitignore index 633613e748..88f8b7cd46 100644 --- a/examples/bernoulli/.gitignore +++ b/examples/bernoulli/.gitignore @@ -1,2 +1,3 @@ bernoulli bernoulli.hpp +output_config.json diff --git a/src/cmdstan/arguments/arg_output.hpp b/src/cmdstan/arguments/arg_output.hpp index ff8b642cac..4ae1becd3c 100644 --- a/src/cmdstan/arguments/arg_output.hpp +++ b/src/cmdstan/arguments/arg_output.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include namespace cmdstan { @@ -21,6 +22,11 @@ class arg_output : public categorical_argument { _subarguments.push_back(new arg_refresh()); _subarguments.push_back(new arg_output_sig_figs()); _subarguments.push_back(new arg_profile_file()); + _subarguments.push_back(new arg_single_bool( + "save_cmdstan_config", + "Save the CmdStan configuration (parsed arguments + default values) as " + "JSON alongside the output files", + false)); } }; diff --git a/src/cmdstan/arguments/argument.hpp b/src/cmdstan/arguments/argument.hpp index f7c4366a4d..c98267011d 100644 --- a/src/cmdstan/arguments/argument.hpp +++ b/src/cmdstan/arguments/argument.hpp @@ -2,6 +2,7 @@ #define CMDSTAN_ARGUMENTS_ARGUMENT_HPP #include +#include #include #include #include @@ -26,6 +27,8 @@ class argument { const std::string &prefix) = 0; + virtual void print(stan::callbacks::structured_writer &j) = 0; + virtual void print_help(stan::callbacks::writer &w, const int depth, const bool recurse) = 0; diff --git a/src/cmdstan/arguments/argument_parser.hpp b/src/cmdstan/arguments/argument_parser.hpp index 07b5ca9805..f6f870158a 100644 --- a/src/cmdstan/arguments/argument_parser.hpp +++ b/src/cmdstan/arguments/argument_parser.hpp @@ -140,6 +140,12 @@ class argument_parser { } } + void print(stan::callbacks::structured_writer &j) { + for (size_t i = 0; i < _arguments.size(); ++i) { + _arguments.at(i)->print(j); + } + } + void print_help(stan::callbacks::writer &w, bool recurse) { for (size_t i = 0; i < _arguments.size(); ++i) { _arguments.at(i)->print_help(w, 1, recurse); diff --git a/src/cmdstan/arguments/categorical_argument.hpp b/src/cmdstan/arguments/categorical_argument.hpp index eb57b9ed2f..98328c215a 100644 --- a/src/cmdstan/arguments/categorical_argument.hpp +++ b/src/cmdstan/arguments/categorical_argument.hpp @@ -28,6 +28,14 @@ class categorical_argument : public argument { (*it)->print(w, depth + 1, prefix); } + void print(stan::callbacks::structured_writer &j) { + j.begin_record(_name); + for (std::vector::iterator it = _subarguments.begin(); + it != _subarguments.end(); ++it) + (*it)->print(j); + j.end_record(); + } + void print_help(stan::callbacks::writer &w, const int depth, const bool recurse) { std::string indent(indent_width * depth, ' '); diff --git a/src/cmdstan/arguments/list_argument.hpp b/src/cmdstan/arguments/list_argument.hpp index 0eccf24f56..b16726cf9f 100644 --- a/src/cmdstan/arguments/list_argument.hpp +++ b/src/cmdstan/arguments/list_argument.hpp @@ -28,6 +28,13 @@ class list_argument : public valued_argument { _values.at(_cursor)->print(w, depth + 1, prefix); } + virtual void print(stan::callbacks::structured_writer &j) { + j.begin_record(_name); + j.write("value", print_value()); + _values.at(_cursor)->print(j); + j.end_record(); + } + void print_help(stan::callbacks::writer &w, int depth, bool recurse) { _default = _values.at(_default_cursor)->name(); diff --git a/src/cmdstan/arguments/singleton_argument.hpp b/src/cmdstan/arguments/singleton_argument.hpp index f06f175964..8e5964eaac 100644 --- a/src/cmdstan/arguments/singleton_argument.hpp +++ b/src/cmdstan/arguments/singleton_argument.hpp @@ -115,6 +115,10 @@ class singleton_argument : public valued_argument { bool is_default() { return _value == _default_value; } + virtual void print(stan::callbacks::structured_writer &j) { + j.write(_name, _value); + } + protected: std::string _validity; virtual bool is_valid(T value) { return true; } diff --git a/src/cmdstan/arguments/unvalued_argument.hpp b/src/cmdstan/arguments/unvalued_argument.hpp index 3e84eb0cc7..b0069a51da 100644 --- a/src/cmdstan/arguments/unvalued_argument.hpp +++ b/src/cmdstan/arguments/unvalued_argument.hpp @@ -15,6 +15,8 @@ class unvalued_argument : public argument { void print(stan::callbacks::writer &w, const int depth, const std::string &prefix) {} + void print(stan::callbacks::structured_writer &j) {} + void print_help(stan::callbacks::writer &w, const int depth, const bool recurse = false) { std::string indent(indent_width * depth, ' '); diff --git a/src/cmdstan/arguments/valued_argument.hpp b/src/cmdstan/arguments/valued_argument.hpp index 93adc9c2b5..6af23ff1c2 100644 --- a/src/cmdstan/arguments/valued_argument.hpp +++ b/src/cmdstan/arguments/valued_argument.hpp @@ -19,6 +19,10 @@ class valued_argument : public argument { w(message); } + virtual void print(stan::callbacks::structured_writer &j) { + j.write(_name, print_value()); + } + virtual void print_help(stan::callbacks::writer &w, const int depth, const bool recurse = false) { std::string indent(indent_width * depth, ' '); diff --git a/src/cmdstan/command.hpp b/src/cmdstan/command.hpp index 4e7ac5226a..f7ff23d658 100644 --- a/src/cmdstan/command.hpp +++ b/src/cmdstan/command.hpp @@ -12,15 +12,11 @@ #include #include #include -#include -#include -#include #include -#include -#include -#include #include -#include +#include +#include +#include #include #include #include @@ -32,7 +28,6 @@ #include #include #include -#include #include #include #include @@ -66,8 +61,6 @@ #include #include -#include - #ifdef STAN_MPI #include #include @@ -210,16 +203,25 @@ int command(int argc, const char *argv[]) { std::vector> init_contexts = get_vec_var_context(init, num_chains, id); - std::vector model_compile_info = model.model_compile_info(); + + if (get_arg_val(parser, "output", "save_cmdstan_config")) { + auto filename = get_basename_suffix( + get_arg_val(parser, "output", "file")) + .first + + "_config.json"; + auto ofs_args = std::make_unique(filename); + if (sig_figs > -1) { + ofs_args->precision(sig_figs); + } + + stan::callbacks::json_writer json_args(std::move(ofs_args)); + + write_config(json_args, parser, model); + } for (int i = 0; i < num_chains; ++i) { - write_stan(sample_writers[i]); - write_model(sample_writers[i], model.model_name()); - write_datetime(sample_writers[i]); - parser.print(sample_writers[i]); - write_parallel_info(sample_writers[i]); - write_opencl_device(sample_writers[i]); - write_compile_info(sample_writers[i], model_compile_info); + write_config(sample_writers[i], parser, model); + write_stan(diagnostic_csv_writers[i]); write_model(diagnostic_csv_writers[i], model.model_name()); parser.print(diagnostic_csv_writers[i]); @@ -279,10 +281,8 @@ int command(int argc, const char *argv[]) { ofs->precision(sig_figs); stan::callbacks::unique_stream_writer pathfinder_writer( std::move(ofs), "# "); - write_stan(pathfinder_writer); - write_model(pathfinder_writer, model.model_name()); - write_datetime(pathfinder_writer); - parser.print(pathfinder_writer); + write_config(pathfinder_writer, parser, model); + return_code = stan::services::pathfinder::pathfinder_lbfgs_multi< stan::model::model_base>( model, init_contexts, random_seed, id, init_radius, history_size, diff --git a/src/cmdstan/command_helper.hpp b/src/cmdstan/command_helper.hpp index 3aa6946dcc..a05b248465 100644 --- a/src/cmdstan/command_helper.hpp +++ b/src/cmdstan/command_helper.hpp @@ -3,16 +3,6 @@ #include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include #include #include #include diff --git a/src/cmdstan/write_chain.hpp b/src/cmdstan/write_chain.hpp deleted file mode 100644 index bf2756d5da..0000000000 --- a/src/cmdstan/write_chain.hpp +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef CMDSTAN_WRITE_CHAIN_HPP -#define CMDSTAN_WRITE_CHAIN_HPP - -#include -#include -#include - -namespace cmdstan { - -inline void write_chain(stan::callbacks::writer& writer, - unsigned int chain_id) { - writer("chain_id = " + std::to_string(chain_id)); -} - -} // namespace cmdstan -#endif diff --git a/src/cmdstan/write_config.hpp b/src/cmdstan/write_config.hpp new file mode 100644 index 0000000000..bfcdc8515b --- /dev/null +++ b/src/cmdstan/write_config.hpp @@ -0,0 +1,50 @@ +#ifndef CMDSTAN_WRITE_CONFIG_HPP +#define CMDSTAN_WRITE_CONFIG_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace cmdstan { + +inline void write_config(stan::callbacks::writer &writer, + argument_parser &parser, + stan::model::model_base &model) { + write_stan(writer); + write_model(writer, model.model_name()); + write_datetime(writer); + parser.print(writer); + write_parallel_info(writer); + write_opencl_device(writer); + write_compile_info(writer, model); +} + +inline void write_config(stan::callbacks::structured_writer &writer, + argument_parser &parser, + stan::model::model_base &model) { + writer.begin_record(); + write_stan(writer); + writer.write("model_name", model.model_name()); + writer.write("start_datetime", current_datetime()); + parser.print(writer); +#ifdef STAN_MPI + writer.write("mpi_enabled", true); +#else + writer.write("mpi_enabled", false); +#endif + write_opencl_device(writer); + write_compile_info(writer, model); + writer.end_record(); +} + +} // namespace cmdstan +#endif diff --git a/src/cmdstan/write_datetime.hpp b/src/cmdstan/write_datetime.hpp index ead43d2f4f..869f0be9b6 100644 --- a/src/cmdstan/write_datetime.hpp +++ b/src/cmdstan/write_datetime.hpp @@ -9,18 +9,23 @@ namespace cmdstan { -void write_datetime(stan::callbacks::writer& writer) { +std::string current_datetime() { const std::time_t current_datetime = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()); std::tm* curr_tm = std::gmtime(¤t_datetime); std::stringstream current_datetime_msg; - current_datetime_msg << "start_datetime = " << std::setfill('0') - << (1900 + curr_tm->tm_year) << "-" << std::setw(2) - << (curr_tm->tm_mon + 1) << "-" << std::setw(2) - << curr_tm->tm_mday << " " << std::setw(2) - << curr_tm->tm_hour << ":" << std::setw(2) - << curr_tm->tm_min << ":" << std::setw(2) + current_datetime_msg << std::setfill('0') << (1900 + curr_tm->tm_year) << "-" + << std::setw(2) << (curr_tm->tm_mon + 1) << "-" + << std::setw(2) << curr_tm->tm_mday << " " + << std::setw(2) << curr_tm->tm_hour << ":" + << std::setw(2) << curr_tm->tm_min << ":" << std::setw(2) << curr_tm->tm_sec << " UTC"; + return current_datetime_msg.str(); +} + +void write_datetime(stan::callbacks::writer& writer) { + std::stringstream current_datetime_msg; + current_datetime_msg << "start_datetime = " << current_datetime(); writer(current_datetime_msg.str()); } diff --git a/src/cmdstan/write_model_compile_info.hpp b/src/cmdstan/write_model_compile_info.hpp index 9bebfa0ed9..625dcdba9b 100644 --- a/src/cmdstan/write_model_compile_info.hpp +++ b/src/cmdstan/write_model_compile_info.hpp @@ -2,16 +2,32 @@ #define CMDSTAN_WRITE_COMPILE_INFO_HPP #include -#include -#include -#include +#include +#include namespace cmdstan { void write_compile_info(stan::callbacks::writer& writer, - std::vector& compile_info) { - for (int i = 0; i < compile_info.size(); i++) { - writer(compile_info[i]); + stan::model::model_base& model) { + auto compile_info = model.model_compile_info(); + for (auto s : compile_info) { + writer(s); } } + +void write_compile_info(stan::callbacks::structured_writer& writer, + stan::model::model_base& model) { + auto compile_info = model.model_compile_info(); + for (auto s : compile_info) { + // split on "=" + std::string::size_type pos = s.find(" = "); + if (pos == std::string::npos) { + continue; + } + std::string key = s.substr(0, pos); + std::string value = s.substr(pos + 3); + writer.write(key, value); + } +} + } // namespace cmdstan #endif diff --git a/src/cmdstan/write_opencl_device.hpp b/src/cmdstan/write_opencl_device.hpp index 89cb744574..f7a2576d7f 100644 --- a/src/cmdstan/write_opencl_device.hpp +++ b/src/cmdstan/write_opencl_device.hpp @@ -2,6 +2,7 @@ #define CMDSTAN_WRITE_OPENCL_DEVICE_HPP #include +#include #ifdef STAN_OPENCL #include #endif @@ -27,5 +28,21 @@ void write_opencl_device(stan::callbacks::writer &writer) { #endif } +void write_opencl_device(stan::callbacks::structured_writer &writer) { +#ifdef STAN_OPENCL + if ((stan::math::opencl_context.platform().size() > 0) + && (stan::math::opencl_context.device().size() > 0)) { + std::stringstream msg_opencl_platform; + msg_opencl_platform + << stan::math::opencl_context.platform()[0].getInfo(); + writer.write("opencl_platform_name", msg_opencl_platform.str()); + std::stringstream msg_opencl_device; + msg_opencl_device + << stan::math::opencl_context.device()[0].getInfo(); + writer.write("opencl_device_name", msg_opencl_device.str()); + } +#endif +} + } // namespace cmdstan #endif diff --git a/src/cmdstan/write_stan.hpp b/src/cmdstan/write_stan.hpp index b08ea00ef9..94af938c6d 100644 --- a/src/cmdstan/write_stan.hpp +++ b/src/cmdstan/write_stan.hpp @@ -2,6 +2,7 @@ #define CMDSTAN_WRITE_STAN_HPP #include +#include #include #include @@ -13,5 +14,11 @@ void write_stan(stan::callbacks::writer &writer) { writer("stan_version_patch = " + stan::PATCH_VERSION); } +void write_stan(stan::callbacks::structured_writer &writer) { + writer.write("stan_major_version", stan::MAJOR_VERSION); + writer.write("stan_minor_version", stan::MINOR_VERSION); + writer.write("stan_patch_version", stan::PATCH_VERSION); +} + } // namespace cmdstan #endif diff --git a/src/test/interface/arguments/argument_test.cpp b/src/test/interface/arguments/argument_test.cpp index 7b9eb89d44..8060039399 100644 --- a/src/test/interface/arguments/argument_test.cpp +++ b/src/test/interface/arguments/argument_test.cpp @@ -5,6 +5,7 @@ class test_arg_impl : public cmdstan::argument { void print(stan::callbacks::writer &w, int depth, const std::string &prefix) { } + void print(stan::callbacks::structured_writer &j) {} void print_help(stan::callbacks::writer &w, int depth, bool recurse) {} }; diff --git a/src/test/interface/config_json_test.cpp b/src/test/interface/config_json_test.cpp new file mode 100644 index 0000000000..36b8e47ea4 --- /dev/null +++ b/src/test/interface/config_json_test.cpp @@ -0,0 +1,58 @@ +#include +#include +#include + +using cmdstan::test::convert_model_path; +using cmdstan::test::file_exists; +using cmdstan::test::run_command; +using cmdstan::test::run_command_output; + +class CmdStan : public testing::Test { + public: + void SetUp() { + multi_normal_model = {"src", "test", "test-models", "multi_normal_model"}; + arg_output = {"test", "output"}; + output_csv = {"test", "output.csv"}; + output_json = {"test", "output_config.json"}; + } + + void TearDown() { + std::remove(convert_model_path(output_csv).c_str()); + std::remove(convert_model_path(output_json).c_str()); + } + + std::vector multi_normal_model; + std::vector arg_output; + std::vector output_csv; + std::vector output_json; +}; + +TEST_F(CmdStan, config_json_output_valid) { + std::stringstream ss; + ss << convert_model_path(multi_normal_model) + << " sample output file=" << convert_model_path(arg_output) + << " save_cmdstan_config=1"; + run_command_output out = run_command(ss.str()); + ASSERT_FALSE(out.hasError) << out.output; + ASSERT_TRUE(file_exists(convert_model_path(output_csv))); + ASSERT_TRUE(file_exists(convert_model_path(output_json))); + + std::fstream json_in(convert_model_path(output_json)); + std::stringstream result_json_sstream; + result_json_sstream << json_in.rdbuf(); + json_in.close(); + std::string json = result_json_sstream.str(); + + ASSERT_FALSE(json.empty()); + ASSERT_TRUE(stan::test::is_valid_JSON(json)); +} + +TEST_F(CmdStan, config_json_output_not_requested) { + std::stringstream ss; + ss << convert_model_path(multi_normal_model) + << " sample output file=" << convert_model_path(arg_output); + run_command_output out = run_command(ss.str()); + ASSERT_FALSE(out.hasError); + ASSERT_TRUE(file_exists(convert_model_path(output_csv))); + ASSERT_FALSE(file_exists(convert_model_path(output_json))); +} diff --git a/src/test/interface/pathfinder_test.cpp b/src/test/interface/pathfinder_test.cpp index 249ff73a35..0b5a2829bd 100644 --- a/src/test/interface/pathfinder_test.cpp +++ b/src/test/interface/pathfinder_test.cpp @@ -1,10 +1,9 @@ #include -#include +#include #include #include using cmdstan::test::convert_model_path; -using cmdstan::test::count_matches; using cmdstan::test::file_exists; using cmdstan::test::parse_sample; using cmdstan::test::run_command; @@ -142,8 +141,7 @@ TEST_F(CmdStan, pathfinder_save_single_default_num_paths) { std::string single_json = result_json_sstream.str(); ASSERT_FALSE(single_json.empty()); - rapidjson::Document document; - ASSERT_FALSE(document.Parse<0>(single_json.c_str()).HasParseError()); + ASSERT_TRUE(stan::test::is_valid_JSON(single_json)); single_json.erase( std::remove_if(single_json.begin(), single_json.end(), is_whitespace), single_json.end());