If you’re like me, when running a process through a loop you’ll add in counters and progress indicators. That way […]
The post Fun with progress bars: Fish, daggers and the Star Wars trench run appeared first on Daniel Oehm | Gradient Descending.
]]>If you’re like me, when running a process through a loop you’ll add in counters and progress indicators. That way you’ll know if it will take 5 minutes or much longer. It’s also good for debugging to know when the code wigged-out.
This is typically what’s done. You take a time stamp at the start – start <- Sys.time()
, print out some indicators at each iteration – cat(“iteration”, k, “// reading file”, file, “\n”)
and print out how long it took at the end – print(Sys.time()-start)
. The problem is it will print out a new line at each time it is called, which is fine but ugly. You can reduce the number of lines printed by only printing out every 10th or 100th iteration e.g. if(k %% 10 == 0) …
.
A simple way to make this better is instead of using "\n"
for a new line use "\r"
for carriage return. This will overwrite the same line which is much neater. It’s much more satisfying watching a number go up, or down, whichever way is the good direction. Try it out…
y <- matrix(0, nrow = 31, ncol = 5) for(sim in 1:5){ y[1, sim] <- rnorm(1, 0, 8) for(j in 1:30){ y[j+1, sim] <- y[j, sim] + rnorm(1) # random walk cat("simulation", sim, "// time step", sprintf("%2.0f", j), "// random walk", sprintf(y[j+1, sim], fmt='% 6.2f'), "\r") Sys.sleep(0.1) } }
## simulation 5 // time step 30 // random walk 8.97
The best way is to use the {progress} package. This package allows you to simply add running time, eta, progress bars, percentage complete as well as custom counters to your code. First decide on what counters you want and the format of the string. The function identifies counters by using a colon at the beginning of the label. Check the doco for built-in tokens.
To add your own token add the label to the format string and add the token to tick()
. To make it pretty I recommend formatting digits with sprintf()
. Here’s an example.
library(progress) pb <- progress_bar$new(format = ":elapsedfull // eta :eta // simulation :sim // time step :ts // random walk :y [:bar]", total = 30*5, clear = FALSE) y <- matrix(0, nrow = 31, ncol = 5) for(sim in 1:5){ y[1, sim] <- rnorm(1, 0, 8) for(j in 1:30){ y[j+1, sim] <- y[j, sim] + rnorm(1) # random walk pb$tick(tokens = list(sim = sim, ts = sprintf("%2.0f", j), y = sprintf(y[j+1, sim], fmt='% 6.2f'))) Sys.sleep(0.1) } }
00:00:17 // eta 0s // simulation 5 // time step 30 // random walk -12.91 [====================================================]
You can also jazz it up with a bit of colour with {crayon}. Be careful with this, it doesn’t handle varying string lengths very well and can start a new line exploding your console.
library(crayon) pb <- progress_bar$new(format = green$bold(":elapsedfull // eta :eta // simulation :sim // time step :ts // random walk :y [:bar]"), total = 30*5, clear = FALSE) ...
00:00:17 // eta 0s // simulation 5 // time step 30 // random walk -12.91 [====================================================]
That’s a much neater progress bar.
Procrastination set in and creative tangents were followed. So, made a progress bar into a big fish which eats smaller fish … and made it green.
n <- 300 bar_fmt <- green$bold(":elapsedfull | :icon |") pb <- progress_bar$new(format = bar_fmt, total = n, clear = FALSE) icon <- progress_bar_icon("fish", n, 75) for(j in 1:n){ pb$tick(tokens = list( icon = token(icon, j) )) Sys.sleep(0.03) }
Each fish represents 25% completion. Once they’re all gobbled up, the job is done.
I also threw knives at boxes. Each box represents 20% completion.
n <- 300 bar_fmt <- green$bold(":elapsedfull | :icon |") pb <- progress_bar$new(format = bar_fmt, total = n, clear = FALSE) icon <- progress_bar_icon("dagger", n, 75) for(j in 1:n){ pb$tick(tokens = list( icon = token(icon, j) )) Sys.sleep(0.03) }
And my personal favourite, the Star Wars trench run.
n <- 500 bar_fmt <- green$bold(":elapsedfull | :icon |") pb <- progress_bar$new(format = bar_fmt, total = n, clear = FALSE) icon <- progress_bar_icon("tiefighter", n, 75) for(j in 1:n){ pb$tick(tokens = list( icon = token(icon, j) )) Sys.sleep(0.03) }
Ok… I have spent way too long on this! But at least it was fun. If you want to play around with it, feel free to download it from Git.
devtools::install_github(“doehm/progressart”)
The post Fun with progress bars: Fish, daggers and the Star Wars trench run appeared first on Daniel Oehm | Gradient Descending.
]]>In my previous post I built a Shiny app mapping accidents on Queensland roads which was great at showing the […]
The post Bayesian estimation of fatality rates and accidents involving cyclists on Queensland roads appeared first on Daniel Oehm | Gradient Descending.
]]>In my previous post I built a Shiny app mapping accidents on Queensland roads which was great at showing the problematic areas within cities and regional areas. I have added to this by estimating the fatality rate given the observed accidents and the rate of accidents involving cyclists for SA3 and SA4 areas. I have also updated the filters making them tidier. What follows in commentary of what can be found in the Shiny app. If you want to jump straight to it, run
shiny::runGitHub("doehm/road-accidents/", "doehm", launch.browser = TRUE)
I used a Bayesian approach to estimate the fatality rate for 2017 (the data isn’t complete for 2018) and presented it as a proportion of the number of observed road accidents. The data dates back to 2001 but it’s reasonable to use the most recent data to allow for improvements in road conditions, policies, advertising, population growth, etc which may have an underlying impact on the risk of road accidents.
To construct a prior I used years 2014-2016 at SA3 level. By taking 10,000 draws from the posterior we estimate the fatality rate distributions for each area. The bands represent the 80% and 95% prediction interval.
The top 3 most dangerous SA4’s for 2017 are Mackay – Isaac – Whitsunday, Wide Bay and The Queensland outback. It’s not surprising that the large majority of fatalities occurred at high speeds on highways and back roads. Having said that a few have occurred closer to the town centers in these areas. It’s possible the road conditions are not particularly good but that’s hard to determine. The road condition filter only has 4 states, sealed (dry / wet) and unsealed (dry / wet) and doesn’t offer much more information. It appears dry conditions on sealed roads is when most accidents occur.
The 3 least dangerous areas are Brisbane North, Brisbane Inner City and Brisbane South. Again, it shouldn’t be surprising given the speed limit is lower in the city, if accidents occur they are generally less severe. There are a lot of rear ended accidents and accidents when leaving a car park – the classic.
At the SA3 level the incident rates focus on smaller geographical areas which can highlight other features. For example, at the SA4 level the area often includes the city centers and the Hinterland regions which tend to have different rates where it’s often higher in the Hinterlands. This is largely due to the differing speed limits. The most dangerous areas are Burnett, Port Douglas and the Ipswich Hinterlands. The least dangerous are Chermside, Brisbane Inner West and Springfield – Redbank. Click the image to zoom in, there are a lot of SA3’s.
As a cyclist it’s good to know which roads to pay particular attention on. In a similar way we’ll look at the rate of road accidents involving cyclists by area. An influencing factor is how many cyclists are on the roads relative to cars and other vehicles. If the area has a strong cycling culture it’s more likely that an accident will involve a cyclist.
It’s not surprising the dense city centers with more traffic have more accidents involving cyclists. At the SA4 level it’s pretty clear that Brisbane Inner City is a dangerous place to ride a bike, particularly on Adelaide street which you can see from the app (go to show filter > unit type involved > bicycle). The rate of accidents involving cyclists in Brisbane Inner City is significantly higher than all other SA4’s. I’m curious to know if the the growth of the CityCycle scheme (which is awesome by the way) is somewhat a contributing factor. Although, given this was introduced in 2011, and 2005-2010 saw the highest rate of growth in accidents involving cyclists, probably isn’t the biggest factor in the high rate of accidents but it should contribute if more people are on bikes – they’re also not the easiest bikes in the world to ride. The other areas in the top 3, Brisbane City West and Cairns. Sheridan Street, the main drag in Cairns is where you need to be extra vigilant.
Here Noosa tops the chart for 2017. It sits within the Sunshine Coast SA4 region ranking 5th in the chart above which also includes the Hinterland regions which is why it scores lower above. Noosa has a strong cycling culture which could be influencing it’s high rate. With a median of a whopping 18% definitely keep your eyes peeled when cycling on those roads. Also, take care in Sherwood – Indooroopilly. Keep in mind that these areas are not necessarily where the most accidents involving cyclists occur. Rather if there is an accident it’s more likely to have involved a cyclists relative to other areas. If you’re driving in those areas keep an eye out for cyclists on the road, perhaps there aren’t marked cycling lanes or off-road bike paths or some dodgy intersections.
Ideally the rate would be adjusted for the number of cyclists on the road, however this is still a useful statistic to find the more dangerous areas. This analysis has also been done for pedestrians involved in car accidents which is can be found in the app. Just a heads up, be careful in Fortitude Valley.
Run the shiny app
library(shiny) runGitHub("doehm/road-accidents/", "doehm", launch.browser = TRUE)
Load the data from Github.
library(repmis) source_data("https://github.com/doehm/road-accidents/raw/master/data/road-accident-data.Rdata")
Estimate hyperparameters
# get prior information from previous 3 years # use method of moments to get alpha and beta # ideally we would use a hierarchical model but this is suitable hyper_params <- function(y){ mu <- mean(y) s2 <- var(y) gamma <- mu*(1-mu)/s2-1 alpha <- gamma*mu beta <- gamma*(1-mu) return(list(alpha = alpha, beta = beta)) } # estimating the actual parameters df <- accidents_raw %>% filter( Loc_Queensland_Transport_Region != "Unknown", Crash_Year >= 2014, Crash_Year <= 2016 ) %>% group_by(Loc_ABS_Statistical_Area_3) %>% summarise( count = n(), n_fatalities = sum(Count_Casualty_Fatality) ) %>% mutate(theta = n_fatalities/(count)) hyper_params(df$theta)
Forest plots
library(tidyverse) library(showtext) # font font_add_google(name = "Montserrat", family = "mont") showtext_auto() # SA3 forest plot of fatality rate # set my theme and colours theme_forest <- function(scale = 1){ theme_minimal() + theme( legend.position = "none", axis.text = element_text(family = "mont", size = 16*scale), axis.text.x = element_text(vjust = 0.5), axis.title.y = element_blank(), axis.title.x = element_text(family = "mont", size = 16*scale), plot.title = element_text(family = "mont", hjust = 0.5, size = 26*scale, face = "bold"), plot.subtitle = element_text(family = "mont", hjust = 0.5, size = 20*scale), plot.caption = element_text(size = 12*scale) ) } my_cols <- function(n = 16) colorRampPalette(c("darkmagenta", "turquoise"))(n) # simulate from the posterior posterior_f <- function(df, y, n, a = 1.3, b = 77, inflator = 100) { out <- data.frame() qs <- c(0.025, 0.1, 0.5, 0.9, 0.975) for(k in 1:nrow(df)){ out <- rbind(out, inflator*rgamma(1e4, shape = a+y[k], rate = b+n[k]) %>% quantile(qs)) } colnames(out) <- paste0("q", 100*qs) return(out) } # SA4 # fatalities areas <- grep("Area", colnames(accidents_raw), value = TRUE)[3:4] names(areas) <- c("sa3", "sa4") fatality_fn <- function(area){ accidents_raw %>% group_by_(area) %>% filter( Crash_Year == 2017, ) %>% summarise( count = length(Count_Casualty_Total), n_fatalities = sum(Count_Casualty_Fatality) ) %>% bind_cols(posterior_f(df = ., y = .$n_fatalities, n = .$count)) %>% arrange(q50) %>% mutate_(area = interp(~factor(v, level = v), v = as.name(area))) %>% ggplot() + geom_segment(mapping = aes(x = q2.5, xend = q97.5, y = area, yend = area)) + geom_segment(mapping = aes(x = q10, xend = q90, y = area, yend = area, col = q50), size = 2) + geom_point(mapping = aes(x = q50, y = area), pch = 3) + theme_forest() + scale_colour_gradientn(colors = my_cols()) + labs( title = "Fatality rate given observed road accidents", subtitle = paste("Bayesian estimate of the fatality rate for", toupper(names(area)), "areas in 2017"), x = "Fatality rate (%)") } fatality_plots <- list(sa3 = fatality_fn(areas[1]), sa4 = fatality_fn(areas[2]))
Cyclist plot
# cyclists df <- accidents_raw %>% filter( Crash_Year >= 2014, Crash_Year <= 2016, Loc_ABS_Statistical_Area_3 != "Unknown" ) %>% group_by(Loc_ABS_Statistical_Area_3) %>% summarise( n_bicycles = sum(Count_Unit_Bicycle > 0), n_accidents = n() ) %>% mutate(p = n_bicycles/n_accidents) %>% arrange(desc(p)) # estimate hyperparameters hyper_params(df$p) # cyclists cyclist_fn <- function(area){ accidents_raw %>% group_by_(area) %>% filter(Crash_Year == 2017) %>% summarise( count = n(), n_bicycles = sum(Count_Unit_Bicycle > 0) ) %>% bind_cols(posterior_f(df = ., y = .$n_bicycles, n = .$count, a = 1.55, b = 25)) %>% arrange(q50) %>% mutate_(area = interp(~factor(v, level = v), v = as.name(area))) %>% ggplot() + geom_segment(mapping = aes(x = q2.5, xend = q97.5, y = area, yend = area)) + geom_segment(mapping = aes(x = q10, xend = q90, y = area, yend = area, col = q50), size = 2) + geom_point(mapping = aes(x = q50, y = area), pch = 3) + theme_forest() + scale_colour_gradientn(colors = my_cols()) + labs( title = "Rate of cyclists involved in road accidents", subtitle = paste("Bayesian estimate of the rate of accidents involving cyclists for", toupper(names(area)), "areas in 2017"), x = "Accidents involving cyclists (%)" ) } cyclist_plots <- list(sa3 = cyclist_fn(areas[1]), sa4 = cyclist_fn(areas[2])) # Brisbane inner accidents_raw %>% filter(Loc_ABS_Statistical_Area_4 == "Brisbane Inner City", Crash_Year < 2018) %>% group_by(Crash_Year) %>% summarise( n_bikes = sum(Count_Unit_Bicycle > 0), n_accidents = n(), p_bikes = n_bikes/n_accidents ) %>% bind_cols(posterior_f(df = ., y = .$n_bikes, n = .$n_accidents, a = 1.55, b = 25)/100) %>% ggplot(aes(x = Crash_Year, y = q50)) + geom_line(col = "darkmagenta") + geom_point(col = "darkmagenta") + theme_minimal() + theme_forest() + labs( x = "Year", title = "Accidents involving cyclists - Brisbane Inner City", subtitle = "Accidents involving cyclists are increasing, indicative of growth in popularity of cycling - Be careful on the roads" )
The post Bayesian estimation of fatality rates and accidents involving cyclists on Queensland roads appeared first on Daniel Oehm | Gradient Descending.
]]>The Queensland government collects data on road accidents dating back to 1st January 2001 and details characteristics of the incident […]
The post Queensland road accidents mapped with Shiny and leaflet in R appeared first on Daniel Oehm | Gradient Descending.
]]>The Queensland government collects data on road accidents dating back to 1st January 2001 and details characteristics of the incident including,
Mapping this data highlights hot spots where car accidents occur more often. In particular the dangerous areas in wet conditions, problematic intersections and the areas of Queensland which are more dangerous than others in terms of fatality rates.
I developed a Shiny App utilising leaflet to easily explore the data (and just for fun). It features,
This data is of road accidents, so the estimate of fatality rate in this case is the fatality rate given the vehicle was involved in an accident, rather than the fatality rate by road accident in the population. It is a slightly different take on how this statistic is usually published, but a useful one.
The best way to view the app is to run the following code. Firstly, check to make sure you have the packages installed by running
check_packages <- function(packages){ if(all(packages %in% rownames(installed.packages()))){ TRUE }else{ cat("Install the following packages before proceeding\n", packages[!(packages %in% rownames(installed.packages()))], "\n") } } packages_needed <- c("tidyverse", "shiny", "leaflet", "leaflet.extras", "magrittr", "htmltools", "htmlwidgets", "showtext", "data.table") check_packages(packages_needed)
If all good, now run the line below and it will load the app.
runGitHub("doehm/road-accidents/", "doehm", launch.browser = TRUE)
This will launch it directly on your machine. Or you can follow the link directly to the Shiny app.
There are a lot of neat things we can do with this data and I’ll be adding to the app over time.
A subset of the app focuses on the “Brisbane Inner” SA3 area to give a taste of what to expect. It shows car accidents in the city since 1st January 2013. When zooming in, hover over the marker to get a short description of the crash.
View the full screen map here.
Below is the underlying code of the example above leaflet map, but I strongly recommend running the code above to view the Shiny app. See Github for the full code.
# queensland road accident data # libraries library(tidyverse) library(shiny) library(leaflet) library(leaflet.extras) library(magrittr) library(htmltools) library(htmlwidgets) library(showtext) library(data.table) # font try({ font_add_google(name = "Montserrat", family = "mont") showtext_auto() }, TRUE) # load data # or if it doesn't work grab the Rdata file from Github - see link above load_data <- function(){ if(!file.exists("locations.csv")){ cat('\n Download may take a few minutes...\n') url <- "http://www.tmr.qld.gov.au/~/media/aboutus/corpinfo/Open%20data/crash/locations.csv" download.file(url, destfile = "locations.csv", method="libcurl") } accidents_raw <- read_csv("locations.csv") return(accidents_raw) } accidents_raw <- load_data() %>% filter(Crash_Severity != "Property damage only") # sample of brisbane inner accidents <- accidents_raw %>% filter( Loc_ABS_Statistical_Area_3 == "Brisbane Inner", Crash_Year > 2013 ) %>% mutate(fatality = Count_Casualty_Fatality > 0) # basic leaflet m <- leaflet(accidents) %>% addProviderTiles(providers$Stamen.Toner, group = "Black and white") %>% addTiles(options = providerTileOptions(noWrap = TRUE), group="Colour") %>% addMarkers( lng = ~Crash_Longitude_GDA94, lat = ~Crash_Latitude_GDA94, clusterOptions = markerClusterOptions(), label = ~htmlEscape(Crash_DCA_Description) ) %>% addCircleMarkers( lng = ~Crash_Longitude_GDA94[accidents$fatality], lat = ~Crash_Latitude_GDA94[accidents$fatality], color = "#8B0000", stroke = FALSE, fillOpacity = 0.8, group = "Fatalities" ) %>% addHeatmap( lng = ~Crash_Longitude_GDA94, lat = ~Crash_Latitude_GDA94, radius = 17, blur = 25, cellSize = 25 ) %>% addLayersControl( overlayGroups = c("Fatalities"), baseGroups = c("Black and white","Colour"), options = layersControlOptions(collapsed = FALSE) )
The post Queensland road accidents mapped with Shiny and leaflet in R appeared first on Daniel Oehm | Gradient Descending.
]]>The Buffalo Stampede is one of the most brutal races on the ultra calendar covering 75km and 4545m of climbing. […]
The post Buffalo Stampede Ultra 75km 2019: Race report appeared first on Daniel Oehm | Gradient Descending.
]]>The Buffalo Stampede is one of the most brutal races on the ultra calendar covering 75km and 4545m of climbing. An out-and-back from Bright to the top of Mount Buffalo tapping Mystic and Clear Spot along the way. It’s iconic for it’s brutally steep ascents that are technical and demanding.
This was my second Buffalo Ultra. My first was in 2016 and one of the hardest I’ve done. In 2016 I finished 9th in 9:56:00 and that feeling of crossing the finish line is one of the greatest.
This year my goal was to do better than my 2016 effort and given the work I’ve put in training I was confident heading into it. My goal time was 9:30:00 which was reasonable since Buffalo does it’s best to throw you into the pain cave and ruin your day.
Kicking of at 6am it was still a little dark. It doesn’t take long to hit the first climb up to Mystic, a 500m ascent to get warmed up. I took it very easy. Trying to run this first climb is a death sentence but people still try.
One sure way to destroy your race is attacking the descents too hard, too early. This is easily done on the infamous Mick’s track, at -45% gradient it’s a monster. I took it very easy, saving the quads for later and just got to the bottom safely.
It’s always a good feeling tapping Clear Spot (10km mark) since you know much of the hard climbing is over, for now. I got into a nice rhythm, smashed my mash potato, cruised up and over Keating Ridge and into the Eurobin aid station clocking up 26kms.
At this point the field was pretty spread out. I came in in 3rd place running with 4th. Refueled, left in 4th place and began the 1000m climb to Buffalo Summit. It wasn’t long until I caught and moved into 3rd place. I was feeling very strong on the climbs and took advantage of that.
At the Chalet Aid station, again refueled and quickly began the loop still in 3rd place and about 10 minutes behind second. The Galleries are always fun, scrambling through the rocks. I recall in 2016 I was cramping up massively at this point. This time however, not even close. I was in good shape.
Coming into the Chalet aid station at the 42km mark again, still feeling great. I had really closed the gap to 2nd place who was only 2 minutes ahead, and apparently looking way more fresh which was a good sign!
On the descent I caught and passed 2nd place. I tried to not get too carried away and just kept the legs turning over and remain comfortable.
About halfway down Buffalo I felt the first tweak in my ITB. I’ve been having ITB issues for the last month leading into the race which is just depressing when you’re in top form but can’t run due to an injury. It was well enough to start but this is what I feared.
It did not feel good and brought me to a walk. I didn’t walk for long and tried to take the edge off on the steep downs hoping it wouldn’t do it again.
Entering Eurobin aid station at the 53km mark in 2nd place, 14 mins behind 1st and a few minutes in front of 3rd and 4th. It was a quick stop for me, grabbed more mash, more water melon and electrolyte. I got moving pretty quickly. I was very happy with where my energy levels and felt I had heaps left in the tank. My ITB was hanging in there but it did not feel good at all and just hoping it would hold together for the last 26kms. I smashed the anti-inflammatories hoping that would do something.
I took the climb over Keating Ridge fairly easy since there were two monsters waiting for me at the 65km mark. A few minutes into the descent it happened, my ITB locked up, I couldn’t straighten my leg and every step was very painful. Running was impossible.
I walked for 500m trying to get some movement and start running again. Limping along for 100m it again flared up, very painful to put any weight on it and forced into a walk. I knew deep down my race was over.
3rd and 4th place passed me and there was nothing I could do about it. I tried running after walking for about 2km but again 100m down the road running become too painful.
At this point, 60km deep, I had to make a decision, do I continue even though the race I wanted is over and just finish by walking it in and hope I don’t do any more damage? Or tap out and not risk major damage? The thought of rupturing my ITB and turning a few weeks of recovery into a few months was not a risk I was willing to take. So, I made the tough decision to drop. No matter how right that decision was it never feels good to drop from a race.
I gave the word and got a lift back to the finish line for first aid. The medical officer felt my knee and said “Faaaark!”. It was in some ways comforting to hear, reassuring me I made the right decision to drop.
I haven’t had a good run in my last 3 races so the way it went bums me out even more. But I’m remaining positive and spending the next few weeks doing all the right things to fix it.
I was on track to crush my goal time and have one of my best races. I guess now this race goes onto the ‘I need redemption’ pile, so I’ll be back next year. For now, it’s preparing for UTA and hoping I don’t see a repeat.
View my race on Strava
The post Buffalo Stampede Ultra 75km 2019: Race report appeared first on Daniel Oehm | Gradient Descending.
]]>Every so often a problem arises where it’s appropriate to use gradient descent, and it’s fun (and / or easier) […]
The post Applying gradient descent – primer / refresher appeared first on Daniel Oehm | Gradient Descending.
]]>Every so often a problem arises where it’s appropriate to use gradient descent, and it’s fun (and / or easier) to apply it manually. Recently I’ve applied it optimising a basic recommender system to ‘unsuppressing’ suppressed tabular data. I thought I’d do a series of posts about how I’ve used gradient descent, but figured it was worth while starting with the basics as a primer / refresher.
To understand how this works gradient descent is applied we’ll use the classic example, linear regression.
A simple linear regression model is of the form
where
The objective is to find the parameters such that they minimise the mean squared error.
This is a good problem since we know the analytical solution and can check our results.
In practice you would never use gradient descent to solve a regression problem, but it is useful for learning the concepts.
Set up
library(ggplot2) set.seed(241) nobs <- 250 b0 <- 4 b1 <- 2 # simulate data x <- rnorm(nobs) y <- b0 + b1*x + rnorm(nobs, 0, 0.5) df <- data.frame(x, y) # plot data g1 <- ggplot(df, aes(x = x, y = y)) + geom_point(size = 2) + theme_minimal()
The analytical solution is given by
# set model matrix X <- model.matrix(y ~ x, data = df) beta <- solve(t(X) %*% X) %*% t(X) %*% y beta
## [,1] ## (Intercept) 4.009283 ## x 2.016444
And just to convince ourselves this is correct
# linear model formulation lm1 <- lm(y ~ x, data = df) coef(lm1)
## (Intercept) x ## 4.009283 2.016444
g1 + geom_abline(slope = coef(lm1)[2], intercept = coef(lm1)[1], col = "darkmagenta", size = 1)
The objective is to achieve the same result using gradient descent. It works by updating the parameters with each iteration in the direction of negative gradient to minimise the mean squared error i.e.
where is the learning rate. Here is the MSE with respect to the regression parameters. Firstly, we find the partial derivatives of .
The learning rate is to ensure we don’t jump too far with each iteration and rather some proportion of the gradient, otherwise we could end up overshooting the minimum and taking much longer to converge or not find the optimal solution at all.
Applying this to the problem above, we’ll initialise our values for to something sensible e.g. . I’ll choose a learning rate of . This is a slow burn, a learning rate of 0.1-0.2 is more appropriate for this problem but we’ll get to see the movement of the gradient better. It’s worth trying different values of to see how it changes convergence. The algorithm is setup as
# gradient descent function gradientDescent <- function(formula, data, par.init, loss.fun, lr, iters){ formula <- as.formula(formula) X <- model.matrix(formula, data = data) y <- data[,all.vars(formula)[1]] par <- loss <- matrix(NA, nrow = iters+1, ncol = 2) par[1,] <- par.init for(k in 1:iters){ loss[k,] <- loss.fun(X=X, y=y, par=par[k,]) par[k+1,] <- par[k,] - lr*loss[k,] } return(list(par = par)) } # loss function loss.fun <- function(X, y, par) return(-2/nrow(X)*(t(X) %*% (y - X %*% par))) # gradient descent. not much to it really beta <- gradientDescent(y ~ x, data = df, par.init = c(1, 1), loss.fun = loss.fun, lr = 0.01, iters = 1000)$par # plotting results z <- seq(1, 1001, 10) g1 + geom_abline(slope = beta[z,2], intercept = beta[z,1], col = "darkmagenta", alpha = 0.2, size = 1)
tail(beta, 1)
## [,1] [,2] ## [1001,] 4.009283 2.016444
As expected we obtain the same result. The lines show the gradient and how the parameters converge to the optimal values. A less reasonable set of starting values still converges quickly to the optimal solution showing how well graident descent works on linear regression.
beta <- gradientDescent(y ~ x, data = df, par.init = c(6, -1), loss.fun = loss.fun, lr = 0.01, iters = 1000)$par # plotting results z <- seq(1, 1001, 10) beta.df <- data.frame(b0 = beta[z,1], b1 = beta[z,2]) g1 + geom_abline(data = beta.df, mapping = aes(slope = b1, intercept = b0), col = "darkmagenta", alpha = 0.2, size = 1)
tail(beta, 1)
## [,1] [,2] ## [1001,] 4.009283 2.016444
library(gganimate) library(magrittr) ggif_minimal <- df %>% ggplot(aes(x = x, y = y)) + geom_point(size = 2) + theme_minimal() + geom_abline(data = beta.df, mapping = aes(slope = b1, intercept = b0), col = "darkmagenta", size = 1) + geom_text( data = data.frame(z, b0 = beta[z,1], b1 = beta[z,2]), mapping = aes( x = -2.8, y = 9, label = paste("b0 = ", round(b0, 2), "\nb1 = ", round(b1, 2))), hjust = 0, size = 6 ) + transition_reveal(z) + ease_aes("linear") + enter_appear() + exit_fade() animate(ggif_minimal, width = 1920, height = 1080, fps = 80)
They are the basics of applying gradient descent. In practice there is no need to use gradient descent to solve a regression problem, but once you know how to apply it you’ll find real-world applications elsewhere that are more complicated (and interesting). If you can define the objective function and it is differentiable, you can apply gradient descent. In later posts i’ll demonstrate how I’ve applied it to real world problems. Stay tuned!
The post Applying gradient descent – primer / refresher appeared first on Daniel Oehm | Gradient Descending.
]]>It’s generally accepted that Martingale betting strategies don’t work. But people still gravitate towards them because they are intuitive. I […]
The post Martingale strategies don’t work, but we knew that – Simulation analysis in R appeared first on Daniel Oehm | Gradient Descending.
]]>It’s generally accepted that Martingale betting strategies don’t work. But people still gravitate towards them because they are intuitive. I was curious to find out how they actually perform.
Disclaimer: I do not encourage gambling, nor do I gamble myself but the games are good examples of stochastic processes.
The Classic Martingale strategy is as follows. Make an initial bet $. If you win, bet $ on the next round. If you lose, double your previous bet. In a nutshell you bet $ where is the number of losses in a row. The intention is to make your money back on the next win.
Assume we start with $100 and our initial bet is $1. We keep playing until there is not enough money in the cash pool to make the next bet. Also assume we are playing roulette on an American table and place bets on red or black which both have the probability 18/38. It only takes a streak of 6 losses before the game is over because we don’t have enough cash to double our bet for a 7th time. If we see a streak of 10 or more it really starts to get out of control.
trials | bet | cumulative_loss |
---|---|---|
1 | 1 | 1 |
2 | 2 | 3 |
3 | 4 | 7 |
4 | 8 | 15 |
5 | 16 | 31 |
6 | 32 | 63 |
7 | 64 | 127 |
8 | 128 | 255 |
9 | 256 | 511 |
10 | 512 | 1023 |
The probability of losing 6 in a row is . Sounds unlikely, but it will occur more often than you think. With each win we will win $1, so once we have won 27 times we’ll have enough cash in order to afford a losing streak of 6 and bet on the 7th.
It’s more likely we’ll have a few wins and losses before observing a long losing streak that takes us out of the game. The question is how many trials (spins of the roulette wheel) will we place bets on before we lose our money and play stops? A slight variation I’ve applied is, if there is not enough money left to double the bet, we will simply bet the remainder in the cash pool, in other words go all-in. More in line with what someone might do.
This simulation was a relatively lucky one, winning over $170 and almost 400 trials, however one bad streak and it’s all over. You can plot as many simulations as you like, some are shorter and some are longer but they all end the same way.
This is a typical pattern for the classic strategy. If you wish to see more try running the code at the end of this post or view this set.
Instead of doubling the bet upon a loss, double the bet upon a win. Not surprisingly this ends the same way as the classic strategy.
The players cash amount slowly decreases by $1 on each loss. Occasionally there is a big win, rather than a big loss. With this strategy you can watch your money dwindle away rather than vanish in front of your eyes.
This plot is using the same win-lose sequence as the one above. In this case the maximum cash held throughout the game is higher with the classic strategy than the reverse.
Click here to see more simulations.
These strategies were simulated 20,000 times. The distribution of the number of trials shows how long a typical game will last until bankruptcy. The classic strategy has a very long tail, so potentially could be playing for a very long time. The maximum number of trials is this simulation was 64254. But you could also be playing for a very short time.
The reverse strategy has a slightly higher median number of trials but much less variable than the classic strategy meaning you can be sure to play between 166 and 217 spins.
min | 2.5% | 10% | 50% | 90% | 97.5% | max | |
---|---|---|---|---|---|---|---|
classic | 7 | 10 | 22 | 170 | 1132 | 2759 | 64254 |
reverse | 152 | 167 | 172 | 191 | 207 | 213 | 226 |
Assume the goal is to double your money. What is the probability you’ll double your money before going bust? After 20,000 simulations for both strategies, the probability you will double your money using…
The Classic Martingale strategy tends to do better on average, but only slightly. Neither of these strategies are better than simply playing once and going all-in which is 0.47
The distribution of the maximum amount of cash held by the player at any given time during the game shows the classic strategy has the slight edge over the reverse strategy on average. Although the reverse strategy has the potential for massive wins if you score a good winning streak.
min | 2.5% | 10% | 50% | 90% | 97.5% | max | |
---|---|---|---|---|---|---|---|
classic | 100 | 101 | 107 | 158 | 425 | 926 | 19894 |
reverse | 100 | 100 | 102 | 133 | 534 | 2062 | 131163 |
However, keep in mind these simulations all resulted in total loss at the end of the game. Perhaps the key is to know when to stop?
Rather than stopping once you reach a specified amount which may not happen, stop when the bet reaches a specified amount.
We could specify a streak length, however a better idea is to specify a ratio of bet to total cash. This way the is stopping condition is dynamic. For example, if there is a winning streak we’ll have more money in which to bet.
Essentially by using this ratio we are fixing a certain level risk rather than betting amount. The ratio is calculated as
If we fix a stopping ratio of 0.1, we could place 4 bets before the ratio exceeds 0.1. If our initial cash pool was $200 we could place 5 bets until the ratio exceeds 0.1.
If we stop as soon as this ratio is reached it means we’re on a losing streak so it makes sense in the case of the classic strategy to bet again until our next win and walk away.
In the case of the reverse it makes sense to leave as soon as the ratio is met since we’re on a winning streak.
There are clear differences and similarities between the two strategies.
Overall the stopping strategies help to minimise loss rather than lock in wins, so on average you will still lose your money.
The code for the plots in this post can be found on github. The Martingale function is below.
# libraries library(tidyverse) # martingale function martingale <- function(bet, cash, p, stop_condn = Inf, stop_factor = Inf, reverse = FALSE, plot = TRUE, stop_on_next_win = TRUE){ bet_vec <- vector(mode = "numeric") cash_vec <- vector(mode = "numeric") outcome <- vector(mode = "numeric") winnings <- vector(mode = "numeric") total_cash <- vector(mode = "numeric") trial <- 0 total_cash[1] <- cash while(total_cash[max(trial, 1)] > 0){ # iterate through trials trial <- trial + 1 # update cash pool if(trial == 1){ cash_vec[trial] <- cash }else{ cash_vec[trial] <- total_cash[trial-1] } # set bet if(!reverse){ if(outcome[trial - 1] == 1 || trial == 1){ bet_vec[trial] <- bet }else{ bet_vec[trial] <- min(2*bet_vec[trial-1], cash_vec[trial]) # if there isn't enough to double the bet just bet what is left } }else{ if(outcome[trial - 1] == 0 || trial == 1){ bet_vec[trial] <- bet }else{ bet_vec[trial] <- min(2*bet_vec[trial-1], cash_vec[trial]) # if there isn't enough to double the bet just bet what is left } } # stop condition if(bet_vec[trial]/cash_vec[trial] > stop_condn){
if(stop_on_next_win & !reverse){
stop_trigger <- TRUE
}else{
outcome[trial] <- NA
winnings[trial] <- NA
total_cash[trial] <- cash_vec[trial]
break
}
} outcome[trial] <- sample(c(0,1), 1, prob = c(1-p, p)) winnings[trial] <- bet_vec[trial]*outcome[trial] - bet_vec[trial]*(1-outcome[trial]) total_cash[trial] <- cash_vec[trial] + winnings[trial] # stop condition if(total_cash[trial] >= stop_factor*cash) break } # make the plot g1 <- NULL if(plot){ df <- data.frame(trials = 1:trial, cash = total_cash) gg <- ggplot() + geom_line(data = df, mapping = aes(x = trials, y = cash), col = "darkmagenta", lty = 1, size = 1) + geom_hline(yintercept = cash_vec[1], col = "grey", lty = 2) + theme_minimal() + labs( x = "Number of spins", y = "Total cash in hand", title = ifelse(reverse, "Reverse Martingale strategy", "Martingale strategy"), subtitle = "The growth and decline of the gamblers cash pool - it always ends the same way" ) + ylim(0, NA) print(gg) } return(list( bet = bet_vec, cash = cash_vec, outcome = outcome, total_cash = total_cash, trials = trial, plot = gg)) } # run the simulation and plot the output # try different parameters to see the effect martingale(1, 100, 18/38, reverse = FALSE, plot = TRUE, stop_condn = 1,
stop_on_next_win = TRUE)
The post Martingale strategies don’t work, but we knew that – Simulation analysis in R appeared first on Daniel Oehm | Gradient Descending.
]]>ggplot – You can spot one from a mile away, which is great! And when you do it’s a silent […]
The post Adding Custom Fonts to ggplot in R appeared first on Daniel Oehm | Gradient Descending.
]]>ggplot – You can spot one from a mile away, which is great! And when you do it’s a silent fist bump. But sometimes you want more than the standard theme.
Fonts can breathe new life into your plots, helping to match the theme of your presentation, poster or report. This is always a second thought for me and need to work out how to do it again, hence the post.
There are two main packages for managing fonts – extrafont
, and showtext
.
A relatively old package and it’s not well supported unfortunately. You can run into problems, however the base functions work well enough.
The fonts in the system directory are first imported into the extrafontdb
using font_import()
. This only needs to be run once in order to load the fonts into the right directory. Secondly, they are registered in R using loadfonts()
for your specific device. The fonts need to be loaded each session.
library(tidyverse) library(extrafont) library(cowplot)
# import fonts - only once font_import()
# load fonts - every session loadfonts(device = "win", quiet = TRUE)
Below are all the available fonts (that I have – click on the image to enlarge).
To use these fonts in a plot, change the text family using one of the names above. For demonstration I’ll use the Antigua corn data from the DAAG package.
library(DAAG) # corn plot corn <- antigua %>% dplyr::filter(ears > 0) %>% ggplot(, aes(x = ears, y = harvwt, col = site)) + geom_point(size = 4) + scale_colour_manual(values = colorRampPalette(c("orange", "darkmagenta", "turquoise"))(8)) + labs(title = "ANTIGUA CORN YIELDS", x = "Ears of corn harvested", y = "Harvest weight") + theme( text = element_text(family = "candara", size = 24), plot.title = element_text(size = 30), plot.caption = element_text(size = 28)) corn
It’s likely you’ll want more than what is available in the standard font set. You can add custom fonts with extrafont()
, however I’ve only had limited success. A better option is using showtext
.
showtext
is a package by Yixuan Qiu and it makes adding new fonts simple. There are a tonne of websites where you can download free fonts to suit pretty much any style you are going for. I’ll only touch on the key bits, so check out the vignette for more details.
The simplest way is to add fonts is via font_add_google()
. Find the font you like on Google Fonts and add it to R using the following.
library(showtext) font_add_google(name = "Amatic SC", family = "amatic-sc")
Amatic SC can now be used by changing the font family to “amatic-sc”. For R to know how to properly render the text we first need to run showtext_auto()
prior to displaying the plot. One downside is it currently does not display in Rstudio. Either open a new graphics window with windows()
or save as an external file e.g. .png
.
# turn on showtext showtext_auto()
Custom fonts are added by first,
.ttf()
file and unzip if neededfont_add()
to register the fontshowtext_auto()
to load the fontsfont_add(family = "docktrin", regular = "./fonts/docktrin/docktrin.ttf") showtext_auto()
And that’s pretty much it. Given how effortless it is to add new fonts you can experiment with many different styles.
These fonts are outrageous but demonstrate that you really can go for any style, from something minimal and easy reading to something fit for a heavy metal band. For professional reports you’ll want to go for something sensible, but if you’re making a poster, website or infographic you may want to get creative e.g.
The tools are there for you to be as creative as you want to be.
The last thing to note is you’ll need to play around with different sizes given the resolution of your screen.
To turn off showtext
and use the standard fonts, run
showtext_auto(FALSE)
# load fonts ft <- data.frame(x = sort(rep(1:4, 31))[1:nrow(fonttable())], y = rep(31:1, 4)[1:nrow(fonttable())], text_name = fonttable()$FullName) font_plot <- ggplot(ft, aes(x = x, y = y)) + geom_text(aes(label = text_name, family = text_name), size = 20) + coord_cartesian(xlim = c(0.5, 4.5)) + theme_void() font_plot # amatic sc scale <- 4.5 # scale is an adjustment for a 4k screen corn <- ggplot(antigua %>% dplyr::filter(ears > 0), aes(x = ears, y = harvwt, col = site)) + geom_point(size = 4) + scale_colour_manual(values = colorRampPalette(c("orange", "darkmagenta", "turquoise"))(8)) + labs(title = "ANTIGUA CORN YIELDS", subtitle = "Study of different treatments and their effect on corn yields in Antigua", x = "Ears of corn harvested", y = "Harvest weight", caption = "@danoehm | gradientdescending.com") + theme_minimal() + theme( text = element_text(family = "amatic-sc", size = 22*scale), plot.title = element_text(size = 26*scale, hjust = 0.5), plot.subtitle = element_text(size = 14*scale, hjust = 0.5), plot.caption = element_text(size = 12*scale), legend.text = element_text(size = 16*scale)) png(file = "./fonts/corn yields amatic-sc.png", width = 3840, height = 2160, units = "px", res = 72*4) corn dev.off() # wanted dead or alive # generate data - unfortunately I couldn't find data on actual outlaws n <- 20 x <- 100*runif(n) y <- (25 + 0.005*x^2 + rnorm(n, 0, 10))*10 wanted <- data.frame(x, y) %>% ggplot(aes(x = x, y = y)) + geom_smooth(col = "black", lty = 2) + geom_point(size = 4) + theme_minimal() + labs(title = "WANTED: DEAD OR ALIVE", subtitle = "Relationship of the crimes committed by outlaws and the bounty on their head", x = "CRIMES", y = "BOUNTY", caption = "@danoehm | gradientdescending.com") + theme_minimal() + theme( text = element_text(family = "docktrin", size = 16*scale), plot.title = element_text(size = 40*scale, hjust = 0.5), plot.subtitle = element_text(size = 14*scale, hjust = 0.5, margin = margin(t = 40)), plot.caption = element_text(size = 10*scale), legend.text = element_text(size = 12*scale), panel.grid = element_line(color = "black"), axis.title = element_text(size = 26*scale), axis.text = element_text(color = "black")) png(file = "./fonts/wanted1.png", width = 3840, height = 2160, units = "px", res = 72*4) ggdraw() + draw_image("./fonts/wanted_dead_or_alive_copped.png", scale = 1.62) + # a png of the background for the plot draw_plot(wanted) dev.off() # Horror movies horror <- read_csv("./fonts/imdb.csv") gghorror <- horror %>% dplyr::filter(str_detect(keywords, "horror")) %>% dplyr::select(original_title, release_date, vote_average) %>% ggplot(aes(x = release_date, y = vote_average)) + geom_smooth(col = "#003300", lty = 2) + geom_point(size = 4, col = "white") + labs(title = "HORROR MOVIES!", subtitle = "Average critics ratings of horror films and relationship over time", x = "YEAR", y = "RATING", caption = "Data from IMDB.\n@danoehm | gradientdescending.com") + theme_minimal() + theme( text = element_text(family = "swamp-witch", size = 16*scale, color = "#006600"), plot.title = element_text(size = 48*scale, hjust = 0.5), plot.subtitle = element_text(family = "montserrat", size = 14*scale, hjust = 0.5, margin = margin(t = 30)), plot.caption = element_text(family = "montserrat", size = 8*scale), legend.text = element_text(size = 12*scale), panel.grid = element_line(color = "grey30"), axis.title = element_text(size = 26*scale), axis.title.y = element_text(margin = margin(r = 15)), axis.text = element_text(color = "#006600"), plot.background = element_rect(fill = "grey10")) png(file = "./fonts/swamp.png", width = 3840, height = 2160, units = "px", res = 72*4) gghorror dev.off()
The post Adding Custom Fonts to ggplot in R appeared first on Daniel Oehm | Gradient Descending.
]]>Townsville, Qld, has been inundated with torrential rain and has broken the record of the largest rainfall over a 10 […]
The post The Most Amount of Rain over a 10 Day Period on Record appeared first on Daniel Oehm | Gradient Descending.
]]>Townsville, Qld, has been inundated with torrential rain and has broken the record of the largest rainfall over a 10 day period. It has been devastating for the farmers and residents of Townsville. I looked at Townsville’s weather data to understand how significant this event was and if there have been comparable events in the past.
Where this may interest the R community is in obtaining the data. The package ‘bomrang’ is an API allowing R users to fetch weather data directly from the Australian Bureau of Meteorology (BOM) and have it returned in a tidy data frame.
Historical weather observations including rainfall, min/max temperatures and sun exposure are obtained via get_historical()
. Either the ID or the lat-long coordinates of the weather station are needed to extract the data. The ID information can be found on the BOM website by navigating to the observations page.
Using the station ID the rainfall data is extracted with the following.
suppressPackageStartupMessages(library(tidyverse)) suppressPackageStartupMessages(library(bomrang)) suppressPackageStartupMessages(library(gridExtra)) suppressPackageStartupMessages(library(magrittr)) suppressPackageStartupMessages(library(MCMCpack)) mycols <- c("darkmagenta", "turquoise") # import data - simple as townsville <- get_historical("032040")
And cleaning up the dates.
# fix date functions fix.date <- function(y,m,d){ s <- paste(c(y, m, d), collapse = "-") d <- as.Date(s, "%Y-%m-%d") return(d) } for(k in 1:nrow(townsville)){ townsville$date[k] <- fix.date(townsville$Year[k], townsville$Month[k], townsville$Day[k]) } townsville$date <- as.Date(townsville$date, origin = "1970-01-01") # trim - for replication of when this was first run townsville %<>% dplyr::filter(date < as.Date("2019-02-13")) # take a look head(townsville)
## Product_code Station_number Year Month Day Rainfall Period Quality ## 1 IDCJAC0009 32040 1941 1 1 0.0 NA Y ## 2 IDCJAC0009 32040 1941 1 2 6.6 1 Y ## 3 IDCJAC0009 32040 1941 1 3 16.5 1 Y ## 4 IDCJAC0009 32040 1941 1 4 205.5 1 Y ## 5 IDCJAC0009 32040 1941 1 5 175.0 1 Y ## 6 IDCJAC0009 32040 1941 1 6 72.9 1 Y ## date ## 1 1941-01-01 ## 2 1941-01-02 ## 3 1941-01-03 ## 4 1941-01-04 ## 5 1941-01-05 ## 6 1941-01-06
Applying a 10 day rolling window over the entire historical record it’s easy to see the significance of this rainfall event. The 8th February recorded 1259.8mm of rain in the 10 days prior. It dwarfs the previous record of 925.5mm set in 1953. It also highlights other significant events in the past, in particular 1968, 1998 and 2009 but these don’t come close to the 2019 event.
# get 10 day total townsville$rolling10 <- 0 for(k in 10:nrow(townsville)){ townsville$rolling10[k] <- sum(townsville$Rainfall[(k-9):k]) } # plot ggplot( townsville %>% dplyr::filter(date > as.Date("1940-01-01")), aes(x = date, y = rolling10, col = rolling10)) + geom_line() + scale_color_gradientn(colors = colorRampPalette(mycols)(32)) + labs(y = "Total rainfall in the last 10 days")
It really was a phenomenal amount of rain. This was not the largest rainfall in a day however, that record occurred in 1998 with a massive 548.8mm of rain. In fact the 2019 floods don’t feature in the top 10 wettest days, but the consistency over 10 days made it the wettest on record.
townsville %>% arrange(desc(Rainfall)) %>% dplyr::select(date, Rainfall) %>% head(10)
## date Rainfall ## 1 1998-01-11 548.8 ## 2 1946-03-03 366.5 ## 3 1953-01-16 346.7 ## 4 1977-02-01 317.6 ## 5 1997-03-24 302.8 ## 6 1978-01-31 273.4 ## 7 2000-04-04 271.6 ## 8 1946-02-10 251.7 ## 9 2009-02-03 236.8 ## 10 1944-03-29 233.4
Townsville received over a years worth of rain in 10 days. The graph below shows the annual rainfall measurements and average annual rainfall (dotted line) given the historical records. Even with only 5-6 weeks of the year, 2019 is already one of the wettest years on record.
# calculate the total annual rainfall and rainfall to date annual.rainfall <- townsville %>% dplyr::filter(date > as.Date("1940-01-01")) %>% mutate( rainfall_to_date = as.numeric(as.POSIXlt(date)$yday < 40)*Rainfall, rainfall_after_wet = as.numeric(as.POSIXlt(date)$yday < 90)*Rainfall ) %>% group_by(Year) %>% summarise( annual = sum(Rainfall), Feb12th = sum(rainfall_to_date), april = sum(rainfall_after_wet), remaining = sum(Rainfall) - sum(rainfall_to_date) ) # bar plot ggplot(annual.rainfall, aes(x = Year, y = annual, fill = annual)) + geom_bar(stat = "identity") + scale_fill_gradientn(colors = colorRampPalette(mycols)(32)) + geom_hline(yintercept = mean(annual.rainfall$annual, na.rm = TRUE), lty = 3, col = "grey20", lwd = 1) + labs(y = "Total annual rainfall")
On close inspection the data suggests that the first 40 years of records are less variable than from 1980 on-wards. There appears to be drier years and wetter years in the latter half.
The current record was set in 2000 at 2400mm where in this year Townsville had a few heavy rainfall events in the months up until April and some lesser events near the end of the year. Comparing 2019 to these years, there is definitely potential for 2019 to be the wettest on record.
ggplot(townsville %>% dplyr::filter(Year %in% c(1950, 1956, 2000, 2019)), aes(x = as.POSIXlt(date)$yday, y = Rainfall)) + geom_line() + facet_grid(Year ~ .) + labs(x = "Day of the year") + labs(title = "Wettest 3 years on record vs 2019")
Below clearly shows which years have had significant rainfall in the first part of the year. The years which have received greater than 700mm (dotted line) are quite distinct from the bulk of the data. Since the wet season ends in April the other heavy years (like 2000) haven’t had their major events yet. This is shown in the April plot at the bottom which has a much stronger relationship (obviously). The years which experienced heavy rainfall at this time of year, in general didn’t get too much afterwards.
grid.arrange( ggplot(annual.rainfall, aes(x = Feb12th, y = annual, col = annual)) + geom_point(size = c(rep(2, nrow(annual.rainfall)-1), 4)) + scale_color_gradientn(colors = colorRampPalette(mycols)(32)) + labs(y = "Annual rainfall", x = "Total rainfall as at 12th Feb") + geom_vline(xintercept = 700, lty = 3, col = "grey20", lwd = 1), ggplot(annual.rainfall, aes(x = april, y = annual, col = annual)) + geom_point(size = c(rep(2, nrow(annual.rainfall)-1), 4)) + scale_color_gradientn(colors = colorRampPalette(mycols)(32)) + labs(y = "Annual rainfall", x = "Total rainfall as at 1st April") + geom_vline(xintercept = 700, lty = 3, col = "grey20", lwd = 1) )
For what it’s worth I’ll offer a prediction for the expected annual rainfall and probability of being the wettest year on record (which, to be honest is a fools errand – tropical weather systems are pretty complex stuff)
# bayesian model blm <- MCMCregress(annual ~ Feb12th, data = annual.rainfall %>% dplyr::filter(Feb12th > 0, Year != 2019), sigma.mu = 235, sigma.var = 35^2) # priors from exploration - details skipped here # prediction x <- matrix(c(1, 1444)) pred.annual.rainfall <- data.frame( annual = blm[,-3] %*% x + rnorm(10000, 0, sqrt(blm[,"sigma2"])), # posterior predictive distribution exp.val = blm[,-3] %*% x) # mean distribution # c(min(pred.annual.rainfall$annual), table(pred.annual.rainfall$annual < 1444)/10000) n <- 1000 xstart <- rep(0, n) xend <- rep(1700, n) ystart <- blm[1:n,1] + blm[1:n,2]*xstart yend <- blm[1:n,1] + blm[1:n,2]*xend ystartp <- blm[1:n,1] + blm[1:n,2]*xstart + rnorm(n, 0, sqrt(blm[,3])) yendp <- blm[1:n,1] + blm[1:n,2]*xend + rnorm(n, 0, sqrt(blm[,3])) df.seg <- data.frame(xstart, xend, ystart, yend, ystartp, yendp) exp.val <- quantile(pred.annual.rainfall$exp.val, c(0.025, 0.5, 0.975)) post.pred <- quantile(pred.annual.rainfall$annual, c(0.025, 0.5, 0.975)) prob.record.rainfall <- sum(as.numeric(pred.annual.rainfall$annual > 2400))/10000 # drawing each line of the posterior draws takes time but its a nice aesthetic df.pred <- data.frame(x = c(1444, 1444), y = range(exp.val)) a <- 6 df.post <- data.frame(xs = rep(1444, 3) + c(0, -a, -a), xe = rep(1444, 3) + c(0, a, a), ys = c(post.pred[1], post.pred[1], post.pred[3]), ye = c(post.pred[3], post.pred[1], post.pred[3])) ggplot(annual.rainfall %>% dplyr::filter(Year != 2019, Feb12th > 0), aes(x = Feb12th, y = annual)) + geom_segment(mapping = aes(x = xstart, xend = xend, y = ystartp, yend = yendp), data = df.seg, col = "grey20", alpha = I(0.05)) + geom_segment(mapping = aes(x = xstart, xend = xend, y = ystart, yend = yend), data = df.seg, col = "darkmagenta", alpha = I(0.025)) + geom_point() + labs(y = "Annual rainfall", x = "Total rainfall as at 12th Feb") + geom_line(mapping = aes(x = x, y = y), data = df.pred, size = 2) + geom_segment(mapping = aes(x = xs, xend = xe, y = ys, yend = ye), data = df.post) + geom_hline(yintercept = 2400, col = "red", lty = 2)
Based on this (crude) model, the expected annual rainfall credible interval is (2027, 2472) mm. Using the posterior predictive distribution for 2019, the probability 2019 will experience record rainfall is 0.29.
But let’s all hope there’s not too much more rain. There has been enough.
The post The Most Amount of Rain over a 10 Day Period on Record appeared first on Daniel Oehm | Gradient Descending.
]]>In my last post I looked at generating synthetic data sets with the ‘synthpop’ package, some of the challenges and […]
The post Synthesising Multiple Linked Data Sets and Sequences in R appeared first on Daniel Oehm | Gradient Descending.
]]>In my last post I looked at generating synthetic data sets with the ‘synthpop’ package, some of the challenges and neat things the package can do. It is simple to use which is great when you have a single data set with independent features.
This post will build on the last post by tackling other complications when attempting to synthesise data. These challenges occur regularly in practice and this post will offer a couple of solutions, but there are plenty more. If you haven’t read my previous post, check that out first and swing back here. I’ll detail how more complex synthesis can be done using synthpop.
When synthesising data, it is generally the finest grain which is synthesised. For population data the person is the finest grain, making synthesis relatively straight forward. But what happens when there are multiple tables to synthesise which are all linked by primary keys and have different grains finer than the person level? Lets look at an example.
Consider a health database which has two tables 1) a patient level table with demographic features and 2) an appointment level fact table consisting of the date of each appointment, including the diagnosis, treatment, prescriptions, etc, all the things you expect on a health record. In the second case the grain is patient x appointment.
Preserving this structure is important, otherwise there could be flow-on complications, plus there could be a loss of information. The obvious solution is to join the patient table onto the appointment table consolidating all the information, then synthesise. Right? Not always, for two key reasons,
Some variables such as appointment dates should be treated as a sequence where the current appointment is dependent on the previous appointment date. The reason being, if a patient has an initial consultation, is diagnosed with some ailment e.g. fracture, fatigue, etc, it’s likely there are follow up appointments with the doctor for treatment. Synthesis should reflect this sequence.
Other examples where sequences should be preserved (not necessarily in the health domain) may include:
or any other time series for that matter. It may not always be the case where the finest grain is a sequence, but it is likely to be. The key point is to know your data and what structure you need to preserve.
With these constraints the synthesised tables needs to retain the properties of properly synthesised data, meaning no variable is synthesised using the original data.
Synthesising data at multiple grains can be done in a number of different ways. I’ll outline one particular way:
This requires some fancy footwork with synthpop but it can be done. There are many other ways to do this however where possible I opt for the path of least resistance.
Appointment level data is first generated to use as original data. Key points,
Appointment date will be simulated to replicate the attendance pattern as outlined above. This pattern is likely to be correlated with other variables such as age and sex e.g. older people visit the doctor more. For this post 2 patterns will be simulated which are loosely correlated with age. Hidden Markov Models are used for the simulation. More complicated models such as state space models, ARIMA models or RNN’s can be used for simulation. (As a side note HMM’s are essentially a discrete form of state space models).
The cost of the appointment will be simulated from a multivariate normal distribution where and (pulling some numbers out of the air) and truncated to 0 where appropriate. These are not correlated with any other variable (in the real world they would be but for lets keep it simple for now) and the assumption of independence is reasonable. This means we don’t have to do anything fancy with synthpop and simply use the built-in functions.
Dates are particularly annoying to deal with at the best of times and simulating them is no exception. Here is one way.
suppressPackageStartupMessages(library(synthpop)) suppressPackageStartupMessages(library(tidyverse)) suppressPackageStartupMessages(library(HMM)) suppressPackageStartupMessages(library(cluster)) suppressPackageStartupMessages(library(doParallel)) suppressPackageStartupMessages(library(foreach)) suppressPackageStartupMessages(library(mvtnorm)) suppressPackageStartupMessages(library(magrittr)) suppressPackageStartupMessages(library(gridExtra)) mycols <- c("darkmagenta", "turquoise") myseed <- 20190201
# set data patient.df <- SD2011 %>% dplyr::select(sex, age, socprof, income, marital, depress, sport, nofriend, smoke, nociga, alcabuse, bmi) # set unique id patient.df$patient_id <- 1:nrow(patient.df) head(patient.df)
## sex age socprof income marital depress sport nofriend smoke ## 1 FEMALE 57 RETIRED 800 MARRIED 6 NO 6 NO ## 2 MALE 20 PUPIL OR STUDENT 350 SINGLE 0 NO 4 NO ## 3 FEMALE 18 PUPIL OR STUDENT NA SINGLE 0 NO 20 NO ## 4 FEMALE 78 RETIRED 900 WIDOWED 16 YES 0 NO ## 5 FEMALE 54 SELF-EMPLOYED 1500 MARRIED 4 YES 6 YES ## 6 MALE 20 PUPIL OR STUDENT -8 SINGLE 5 NO 10 NO ## nociga alcabuse bmi patient_id ## 1 -8 NO 30.79585 1 ## 2 -8 NO 23.44934 2 ## 3 -8 NO 18.36547 3 ## 4 -8 NO 30.46875 4 ## 5 20 NO 20.02884 5 ## 6 -8 NO 23.87511 6
# patient group low <- sample(1:2, nrow(patient.df), replace = TRUE, prob = c(0.25, 0.75)) high <- sample(1:2, nrow(patient.df), replace = TRUE, prob = c(0.75, 0.25)) patient.df$group <- ifelse(patient.df$age < median(patient.df$age), low, high) # set my HMM simulation function mysimHMM <- function (hmm, num.events, which.event = NULL){ hmm$transProbs[is.na(hmm$transProbs)] = 0 hmm$emissionProbs[is.na(hmm$emissionProbs)] = 0 states = c() emission = c() states = c(states, sample(hmm$States, 1, prob = hmm$startProbs)) i <- 2 if(is.null(which.event)) which.event <- hmm$Symbols[1] while(sum(emission == which.event) < num.events) { state = sample(hmm$States, 1, prob = hmm$transProbs[states[i-1], ]) states = c(states, state) emi = sample(hmm$Symbols, 1, prob = hmm$emissionProbs[states[i-1], ]) emission = c(emission, emi) i <- i + 1 } return(list(states = states, observation = emission)) } # initialise HMM hmm.list <- list(initHMM(States = c("episode", "no episode"), Symbols = c("Y", "N"), transProbs = matrix(rbind( c(0.92, 0.08), c(0.02, 0.98) )), emissionProbs = matrix(rbind( c(0.25, 0.75), c(0.001, 0.999) ))), initHMM(States = c("episode", "no episode"), Symbols = c("Y", "N"), transProbs = matrix(rbind( c(0.92, 0.08), c(0.1, 0.9) )), emissionProbs = matrix(rbind( c(0.25, 0.75), c(0.001, 0.999) )))) # appointments start dates will be contained within a 3 year window start.dates <- seq(as.Date("2014-01-01"), as.Date("2016-12-31"), 1) simulate.appointment.dates <- function(){ n <- rpois(1, rgamma(1, 1)*12) + 1 appointment.dates <- mysimHMM(hmm.list[[patient.df$group[k]]], n)$observation l <- length(appointment.dates) dts <- sample(start.dates, 1) + 0:(l-1) appointment_dates = dts[which(appointment.dates == "Y")] out <- data.frame(patient_id = rep(k, length = n), appointment_dates, cost = rmvnorm(n, mean = c(100, 500, 1000), sigma = diag(c(75, 100, 125)^2)) %>% sample(n) %>% pmax(0)) return(out) } ncores <- min(detectCores(), 4) registerDoParallel(ncores) appointment.df <- foreach(k = 1:nrow(patient.df), .combine = rbind, .export = "rmvnorm") %dopar% simulate.appointment.dates() appointment.df %<>% dplyr::filter(appointment_dates < as.Date("2019-01-31"))
Plotting the sequence of appointment dates, we can see the structure we want to preserve.
# look at the clustering of appointments df <- appointment.df %>% dplyr::filter(patient_id < 7) g.date.sequences <- ggplot(df, aes(x = appointment_dates, y = patient_id)) + geom_hline(yintercept = 1:6, col = "grey80") + geom_point(col = "turquoise", size = 4) + coord_cartesian(xlim = range(df$appointment_dates)) # costs g.costs <- ggplot(appointment.df, aes(x = cost)) + geom_histogram(fill = "turquoise") # plot them grid.arrange(g.date.sequences, g.costs)
To create a link between the patient level data and the appointment level data the patient table is clustered using a simple hierarchical clustering procedure (it can be whatever method you think is suitable but I’m choosing hierarchical for this example).
Here 20 clusters are defined. In practice, more thought may need to go into how many clusters are defined given the problem. The clusters are joined to the appointment table to create the link.
# --------------- CLUSTER PATIENTS TO EXPAND TO APPOINTMENT LEVEL dsim <- patient.df %>% dplyr::select(-patient_id) %>% dist()
clust <- hclust(dsim, method = "ward.D2") nc <- 20 cut.clust <- cutree(clust, nc) table(cut.clust)
## cut.clust ## 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 ## 385 179 800 626 569 532 464 277 349 92 60 35 165 199 74 4 168 11 ## 19 20 ## 8 3
# put cluster on patient table patient.df$cluster <- cut.clust appointment.df %<>% left_join(patient.df %>% dplyr::select(patient_id, cluster, group), by = "patient_id")
The patient table is now synthesised, including the clusters and a new patient id created.
# --------------- SYNTHESISE PATIENT DATA FRAME p.df <- patient.df %>% dplyr::select(-patient_id, -group) # apply rules to ensure consistency rules.list <- list(marital = "age < 18", nociga = "smoke == 'NO'") rules.value.list <- list(marital = "SINGLE", nociga = -8) # setting continuous variable NA list cont.na.list <- list(income = c(NA, -8), nofriend = c(NA, -8), nociga = c(NA, -8)) # synth patient.synth <- syn(p.df, cont.na = cont.na.list, rules = rules.list, rvalues = rules.value.list)$syn
## ## Unexpected values (not obeying the rules) found for variable(s): nociga. ## Rules have been applied but make sure they are correct. ## Synthesis ## ----------- ## sex age socprof income marital depress sport nofriend smoke nociga ## alcabuse bmi cluster
# --------------- NEW PATIENT ID patient.synth$patient_id <- 1:nrow(patient.synth) head(patient.synth)
## sex age socprof income marital depress sport ## 1 FEMALE 31 EMPLOYED IN PUBLIC SECTOR 1240 MARRIED 4 YES ## 2 FEMALE 49 OTHER ECONOMICALLY INACTIVE 600 MARRIED 7 YES ## 3 FEMALE 20 OTHER ECONOMICALLY INACTIVE -8 SINGLE 0 YES ## 4 FEMALE 26 <NA> 400 SINGLE 4 NO ## 5 MALE 27 OTHER ECONOMICALLY INACTIVE 900 SINGLE 0 YES ## 6 FEMALE 53 EMPLOYED IN PRIVATE SECTOR -8 SINGLE 5 YES ## nofriend smoke nociga alcabuse bmi cluster patient_id ## 1 3 NO -8 NO 20.06920 7 1 ## 2 20 YES 10 NO 25.14861 8 2 ## 3 10 NO -8 NO 18.02596 3 3 ## 4 5 NO -8 NO 22.49135 2 4 ## 5 7 YES 20 NO 23.56663 4 5 ## 6 2 YES 10 NO 26.53376 3 6
The next step is to randomly select a true patient id from within the same cluster. By doing so each synthesised patient gets a random patient’s appointment history. Since the history is randomly selected within the same cluster we can be confident the structure of the data is retained e.g. a 50 year old male that doesn’t smoke gets a history which is realistic and supported by the data.
# --------------- SAMPLE PATIENT ID WITHIN CLUSTER IN ORDER TO LINK TO APPOINTMENT DATASET sample.patient <- function(x){ z <- which(patient.df$cluster == x) n <- length(which(patient.synth$cluster == x)) out <- sample(patient.df$patient_id[z], n, replace = TRUE) return(out) } patient.list <- lapply(as.list(1:nc), sample.patient) patient.synth$link.patient <- NA for(k in 1:nc){ z <- which(patient.synth$cluster == k) patient.synth$link.patient[z] <- patient.list[[k]] } # -------------- LINK TO APPOINTMENT patient.synth %<>% left_join(appointment.df %>% dplyr::select(-cluster), by = c("link.patient" = "patient_id")) %>% arrange(patient_id, appointment_dates) head(patient.synth %>% dplyr::select(patient_id, sex, age, socprof, appointment_dates, cost, cluster)) # only displaying a few variables for convenience
## patient_id sex age socprof appointment_dates ## 1 1 FEMALE 31 EMPLOYED IN PUBLIC SECTOR 2015-06-18 ## 2 2 FEMALE 49 OTHER ECONOMICALLY INACTIVE 2016-04-26 ## 3 2 FEMALE 49 OTHER ECONOMICALLY INACTIVE 2016-05-30 ## 4 2 FEMALE 49 OTHER ECONOMICALLY INACTIVE 2016-06-02 ## 5 2 FEMALE 49 OTHER ECONOMICALLY INACTIVE 2016-07-02 ## 6 2 FEMALE 49 OTHER ECONOMICALLY INACTIVE 2016-07-03 ## cost cluster ## 1 1063.82122 7 ## 2 380.20829 8 ## 3 492.43475 8 ## 4 1085.44437 8 ## 5 1038.01058 8 ## 6 68.89496 8
This data is suitable to be simulated using HMM’s (we know this for sure since we created the data set!). A new function is written to be consumed by syn()
and the data synthesised during the same process.
This is where things can get tricky and it is important to have a good handle on your data. To synthesise the sequence data, a model needs to be trained on the original data first. In practice syn()
takes care of the training in most situations. Either the function is defined to train the model before synthesis or consume pre-trained models. This will depend on the problem at hand.
I am going to skip the training step here. Training HMM’s (or any state space models) can be tricky and require careful tuning. Instead I am simply going to use the predefined HMM’s above that were used to create the original data. The primary aim is to demonstrate how to synthesise the data rather than train HMM’s.
syn.hmm()
functionThe function to pass to syn()
is given below. The list of HMM’s is passed directly to the function which returns an array of dates.
syn.hmm <- function (y, x, xp, id = patient.synth$link.patient, hmm_list = hmm.list, group = patient.synth$group, ...){ start.dates <- seq(as.Date("2014-01-01"), as.Date("2016-12-31"), 1) uniq.id <- unique(id) n <- length(uniq.id) gr <- unique(cbind(id, group)) new.y <- vector(mode = "numeric", length = nrow(xp)) for(k in 1:n){ ii <- which(id == uniq.id[k]) gr <- group[k] obs <- mysimHMM(hmm = hmm_list[[gr]], length(ii))$observation l <- length(obs) dts <- sample(start.dates, 1) + 0:(l-1) new.y[ii] <- dts[which(obs == "Y")] } return(list(res = as.Date(new.y, origin = "1970-01-01"), fit = "hmm")) }
To synthesise only the appointment dates and cost but ignore the remaining variables, we define the predictor matrix and methods vector.
Firstly, syn()
will switch the method to “sample” if the variable does not have a predictor. To trick the function into doing the what we want we can set the appointment dates variable to have at least 1 predictor which it won’t use because we haven’t told it to in syn.hmm()
.
Secondly, cost can be synthesised as normal and for this example is synthesised using age, sex and income for demonstration.
Thirdly, to tell the function to ignore all other variables we set the visit sequence to be appointment dates followed by cost and nothing else. It will ignore the other already synthesised variables and attach them to the data frame.
# set data frame df <- patient.synth %>% dplyr::select(-cluster, -group) synth.obj <- syn(df, method = rep("sample", ncol(df)), m = 0) # create new predictor matrix new.pm <- synth.obj$predictor.matrix new.pm["cost", c("sex", "age", "income")] <- 1 new.pm["appointment_dates", c("link.patient")] <- 1 # set visit sequences vseq <- c(which(colnames(new.pm) == "appointment_dates"), which(colnames(new.pm) == "cost")) # set methods array mth <- rep("", ncol(new.pm)) mth[vseq] <- c("hmm", "cart") # synthesise data synth.obj <- syn(df, method = mth, visit.sequence = vseq, predictor.matrix = new.pm)
## ## Variable(s): socprof, marital, depress, sport, nofriend, smoke, nociga, alcabuse, bmi, patient_id not synthesised or used in prediction. ## CAUTION: The synthesised data will contain the variable(s) unchanged. ## ## Synthesis ## ----------- ## appointment_dates cost
head(synth.obj$syn %>% dplyr::select(sex, age, income, appointment_dates, cost)) # only displaying a few variables for convenience
## sex age income appointment_dates cost ## 1 FEMALE 31 1240 2016-04-30 74.43372 ## 2 FEMALE 49 600 2016-06-25 68.89496 ## 3 FEMALE 49 600 2016-07-27 475.31540 ## 4 FEMALE 49 600 2016-07-28 126.08026 ## 5 FEMALE 49 600 2016-11-03 1009.78769 ## 6 FEMALE 49 600 2016-11-04 1038.01058
synth.obj$method
## sex age socprof income ## "" "" "" "" ## marital depress sport nofriend ## "" "" "" "" ## smoke nociga alcabuse bmi ## "" "" "" "" ## patient_id link.patient appointment_dates cost ## "" "" "hmm" "cart"
synth.obj$visit.sequence
## appointment_dates cost ## 15 16
# compare cost # look at the clustering of appointments g.date.sequences <- ggplot(synth.obj$syn %>% dplyr::filter(patient_id < 7), aes(x = appointment_dates, y = patient_id)) + geom_hline(yintercept = 1:6, col = "grey80") + geom_point(col = "turquoise", size = 4) + coord_cartesian(xlim = range(df$appointment_dates)) grid.arrange(g.date.sequences, compare(synth.obj, df, vars = "cost", cols = mycols)$plot)
Again, a solid effort. Appointment dates display a similar patterns and the cost variable matches the original distributions.
The plots below show the comparison between the synthetic and original data for the mean and total cost, appointment count and mean appointment count per patient. All given by sex and marital status.
The outputs are very similar with the exception of the mean number of appointments for those in de facto relationships. This is likely a property of the small sample size in this group.
When synthesising data from any relational database these challenges will be present. There will be a business need to preserve the same relational form to ensure the synthetic data is used in the same way as the original.
There are many ways of dealing with the problem of different grains and synthesising sequences – this is only one. With respect to sequences and time series data, careful thought is needed to use the right model for synthesis and define the appropriate function.
This example is relatively straight forward to highlight a couple of key challenges. The synthpop package is flexible enough to handle these more complex real world scenarios.
If you have solved this in other ways, let me know in the comments!
The post Synthesising Multiple Linked Data Sets and Sequences in R appeared first on Daniel Oehm | Gradient Descending.
]]>Synthpop – A great music genre and an aptly named R package for synthesising population data. I recently came across […]
The post Generating Synthetic Data Sets with ‘synthpop’ in R appeared first on Daniel Oehm | Gradient Descending.
]]>Synthpop – A great music genre and an aptly named R package for synthesising population data. I recently came across this package while looking for an easy way to synthesise unit record data sets for public release. The goal is to generate a data set which contains no real units, therefore safe for public release and retains the structure of the data. From which, any inference returns the same conclusion as the original. This will be a quick look into synthesising data, some challenges that can arise from common data structures and some things to watch out for.
This example will use the same data set as in the synthpop documentation and will cover similar ground, but perhaps an abridged version with a few other things that weren’t mentioned. The SD2011 contains 5000 observations and 35 variables on social characteristics of Poland. A subset of 12 of these variables are considered.
suppressPackageStartupMessages(library(synthpop)) suppressPackageStartupMessages(library(tidyverse)) suppressPackageStartupMessages(library(sampling)) suppressPackageStartupMessages(library(partykit)) mycols <- c("darkmagenta", "turquoise") options(xtable.floating = FALSE) options(xtable.timestamp = "") myseed <- 20190110
# filtering the dataset original.df <- SD2011 %>% dplyr::select(sex, age, socprof, income, marital, depress, sport, nofriend, smoke, nociga, alcabuse, bmi) head(original.df)
## sex age socprof income marital depress sport nofriend smoke ## 1 FEMALE 57 RETIRED 800 MARRIED 6 NO 6 NO ## 2 MALE 20 PUPIL OR STUDENT 350 SINGLE 0 NO 4 NO ## 3 FEMALE 18 PUPIL OR STUDENT NA SINGLE 0 NO 20 NO ## 4 FEMALE 78 RETIRED 900 WIDOWED 16 YES 0 NO ## 5 FEMALE 54 SELF-EMPLOYED 1500 MARRIED 4 YES 6 YES ## 6 MALE 20 PUPIL OR STUDENT -8 SINGLE 5 NO 10 NO ## nociga alcabuse bmi ## 1 -8 NO 30.79585 ## 2 -8 NO 23.44934 ## 3 -8 NO 18.36547 ## 4 -8 NO 30.46875 ## 5 20 NO 20.02884 ## 6 -8 NO 23.87511
The objective of synthesising data is to generate a data set which resembles the original as closely as possible, warts and all, meaning also preserving the missing value structure. There are two ways to deal with missing values 1) impute/treat missing values before synthesis 2) synthesise the missing values and deal with the missings later. The second option is generally better since the purpose the data is supporting may influence how the missing values are treated.
Missing values can be simply NA or some numeric code specified by the collection. A useful inclusion is the syn
function allows for different NA types, for example income, nofriend and nociga features -8 as a missing value. A list is passed to the function in the following form.
# setting continuous variable NA list cont.na.list <- list(income = c(NA, -8), nofriend = c(NA, -8), nociga = c(NA, -8))
By not including this the -8’s will be treated as a numeric value and may distort the synthesis.
After synthesis, there is often a need to post process the data to ensure it is logically consistent. For example, anyone who is married must be over 18 and anyone who doesn’t smoke shouldn’t have a value recorded for ‘number of cigarettes consumed’. These rules can be applied during synthesis rather than needing adhoc post processing.
# apply rules to ensure consistency rules.list <- list( marital = "age < 18", nociga = "smoke == 'NO'") rules.value.list <- list( marital = "SINGLE", nociga = -8)
The variables in the condition need to be synthesised before applying the rule otherwise the function will throw an error. In this case age should be synthesised before marital and smoke should be synthesised before nociga.
There is one person with a bmi of 450.
SD2011[which.max(SD2011$bmi),]
## sex age agegr placesize region edu ## 3953 FEMALE 58 45-59 URBAN 20,000-100,000 Lubelskie SECONDARY ## eduspec socprof unempdur income ## 3953 economy and administration LONG-TERM SICK/DISABLED -8 1300 ## marital mmarr ymarr msepdiv ysepdiv ls depress ## 3953 MARRIED 4 1982 NA NA MOSTLY SATISFIED 1 ## trust trustfam trustneigh sport nofriend smoke ## 3953 ONE CAN`T BE TOO CAREFUL YES YES YES 2 NO ## nociga alcabuse alcsol workab wkabdur wkabint wkabintdur emcc englang ## 3953 -8 NO NO NO -8 NO <NA> <NA> NONE ## height weight bmi ## 3953 149 NA 449.9797
Their weight is missing from the data set and would need to be for this to be accurate. I don’t believe this is correct! So, any bmi over 75 (which is still very high) will be considered a missing value and corrected before synthesis.
# getting around the error: synthesis needs to occur before the rules are applied original.df$bmi <- ifelse(original.df$bmi > 75, NA, original.df$bmi)
Consider a data set with variables. In a nutshell, synthesis follows these steps:
The data can now be synthesised using the following code.
# synthesise data synth.obj <- syn(original.df, cont.na = cont.na.list, rules = rules.list, rvalues = rules.value.list, seed = myseed)
## ## Unexpected values (not obeying the rules) found for variable(s): nociga. ## Rules have been applied but make sure they are correct. ## Synthesis ## ----------- ## sex age socprof income marital depress sport nofriend smoke nociga ## alcabuse bmi
synth.obj
## Call: ## ($call) syn(data = original.df, rules = rules.list, rvalues = rules.value.list, ## cont.na = cont.na.list, seed = myseed) ## ## Number of synthesised data sets: ## ($m) 1 ## ## First rows of synthesised data set: ## ($syn) ## sex age socprof income marital depress sport ## 1 FEMALE 45 EMPLOYED IN PUBLIC SECTOR 1500 SINGLE 5 YES ## 2 MALE 65 OTHER ECONOMICALLY INACTIVE 500 SINGLE 5 YES ## 3 FEMALE 17 PUPIL OR STUDENT NA SINGLE 3 NO ## 4 FEMALE 48 EMPLOYED IN PRIVATE SECTOR 1000 MARRIED 0 NO ## 5 MALE 50 EMPLOYED IN PRIVATE SECTOR 1300 MARRIED 0 NO ## 6 FEMALE 65 RETIRED 1200 WIDOWED 6 NO ## nofriend smoke nociga alcabuse bmi ## 1 3 NO -8 NO 26.12245 ## 2 30 NO -8 NO 29.32099 ## 3 7 NO -8 NO 22.40588 ## 4 10 NO -8 NO 25.81663 ## 5 12 YES 20 YES 27.17063 ## 6 15 NO -8 NO 27.51338 ## ... ## ## Synthesising methods: ## ($method) ## sex age socprof income marital depress sport nofriend ## "sample" "cart" "cart" "cart" "cart" "cart" "cart" "cart" ## smoke nociga alcabuse bmi ## "cart" "cart" "cart" "cart" ## ## Order of synthesis: ## ($visit.sequence) ## sex age socprof income marital depress sport nofriend ## 1 2 3 4 5 6 7 8 ## smoke nociga alcabuse bmi ## 9 10 11 12 ## ## Matrix of predictors: ## ($predictor.matrix) ## sex age socprof income marital depress sport nofriend smoke ## sex 0 0 0 0 0 0 0 0 0 ## age 1 0 0 0 0 0 0 0 0 ## socprof 1 1 0 0 0 0 0 0 0 ## income 1 1 1 0 0 0 0 0 0 ## marital 1 1 1 1 0 0 0 0 0 ## depress 1 1 1 1 1 0 0 0 0 ## sport 1 1 1 1 1 1 0 0 0 ## nofriend 1 1 1 1 1 1 1 0 0 ## smoke 1 1 1 1 1 1 1 1 0 ## nociga 1 1 1 1 1 1 1 1 1 ## alcabuse 1 1 1 1 1 1 1 1 1 ## bmi 1 1 1 1 1 1 1 1 1 ## nociga alcabuse bmi ## sex 0 0 0 ## age 0 0 0 ## socprof 0 0 0 ## income 0 0 0 ## marital 0 0 0 ## depress 0 0 0 ## sport 0 0 0 ## nofriend 0 0 0 ## smoke 0 0 0 ## nociga 0 0 0 ## alcabuse 1 0 0 ## bmi 1 1 0
The compare
function allows for easy checking of the sythesised data.
# compare the synthetic and original data frames compare(synth.obj, original.df, nrow = 3, ncol = 4, cols = mycols)$plot
Solid. The distributions are very well preserved. Did the rules work on the smoking variable?
# checking rules worked table(synth.obj$syn[,c("smoke", "nociga")])
## nociga ## smoke -8 1 2 3 4 5 6 7 8 10 12 14 15 ## YES 13 13 13 32 5 49 17 12 28 308 26 3 135 ## NO 3777 0 0 0 0 0 0 0 0 0 0 0 0 ## nociga ## smoke 16 18 20 21 22 24 25 26 29 30 35 40 50 ## YES 5 7 418 2 1 2 31 2 2 51 1 33 1 ## NO 0 0 0 0 0 0 0 0 0 0 0 0 0 ## nociga ## smoke 60 ## YES 1 ## NO 0
They did. All non-smokers have missing values for the number of cigarettes consumed.
compare
can also be used for model output checking. A logistic regression model will be fit to find the important predictors of depression. The depression variable ranges from 0-21. This will be converted to
This split leaves 3822 (0)’s and 1089 (1)’s for modelling.
# ------------ MODEL COMPARISON glm1 <- glm.synds(ifelse(depress > 7, 1, 0) ~ sex + age + log(income) + sport + nofriend + smoke + alcabuse + bmi, data = synth.obj, family = "binomial")
## Warning in log(income): NaNs produced
summary(glm1)
## Warning: Note that all these results depend on the synthesis model being correct. ## ## Fit to synthetic data set with a single synthesis. ## Inference to coefficients and standard errors that ## would be obtained from the observed data. ## ## Call: ## glm.synds(formula = ifelse(depress > 7, 1, 0) ~ sex + age + log(income) + ## sport + nofriend + smoke + alcabuse + bmi, family = "binomial", ## data = synth.obj) ## ## Combined estimates: ## xpct(Beta) xpct(se.Beta) xpct(z) Pr(>|xpct(z)|) ## (Intercept) -1.00819811 0.71605764 -1.4080 0.1591356 ## sexFEMALE 0.35681507 0.10010909 3.5643 0.0003649 *** ## age 0.09004527 0.00384758 23.4031 < 2.2e-16 *** ## log(income) -0.68041355 0.08829602 -7.7060 1.298e-14 *** ## sportNO -0.66446597 0.11880451 -5.5929 2.233e-08 *** ## nofriend 0.00028325 0.00679104 0.0417 0.9667305 ## smokeNO 0.08544511 0.11894243 0.7184 0.4725269 ## alcabuseNO -0.72124437 0.21444108 -3.3634 0.0007700 *** ## bmi 0.00644972 0.01036016 0.6226 0.5335801 ## --- ## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
# compare to the original data set compare(glm1, original.df, lcol = mycols)
## Warning in log(income): NaNs produced
## ## Call used to fit models to the data: ## glm.synds(formula = ifelse(depress > 7, 1, 0) ~ sex + age + log(income) + ## sport + nofriend + smoke + alcabuse + bmi, family = "binomial", ## data = synth.obj) ## ## Differences between results based on synthetic and observed data: ## Std. coef diff p value CI overlap ## (Intercept) -1.26517500 0.206 0.6772453 ## sexFEMALE 0.27373709 0.784 0.9301678 ## age 0.85530291 0.392 0.7818065 ## log(income) 0.98568572 0.324 0.7485449 ## sportNO 0.16485706 0.869 0.9579439 ## nofriend 1.31926470 0.187 0.6634467 ## smokeNO -0.07386025 0.941 0.9811578 ## alcabuseNO 0.02501199 0.980 0.9936193 ## bmi -0.17971185 0.857 0.9541543 ## ## Measures for one synthesis and 9 coefficients ## Mean confidence interval overlap: 0.8542318 ## Mean absolute std. coef diff: 0.5714007 ## Lack-of-fit: 5.49732; p-value 0.789 for test that synthesis model is compatible ## with a chi-squared test with 9 degrees of freedom ## ## Confidence interval plot:
While the model needs more work, the same conclusions would be made from both the original and synthetic data set as can be seen from the confidence interavals. Occaisonally there may be contradicting conclusions made about a variable, accepting it in the observed data but not in the synthetic data for example. This scenario could be corrected by using different synthesis methods (see documentation) or altering the visit sequence.
Released population data are often counts of people in geographical areas by demographic variables (age, sex, etc). Some cells in the table can be very small e.g. <5. For privacy reasons these cells are suppressed to protect peoples identity. With a synthetic data, suppression is not required given it contains no real people, assuming there is enough uncertainty in how the records are synthesised.
The existence of small cell counts opens a few questions,
To test this 200 areas will be simulated to replicate possible real world scenarios. Area size will be randomly allocated ensuring a good mix of large and small population sizes. Population sizes are randomly drawn from a Poisson distribution with mean . If large, is drawn from a uniform distribution on the interval [20, 40]. If small, is set to 1.
# ---------- AREA # add area flag to the data frame area.label <- paste0("area", 1:200) a <- sample(0:1, 200, replace = TRUE, prob = c(0.5, 0.5)) lambda <- runif(200, 20, 40)*(1-a) + a prob.dist <- rpois(200, lambda) area <- sample(area.label, 5000, replace = TRUE, prob = prob.dist) # attached to original data frame original.df <- SD2011 %>% dplyr::select(sex, age, socprof, income, marital, depress, sport, nofriend, smoke, nociga, alcabuse, bmi) original.df$bmi <- ifelse(original.df$bmi > 75, NA, original.df$bmi) original.df <- cbind(original.df, area) %>% arrange(area)
The sequence of synthesising variables and the choice of predictors is important when there are rare events or low sample areas. If Synthesised very early in the procedure and used as a predictor for following variables, it’s likely the subsequent models will over-fit. Synthetic data sets require a level of uncertainty to reduce the risk of statistical disclosure, so this is not ideal.
Fortunately syn
allows for modification of the predictor matrix. To avoid over-fitting, ‘area’ is the last variable to by synthesised and will only use sex and age as predictors. This is reasonable to capture the key population characteristics. Additionally, syn
throws an error unless maxfaclevels
is changed to the number of areas (the default is 60). This is to prevent poorly synthesised data for this reason and a warning message suggest to check the results, which is good practice.
# synthesise data # m is set to 0 as a hack to set the synds object and the predictor matrix synth.obj.b <- syn(original.df, cont.na = cont.na.list, rules = rules.list, rvalues = rules.value.list, maxfaclevels = 200, seed = myseed, m = 0)
## ## Unexpected values (not obeying the rules) found for variable(s): nociga. ## Rules have been applied but make sure they are correct.
# changing the predictor matrix to predict area with only age and sex new.pred.mat <- synth.obj.b$predictor.matrix new.pred.mat["area",] <- 0 new.pred.mat["area",c("age", "sex")] <- 1 new.pred.mat
## sex age socprof income marital depress sport nofriend smoke ## sex 0 0 0 0 0 0 0 0 0 ## age 1 0 0 0 0 0 0 0 0 ## socprof 1 1 0 0 0 0 0 0 0 ## income 1 1 1 0 0 0 0 0 0 ## marital 1 1 1 1 0 0 0 0 0 ## depress 1 1 1 1 1 0 0 0 0 ## sport 1 1 1 1 1 1 0 0 0 ## nofriend 1 1 1 1 1 1 1 0 0 ## smoke 1 1 1 1 1 1 1 1 0 ## nociga 1 1 1 1 1 1 1 1 1 ## alcabuse 1 1 1 1 1 1 1 1 1 ## bmi 1 1 1 1 1 1 1 1 1 ## area 1 1 0 0 0 0 0 0 0 ## nociga alcabuse bmi area ## sex 0 0 0 0 ## age 0 0 0 0 ## socprof 0 0 0 0 ## income 0 0 0 0 ## marital 0 0 0 0 ## depress 0 0 0 0 ## sport 0 0 0 0 ## nofriend 0 0 0 0 ## smoke 0 0 0 0 ## nociga 0 0 0 0 ## alcabuse 1 0 0 0 ## bmi 1 1 0 0 ## area 0 0 0 0
# synthesising with new predictor synth.obj.b <- syn(original.df, cont.na = cont.na.list, rules = rules.list, rvalues = rules.value.list, maxfaclevels = 200, seed = myseed, proper = TRUE, predictor.matrix = new.pred.mat)
## ## Unexpected values (not obeying the rules) found for variable(s): nociga. ## Rules have been applied but make sure they are correct. ## Synthesis ## ----------- ## sex age socprof income marital depress sport nofriend smoke nociga ## alcabuse bmi area
# compare the synthetic and original data frames compare(synth.obj.b, original.df, vars = "area", nrow = 1, ncol = 1, cols = c("darkmagenta", "turquoise"), stat = "counts")$plot
The area variable is simulated fairly well on simply age and sex. It captures the large and small areas, however the large areas are relatively more variable. This could use some fine tuning, but will stick with this for now.
tab.syn <- synth.obj.b$syn %>% dplyr::select(area, sex) %>% table() tab.orig <- original.df %>% dplyr::select(area, sex) %>% table()
## ## synthetic ## sex ## area MALE FEMALE ## area1 2 0 ## area10 15 19 ## area101 0 6 ## area103 35 22 ## area105 3 0 ## area106 30 31 ## area107 0 3 ## area108 28 11 ## area110 17 37 ## area112 23 24 ## area113 0 2 ## area114 21 52 ## area115 30 28 ## ## original ## sex ## area MALE FEMALE ## area1 1 0 ## area10 19 27 ## area101 1 3 ## area103 29 26 ## area105 1 0 ## area106 22 33 ## area107 0 4 ## area108 28 18 ## area110 18 33 ## area112 23 25 ## area113 1 2 ## area114 28 39 ## area115 23 44
d <- data.frame(difference = as.numeric(tab.syn - tab.orig), sex = c(rep("Male", 150), rep("Female", 150))) ggplot(d, aes(x = difference, fill = sex)) + geom_histogram() + facet_grid(sex ~ .) + scale_fill_manual(values = mycols)
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
The method does a good job at preserving the structure for the areas. How much variability is acceptable is up to the user and intended purpose. Using more predictors may provide a better fit. The errors are distributed around zero, a good sign no bias has leaked into the data from the synthesis.
tab.syn <- synth.obj.b$syn %>% dplyr::select(marital, sex) %>% table() tab.orig <- original.df %>% dplyr::select(marital, sex) %>% table()
## ## synthetic ## sex ## marital MALE FEMALE ## SINGLE 667 565 ## MARRIED 1352 1644 ## WIDOWED 76 398 ## DIVORCED 81 169 ## LEGALLY SEPARATED 2 0 ## DE FACTO SEPARATED 6 36 ## ## original ## sex ## marital MALE FEMALE ## SINGLE 657 596 ## MARRIED 1382 1597 ## WIDOWED 66 465 ## DIVORCED 62 137 ## LEGALLY SEPARATED 6 1 ## DE FACTO SEPARATED 4 18
At higher levels of aggregation the structure of tables is more maintained.
‘synthpop’ is built with a similar function to the ‘mice’ package where user defined methods can be specified and passed to the syn
function using the form syn.newmethod
. To demonstrate this we’ll build our own neural net method.
As a minimum the function takes as input
y
– observed variable to be synthesisedx
– observed predictorsxp
– synthesised predictorssyn.nn <- function (y, x, xp, smoothing, size = 6, ...) { for (i in which(sapply(x, class) != sapply(xp, class))) xp[, i] <- eval(parse(text = paste0("as.", class(x[, i]), "(xp[,i])", sep = ""))) # model and prediction nn <- nnet(y ~ ., data = as.data.frame(cbind(y, x)), size = size, trace = FALSE) probs <- predict(nn, newdata = xp) probs[is.na(probs)] <- 0 for(k in 1:nrow(probs)){ if(sum(probs[k,]) == 0){ probs[k,] <- 1 } } new <- apply(probs, 1, function(x) colnames(probs)[sample(1:ncol(probs), 1, prob = x)]) %>% unname() # if smothing if (!is.factor(y) & smoothing == "density") new <- syn.smooth(new, y) # return return(list(res = new, fit = nn)) }
Set the method vector to apply the new neural net method for the factors, ctree
for the others and pass to syn
.
# methods vector meth.vec <- c("sample", ifelse(sapply(original.df[,-1], is.factor), "nn", "ctree")) meth.vec[13] <- "ctree" # synthesise synth.obj.c <- syn(original.df, method = meth.vec, cont.na = cont.na.list, rules = rules.list, rvalues = rules.value.list, maxfaclevels = 200, seed = myseed, predictor.matrix = new.pred.mat)
## ## Unexpected values (not obeying the rules) found for variable(s): nociga. ## Rules have been applied but make sure they are correct. ## Synthesis ## ----------- ## sex age socprof income marital depress sport nofriend smoke nociga ## alcabuse bmi area
# compare the synthetic and original data frames compare(synth.obj.c, original.df, vars = colnames(original.df)[-13], nrow = 3, ncol = 4, cols = c("darkmagenta", "turquoise"), stat = "counts")$plot
The results are very similar to above with the exception of ‘alcabuse’, but this demonstrates how new methods can be applied.
The ‘synthpop’ package is great for synthesising data for statistical disclosure control or creating training data for model development. Other things to note,
Following posts tackle complications that arise when there are multiple tables at different grains that are to be synthesised. Further complications arise when their relationships in the database also need to be preserved. Ideally the data is synthesised and stored alongside the original enabling any report or analysis to be conducted on either the original or synthesised data. This will require some trickery to get synthpop to do the right thing, but is possible.
The post Generating Synthetic Data Sets with ‘synthpop’ in R appeared first on Daniel Oehm | Gradient Descending.
]]>In my last post I coded Liar’s Dice in R and some brainless bots to play against. I build on […]
The post Q-learning example with Liar’s Dice in R appeared first on Daniel Oehm | Gradient Descending.
]]>In my last post I coded Liar’s Dice in R and some brainless bots to play against. I build on that post by using Q-learning to train an agent to play Liar’s Dice well.
Spoiler alert: The brainless bots aren’t actually that brainless! More on that later.
Note – I’ll share enough code to run the simulations however the full code can be found on Github. Check out my previous post for the rules to Liar’s Dice.
Firstly, some background. Q-learning is a reinforcement learning algorithm which trains an agent to make the right decisions given the environment it is in and what tasks it needs to complete. The task may be navigating a maze, playing a game, driving a car, flying a drone or learning which offers to make to increase customer retention. In the case of Liar’s Dice, when to call or raise the bid. The agent learns through
This cycle is repeated until eventually it learns the best decisions to make given it’s situation.
The environment is formulated as a Markov Decision Process defined by a set of states . By taking an action the agent will transition to a new state with probability . After transitioning to the new state the agent receives a reward which will either tell it this was a good move, or this was a bad move. It may also tell the agent this was neither a good or bad move until it reaches a win/lose state.
Finding the optimal policy of an MDP can be done through value iteration and policy iteration. The optimal policy and state value functions are given by
and
where and are learning rate and discount parameters. The above equations rely on knowing (or often, making crude approximations to) the transition probabilities. The benefit of Q-learning is the transition probabilities are not required, instead they are derived through simulation. The Q function is given by
The cells in the Q matrix represent the ‘quality’ of taking action given state . After each action the Q matrix is updated. After many iterations, the agent would have explored many states and determined which states and action pairs led to the best outcomes. Now it has the information needed to make the optimal choice by taking the action which leads to the maximum overall reward indicated by the largest Q value.
The key is to formulate the states the agent can be in at any point in the game and the reward for transitioning from one state to another. MDP’s can become very large very quickly if every possible state is accounted for so it’s important to identify the key information and the redundancies.
The key pieces of information needed to make a decision in Liar’s Dice are
Consider the player has 6 dice, this gives a possible possible hands and this hasn’t yet factored in the bid or the total number of dice on the table. You can see how the number of states blows out.
To make a good decision on whether or not to raise or call the player only needs to know how many dice of the current bid value the player has in their hand and the chance the remainder are in the unseen opponents dice. Essentially, the dice value isn’t required in the formulation of the game state.
The states are given by 3 values.
The last point is the combination of the information given by the players dice and the bid. The probability there is at least the bid quantity on the table is calculated and reduced to a bucket.
where
and is the unknown quantity needed and is the number of unobserved dice on the table. This reduces the state space down to something more manageable. For this example we’ll use a maximum of 20 buckets i.e. (5%, 10%, …, 100%). Overkill for small numbers of dice, but it doesn’t hurt.
The function below generates the complete game states given the number of dice and players.
# generate games states generate.game.states <- function(nplayers, ndice, bin.size = 20){ # create the basic data frame with total dice and player dice count total.dice <- nplayers*ndice states <- t(combn(rep(0:ndice, nplayers), nplayers)) %>% as.data.frame() %>% distinct() colnames(states) <- paste0("p", 1:nplayers) states$total <- rowSums(states) states <- states %>% dplyr::select(total, p1) %>% dplyr::arrange(total, p1) %>% dplyr::filter(total > 0) %>% distinct # add in the probability bucket state game.states <- data.frame() for(k in 1:nrow(states)){ total <- states$total[k] p1 <- states$p1[k] for(j in 0:bin.size){ game.states <- rbind(game.states, c(total, p1, j)) } } colnames(game.states) <- c("total", "p1", "prob_cat") return(game.states) } gs <- generate.game.states(4, 6) dim(gs)
## [1] 2772 3
The state space reduces to only 2772 states for a game with 4 players with 6 dice each where previously it would have been several orders of magnitude larger. There are still redundant states in this formulation (mostly because I’m lazy) but it’s been reduced enough to be viable and won’t significantly slow down training.
To simplify the problem, the agent only needs to decide on whether to call or raise. A more complicated problem would be to allow the agent to choose what the new bid should be (this is for a later post, for now we’ll keep it simple).
The agent will explore the states randomly and will eventually learn when it’s a good time to call and a good time to raise. For example if the bid is three 5’s and the agent has three 5’s in hand, the obvious action is to raise. The agent won’t know this at first but will soon work it out.
If the agent raises, it will first randomly select whether to bluff or play the numbers. By bluffing the agent randomly selects a dice value and increases the bid by 1. If the agent plays the numbers it selects the value it has the most of and raises the quantity by 1.
When the agent has been trained it makes the optimal decision by selecting the maximum Q value given the current state .
The agent needs to know how good taking action was. The reward matrix is defined by rewarding
The reward values are arbitrary but work well in this case. We want to emphasize that losing a die is bad but losing the game is worse. While any state other than the terminal states i.e. when the number dice the player has is 0 (lose) or the same as the total number of dice on the table (win) no state is particularly good/bad but the transition from one to the other is what triggers the reward or penalty. Therefore, each reward matrix will be an square matrix where is the total number of states.
There is a reward matrix for each action and stored in a list. For Liar’s Dice this isn’t necessary since the rewards and penalties are same whether the player raises or calls and transitions to another state. However, the framework is there for actions to have different rewards.
# generate reward matrix generate.reward.matrix <- function(game.states){ reward <- matrix(0, nrow = nrow(game.states), ncol = nrow(game.states)) for(i in 1:nrow(reward)){ # which state are we in total <- game.states$total[i] p1 <- game.states$p1[i] # small penalty for losing a die reward[i, p1 - game.states$p1 == 1 & total - game.states$total == 1 & game.states$p1 != game.states$total] <- -1 # small reward for others losing a die reward[i, p1 == game.states$p1 & total - game.states$total == 1 & game.states$p1 != game.states$total & p1 > 0] <- 1 # fail states - when players dice count is 0 if(p1 == 1){ reward[i, which(total - game.states$total == 1 & game.states$p1 == 0)] <- -10 } # win states when the player dice count equals the total dice count if(total - p1 == 1){ reward[i, game.states$total == p1 & game.states$p1 == p1] <- 10 } } return(reward) } rw <- generate.reward.matrix(gs) reward <- list(raise = rw, call = rw)
The process follows the steps below.
Assume player 1 raises on their turn. In a 4 person game, player 1 may actually transition to multiple other states before control returns. For each other raise or call by the other players, the game state will change for player 1. For the context of the model the action player 1 took is considered to be the last action for all subsequent transitions.
Here is an example of the state transition table for 4 players each with 3 dice.
play.liars.dice(auto = TRUE, players = 4, num.dice = 3, verbose = 0, agents = agents, Q.mat = Q.mat, print.trans = TRUE)$winner
## y.ctrl y.state y.action ## 1 3 808 raise ## 2 3 787 raise ## 3 3 778 raise ## 4 4 778 raise ## 5 4 736 raise ## 6 4 736 raise ## 7 1 736 call ## 8 1 673 call ## 9 1 678 raise ## 10 2 674 raise ## 11 3 673 raise ## 12 4 673 raise ## 13 4 589 raise ## 14 4 592 raise ## 15 1 589 call ## 16 1 505 call ## 17 1 507 call ## 18 1 423 call ## 19 2 422 raise ## 20 3 421 raise ## 21 1 421 raise ## 22 2 421 raise ## 23 2 316 raise ## 24 2 324 raise ## 25 2 240 raise ## 26 3 232 raise ## 27 3 148 raise ## 28 2 151 raise ## 29 2 88 raise
## [1] 1
This table is then passed to the update.Q()
function.
# update Q matrix update.Q <- function(play, Q.mat, reward, alpha = 0.1, discount = 0.9){ for(k in 2:nrow(play)){ curr.state <- play$y.state[k] prev.state <- play$y.state[k-1] action <- play$y.action[k] # Q update Q.mat[prev.state, action] <- (1-alpha)*Q.mat[prev.state, action] + alpha*(reward[[action]][prev.state, curr.state] + discount*max(Q.mat[curr.state,])) } return(Q.mat) }
The learning rate and discount values have been initialised to 0.1 and 0.9 respectively.
Liar’s Dice is now simulated 5000 times and Q value iteration is conducted with the above functions (see github for the full code). The first agent will be the only one that uses the Q matrix to decide it’s actions and therefore the only agent that is trained. It will bluff with a probability of 50% to add in some more realism to the agents decision. The other 3 will be random agents, bluffing 100% of the time and randomly deciding to call or raise at each decision point. It is expected that after training agent 1 will outperform the other 3 random agents.
# set the agents agent1 <- build.agent(c(0.5,0.5), method = "Q.decide") agent2 <- build.agent(c(1,0), method = "random") agent3 <- build.agent(c(1,0), method = "random") agent4 <- build.agent(c(1,0), method = "random") agents <- list(agent1, agent2, agent3, agent4)
# training its <- 5000 # setting the vector to store the winners winners <- vector("numeric", length = its) # loopy loop pb <- txtProgressBar(min = 0, max = its, style = 3) for(k in 1:its){ out <- play.liars.dice(auto = TRUE, players = 4, num.dice = 6, verbose = 0, agents = agents, Q.mat = Q.mat, train = TRUE) winners[k] <- out$winner Q.mat <- out$Q.mat setTxtProgressBar(pb, k) }
# table of winners table(winners)
## winners ## 1 2 3 4 ## 3288 434 496 782
# agent win percentage x <- 1000:5000 y <- (cumsum(winners == 1)/(1:5000))[x] qplot(x = x, y = y, geom = "line", xlab = "iterations", ylab = "Proportion of agent 1 wins")
After only 5000 iterations (which isn’t a lot given there are approximately 2000 valid states) the results show that agent 1 performs very well against the random agents. If each agent was equivalent the win percentage would be on average 25% where as here the results show agent 1 won 65% of the games.
The graph shows the percentage of wins for agent 1 continuing to increase as it is trained. Further training will improve the Q matrix and hence the performance of the agent. Given the stochastic nature of the game we wouldn’t expect a win percentage of 100%, so this is a great result.
Here is another 100 games with the trained agent.
library(pbapply) sim <- pbsapply(1:100, function(x) play.liars.dice(auto = TRUE, players = 4, num.dice = 6, verbose = 0, agents = agents, Q.mat = Q.mat)[["winner"]]) table(sim)
## sim ## 1 2 3 4 ## 65 6 15 14
sum(sim == 1)/100
## [1] 0.65
Solid effort.
What’s really happening here? The last variable in our state space formulation is the probability bucket which is in essence an approximation of the actual probability that the bid quantity exists on the table. At first the agent doesn’t know what to do with that information and will decide to call or raise randomly. Over time it learns how best to use that information and either calls or raises. In my previous post we simply used the probability directly by randomly choosing to raise with probability and call with probability . So in truth the original bots weren’t too bad.
The Q-learning algorithm has an advantage by being able to solve for more complex scenarios. The original agents only had the probability to base a decision, where as under an MDP framework the agent is free to also make decisions based on how many dice they have in hand and how many on the table. It has the ability to vary the risk depending on how close it is to winning or losing.
There are ways we can expand the state space to allow for potentially more complex decisions such as factoring in the remaining dice of the person to the left or right and allowing the agent to learn each players bluffing likelihoods. The state space could also be reduced to when a player has 0 dice and 1 or more, since whether the player has 2 or 6 dice may not matter too much. It’s worth an experiment to test this and see if it performs just as well.
In short a few things to take away are,
# set dice value set.dice.value <- function(note, max.val, prev.val = 0){ good.val <- FALSE while(!good.val){ val <- readline(note) %>% as.numeric() if(val > 0 & val <= max.val & !is.na(val) & (val > prev.val)){ good.val <- TRUE }else{ cat("please select a value between 1 and", max.val, "\n") } } return(val) } # roll table roll.table.fn <- function(rolls){ rt <- table(unlist(rolls)) roll.table <- rep(0, 6) names(roll.table) <- 1:6 roll.table[names(rt)] <- rt return(roll.table) } # call probability that there is at least the bid quantity on the table and converts to a bucket calc.prob <- function(x, bin.size = 20) { if(x[3] <= x[4]){ return(1*bin.size) }else{ n <- x[1]-x[2] k <- seq(min(x[3]-x[4], n), n, 1) return(floor(sum(choose(n, k)*(1/6)^k*(5/6)^(n-k))*bin.size)) } } # agent function chooses the best action e.g. raise or call # it needs to take in as input dice, total dice, dice value and dice quantity # as output action (raise or call), if raised also new dice value and quantity # dice, total.dice, dice.value, dice.quantity. # this is wrapped by a building function to make it easier to change certain # parameters and decisions an agent might make and be able to play them off against # each other to see which is the better strategy build.agent <- function(bluff.prob, method = "random"){ return( function(pars, Q.mat){ # bluff or truth bluff <- sample(c(TRUE, FALSE), 1, prob = bluff.prob) # pobability table roll.table <- roll.table.fn(pars$dice) ptable <- roll.table/sum(roll.table) # if the initial bid do this if(is.null(pars$dice.value)){ new.dice.value <- which.max(ptable[1:6]) %>% names() %>% as.numeric() new.dice.quantity <- max(roll.table) + 1 return(list(dice.value = new.dice.value, dice.quantity = new.dice.quantity)) } # are you gonna call? # use the Q matrix to make a decision if(method == "Q.decide"){ if(abs(max(Q.mat[pars$p1.state,]) - min(Q.mat[pars$p1.state,])) < 1e-6) call <- sample(c(TRUE, FALSE), 1) else{ # exploration vs exploitation if(runif(1) < 0.1){ call <- sample(c(TRUE, FALSE), 1) }else{ call <- names(which.max(Q.mat[pars$p1.state,])) == "call" } } # the random agent }else if(method == "random"){ prob <- 0.5 call <- sample(c(TRUE, FALSE), 1, prob = c(1-prob, prob)) # playing the actual numbers }else if(method == "true.prob"){ prob <- 1-sum(dbinom(0:max(c(pars$dice.quantity - roll.table[pars$dice.value], 0)), pars$total.dice-length(pars$dice), prob = 1/6)) call <- sample(c(TRUE, FALSE), 1, prob = c(1-prob, prob)) } # if called return the values if(call){ return(list(action = "call", dice.value = pars$dice.value, dice.quantity = pars$dice.quantity)) }else{ # raise # if choosing to bluff randomly select a number and increase by 1 if(bluff){ new.dice.value <- sample(1:6, 1) new.dice.quantity <- pars$dice.quantity + 1 }else{ # if not bluffing choose the maximum number in hand and increase by one # this should be made to be more flexible however in general raising by # 1 occurs 99% of the time new.dice.value <- which.max(ptable) %>% names() %>% as.numeric() new.dice.quantity <- pars$dice.quantity + 1 } # return the new values and action return(list(action = "raise", dice.value = new.dice.value, dice.quantity = new.dice.quantity)) } } ) } #------ play a round of liars dice liars.dice.round <- function(players, control, player.dice.count, agents, game.states, reward, Q.mat, a = 1, verbose = 1){ # set array for recording results y.ctrl = c(); y.state = c(); y.action = c() # roll the dice for each player if(verbose > 0) cat("\n\n") rolls <- lapply(1:players, function(x) sort(sample(1:6, player.dice.count[[x]], replace = TRUE))) if(verbose > 1) lapply(rolls, function(x) cat("dice: ", x, "\n")) total.dice <- sum(unlist(player.dice.count)) # set penalty penalty <- sapply(1:players, function(x) 0, simplify = FALSE) # print dice blocks if(verbose > 0) Dice(rolls[[1]]) # set up roll table roll.table <- roll.table.fn(rolls) # initial bid if(verbose > 0) cat("place first bid\nPlayer", control, "has control\n") if(control == a){ dice.value <- set.dice.value("dice value: ", 6) dice.quantity <- set.dice.value("quantity; ", sum(roll.table)) }else{ # agent plays p1.state <- which(game.states$total == total.dice & game.states$p1 == player.dice.count[[1]] & game.states$prob_cat == total.dice) pars <- list(dice = rolls[[control]], total.dice = total.dice, dice.value = NULL, dice.quantity = 0, p1.state = p1.state) agent.action <- agents[[control]](pars = pars, Q.mat = Q.mat) dice.value <- agent.action$dice.value dice.quantity <- agent.action$dice.quantity } # calculate probability cat and determine the game state # action set to raise because you can't call without an initial bid # this could be a 3rd action (initial bid) but it's not really necessary player.dice.qty <- table(rolls[[1]])[as.character(dice.value)] player.dice.qty <- ifelse(is.na(player.dice.qty), 0, player.dice.qty) %>% unname prob.cat <- calc.prob(c(total.dice, player.dice.count[[1]], dice.quantity, player.dice.qty)) p1.state <- which(game.states$total == total.dice & game.states$p1 == player.dice.count[[1]] & game.states$prob_cat == prob.cat) p1.action <- "raise" # storing states for Q iteration y.ctrl = c(); y.state = c(); y.action = c() # moving control to the next player # storing the previous player since if the next player calls the previous player could lose a die prev <- control control <- control %% players + 1 if(verbose > 0) cat("dice value ", dice.value, "; dice quantity ", dice.quantity, "\n") # loop through each player and continue until there is a winner and loser called <- FALSE while(!called){ # check if the player with control is still in the game - if not skip if(player.dice.count[[control]] > 0){ if(control == a){ action <- readline("raise or call (r/c)? ") }else{ # the agent makes a decision pars <- list(dice = rolls[[control]], total.dice = total.dice, dice.value = dice.value, dice.quantity = dice.quantity, p1.state = p1.state) agent.action <- agents[[control]](pars = pars, Q.mat = Q.mat) action <- agent.action$action } # storing states for reward iteration if(control == 1 & !is.null(agent.action$action)){ player.dice.qty <- table(rolls[[1]])[as.character(dice.value)] player.dice.qty <- ifelse(is.na(player.dice.qty), 0, player.dice.qty) %>% unname p1.action <- agent.action$action prob.cat <- calc.prob(c(total.dice, player.dice.count[[1]], dice.quantity, player.dice.qty)) p1.state <- which(game.states$total == total.dice & game.states$p1 == player.dice.count[[1]] & game.states$prob_cat == prob.cat) } # called if(action %in% c("call", "c")){ if(verbose > 0) { cat("player", control, "called\nRoll table\n") print(roll.table) } # dice are reavealed # check if the quantity of dice value is less or more than the total in the pool # if more control loses otherwise control-1 win if(dice.quantity > roll.table[dice.value]){ penalty[[prev]] <- penalty[[prev]] - 1 if(verbose > 0) cat("player", prev, "lost a die\n") }else{ penalty[[control]] <- penalty[[control]] - 1 if(verbose > 0) cat("player", control, "lost a die\n") } # for Q iteration y.ctrl <- c(y.ctrl, control); y.state <- c(y.state, p1.state); y.action <- c(y.action, p1.action) # if called use the penalty array to change states prob.cat <- calc.prob(c(total.dice, player.dice.count[[1]], dice.quantity, player.dice.qty)) p1.state <- which(game.states$total == total.dice-1 & game.states$p1 == player.dice.count[[1]]+penalty[[1]] & game.states$prob_cat == prob.cat) # break the loop called <- TRUE }else{ if(verbose > 0) cat("player", control, "raised\n") if(control == a){ # player sets next dice value dice.value <- set.dice.value("dice value: ", 6) dice.quantity <- set.dice.value("quantity; ", sum(roll.table)) }else{ dice.value <- agent.action$dice.value dice.quantity <- agent.action$dice.quantity } # p1 state after the raise prob.cat <- calc.prob(c(total.dice, player.dice.count[[1]], dice.quantity, player.dice.qty)) p1.state <- which(game.states$total == total.dice & game.states$p1 == player.dice.count[[1]] & game.states$prob_cat == prob.cat) if(verbose > 0) cat("dice value", dice.value, "; dice quantity", dice.quantity, "\n") } # store info for Q update y.ctrl <- c(y.ctrl, control); y.state <- c(y.state, p1.state); y.action <- c(y.action, p1.action) # set the control player to now be the previous player prev <- control } # next player has control control <- control %% players + 1 } # play results and return play <- data.frame(y.ctrl, y.state, y.action) return(list(penalty = penalty, play = play)) } # play a full game of liars dice play.liars.dice <- function(players = 4, num.dice = 6, auto = FALSE, verbose = 1, agents, Q.mat = NULL, train = FALSE, print.trans = FALSE){ # begin! if(verbose > 0) liars.dice.title() # setting the number of dice each player has ndice <- sapply(rep(num.dice, players), function(x) x, simplify = FALSE) players.left <- sum(unlist(ndice) > 0) # setting game states matrix game.states <- generate.game.states(players, num.dice) # set up reward matrix reward <- generate.reward.matrix(game.states) reward <- list(raise = reward, call = reward) # set Q matrix if null if(is.null(Q.mat)) Q.mat <- matrix(0, nrow = nrow(reward$raise), ncol = length(reward), dimnames = list(c(), names(reward))) # while there is at least 2 left in the game # who has control ctrl <- sample(1:players, 1) play.df <- data.frame() while(players.left > 1){ # play a round results <- liars.dice.round( players = players, control = ctrl, player.dice.count = ndice, game.states = game.states, reward = reward, Q.mat = Q.mat, agents = agents, a = as.numeric(!auto), verbose = verbose ) # update how many dice the players are left with given the # outcomes of the round for(k in seq_along(ndice)){ ndice[[k]] <- ndice[[k]] + results$penalty[[k]] if(ndice[[k]] == 0 & results$penalty[[k]] == -1){ if(verbose > 0) cat("player", k, "is out of the game\n") } # update who has control so they can start the bidding if(results$penalty[[k]] == -1){ ctrl <- k while(ndice[[ctrl]] == 0){ ctrl <- ctrl %% players + 1 } } } # checking how many are left and if anyone won the game players.left <- sum(unlist(ndice) > 0) if(players.left == 1){ if(verbose > 0) cat("player", which(unlist(ndice) > 0), "won the game\n") } # appending play play.df <- rbind(play.df, results$play) } if(print.trans) print(play.df) # update Q # rather than training after each action, training at the # end of each game in bulk # just easier this way if(train) Q.mat <- update.Q(play.df, Q.mat, reward) # return the winner and Q matrix return(list(winner = which(unlist(ndice) > 0), Q.mat = Q.mat)) }
The post Q-learning example with Liar’s Dice in R appeared first on Daniel Oehm | Gradient Descending.
]]>It’s that time of year again. Christmas. Buying presents is a chore and stressful, so a popular option is Secret […]
The post Secret Santa List Using Python appeared first on Daniel Oehm | Gradient Descending.
]]>It’s that time of year again. Christmas. Buying presents is a chore and stressful, so a popular option is Secret Santa. Buying for one is better than buying for several! But there’s always the problem of who does the selection so it’s secret for everyone?
In this post I’ll share how I make Secret Santa allocations using Python. The program randomises the allocation and automatically sends the emails to each person informing them who they are buying for from your designated email account. It ensures nobody is allocated themselves and the master list is stored just in case it is needed. Sure, there are a plethora of apps that can do this but it’s fun to do it yourself. Plus, you’ll probably use the app once, forget about it for the following year, have to re-learn how to use it and need to input all the data again. With your own script in Python you’ll always know you have it and know how to use it. Each following year it will only take a minute or two and the allocations are done.
The full code template is on Github or at the bottom of this post.
Arguably the most amount of work needed is ensuring your email list is up to date. All those participating need to be listed in the dictionary. There are no checks to ensure the email addresses entered are valid so it’s important to double check them and ensure they are correct.
# EMAIL DICTIONARY # replace the names and emails with those participating # example emails only fam = { 'Alice' : 'alice@gmail.com', 'Bob' : 'bob@gmail.com', 'Chris' : 'chris@gmail.com', 'Daniel' : 'daniel@hotmail.com', 'Evelyn' : 'evelyn@gmail.com', 'Fred' : 'fred@gmail.com' }
The allocation is made simply by first shuffling the names list then pairing the list with itself at lag 1. This way you can be sure no one will be allocated themselves. For example, assume the names list was shuffled [Evelyn, Chris, Fred, Bob, Alice, Daniel]. This list is labelled Santa. The Santa list is then shifted by 1 to create the receiver list, [Daniel, Evelyn, Chris, Fred, Bob, Alice]. Done.
# SELECT SECRET SANTAS santa = list(np.random.choice(list(fam.keys()), len(list(fam.keys())), replace = False)) recvr = [] for k in range(-1, len(santa)-1): recvr.append(santa[k])
Once the allocation is done the emails are sent to each Santa using the smtplib
package. This package allows you to send emails from your designated account directly from Python. It is best done using a Gmail account (pretty sure hotmail is fine but not sure about others).
The connection is first established, then python will ask for your email and account password. These can be hard coded for ease but that’s up to you if you want to take on that extra risk.
Once Python has logged in it will send an email to each Santa individually. This is where you can get creative with your message in the email. Write a poem or send a picture or gif, just make sure you input who they are buying for!
# STARTING EMAILS smtpObj = smtplib.SMTP('smtp.gmail.com', 587) smtpObj.ehlo() smtpObj.starttls() eml = raw_input('Enter your email: ') pwd = raw_input('Enter login details: ') smtpObj.login(eml, pwd) # SENDING EMAILS for k in range(len(fam)): smtpObj.sendmail(eml, fam[santa[k]], \ 'Subject: Secret Santa 2018 \ \nHo Ho Ho %s! \ \n\nChristmas is almost here, \ \nTime to find that Christmas cheer! \ \n\nThis year you are buying for....... %s! \ \n\nThis was sent on %s.' % (santa[k], recvr[k], date)) # QUIT SERVER smtpObj.quit()
Before any emails can be sent from your account you’ll need allow less secure apps to access your account. To do this go to your Google account, select ‘Sign-in & Security’, scroll to the bottom and toggle ‘allow less secure apps’ to ON. You’ll get an email straight away from Google saying your account is not secure, which is fair enough, so it’s good to do this just before you send the emails and switch it back to off afterwards. This can be done with other emails but I’m only familiar with Googs.
Finally, it is good practice to save the master list in case it is needed.
# STORE SELECTIONS FOR SAFE KEEPING santadf = pd.DataFrame({'santa':santa, 'recvr':recvr}) santadf.to_csv('Secret Santa list ' + datem + '.csv')
And that’s it, Secret Santa allocations made. This is a simple template for sending a list but should be easy to add in constraints e.g. spouses don’t get each other, get someone different from last year, etc.
The only other thing to remember is you will have a number of sent emails from your account which will have the Santa and receiver so you could accidentally find out who your Santa is. Also, if a family member replies to the email it could also reveal your Santa. Maybe it would be wise to look into do-not-reply tags!
# SECRET SANTA # LIBRARIES import numpy as np, import datetime import pandas as pd import smtplib # DATE AND TIME dt = datetime.datetime.now() date = dt.strftime('%Y-%m-%d at %I:%M%p') datem = dt.strftime('%Y-%m-%d') # EMAIL DICTIONARY # replace the names and emails with those participating # example emails only fam = { 'Alice' : 'alice@gmail.com', 'Bob' : 'bob@gmail.com', 'Chris' : 'chris@gmail.com', 'Daniel' : 'daniel@hotmail.com', 'Evelyn' : 'evelyn@gmail.com', 'Fred' : 'fred@gmail.com' } # SELECT SECRET SANTAS santa = list(np.random.choice(list(fam.keys()), len(list(fam.keys())), replace = False)) recvr = [] for k in range(-1, len(santa)-1): recvr.append(santa[k]) # STARTING EMAILS smtpObj = smtplib.SMTP('smtp.gmail.com', 587) smtpObj.ehlo() smtpObj.starttls() eml = raw_input('Enter your email: ') pwd = raw_input('Enter login details: ') smtpObj.login(eml, pwd) # SENDING EMAILS for k in range(len(fam)): smtpObj.sendmail(eml, fam[santa[k]], \ 'Subject: Secret Santa 2018 \ \nHo Ho Ho %s! \ \n\nChristmas is almost here, \ \nTime to find that Christmas cheer! \ \n\nThis year you are buying for....... %s! \ \n\nThis was sent on %s.' % (santa[k], recvr[k], date)) # QUIT SERVER smtpObj.quit() # STORE SELECTIONS FOR SAFE KEEPING santadf = pd.DataFrame({'santa':santa, 'recvr':recvr}) santadf.to_csv('Secret Santa list ' + datem + '.csv')
The post Secret Santa List Using Python appeared first on Daniel Oehm | Gradient Descending.
]]>I have been playing Red Dead Redemption 2, immersing myself in the Old West as I did with the first […]
The post Liar’s Dice in R appeared first on Daniel Oehm | Gradient Descending.
]]>I have been playing Red Dead Redemption 2, immersing myself in the Old West as I did with the first game. It’s an incredibly impressive game and there are many side activities that can keep you entertained in the world such as playing Poker in the saloon, Five Finger Fillet and Domino’s. I was disappointed to find out that Liars Dice is not in RDR2 unlike the first game, at least I haven’t found it yet. Instead I decided to code it up in R and play some outlaws off against each other.
The rules of Liars Dice are relatively straight forward but there are also many variants. For this project I’ll begin with the simpler variants. The object of the game is to be the last person standing with at least one die. The game is played in the following steps:
Other rules:
A popular variant on the game, and the one I tend to play is 1’s are considered wild and are included in the bid quantity meaning on reveal the 1’s and whatever the bid value was are counted. This makes the program a little more complicated so for this post I won’t include 1’s as wilds.
There are plenty more variants listed on the wiki page.
The first step is to code a function to play one full round of the game from initial bid to one person losing their die. It begins by each player rolling their dice depending on how many they have left. An initial bid is placed by the agent in control and sets the dice value and dice quantity. Control is handed to the next player and given the dice value, dice quantity and the dice they rolled, make a decision on whether to call or raise the bid (more on this later). If they raise the bid, control moves to the next player and so on until an agent calls and the dice are revealed. Whoever loses gets penalty of -1 (i.e. loses 1 die). It is the penalty list which is returned by this function.
This has been set up to run on autopilot for simulation or manually against bots for fun (point of debate, is playing manually or running the simulations more fun?). If manual, the program will ask the user for input. Code can be found on Github.
# play a round of liars dice liars.dice.round(players, control, player.dice.count, agents, a = 1, verbose = 1)
The next step is to play as many rounds as needed until there is a winner. The inputs to this function are,
The function begins by initialising the number of dice for each player and storing in a list. Control is randomly given to a player (there is actually an advantage for the starting player given the number of players, dice and how the agents make decisions which isn’t very sophisticated mind you). Now the above function is run until there is only 1 player left. The output of this function is the number of the winning player.
# play the game of liars dice play.liars.dice(players = 4, num.dice = 6, auto = FALSE, verbose = 1, agents)
Now that the game is set up we get to the fun part which is how the agents make decisions. For this first post the agents are going to be very basic but in later posts will become more sophisticated.
Similar to Poker, players try to infer what dice the others have rolled based on their bid (dice value and quantity) and body langauge. Arguably picking up on the ques of someones body language is the most important part of the game to know if they’re bluffing or have a good hand, but also impossible to replicate here. Instead the agent will make a decision solely on the bid and the dice in their hand. How an agent can make a decision get’s complicated quickly so we’ll start small.
The good thing about this game is the probabilities are easy to calculate and therefore it’s easy to play the numbers game and choose the most likely scenario. The probability of the dice quantity is given by a binomial distribution. The probability that there is at least dice on the table with value is given by
For example the probability there are at least five 3’s in a pool of 20 dice is
Given their hand, the agent calculates this probability and randomly chooses whether to call or raise based on this probability. For example if there are 25 dice in play, 5 players, the bid is seven 3’s and the player has 2 in hand, the player is only really concerned with the probabilty there is at least another five 3’s in the 20 dice they can’t see. For the case above there is a 77% chance the agent will ‘call’ since it is more likely there is less than five 3’s. We could set a threshold for which the agent will ‘call’ if the probability is below 0.2 for example, but using the probability to either ‘call’ or ‘raise’ is suitable and adds some realism.
If the agent chooses to raise, it then decides if they are going to bluff which is decided randomly. This can be changed for each agent for example an agent may choose to bluff 100% of the time or 50% of the time or not at all. By bluffing the agent randomly selects a die value and increases it by 1 disregarding the dice they have rolled. If the agent chooses not to bluff they select the value which they have the most of in their hand and raises it by 1.
Ideally the agent would use the information from the previous players bids to better decide the next bid. For example a player that always raises to the same number is probably selecting the dice value they have most in their hand in which case sheds light on the other dice on the table. At the end of each round when the dice are revealed each player (if they have a good memory) can determine who was bluffing and who was playing the numbers. This information can be used as prior knowledge and incorporated into the probability calculations.
Assume we know player never bluffs, we know for sure that there is at least 1 die of that value in their hand therefore the probability that there is at least the bid quantity will be slightly higher. For example in the case, above since we know that there is at least 1 that is confirmed the problem then becomes, what is the probability there is at least four 3’s in 19 dice?
If a player always bluffs, then the bid is effectively random. This addition to the probability calculation isn’t too difficult but for now we’ll keep it simple. In the future we could allow the agents to learn the bluffing parameters of each player and refine their decisions.
To summarise, if a player chooses to bluff the dice value is chosen at random and quantity raise by 1. If they choose not to bluff they choose the value which they have the most of in their hand and raise the quantity by 1.
Perhaps the most important part of this project are the aesthetics, a title and randomised dice blocks. Rather than just outputting numbers between 1 and 6 this function will build a blank dice block and then convert it into the value of the dice that is rolled. It’s best to see it in action.
liars.dice.title()
## ## ## __ _______ __ ____ _____ ## / /| /__ __/| / | / _ \ / __/| ## / / / |_/ /|_|/ / | / /_| | \ \__|/ ## / / / / / / / /| | / _ /| \ \ ## / /_/_ __ / /_/ / __ | / / | |/ __ / /| ## /_____/| /_______/| /_/|_|_| /_/ /|_| /____/ / ## |_____|/ |_______|/ |_|/ |_| |_|/ |_| |____|/ ## ## ____ _______ _____ ______ ## / _ \ /__ __/| / ___/| / ____/| ## / / | | |_/ /|_|/ / /|__|/ / /___ |/ ## / / / /| / / / / / / / ____/| ## / /_/ / / __ / /_/ / /_/_ / /____|/ ## /_____/ / /_______/| |____/| /______/| ## |_____|/ |_______|/ |____|/ |______|/
Dice(sample(1:6, 5, replace = TRUE))
## _________ _________ _________ _________ _________ ## / /| / /| / /| / /| / /| ## /________/ | /________/ | /________/ | /________/ | /________/ | ## | o | | | o | | | o | | | o | | | o | | ## | | | | | | | o | | | o | | | | | ## | o | / | o | / | o | / | o | / | o | / ## |________|/ |________|/ |________|/ |________|/ |________|/
Much better than numbers!
To play a game of liars dice simply input the parameters into the following function.
# set the agent # even if playing a manual game input the same number of agents as there are players # the human player will overide one of them agent1 <- build.agent(c(0.5,0.5)) agent2 <- build.agent(c(0.5,0.5)) agent3 <- build.agent(c(0.5,0.5)) agent4 <- build.agent(c(0.5,0.5)) agents <- list(agent1, agent2, agents3, agents4) # play the game play.liars.dice(auto = FALSE, players = 4, num.dice = 6, verbose = 1, agents = agents)
The game starts and player one has control. Sets the bid at four 5’s. Player 2 calls, the dice are reveals and since there are four 5’s player 2 loses a die. Player 2 now has control and starts the bidding.
Player 2 bids three 3’s. Player 3 raises the bid to four 5’s. Player 4 calls and player 3 loses a die since there are only three 5’s on the table. This continues until for a few more rounds until there is a winner.
And there you have it, you can now play Liar’s Dice just like in Red Dead Redemption, just worse graphics. The agent could definitely use some more brains, but still probably better than the NPC’s in Red Dead Redemption to be fair. At some point I may turn this into a shiny app just for fun.
As mentioned, for this early version bluffing essentially means play random. For validation we can simulate many games and ensure that the numbers strategy defeats the random strategy. The game will be simulated 10000 times with 2 agents where one bluffs all the time and the other plays the numbers. We expect the “numbers guy” to win more than half the games, even if only slightly.
# 2 agents exactly the same # 1 bluffs all the time nd the other plays the numbers agent1 <- build.agent(c(1,0)) agent2 <- build.agent(c(0,1)) agents <- list(agent1, agent2) # parallelise compute library(parallel) strt <- Sys.time() n.cores <- detectCores() clust <- makeCluster(n.cores) set.seed(20181120) clusterExport(clust, c("play.liars.dice", "liars.dice.round", "agents", "set.dice.value", "liars.dice.title", "agents")) a <- parSapply(clust, 1:1e4, function(x) play.liars.dice(verbose = 0, auto = TRUE, players = 2, num.dice = 6, agents = agents)) stopCluster(clust) end <- Sys.time() end-strt
## Time difference of 37.1997 secs
# win results table(a)
## a ## 1 2 ## 4934 5066
wins <- table(a)[2] ggplot(data.frame(z = rbeta(1e5, wins, 1e4-wins)), aes(x = z)) + geom_histogram(fill = "darkturquoise", col = "grey20")
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
To be fair, for any given game both agents are almost as equally likely to win. Not too surprising since they aren’t very sophisticated and the high degree of randomness in the game. However, playing the numbers will win just over 50% of the matches in the long run (that’s a long long run though!).
The simulation will be run again for 4 agents with different probabilities of bluffing.
agent1 <- build.agent(c(1, 0)) agent2 <- build.agent(c(0.75, 0.25)) agent3 <- build.agent(c(0.25, 0.75)) agent4 <- build.agent(c(0, 1)) agents <- list(agent1, agent2, agent3, agent4) strt <- Sys.time() clust <- makeCluster(n.cores) set.seed(20181120) clusterExport(clust, c("play.liars.dice", "liars.dice.round", "agents", "set.dice.value", "liars.dice.title", "agents")) a <- parSapply(clust, 1:1e4, function(x) play.liars.dice(verbose = 0, auto = TRUE, players = 4, num.dice = 6, agents = agents)) stopCluster(clust) end <- Sys.time() end-strt
## Time difference of 1.164514 mins
# win results table(a)
## a ## 1 2 3 4 ## 2379 2465 2540 2616
It’s nice to see the numbers in general going up as the bluffing probability goes to 0.
This project was more about simulating the game rather than a high performing agent. But now that we have this we can start to give the agent more brains and play around with different learning methods. Given the random and probabilistic nature of the game it becomes an interesting and challenging RL problem. In a game like Tic Tac Toe all states of the game are known by the player, but in Liar’s Dice the opponents hands are unknown therefore the player doesn’t know for sure which state they are in. The challenge is to reduce the problem down to something more manageable. The results we have seen above are essentially the baseline that we can now improve on with more advanced methods.
The post Liar’s Dice in R appeared first on Daniel Oehm | Gradient Descending.
]]>Recently I developed a solution using a Hidden Markov Model and was quickly asked to explain myself. What are they […]
The post Hidden Markov Model example in R with the depmixS4 package appeared first on Daniel Oehm | Gradient Descending.
]]>Recently I developed a solution using a Hidden Markov Model and was quickly asked to explain myself. What are they and why do they work so well? I can answer the first part, the second we just have to take for granted.
HMM’s are for modelling sequences of data whether they are derived from continuous or discrete probability distributions. They are related to state space and Gaussian mixture models in the sense they aim to estimate the state which gave rise to the observation. The states are unknown or ‘hidden’ and HMM’s attempt to estimate the states similar to an unsupervised clustering procedure.
Before getting into the basic theory behind HMM’s, here’s a (silly) toy example which will help to understand the core concepts. There are 2 dice and a jar of jelly beans. Bob rolls the dice, if the total is greater than 4 he takes a handful of jelly beans and rolls again. If the total is equal to 2 he takes a handful jelly beans then hands the dice to Alice. It’s now Alice’s turn to roll the dice. If she rolls greater than 4 she takes a handful of jelly beans however she isn’t a fan of any other colour than the black ones (a polarizing opinion) so puts the others back, therefore we would expect Bob to take more than Alice. They do this until the jar is empty.
Now assume Alice and Bob are in a different room and we can’t see who is rolling the dice. Instead we only know how many jelly beans were taken after the roll. We don’t know the colour, simply the final number of jelly beans that were removed from the jar on that turn. How could we know who rolled the dice? HMM’s.
In this example the state is the person who rolled the dice, Alice or Bob. The observation is how many jelly beans were removed on that turn. The roll of the dice and the condition of passing the dice if the value is less than 4 is the transition probability. Since we made up this example we can calculate the transition probability exactly i.e. 1/12. There is no condition saying the transition probabilities need to be the same, Bob could hand the dice over when he rolls a 2 for example meaning a probability of 1/36.
Firstly, we’ll simulate the example. On average Bob takes 12 jelly beans and Alice takes 4.
# libraries library(depmixS4) library(ggplot2) library(gridExtra) library(reshape2) # the setup # functions simulate <- function(N, dice.val = 6, jbns, switch.val = 4){ # simulate variables # could just use one dice sample but having both alice and bob makes it simple to try # different mechanics e.g. bob only throws 1 die, or whatever other probability distribution # you want to set. bob.dice <- sample(1:dice.val, N, replace = T) + sample(1:dice.val, N, replace = T) alice.dice <- sample(1:dice.val, N, replace = T) + sample(1:dice.val, N, replace = T) bob.jbns <- rpois(N, jbns[1]) alice.jbns <- rpois(N, jbns[2]) # states draws <- data.frame(state = rep(NA, N), obs = rep(NA, N), dice = rep(NA, N)) draws$state[1] <- "alice" draws$obs <- alice.jbns[1] draws$dice <- alice.dice[1] for(k in 2:N){ if(draws$state[k-1] == "alice"){ if(draws$dice[k-1] < switch.val+1){ draws$state[k] <- "bob" draws$obs[k] <- bob.jbns[k] draws$dice[k] <- bob.dice[k] }else{ draws$state[k] <- "alice" draws$obs[k] <- alice.jbns[k] draws$dice[k] <- alice.dice[k] } }else if(draws$state[k-1] == "bob"){ if(draws$dice[k-1] < switch.val+1){ draws$state[k] <- "alice" draws$obs[k] <- alice.jbns[k] draws$dice[k] <- alice.dice[k] }else{ draws$state[k] <- "bob" draws$obs[k] <- bob.jbns[k] draws$dice[k] <- bob.dice[k] } } } # return return(cbind(roll = 1:N, draws)) } # simulate scenario set.seed(20181031) N <- 100 draws <- simulate(N, jbns = c(12, 4), switch.val = 4) # observe results mycols <- c("darkmagenta", "turquoise") cols <- ifelse(draws$state == "alice", mycols[1], mycols[3]) ggplot(draws, aes(x = roll, y = obs)) + geom_line()
As you can see it’s difficult from simply inspecting the series of counts determine who rolled the dice. Using the depmixS4 package we’ll fit a HMM. Since we are dealing with count data the observations are drawn from a Poisson distribution.
fit.hmm <- function(draws){ # HMM with depmix mod <- depmix(obs ~ 1, data = draws, nstates = 2, family = poisson()) # use gaussian() for normally distributed data fit.mod <- fit(mod) # predict the states by estimating the posterior est.states <- posterior(fit.mod) head(est.states) # results tbl <- table(est.states$state, draws$state) draws$est.state.labels <- c(colnames(tbl)[which.max(tbl[1,])], colnames(tbl)[which.max(tbl[2,])])[est.states$state] est.states$roll <- 1:100 colnames(est.states)[2:3] <- c(colnames(tbl)[which.max(tbl[1,])], colnames(tbl)[which.max(tbl[2,])]) hmm.post.df <- melt(est.states, measure.vars = c("alice", "bob")) # print the table print(table(draws[,c("state", "est.state.labels")])) # return it return(list(draws = draws, hmm.post.df = hmm.post.df)) } hmm1 <- fit.hmm(draws)
## iteration 0 logLik: -346.2084 ## iteration 5 logLik: -274.2033 ## converged at iteration 7 with logLik: -274.2033 ## est.state.labels ## state alice bob ## alice 49 2 ## bob 3 46
The model converges quickly. Using the posterior probabilities we estimate which state the process is in i.e. who has the dice, Alice or Bob. To answer that question specifically we need to know more about the process. In this case we do, we know Alice only likes the black jelly beans. Otherwise we can only say the process is in state 1 or 2 (or however many states you believe there are). The plots below show are well the HMM fits the data and estimates the hidden states.
# plot output plot.hmm.output <- function(model.output){ g0 <- (ggplot(model.output$draws, aes(x = roll, y = obs)) + geom_line() + theme(axis.ticks = element_blank(), axis.title.y = element_blank())) %>% ggplotGrob g1 <- (ggplot(model.output$draws, aes(x = roll, y = state, fill = state, col = state)) + geom_bar(stat = "identity", alpha = I(0.7)) + scale_fill_manual(values = mycols, name = "State:\nPerson that\nrolled the\ndice", labels = c("Alice", "Bob")) + scale_color_manual(values = mycols, name = "State:\nPerson that\nrolled the\ndice", labels = c("Alice", "Bob")) + theme(axis.ticks = element_blank(), axis.text.y = element_blank()) + labs(y = "Actual State")) %>% ggplotGrob g2 <- (ggplot(model.output$draws, aes(x = roll, y = est.state.labels, fill = est.state.labels, col = est.state.labels)) + geom_bar(stat = "identity", alpha = I(0.7)) + scale_fill_manual(values = mycols, name = "State:\nPerson that\nrolled the\ndice", labels = c("Alice", "Bob")) + scale_color_manual(values = mycols, name = "State:\nPerson that\nrolled the\ndice", labels = c("Alice", "Bob")) + theme(axis.ticks = element_blank(), axis.text.y = element_blank()) + labs(y = "Estimated State")) %>% ggplotGrob g3 <- (ggplot(model.output$hmm.post.df, aes(x = roll, y = value, col = variable)) + geom_line() + scale_color_manual(values = mycols, name = "State:\nPerson that\nrolled the\ndice", labels = c("Alice", "Bob")) + theme(axis.ticks = element_blank(), axis.text.y = element_blank()) + labs(y = "Posterior Prob.")) %>% ggplotGrob() g0$widths <- g1$widths return(grid.arrange(g0, g1, g2, g3, widths = 1, nrow = 4)) } plot.hmm.output(hmm1)
It’s impressive how well the model fits the data and filters out the noise to estimate the states. To be fair the states could be estimated by ignoring the time component and using the EM algorithm. However because we know the data forms a sequence there is more infomration at our disposal since the probability of observing the next draw is conditional on the previous i.e. where is the number of jelly beans.
This may have been a relatively easy case given we constructed the problem. What if the transition probabilities were much greater?
draws <- simulate(100, jbns = c(12, 4), switch.val = 7) hmm2 <- fit.hmm(draws)
## iteration 0 logLik: -354.2707 ## iteration 5 logLik: -282.4679 ## iteration 10 logLik: -282.3879 ## iteration 15 logLik: -282.3764 ## iteration 20 logLik: -282.3748 ## iteration 25 logLik: -282.3745 ## converged at iteration 30 with logLik: -282.3745 ## est.state.labels ## state alice bob ## alice 54 2 ## bob 5 39
plot.hmm.output(hmm2)
It is much noiser data but the HMM still does a great job. The performance is in part due to our choise of means for the number of jelly beans removed from the jar. The more distinct the distributions are the easier it ifor the model to pick up the transitions. To be fair we could calculate the median and take all those below the median to be of one state and all those above in another state which you can see from the plot wold do quite well. This is because the transition probabilities are very high and it is expect we would observe a similar number of observations from each state. When the transition probabilities are not the same we see the HMM perform better.
What if the observations are drawn from the same distribution i.e. Alice and Bob take the same amount of jelly beans?
draws <- simulate(100, jbns = c(12, 12), switch.val = 4) hmm3 <- fit.hmm(draws) plot.hmm.output(hmm3)
Not so great, but that’s to be expected. If there is no difference between the distributions from which the observations are drawn then there may as well be only 1 state. Feel free to play around with different values to see their impact.
Firstly the number of states and how they are distributed are inherently unknown. With knowledge of the system being modelled a reasonable number of states chosen by the user. In our example we knew there were two states making things easier. It’s possible to know the exact number of states but it is uncommon. Is is often reasonable to assume the observations are normally disributed, again through knowledge of the system.
From here the Baum-Welch algrothm is applied to estimate the parameters which is a variant of the EM algrothm which leverages the sequence of observations and the Markov property. In addition to estimating the parameters of the states it also needs to estimate the transition probabilities. The Baum-Welch algrothm first makes a forward pass over of the data followed by a backward pass. The state transition probabilities are then updated. This process is then repeated until convergence. See the link for an example to walkthrough.
In the real world it’s unlikely you’ll ever be predicting who took the jelly beans from the jelly bean jar. Hopefully you are working on more interesting problems, however this example breaks it down into understandable components. Often HMM’s are used for
to name a few. Whenever there is a sequence of observations HMM’s can be used which also holds true for discrete cases.
The post Hidden Markov Model example in R with the depmixS4 package appeared first on Daniel Oehm | Gradient Descending.
]]>Bayesian Networks are probabilistic graphical models and they have some neat features which make them very useful for many problems. […]
The post Bayesian Network Example with the bnlearn Package appeared first on Daniel Oehm | Gradient Descending.
]]>Bayesian Networks are probabilistic graphical models and they have some neat features which make them very useful for many problems. They are structured in a way which allows you to calculate the conditional probability of an event given the evidence. The graphical representation makes it easy to understand the relationships between the variables and they are used in many AI solutions where decisions need to be automated in a range of contexts such as medical diagnosis, risk modelling and mitigation. Bayesian networks are great where the is a complex system of many causal relationships.
Some key benefits of Bayesian Networks include:
In this post I’ll build a Bayesian Network with the AIS dataset found in the DAAG package. This dataset was used to determine if there was a difference in mean hemoglobin levels for different sport disciplines. To begin with we’ll quickly look at a box plot comparing the distribution of hemoglobin levels for the different sports just to get a feel for the data.
ggplot(ais, aes(x = sport, y = hg, fill = sport)) + geom_boxplot() + scale_fill_manual(values = colorRampPalette(king.yna)(10))
The box plots would suggest there are some differences. We can use this to direct our Bayesian Network construction.
We’ll start of by building a simple network using 3 variables hematocrit (hc) which is the volume percentage of red blood cells in the blood, sport and hemoglobin concentration (hg). Hematocrit and hemoglobin measurements are continuous variables. For simplicity of the first example these will be transformed to binary variables and we’ll subset the data to only 3 sports, netball, tennis and water polo. These sports were chosen since there is a clear difference between their hemoglobin levels as shown by the box plots above. An empty graph will be created followed by inputting the structure manually.
# set boolean variables ais$high_hc <- as.factor(ais$hc > median(ais$hc)) ais$high_hg <- as.factor(ais$hg > median(ais$hg)) # create an empty graph structure <- empty.graph(c("high_hc", "high_hg", "sport")) # set relationships manually modelstring(structure) <- "[high_hc][sport][high_hg|sport:high_hc]" # plot network func # using the visNetwork package to plot the network because it looks very nice. plot.network <- function(structure, ht = "400px"){ nodes.uniq <- unique(c(structure$arcs[,1], structure$arcs[,2])) nodes <- data.frame(id = nodes.uniq, label = nodes.uniq, color = "darkturquoise", shadow = TRUE) edges <- data.frame(from = structure$arcs[,1], to = structure$arcs[,2], arrows = "to", smooth = TRUE, shadow = TRUE, color = "black") return(visNetwork(nodes, edges, height = ht, width = "100%")) } # observe structure plot.network(structure)
If you can’t see the plot you may need to unblock the content or see it here. There are algorithms to do this which I’ll briefly go into later, but for now we’ll explicitly state the relationships. Manually creating the structure is often a good way to go since you are required to understand the system you are trying to model and not relying on a black box to do it for you. Having said that, once constructed whether it be manually or using an algorithm the Bayesian Network is easily understood through the graphical representation and each variable can be explained.
The relationship of this network is saying,
We’ll now fit the model and output the conditional probabilities for each node.
ais.sub <- ais[ais$sport %in% c("Netball", "Tennis", "W_Polo"), c("high_hc", "high_hg", "sport")] ais.sub$sport <- factor(ais.sub$sport) bn.mod <- bn.fit(structure, data = ais.sub) bn.mod
## ## Bayesian network parameters ## ## Parameters of node high_hc (multinomial distribution) ## ## Conditional probability table: ## FALSE TRUE ## 0.6078431 0.3921569 ## ## Parameters of node high_hg (multinomial distribution) ## ## Conditional probability table: ## ## , , sport = Netball ## ## high_hc ## high_hg FALSE TRUE ## FALSE 1.0000000 ## TRUE 0.0000000 ## ## , , sport = Tennis ## ## high_hc ## high_hg FALSE TRUE ## FALSE 0.8571429 0.2500000 ## TRUE 0.1428571 0.7500000 ## ## , , sport = W_Polo ## ## high_hc ## high_hg FALSE TRUE ## FALSE 1.0000000 0.0625000 ## TRUE 0.0000000 0.9375000 ## ## ## Parameters of node sport (multinomial distribution) ## ## Conditional probability table: ## Netball Tennis W_Polo ## 0.4509804 0.2156863 0.3333333
cat("P(high hemaglobin levels) =", cpquery(bn.mod, (high_hg=="TRUE"), TRUE), "\n")
## P(high hemaglobin levels) = 0.2136
cat("P(high hemaglobin levels | play water polo and have high hematocrit ratio) =", cpquery(bn.mod, (high_hg=="TRUE"), (sport == "W_Polo" & high_hc == "TRUE")), "\n")
## P(high hemaglobin levels | play water polo and have high hematocrit ratio) = 0.9399076
One of the main benefits of Bayes nets is we can reverse the direction. Unlike a regression where there are response and explanatory variables a Bayes Net is not ‘fixed’ (for lack of a better word) in the same way and each node can be made the subject of the query for inference. With the same model we can query the probability that an athlete plays water polo given we observe their high hemoglobin levels or the probability of having high hemoglobin levels given they play water polo.
cat("P(they play water polo | high hemaglobin levels and have high hematocrit ratio) =", cpquery(bn.mod, (sport=="W_Polo"), (high_hg == "TRUE" & high_hc == "TRUE")), "\n")
## P(they play water polo | high hemaglobin levels and have high hematocrit ratio) = 0.6351064
Let’s say that we didn’t know the athletes hematocrit ratio, could we still calculate the probability they have high hemoglobin levels given they play water polo? Sure, we just sum over the hematocrit probability distribution. Fortunately the cpquery
function takes care of this for us.
cat("P(high hemaglobin levels | play water polo) =", cpquery(bn.mod, (high_hg=="TRUE"), (sport == "W_Polo")), "\n")
## P(high hemaglobin levels | play water polo) = 0.3623018
Let’s redefine our simple network with the actual continuous variables. Again, bnlearn
handles the hard work. For the continuous case the probability densities are estimated.
# create an empty graph structure <- empty.graph(c("hc", "hg", "sport")) # set relationships manually modelstring(structure) <- "[hc][sport][hg|sport:hc]" # subset and fit ais.sub <- ais[ais$sport %in% c("Netball", "Tennis", "W_Polo"), c("hc", "hg", "sport")] ais.sub$sport <- factor(ais.sub$sport) bn.mod <- bn.fit(structure, data = ais.sub) bn.mod
## ## Bayesian network parameters ## ## Parameters of node hc (Gaussian distribution) ## ## Conditional density: hc ## Coefficients: ## (Intercept) ## 41.82353 ## Standard deviation of the residuals: 4.092363 ## ## Parameters of node hg (conditional Gaussian distribution) ## ## Conditional density: hg | hc + sport ## Coefficients: ## 0 1 2 ## (Intercept) 1.5550754 -2.7611358 -0.1173597 ## hc 0.2929909 0.4019544 0.3398915 ## Standard deviation of the residuals: ## 0 1 2 ## 0.2726074 0.3383277 0.3091150 ## Discrete parents' configurations: ## sport ## 0 Netball ## 1 Tennis ## 2 W_Polo ## ## Parameters of node sport (multinomial distribution) ## ## Conditional probability table: ## Netball Tennis W_Polo ## 0.4509804 0.2156863 0.3333333
Now when querying the model we need to be a little more specific than in the discrete case by specifying a range.
cat("P(hemaglobin levels > 14 | play water polo and have high hematocrit ratio) =", cpquery(bn.mod, (hg > 14), (sport == "W_Polo" & hc > 42 )), "\n")
## P(hemaglobin levels > 14 | play water polo and have high hematocrit ratio) = 0.9495798
Another key benefit of Bayes nets is variables can be chained together. In other words two nodes don’t need to be directly connected to make inference from one about the other. We’ll add in another variable into our simple model, lean body mass which is calculated as body weight minus body fat in kgs, so higher the number the leaner the athlete.
# create an empty graph structure <- empty.graph(c("hc", "hg", "sport", "lbm")) # set relationships manually modelstring(structure) <- "[lbm][hc|lbm][sport][hg|sport:hc]" plot.network(structure) # subset and fit ais.sub <- ais[ais$sport %in% c("Netball", "Tennis", "W_Polo"), c("hc", "hg", "sport", "lbm")] ais.sub$sport <- factor(ais.sub$sport) bn.mod <- bn.fit(structure, data = ais.sub) bn.mod
## ## Bayesian network parameters ## ## Parameters of node hc (Gaussian distribution) ## ## Conditional density: hc | lbm ## Coefficients: ## (Intercept) lbm ## 26.5212185 0.2471436 ## Standard deviation of the residuals: 2.846647 ## ## Parameters of node hg (conditional Gaussian distribution) ## ## Conditional density: hg | hc + sport ## Coefficients: ## 0 1 2 ## (Intercept) 1.5550754 -2.7611358 -0.1173597 ## hc 0.2929909 0.4019544 0.3398915 ## Standard deviation of the residuals: ## 0 1 2 ## 0.2726074 0.3383277 0.3091150 ## Discrete parents' configurations: ## sport ## 0 Netball ## 1 Tennis ## 2 W_Polo ## ## Parameters of node sport (multinomial distribution) ## ## Conditional probability table: ## Netball Tennis W_Polo ## 0.4509804 0.2156863 0.3333333 ## ## Parameters of node lbm (Gaussian distribution) ## ## Conditional density: lbm ## Coefficients: ## (Intercept) ## 61.91667 ## Standard deviation of the residuals: 12.00722
Now we can query the model and calculate the probability that athletes have hemoglobin levels greater than 14 given they play water polo and have an LBM of greater than 65 kg without having any knowledge of their hematocrit ratio.
cat("P(hemaglobin levels > 14 | play water polo and have LBM > 65 kg) =", cpquery(bn.mod, (hg > 14), (sport == "W_Polo" & lbm > 65 )), "\n")
## P(hemaglobin levels > 14 | play water polo and have LBM > 65 kg) = 0.8123028
For large cases you’ll want to use an algorithm to define the structure of the Bayes net, and then add other user defined relationships on top of that if required. bnlearn
includes the hill climbing algorithm which is suitable for the job. The default score it uses to optimise the model is the BIC which is appropriate. There are many others such as AIC, Bayesian Dirichlet score, K2, to name a few that may be more appropriate for your problem.
# learn the structure using the hill climbing algorithm and the BIC structure <- hc(ais.sub, score = "bic-cg") plot.network(structure)
As you can see it is different to the one defined. This structure best fits the data by maximising the BIC, but if we understand the system well enough we can input the relationships that we know are important. This is more the case when sample sizes are small, when they are large we can put more trust in the algorithm to find the correct relationships. Having said that, often there are biases in the data and if those mechanisms are well understood the right relationships can be put into the model as well.
bn.mod <- bn.fit(structure, data = ais.sub) cat("P(hemaglobin levels > 14 | play water polo and have LBM > 65 kg) =", cpquery(bn.mod, (hg > 14), (sport == "W_Polo" & lbm > 65 )), "\n")
## P(hemaglobin levels > 14 | play water polo and have LBM > 65 kg) = 0.9833866
Now we will fit the full model using all the available data after removing those which are a function of others e.g. .
ais.sub <- ais[, c("hc", "hg", "sport", "lbm", "rcc", "wcc", "ferr", "ht", "wt", "sex", "ssf")] structure <- hc(ais.sub, score = "bic-cg") bn.mod <- bn.fit(structure, data = ais.sub) plot.network(structure, ht = "600px")
Bayes Nets can get complex quite quickly (for example check out a few from the bnlearn doco, however the graphical representation makes it easy to visualise the relationships and the package makes it easy to query the graph.
Fitting the network and querying the model is only the first part of the practice. Where Bayes nets really shine is how they are used to make actionable decisions. In our example we fit a model to help explain the influencing factors on hemoglobin concentration in an athlete. But lets assume that high hemoglobin levels are correlated with better performance, which is likely to be true for endurance sports such as running or cycling but less so for skill based sports like basketball. The athlete could take appropriate action to ensure their hemoglobin concentrations are at optimal levels. Decisions need to be made around
It is when ‘interventions’ such as these can be accounted for in the model the user can implement ‘what if’ scenarios to help make the best decision. Some of these variables can easily be observed but other can not such as red cell count. This might be a measurement that gets taken once every 2-3 months perhaps, in which case decisions will need to be made without the knowledge of the athletes current red cell count. Fortunately a Bayesian network can handle this type of uncertainty and missing information.
The outputs of a Bayesian network are conditional probabilities. Often these are used as input for an overarching optimisation problem. For example an insurance company may construct a Bayesian network to predict the probability of signing up a new customer to premium plan for the next marketing campaign. This probability is then used to calculate the expected revenue from new sales. In turn the model could inform the company if they took actions A and B they could increase their revenue by $x or if they advertised in these other locations for some cost, the revenue is expected to be $y. Using this information they can make them best decision to maximise their profits.
Bayesian networks are hugely flexible and extension to the theory is a Dynamic Bayesian Network which brings in a time component. As new data is collected it is added to the model and the probabilities are updated. This is homework for another day.
The post Bayesian Network Example with the bnlearn Package appeared first on Daniel Oehm | Gradient Descending.
]]>