diff --git a/R/csv.R b/R/csv.R index 8ac913937..46931d5d3 100644 --- a/R/csv.R +++ b/R/csv.R @@ -204,6 +204,7 @@ read_cmdstan_csv <- function(files, lp = lp )) } + user_variables_subset <- FALSE if (is.null(variables)) { # variables = NULL returns all variables <- metadata$variables } else if (!any(nzchar(variables))) { # if variables = "" returns none @@ -215,6 +216,7 @@ read_cmdstan_csv <- function(files, paste(res$not_found, collapse = ", "), call. = FALSE) } variables <- unrepair_variable_names(res$matching) + user_variables_subset <- TRUE } if (is.null(sampler_diagnostics)) { sampler_diagnostics <- metadata$sampler_diagnostics @@ -281,8 +283,13 @@ read_cmdstan_csv <- function(files, draws_list_id <- length(draws) + 1 warmup_draws_list_id <- length(warmup_draws) + 1 if (metadata$method == "pathfinder") { - metadata$variables = union(metadata$sampler_diagnostics, metadata$variables) - variables = union(metadata$sampler_diagnostics, variables) + metadata$variables <- union(metadata$sampler_diagnostics, metadata$variables) + if (!user_variables_subset) { + # because for pathfinder variables and diagnostics are read in together, + # if user hasn't selected a custom subset of variables we need to include + # all diagnostics + variables <- union(metadata$sampler_diagnostics, variables) + } } suppressWarnings( draws[[draws_list_id]] <- data.table::fread( @@ -489,10 +496,24 @@ read_sample_csv <- function(files, #' `TRUE` but set to `FALSE` to avoid checking for problems with divergences #' and treedepth. #' -as_cmdstan_fit <- function(files, check_diagnostics = TRUE, format = getOption("cmdstanr_draws_format")) { - csv_contents <- read_cmdstan_csv(files, format = format) +as_cmdstan_fit <- function(files, + variables = NULL, + check_diagnostics = TRUE, + format = getOption("cmdstanr_draws_format")) { + csv_contents <- read_cmdstan_csv(files, variables = variables, format = format) + method <- csv_contents$metadata$method + if (!is.null(variables)) { + if (method == "sample") { + variables <- posterior::variables(csv_contents$post_warmup_draws) + } else if (method == "optimize") { + variables <- posterior::variables(csv_contents$point_estimates) + } else { # variational, laplace, pathfinder + variables <- posterior::variables(csv_contents$draws) + } + csv_contents$metadata$variables <- variables + } switch( - csv_contents$metadata$method, + method, "sample" = CmdStanMCMC_CSV$new(csv_contents, files, check_diagnostics), "optimize" = CmdStanMLE_CSV$new(csv_contents, files), "variational" = CmdStanVB_CSV$new(csv_contents, files), @@ -638,6 +659,7 @@ for (method in unavailable_methods_CmdStanFit_CSV) { CmdStanMLE_CSV$set("public", name = method, value = error_unavailable_CmdStanFit_CSV) CmdStanVB_CSV$set("public", name = method, value = error_unavailable_CmdStanFit_CSV) CmdStanLaplace_CSV$set("public", name = method, value = error_unavailable_CmdStanFit_CSV) + CmdStanPathfinder_CSV$set("public", name = method, value = error_unavailable_CmdStanFit_CSV) } diff --git a/man/read_cmdstan_csv.Rd b/man/read_cmdstan_csv.Rd index 86c4832df..895c9e833 100644 --- a/man/read_cmdstan_csv.Rd +++ b/man/read_cmdstan_csv.Rd @@ -14,6 +14,7 @@ read_cmdstan_csv( as_cmdstan_fit( files, + variables = NULL, check_diagnostics = TRUE, format = getOption("cmdstanr_draws_format") ) diff --git a/tests/testthat/helper-models.R b/tests/testthat/helper-models.R index b0773e8b0..f1e248b87 100644 --- a/tests/testthat/helper-models.R +++ b/tests/testthat/helper-models.R @@ -24,6 +24,7 @@ testing_fit <- "optimize", "laplace", "variational", + "pathfinder", "generate_quantities"), seed = 123, ...) { diff --git a/tests/testthat/test-csv.R b/tests/testthat/test-csv.R index dc678a6bc..6df1cb7c7 100644 --- a/tests/testthat/test-csv.R +++ b/tests/testthat/test-csv.R @@ -8,6 +8,7 @@ fit_logistic_optimize <- testing_fit("logistic", method = "optimize", seed = 123 fit_logistic_variational <- testing_fit("logistic", method = "variational", seed = 123) fit_logistic_variational_short <- testing_fit("logistic", method = "variational", output_samples = 100, seed = 123) fit_logistic_laplace <- testing_fit("logistic", method = "laplace", seed = 123) +fit_logistic_pathfinder <- testing_fit("logistic", method = "pathfinder", seed = 123) fit_bernoulli_diag_e_no_samples <- testing_fit("bernoulli", method = "sample", seed = 123, chains = 2, iter_sampling = 0, metric = "diag_e") @@ -520,64 +521,6 @@ test_that("time from read_cmdstan_csv matches time from fit$time()", { ) }) -test_that("as_cmdstan_fit creates fitted model objects from csv", { - fits <- list( - mle = as_cmdstan_fit(fit_logistic_optimize$output_files()), - vb = as_cmdstan_fit(fit_logistic_variational$output_files()), - laplace = as_cmdstan_fit(fit_logistic_laplace$output_files()), - mcmc = as_cmdstan_fit(fit_logistic_thin_1$output_files()) - ) - for (class in names(fits)) { - fit <- fits[[class]] - class_name <- if (class == "laplace") "Laplace" else toupper(class) - checkmate::expect_r6(fit, classes = paste0("CmdStan", class_name, "_CSV")) - expect_s3_class(fit$draws(), "draws") - checkmate::expect_numeric(fit$lp()) - expect_output(fit$print(), "variable") - expect_length(fit$output_files(), if (class == "mcmc") fit$num_chains() else 1) - expect_s3_class(fit$summary(), "draws_summary") - - if (class == "mcmc") { - expect_s3_class(fit$sampler_diagnostics(), "draws_array") - expect_type(fit$inv_metric(), "list") - expect_equal(fit$time()$total, NA_integer_) - expect_s3_class(fit$time()$chains, "data.frame") - } - if (class == "mle") { - checkmate::expect_numeric(fit$mle()) - } - if (class == "vb") { - checkmate::expect_numeric(fit$lp_approx()) - } - if (class == "laplace") { - checkmate::expect_numeric(fit$lp_approx()) - } - - for (method in unavailable_methods_CmdStanFit_CSV) { - if (!(method == "time" && class == "mcmc")) { - expect_error(fit[[method]](), "This method is not available") - } - } - } -}) - -test_that("as_cmdstan_fit can check MCMC diagnostics", { - fit_schools <- suppressMessages( - testing_fit("schools", chains = 2, - adapt_delta = 0.5, max_treedepth = 4, - show_messages = FALSE) - ) - expect_message( - as_cmdstan_fit(fit_schools$output_files()), - "transitions ended with a divergence" - ) - expect_message( - as_cmdstan_fit(fit_schools$output_files()), - "transitions hit the maximum treedepth" - ) - expect_silent(as_cmdstan_fit(fit_schools$output_files(), check_diagnostics = FALSE)) -}) - test_that("read_cmdstan_csv reads seed correctly", { opt <- read_cmdstan_csv(fit_bernoulli_optimize$output_files()) vi <- read_cmdstan_csv(fit_bernoulli_variational$output_files()) @@ -896,3 +839,98 @@ test_that("read_cmdstan_csv() works with tilde expansion", { tildified_path <- file.path("~", fs::path_rel(full_path, "~")) expect_no_error(read_cmdstan_csv(tildified_path)) }) + + +test_that("as_cmdstan_fit creates fitted model objects from csv", { + fits <- list( + mle = as_cmdstan_fit(fit_logistic_optimize$output_files()), + vb = as_cmdstan_fit(fit_logistic_variational$output_files()), + laplace = as_cmdstan_fit(fit_logistic_laplace$output_files()), + pathfinder = as_cmdstan_fit(fit_logistic_pathfinder$output_files()), + mcmc = as_cmdstan_fit(fit_logistic_thin_1$output_files()) + ) + + for (class in names(fits)) { + fit <- fits[[class]] + if (class == "laplace") { + class_name <- "Laplace" + } else if (class == "pathfinder") { + class_name <- "Pathfinder" + } else { + class_name <- toupper(class) + } + checkmate::expect_r6(fit, classes = paste0("CmdStan", class_name, "_CSV")) + expect_s3_class(fit$draws(), "draws") + checkmate::expect_numeric(fit$lp()) + expect_output(fit$print(), "variable") + expect_length(fit$output_files(), if (class == "mcmc") fit$num_chains() else 1) + expect_s3_class(fit$summary(), "draws_summary") + + if (class == "mcmc") { + expect_s3_class(fit$sampler_diagnostics(), "draws_array") + expect_type(fit$inv_metric(), "list") + expect_equal(fit$time()$total, NA_integer_) + expect_s3_class(fit$time()$chains, "data.frame") + } + if (class == "mle") { + checkmate::expect_numeric(fit$mle()) + } + if (class %in% c("vb", "laplace", "pathfinder")) { + checkmate::expect_numeric(fit$lp_approx()) + } + for (method in unavailable_methods_CmdStanFit_CSV) { + if (!(method == "time" && class == "mcmc")) { + expect_error(fit[[method]](), "This method is not available", info = class) + } + } + } +}) + +test_that("as_cmdstan_fit can check MCMC diagnostics", { + fit_schools <- suppressMessages( + testing_fit("schools", chains = 2, + adapt_delta = 0.5, max_treedepth = 4, + show_messages = FALSE) + ) + expect_message( + as_cmdstan_fit(fit_schools$output_files()), + "transitions ended with a divergence" + ) + expect_message( + as_cmdstan_fit(fit_schools$output_files()), + "transitions hit the maximum treedepth" + ) + expect_silent(as_cmdstan_fit(fit_schools$output_files(), check_diagnostics = FALSE)) +}) + +test_that("as_cmdstan_fit filters variables across methods", { + mcmc_vars <- c("alpha", "beta[2]") + mcmc <- as_cmdstan_fit(fit_logistic_thin_1$output_files(), variables = mcmc_vars) + expect_equal(posterior::variables(mcmc$draws()), mcmc_vars) + expect_equal(mcmc$summary()$variable, mcmc_vars) + expect_equal(mcmc$metadata()$variables, mcmc_vars) + + mle_vars <- c("beta[1]", "beta[3]") + mle <- as_cmdstan_fit(fit_logistic_optimize$output_files(), variables = mle_vars) + expect_equal(posterior::variables(mle$draws()), mle_vars) + expect_equal(mle$summary()$variable, mle_vars) + expect_equal(mle$metadata()$variables, mle_vars) + + vb_vars <- "beta" + vb <- as_cmdstan_fit(fit_logistic_variational$output_files(), variables = vb_vars) + expect_equal(posterior::variables(vb$draws()), c("beta[1]", "beta[2]", "beta[3]")) + expect_equal(vb$summary()$variable, c("beta[1]", "beta[2]", "beta[3]")) + expect_equal(vb$metadata()$variables, c("beta[1]", "beta[2]", "beta[3]")) + + laplace_vars <- "alpha" + laplace <- as_cmdstan_fit(fit_logistic_laplace$output_files(), variables = laplace_vars) + expect_equal(posterior::variables(laplace$draws()), laplace_vars) + expect_equal(laplace$summary()$variable, laplace_vars) + expect_equal(laplace$metadata()$variables, laplace_vars) + + pathfinder_vars <- c("alpha", "beta[1]", "beta[3]") + pathfinder <- as_cmdstan_fit(fit_logistic_pathfinder$output_files(), variables = pathfinder_vars) + expect_equal(posterior::variables(pathfinder$draws()), pathfinder_vars) + expect_equal(pathfinder$summary()$variable, pathfinder_vars) + expect_equal(pathfinder$metadata()$variables, pathfinder_vars) +}) diff --git a/vignettes/posterior.Rmd b/vignettes/posterior.Rmd index 0039554d0..1e2729184 100644 --- a/vignettes/posterior.Rmd +++ b/vignettes/posterior.Rmd @@ -49,7 +49,6 @@ fit$summary(variables = c("mu", "tau"), mean, sd) To summarize all variables with non-default functions, it is necessary to set explicitly set the variables argument, either to `NULL` or the full vector of variable names. ```{r} -fit$metadata()$model_params fit$summary(variables = NULL, "mean", "median") ```