Splits data set into training, validation, and inference, stratifying by exposure level.
Required Libraries: caret, dplyr
split_dataset(data, num_exposure_cats, treat_col, stratify_cols)
A data.frame with variable "treat" corresponding to exposure level.
the number of categories to bin the exposure level into for stratification.
Name of treatment(exposure column) in data.
List of columns that needs to be used in strarifying.
A data frame encompassing the original data with an additional column named
subsample, each data row is one of inference
, validation
, and
exploration
categories.
"%>%" <- magrittr::"%>%"
data <- generate_syn_data_het(
sample_size = 500, outcome_type = "continuous",
gps_spec = 1, em_spec = 1, cova_spec = 1,
heterogenous_intercept = FALSE,
em_as_confounder = FALSE, outcome_sd = 1, beta = 0.3)
treat_col <- "treat"
num_exposure_cats <- 5
a.vals <- seq(min(data[[treat_col]]),
max(data[[treat_col]]),
length.out = num_exposure_cats)
data <-
data %>%
dplyr::mutate(treat_level = cut(treat, breaks = a.vals))
sp_data <- split_dataset(data = data,
num_exposure_cats = num_exposure_cats,
treat_col = treat_col,
stratify_cols = c("em1", "em2", "em3", "em4"))