CmdStanR サンプル - you1025/my_something_flagments GitHub Wiki
- R Packages
- 参考
Installing CmdStan を参照。
Stan ファイルを指定して CmdStanModel
モデルを作成する。
コード上に Stan コードを記述する場合は下記のように一時ファイルを経由する。
# Stan コード
stan_code <- "
data {
int x;
int n;
}
parameters {
real<lower=0, upper=1> theta;
}
model {
x ~ binomial(n, theta);
}
"
filepath <- cmdstanr::write_stan_file(code = stan_code)
# ファイルから CmdStanModel を作成
model <- cmdstanr::cmdstan_model(stan_file = filepath)
4/15 の二項分布におけるパラメータ theta
を推定する CmdStanMCMC
を作成する。
fit <- model$sample(
data = list(
x = 4,
n = 15
),
chains = 4, # チェイン数
parallel_chains = 4, # 並列コア数
iter_warmup = 1000, # warm-up
iter_sampling = 5000, # sample
seed = 1234
)
結果の概要
fit
# variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
# lp__ -10.81 -10.53 0.73 0.31 -12.27 -10.30 1.00 8826 8683
# theta 0.29 0.28 0.11 0.11 0.13 0.48 1.00 7469 7257
パラメータ(variable)の指定も可能
fit$summary("theta")
# A tibble: 1 × 10
# variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
# <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#1 theta 0.293 0.285 0.106 0.108 0.132 0.481 1.00 7469. 7258.
代表値(と信用区間)の指定も可能
fit$summary("theta", "mean", "sd")
# A tibble: 1 × 3
# variable mean sd
# <chr> <dbl> <dbl>
#1 theta 0.293 0.106
draw で 3 次元配列(iteration x chain x variable)を取得。
fit$draws()
# A draws_array: 5000 iterations, 4 chains, and 2 variables
#, , variable = lp__
#
# chain
#iteration 1 2 3 4
# 1 -11 -11 -11 -11
# 2 -10 -11 -10 -11
# 3 -10 -10 -10 -11
# 4 -10 -10 -10 -11
# 5 -10 -10 -10 -11
#
#, , variable = theta
#
# chain
#iteration 1 2 3 4
# 1 0.46 0.21 0.21 0.37
# 2 0.28 0.18 0.24 0.37
# 3 0.29 0.34 0.23 0.41
# 4 0.31 0.32 0.27 0.39
# 5 0.36 0.28 0.35 0.39
#
# ... with 4995 more iterations
各チェインが 1 つの配列にまとめられる。
チェイン毎に取得したい場合は extract_variable_matrix を使う。
fit$draws() %>% posterior::extract_variable("theta")
# [1] 0.4551250 0.2844410 0.2899700 0.3052940 0.3598680
tibble 形式で MCMC サンプルを抽出。
各チェインが 1 つのデータにまとめられる。
fit$draws() %>% posterior::as_draws_df()
# A draws_df: 5000 iterations, 4 chains, and 2 variables
# lp__ theta
#1 -11 0.46
#2 -10 0.28
#3 -10 0.29
#4 -10 0.31
#5 -10 0.36
#6 -10 0.28
#7 -10 0.34
#8 -11 0.40
#9 -11 0.37
#10 -10 0.23
# ... with 19990 more draws
# ... hidden reserved variables {'.chain', '.iteration', '.draw'}
fit$draws("theta") %>% bayesplot::mcmc_hist()
fit$draws() %>%
posterior::extract_variable("theta") %>%
# 予測値を算出
purrr::map_int(~ rbinom(1, 10, .)) %>%
# 集計して比率を算出
tibble::as_tibble_col(column_name = "k") %>%
dplyr::count(k) %>%
dplyr::mutate(ratio = n / sum(n)) %>%
# 可視化
ggplot(aes(k, ratio)) +
geom_col() +
scale_x_continuous(breaks = 0:10, minor_breaks = NULL) +
scale_y_continuous(labels = purrr::partial(formattable::percent, digits = 0)) +
labs(
x = "count of True",
y = "probability"
)