Busemeyer & Diederich(2010), Heathcote (2014), Palminteri et al.(2017)を参考にまとめると,以下のような感じになります。
「強化学習モデル: 最尤推定」ではABCDとすすめてきました。今回は,ベイズ推定をcmdstanを使って行ってみます。そこで,一度Bに戻ってからDに取り組みます。
適宜必要なパッケージをインストールしてください。cmdstnrのインストールは,こちらをご確認ください。
rm(list = ls())
library(tidyverse)
library(cmdstanr)
library(rstan)
library(posterior)
library(bridgesampling)
改めて人工データを生成するために,実験状況の入ったsim_dataの準備とq_learning_sim関数とq_learning_ll関数の準備をします。
sim_data <- tibble(trial = 1:80,
prob_s1 = rep(c(0.2, 0.8), each = 40),
prob_s2 = rep(c(0.8, 0.2), each = 40),
reward_s1 = ifelse(runif(80) < prob_s1, 1, 0),
reward_s2 = ifelse(runif(80) < prob_s2, 1, 0))
q_learning_sim <- function(alpha, beta,data) {
#変数の準備
value_s1 <- 0 # s1の価値(初期値は0)
value_s2 <- 0 # s2の価値(初期値は0)
current_choice <- NULL # ある時点の選択(1=s1,0=s2)
choice_prob_s1 <- NULL # s1の選択確率
reward <- NULL # 報酬
# Qlearningモデル
for (i in 1:nrow(data)){
# s1を選ぶ確率を計算し,一様分布から発生させた乱数が行動選択確率よりも小さい時に1(s1),大きい時に0(s2)
choice_prob_s1[i] <- exp(beta*value_s1[i])/(exp(beta*value_s1[i])+exp(beta*value_s2[i]))
current_choice[i] <- as.integer(runif(1,min=0,max=1) <= choice_prob_s1[i])
#FBを報酬(r)として、価値の更新を行う。
if (current_choice[i] == 1){
reward[i] <- data$reward_s1[i]
#予測誤差の計算
prediction_error <- reward[i] - value_s1[i]
#予測誤差を使ってs1の価値を更新する
value_s1[i+1] <- value_s1[i]+alpha*prediction_error
#s2は更新なし
value_s2[i+1] <- value_s2[i]
}else{
reward[i] <- data$reward_s2[i]
#予測誤差の計算
prediction_error <- reward[i] - value_s2[i]
#予測誤差を使ってs2の価値を更新する
value_s2[i+1] <- value_s2[i]+alpha*prediction_error
#s1は更新なし
value_s1[i+1] <- value_s1[i]
}
}
result <- data.frame(trial = data$trial,
value_s1 = value_s1[1:nrow(data)],
value_s2 = value_s2[1:nrow(data)],
prob_s1 = choice_prob_s1,
choice = current_choice,
reward = reward)
return(result)
}
q_learning_ll <- function(alpha, beta,data) {
#変数の準備
value_s1 <- 0 # s1の価値(初期値は0)
value_s2 <- 0 # s2の価値(初期値は0)
prob_s1 <- NULL # s1の選択確率
ll <- 0 # 対数尤度
# Qlearningモデル
for (i in 1:nrow(data)){
# s1を選ぶ確率を計算
prob_s1[i] <- exp(beta*value_s1[i])/(exp(beta*value_s1[i])+exp(beta*value_s2[i]))
#FBを報酬(r)として、価値の更新を行う。
if (data$choice[i] == 1){
#予測誤差の計算
prediction_error <- data$reward[i] - value_s1[i]
#予測誤差を使ってs1の価値を更新する
value_s1[i+1] <- value_s1[i]+alpha*prediction_error
#s2は更新なし
value_s2[i+1] <- value_s2[i]
# 対数尤度の計算のために選択したs1を選ぶ確率の対数を加算
ll <- ll + log(prob_s1[i])
}else{
#予測誤差の計算
prediction_error <- data$reward[i] - value_s2[i]
#予測誤差を使ってs2の価値を更新する
value_s2[i+1] <- value_s2[i]+alpha*prediction_error
#s1は更新なし
value_s1[i+1] <- value_s1[i]
# 対数尤度の計算のために選択したs2を選ぶ確率の対数を加算
ll <- ll + log(1-prob_s1[i])
}
}
result <- data.frame(trial = data$trial,
value_s1 = value_s1[1:nrow(data)],
value_s2 = value_s2[1:nrow(data)],
prob_s1 = prob_s1,
choice = data$choice,
reward = data$reward)
return(list(result = result, ll = ll))
}
alpha = 0.3, beta = 2で,シミュレーション・データを生成します。
set.seed(1234)
data_1 <- q_learning_sim(alpha = 0.3, beta = 2, sim_data)
“q_learning.stan”というファイルを作成して,以下のコードを書き込みます。Stanは統計モデリング用のプラットフォームで,MCMCサンプリングによるベイズ推定,変分推定,最適化による最尤推定が可能です。Stanでは,data,parameters,modelのようにブロックごとに指定をして,書いていきます(この3つが最小限のブロック数かと思います)。dataブロックでは入力するデータについて記述します(型と範囲の指定が必要です)。parameterブロックでは,推定するパラメータについて記述します(型と範囲の指定が必要です)。modelブロックでは,データとパラメータを用いた生成モデルの記載をします。なお,データ~分布(y ~ normal(mu, sig))のような形で記述もできますし,target += 対数尤度のように,対数尤度でも記述ができます。以下は,Q学習の推定用のコードですが,グリッドサーチで用意したものに比べて,すっきりしているかと思います。
data {
int<lower=1> trial;
int<lower=1,upper=2> choice[trial]; // 1 or 2
int<lower=0,upper=1> reward[trial]; // 0 or 1
}
parameters {
real<lower=0.0,upper=1.0> alpha; //学習率
real<lower=0.0> beta; //逆温度
}
model {
//学習率と逆温度の事前分布の指定はしていないので,parametersで指定した範囲の無情報事前分布が使われる
matrix[trial,2] Q;
Q[1, 1] = 0;
Q[1, 2] = 0;
for ( t in 1:trial) {
// 対数尤度を足す
target += log(exp(beta*Q[t,choice[t]])/(exp(beta*Q[t,choice[t]])+exp(beta*Q[t,3-choice[t]])));
if (t < trial) {
// 選択された選択肢のQ値の更新
Q[t+1,choice[t]] = Q[t, choice[t]] + alpha * (reward[t] - Q[t, choice[t]]);
// 選択されなかった選択肢は更新しない
Q[t+1, 3- choice[t]] = Q[t, 3- choice[t]];
}
}
}
Stanコードがstanファイルとして保存できたら,cmdstan_model()コンパイルします。Rは便利ですが計算は遅い言語なので,c++のような高速な計算が可能な言語にコンパイルをします。ちなみに,stanをRで使う場合は,Rstanを使うことが多かったのですが,最近は,cmdstanrが開発されており,こっちのほうがコンパイルもサンプリングも早いのでおすすめです。
q_learning <- cmdstan_model('q_learning.stan')
コンパイルができたら,最適化による最尤推定をしてみましょう。最尤推定は,「強化学習モデル: 最尤推定」 でも見てきたように,optimやpsoなどのRパッケージでできますが,ここではstanで最尤推定値を推定します。コンパイルしたモデル$optimize()で推定ができます。
mle_cmdstan <- q_learning$optimize(
data = list(trial = nrow(data_1),
reward = data_1$reward,
choice = data_1$choice + 1),
seed = 123)
mle_cmdstan$summary()
stanのoptimizeを用いた最尤推定のパラメータリカバリをします。αは0.1から1の範囲で0.1刻み,βは0.5から5の範囲で0.5刻みでデータ生成とパラメータ推定を行います(つまり100個分チェックします)。
alpha_true <- NULL
beta_true <- NULL
alpha_estimated <- NULL
beta_estimated <- NULL
beta_set <- 0
for (i in 1:10) {
alpha_set <- 0
beta_set = beta_set + 0.5
for (j in 1:10) {
alpha_set = alpha_set + 0.1
#データ生成
data_2 <- q_learning_sim(alpha = alpha_set, beta = beta_set, sim_data)
alpha_true[(i-1)*10 + j] <- alpha_set
beta_true[(i-1)*10 + j] <- beta_set
#パラメータ推定(推定がミスった時用にtryCatch関数を準備)
tryCatch({
q_learning_mle <- q_learning$optimize(
data = list(trial = nrow(data_2),
reward = data_2$reward,
choice = data_2$choice + 1),
seed = 123)
alpha_estimated[(i-1)*10 + j] <- q_learning_mle$mle("alpha")
beta_estimated[(i-1)*10 + j] <- q_learning_mle$mle("beta")
},error = function(e) {message(e)})
}
}
parameter_recovery_mle <- data.frame(alpha_true = alpha_true[1:length(alpha_estimated)],
beta_true = beta_true[1:length(alpha_estimated)],
alpha_estimated = alpha_estimated[1:length(alpha_estimated)],
beta_estimated = beta_estimated[1:length(alpha_estimated)])
パラメータリカバリーのチェックをしましょう。散布図を書いて,真値(研究者がデータ生成時に設定した値)と推定された値が強い相関を示しているか確認します(データ生成時や推定時に確率的な変動が生じるので,完全一致はありません)。最尤推定だと,一部のパラメータは推定ミスをすることがあります。真値とはずれて,極端に大きなαの値(つまり1付近の値)や低いαの値が確認できます。ちょっと気になりますね。
parameter_recovery_mle %>%
ggplot(aes(x = alpha_true, y = alpha_estimated)) +
geom_point()
どうもβの推定値の中にとても大きな値になってしまったものがあるようです。βの最大値は制約をかけてないので,すごく大きくなることがあります。これも気になります。
parameter_recovery_mle %>%
ggplot(aes(x = beta_true, y = beta_estimated)) +
geom_point()
確認しにくいので,10以下の推定値のみをプロットして確認します。
parameter_recovery_mle %>%
filter(beta_estimated < 10) %>%
ggplot(aes(x = beta_true, y = beta_estimated)) +
geom_point()
q_learning_mcmc <- q_learning$sample(
data = list(trial = nrow(data_1),
reward = data_1$reward,
choice = data_1$choice + 1),
seed = 123,
chains = 4,
iter_warmup = 500,
iter_sampling = 1000,
parallel_chains = 4)
結果の要約を確認してみましょう。
q_learning_mcmc$summary()
## # A tibble: 3 × 10
## variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 lp__ -43.0 -42.6 1.19 0.889 -45.3 -41.8 1.00 1399. 1646.
## 2 alpha 0.511 0.489 0.178 0.185 0.250 0.841 1.00 2129. 1688.
## 3 beta 2.09 2.07 0.512 0.500 1.28 2.97 1.00 1531. 1724.
以下はMCMCについての診断結果です。
q_learning_mcmc$cmdstan_diagnose()
## Processing csv files: /tmp/RtmpJJ3WAj/q_learning-202203310003-1-4c61b3.csv, /tmp/RtmpJJ3WAj/q_learning-202203310003-2-4c61b3.csv, /tmp/RtmpJJ3WAj/q_learning-202203310003-3-4c61b3.csv, /tmp/RtmpJJ3WAj/q_learning-202203310003-4-4c61b3.csv
##
## Checking sampler transitions treedepth.
## Treedepth satisfactory for all transitions.
##
## Checking sampler transitions for divergences.
## No divergent transitions found.
##
## Checking E-BFMI - sampler transitions HMC potential energy.
## E-BFMI satisfactory.
##
## Effective sample size satisfactory.
##
## Split R-hat values satisfactory all parameters.
##
## Processing complete, no problems detected.
trace plot と事後分布をプロットしてみましょう。trace plotはMCMCのチェーンがきれいにまざっているか確認をしましょう。
# ggplotやtidyverseで扱いやすく処理する
mcmc_samples = as_draws_df(q_learning_mcmc$draws())
# alphaのtrace plot
mcmc_samples %>%
mutate(chain = as.factor(.chain)) %>%
ggplot(aes(x = .iteration, y = alpha, group = .chain, color = chain)) +
geom_line()
# betaのtrace plot(ごく少数の大きな値でプロットできないので10以下に絞った)
mcmc_samples %>%
mutate(chain = as.factor(.chain)) %>%
filter(beta < 10) %>%
ggplot(aes(x = .iteration, y = beta, group = .chain, color = chain)) +
geom_line()
# alphaの事後分布
mcmc_samples %>%
ggplot() +
geom_histogram(aes(x=alpha),binwidth = 0.01)
# betaの事後分布(ごく少数の大きな値があるとプロットできないので6以下に絞った)
mcmc_samples %>%
filter(beta < 6) %>%
ggplot() +
geom_histogram(aes(x=beta),binwidth = 0.1)
ベイズ推定でもパラメータリカバリをしてみましょう。ただ,これは結構計算に時間がかかると思います。
alpha_true <- NULL
beta_true <- NULL
alpha_estimated <- NULL
beta_estimated <- NULL
beta_set <- 0
for (i in 1:10) {
alpha_set <- 0
beta_set = beta_set + 0.5
for (j in 1:10) {
alpha_set = alpha_set + 0.1
#データ生成
data_2 <- q_learning_sim(alpha = alpha_set, beta = beta_set, sim_data)
alpha_true[(i-1)*10 + j] <- alpha_set
beta_true[(i-1)*10 + j] <- beta_set
print(paste("進捗状況:",(i-1)*10 + j,"/100"))
#パラメータ推定(推定がミスった時用にtryCatch関数を準備)
tryCatch({
q_learning_mcmc <- q_learning$sample(
data = list(trial = nrow(data_2),
reward = data_2$reward,
choice = data_2$choice + 1),
seed = 123,
chains = 4,
iter_warmup = 500,
iter_sampling = 1000,
parallel_chains = 4)
mcmc_samples = as_draws_df(q_learning_mcmc$draws())
alpha_estimated[(i-1)*10 + j] <- mean(mcmc_samples$alpha)
beta_estimated[(i-1)*10 + j] <- mean(mcmc_samples$beta)
},error = function(e) {message(e)})
}
}
parameter_recovery_mcmc <- data.frame(alpha_true = alpha_true,
beta_true = beta_true,
alpha_estimated = alpha_estimated,
beta_estimated = beta_estimated)
パラメータリカバリーのチェックをしましょう。散布図を書いて,真値(研究者がデータ生成時に設定した値)と推定された値が強い相関を示しているか確認します(データ生成時や推定時に確率的な変動が生じるので,完全一致はありません)。最尤推定と同様に,真値と結構ずれて低い推定値などがあって気になるところです。
parameter_recovery_mcmc %>%
ggplot(aes(x = alpha_true, y = alpha_estimated)) +
geom_point()
最尤推定同様に,βがすごく大きくなることがあるのは,気になるところです。
parameter_recovery_mcmc %>%
ggplot(aes(x = beta_true, y = beta_estimated)) +
geom_point()
上記の最尤推定や無情報の事前分布の場合だと,ベータがとても大きな値になる点が気にかかるところです。そこで事前分布にも情報をもたせることにします。最尤推定にはないベイズ推定の特徴は,事前分布と尤度からパラメータの事後分布を推定する点です。例えば,逆転学習課題でのパタメータ推定において,おおよそパラメータの分布が分かっていると,それを事前分布に使って,推定を安定化させることができます。例えば,Kanen et al.(2019)では,学習率αはden Ouden et al.(2013)を参考にベータ分布を事前分布に用い,逆温度βはGershman(2016)を参考にガンマ分布を用いています(なお,Karen et al.(2019)で用いるのは逆温度ではなく報酬感受性ですが,意味は同じです)。
alpha = seq(0,1, length=100)
plot(alpha, dbeta(alpha, 1.2, 1.2), ylab="density", type ="l", col=4)
beta = seq(0,15, length=500)
plot(beta, dgamma(beta, 4.82, 0.88), ylab="density", type ="l", col=4)
上記を踏まえて,αの事前分布にベータ分布,βの事前分布にガンマ分布を仮定したStanコードを書いて,q_learning_prior.stanという名前で保存します。
data {
int<lower=1> trial;
int<lower=1,upper=2> choice[trial]; // 1 or 2
int<lower=0,upper=1> reward[trial]; // 0 or 1
}
parameters {
real<lower=0.0,upper=1.0> alpha; //学習率
real<lower=0.0> beta; //逆温度
}
model {
matrix[trial,2] Q;
Q[1, 1] = 0;
Q[1, 2] = 0;
//学習率の事前分布にベータ分布,逆温度の事前分布にガンマ分布
alpha ~ beta(1.2, 1.2);
beta ~ gamma(4.82, 0.88);
for ( t in 1:trial) {
// 対数尤度を足す
target += log(exp(beta*Q[t,choice[t]])/(exp(beta*Q[t,choice[t]])+exp(beta*Q[t,3-choice[t]])));
if (t < trial) {
// 選択された選択肢のQ値の更新
Q[t+1,choice[t]] = Q[t, choice[t]] + alpha * (reward[t] - Q[t, choice[t]]);
// 選択されなかった選択肢は更新しない
Q[t+1, 3- choice[t]] = Q[t, 3- choice[t]];
}
}
}
q_learning_prior.stanを使って,パラメータリカバリを実施してみます。
まず,上記のモデルをコンパイルします。
q_learning_prior <- cmdstan_model('q_learning_prior.stan')
まず,先程のdata_1で試してみます。
q_learning_prior_mcmc <- q_learning_prior$sample(
data = list(trial = nrow(data_1),
reward = data_1$reward,
choice = data_1$choice + 1),
seed = 123,
chains = 4,
iter_warmup = 500,
iter_sampling = 1000,
parallel_chains = 4)
結果の要約を確認してみましょう。
q_learning_prior_mcmc$summary()
## # A tibble: 3 × 10
## variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 lp__ -42.0 -41.6 1.15 0.820 -44.4 -40.9 1.00 1313. 1796.
## 2 alpha 0.479 0.461 0.162 0.160 0.242 0.778 1.00 2206. 1671.
## 3 beta 2.33 2.32 0.486 0.475 1.55 3.16 1.00 2084. 2177.
以下はMCMCについての診断結果です。
q_learning_prior_mcmc$cmdstan_diagnose()
## Processing csv files: /tmp/RtmpJJ3WAj/q_learning_prior-202203310009-1-89764c.csv, /tmp/RtmpJJ3WAj/q_learning_prior-202203310009-2-89764c.csv, /tmp/RtmpJJ3WAj/q_learning_prior-202203310009-3-89764c.csv, /tmp/RtmpJJ3WAj/q_learning_prior-202203310009-4-89764c.csv
##
## Checking sampler transitions treedepth.
## Treedepth satisfactory for all transitions.
##
## Checking sampler transitions for divergences.
## No divergent transitions found.
##
## Checking E-BFMI - sampler transitions HMC potential energy.
## E-BFMI satisfactory.
##
## Effective sample size satisfactory.
##
## Split R-hat values satisfactory all parameters.
##
## Processing complete, no problems detected.
それでは,先程と同じデータでパラメータリカバリーをしてみましょう!
alpha_true <- NULL
beta_true <- NULL
alpha_estimated <- NULL
beta_estimated <- NULL
beta_set <- 0
for (i in 1:10) {
alpha_set <- 0
beta_set = beta_set + 0.5
for (j in 1:10) {
alpha_set = alpha_set + 0.1
#データ生成
data_2 <- q_learning_sim(alpha = alpha_set, beta = beta_set, sim_data)
alpha_true[(i-1)*10 + j] <- alpha_set
beta_true[(i-1)*10 + j] <- beta_set
print(paste("進捗状況:",(i-1)*10 + j,"/100"))
#パラメータ推定(推定がミスった時用にtryCatch関数を準備)
tryCatch({
q_learning_prior_mcmc <- q_learning_prior$sample(
data = list(trial = nrow(data_2),
reward = data_2$reward,
choice = data_2$choice + 1),
seed = 123,
chains = 4,
iter_warmup = 500,
iter_sampling = 1000,
parallel_chains = 4)
mcmc_samples = as_draws_df(q_learning_prior_mcmc$draws())
alpha_estimated[(i-1)*10 + j] <- mean(mcmc_samples$alpha)
beta_estimated[(i-1)*10 + j] <- mean(mcmc_samples$beta)
},error = function(e) {message(e)})
}
}
parameter_recovery_prior_mcmc <- data.frame(alpha_true = alpha_true,
beta_true = beta_true,
alpha_estimated = alpha_estimated,
beta_estimated = beta_estimated)
いい感じにパラメータリカバリーできています。無情報事前分布のときのようにすごく大きなβが発生しなくなっています。
parameter_recovery_prior_mcmc %>%
ggplot(aes(x = alpha_true, y = alpha_estimated)) +
geom_point()
parameter_recovery_prior_mcmc %>%
ggplot(aes(x = beta_true, y = beta_estimated)) +
geom_point()
「強化学習モデルを使ったモデル・フィッティング1」と同じデータを使います。以下のコードで5名分のlong形式のデータセットを準備します。
setwd("data")
file_names <- list.files()
setwd("..")
# 確認用の図を入れる場所を確保
plot_check <- NULL
# データを入れる場所を確保
data_long <- NULL
for (i in 1:length(file_names)) {
# file_namesのi番目のデータを読み込んで,上記の処理をして,tmp_dataに格納
tmp_data <- read_csv(paste("data",file_names[i], sep = "/")) %>%
filter(trial_type == "html-button-response") %>%
mutate(id = rep(i,80),
trial = 1:80,
s1_prob = rep(c(0.2,0.8),each = 40),
reward = ifelse(button_pressed == 0, reward_s1, reward_s2)) %>%
select(id, trial,choice=button_pressed, rt, reward, s1_prob,reward_s1, reward_s2)
# データの保存
data_long <- rbind(data_long, tmp_data)
# plot
plot_check[[i]] <- ggplot(tmp_data, aes(x = trial, y = s1_prob)) +
geom_line() +
geom_line(aes(x= trial, y = choice), colour = 'blue') +
geom_point(aes(x = trial, y = reward),colour = 'red')
}
Sub01
data_individual <- data_long %>%
filter(id == 1)
q_learning_mcmc <- q_learning$sample(
data = list(trial = nrow(data_individual),
reward = data_individual$reward,
choice = data_individual$choice + 1),
seed = 123,
chains = 4,
iter_warmup = 500,
iter_sampling = 1000,
parallel_chains = 4)
q_learning_mcmc$summary()
q_learning_mcmc$cmdstan_diagnose()
結果のプロット
q_learning_mcmc <- as_draws_df(q_learning_mcmc$draws())
# alphaのtrace plot
q_learning_mcmc %>%
mutate(chain = as.factor(.chain)) %>%
ggplot(aes(x = .iteration, y = alpha, group = .chain, color = chain)) +
geom_line()
# betaのtrace plot
q_learning_mcmc %>%
mutate(chain = as.factor(.chain)) %>%
filter(beta < 1000) %>%
ggplot(aes(x = .iteration, y = beta, group = .chain, color = chain)) +
geom_line()
# alphaの事後分布
q_learning_mcmc %>%
ggplot() +
geom_histogram(aes(x=alpha),binwidth = 0.01)
# betaの事後分布
q_learning_mcmc %>%
filter(beta < 1000) %>%
ggplot() +
geom_histogram(aes(x=beta),binwidth = 0.1)
q_values <- q_learning_ll(mean(mcmc_samples$alpha),mean(mcmc_samples$beta),data_individual)
sub01
data_individual <- data_long %>%
filter(id == 1)
q_learning_prior_mcmc <- q_learning_prior$sample(
data = list(trial = nrow(data_individual),
reward = data_individual$reward,
choice = data_individual$choice + 1),
seed = 123,
chains = 4,
iter_warmup = 500,
iter_sampling = 1000,
parallel_chains = 4)
q_learning_prior_mcmc$summary()
q_learning_prior_mcmc$cmdstan_diagnose()
結果のプロット
q_learning_prior_mcmc <- as_draws_df(q_learning_prior_mcmc$draws())
# alphaのtrace plot
q_learning_prior_mcmc %>%
mutate(chain = as.factor(.chain)) %>%
ggplot(aes(x = .iteration, y = alpha, group = .chain, color = chain)) +
geom_line()
# betaのtrace plot
q_learning_prior_mcmc %>%
mutate(chain = as.factor(.chain)) %>%
ggplot(aes(x = .iteration, y = beta, group = .chain, color = chain)) +
geom_line()
# alphaの事後分布
q_learning_prior_mcmc %>%
ggplot() +
geom_histogram(aes(x=alpha),binwidth = 0.01)
# betaの事後分布
q_learning_prior_mcmc %>%
ggplot() +
geom_histogram(aes(x=beta),binwidth = 0.1)
q_values <- q_learning_ll(mean(mcmc_samples$alpha),mean(mcmc_samples$beta),data_individual)
sub01のトレースプロットや事後分布をみる限りでは,有情報事前分布の方が推定がうまくいっているように見えます。
さて,ここまでで,ベイズ推定が出来るようになりました。次は「強化学習モデル: ベイズ推定(2)」で,モデル比較などに取り組みます。