#---- R environment ----#
# install packages
pkgs <- c(
"tools",
"bayesplot",
"posterior",
"tibble",
"tidyr",
"tidybayes",
"dplyr",
"ggplot2"
)
pkgs_missing <- setdiff(pkgs, rownames(installed.packages()))
if (length(pkgs_missing) > 0) {
install.packages(pkgs_missing)
# load libraries
library(cmdstanr)
library(bayesplot)
library(posterior)
library(tools)
library(tibble)
library(tidyr)
library(tidybayes)
library(dplyr)
library(ggplot2)
# output directory
outdir <- file.path(getwd(), "synthetic_controls")
dir.create(outdir, recursive = TRUE)
#---- Raw data ----#
# filename to save the data
dat_filename <- file.path(outdir, "international_migration_flow.csv")
# download the data if we don't already have it
if (!file.exists(dat_filename)) {
download.file(
"https://data.humdata.org/dataset/e09595bd-f4c5-4a66-8130-5a05f14d5e64/resource/67b2d6fc-ff4f-4d04-a75e-80e3de81b072/download/international_migration_flow.csv",
destfile = dat_filename
)
}
# load the data into R
dat <- read.csv(dat_filename)
#---- Model data ----#
# get treated (y) data: flows from origin to destination
y_treated <- dat %>%
filter(country_from == origin, country_to == destination) %>%
slice(match(months, migration_month)) %>%
pull(num_migrants) %>%
replace_na(0)
# get donor (x) data: flows from other countries to destination
x_donors <- dat %>%
filter(country_to == destination, country_from %in% donors) %>%
pivot_wider(
id_cols = migration_month,
names_from = country_from,
values_from = num_migrants,
values_fill = 0
) %>%
slice(match(months, migration_month)) %>%
column_to_rownames("migration_month") %>%
select(all_of(donors)) %>%
as.matrix()
# identify pre-treatment months
months_pre_treat <- months[months <= last_pretreatment_month]
# get the scaling factor
# (i.e. max of each donor in pre-treatment period)
scalar <- x_donors[months_pre_treat, ] %>%
apply(2, max)
# scale donor data
x_donors_scaled <- x_donors %>%
sweep(2, scalar, FUN = "/")
# set seed for random number generators (important for reproducibility)
seed <- round(runif(1, 1, 1e6))
set.seed(seed)
# stan model data object
md <- list(
T = length(months),
T0 = which(row.names(x_donors) == last_pretreatment_month),
K = length(donors),
y = y_treated,
x = x_donors_scaled,
seed = seed
)
# save data to disk
saveRDS(
object = md,
file = file.path(outdir, "md.rds")
)
#---- Stan model ----#
stan_code <- "
// bayesian synthetic controls
// negative binomial likelihood
// horseshoe shrinkage
data {
int<lower=1> T; // Total number of months
int<lower=1> T0; // End of pre-treatment period
int<lower=1> K; // Number of donor countries
array[T] int y; // Treated country migration counts (Ukraine)
matrix[T, K] x; // Donor country migration counts/rates
}
parameters {
real alpha;
real phi;
vector<lower=0>[K] weights_raw;
vector<lower=0>[K] local_scales;
real<lower=0> global_scale;
}
transformed parameters {
vector[T] mu;
vector<lower=0>[K] weights_shrunk;
simplex[K] weights;
weights_shrunk = weights_raw .* local_scales * global_scale;
weights = weights_shrunk / sum(weights_shrunk);
mu = exp(alpha) .* (x * weights);
}
model {
// likelihood
y[1 : T0] ~ neg_binomial_2(mu[1 : T0], phi);
// priors
alpha ~ normal(0, 10);
phi ~ exponential(0.1);
weights_raw ~ std_normal();
local_scales ~ cauchy(0, 1);
global_scale ~ cauchy(0, 1);
}
generated quantities {
array[T] int y_synthetic;
array[T] int effect;
for (t in 1 : T) {
y_synthetic[t] = neg_binomial_2_rng(mu[t], phi);
effect[t] = y[t] - y_synthetic[t];
}
}
"
writeLines(stan_code, con = file.path(outdir, "shrinkage.stan"))
#---- MCMC ----#
# function to generate initials
init_generator <- function(md = md, chain_id = 1) {
result <- list(
mu = rnorm(1, log(mean(md$y[1:md$T0])), 0.1),
phi = runif(1, 5, 15),
weights_raw = rep(1, md$K),
local_scales = rep(1, md$K),
global_scale = 1
)
return(result)
}
# random initials for four chains
inits <- lapply(1:4, function(id) {
init_generator(md = md, chain_id = id)
})
# compile the stan model
mod <- cmdstan_model(file.path(outdir, "shrinkage.stan"))
# run MCMC
fit_sh <- mod$sample(
data = md, # model data (generated above)
init = inits, # initials (generated above)
parallel_chains = 4, # number of MCMC chains
iter_warmup = 1e3, # warmup MCMC iterations to discard
iter_sampling = 2e3, # MCMC iterations to keep
seed = md$seed
)
# save model to disk
fit_sh$save_object(file.path(outdir, "fit_shrinkage.rds"))
#---- Diagnostics ----#
# get and save fit summary
fit_summary <- fit_sh$summary()
write.csv(
fit_summary,
file.path(
outdir,
"fit_summary.csv"
)
)
# posterior parameter estimates
print(fit_summary)
# check convergence for all parameters
# rhat < 1.01 means the chains converged (< 1.1 is okay for testing purposes)
fit_summary %>%
filter(rhat > 1.01) %>%
select(variable, rhat)
# traceplots
draws <- fit_sh$draws()
params <- fit_sh$metadata()$variables
pdf(
file.path(
outdir,
"traceplots.pdf"
),
width = 11,
height = 8.5
)
plots_per_page <- 12
for (i in seq(1, length(params), by = plots_per_page)) {
current_batch <- params[i:min(i + 11, length(params))]
p <- mcmc_trace(draws, pars = current_batch) +
facet_wrap(~parameter, ncol = 3, nrow = 4, scales = "free_y") +
ggtitle(paste("Parameters", i, "to", min(i + 11, length(params))))
print(p)
}
dev.off()
#---- assess coverage ----#
# get prediction intervals
predictive_intervals <- fit_sh$draws("y_synthetic", format = "df") %>%
spread_draws(y_synthetic[t]) %>%
median_qi(y_synthetic, .width = 0.95) %>% # Calculates 2.5% and 97.5% quantiles
filter(t <= md$T0) # Only look at the pre-treatment period
# join with observed data
coverage_df <- predictive_intervals %>%
mutate(observed = md$y[t]) %>%
mutate(is_covered = observed >= .lower & observed <= .upper)
# calculate coverage
coverage_rate <- mean(coverage_df$is_covered) * 100
print(paste0("Pre-treatment Coverage: ", round(coverage_rate, 2), "%"))
# visualise coverage
pdf(
file.path(
outdir,
"coverage_plot.pdf"
),
width = 11,
height = 8.5
)
ggplot(coverage_df, aes(x = t)) +
geom_ribbon(aes(ymin = .lower, ymax = .upper), alpha = 0.2, fill = "blue") +
geom_line(aes(y = y_synthetic), color = "blue", linetype = "dashed") +
geom_point(aes(y = observed, color = is_covered)) +
scale_color_manual(values = c("TRUE" = "black", "FALSE" = "red")) +
labs(
title = paste(
"Interval Coverage Diagnostic (",
round(coverage_rate, 1),
"%)"
),
subtitle = "Red points fall outside the 95% prediction interval in the pre-treatment period",
x = "Month Index",
y = "Migration Count"
) +
theme_minimal()
dev.off()
#---- synthetic controls plot ----#
# extract the posterior predictions for y_synthetic
draws_df <- fit_sh$draws(
variables = "y_synthetic",
format = "df"
)
# summarise posterior predictions (mean, lower, upper)
synth_draws <- draws_df %>%
pivot_longer(
cols = starts_with("y_synthetic"),
names_to = "parameter",
values_to = "value"
) %>%
mutate(t = as.integer(gsub(".*\\[(\\d+)\\]", "\\1", parameter))) %>%
group_by(t) %>%
summarise(
y_hat = mean(value),
lower = quantile(value, 0.025),
upper = quantile(value, 0.975),
.groups = "drop"
) %>%
mutate(date = months[t])
# observed (real) data
real_data <- data.frame(
date = row.names(md$x),
y_obs = md$y
)
# data for plot
plot_df <- left_join(synth_draws, real_data, by = "date") %>%
mutate(date = as.Date(paste0(date, "-01"))) %>%
mutate(diff = y_obs - y_hat)
treatment_date <- as.Date(paste0(months[md$T0], "-01"))
## plot reality and synthetic control
pdf(
file.path(
outdir,
"synthetic_controls_plot.pdf"
),
width = 11,
height = 8.5
)
ggplot(plot_df, aes(x = date)) +
geom_ribbon(aes(ymin = lower, ymax = upper), fill = "skyblue", alpha = 0.4) +
geom_line(aes(y = y_hat, color = "Synthetic Control"), linetype = "dashed") +
geom_line(aes(y = y_obs, color = "Actual Data")) +
geom_vline(xintercept = treatment_date, color = "black") +
scale_y_log10(labels = scales::comma) + # This makes the ribbon visible!
labs(
title = "Migration reality and synthetic control timeseries",
subtitle = paste(origin, "to", destination),
) +
ylab('Migration flows (log scale)') +
theme_minimal()
dev.off()
## plot difference
pdf(
file.path(
outdir,
"treatment_effect_plot.pdf"
),
width = 11,
height = 8.5
)
ggplot(plot_df, aes(x = date)) +
geom_ribbon(
aes(ymin = y_obs - upper, ymax = y_obs - lower),
fill = "firebrick",
alpha = 0.4
) +
geom_line(aes(y = diff)) +
geom_hline(yintercept = 0, linetype = "dashed") +
geom_vline(xintercept = treatment_date, color = "black") +
labs(
title = "Estimated Treatment Effect (Actual - Synthetic)",
subtitle = paste(origin, "to", destination),
) +
theme_minimal()
dev.off()
#---- check shrinkage (i.e. dominant donors) ----#
# Set your threshold
top_x <- 10
# Summarize weights and map names
weights_summary <- fit_sh$summary("weights") %>%
mutate(
donor_index = 1:n(),
# Map the index to the name in the 'donors' object
donor_name = donors[donor_index]
) %>%
arrange(desc(mean)) %>%
mutate(cumulative_weight = cumsum(mean)) %>%
slice_head(n = top_x)
# Plot the top donors with names
pdf(
file.path(
outdir,
"top_donors_plot.pdf"
),
width = 11,
height = 8.5
)
p <- ggplot(weights_summary, aes(x = reorder(donor_name, -mean), y = mean)) +
geom_bar(stat = "identity", fill = "steelblue", alpha = 0.8) +
geom_errorbar(
aes(ymin = q5, ymax = q95),
width = 0.2,
color = "firebrick"
) +
geom_text(aes(label = round(mean, 3)), vjust = -1.2, size = 3.5) +
labs(
title = paste(
"Top",
top_x,
"donors for synthetic",
origin,
"->",
destination
),
subtitle = paste0(
paste0("Total weight of top ", top_x, " donors: "),
round(max(weights_summary$cumulative_weight) * 100, 1),
"%"
),
x = NULL, # Remove axis label as names are self-explanatory
y = "Posterior Weight (Mean)"
) +
theme_minimal() +
theme(
axis.text.x = element_text(angle = 45, hjust = 1, size = 10),
panel.grid.major.x = element_blank()
)
print(p)
dev.off()