What’s that saying? Choice is the enemy of happiness, or something like that. There are so many colour palettes out […]
The post Evolve new colour palettes in R with evoPalette appeared first on Daniel Oehm | Gradient Descending.
]]>What’s that saying? Choice is the enemy of happiness, or something like that. There are so many colour palettes out there, that saying tends to resonate when either choosing one, or creating a new one. So, I started to think about a way I could make this more organic, and came up with the idea for evoPalette.
evoPalette allows for the discovery and generation of new colour palettes for R. Taking user input it uses an evolutionary algorithm to spawn new palettes with the expectation the colours you like will be combined over time. This can be repeated as many times as needed until you find something that works for you.
There’s a lot to colour theory, no choice is trivial, and so leaving some of those choices largely to an algorithm will have varying success. This is trading best practice for fun and discovery but I’m okay with that!
Very briefly, evolutionary algorithms operate in 3 main steps:
The process can continue for as long as needed.
You can install evoPalette from Github with the following command.
devtools::install_github("doehm/evoPalette")
(It’s only on Git for the moment. We’ll see how it goes.)
To jump straight into it launch the Shiny app with
launch_evo_palette()
To generate the first set of palettes, click ‘evolve’. Using the default setting the initial palettes will be randomly drawn from a standard set saved in evoPalette and paletteer. Paletteer contains some 1700+ palettes from a number of packages.
Select one or more palettes that you like and think may work well together from the check box and click ‘evolve’. The next generation of palettes will be created.
With each generation you’ll notice how the palette is converging to a similar theme based on your selections, along with some variation of each individual colour and a chance of a completely new colour. The features that you liked about one palette have a chance to be combined with other colours and flow through to the next generation. Although, this isn’t completely ensured since there is still a stochastic element at play, but you may discover another combination that works well for you. The selected parent palettes are shown in the sidebar for reference.
Selecting only a single parent is convenient for generating minor variations to tweak the colours.
Select ‘Example: Fill aesthetic’ to see how the palette works in practice (similarly for ‘Example: Colour aesthetic’).
To save a palette, select the desired palette from the check box and click ‘save‘. A box will appear giving the opportunity to change the randomly generated palette name but, why would you want to!
The palette is now accessible from palette_box()
once the app is closed. Multiple palettes can be saved which are collected in the palette box, just save one at a time. The palette box is refreshed when starting a new session so remember to save it to disk.
palette_box() $instructive_customers [1] "#C7B6C8" "#74B6BA" "#7783C4" "#B56D4E" "#F28E47" "#FCE269"
To begin again, deselect all palettes, click ‘evolve’. New palettes are randomly selected beginning a new process.
A palette is viewed by using show_palette()
.
show_palette(palette_box()["instructive_customers"])
The continuous scales are simply created from the two colours on the ends and one in the middle (room for improvement).
The user has some control over the algorithm by adjusting the mutation rate and variation allowed for each colour. Select the parameters from the drop down in the menu.
palette_box()
.When you are happy with your selections, click ‘evolve’ and begin a new process.
The generated colour palettes are easily added to ggplots using the scale functions scale_fill_evo
and scale_colour_evo
.
library(ggplot2) mpg %>% ggplot(aes(x = displ, fill = class)) + geom_histogram() + scale_fill_evo("instructive_customers")
The first palette in the list is used if no name is given. The scales are also parameterised to reverse the palette and switch discrete / continuous.
Colour theory is pretty complex stuff so choosing a good palette isn’t easy, let alone evolving one. So, you’re going to have some hits and some misses. This is definitely more for fun seeing what you discover rather than finding the perfect palette. Having said that you could discover some gold!
There are best practices when choosing a palette for data visualisation depending on the context and what is to be shown. For example people tend to respond to certain colours representing high / low, hot / cold or good / bad, there is also colourblindness considerations. evoPalette won’t necessarily adhere to these ideals.
This is best for generating discrete colour palettes with 4-7 distinct colours. It will help you find colours which are to your liking and generate a palette which is unlikely to already exist. It’s fun to experiment and see what you come up with!
Some of the initial starting palettes I’ve created and are shown below. A few are created from evoPalette, some from a clustering process. The 90’s rock palettes were created from the album covers. Sorry for doubling up on Alice in Chains .
g_evo <- readRDS(system.file("extdata/palettes.rds", package = "evoPalette")) map2(g_evo$palette, g_evo$name, ~show_palette(.x, title = .y)) %>% wrap_plots(ncol = 2)
The post Evolve new colour palettes in R with evoPalette appeared first on Daniel Oehm | Gradient Descending.
]]>Generative Adversarial Networks are great for generating something from essentially nothing and there are some interesting uses for them. Most […]
The post Some basics and intuition behind GAN’s in R and Python appeared first on Daniel Oehm | Gradient Descending.
]]>Generative Adversarial Networks are great for generating something from essentially nothing and there are some interesting uses for them. Most uses are some sort of image processing. Nvidia’s GauGAN is an impressive example, giving the user an MS paint-like interface and generating landscapes. You can give the beta a shot here.
I wanted to take a step back and use a small example to understand the basics and build some intuition behind GAN’s. There’s a tonne of information out there on how to fit a GAN to generate new hand drawn numbers, faces or Pokemon to varying success (the jury is still out as to whether or not the output can pass as Pokemon, but anyway). This isn’t the focus for this post. Instead, I’m going to simplify things further and use a GAN to generate sine waves.
Keras and Tensorflow are used for this analysis. While there are R libraries I personally find it easier using Python via reticulate for deep learning tasks. You can find the code on Github. All the code bits in this post refer to functions from this repo.
For this to run correctly you’ll need Ananconda, Python 3.6-3.7 installed with keras and Tensorflow, as well as the standard libraries, numpy, pandas, etc. Along with python you’ll need reticulate installed and configured to use the appropriate version on Python. In short, run py_config()
to initialise python for the session and py_available()
to check if it’s all good to go. This can be tricky to set up and relies on how you’ve installed Python. If you have trouble refer to the reticulate cheat sheet and documentation.
For training data I’m going to use two wave functions,
with random noise added to throw in some minor variation.
get_training_data <- function(n = 200, m = 250, shape = list(a = c(1, 3), p = c(2, 10))) {
mat <- matrix(NA, nrow = n, ncol = m)
n_waves <- length(shape$a)
for(k in 1:n){
ak <- shape$a[(k-1) %% n_waves + 1]
pk <- shape$p[(k-1) %% n_waves + 1]
mat[k,] <- ak*sin(2*pi*seq(0, 1, length = m)*pk) + rnorm(m, 0, 0.05)
}
mat
}
train <- get_training_data()
plot_waves(train)
Nothing too complicated, just two distinct waves. Because we are generating training data using these two wave functions we only need to generate a handful of observations.
There are two main components, the generator and discriminator. The generator generates new waves from random input, in this case a standard normal distribution. The discriminator sorts the real from the fake data. During training it will switch between training the discriminator and the generator. At each iteration both components perform better – the generator gets better at generating real observations and the discriminator gets better at determining whether or not the observation is real or fake.
Like any neural network determining the number of hidden layers and sizes is more a process of experimentation than it is a science. For this example what I found worked well is,
All are dense layers with LeakyReLU activation and 20% dropout. Given the input data is distinct it seems like overkill however I found this worked best. I’m sure other network configurations also work.
The input data is time series data, in which case it is appropriate to use recurrent layers for the generator and discriminator. I actually found this to not be very successful. Either it takes far longer to train or just has trouble converging to a good solution, not saying it can’t be done though. For more challenging problems you’ll need more advanced models. Fortunately for this example we can keep it simple.
To train the GAN on these two functions we simply run.
gan(train, epochs = 2000)
This will call the Python script with the GAN code, run it in Python for 2000 epochs and return the results. The training is saved in the global environment as x_train
which is then able to be imported into the Python environment with r.x_train
. A log file is created within the working directory and records the progress every 100 epochs.
Once training has finished, view the output by py$gan_r$output
. At each iteration set by trace
a set of waves are generated. The plot_waves
function will plot a set from the final iteration.
Recall, we only fed the GAN two sine functions, which makes the output below pretty cool. Here we see 12 randomly generated waves from the final iteration of the GAN.
plot_waves(py$gan_r$output, seq(1, 24, 2))
With a single draw from a random normal distribution the GAN iteratively adjusted it’s weights until it learnt to generate and identify each wave. But also it learnt to generate new waves. What stands out here is the new waves appear to be some combination of the input waves. What we’ve done is found a really, really expensive and not particularly accurate way to estimate this…
where is between 0 and 1. This can be seen in the plot below where 12 waves have been plotted for different values of .
Without explicitly telling the GAN what the functions were, it managed to learn them and explore the space between. While it estimated the frequency well it didn’t quite explore the whole range of amplitudes. They tend to range between 1.5-2.5 rather than 1-3. With more training it would probably get there.
This took a few goes as training the GAN tends to converge to one of the input functions. By generating only one of the waves with high accuracy it can trick the discriminator into thinking it’s real every time. It’s a solution to the problem just not a very exciting one. With tuning we can get the desired result.
Each starting value will correspond to some kind of wave. Out of the 12 random waves, 4 are very similar, right down to the two little spikes at the top of the second crest (see the waves in the third column). This suggests this wave is mapped to a set of values that may be drawn with a higher probability.
This isn’t as sexy as generating new landscape images using Paint but it’s helpful to understand what is going on within the guts of the GAN. It attempts to identify the key features in the observations making it distinct from random noise, pass as a real observation and map the space between.
The same process is essentially happening at scale for more complex tasks. In this example it’s very inefficient to get the result we were looking for, but as the problem becomes more and more complex the trade-off makes sense.
With an image example such as faces, the GAN will identify what a nose looks like and the range of noses that exist in the population, what an eye looks like and that they generally come in pairs at a slightly varying distance apart and so on. It will then generate eyes and noses along some space of possible eyes and noses, in the same way it generated a wave along some space of possible waves given the input.
What’s interesting, the GAN only maps the space between. There are no examples where it generated a wave with a frequency greater than 10 or less than 2. Nor did it generate a wave with an amplitude greater than 3. The input waves essentially formed the boundaries of what could be considered real. There may be GAN variations that allow for this exploration.
The easiest way to get the code for this example is from Github. Either clone or install.
devtools::install_github("doehm/rgan")
The Python component is quite long and don’t think there is much to gain pasting it here. As mentioned the trickiest part is going to be configuring Python, keras and tensorflow but the R bits should work.
If you want to explore the output of the GAN the data is found at inst/extdata/gan-output.Rdata
in the Github repo. This will show you how the GAN improved with each iteration. Each element of the list is a sample of generated waves at iteration 100, 200, …, 2000. This data is the basis of the output waves above. e.g.
plot_waves(gan_output[[1]], id = 1:12, nrow = 3)
The code below created the title graphic. It is an area chart of waves using 48 values for . Perhaps worthy of an accidental aRt submission.
shape <- list(a = c(1, 3), p = c(2, 10))
m <- 250
ln <- 48
pal <- c("#4E364B", "#8D4E80", "#D86C15", "#F3C925", "#48B2A8")
map_dfr(seq(0, 1, length = ln), ~data.frame(
x = seq(0, 1, length = m),
y = .x*shape$a[1]*sin(2*pi*seq(0, 1, length = m)*shape$p[1]) +
(1-.x)*shape$a[2]*sin(2*pi*seq(0, 1, length = m)*shape$p[2]),
a = round(.x, 2))) %>% filter(x < 0.5) %>%
ggplot(aes(x = x, y = y, colour = a, fill = a, group = as.factor(a))) +
geom_area() +
theme_void() +
theme(legend.position = "none") +
scale_fill_gradientn(colors = colorRampPalette(pal)(200)) +
scale_colour_gradientn(colors = colorRampPalette(pal)(200))
The post Some basics and intuition behind GAN’s in R and Python appeared first on Daniel Oehm | Gradient Descending.
]]>Bayesian networks are really useful for many applications and one of those is to simulate new data. Bayes nets represent […]
The post Simulating data with Bayesian networks appeared first on Daniel Oehm | Gradient Descending.
]]>Bayesian networks are really useful for many applications and one of those is to simulate new data. Bayes nets represent data as a probabilistic graph and from this structure it is then easy to simulate new data. This post will demonstrate how to do this with bnlearn.
Before simulating new data we need a model to simulate data from. Using the same Australian Institute of Sport dataset from my previous post on Bayesian networks we’ll set up a simple model. For convenience I’ll subset the data to 6 variables.
library(DAAG)
library(tidyverse)
# ais data from the DAAG package
ais_sub <- ais %>% dplyr::select(sex, sport, pcBfat, hg, rcc, hc)
The variables sex and sport are pretty straight forward. The remaining four are
I’ve allowed the data to learn the structure of the network, bar one arc, sport to percentage of body fat. The details are not shown here, but check out the post above on how to fit the structure algorithmically (also I suggest heading to the bnlearn doco which has great examples of a number of networks that can be downloaded). The structure is defined by the string and converted to a bn
class object.
library(visNetwork)
library(bnlearn)
# set structure
bn_struct <- model2network("[sex][sport|sex][hg|sex][pcBfat|sex:sport][hc|hg:pcBfat][rcc|hc]")
bn_struct
## ## Random/Generated Bayesian network ## ## model: ## [sex][hg|sex][sport|sex][pcBfat|sex:sport][hc|hg:pcBfat][rcc|hc] ## nodes: 6 ## arcs: 7 ## undirected arcs: 0 ## directed arcs: 7 ## average markov blanket size: 2.67 ## average neighbourhood size: 2.33 ## average branching factor: 1.17 ## ## generation algorithm: Empty
# plot network - code for function at end of post
plot.network(bn_struct)
Now that we have set the structure of the model it is fit to the data with bn.fit
using maximum likelihood estimation.
bn_mod <- bn.fit(bn_struct, data = ais_sub, method = "mle")
The output is quite detailed so it’s worth running bn_mod
to view the conditional probability tables and Gaussian distributions.
New data is simulated from a Bayes net by first sampling from each of the root nodes, in this case sex. Then followed by the children conditional on their parent(s) (e.g. sport | sex and hg | sex) until data for all nodes has been drawn. The numbers on the nodes below indicate the sequence in which the data is simulated, noting that rcc is the terminal node.
From this point it’s easy to simulate new data using rbn
. Here we simulate a dataset the same size as the original, but you can simulate as many rows as needed.
ais_sim <- rbn(bn_mod, 202)
head(ais_sim)
## hc hg pcBfat rcc sex sport ## 1 38.00754 12.43565 13.772499 3.917082 f Swim ## 2 45.54250 15.79388 13.586402 4.824458 m Field ## 3 49.87429 17.31542 5.308675 5.814398 m T_400m ## 4 49.05707 17.19443 9.230973 5.337367 m W_Polo ## 5 37.66307 12.99088 13.685909 4.020170 f Gym ## 6 42.33715 14.62894 15.147165 4.440556 f Netball
Done. We now have a fully synthetic dataset which retains the properties of the original data. And it only took a few lines of code.
An important property of generating synthetic data is that it doesn’t use real data to do so, meaning any predictors need to be simulated first (my post on the synthpop package explains this in more detail). This property is retained since the data is generated sequentially as per the structure of network. Also, when using synthpop the order in which the variables are simulated needs to be set. The order can alter the accuracy of simulated dataset and so it’s important to spend the time to get it right. For a Bayesian network the order is determined by the structure, so in effect this step is already done.
The original and simulated datasets are compared in a couple of ways 1) observing the distributions of the variables 2) comparing the output from various models and 3) comparing conditional probability queries. The third test is more of a sanity check. If the data is generated from the original Bayes net then a new one fit on the simulated data should be approximately the same. The more rows we generate the closer the parameters will be to the original values.
The variable distributions are very close to the original with only a small amount of variation, mostly observed in sport. Red cell count may have a slight bi-modal distribution but in most part it’s a good fit. This amount of variation is reasonable since there are only 202 simulated observations. Simulating more rows will be a closer fit but there are often practical considerations for retaining the same size dataset.
For the second check, two linear models are fit to the original and simulated data to predict hematocrit levels with sex, hemoglobin concentration, percentage of body fat and red cell count as predictors. Sport was left out of the model since it was not a strong predictor of hc and only increased the error.
# fit models
glm_og <- glm(hc ~ hg + rcc + pcBfat + sex, data = ais_sub)
glm_sim <- glm(hc ~ hg + rcc + pcBfat + sex, data = ais_sim)
# print summary
summary(glm_og)
## ## Call: ## glm(formula = hc ~ hg + rcc + pcBfat + sex, data = ais_sub) ## ## Deviance Residuals: ## Min 1Q Median 3Q Max ## -3.9218 -0.4868 0.0523 0.5470 2.8983 ## ## Coefficients: ## Estimate Std. Error t value Pr(>|t|) ## (Intercept) 5.15656 1.04109 4.953 1.57e-06 *** ## hg 1.64639 0.11520 14.291 < 2e-16 *** ## rcc 3.04366 0.31812 9.568 < 2e-16 *** ## pcBfat -0.02271 0.01498 -1.517 0.131 ## sexm -0.20182 0.23103 -0.874 0.383 ## --- ## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1 ## ## (Dispersion parameter for gaussian family taken to be 0.888242) ## ## Null deviance: 2696.92 on 201 degrees of freedom ## Residual deviance: 174.98 on 197 degrees of freedom ## AIC: 556.25 ## ## Number of Fisher Scoring iterations: 2
summary(glm_sim)
## ## Call: ## glm(formula = hc ~ hg + rcc + pcBfat + sex, data = ais_sim) ## ## Deviance Residuals: ## Min 1Q Median 3Q Max ## -2.2117 -0.6232 0.0089 0.5910 2.0073 ## ## Coefficients: ## Estimate Std. Error t value Pr(>|t|) ## (Intercept) 5.318057 0.953131 5.580 7.9e-08 *** ## hg 1.633765 0.107235 15.235 < 2e-16 *** ## rcc 2.905111 0.299397 9.703 < 2e-16 *** ## pcBfat -0.001506 0.014768 -0.102 0.919 ## sexm 0.319853 0.220904 1.448 0.149 ## --- ## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1 ## ## (Dispersion parameter for gaussian family taken to be 0.7520993) ## ## Null deviance: 3014.20 on 201 degrees of freedom ## Residual deviance: 148.16 on 197 degrees of freedom ## AIC: 522.64 ## ## Number of Fisher Scoring iterations: 2
The coefficients and test statistics of the models are very similar, so both datasets result in the same conclusions. Percent of body fat is the least accurate but still make the same conclusion. In practice you should fit more models to assess the quality of the simulated data.
As mentioned the third is more of a sanity check but it is also a good demonstration of the process. By fitting the same structure to the simulated data we expect to estimate the same parameters and calculate very similar conditional probabilities. Here we simulate 20 000 observations to better estimate the parameters. The conditional probability for the athletes red cell count given the sport they compete in i.e. what is the probability the athletes red cell count will be greater than where is the 33rd and 66th percentile?
library(progress)
# fit bayes net with the same structure using simulated
ais_sim <- rbn(bn_mod, 2e4)
bn_mod_sim <- bn.fit(bn_struct, ais_sim, method = "mle")
# sport - mcpquery function at end of post - it generalises bnlearn::cpquery
# rcc is continuous so we're calculating the probability rcc > some value, here the 33rd and 66th percentile
# the function replaces xx and yy with actually values and evaluates the text
given <- "sport == 'xx'"
event <- "rcc > yy"
vals <- as.matrix(expand.grid(sport = unique(ais_sub$sport), rcc = quantile(ais_sub$rcc, c(0.33, 0.66))))
orig <- mcpquery(bn_mod, values = vals, token = c("xx", "yy"), event, given, n = 1e6) %>% spread(rcc, cp)
## P( rcc > yy | sport == 'xx' )
sim <- mcpquery(bn_mod_sim, values = vals, token = c("xx", "yy"), event, given, n = 1e6) %>% spread(rcc, cp)
## P( rcc > yy | sport == 'xx' )
# join for display
left_join(orig, sim, by = "sport", suffix = c("_orig", "_sim"))
## # A tibble: 10 x 5 ## sport `4.4533_orig` `4.9466_orig` `4.4533_sim` `4.9466_sim` ## <chr> <dbl> <dbl> <dbl> <dbl> ## 1 B_Ball 0.687 0.310 0.682 0.302 ## 2 Field 0.764 0.386 0.768 0.384 ## 3 Gym 0.477 0.0658 0.475 0.0717 ## 4 Netball 0.444 0.0584 0.442 0.0575 ## 5 Row 0.652 0.270 0.654 0.271 ## 6 Swim 0.752 0.370 0.758 0.376 ## 7 T_400m 0.767 0.388 0.769 0.386 ## 8 T_Sprnt 0.822 0.449 0.826 0.445 ## 9 Tennis 0.640 0.251 0.639 0.249 ## 10 W_Polo 0.945 0.571 0.944 0.566
The conditional probabilities from the simulated data are very close to the original as expected. Now we can be confident that our simulated data can be used as an alternative to the original.
Another useful property of Bayes nets is to impute missing values. This is easily done using impute
. We’ll remove 25% of the observations from variables hg and hc, and allow the Bayes net to impute them.
ais_miss <- ais_sub
miss_id <- sample(1:nrow(ais_sub), 50)
ais_miss[miss_id, c("hg", "hc")] <- NA
# table counting the NA values in hg and hc
apply(ais_miss[,c("hg", "hc")], 2, function(x) table(is.na(x)))
## hg hc ## FALSE 152 152 ## TRUE 50 50
The table confirms there are 50 missing observations from hemoglobin and hematocrit variables. Now impute using Bayesian likelihood weighting.
ais_imp <- impute(bn_mod, data = ais_miss, method = "bayes-lw")
Plotting the imputed against the true values shows the Bayes net imputed the missing values quite well.
I’ve only tested and shown two variables but the others would perform similarly for the subset of data I have chosen. This data is normally distributed so I expected it to work well, however if your data has more complex relationships you’ll need to be more rigorous with defining the structure.
# plot network function
plot.network <- function(structure, ht = "400px", cols = "darkturquoise", labels = nodes(structure)){
if(is.null(labels)) labels <- rep("", length(nodes(structure)))
nodes <- data.frame(id = nodes(structure),
label = labels,
color = cols,
shadow = TRUE
)
edges <- data.frame(from = structure$arcs[,1],
to = structure$arcs[,2],
arrows = "to",
smooth = FALSE,
shadow = TRUE,
color = "black")
return(visNetwork(nodes, edges, height = ht, width = "100%"))
}
# conditional probability query functions
mcpquery <- function(mod, values, token, event, given, ns = 1e4){
cat("P(", event, "|", given, ")\n")
UseMethod("mcpquery", values)
}
# numeric
mcpquery.numeric <- function(mod, values, token, event, given, ns = 1e4){
y <- rep(NA, length(values))
pb <- progress_bar$new(
format = "time :elapsedfull // eta :eta // :k of :n // P( :event | :given )",
clear = FALSE, total = length(values))
for(k in 1:length(values)){
givenk <- gsub(token, values[k], given)
eventk <- gsub(token, values[k], event)
pb$tick(token = list(k = k, n = length(values), event = eventk, given = givenk))
y[k] <- eval(parse(text = paste0("cpquery(mod,", eventk, ",", givenk, ", n = ", ns, ", method = 'ls')")))
}
return(tibble(values = values, cp = y) %>% arrange(desc(cp)))
}
# character
mcpquery.character <- function(mod, values, token, event, given, ns = 1e4){
y <- rep(NA, length(values))
pb <- progress_bar$new(
format = "time :elapsedfull // eta :eta // :k of :n // P( :event | :given )",
clear = FALSE, total = length(values))
for(k in 1:length(values)){
givenk <- gsub(token, values[k], given)
eventk <- gsub(token, values[k], event)
pb$tick(token = list(k = k, n = length(values), event = eventk, given = givenk))
y[k] <- eval(parse(text = paste0("cpquery(mod,", eventk, ",", givenk, ", n = ", ns, ", method = 'ls')")))
}
return(tibble(values = values, cp = y) %>% arrange(desc(cp)))
}
# matrix
mcpquery.matrix <- function(mod, values, token, event, given, ns = 1e4){
n <- nrow(values)
y <- rep(NA, n)
pb <- progress_bar$new(
format = "time :elapsedfull // eta :eta // :k of :n // P( :event | :given )",
clear = FALSE, total = n)
for(k in 1:n){
givenk <- given
eventk <- event
for(j in 1:ncol(values)){
givenk <- gsub(token[j], values[k,j], givenk)
eventk <- gsub(token[j], values[k,j], eventk)
}
pb$tick(token = list(k = k, n = n, event = eventk, given = givenk))
y[k] <- eval(parse(text = paste0("cpquery(mod,", eventk, ",", givenk, ", n = ", ns, ", method = 'ls')")))
}
out <- as.tibble(values) %>%
bind_cols(tibble(cp = y))
colnames(out)[1:ncol(values)] <- colnames(values)
return(out)
}
# compare the synthetic and original data frames
df <- ais_sub %>%
mutate(type = "orig") %>%
bind_rows(
rbn(bn_mod, 202) %>%
mutate(type = "sim")
) # %>%
gg_list <- list()
grp_var <- "type"
vars <- colnames(df)[colnames(df) != grp_var]
for(k in 1:length(vars)){
var_k <- vars[k]
gg_list[[k]] <- ggplot(df, aes_string(x = var_k, fill = grp_var, col = grp_var))
if(is.numeric(df[[var_k]])){
gg_list[[k]] <- gg_list[[k]] + geom_density(alpha = 0.85, size = 0)
}else{
gg_list[[k]] <- gg_list[[k]] + geom_bar(position = "dodge")
}
gg_list[[k]] <- gg_list[[k]] +
theme(
axis.text.x = element_text(angle = 90),
axis.title.x = element_blank()
) +
labs(title = var_k)
}
The post Simulating data with Bayesian networks appeared first on Daniel Oehm | Gradient Descending.
]]>Previously I posted on how to apply gradient descent on linear regression as an example. With that as background it’s […]
The post Use more of your data with matrix factorisation appeared first on Daniel Oehm | Gradient Descending.
]]>Previously I posted on how to apply gradient descent on linear regression as an example. With that as background it’s relatively easy to extend the logic to other problems. One of those is matrix factorisation.
There are many ways to factorise a matrix into components such as PCA, singular value decomposition (SVD), but one way is to use gradient descent. While it’s inception is in image processing, it was popularised by it’s use with recommender systems (Funk SVD). But, it is also found useful in other ways such as market basket analysis and topic modelling. Once you know the nuts and bolts and how to apply it you’ll find uses that are not recommenders.
Firstly, the maths.
Consider a matrix . We want to factorise into two matrices and such that
and have dimensions x and x where typically . These are estimated in the following steps 1) initialise each matrix with random values 2) iteratively adjust each matrix by the direction of negative gradient to minimise the MSE and 3) stop at the maximum iterations or a specified tolerance.
The objective function is the mean squared error and a regularisation term.
Here is an indicator where if the rating or value exists and 0 if it is missing. The gradients with respect to and are
The factor is dropped to keep it tidy. By doing so this quantity is effectively rolled into the learning rate, if you work through the maths. For each iteration and are adjusted by
where is the learning rate.
Factorising a matrix into two lower dimensional matrices learns the latent information. For example, if your data is a patient-drug profile it may learn the diagnosis of the patient. In the case of a recommender if your data are song ratings it will learn those that prefer metal over pop, classical over electro, etc. The first column of the matrix may represent the level of distortion in a song, or tempo as a numeric value. What the algorithm picks out as important isn’t necessarily obvious and worth inspecting those on the extremes to get a feel for the data. The user matrix works similarly. The numeric values in the user row represent how much weight that user gives to the particular feature.
Why this is useful over PCA or SVD is it handles missing values. The other methods treat every cell as a value so need to be imputed before factorising. It will then use the imputed values to learn the data structure. The beauty of this method is it factorises only on the observed values and imputes the missings.
I’ve developed a simple package to apply the method. There are other packages which are more optimised but this is just simple and easy to use. You can install it from Github (not on CRAN).
devtools::install_github("doehm/matrixfactorisation")
Simulate some data.
m <- matrix(sample(c(NA, 1:5), 60, replace = TRUE, prob = c(0.2, rep(0.8/5, 5))), nrow = 10)
m
## [,1] [,2] [,3] [,4] [,5] [,6] ## [1,] NA NA 1 4 2 2 ## [2,] NA 2 NA 1 5 4 ## [3,] 5 3 3 1 4 4 ## [4,] NA 1 3 5 3 5 ## [5,] 4 4 3 2 4 3 ## [6,] NA 4 5 NA 4 4 ## [7,] NA NA 2 3 3 2 ## [8,] 2 NA 2 NA 3 2 ## [9,] 2 3 5 2 3 NA ## [10,] 3 3 2 2 5 5
Select a test set and factorise. Selecting a test set holds out a proportion of randomly selected observations. It ensures there is at least 1 observation remaining in each row. This can be increased by min_row
.
id <- select_test(m, 0.2)
mf <- matrix_factorisation(m, 2, test = id$test, pbar = TRUE)
## 00:00:02 // dimensions 2 // epoch 1000 // train 1.2703 // test 1.6059 // delta 9.467e-04
Plot the convergence.
plot(mf)
Viewing the convergence profile is useful to see whether or not the data has been overfit.
Typically missing values are treated before clustering which may include removing certain rows or columns. Usually this also means throwing away actual observations. If the data has a proportionally large amount of missing values you could be throwing away most of the data. This is quite typical for ‘user-item’ type data. Matrix factorisation enables the use of more data.
This is demonstrated in a small example. A matrix is randomised with 12 users, 6 items and 2 groups with a high degree of missingness. The group is the latent variable. The values are drawn from a Poisson distribution . Those that are in group 1 are drawn from the distribution with and those in group 2 are drawn from . It is then simple to see the relationship between the users, items and groups. Count data is very common and often has it’s own challenges when clustering, particularly when counts are small or only part of the units history is observed e.g. a data set where the grain is the patient, the items are prescription drugs and the values are counts of that drug.
user_a <- function(m) c(rpois(m/2, 5), rpois(m/2, 1))
user_b <- function(m) c(rpois(m/2, 1), rpois(m/2, 5))
df <- NULL
n <- 12
m <- 6
grp <- as.factor(sort(rep(letters[1:2], n/2)))
for(k in 1:n){
df <- rbind(df, if(k <= 6) user_a(m) else user_b(m))
}
df
## [,1] [,2] [,3] [,4] [,5] [,6] ## [1,] 6 7 6 0 3 2 ## [2,] 5 5 3 1 1 1 ## [3,] 7 3 7 1 0 0 ## [4,] 5 5 5 0 0 0 ## [5,] 3 1 3 2 0 0 ## [6,] 5 7 4 0 1 0 ## [7,] 0 0 2 7 6 6 ## [8,] 3 1 0 6 9 8 ## [9,] 2 0 2 6 4 5 ## [10,] 0 2 1 5 4 9 ## [11,] 0 2 2 4 4 5 ## [12,] 0 0 2 1 6 8
We’ll now suppress 60% of the values.
df_na <- select_test(df, 0.6)$train
df_na
## [,1] [,2] [,3] [,4] [,5] [,6] ## [1,] NA 7 NA NA NA 2 ## [2,] 5 NA 3 NA NA NA ## [3,] NA 3 7 NA 0 NA ## [4,] 5 5 5 0 0 NA ## [5,] 3 1 3 NA NA 0 ## [6,] 5 NA NA 0 NA 0 ## [7,] 0 NA NA NA NA 6 ## [8,] 3 1 NA 6 9 8 ## [9,] NA 0 2 NA 4 NA ## [10,] 0 2 NA NA NA 9 ## [11,] 0 2 2 NA NA 5 ## [12,] NA NA 2 NA 6 NA
This data can be challenging to use, but by factorising using only the observed data it is embedded in a much more usable form (mathematically this is similar to using the autoencoder – PCA vs autoencoders). Using the prediction matrix all units can be clustered. While there is error in the prediction matrix it should be accurate enough to find the two groups.
# factorise
df_mf <- matrix_factorisation(df_na, 2, epochs = 2000)
# cluster
km <- kmeans(df_mf$pred, 2, nstart = 100)
clus <- as.factor(km$cluster)
names(clus) <- grp
clus
## a a a a a a b b b b b b ## 1 1 1 1 1 1 2 2 2 2 2 2 ## Levels: 1 2
table(clus, grp)
## grp ## clus a b ## 1 6 0 ## 2 0 6
A simple K-means clustering process finds the two groups, which is encouraging with a highly patchy and small data set. K-means can struggle with low value count data so embedding the data in this form can really help.
Often the original data set can be large and therefore inefficient to cluster. The same results (or a at least a very good approximation) is achieved using the matrix instead of the prediction matrix.
# cluster
km <- kmeans(df_mf$u, 2, nstart = 100)
clus <- as.factor(km$cluster)
names(clus) <- grp
clus
## a a a a a a b b b b b b ## 1 1 1 1 1 1 2 2 2 2 2 2 ## Levels: 1 2
table(clus, grp)
## grp ## clus a b ## 1 6 0 ## 2 0 6
In this case we achieve the same result which is great. This is an approximation so may not always work as well, in fact if you randomised the input matrix a number of times I’m sure you’ll see different results. But in general you should achieve very reasonable results most of the time.
In practice I have had success using this technique on large data sets. While my preferred method is to use the autoencoder for dimension reduction and embedding, there is an overhead in setting it up. On the other hand matrix factorisation is quick and simple to get you exploring the data sooner. Clustering is just one example and I’m sure you’ll find more applications that benefit from matrix factorisation.
The post Use more of your data with matrix factorisation appeared first on Daniel Oehm | Gradient Descending.
]]>If you’re like me, when running a process through a loop you’ll add in counters and progress indicators. That way […]
The post Fun with progress bars: Fish, daggers and the Star Wars trench run appeared first on Daniel Oehm | Gradient Descending.
]]>If you’re like me, when running a process through a loop you’ll add in counters and progress indicators. That way you’ll know if it will take 5 minutes or much longer. It’s also good for debugging to know when the code wigged-out.
This is typically what’s done. You take a time stamp at the start – start <- Sys.time()
, print out some indicators at each iteration – cat(“iteration”, k, “// reading file”, file, “\n”)
and print out how long it took at the end – print(Sys.time()-start)
. The problem is it will print out a new line at each time it is called, which is fine but ugly. You can reduce the number of lines printed by only printing out every 10th or 100th iteration e.g. if(k %% 10 == 0) …
.
A simple way to make this better is instead of using "\n"
for a new line use "\r"
for carriage return. This will overwrite the same line which is much neater. It’s much more satisfying watching a number go up, or down, whichever way is the good direction. Try it out…
y <- matrix(0, nrow = 31, ncol = 5)
for(sim in 1:5){
y[1, sim] <- rnorm(1, 0, 8)
for(j in 1:30){
y[j+1, sim] <- y[j, sim] + rnorm(1) # random walk
cat("simulation", sim, "// time step", sprintf("%2.0f", j), "// random walk", sprintf(y[j+1, sim], fmt='% 6.2f'), "\r")
Sys.sleep(0.1)
}
}
## simulation 5 // time step 30 // random walk 8.97
The best way is to use the {progress} package. This package allows you to simply add running time, eta, progress bars, percentage complete as well as custom counters to your code. First decide on what counters you want and the format of the string. The function identifies counters by using a colon at the beginning of the label. Check the doco for built-in tokens.
To add your own token add the label to the format string and add the token to tick()
. To make it pretty I recommend formatting digits with sprintf()
. Here’s an example.
library(progress)
pb <- progress_bar$new(format = ":elapsedfull // eta :eta // simulation :sim // time step :ts // random walk :y [:bar]", total = 30*5, clear = FALSE)
y <- matrix(0, nrow = 31, ncol = 5)
for(sim in 1:5){
y[1, sim] <- rnorm(1, 0, 8)
for(j in 1:30){
y[j+1, sim] <- y[j, sim] + rnorm(1) # random walk
pb$tick(tokens = list(sim = sim, ts = sprintf("%2.0f", j), y = sprintf(y[j+1, sim], fmt='% 6.2f')))
Sys.sleep(0.1)
}
}
00:00:17 // eta 0s // simulation 5 // time step 30 // random walk -12.91 [====================================================]
You can also jazz it up with a bit of colour with {crayon}. Be careful with this, it doesn’t handle varying string lengths very well and can start a new line exploding your console.
library(crayon)
pb <- progress_bar$new(format = green$bold(":elapsedfull // eta :eta // simulation :sim // time step :ts // random walk :y [:bar]"), total = 30*5, clear = FALSE)
...
00:00:17 // eta 0s // simulation 5 // time step 30 // random walk -12.91 [====================================================]
That’s a much neater progress bar.
Procrastination set in and creative tangents were followed. So, made a progress bar into a big fish which eats smaller fish … and made it green.
n <- 300
bar_fmt <- green$bold(":elapsedfull | :icon |")
pb <- progress_bar$new(format = bar_fmt, total = n, clear = FALSE)
icon <- progress_bar_icon("fish", n, 75)
for(j in 1:n){
pb$tick(tokens = list(
icon = token(icon, j)
))
Sys.sleep(0.03)
}
Each fish represents 25% completion. Once they’re all gobbled up, the job is done.
I also threw knives at boxes. Each box represents 20% completion.
n <- 300
bar_fmt <- green$bold(":elapsedfull | :icon |")
pb <- progress_bar$new(format = bar_fmt, total = n, clear = FALSE)
icon <- progress_bar_icon("dagger", n, 75)
for(j in 1:n){
pb$tick(tokens = list(
icon = token(icon, j)
))
Sys.sleep(0.03)
}
And my personal favourite, the Star Wars trench run.
n <- 500
bar_fmt <- green$bold(":elapsedfull | :icon |")
pb <- progress_bar$new(format = bar_fmt, total = n, clear = FALSE)
icon <- progress_bar_icon("tiefighter", n, 75)
for(j in 1:n){
pb$tick(tokens = list(
icon = token(icon, j)
))
Sys.sleep(0.03)
}
Ok… I have spent way too long on this! But at least it was fun. If you want to play around with it, feel free to download it from Git.
devtools::install_github(“doehm/progressart”)
The post Fun with progress bars: Fish, daggers and the Star Wars trench run appeared first on Daniel Oehm | Gradient Descending.
]]>In my previous post I built a Shiny app mapping accidents on Queensland roads which was great at showing the […]
The post Bayesian estimation of fatality rates and accidents involving cyclists on Queensland roads appeared first on Daniel Oehm | Gradient Descending.
]]>In my previous post I built a Shiny app mapping accidents on Queensland roads which was great at showing the problematic areas within cities and regional areas. I have added to this by estimating the fatality rate given the observed accidents and the rate of accidents involving cyclists for SA3 and SA4 areas. I have also updated the filters making them tidier. What follows in commentary of what can be found in the Shiny app. If you want to jump straight to it, run
shiny::runGitHub("doehm/road-accidents/", "doehm", launch.browser = TRUE)
I used a Bayesian approach to estimate the fatality rate for 2017 (the data isn’t complete for 2018) and presented it as a proportion of the number of observed road accidents. The data dates back to 2001 but it’s reasonable to use the most recent data to allow for improvements in road conditions, policies, advertising, population growth, etc which may have an underlying impact on the risk of road accidents.
To construct a prior I used years 2014-2016 at SA3 level. By taking 10,000 draws from the posterior we estimate the fatality rate distributions for each area. The bands represent the 80% and 95% prediction interval.
The top 3 most dangerous SA4’s for 2017 are Mackay – Isaac – Whitsunday, Wide Bay and The Queensland outback. It’s not surprising that the large majority of fatalities occurred at high speeds on highways and back roads. Having said that a few have occurred closer to the town centers in these areas. It’s possible the road conditions are not particularly good but that’s hard to determine. The road condition filter only has 4 states, sealed (dry / wet) and unsealed (dry / wet) and doesn’t offer much more information. It appears dry conditions on sealed roads is when most accidents occur.
The 3 least dangerous areas are Brisbane North, Brisbane Inner City and Brisbane South. Again, it shouldn’t be surprising given the speed limit is lower in the city, if accidents occur they are generally less severe. There are a lot of rear ended accidents and accidents when leaving a car park – the classic.
At the SA3 level the incident rates focus on smaller geographical areas which can highlight other features. For example, at the SA4 level the area often includes the city centers and the Hinterland regions which tend to have different rates where it’s often higher in the Hinterlands. This is largely due to the differing speed limits. The most dangerous areas are Burnett, Port Douglas and the Ipswich Hinterlands. The least dangerous are Chermside, Brisbane Inner West and Springfield – Redbank. Click the image to zoom in, there are a lot of SA3’s.
As a cyclist it’s good to know which roads to pay particular attention on. In a similar way we’ll look at the rate of road accidents involving cyclists by area. An influencing factor is how many cyclists are on the roads relative to cars and other vehicles. If the area has a strong cycling culture it’s more likely that an accident will involve a cyclist.
It’s not surprising the dense city centers with more traffic have more accidents involving cyclists. At the SA4 level it’s pretty clear that Brisbane Inner City is a dangerous place to ride a bike, particularly on Adelaide street which you can see from the app (go to show filter > unit type involved > bicycle). The rate of accidents involving cyclists in Brisbane Inner City is significantly higher than all other SA4’s. I’m curious to know if the the growth of the CityCycle scheme (which is awesome by the way) is somewhat a contributing factor. Although, given this was introduced in 2011, and 2005-2010 saw the highest rate of growth in accidents involving cyclists, probably isn’t the biggest factor in the high rate of accidents but it should contribute if more people are on bikes – they’re also not the easiest bikes in the world to ride. The other areas in the top 3, Brisbane City West and Cairns. Sheridan Street, the main drag in Cairns is where you need to be extra vigilant.
Here Noosa tops the chart for 2017. It sits within the Sunshine Coast SA4 region ranking 5th in the chart above which also includes the Hinterland regions which is why it scores lower above. Noosa has a strong cycling culture which could be influencing it’s high rate. With a median of a whopping 18% definitely keep your eyes peeled when cycling on those roads. Also, take care in Sherwood – Indooroopilly. Keep in mind that these areas are not necessarily where the most accidents involving cyclists occur. Rather if there is an accident it’s more likely to have involved a cyclists relative to other areas. If you’re driving in those areas keep an eye out for cyclists on the road, perhaps there aren’t marked cycling lanes or off-road bike paths or some dodgy intersections.
Ideally the rate would be adjusted for the number of cyclists on the road, however this is still a useful statistic to find the more dangerous areas. This analysis has also been done for pedestrians involved in car accidents which is can be found in the app. Just a heads up, be careful in Fortitude Valley.
Run the shiny app
library(shiny)
runGitHub("doehm/road-accidents/", "doehm", launch.browser = TRUE)
Load the data from Github.
library(repmis)
source_data("https://github.com/doehm/road-accidents/raw/master/data/road-accident-data.Rdata")
Estimate hyperparameters
# get prior information from previous 3 years
# use method of moments to get alpha and beta
# ideally we would use a hierarchical model but this is suitable
hyper_params <- function(y){
mu <- mean(y)
s2 <- var(y)
gamma <- mu*(1-mu)/s2-1
alpha <- gamma*mu
beta <- gamma*(1-mu)
return(list(alpha = alpha, beta = beta))
}
# estimating the actual parameters
df <- accidents_raw %>%
filter(
Loc_Queensland_Transport_Region != "Unknown",
Crash_Year >= 2014,
Crash_Year <= 2016
) %>%
group_by(Loc_ABS_Statistical_Area_3) %>%
summarise(
count = n(),
n_fatalities = sum(Count_Casualty_Fatality)
) %>%
mutate(theta = n_fatalities/(count))
hyper_params(df$theta)
Forest plots
library(tidyverse)
library(showtext)
# font
font_add_google(name = "Montserrat", family = "mont")
showtext_auto()
# SA3 forest plot of fatality rate
# set my theme and colours
theme_forest <- function(scale = 1){
theme_minimal() +
theme(
legend.position = "none",
axis.text = element_text(family = "mont", size = 16*scale),
axis.text.x = element_text(vjust = 0.5),
axis.title.y = element_blank(),
axis.title.x = element_text(family = "mont", size = 16*scale),
plot.title = element_text(family = "mont", hjust = 0.5, size = 26*scale, face = "bold"),
plot.subtitle = element_text(family = "mont", hjust = 0.5, size = 20*scale),
plot.caption = element_text(size = 12*scale)
)
}
my_cols <- function(n = 16) colorRampPalette(c("darkmagenta", "turquoise"))(n)
# simulate from the posterior
posterior_f <- function(df, y, n, a = 1.3, b = 77, inflator = 100) {
out <- data.frame()
qs <- c(0.025, 0.1, 0.5, 0.9, 0.975)
for(k in 1:nrow(df)){
out <- rbind(out, inflator*rgamma(1e4, shape = a+y[k], rate = b+n[k]) %>% quantile(qs))
}
colnames(out) <- paste0("q", 100*qs)
return(out)
}
# SA4
# fatalities
areas <- grep("Area", colnames(accidents_raw), value = TRUE)[3:4]
names(areas) <- c("sa3", "sa4")
fatality_fn <- function(area){
accidents_raw %>%
group_by_(area) %>%
filter(
Crash_Year == 2017,
) %>%
summarise(
count = length(Count_Casualty_Total),
n_fatalities = sum(Count_Casualty_Fatality)
) %>%
bind_cols(posterior_f(df = ., y = .$n_fatalities, n = .$count)) %>%
arrange(q50) %>%
mutate_(area = interp(~factor(v, level = v), v = as.name(area))) %>%
ggplot() +
geom_segment(mapping = aes(x = q2.5, xend = q97.5, y = area, yend = area)) +
geom_segment(mapping = aes(x = q10, xend = q90, y = area, yend = area, col = q50), size = 2) +
geom_point(mapping = aes(x = q50, y = area), pch = 3) +
theme_forest() +
scale_colour_gradientn(colors = my_cols()) +
labs(
title = "Fatality rate given observed road accidents",
subtitle = paste("Bayesian estimate of the fatality rate for", toupper(names(area)), "areas in 2017"),
x = "Fatality rate (%)")
}
fatality_plots <- list(sa3 = fatality_fn(areas[1]), sa4 = fatality_fn(areas[2]))
Cyclist plot
# cyclists
df <- accidents_raw %>%
filter(
Crash_Year >= 2014,
Crash_Year <= 2016,
Loc_ABS_Statistical_Area_3 != "Unknown"
) %>%
group_by(Loc_ABS_Statistical_Area_3) %>%
summarise(
n_bicycles = sum(Count_Unit_Bicycle > 0),
n_accidents = n()
) %>%
mutate(p = n_bicycles/n_accidents) %>%
arrange(desc(p))
# estimate hyperparameters
hyper_params(df$p)
# cyclists
cyclist_fn <- function(area){
accidents_raw %>%
group_by_(area) %>%
filter(Crash_Year == 2017) %>%
summarise(
count = n(),
n_bicycles = sum(Count_Unit_Bicycle > 0)
) %>%
bind_cols(posterior_f(df = ., y = .$n_bicycles, n = .$count, a = 1.55, b = 25)) %>%
arrange(q50) %>%
mutate_(area = interp(~factor(v, level = v), v = as.name(area))) %>%
ggplot() +
geom_segment(mapping = aes(x = q2.5, xend = q97.5, y = area, yend = area)) +
geom_segment(mapping = aes(x = q10, xend = q90, y = area, yend = area, col = q50), size = 2) +
geom_point(mapping = aes(x = q50, y = area), pch = 3) +
theme_forest() +
scale_colour_gradientn(colors = my_cols()) +
labs(
title = "Rate of cyclists involved in road accidents",
subtitle = paste("Bayesian estimate of the rate of accidents involving cyclists for", toupper(names(area)), "areas in 2017"),
x = "Accidents involving cyclists (%)"
)
}
cyclist_plots <- list(sa3 = cyclist_fn(areas[1]), sa4 = cyclist_fn(areas[2]))
# Brisbane inner
accidents_raw %>%
filter(Loc_ABS_Statistical_Area_4 == "Brisbane Inner City", Crash_Year < 2018) %>%
group_by(Crash_Year) %>%
summarise(
n_bikes = sum(Count_Unit_Bicycle > 0),
n_accidents = n(),
p_bikes = n_bikes/n_accidents
) %>%
bind_cols(posterior_f(df = ., y = .$n_bikes, n = .$n_accidents, a = 1.55, b = 25)/100) %>%
ggplot(aes(x = Crash_Year, y = q50)) +
geom_line(col = "darkmagenta") +
geom_point(col = "darkmagenta") +
theme_minimal() +
theme_forest() +
labs(
x = "Year",
title = "Accidents involving cyclists - Brisbane Inner City",
subtitle = "Accidents involving cyclists are increasing, indicative of growth in popularity of cycling - Be careful on the roads"
)
The post Bayesian estimation of fatality rates and accidents involving cyclists on Queensland roads appeared first on Daniel Oehm | Gradient Descending.
]]>The Queensland government collects data on road accidents dating back to 1st January 2001 and details characteristics of the incident […]
The post Queensland road accidents mapped with Shiny and leaflet in R appeared first on Daniel Oehm | Gradient Descending.
]]>The Queensland government collects data on road accidents dating back to 1st January 2001 and details characteristics of the incident including,
Mapping this data highlights hot spots where car accidents occur more often. In particular the dangerous areas in wet conditions, problematic intersections and the areas of Queensland which are more dangerous than others in terms of fatality rates.
I developed a Shiny App utilising leaflet to easily explore the data (and just for fun). It features,
This data is of road accidents, so the estimate of fatality rate in this case is the fatality rate given the vehicle was involved in an accident, rather than the fatality rate by road accident in the population. It is a slightly different take on how this statistic is usually published, but a useful one.
The best way to view the app is to run the following code. Firstly, check to make sure you have the packages installed by running
check_packages <- function(packages){
if(all(packages %in% rownames(installed.packages()))){
TRUE
}else{
cat("Install the following packages before proceeding\n", packages[!(packages %in% rownames(installed.packages()))], "\n")
}
}
packages_needed <- c("tidyverse", "shiny", "leaflet", "leaflet.extras", "magrittr", "htmltools", "htmlwidgets", "showtext", "data.table")
check_packages(packages_needed)
If all good, now run the line below and it will load the app.
runGitHub("doehm/road-accidents/", "doehm", launch.browser = TRUE)
This will launch it directly on your machine. Or you can follow the link directly to the Shiny app.
There are a lot of neat things we can do with this data and I’ll be adding to the app over time.
A subset of the app focuses on the “Brisbane Inner” SA3 area to give a taste of what to expect. It shows car accidents in the city since 1st January 2013. When zooming in, hover over the marker to get a short description of the crash.
View the full screen map here.
Below is the underlying code of the example above leaflet map, but I strongly recommend running the code above to view the Shiny app. See Github for the full code.
# queensland road accident data
# libraries
library(tidyverse)
library(shiny)
library(leaflet)
library(leaflet.extras)
library(magrittr)
library(htmltools)
library(htmlwidgets)
library(showtext)
library(data.table)
# font
try({
font_add_google(name = "Montserrat", family = "mont")
showtext_auto()
}, TRUE)
# load data
# or if it doesn't work grab the Rdata file from Github - see link above
load_data <- function(){
if(!file.exists("locations.csv")){
cat('\n Download may take a few minutes...\n')
url <- "http://www.tmr.qld.gov.au/~/media/aboutus/corpinfo/Open%20data/crash/locations.csv"
download.file(url, destfile = "locations.csv", method="libcurl")
}
accidents_raw <- read_csv("locations.csv")
return(accidents_raw)
}
accidents_raw <- load_data() %>% filter(Crash_Severity != "Property damage only")
# sample of brisbane inner
accidents <- accidents_raw %>%
filter(
Loc_ABS_Statistical_Area_3 == "Brisbane Inner",
Crash_Year > 2013
) %>%
mutate(fatality = Count_Casualty_Fatality > 0)
# basic leaflet
m <- leaflet(accidents) %>%
addProviderTiles(providers$Stamen.Toner, group = "Black and white") %>%
addTiles(options = providerTileOptions(noWrap = TRUE), group="Colour") %>%
addMarkers(
lng = ~Crash_Longitude_GDA94,
lat = ~Crash_Latitude_GDA94,
clusterOptions = markerClusterOptions(),
label = ~htmlEscape(Crash_DCA_Description)
) %>%
addCircleMarkers(
lng = ~Crash_Longitude_GDA94[accidents$fatality],
lat = ~Crash_Latitude_GDA94[accidents$fatality],
color = "#8B0000",
stroke = FALSE,
fillOpacity = 0.8,
group = "Fatalities"
) %>%
addHeatmap(
lng = ~Crash_Longitude_GDA94,
lat = ~Crash_Latitude_GDA94,
radius = 17,
blur = 25,
cellSize = 25
) %>%
addLayersControl(
overlayGroups = c("Fatalities"),
baseGroups = c("Black and white","Colour"),
options = layersControlOptions(collapsed = FALSE)
)
The post Queensland road accidents mapped with Shiny and leaflet in R appeared first on Daniel Oehm | Gradient Descending.
]]>The Buffalo Stampede is one of the most brutal races on the ultra calendar covering 75km and 4545m of climbing. […]
The post Buffalo Stampede Ultra 75km 2019: Race report appeared first on Daniel Oehm | Gradient Descending.
]]>The Buffalo Stampede is one of the most brutal races on the ultra calendar covering 75km and 4545m of climbing. An out-and-back from Bright to the top of Mount Buffalo tapping Mystic and Clear Spot along the way. It’s iconic for it’s brutally steep ascents that are technical and demanding.
This was my second Buffalo Ultra. My first was in 2016 and one of the hardest I’ve done. In 2016 I finished 9th in 9:56:00 and that feeling of crossing the finish line is one of the greatest.
This year my goal was to do better than my 2016 effort and given the work I’ve put in training I was confident heading into it. My goal time was 9:30:00 which was reasonable since Buffalo does it’s best to throw you into the pain cave and ruin your day.
Kicking of at 6am it was still a little dark. It doesn’t take long to hit the first climb up to Mystic, a 500m ascent to get warmed up. I took it very easy. Trying to run this first climb is a death sentence but people still try.
One sure way to destroy your race is attacking the descents too hard, too early. This is easily done on the infamous Mick’s track, at -45% gradient it’s a monster. I took it very easy, saving the quads for later and just got to the bottom safely.
It’s always a good feeling tapping Clear Spot (10km mark) since you know much of the hard climbing is over, for now. I got into a nice rhythm, smashed my mash potato, cruised up and over Keating Ridge and into the Eurobin aid station clocking up 26kms.
At this point the field was pretty spread out. I came in in 3rd place running with 4th. Refueled, left in 4th place and began the 1000m climb to Buffalo Summit. It wasn’t long until I caught and moved into 3rd place. I was feeling very strong on the climbs and took advantage of that.
At the Chalet Aid station, again refueled and quickly began the loop still in 3rd place and about 10 minutes behind second. The Galleries are always fun, scrambling through the rocks. I recall in 2016 I was cramping up massively at this point. This time however, not even close. I was in good shape.
Coming into the Chalet aid station at the 42km mark again, still feeling great. I had really closed the gap to 2nd place who was only 2 minutes ahead, and apparently looking way more fresh which was a good sign!
On the descent I caught and passed 2nd place. I tried to not get too carried away and just kept the legs turning over and remain comfortable.
About halfway down Buffalo I felt the first tweak in my ITB. I’ve been having ITB issues for the last month leading into the race which is just depressing when you’re in top form but can’t run due to an injury. It was well enough to start but this is what I feared.
It did not feel good and brought me to a walk. I didn’t walk for long and tried to take the edge off on the steep downs hoping it wouldn’t do it again.
Entering Eurobin aid station at the 53km mark in 2nd place, 14 mins behind 1st and a few minutes in front of 3rd and 4th. It was a quick stop for me, grabbed more mash, more water melon and electrolyte. I got moving pretty quickly. I was very happy with where my energy levels and felt I had heaps left in the tank. My ITB was hanging in there but it did not feel good at all and just hoping it would hold together for the last 26kms. I smashed the anti-inflammatories hoping that would do something.
I took the climb over Keating Ridge fairly easy since there were two monsters waiting for me at the 65km mark. A few minutes into the descent it happened, my ITB locked up, I couldn’t straighten my leg and every step was very painful. Running was impossible.
I walked for 500m trying to get some movement and start running again. Limping along for 100m it again flared up, very painful to put any weight on it and forced into a walk. I knew deep down my race was over.
3rd and 4th place passed me and there was nothing I could do about it. I tried running after walking for about 2km but again 100m down the road running become too painful.
At this point, 60km deep, I had to make a decision, do I continue even though the race I wanted is over and just finish by walking it in and hope I don’t do any more damage? Or tap out and not risk major damage? The thought of rupturing my ITB and turning a few weeks of recovery into a few months was not a risk I was willing to take. So, I made the tough decision to drop. No matter how right that decision was it never feels good to drop from a race.
I gave the word and got a lift back to the finish line for first aid. The medical officer felt my knee and said “Faaaark!”. It was in some ways comforting to hear, reassuring me I made the right decision to drop.
I haven’t had a good run in my last 3 races so the way it went bums me out even more. But I’m remaining positive and spending the next few weeks doing all the right things to fix it.
I was on track to crush my goal time and have one of my best races. I guess now this race goes onto the ‘I need redemption’ pile, so I’ll be back next year. For now, it’s preparing for UTA and hoping I don’t see a repeat.
View my race on Strava
The post Buffalo Stampede Ultra 75km 2019: Race report appeared first on Daniel Oehm | Gradient Descending.
]]>Every so often a problem arises where it’s appropriate to use gradient descent, and it’s fun (and / or easier) […]
The post Applying gradient descent – primer / refresher appeared first on Daniel Oehm | Gradient Descending.
]]>Every so often a problem arises where it’s appropriate to use gradient descent, and it’s fun (and / or easier) to apply it manually. Recently I’ve applied it optimising a basic recommender system to ‘unsuppressing’ suppressed tabular data. I thought I’d do a series of posts about how I’ve used gradient descent, but figured it was worth while starting with the basics as a primer / refresher.
To understand how this works gradient descent is applied we’ll use the classic example, linear regression.
A simple linear regression model is of the form
where
The objective is to find the parameters such that they minimise the mean squared error.
This is a good problem since we know the analytical solution and can check our results.
In practice you would never use gradient descent to solve a regression problem, but it is useful for learning the concepts.
Set up
library(ggplot2)
set.seed(241)
nobs <- 250
b0 <- 4
b1 <- 2
# simulate data
x <- rnorm(nobs)
y <- b0 + b1*x + rnorm(nobs, 0, 0.5)
df <- data.frame(x, y)
# plot data
g1 <- ggplot(df, aes(x = x, y = y)) +
geom_point(size = 2) +
theme_minimal()
The analytical solution is given by
# set model matrix
X <- model.matrix(y ~ x, data = df)
beta <- solve(t(X) %*% X) %*% t(X) %*% y
beta
## [,1] ## (Intercept) 4.009283 ## x 2.016444
And just to convince ourselves this is correct
# linear model formulation
lm1 <- lm(y ~ x, data = df)
coef(lm1)
## (Intercept) x ## 4.009283 2.016444
g1 + geom_abline(slope = coef(lm1)[2], intercept = coef(lm1)[1], col = "darkmagenta", size = 1)
The objective is to achieve the same result using gradient descent. It works by updating the parameters with each iteration in the direction of negative gradient to minimise the mean squared error i.e.
where is the learning rate. Here is the MSE with respect to the regression parameters. Firstly, we find the partial derivatives of .
The learning rate is to ensure we don’t jump too far with each iteration and rather some proportion of the gradient, otherwise we could end up overshooting the minimum and taking much longer to converge or not find the optimal solution at all.
Applying this to the problem above, we’ll initialise our values for to something sensible e.g. . I’ll choose a learning rate of . This is a slow burn, a learning rate of 0.1-0.2 is more appropriate for this problem but we’ll get to see the movement of the gradient better. It’s worth trying different values of to see how it changes convergence. The algorithm is setup as
# gradient descent function
gradientDescent <- function(formula, data, par.init, loss.fun, lr, iters){
formula <- as.formula(formula)
X <- model.matrix(formula, data = data)
y <- data[,all.vars(formula)[1]]
par <- loss <- matrix(NA, nrow = iters+1, ncol = 2)
par[1,] <- par.init
for(k in 1:iters){
loss[k,] <- loss.fun(X=X, y=y, par=par[k,])
par[k+1,] <- par[k,] - lr*loss[k,]
}
return(list(par = par))
}
# loss function
loss.fun <- function(X, y, par) return(-2/nrow(X)*(t(X) %*% (y - X %*% par)))
# gradient descent. not much to it really
beta <- gradientDescent(y ~ x, data = df, par.init = c(1, 1), loss.fun = loss.fun, lr = 0.01, iters = 1000)$par
# plotting results
z <- seq(1, 1001, 10)
g1 + geom_abline(slope = beta[z,2], intercept = beta[z,1], col = "darkmagenta", alpha = 0.2, size = 1)
tail(beta, 1)
## [,1] [,2] ## [1001,] 4.009283 2.016444
As expected we obtain the same result. The lines show the gradient and how the parameters converge to the optimal values. A less reasonable set of starting values still converges quickly to the optimal solution showing how well graident descent works on linear regression.
beta <- gradientDescent(y ~ x, data = df, par.init = c(6, -1), loss.fun = loss.fun, lr = 0.01, iters = 1000)$par
# plotting results
z <- seq(1, 1001, 10)
beta.df <- data.frame(b0 = beta[z,1], b1 = beta[z,2])
g1 + geom_abline(data = beta.df, mapping = aes(slope = b1, intercept = b0), col = "darkmagenta", alpha = 0.2, size = 1)
tail(beta, 1)
## [,1] [,2] ## [1001,] 4.009283 2.016444
library(gganimate)
library(magrittr)
ggif_minimal <- df %>%
ggplot(aes(x = x, y = y)) +
geom_point(size = 2) +
theme_minimal() +
geom_abline(data = beta.df, mapping = aes(slope = b1, intercept = b0), col = "darkmagenta", size = 1) +
geom_text(
data = data.frame(z, b0 = beta[z,1], b1 = beta[z,2]),
mapping = aes(
x = -2.8, y = 9,
label = paste("b0 = ", round(b0, 2), "\nb1 = ", round(b1, 2))),
hjust = 0,
size = 6
) +
transition_reveal(z) +
ease_aes("linear") +
enter_appear() +
exit_fade()
animate(ggif_minimal, width = 1920, height = 1080, fps = 80)
They are the basics of applying gradient descent. In practice there is no need to use gradient descent to solve a regression problem, but once you know how to apply it you’ll find real-world applications elsewhere that are more complicated (and interesting). If you can define the objective function and it is differentiable, you can apply gradient descent. In later posts i’ll demonstrate how I’ve applied it to real world problems. Stay tuned!
The post Applying gradient descent – primer / refresher appeared first on Daniel Oehm | Gradient Descending.
]]>It’s generally accepted that Martingale betting strategies don’t work. But people still gravitate towards them because they are intuitive. I […]
The post Martingale strategies don’t work, but we knew that – Simulation analysis in R appeared first on Daniel Oehm | Gradient Descending.
]]>It’s generally accepted that Martingale betting strategies don’t work. But people still gravitate towards them because they are intuitive. I was curious to find out how they actually perform.
Disclaimer: I do not encourage gambling, nor do I gamble myself but the games are good examples of stochastic processes.
The Classic Martingale strategy is as follows. Make an initial bet $. If you win, bet $ on the next round. If you lose, double your previous bet. In a nutshell you bet $ where is the number of losses in a row. The intention is to make your money back on the next win.
Assume we start with $100 and our initial bet is $1. We keep playing until there is not enough money in the cash pool to make the next bet. Also assume we are playing roulette on an American table and place bets on red or black which both have the probability 18/38. It only takes a streak of 6 losses before the game is over because we don’t have enough cash to double our bet for a 7th time. If we see a streak of 10 or more it really starts to get out of control.
trials | bet | cumulative_loss |
---|---|---|
1 | 1 | 1 |
2 | 2 | 3 |
3 | 4 | 7 |
4 | 8 | 15 |
5 | 16 | 31 |
6 | 32 | 63 |
7 | 64 | 127 |
8 | 128 | 255 |
9 | 256 | 511 |
10 | 512 | 1023 |
The probability of losing 6 in a row is . Sounds unlikely, but it will occur more often than you think. With each win we will win $1, so once we have won 27 times we’ll have enough cash in order to afford a losing streak of 6 and bet on the 7th.
It’s more likely we’ll have a few wins and losses before observing a long losing streak that takes us out of the game. The question is how many trials (spins of the roulette wheel) will we place bets on before we lose our money and play stops? A slight variation I’ve applied is, if there is not enough money left to double the bet, we will simply bet the remainder in the cash pool, in other words go all-in. More in line with what someone might do.
This simulation was a relatively lucky one, winning over $170 and almost 400 trials, however one bad streak and it’s all over. You can plot as many simulations as you like, some are shorter and some are longer but they all end the same way.
This is a typical pattern for the classic strategy. If you wish to see more try running the code at the end of this post or view this set.
Instead of doubling the bet upon a loss, double the bet upon a win. Not surprisingly this ends the same way as the classic strategy.
The players cash amount slowly decreases by $1 on each loss. Occasionally there is a big win, rather than a big loss. With this strategy you can watch your money dwindle away rather than vanish in front of your eyes.
This plot is using the same win-lose sequence as the one above. In this case the maximum cash held throughout the game is higher with the classic strategy than the reverse.
Click here to see more simulations.
These strategies were simulated 20,000 times. The distribution of the number of trials shows how long a typical game will last until bankruptcy. The classic strategy has a very long tail, so potentially could be playing for a very long time. The maximum number of trials is this simulation was 64254. But you could also be playing for a very short time.
The reverse strategy has a slightly higher median number of trials but much less variable than the classic strategy meaning you can be sure to play between 166 and 217 spins.
min | 2.5% | 10% | 50% | 90% | 97.5% | max | |
---|---|---|---|---|---|---|---|
classic | 7 | 10 | 22 | 170 | 1132 | 2759 | 64254 |
reverse | 152 | 167 | 172 | 191 | 207 | 213 | 226 |
Assume the goal is to double your money. What is the probability you’ll double your money before going bust? After 20,000 simulations for both strategies, the probability you will double your money using…
The Classic Martingale strategy tends to do better on average, but only slightly. Neither of these strategies are better than simply playing once and going all-in which is 0.47
The distribution of the maximum amount of cash held by the player at any given time during the game shows the classic strategy has the slight edge over the reverse strategy on average. Although the reverse strategy has the potential for massive wins if you score a good winning streak.
min | 2.5% | 10% | 50% | 90% | 97.5% | max | |
---|---|---|---|---|---|---|---|
classic | 100 | 101 | 107 | 158 | 425 | 926 | 19894 |
reverse | 100 | 100 | 102 | 133 | 534 | 2062 | 131163 |
However, keep in mind these simulations all resulted in total loss at the end of the game. Perhaps the key is to know when to stop?
Rather than stopping once you reach a specified amount which may not happen, stop when the bet reaches a specified amount.
We could specify a streak length, however a better idea is to specify a ratio of bet to total cash. This way the is stopping condition is dynamic. For example, if there is a winning streak we’ll have more money in which to bet.
Essentially by using this ratio we are fixing a certain level risk rather than betting amount. The ratio is calculated as
If we fix a stopping ratio of 0.1, we could place 4 bets before the ratio exceeds 0.1. If our initial cash pool was $200 we could place 5 bets until the ratio exceeds 0.1.
If we stop as soon as this ratio is reached it means we’re on a losing streak so it makes sense in the case of the classic strategy to bet again until our next win and walk away.
In the case of the reverse it makes sense to leave as soon as the ratio is met since we’re on a winning streak.
There are clear differences and similarities between the two strategies.
Overall the stopping strategies help to minimise loss rather than lock in wins, so on average you will still lose your money.
The code for the plots in this post can be found on github. The Martingale function is below.
# libraries
library(tidyverse)
# martingale function
martingale <- function(bet, cash, p, stop_condn = Inf, stop_factor = Inf, reverse = FALSE, plot = TRUE, stop_on_next_win = TRUE){
bet_vec <- vector(mode = "numeric")
cash_vec <- vector(mode = "numeric")
outcome <- vector(mode = "numeric")
winnings <- vector(mode = "numeric")
total_cash <- vector(mode = "numeric")
trial <- 0
total_cash[1] <- cash
while(total_cash[max(trial, 1)] > 0){
# iterate through trials
trial <- trial + 1
# update cash pool
if(trial == 1){
cash_vec[trial] <- cash
}else{
cash_vec[trial] <- total_cash[trial-1]
}
# set bet
if(!reverse){
if(outcome[trial - 1] == 1 || trial == 1){
bet_vec[trial] <- bet
}else{
bet_vec[trial] <- min(2*bet_vec[trial-1], cash_vec[trial]) # if there isn't enough to double the bet just bet what is left
}
}else{
if(outcome[trial - 1] == 0 || trial == 1){
bet_vec[trial] <- bet
}else{
bet_vec[trial] <- min(2*bet_vec[trial-1], cash_vec[trial]) # if there isn't enough to double the bet just bet what is left
}
}
# stop condition
if(bet_vec[trial]/cash_vec[trial] > stop_condn){
if(stop_on_next_win & !reverse){
stop_trigger <- TRUE
}else{
outcome[trial] <- NA
winnings[trial] <- NA
total_cash[trial] <- cash_vec[trial]
break
}
}
outcome[trial] <- sample(c(0,1), 1, prob = c(1-p, p))
winnings[trial] <- bet_vec[trial]*outcome[trial] - bet_vec[trial]*(1-outcome[trial])
total_cash[trial] <- cash_vec[trial] + winnings[trial]
# stop condition
if(total_cash[trial] >= stop_factor*cash) break
}
# make the plot
g1 <- NULL
if(plot){
df <- data.frame(trials = 1:trial, cash = total_cash)
gg <- ggplot() +
geom_line(data = df, mapping = aes(x = trials, y = cash), col = "darkmagenta", lty = 1, size = 1) +
geom_hline(yintercept = cash_vec[1], col = "grey", lty = 2) +
theme_minimal() +
labs(
x = "Number of spins",
y = "Total cash in hand",
title = ifelse(reverse, "Reverse Martingale strategy", "Martingale strategy"),
subtitle = "The growth and decline of the gamblers cash pool - it always ends the same way"
) +
ylim(0, NA)
print(gg)
}
return(list(
bet = bet_vec,
cash = cash_vec,
outcome = outcome,
total_cash = total_cash,
trials = trial,
plot = gg))
}
# run the simulation and plot the output
# try different parameters to see the effect
martingale(1, 100, 18/38, reverse = FALSE, plot = TRUE, stop_condn = 1,
stop_on_next_win = TRUE)
The post Martingale strategies don’t work, but we knew that – Simulation analysis in R appeared first on Daniel Oehm | Gradient Descending.
]]>ggplot – You can spot one from a mile away, which is great! And when you do it’s a silent […]
The post Adding Custom Fonts to ggplot in R appeared first on Daniel Oehm | Gradient Descending.
]]>ggplot – You can spot one from a mile away, which is great! And when you do it’s a silent fist bump. But sometimes you want more than the standard theme.
Fonts can breathe new life into your plots, helping to match the theme of your presentation, poster or report. This is always a second thought for me and need to work out how to do it again, hence the post.
There are two main packages for managing fonts – extrafont
, and showtext
.
A relatively old package and it’s not well supported unfortunately. You can run into problems, however the base functions work well enough.
The fonts in the system directory are first imported into the extrafontdb
using font_import()
. This only needs to be run once in order to load the fonts into the right directory. Secondly, they are registered in R using loadfonts()
for your specific device. The fonts need to be loaded each session.
library(tidyverse)
library(extrafont)
library(cowplot)
# import fonts - only once
font_import()
# load fonts - every session
loadfonts(device = "win", quiet = TRUE)
Below are all the available fonts (that I have – click on the image to enlarge).
To use these fonts in a plot, change the text family using one of the names above. For demonstration I’ll use the Antigua corn data from the DAAG package.
library(DAAG)
# corn plot
corn <- antigua %>%
dplyr::filter(ears > 0) %>%
ggplot(, aes(x = ears, y = harvwt, col = site)) +
geom_point(size = 4) +
scale_colour_manual(values = colorRampPalette(c("orange", "darkmagenta", "turquoise"))(8)) +
labs(title = "ANTIGUA CORN YIELDS",
x = "Ears of corn harvested",
y = "Harvest weight") +
theme(
text = element_text(family = "candara", size = 24),
plot.title = element_text(size = 30),
plot.caption = element_text(size = 28))
corn
It’s likely you’ll want more than what is available in the standard font set. You can add custom fonts with extrafont()
, however I’ve only had limited success. A better option is using showtext
.
showtext
is a package by Yixuan Qiu and it makes adding new fonts simple. There are a tonne of websites where you can download free fonts to suit pretty much any style you are going for. I’ll only touch on the key bits, so check out the vignette for more details.
The simplest way is to add fonts is via font_add_google()
. Find the font you like on Google Fonts and add it to R using the following.
library(showtext)
font_add_google(name = "Amatic SC", family = "amatic-sc")
Amatic SC can now be used by changing the font family to “amatic-sc”. For R to know how to properly render the text we first need to run showtext_auto()
prior to displaying the plot. One downside is it currently does not display in Rstudio. Either open a new graphics window with windows()
or save as an external file e.g. .png
.
# turn on showtext
showtext_auto()
Custom fonts are added by first,
.ttf()
file and unzip if neededfont_add()
to register the fontshowtext_auto()
to load the fontsfont_add(family = "docktrin", regular = "./fonts/docktrin/docktrin.ttf")
showtext_auto()
And that’s pretty much it. Given how effortless it is to add new fonts you can experiment with many different styles.
These fonts are outrageous but demonstrate that you really can go for any style, from something minimal and easy reading to something fit for a heavy metal band. For professional reports you’ll want to go for something sensible, but if you’re making a poster, website or infographic you may want to get creative e.g.
The tools are there for you to be as creative as you want to be.
The last thing to note is you’ll need to play around with different sizes given the resolution of your screen.
To turn off showtext
and use the standard fonts, run
showtext_auto(FALSE)
# load fonts
ft <- data.frame(x = sort(rep(1:4, 31))[1:nrow(fonttable())], y = rep(31:1, 4)[1:nrow(fonttable())], text_name = fonttable()$FullName)
font_plot <- ggplot(ft, aes(x = x, y = y)) +
geom_text(aes(label = text_name, family = text_name), size = 20) +
coord_cartesian(xlim = c(0.5, 4.5)) +
theme_void()
font_plot
# amatic sc
scale <- 4.5 # scale is an adjustment for a 4k screen
corn <- ggplot(antigua %>% dplyr::filter(ears > 0), aes(x = ears, y = harvwt, col = site)) +
geom_point(size = 4) +
scale_colour_manual(values = colorRampPalette(c("orange", "darkmagenta", "turquoise"))(8)) +
labs(title = "ANTIGUA CORN YIELDS",
subtitle = "Study of different treatments and their effect on corn yields in Antigua",
x = "Ears of corn harvested",
y = "Harvest weight",
caption = "@danoehm | gradientdescending.com") + theme_minimal() +
theme(
text = element_text(family = "amatic-sc", size = 22*scale),
plot.title = element_text(size = 26*scale, hjust = 0.5),
plot.subtitle = element_text(size = 14*scale, hjust = 0.5),
plot.caption = element_text(size = 12*scale),
legend.text = element_text(size = 16*scale))
png(file = "./fonts/corn yields amatic-sc.png", width = 3840, height = 2160, units = "px", res = 72*4)
corn
dev.off()
# wanted dead or alive
# generate data - unfortunately I couldn't find data on actual outlaws
n <- 20
x <- 100*runif(n)
y <- (25 + 0.005*x^2 + rnorm(n, 0, 10))*10
wanted <- data.frame(x, y) %>%
ggplot(aes(x = x, y = y)) +
geom_smooth(col = "black", lty = 2) +
geom_point(size = 4) +
theme_minimal() +
labs(title = "WANTED: DEAD OR ALIVE",
subtitle = "Relationship of the crimes committed by outlaws and the bounty on their head",
x = "CRIMES",
y = "BOUNTY",
caption = "@danoehm | gradientdescending.com") + theme_minimal() +
theme(
text = element_text(family = "docktrin", size = 16*scale),
plot.title = element_text(size = 40*scale, hjust = 0.5),
plot.subtitle = element_text(size = 14*scale, hjust = 0.5, margin = margin(t = 40)),
plot.caption = element_text(size = 10*scale),
legend.text = element_text(size = 12*scale),
panel.grid = element_line(color = "black"),
axis.title = element_text(size = 26*scale),
axis.text = element_text(color = "black"))
png(file = "./fonts/wanted1.png", width = 3840, height = 2160, units = "px", res = 72*4)
ggdraw() +
draw_image("./fonts/wanted_dead_or_alive_copped.png", scale = 1.62) + # a png of the background for the plot
draw_plot(wanted)
dev.off()
# Horror movies
horror <- read_csv("./fonts/imdb.csv")
gghorror <- horror %>%
dplyr::filter(str_detect(keywords, "horror")) %>%
dplyr::select(original_title, release_date, vote_average) %>%
ggplot(aes(x = release_date, y = vote_average)) +
geom_smooth(col = "#003300", lty = 2) +
geom_point(size = 4, col = "white") +
labs(title = "HORROR MOVIES!",
subtitle = "Average critics ratings of horror films and relationship over time",
x = "YEAR",
y = "RATING",
caption = "Data from IMDB.\n@danoehm | gradientdescending.com") + theme_minimal() +
theme(
text = element_text(family = "swamp-witch", size = 16*scale, color = "#006600"),
plot.title = element_text(size = 48*scale, hjust = 0.5),
plot.subtitle = element_text(family = "montserrat", size = 14*scale, hjust = 0.5, margin = margin(t = 30)),
plot.caption = element_text(family = "montserrat", size = 8*scale),
legend.text = element_text(size = 12*scale),
panel.grid = element_line(color = "grey30"),
axis.title = element_text(size = 26*scale),
axis.title.y = element_text(margin = margin(r = 15)),
axis.text = element_text(color = "#006600"),
plot.background = element_rect(fill = "grey10"))
png(file = "./fonts/swamp.png", width = 3840, height = 2160, units = "px", res = 72*4)
gghorror
dev.off()
The post Adding Custom Fonts to ggplot in R appeared first on Daniel Oehm | Gradient Descending.
]]>Townsville, Qld, has been inundated with torrential rain and has broken the record of the largest rainfall over a 10 […]
The post The Most Amount of Rain over a 10 Day Period on Record appeared first on Daniel Oehm | Gradient Descending.
]]>Townsville, Qld, has been inundated with torrential rain and has broken the record of the largest rainfall over a 10 day period. It has been devastating for the farmers and residents of Townsville. I looked at Townsville’s weather data to understand how significant this event was and if there have been comparable events in the past.
Where this may interest the R community is in obtaining the data. The package ‘bomrang’ is an API allowing R users to fetch weather data directly from the Australian Bureau of Meteorology (BOM) and have it returned in a tidy data frame.
Historical weather observations including rainfall, min/max temperatures and sun exposure are obtained via get_historical()
. Either the ID or the lat-long coordinates of the weather station are needed to extract the data. The ID information can be found on the BOM website by navigating to the observations page.
Using the station ID the rainfall data is extracted with the following.
suppressPackageStartupMessages(library(tidyverse))
suppressPackageStartupMessages(library(bomrang))
suppressPackageStartupMessages(library(gridExtra))
suppressPackageStartupMessages(library(magrittr))
suppressPackageStartupMessages(library(MCMCpack))
mycols <- c("darkmagenta", "turquoise")
# import data - simple as
townsville <- get_historical("032040")
And cleaning up the dates.
# fix date functions
fix.date <- function(y,m,d){
s <- paste(c(y, m, d), collapse = "-")
d <- as.Date(s, "%Y-%m-%d")
return(d)
}
for(k in 1:nrow(townsville)){
townsville$date[k] <- fix.date(townsville$Year[k], townsville$Month[k], townsville$Day[k])
}
townsville$date <- as.Date(townsville$date, origin = "1970-01-01")
# trim - for replication of when this was first run
townsville %<>% dplyr::filter(date < as.Date("2019-02-13"))
# take a look
head(townsville)
## Product_code Station_number Year Month Day Rainfall Period Quality ## 1 IDCJAC0009 32040 1941 1 1 0.0 NA Y ## 2 IDCJAC0009 32040 1941 1 2 6.6 1 Y ## 3 IDCJAC0009 32040 1941 1 3 16.5 1 Y ## 4 IDCJAC0009 32040 1941 1 4 205.5 1 Y ## 5 IDCJAC0009 32040 1941 1 5 175.0 1 Y ## 6 IDCJAC0009 32040 1941 1 6 72.9 1 Y ## date ## 1 1941-01-01 ## 2 1941-01-02 ## 3 1941-01-03 ## 4 1941-01-04 ## 5 1941-01-05 ## 6 1941-01-06
Applying a 10 day rolling window over the entire historical record it’s easy to see the significance of this rainfall event. The 8th February recorded 1259.8mm of rain in the 10 days prior. It dwarfs the previous record of 925.5mm set in 1953. It also highlights other significant events in the past, in particular 1968, 1998 and 2009 but these don’t come close to the 2019 event.
# get 10 day total
townsville$rolling10 <- 0
for(k in 10:nrow(townsville)){
townsville$rolling10[k] <- sum(townsville$Rainfall[(k-9):k])
}
# plot
ggplot(
townsville %>% dplyr::filter(date > as.Date("1940-01-01")),
aes(x = date, y = rolling10, col = rolling10)) + geom_line() +
scale_color_gradientn(colors = colorRampPalette(mycols)(32)) +
labs(y = "Total rainfall in the last 10 days")
It really was a phenomenal amount of rain. This was not the largest rainfall in a day however, that record occurred in 1998 with a massive 548.8mm of rain. In fact the 2019 floods don’t feature in the top 10 wettest days, but the consistency over 10 days made it the wettest on record.
townsville %>% arrange(desc(Rainfall)) %>% dplyr::select(date, Rainfall) %>% head(10)
## date Rainfall ## 1 1998-01-11 548.8 ## 2 1946-03-03 366.5 ## 3 1953-01-16 346.7 ## 4 1977-02-01 317.6 ## 5 1997-03-24 302.8 ## 6 1978-01-31 273.4 ## 7 2000-04-04 271.6 ## 8 1946-02-10 251.7 ## 9 2009-02-03 236.8 ## 10 1944-03-29 233.4
Townsville received over a years worth of rain in 10 days. The graph below shows the annual rainfall measurements and average annual rainfall (dotted line) given the historical records. Even with only 5-6 weeks of the year, 2019 is already one of the wettest years on record.
# calculate the total annual rainfall and rainfall to date
annual.rainfall <- townsville %>%
dplyr::filter(date > as.Date("1940-01-01")) %>%
mutate(
rainfall_to_date = as.numeric(as.POSIXlt(date)$yday < 40)*Rainfall,
rainfall_after_wet = as.numeric(as.POSIXlt(date)$yday < 90)*Rainfall
) %>%
group_by(Year) %>%
summarise(
annual = sum(Rainfall),
Feb12th = sum(rainfall_to_date),
april = sum(rainfall_after_wet),
remaining = sum(Rainfall) - sum(rainfall_to_date)
)
# bar plot
ggplot(annual.rainfall, aes(x = Year, y = annual, fill = annual)) +
geom_bar(stat = "identity") +
scale_fill_gradientn(colors = colorRampPalette(mycols)(32)) +
geom_hline(yintercept = mean(annual.rainfall$annual, na.rm = TRUE), lty = 3, col = "grey20", lwd = 1) +
labs(y = "Total annual rainfall")
On close inspection the data suggests that the first 40 years of records are less variable than from 1980 on-wards. There appears to be drier years and wetter years in the latter half.
The current record was set in 2000 at 2400mm where in this year Townsville had a few heavy rainfall events in the months up until April and some lesser events near the end of the year. Comparing 2019 to these years, there is definitely potential for 2019 to be the wettest on record.
ggplot(townsville %>% dplyr::filter(Year %in% c(1950, 1956, 2000, 2019)), aes(x = as.POSIXlt(date)$yday, y = Rainfall)) +
geom_line() +
facet_grid(Year ~ .) + labs(x = "Day of the year") +
labs(title = "Wettest 3 years on record vs 2019")
Below clearly shows which years have had significant rainfall in the first part of the year. The years which have received greater than 700mm (dotted line) are quite distinct from the bulk of the data. Since the wet season ends in April the other heavy years (like 2000) haven’t had their major events yet. This is shown in the April plot at the bottom which has a much stronger relationship (obviously). The years which experienced heavy rainfall at this time of year, in general didn’t get too much afterwards.
grid.arrange(
ggplot(annual.rainfall, aes(x = Feb12th, y = annual, col = annual)) +
geom_point(size = c(rep(2, nrow(annual.rainfall)-1), 4)) +
scale_color_gradientn(colors = colorRampPalette(mycols)(32)) +
labs(y = "Annual rainfall", x = "Total rainfall as at 12th Feb") +
geom_vline(xintercept = 700, lty = 3, col = "grey20", lwd = 1),
ggplot(annual.rainfall, aes(x = april, y = annual, col = annual)) +
geom_point(size = c(rep(2, nrow(annual.rainfall)-1), 4)) +
scale_color_gradientn(colors = colorRampPalette(mycols)(32)) +
labs(y = "Annual rainfall", x = "Total rainfall as at 1st April") +
geom_vline(xintercept = 700, lty = 3, col = "grey20", lwd = 1)
)
For what it’s worth I’ll offer a prediction for the expected annual rainfall and probability of being the wettest year on record (which, to be honest is a fools errand – tropical weather systems are pretty complex stuff)
# bayesian model
blm <- MCMCregress(annual ~ Feb12th, data = annual.rainfall %>% dplyr::filter(Feb12th > 0, Year != 2019),
sigma.mu = 235, sigma.var = 35^2) # priors from exploration - details skipped here
# prediction
x <- matrix(c(1, 1444))
pred.annual.rainfall <- data.frame(
annual = blm[,-3] %*% x + rnorm(10000, 0, sqrt(blm[,"sigma2"])), # posterior predictive distribution
exp.val = blm[,-3] %*% x) # mean distribution
# c(min(pred.annual.rainfall$annual), table(pred.annual.rainfall$annual < 1444)/10000)
n <- 1000
xstart <- rep(0, n)
xend <- rep(1700, n)
ystart <- blm[1:n,1] + blm[1:n,2]*xstart
yend <- blm[1:n,1] + blm[1:n,2]*xend
ystartp <- blm[1:n,1] + blm[1:n,2]*xstart + rnorm(n, 0, sqrt(blm[,3]))
yendp <- blm[1:n,1] + blm[1:n,2]*xend + rnorm(n, 0, sqrt(blm[,3]))
df.seg <- data.frame(xstart, xend, ystart, yend, ystartp, yendp)
exp.val <- quantile(pred.annual.rainfall$exp.val, c(0.025, 0.5, 0.975))
post.pred <- quantile(pred.annual.rainfall$annual, c(0.025, 0.5, 0.975))
prob.record.rainfall <- sum(as.numeric(pred.annual.rainfall$annual > 2400))/10000
# drawing each line of the posterior draws takes time but its a nice aesthetic
df.pred <- data.frame(x = c(1444, 1444), y = range(exp.val))
a <- 6
df.post <- data.frame(xs = rep(1444, 3) + c(0, -a, -a), xe = rep(1444, 3) + c(0, a, a),
ys = c(post.pred[1], post.pred[1], post.pred[3]), ye = c(post.pred[3], post.pred[1], post.pred[3]))
ggplot(annual.rainfall %>% dplyr::filter(Year != 2019, Feb12th > 0), aes(x = Feb12th, y = annual)) +
geom_segment(mapping = aes(x = xstart, xend = xend, y = ystartp, yend = yendp), data = df.seg, col = "grey20", alpha = I(0.05)) +
geom_segment(mapping = aes(x = xstart, xend = xend, y = ystart, yend = yend), data = df.seg, col = "darkmagenta", alpha = I(0.025)) +
geom_point() +
labs(y = "Annual rainfall", x = "Total rainfall as at 12th Feb") +
geom_line(mapping = aes(x = x, y = y), data = df.pred, size = 2) +
geom_segment(mapping = aes(x = xs, xend = xe, y = ys, yend = ye), data = df.post) +
geom_hline(yintercept = 2400, col = "red", lty = 2)
Based on this (crude) model, the expected annual rainfall credible interval is (2027, 2472) mm. Using the posterior predictive distribution for 2019, the probability 2019 will experience record rainfall is 0.29.
But let’s all hope there’s not too much more rain. There has been enough.
The post The Most Amount of Rain over a 10 Day Period on Record appeared first on Daniel Oehm | Gradient Descending.
]]>In my last post I looked at generating synthetic data sets with the ‘synthpop’ package, some of the challenges and […]
The post Synthesising Multiple Linked Data Sets and Sequences in R appeared first on Daniel Oehm | Gradient Descending.
]]>In my last post I looked at generating synthetic data sets with the ‘synthpop’ package, some of the challenges and neat things the package can do. It is simple to use which is great when you have a single data set with independent features.
This post will build on the last post by tackling other complications when attempting to synthesise data. These challenges occur regularly in practice and this post will offer a couple of solutions, but there are plenty more. If you haven’t read my previous post, check that out first and swing back here. I’ll detail how more complex synthesis can be done using synthpop.
When synthesising data, it is generally the finest grain which is synthesised. For population data the person is the finest grain, making synthesis relatively straight forward. But what happens when there are multiple tables to synthesise which are all linked by primary keys and have different grains finer than the person level? Lets look at an example.
Consider a health database which has two tables 1) a patient level table with demographic features and 2) an appointment level fact table consisting of the date of each appointment, including the diagnosis, treatment, prescriptions, etc, all the things you expect on a health record. In the second case the grain is patient x appointment.
Preserving this structure is important, otherwise there could be flow-on complications, plus there could be a loss of information. The obvious solution is to join the patient table onto the appointment table consolidating all the information, then synthesise. Right? Not always, for two key reasons,
Some variables such as appointment dates should be treated as a sequence where the current appointment is dependent on the previous appointment date. The reason being, if a patient has an initial consultation, is diagnosed with some ailment e.g. fracture, fatigue, etc, it’s likely there are follow up appointments with the doctor for treatment. Synthesis should reflect this sequence.
Other examples where sequences should be preserved (not necessarily in the health domain) may include:
or any other time series for that matter. It may not always be the case where the finest grain is a sequence, but it is likely to be. The key point is to know your data and what structure you need to preserve.
With these constraints the synthesised tables needs to retain the properties of properly synthesised data, meaning no variable is synthesised using the original data.
Synthesising data at multiple grains can be done in a number of different ways. I’ll outline one particular way:
This requires some fancy footwork with synthpop but it can be done. There are many other ways to do this however where possible I opt for the path of least resistance.
Appointment level data is first generated to use as original data. Key points,
Appointment date will be simulated to replicate the attendance pattern as outlined above. This pattern is likely to be correlated with other variables such as age and sex e.g. older people visit the doctor more. For this post 2 patterns will be simulated which are loosely correlated with age. Hidden Markov Models are used for the simulation. More complicated models such as state space models, ARIMA models or RNN’s can be used for simulation. (As a side note HMM’s are essentially a discrete form of state space models).
The cost of the appointment will be simulated from a multivariate normal distribution where and (pulling some numbers out of the air) and truncated to 0 where appropriate. These are not correlated with any other variable (in the real world they would be but for lets keep it simple for now) and the assumption of independence is reasonable. This means we don’t have to do anything fancy with synthpop and simply use the built-in functions.
Dates are particularly annoying to deal with at the best of times and simulating them is no exception. Here is one way.
suppressPackageStartupMessages(library(synthpop))
suppressPackageStartupMessages(library(tidyverse))
suppressPackageStartupMessages(library(HMM))
suppressPackageStartupMessages(library(cluster))
suppressPackageStartupMessages(library(doParallel))
suppressPackageStartupMessages(library(foreach))
suppressPackageStartupMessages(library(mvtnorm))
suppressPackageStartupMessages(library(magrittr))
suppressPackageStartupMessages(library(gridExtra))
mycols <- c("darkmagenta", "turquoise")
myseed <- 20190201
# set data
patient.df <- SD2011 %>% dplyr::select(sex, age, socprof, income, marital, depress, sport, nofriend, smoke, nociga, alcabuse, bmi)
# set unique id
patient.df$patient_id <- 1:nrow(patient.df)
head(patient.df)
## sex age socprof income marital depress sport nofriend smoke ## 1 FEMALE 57 RETIRED 800 MARRIED 6 NO 6 NO ## 2 MALE 20 PUPIL OR STUDENT 350 SINGLE 0 NO 4 NO ## 3 FEMALE 18 PUPIL OR STUDENT NA SINGLE 0 NO 20 NO ## 4 FEMALE 78 RETIRED 900 WIDOWED 16 YES 0 NO ## 5 FEMALE 54 SELF-EMPLOYED 1500 MARRIED 4 YES 6 YES ## 6 MALE 20 PUPIL OR STUDENT -8 SINGLE 5 NO 10 NO ## nociga alcabuse bmi patient_id ## 1 -8 NO 30.79585 1 ## 2 -8 NO 23.44934 2 ## 3 -8 NO 18.36547 3 ## 4 -8 NO 30.46875 4 ## 5 20 NO 20.02884 5 ## 6 -8 NO 23.87511 6
# patient group
low <- sample(1:2, nrow(patient.df), replace = TRUE, prob = c(0.25, 0.75))
high <- sample(1:2, nrow(patient.df), replace = TRUE, prob = c(0.75, 0.25))
patient.df$group <- ifelse(patient.df$age < median(patient.df$age), low, high)
# set my HMM simulation function
mysimHMM <- function (hmm, num.events, which.event = NULL){
hmm$transProbs[is.na(hmm$transProbs)] = 0
hmm$emissionProbs[is.na(hmm$emissionProbs)] = 0
states = c()
emission = c()
states = c(states, sample(hmm$States, 1, prob = hmm$startProbs))
i <- 2
if(is.null(which.event)) which.event <- hmm$Symbols[1]
while(sum(emission == which.event) < num.events) {
state = sample(hmm$States, 1, prob = hmm$transProbs[states[i-1], ])
states = c(states, state)
emi = sample(hmm$Symbols, 1, prob = hmm$emissionProbs[states[i-1], ])
emission = c(emission, emi)
i <- i + 1
}
return(list(states = states, observation = emission))
}
# initialise HMM
hmm.list <- list(initHMM(States = c("episode", "no episode"),
Symbols = c("Y", "N"),
transProbs = matrix(rbind(
c(0.92, 0.08),
c(0.02, 0.98)
)),
emissionProbs = matrix(rbind(
c(0.25, 0.75),
c(0.001, 0.999)
))),
initHMM(States = c("episode", "no episode"),
Symbols = c("Y", "N"),
transProbs = matrix(rbind(
c(0.92, 0.08),
c(0.1, 0.9)
)),
emissionProbs = matrix(rbind(
c(0.25, 0.75),
c(0.001, 0.999)
))))
# appointments start dates will be contained within a 3 year window
start.dates <- seq(as.Date("2014-01-01"), as.Date("2016-12-31"), 1)
simulate.appointment.dates <- function(){
n <- rpois(1, rgamma(1, 1)*12) + 1
appointment.dates <- mysimHMM(hmm.list[[patient.df$group[k]]], n)$observation
l <- length(appointment.dates)
dts <- sample(start.dates, 1) + 0:(l-1)
appointment_dates = dts[which(appointment.dates == "Y")]
out <- data.frame(patient_id = rep(k, length = n), appointment_dates,
cost = rmvnorm(n, mean = c(100, 500, 1000), sigma = diag(c(75, 100, 125)^2)) %>% sample(n) %>% pmax(0))
return(out)
}
ncores <- min(detectCores(), 4)
registerDoParallel(ncores)
appointment.df <- foreach(k = 1:nrow(patient.df), .combine = rbind, .export = "rmvnorm") %dopar% simulate.appointment.dates()
appointment.df %<>% dplyr::filter(appointment_dates < as.Date("2019-01-31"))
Plotting the sequence of appointment dates, we can see the structure we want to preserve.
# look at the clustering of appointments
df <- appointment.df %>% dplyr::filter(patient_id < 7)
g.date.sequences <- ggplot(df, aes(x = appointment_dates, y = patient_id)) +
geom_hline(yintercept = 1:6, col = "grey80") +
geom_point(col = "turquoise", size = 4) +
coord_cartesian(xlim = range(df$appointment_dates))
# costs
g.costs <- ggplot(appointment.df, aes(x = cost)) + geom_histogram(fill = "turquoise")
# plot them
grid.arrange(g.date.sequences, g.costs)
To create a link between the patient level data and the appointment level data the patient table is clustered using a simple hierarchical clustering procedure (it can be whatever method you think is suitable but I’m choosing hierarchical for this example).
Here 20 clusters are defined. In practice, more thought may need to go into how many clusters are defined given the problem. The clusters are joined to the appointment table to create the link.
# --------------- CLUSTER PATIENTS TO EXPAND TO APPOINTMENT LEVEL
dsim <- patient.df %>%
dplyr::select(-patient_id) %>%
dist()
clust <- hclust(dsim, method = "ward.D2")
nc <- 20
cut.clust <- cutree(clust, nc)
table(cut.clust)
## cut.clust ## 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 ## 385 179 800 626 569 532 464 277 349 92 60 35 165 199 74 4 168 11 ## 19 20 ## 8 3
# put cluster on patient table
patient.df$cluster <- cut.clust
appointment.df %<>% left_join(patient.df %>% dplyr::select(patient_id, cluster, group), by = "patient_id")
The patient table is now synthesised, including the clusters and a new patient id created.
# --------------- SYNTHESISE PATIENT DATA FRAME
p.df <- patient.df %>% dplyr::select(-patient_id, -group)
# apply rules to ensure consistency
rules.list <- list(marital = "age < 18", nociga = "smoke == 'NO'")
rules.value.list <- list(marital = "SINGLE", nociga = -8)
# setting continuous variable NA list
cont.na.list <- list(income = c(NA, -8), nofriend = c(NA, -8), nociga = c(NA, -8))
# synth
patient.synth <- syn(p.df, cont.na = cont.na.list, rules = rules.list, rvalues = rules.value.list)$syn
## ## Unexpected values (not obeying the rules) found for variable(s): nociga. ## Rules have been applied but make sure they are correct. ## Synthesis ## ----------- ## sex age socprof income marital depress sport nofriend smoke nociga ## alcabuse bmi cluster
# --------------- NEW PATIENT ID
patient.synth$patient_id <- 1:nrow(patient.synth)
head(patient.synth)
## sex age socprof income marital depress sport ## 1 FEMALE 31 EMPLOYED IN PUBLIC SECTOR 1240 MARRIED 4 YES ## 2 FEMALE 49 OTHER ECONOMICALLY INACTIVE 600 MARRIED 7 YES ## 3 FEMALE 20 OTHER ECONOMICALLY INACTIVE -8 SINGLE 0 YES ## 4 FEMALE 26 <NA> 400 SINGLE 4 NO ## 5 MALE 27 OTHER ECONOMICALLY INACTIVE 900 SINGLE 0 YES ## 6 FEMALE 53 EMPLOYED IN PRIVATE SECTOR -8 SINGLE 5 YES ## nofriend smoke nociga alcabuse bmi cluster patient_id ## 1 3 NO -8 NO 20.06920 7 1 ## 2 20 YES 10 NO 25.14861 8 2 ## 3 10 NO -8 NO 18.02596 3 3 ## 4 5 NO -8 NO 22.49135 2 4 ## 5 7 YES 20 NO 23.56663 4 5 ## 6 2 YES 10 NO 26.53376 3 6
The next step is to randomly select a true patient id from within the same cluster. By doing so each synthesised patient gets a random patient’s appointment history. Since the history is randomly selected within the same cluster we can be confident the structure of the data is retained e.g. a 50 year old male that doesn’t smoke gets a history which is realistic and supported by the data.
# --------------- SAMPLE PATIENT ID WITHIN CLUSTER IN ORDER TO LINK TO APPOINTMENT DATASET
sample.patient <- function(x){
z <- which(patient.df$cluster == x)
n <- length(which(patient.synth$cluster == x))
out <- sample(patient.df$patient_id[z], n, replace = TRUE)
return(out)
}
patient.list <- lapply(as.list(1:nc), sample.patient)
patient.synth$link.patient <- NA
for(k in 1:nc){
z <- which(patient.synth$cluster == k)
patient.synth$link.patient[z] <- patient.list[[k]]
}
# -------------- LINK TO APPOINTMENT
patient.synth %<>%
left_join(appointment.df %>% dplyr::select(-cluster), by = c("link.patient" = "patient_id")) %>%
arrange(patient_id, appointment_dates)
head(patient.synth %>% dplyr::select(patient_id, sex, age, socprof, appointment_dates, cost, cluster)) # only displaying a few variables for convenience
## patient_id sex age socprof appointment_dates ## 1 1 FEMALE 31 EMPLOYED IN PUBLIC SECTOR 2015-06-18 ## 2 2 FEMALE 49 OTHER ECONOMICALLY INACTIVE 2016-04-26 ## 3 2 FEMALE 49 OTHER ECONOMICALLY INACTIVE 2016-05-30 ## 4 2 FEMALE 49 OTHER ECONOMICALLY INACTIVE 2016-06-02 ## 5 2 FEMALE 49 OTHER ECONOMICALLY INACTIVE 2016-07-02 ## 6 2 FEMALE 49 OTHER ECONOMICALLY INACTIVE 2016-07-03 ## cost cluster ## 1 1063.82122 7 ## 2 380.20829 8 ## 3 492.43475 8 ## 4 1085.44437 8 ## 5 1038.01058 8 ## 6 68.89496 8
This data is suitable to be simulated using HMM’s (we know this for sure since we created the data set!). A new function is written to be consumed by syn()
and the data synthesised during the same process.
This is where things can get tricky and it is important to have a good handle on your data. To synthesise the sequence data, a model needs to be trained on the original data first. In practice syn()
takes care of the training in most situations. Either the function is defined to train the model before synthesis or consume pre-trained models. This will depend on the problem at hand.
I am going to skip the training step here. Training HMM’s (or any state space models) can be tricky and require careful tuning. Instead I am simply going to use the predefined HMM’s above that were used to create the original data. The primary aim is to demonstrate how to synthesise the data rather than train HMM’s.
syn.hmm()
functionThe function to pass to syn()
is given below. The list of HMM’s is passed directly to the function which returns an array of dates.
syn.hmm <- function (y, x, xp, id = patient.synth$link.patient, hmm_list = hmm.list, group = patient.synth$group, ...){
start.dates <- seq(as.Date("2014-01-01"), as.Date("2016-12-31"), 1)
uniq.id <- unique(id)
n <- length(uniq.id)
gr <- unique(cbind(id, group))
new.y <- vector(mode = "numeric", length = nrow(xp))
for(k in 1:n){
ii <- which(id == uniq.id[k])
gr <- group[k]
obs <- mysimHMM(hmm = hmm_list[[gr]], length(ii))$observation
l <- length(obs)
dts <- sample(start.dates, 1) + 0:(l-1)
new.y[ii] <- dts[which(obs == "Y")]
}
return(list(res = as.Date(new.y, origin = "1970-01-01"), fit = "hmm"))
}
To synthesise only the appointment dates and cost but ignore the remaining variables, we define the predictor matrix and methods vector.
Firstly, syn()
will switch the method to “sample” if the variable does not have a predictor. To trick the function into doing the what we want we can set the appointment dates variable to have at least 1 predictor which it won’t use because we haven’t told it to in syn.hmm()
.
Secondly, cost can be synthesised as normal and for this example is synthesised using age, sex and income for demonstration.
Thirdly, to tell the function to ignore all other variables we set the visit sequence to be appointment dates followed by cost and nothing else. It will ignore the other already synthesised variables and attach them to the data frame.
# set data frame
df <- patient.synth %>% dplyr::select(-cluster, -group)
synth.obj <- syn(df, method = rep("sample", ncol(df)), m = 0)
# create new predictor matrix
new.pm <- synth.obj$predictor.matrix
new.pm["cost", c("sex", "age", "income")] <- 1
new.pm["appointment_dates", c("link.patient")] <- 1
# set visit sequences
vseq <- c(which(colnames(new.pm) == "appointment_dates"), which(colnames(new.pm) == "cost"))
# set methods array
mth <- rep("", ncol(new.pm))
mth[vseq] <- c("hmm", "cart")
# synthesise data
synth.obj <- syn(df, method = mth, visit.sequence = vseq, predictor.matrix = new.pm)
## ## Variable(s): socprof, marital, depress, sport, nofriend, smoke, nociga, alcabuse, bmi, patient_id not synthesised or used in prediction. ## CAUTION: The synthesised data will contain the variable(s) unchanged. ## ## Synthesis ## ----------- ## appointment_dates cost
head(synth.obj$syn %>% dplyr::select(sex, age, income, appointment_dates, cost)) # only displaying a few variables for convenience
## sex age income appointment_dates cost ## 1 FEMALE 31 1240 2016-04-30 74.43372 ## 2 FEMALE 49 600 2016-06-25 68.89496 ## 3 FEMALE 49 600 2016-07-27 475.31540 ## 4 FEMALE 49 600 2016-07-28 126.08026 ## 5 FEMALE 49 600 2016-11-03 1009.78769 ## 6 FEMALE 49 600 2016-11-04 1038.01058
synth.obj$method
## sex age socprof income ## "" "" "" "" ## marital depress sport nofriend ## "" "" "" "" ## smoke nociga alcabuse bmi ## "" "" "" "" ## patient_id link.patient appointment_dates cost ## "" "" "hmm" "cart"
synth.obj$visit.sequence
## appointment_dates cost ## 15 16
# compare cost
# look at the clustering of appointments
g.date.sequences <- ggplot(synth.obj$syn %>% dplyr::filter(patient_id < 7), aes(x = appointment_dates, y = patient_id)) +
geom_hline(yintercept = 1:6, col = "grey80") +
geom_point(col = "turquoise", size = 4) +
coord_cartesian(xlim = range(df$appointment_dates))
grid.arrange(g.date.sequences, compare(synth.obj, df, vars = "cost", cols = mycols)$plot)
Again, a solid effort. Appointment dates display a similar patterns and the cost variable matches the original distributions.
The plots below show the comparison between the synthetic and original data for the mean and total cost, appointment count and mean appointment count per patient. All given by sex and marital status.
The outputs are very similar with the exception of the mean number of appointments for those in de facto relationships. This is likely a property of the small sample size in this group.
When synthesising data from any relational database these challenges will be present. There will be a business need to preserve the same relational form to ensure the synthetic data is used in the same way as the original.
There are many ways of dealing with the problem of different grains and synthesising sequences – this is only one. With respect to sequences and time series data, careful thought is needed to use the right model for synthesis and define the appropriate function.
This example is relatively straight forward to highlight a couple of key challenges. The synthpop package is flexible enough to handle these more complex real world scenarios.
If you have solved this in other ways, let me know in the comments!
The post Synthesising Multiple Linked Data Sets and Sequences in R appeared first on Daniel Oehm | Gradient Descending.
]]>Synthpop – A great music genre and an aptly named R package for synthesising population data. I recently came across […]
The post Generating Synthetic Data Sets with ‘synthpop’ in R appeared first on Daniel Oehm | Gradient Descending.
]]>Synthpop – A great music genre and an aptly named R package for synthesising population data. I recently came across this package while looking for an easy way to synthesise unit record data sets for public release. The goal is to generate a data set which contains no real units, therefore safe for public release and retains the structure of the data. From which, any inference returns the same conclusion as the original. This will be a quick look into synthesising data, some challenges that can arise from common data structures and some things to watch out for.
This example will use the same data set as in the synthpop documentation and will cover similar ground, but perhaps an abridged version with a few other things that weren’t mentioned. The SD2011 contains 5000 observations and 35 variables on social characteristics of Poland. A subset of 12 of these variables are considered.
suppressPackageStartupMessages(library(synthpop))
suppressPackageStartupMessages(library(tidyverse))
suppressPackageStartupMessages(library(sampling))
suppressPackageStartupMessages(library(partykit))
mycols <- c("darkmagenta", "turquoise")
options(xtable.floating = FALSE)
options(xtable.timestamp = "")
myseed <- 20190110
# filtering the dataset
original.df <- SD2011 %>% dplyr::select(sex, age, socprof, income, marital, depress, sport, nofriend, smoke, nociga, alcabuse, bmi)
head(original.df)
## sex age socprof income marital depress sport nofriend smoke ## 1 FEMALE 57 RETIRED 800 MARRIED 6 NO 6 NO ## 2 MALE 20 PUPIL OR STUDENT 350 SINGLE 0 NO 4 NO ## 3 FEMALE 18 PUPIL OR STUDENT NA SINGLE 0 NO 20 NO ## 4 FEMALE 78 RETIRED 900 WIDOWED 16 YES 0 NO ## 5 FEMALE 54 SELF-EMPLOYED 1500 MARRIED 4 YES 6 YES ## 6 MALE 20 PUPIL OR STUDENT -8 SINGLE 5 NO 10 NO ## nociga alcabuse bmi ## 1 -8 NO 30.79585 ## 2 -8 NO 23.44934 ## 3 -8 NO 18.36547 ## 4 -8 NO 30.46875 ## 5 20 NO 20.02884 ## 6 -8 NO 23.87511
The objective of synthesising data is to generate a data set which resembles the original as closely as possible, warts and all, meaning also preserving the missing value structure. There are two ways to deal with missing values 1) impute/treat missing values before synthesis 2) synthesise the missing values and deal with the missings later. The second option is generally better since the purpose the data is supporting may influence how the missing values are treated.
Missing values can be simply NA or some numeric code specified by the collection. A useful inclusion is the syn
function allows for different NA types, for example income, nofriend and nociga features -8 as a missing value. A list is passed to the function in the following form.
# setting continuous variable NA list
cont.na.list <- list(income = c(NA, -8), nofriend = c(NA, -8), nociga = c(NA, -8))
By not including this the -8’s will be treated as a numeric value and may distort the synthesis.
After synthesis, there is often a need to post process the data to ensure it is logically consistent. For example, anyone who is married must be over 18 and anyone who doesn’t smoke shouldn’t have a value recorded for ‘number of cigarettes consumed’. These rules can be applied during synthesis rather than needing adhoc post processing.
# apply rules to ensure consistency
rules.list <- list(
marital = "age < 18",
nociga = "smoke == 'NO'")
rules.value.list <- list(
marital = "SINGLE",
nociga = -8)
The variables in the condition need to be synthesised before applying the rule otherwise the function will throw an error. In this case age should be synthesised before marital and smoke should be synthesised before nociga.
There is one person with a bmi of 450.
SD2011[which.max(SD2011$bmi),]
## sex age agegr placesize region edu ## 3953 FEMALE 58 45-59 URBAN 20,000-100,000 Lubelskie SECONDARY ## eduspec socprof unempdur income ## 3953 economy and administration LONG-TERM SICK/DISABLED -8 1300 ## marital mmarr ymarr msepdiv ysepdiv ls depress ## 3953 MARRIED 4 1982 NA NA MOSTLY SATISFIED 1 ## trust trustfam trustneigh sport nofriend smoke ## 3953 ONE CAN`T BE TOO CAREFUL YES YES YES 2 NO ## nociga alcabuse alcsol workab wkabdur wkabint wkabintdur emcc englang ## 3953 -8 NO NO NO -8 NO <NA> <NA> NONE ## height weight bmi ## 3953 149 NA 449.9797
Their weight is missing from the data set and would need to be for this to be accurate. I don’t believe this is correct! So, any bmi over 75 (which is still very high) will be considered a missing value and corrected before synthesis.
# getting around the error: synthesis needs to occur before the rules are applied
original.df$bmi <- ifelse(original.df$bmi > 75, NA, original.df$bmi)
Consider a data set with variables. In a nutshell, synthesis follows these steps:
The data can now be synthesised using the following code.
# synthesise data
synth.obj <- syn(original.df, cont.na = cont.na.list, rules = rules.list, rvalues = rules.value.list, seed = myseed)
## ## Unexpected values (not obeying the rules) found for variable(s): nociga. ## Rules have been applied but make sure they are correct. ## Synthesis ## ----------- ## sex age socprof income marital depress sport nofriend smoke nociga ## alcabuse bmi
synth.obj
## Call: ## ($call) syn(data = original.df, rules = rules.list, rvalues = rules.value.list, ## cont.na = cont.na.list, seed = myseed) ## ## Number of synthesised data sets: ## ($m) 1 ## ## First rows of synthesised data set: ## ($syn) ## sex age socprof income marital depress sport ## 1 FEMALE 45 EMPLOYED IN PUBLIC SECTOR 1500 SINGLE 5 YES ## 2 MALE 65 OTHER ECONOMICALLY INACTIVE 500 SINGLE 5 YES ## 3 FEMALE 17 PUPIL OR STUDENT NA SINGLE 3 NO ## 4 FEMALE 48 EMPLOYED IN PRIVATE SECTOR 1000 MARRIED 0 NO ## 5 MALE 50 EMPLOYED IN PRIVATE SECTOR 1300 MARRIED 0 NO ## 6 FEMALE 65 RETIRED 1200 WIDOWED 6 NO ## nofriend smoke nociga alcabuse bmi ## 1 3 NO -8 NO 26.12245 ## 2 30 NO -8 NO 29.32099 ## 3 7 NO -8 NO 22.40588 ## 4 10 NO -8 NO 25.81663 ## 5 12 YES 20 YES 27.17063 ## 6 15 NO -8 NO 27.51338 ## ... ## ## Synthesising methods: ## ($method) ## sex age socprof income marital depress sport nofriend ## "sample" "cart" "cart" "cart" "cart" "cart" "cart" "cart" ## smoke nociga alcabuse bmi ## "cart" "cart" "cart" "cart" ## ## Order of synthesis: ## ($visit.sequence) ## sex age socprof income marital depress sport nofriend ## 1 2 3 4 5 6 7 8 ## smoke nociga alcabuse bmi ## 9 10 11 12 ## ## Matrix of predictors: ## ($predictor.matrix) ## sex age socprof income marital depress sport nofriend smoke ## sex 0 0 0 0 0 0 0 0 0 ## age 1 0 0 0 0 0 0 0 0 ## socprof 1 1 0 0 0 0 0 0 0 ## income 1 1 1 0 0 0 0 0 0 ## marital 1 1 1 1 0 0 0 0 0 ## depress 1 1 1 1 1 0 0 0 0 ## sport 1 1 1 1 1 1 0 0 0 ## nofriend 1 1 1 1 1 1 1 0 0 ## smoke 1 1 1 1 1 1 1 1 0 ## nociga 1 1 1 1 1 1 1 1 1 ## alcabuse 1 1 1 1 1 1 1 1 1 ## bmi 1 1 1 1 1 1 1 1 1 ## nociga alcabuse bmi ## sex 0 0 0 ## age 0 0 0 ## socprof 0 0 0 ## income 0 0 0 ## marital 0 0 0 ## depress 0 0 0 ## sport 0 0 0 ## nofriend 0 0 0 ## smoke 0 0 0 ## nociga 0 0 0 ## alcabuse 1 0 0 ## bmi 1 1 0
The compare
function allows for easy checking of the sythesised data.
# compare the synthetic and original data frames
compare(synth.obj, original.df, nrow = 3, ncol = 4, cols = mycols)$plot
Solid. The distributions are very well preserved. Did the rules work on the smoking variable?
# checking rules worked
table(synth.obj$syn[,c("smoke", "nociga")])
## nociga ## smoke -8 1 2 3 4 5 6 7 8 10 12 14 15 ## YES 13 13 13 32 5 49 17 12 28 308 26 3 135 ## NO 3777 0 0 0 0 0 0 0 0 0 0 0 0 ## nociga ## smoke 16 18 20 21 22 24 25 26 29 30 35 40 50 ## YES 5 7 418 2 1 2 31 2 2 51 1 33 1 ## NO 0 0 0 0 0 0 0 0 0 0 0 0 0 ## nociga ## smoke 60 ## YES 1 ## NO 0
They did. All non-smokers have missing values for the number of cigarettes consumed.
compare
can also be used for model output checking. A logistic regression model will be fit to find the important predictors of depression. The depression variable ranges from 0-21. This will be converted to
This split leaves 3822 (0)’s and 1089 (1)’s for modelling.
# ------------ MODEL COMPARISON
glm1 <- glm.synds(ifelse(depress > 7, 1, 0) ~ sex + age + log(income) + sport + nofriend + smoke + alcabuse + bmi, data = synth.obj, family = "binomial")
## Warning in log(income): NaNs produced
summary(glm1)
## Warning: Note that all these results depend on the synthesis model being correct. ## ## Fit to synthetic data set with a single synthesis. ## Inference to coefficients and standard errors that ## would be obtained from the observed data. ## ## Call: ## glm.synds(formula = ifelse(depress > 7, 1, 0) ~ sex + age + log(income) + ## sport + nofriend + smoke + alcabuse + bmi, family = "binomial", ## data = synth.obj) ## ## Combined estimates: ## xpct(Beta) xpct(se.Beta) xpct(z) Pr(>|xpct(z)|) ## (Intercept) -1.00819811 0.71605764 -1.4080 0.1591356 ## sexFEMALE 0.35681507 0.10010909 3.5643 0.0003649 *** ## age 0.09004527 0.00384758 23.4031 < 2.2e-16 *** ## log(income) -0.68041355 0.08829602 -7.7060 1.298e-14 *** ## sportNO -0.66446597 0.11880451 -5.5929 2.233e-08 *** ## nofriend 0.00028325 0.00679104 0.0417 0.9667305 ## smokeNO 0.08544511 0.11894243 0.7184 0.4725269 ## alcabuseNO -0.72124437 0.21444108 -3.3634 0.0007700 *** ## bmi 0.00644972 0.01036016 0.6226 0.5335801 ## --- ## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
# compare to the original data set
compare(glm1, original.df, lcol = mycols)
## Warning in log(income): NaNs produced
## ## Call used to fit models to the data: ## glm.synds(formula = ifelse(depress > 7, 1, 0) ~ sex + age + log(income) + ## sport + nofriend + smoke + alcabuse + bmi, family = "binomial", ## data = synth.obj) ## ## Differences between results based on synthetic and observed data: ## Std. coef diff p value CI overlap ## (Intercept) -1.26517500 0.206 0.6772453 ## sexFEMALE 0.27373709 0.784 0.9301678 ## age 0.85530291 0.392 0.7818065 ## log(income) 0.98568572 0.324 0.7485449 ## sportNO 0.16485706 0.869 0.9579439 ## nofriend 1.31926470 0.187 0.6634467 ## smokeNO -0.07386025 0.941 0.9811578 ## alcabuseNO 0.02501199 0.980 0.9936193 ## bmi -0.17971185 0.857 0.9541543 ## ## Measures for one synthesis and 9 coefficients ## Mean confidence interval overlap: 0.8542318 ## Mean absolute std. coef diff: 0.5714007 ## Lack-of-fit: 5.49732; p-value 0.789 for test that synthesis model is compatible ## with a chi-squared test with 9 degrees of freedom ## ## Confidence interval plot:
While the model needs more work, the same conclusions would be made from both the original and synthetic data set as can be seen from the confidence interavals. Occaisonally there may be contradicting conclusions made about a variable, accepting it in the observed data but not in the synthetic data for example. This scenario could be corrected by using different synthesis methods (see documentation) or altering the visit sequence.
Released population data are often counts of people in geographical areas by demographic variables (age, sex, etc). Some cells in the table can be very small e.g. <5. For privacy reasons these cells are suppressed to protect peoples identity. With a synthetic data, suppression is not required given it contains no real people, assuming there is enough uncertainty in how the records are synthesised.
The existence of small cell counts opens a few questions,
To test this 200 areas will be simulated to replicate possible real world scenarios. Area size will be randomly allocated ensuring a good mix of large and small population sizes. Population sizes are randomly drawn from a Poisson distribution with mean . If large, is drawn from a uniform distribution on the interval [20, 40]. If small, is set to 1.
# ---------- AREA
# add area flag to the data frame
area.label <- paste0("area", 1:200)
a <- sample(0:1, 200, replace = TRUE, prob = c(0.5, 0.5))
lambda <- runif(200, 20, 40)*(1-a) + a
prob.dist <- rpois(200, lambda)
area <- sample(area.label, 5000, replace = TRUE, prob = prob.dist)
# attached to original data frame
original.df <- SD2011 %>% dplyr::select(sex, age, socprof, income, marital, depress, sport, nofriend, smoke, nociga, alcabuse, bmi)
original.df$bmi <- ifelse(original.df$bmi > 75, NA, original.df$bmi)
original.df <- cbind(original.df, area) %>% arrange(area)
The sequence of synthesising variables and the choice of predictors is important when there are rare events or low sample areas. If Synthesised very early in the procedure and used as a predictor for following variables, it’s likely the subsequent models will over-fit. Synthetic data sets require a level of uncertainty to reduce the risk of statistical disclosure, so this is not ideal.
Fortunately syn
allows for modification of the predictor matrix. To avoid over-fitting, ‘area’ is the last variable to by synthesised and will only use sex and age as predictors. This is reasonable to capture the key population characteristics. Additionally, syn
throws an error unless maxfaclevels
is changed to the number of areas (the default is 60). This is to prevent poorly synthesised data for this reason and a warning message suggest to check the results, which is good practice.
# synthesise data
# m is set to 0 as a hack to set the synds object and the predictor matrix
synth.obj.b <- syn(original.df, cont.na = cont.na.list, rules = rules.list, rvalues = rules.value.list, maxfaclevels = 200, seed = myseed, m = 0)
## ## Unexpected values (not obeying the rules) found for variable(s): nociga. ## Rules have been applied but make sure they are correct.
# changing the predictor matrix to predict area with only age and sex
new.pred.mat <- synth.obj.b$predictor.matrix
new.pred.mat["area",] <- 0
new.pred.mat["area",c("age", "sex")] <- 1
new.pred.mat
## sex age socprof income marital depress sport nofriend smoke ## sex 0 0 0 0 0 0 0 0 0 ## age 1 0 0 0 0 0 0 0 0 ## socprof 1 1 0 0 0 0 0 0 0 ## income 1 1 1 0 0 0 0 0 0 ## marital 1 1 1 1 0 0 0 0 0 ## depress 1 1 1 1 1 0 0 0 0 ## sport 1 1 1 1 1 1 0 0 0 ## nofriend 1 1 1 1 1 1 1 0 0 ## smoke 1 1 1 1 1 1 1 1 0 ## nociga 1 1 1 1 1 1 1 1 1 ## alcabuse 1 1 1 1 1 1 1 1 1 ## bmi 1 1 1 1 1 1 1 1 1 ## area 1 1 0 0 0 0 0 0 0 ## nociga alcabuse bmi area ## sex 0 0 0 0 ## age 0 0 0 0 ## socprof 0 0 0 0 ## income 0 0 0 0 ## marital 0 0 0 0 ## depress 0 0 0 0 ## sport 0 0 0 0 ## nofriend 0 0 0 0 ## smoke 0 0 0 0 ## nociga 0 0 0 0 ## alcabuse 1 0 0 0 ## bmi 1 1 0 0 ## area 0 0 0 0
# synthesising with new predictor
synth.obj.b <- syn(original.df, cont.na = cont.na.list, rules = rules.list, rvalues = rules.value.list, maxfaclevels = 200, seed = myseed, proper = TRUE, predictor.matrix = new.pred.mat)
## ## Unexpected values (not obeying the rules) found for variable(s): nociga. ## Rules have been applied but make sure they are correct. ## Synthesis ## ----------- ## sex age socprof income marital depress sport nofriend smoke nociga ## alcabuse bmi area
# compare the synthetic and original data frames
compare(synth.obj.b, original.df, vars = "area", nrow = 1, ncol = 1, cols = c("darkmagenta", "turquoise"), stat = "counts")$plot
The area variable is simulated fairly well on simply age and sex. It captures the large and small areas, however the large areas are relatively more variable. This could use some fine tuning, but will stick with this for now.
tab.syn <- synth.obj.b$syn %>%
dplyr::select(area, sex) %>%
table()
tab.orig <- original.df %>%
dplyr::select(area, sex) %>%
table()
##
## synthetic
## sex
## area MALE FEMALE
## area1 2 0
## area10 15 19
## area101 0 6
## area103 35 22
## area105 3 0
## area106 30 31
## area107 0 3
## area108 28 11
## area110 17 37
## area112 23 24
## area113 0 2
## area114 21 52
## area115 30 28
##
## original
## sex
## area MALE FEMALE
## area1 1 0
## area10 19 27
## area101 1 3
## area103 29 26
## area105 1 0
## area106 22 33
## area107 0 4
## area108 28 18
## area110 18 33
## area112 23 25
## area113 1 2
## area114 28 39
## area115 23 44
d <- data.frame(difference = as.numeric(tab.syn - tab.orig), sex = c(rep("Male", 150), rep("Female", 150)))
ggplot(d, aes(x = difference, fill = sex)) + geom_histogram() + facet_grid(sex ~ .) + scale_fill_manual(values = mycols)
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
The method does a good job at preserving the structure for the areas. How much variability is acceptable is up to the user and intended purpose. Using more predictors may provide a better fit. The errors are distributed around zero, a good sign no bias has leaked into the data from the synthesis.
tab.syn <- synth.obj.b$syn %>%
dplyr::select(marital, sex) %>%
table()
tab.orig <- original.df %>%
dplyr::select(marital, sex) %>%
table()
##
## synthetic
## sex
## marital MALE FEMALE
## SINGLE 667 565
## MARRIED 1352 1644
## WIDOWED 76 398
## DIVORCED 81 169
## LEGALLY SEPARATED 2 0
## DE FACTO SEPARATED 6 36
##
## original
## sex
## marital MALE FEMALE
## SINGLE 657 596
## MARRIED 1382 1597
## WIDOWED 66 465
## DIVORCED 62 137
## LEGALLY SEPARATED 6 1
## DE FACTO SEPARATED 4 18
At higher levels of aggregation the structure of tables is more maintained.
‘synthpop’ is built with a similar function to the ‘mice’ package where user defined methods can be specified and passed to the syn
function using the form syn.newmethod
. To demonstrate this we’ll build our own neural net method.
As a minimum the function takes as input
y
– observed variable to be synthesisedx
– observed predictorsxp
– synthesised predictorssyn.nn <- function (y, x, xp, smoothing, size = 6, ...) {
for (i in which(sapply(x, class) != sapply(xp, class))) xp[,
i] <- eval(parse(text = paste0("as.", class(x[, i]),
"(xp[,i])", sep = "")))
# model and prediction
nn <- nnet(y ~ ., data = as.data.frame(cbind(y, x)), size = size, trace = FALSE)
probs <- predict(nn, newdata = xp)
probs[is.na(probs)] <- 0
for(k in 1:nrow(probs)){
if(sum(probs[k,]) == 0){
probs[k,] <- 1
}
}
new <- apply(probs, 1, function(x) colnames(probs)[sample(1:ncol(probs), 1, prob = x)]) %>% unname()
# if smothing
if (!is.factor(y) & smoothing == "density")
new <- syn.smooth(new, y)
# return
return(list(res = new, fit = nn))
}
Set the method vector to apply the new neural net method for the factors, ctree
for the others and pass to syn
.
# methods vector
meth.vec <- c("sample", ifelse(sapply(original.df[,-1], is.factor), "nn", "ctree"))
meth.vec[13] <- "ctree"
# synthesise
synth.obj.c <- syn(original.df, method = meth.vec, cont.na = cont.na.list, rules = rules.list, rvalues = rules.value.list, maxfaclevels = 200, seed = myseed, predictor.matrix = new.pred.mat)
## ## Unexpected values (not obeying the rules) found for variable(s): nociga. ## Rules have been applied but make sure they are correct. ## Synthesis ## ----------- ## sex age socprof income marital depress sport nofriend smoke nociga ## alcabuse bmi area
# compare the synthetic and original data frames
compare(synth.obj.c, original.df, vars = colnames(original.df)[-13], nrow = 3, ncol = 4, cols = c("darkmagenta", "turquoise"), stat = "counts")$plot
The results are very similar to above with the exception of ‘alcabuse’, but this demonstrates how new methods can be applied.
The ‘synthpop’ package is great for synthesising data for statistical disclosure control or creating training data for model development. Other things to note,
Following posts tackle complications that arise when there are multiple tables at different grains that are to be synthesised. Further complications arise when their relationships in the database also need to be preserved. Ideally the data is synthesised and stored alongside the original enabling any report or analysis to be conducted on either the original or synthesised data. This will require some trickery to get synthpop to do the right thing, but is possible.
The post Generating Synthetic Data Sets with ‘synthpop’ in R appeared first on Daniel Oehm | Gradient Descending.
]]>In my last post I coded Liar’s Dice in R and some brainless bots to play against. I build on […]
The post Q-learning example with Liar’s Dice in R appeared first on Daniel Oehm | Gradient Descending.
]]>In my last post I coded Liar’s Dice in R and some brainless bots to play against. I build on that post by using Q-learning to train an agent to play Liar’s Dice well.
Spoiler alert: The brainless bots aren’t actually that brainless! More on that later.
Note – I’ll share enough code to run the simulations however the full code can be found on Github. Check out my previous post for the rules to Liar’s Dice.
Firstly, some background. Q-learning is a reinforcement learning algorithm which trains an agent to make the right decisions given the environment it is in and what tasks it needs to complete. The task may be navigating a maze, playing a game, driving a car, flying a drone or learning which offers to make to increase customer retention. In the case of Liar’s Dice, when to call or raise the bid. The agent learns through
This cycle is repeated until eventually it learns the best decisions to make given it’s situation.
The environment is formulated as a Markov Decision Process defined by a set of states . By taking an action the agent will transition to a new state with probability . After transitioning to the new state the agent receives a reward which will either tell it this was a good move, or this was a bad move. It may also tell the agent this was neither a good or bad move until it reaches a win/lose state.
Finding the optimal policy of an MDP can be done through value iteration and policy iteration. The optimal policy and state value functions are given by
and
where and are learning rate and discount parameters. The above equations rely on knowing (or often, making crude approximations to) the transition probabilities. The benefit of Q-learning is the transition probabilities are not required, instead they are derived through simulation. The Q function is given by
The cells in the Q matrix represent the ‘quality’ of taking action given state . After each action the Q matrix is updated. After many iterations, the agent would have explored many states and determined which states and action pairs led to the best outcomes. Now it has the information needed to make the optimal choice by taking the action which leads to the maximum overall reward indicated by the largest Q value.
The key is to formulate the states the agent can be in at any point in the game and the reward for transitioning from one state to another. MDP’s can become very large very quickly if every possible state is accounted for so it’s important to identify the key information and the redundancies.
The key pieces of information needed to make a decision in Liar’s Dice are
Consider the player has 6 dice, this gives a possible possible hands and this hasn’t yet factored in the bid or the total number of dice on the table. You can see how the number of states blows out.
To make a good decision on whether or not to raise or call the player only needs to know how many dice of the current bid value the player has in their hand and the chance the remainder are in the unseen opponents dice. Essentially, the dice value isn’t required in the formulation of the game state.
The states are given by 3 values.
The last point is the combination of the information given by the players dice and the bid. The probability there is at least the bid quantity on the table is calculated and reduced to a bucket.
where
and is the unknown quantity needed and is the number of unobserved dice on the table. This reduces the state space down to something more manageable. For this example we’ll use a maximum of 20 buckets i.e. (5%, 10%, …, 100%). Overkill for small numbers of dice, but it doesn’t hurt.
The function below generates the complete game states given the number of dice and players.
# generate games states generate.game.states <- function(nplayers, ndice, bin.size = 20){ # create the basic data frame with total dice and player dice count total.dice <- nplayers*ndice states <- t(combn(rep(0:ndice, nplayers), nplayers)) %>% as.data.frame() %>% distinct() %>% set_names(paste0("p", 1:nplayers)) states$total <- rowSums(states) states <- states %>% select(total, p1) %>% arrange(total, p1) %>% filter(total > 0) %>% distinct # add in the probability bucket state game.states <- data.frame() for(k in 1:nrow(states)){ total <- states$total[k] p1 <- states$p1[k] for(j in 0:bin.size){ game.states <- rbind(game.states, c(total, p1, j)) } } colnames(game.states) <- c("total", "p1", "prob_cat") return(game.states) } gs <- generate.game.states(4, 6) dim(gs)
## [1] 2772 3
The state space reduces to only 2772 states for a game with 4 players with 6 dice each where previously it would have been several orders of magnitude larger. There are still redundant states in this formulation (mostly because I’m lazy) but it’s been reduced enough to be viable and won’t significantly slow down training.
To simplify the problem, the agent only needs to decide on whether to call or raise. A more complicated problem would be to allow the agent to choose what the new bid should be (this is for a later post, for now we’ll keep it simple).
The agent will explore the states randomly and will eventually learn when it’s a good time to call and a good time to raise. For example if the bid is three 5’s and the agent has three 5’s in hand, the obvious action is to raise. The agent won’t know this at first but will soon work it out.
If the agent raises, it will first randomly select whether to bluff or play the numbers. By bluffing the agent randomly selects a dice value and increases the bid by 1. If the agent plays the numbers it selects the value it has the most of and raises the quantity by 1.
When the agent has been trained it makes the optimal decision by selecting the maximum Q value given the current state .
The agent needs to know how good taking action was. The reward matrix is defined by rewarding
The reward values are arbitrary but work well in this case. We want to emphasize that losing a die is bad but losing the game is worse. While any state other than the terminal states i.e. when the number dice the player has is 0 (lose) or the same as the total number of dice on the table (win) no state is particularly good/bad but the transition from one to the other is what triggers the reward or penalty. Therefore, each reward matrix will be an square matrix where is the total number of states.
There is a reward matrix for each action and stored in a list. For Liar’s Dice this isn’t necessary since the rewards and penalties are same whether the player raises or calls and transitions to another state. However, the framework is there for actions to have different rewards.
# generate reward matrix generate.reward.matrix <- function(game.states){ reward <- matrix(0, nrow = nrow(game.states), ncol = nrow(game.states)) for(i in 1:nrow(reward)){ # which state are we in total <- game.states$total[i] p1 <- game.states$p1[i] # small penalty for losing a die reward[i, p1 - game.states$p1 == 1 & total - game.states$total == 1 & game.states$p1 != game.states$total] <- -1 # small reward for others losing a die reward[i, p1 == game.states$p1 & total - game.states$total == 1 & game.states$p1 != game.states$total & p1 > 0] <- 1 # fail states - when players dice count is 0 if(p1 == 1){ reward[i, which(total - game.states$total == 1 & game.states$p1 == 0)] <- -10 } # win states when the player dice count equals the total dice count if(total - p1 == 1){ reward[i, game.states$total == p1 & game.states$p1 == p1] <- 10 } } return(reward) } rw <- generate.reward.matrix(gs) reward <- list(raise = rw, call = rw)
The process follows the steps below.
Assume player 1 raises on their turn. In a 4 person game, player 1 may actually transition to multiple other states before control returns. For each other raise or call by the other players, the game state will change for player 1. For the context of the model the action player 1 took is considered to be the last action for all subsequent transitions.
Here is an example of the state transition table for 4 players each with 3 dice.
play.liars.dice(auto = TRUE, players = 4, num.dice = 3, verbose = 0, agents = agents, Q.mat = Q.mat, print.trans = TRUE)$winner
## y.ctrl y.state y.action ## 1 3 808 raise ## 2 3 787 raise ## 3 3 778 raise ## 4 4 778 raise ## 5 4 736 raise ## 6 4 736 raise ## 7 1 736 call ## 8 1 673 call ## 9 1 678 raise ## 10 2 674 raise ## 11 3 673 raise ## 12 4 673 raise ## 13 4 589 raise ## 14 4 592 raise ## 15 1 589 call ## 16 1 505 call ## 17 1 507 call ## 18 1 423 call ## 19 2 422 raise ## 20 3 421 raise ## 21 1 421 raise ## 22 2 421 raise ## 23 2 316 raise ## 24 2 324 raise ## 25 2 240 raise ## 26 3 232 raise ## 27 3 148 raise ## 28 2 151 raise ## 29 2 88 raise
## [1] 1
This table is then passed to the update.Q()
function.
# update Q matrix update.Q <- function(play, Q.mat, reward, alpha = 0.1, discount = 0.9){ for(k in 2:nrow(play)){ curr.state <- play$y.state[k] prev.state <- play$y.state[k-1] action <- play$y.action[k] # Q update Q.mat[prev.state, action] <- (1-alpha)*Q.mat[prev.state, action] + alpha*(reward[[action]][prev.state, curr.state] + discount*max(Q.mat[curr.state,])) } return(Q.mat) }
The learning rate and discount values have been initialised to 0.1 and 0.9 respectively.
Liar’s Dice is now simulated 5000 times and Q value iteration is conducted with the above functions (see github for the full code). The first agent will be the only one that uses the Q matrix to decide it’s actions and therefore the only agent that is trained. It will bluff with a probability of 50% to add in some more realism to the agents decision. The other 3 will be random agents, bluffing 100% of the time and randomly deciding to call or raise at each decision point. It is expected that after training agent 1 will outperform the other 3 random agents.
# set the agents agent1 <- build.agent(c(0.5,0.5), method = "Q.decide") agent2 <- build.agent(c(1,0), method = "random") agent3 <- build.agent(c(1,0), method = "random") agent4 <- build.agent(c(1,0), method = "random") agents <- list(agent1, agent2, agent3, agent4)
# training its <- 5000 # setting the vector to store the winners winners <- vector("numeric", length = its) # loopy loop pb <- txtProgressBar(min = 0, max = its, style = 3) for(k in 1:its){ out <- play.liars.dice(auto = TRUE, players = 4, num.dice = 6, verbose = 0, agents = agents, Q.mat = Q.mat, train = TRUE) winners[k] <- out$winner Q.mat <- out$Q.mat setTxtProgressBar(pb, k) }
# table of winners table(winners)
## winners ## 1 2 3 4 ## 3288 434 496 782
# agent win percentage x <- 1000:5000 y <- (cumsum(winners == 1)/(1:5000))[x] qplot(x = x, y = y, geom = "line", xlab = "iterations", ylab = "Proportion of agent 1 wins")
After only 5000 iterations (which isn’t a lot given there are approximately 2000 valid states) the results show that agent 1 performs very well against the random agents. If each agent was equivalent the win percentage would be on average 25% where as here the results show agent 1 won 65% of the games.
The graph shows the percentage of wins for agent 1 continuing to increase as it is trained. Further training will improve the Q matrix and hence the performance of the agent. Given the stochastic nature of the game we wouldn’t expect a win percentage of 100%, so this is a great result.
Here is another 100 games with the trained agent.
library(pbapply) sim <- pbsapply(1:100, function(x) play.liars.dice(auto = TRUE, players = 4, num.dice = 6, verbose = 0, agents = agents, Q.mat = Q.mat)[["winner"]]) table(sim)
## sim ## 1 2 3 4 ## 65 6 15 14
sum(sim == 1)/100
## [1] 0.65
Solid effort.
What’s really happening here? The last variable in our state space formulation is the probability bucket which is in essence an approximation of the actual probability that the bid quantity exists on the table. At first the agent doesn’t know what to do with that information and will decide to call or raise randomly. Over time it learns how best to use that information and either calls or raises. In my previous post we simply used the probability directly by randomly choosing to raise with probability and call with probability . So in truth the original bots weren’t too bad.
The Q-learning algorithm has an advantage by being able to solve for more complex scenarios. The original agents only had the probability to base a decision, where as under an MDP framework the agent is free to also make decisions based on how many dice they have in hand and how many on the table. It has the ability to vary the risk depending on how close it is to winning or losing.
There are ways we can expand the state space to allow for potentially more complex decisions such as factoring in the remaining dice of the person to the left or right and allowing the agent to learn each players bluffing likelihoods. The state space could also be reduced to when a player has 0 dice and 1 or more, since whether the player has 2 or 6 dice may not matter too much. It’s worth an experiment to test this and see if it performs just as well.
In short a few things to take away are,
# set dice value set.dice.value <- function(note, max.val, prev.val = 0){ good.val <- FALSE while(!good.val){ val <- readline(note) %>% as.numeric() if(val > 0 & val <= max.val & !is.na(val) & (val > prev.val)){ good.val <- TRUE }else{ cat("please select a value between 1 and", max.val, "\n") } } return(val) } # roll table roll.table.fn <- function(rolls){ rt <- table(unlist(rolls)) roll.table <- rep(0, 6) names(roll.table) <- 1:6 roll.table[names(rt)] <- rt return(roll.table) } # call probability that there is at least the bid quantity on the table and converts to a bucket calc.prob <- function(x, bin.size = 20) { if(x[3] <= x[4]){ return(1*bin.size) }else{ n <- x[1]-x[2] k <- seq(min(x[3]-x[4], n), n, 1) return(floor(sum(choose(n, k)*(1/6)^k*(5/6)^(n-k))*bin.size)) } } # agent function chooses the best action e.g. raise or call # it needs to take in as input dice, total dice, dice value and dice quantity # as output action (raise or call), if raised also new dice value and quantity # dice, total.dice, dice.value, dice.quantity. # this is wrapped by a building function to make it easier to change certain # parameters and decisions an agent might make and be able to play them off against # each other to see which is the better strategy build.agent <- function(bluff.prob, method = "random"){ return( function(pars, Q.mat){ # bluff or truth bluff <- sample(c(TRUE, FALSE), 1, prob = bluff.prob) # pobability table roll.table <- roll.table.fn(pars$dice) ptable <- roll.table/sum(roll.table) # if the initial bid do this if(is.null(pars$dice.value)){ new.dice.value <- which.max(ptable[1:6]) %>% names() %>% as.numeric() new.dice.quantity <- max(roll.table) + 1 return(list(dice.value = new.dice.value, dice.quantity = new.dice.quantity)) } # are you gonna call? # use the Q matrix to make a decision if(method == "Q.decide"){ if(abs(max(Q.mat[pars$p1.state,]) - min(Q.mat[pars$p1.state,])) < 1e-6) call <- sample(c(TRUE, FALSE), 1) else{ # exploration vs exploitation if(runif(1) < 0.1){ call <- sample(c(TRUE, FALSE), 1) }else{ call <- names(which.max(Q.mat[pars$p1.state,])) == "call" } } # the random agent }else if(method == "random"){ prob <- 0.5 call <- sample(c(TRUE, FALSE), 1, prob = c(1-prob, prob)) # playing the actual numbers }else if(method == "true.prob"){ prob <- 1-sum(dbinom(0:max(c(pars$dice.quantity - roll.table[pars$dice.value], 0)), pars$total.dice-length(pars$dice), prob = 1/6)) call <- sample(c(TRUE, FALSE), 1, prob = c(1-prob, prob)) } # if called return the values if(call){ return(list(action = "call", dice.value = pars$dice.value, dice.quantity = pars$dice.quantity)) }else{ # raise # if choosing to bluff randomly select a number and increase by 1 if(bluff){ new.dice.value <- sample(1:6, 1) new.dice.quantity <- pars$dice.quantity + 1 }else{ # if not bluffing choose the maximum number in hand and increase by one # this should be made to be more flexible however in general raising by # 1 occurs 99% of the time new.dice.value <- which.max(ptable) %>% names() %>% as.numeric() new.dice.quantity <- pars$dice.quantity + 1 } # return the new values and action return(list(action = "raise", dice.value = new.dice.value, dice.quantity = new.dice.quantity)) } } ) } #------ play a round of liars dice liars.dice.round <- function(players, control, player.dice.count, agents, game.states, reward, Q.mat, a = 1, verbose = 1){ # set array for recording results y.ctrl = c(); y.state = c(); y.action = c() # roll the dice for each player if(verbose > 0) cat("\n\n") rolls <- lapply(1:players, function(x) sort(sample(1:6, player.dice.count[[x]], replace = TRUE))) if(verbose > 1) lapply(rolls, function(x) cat("dice: ", x, "\n")) total.dice <- sum(unlist(player.dice.count)) # set penalty penalty <- sapply(1:players, function(x) 0, simplify = FALSE) # print dice blocks if(verbose > 0) Dice(rolls[[1]]) # set up roll table roll.table <- roll.table.fn(rolls) # initial bid if(verbose > 0) cat("place first bid\nPlayer", control, "has control\n") if(control == a){ dice.value <- set.dice.value("dice value: ", 6) dice.quantity <- set.dice.value("quantity; ", sum(roll.table)) }else{ # agent plays p1.state <- which(game.states$total == total.dice & game.states$p1 == player.dice.count[[1]] & game.states$prob_cat == total.dice) pars <- list(dice = rolls[[control]], total.dice = total.dice, dice.value = NULL, dice.quantity = 0, p1.state = p1.state) agent.action <- agents[[control]](pars = pars, Q.mat = Q.mat) dice.value <- agent.action$dice.value dice.quantity <- agent.action$dice.quantity } # calculate probability cat and determine the game state # action set to raise because you can't call without an initial bid # this could be a 3rd action (initial bid) but it's not really necessary player.dice.qty <- table(rolls[[1]])[as.character(dice.value)] player.dice.qty <- ifelse(is.na(player.dice.qty), 0, player.dice.qty) %>% unname prob.cat <- calc.prob(c(total.dice, player.dice.count[[1]], dice.quantity, player.dice.qty)) p1.state <- which(game.states$total == total.dice & game.states$p1 == player.dice.count[[1]] & game.states$prob_cat == prob.cat) p1.action <- "raise" # storing states for Q iteration y.ctrl = c(); y.state = c(); y.action = c() # moving control to the next player # storing the previous player since if the next player calls the previous player could lose a die prev <- control control <- control %% players + 1 if(verbose > 0) cat("dice value ", dice.value, "; dice quantity ", dice.quantity, "\n") # loop through each player and continue until there is a winner and loser called <- FALSE while(!called){ # check if the player with control is still in the game - if not skip if(player.dice.count[[control]] > 0){ if(control == a){ action <- readline("raise or call (r/c)? ") }else{ # the agent makes a decision pars <- list(dice = rolls[[control]], total.dice = total.dice, dice.value = dice.value, dice.quantity = dice.quantity, p1.state = p1.state) agent.action <- agents[[control]](pars = pars, Q.mat = Q.mat) action <- agent.action$action } # storing states for reward iteration if(control == 1 & !is.null(agent.action$action)){ player.dice.qty <- table(rolls[[1]])[as.character(dice.value)] player.dice.qty <- ifelse(is.na(player.dice.qty), 0, player.dice.qty) %>% unname p1.action <- agent.action$action prob.cat <- calc.prob(c(total.dice, player.dice.count[[1]], dice.quantity, player.dice.qty)) p1.state <- which(game.states$total == total.dice & game.states$p1 == player.dice.count[[1]] & game.states$prob_cat == prob.cat) } # called if(action %in% c("call", "c")){ if(verbose > 0) { cat("player", control, "called\nRoll table\n") print(roll.table) } # dice are reavealed # check if the quantity of dice value is less or more than the total in the pool # if more control loses otherwise control-1 win if(dice.quantity > roll.table[dice.value]){ penalty[[prev]] <- penalty[[prev]] - 1 if(verbose > 0) cat("player", prev, "lost a die\n") }else{ penalty[[control]] <- penalty[[control]] - 1 if(verbose > 0) cat("player", control, "lost a die\n") } # for Q iteration y.ctrl <- c(y.ctrl, control); y.state <- c(y.state, p1.state); y.action <- c(y.action, p1.action) # if called use the penalty array to change states prob.cat <- calc.prob(c(total.dice, player.dice.count[[1]], dice.quantity, player.dice.qty)) p1.state <- which(game.states$total == total.dice-1 & game.states$p1 == player.dice.count[[1]]+penalty[[1]] & game.states$prob_cat == prob.cat) # break the loop called <- TRUE }else{ if(verbose > 0) cat("player", control, "raised\n") if(control == a){ # player sets next dice value dice.value <- set.dice.value("dice value: ", 6) dice.quantity <- set.dice.value("quantity; ", sum(roll.table)) }else{ dice.value <- agent.action$dice.value dice.quantity <- agent.action$dice.quantity } # p1 state after the raise prob.cat <- calc.prob(c(total.dice, player.dice.count[[1]], dice.quantity, player.dice.qty)) p1.state <- which(game.states$total == total.dice & game.states$p1 == player.dice.count[[1]] & game.states$prob_cat == prob.cat) if(verbose > 0) cat("dice value", dice.value, "; dice quantity", dice.quantity, "\n") } # store info for Q update y.ctrl <- c(y.ctrl, control); y.state <- c(y.state, p1.state); y.action <- c(y.action, p1.action) # set the control player to now be the previous player prev <- control } # next player has control control <- control %% players + 1 } # play results and return play <- data.frame(y.ctrl, y.state, y.action) return(list(penalty = penalty, play = play)) } # play a full game of liars dice play.liars.dice <- function(players = 4, num.dice = 6, auto = FALSE, verbose = 1, agents, Q.mat = NULL, train = FALSE, print.trans = FALSE){ # begin! if(verbose > 0) liars.dice.title() # setting the number of dice each player has ndice <- sapply(rep(num.dice, players), function(x) x, simplify = FALSE) players.left <- sum(unlist(ndice) > 0) # setting game states matrix game.states <- generate.game.states(players, num.dice) # set up reward matrix reward <- generate.reward.matrix(game.states) reward <- list(raise = reward, call = reward) # set Q matrix if null if(is.null(Q.mat)) Q.mat <- matrix(0, nrow = nrow(reward$raise), ncol = length(reward), dimnames = list(c(), names(reward))) # while there is at least 2 left in the game # who has control ctrl <- sample(1:players, 1) play.df <- data.frame() while(players.left > 1){ # play a round results <- liars.dice.round( players = players, control = ctrl, player.dice.count = ndice, game.states = game.states, reward = reward, Q.mat = Q.mat, agents = agents, a = as.numeric(!auto), verbose = verbose ) # update how many dice the players are left with given the # outcomes of the round for(k in seq_along(ndice)){ ndice[[k]] <- ndice[[k]] + results$penalty[[k]] if(ndice[[k]] == 0 & results$penalty[[k]] == -1){ if(verbose > 0) cat("player", k, "is out of the game\n") } # update who has control so they can start the bidding if(results$penalty[[k]] == -1){ ctrl <- k while(ndice[[ctrl]] == 0){ ctrl <- ctrl %% players + 1 } } } # checking how many are left and if anyone won the game players.left <- sum(unlist(ndice) > 0) if(players.left == 1){ if(verbose > 0) cat("player", which(unlist(ndice) > 0), "won the game\n") } # appending play play.df <- rbind(play.df, results$play) } if(print.trans) print(play.df) # update Q # rather than training after each action, training at the # end of each game in bulk # just easier this way if(train) Q.mat <- update.Q(play.df, Q.mat, reward) # return the winner and Q matrix return(list(winner = which(unlist(ndice) > 0), Q.mat = Q.mat)) }
The post Q-learning example with Liar’s Dice in R appeared first on Daniel Oehm | Gradient Descending.
]]>