library(tidymodels) # for the parsnip package, along with the rest of tidymodels
# Helper packages
library(broom.mixed) # for converting bayesian models to tidy tibbles
library(dotwhisker) # for visualizing regression results
library(RCurl) # for reading in text files from a URL
# Set global theme for ggplots
theme_set(theme_minimal(base_size = 14))# minimal theme, 14pt font
theme_update(legend.position = "right") # legend to the right
Tidymodels: Build a Model
Description
Python has taken most of the hype as the most widely used programming language for machine learning. This has left R users with a choice: 1. learn Python or 2. don’t do machine learning very well. Enter tidymodels
, Posit’s solution to machine learning in R, using a framework similar to the tidyverse
.
In this session of the Classical Machine Learning workshop series, we will overview how to build a model using the tidymodels
framework in R. This is the first necessary step towards the more sophisticated models that we will deploy later in this series.
This workshop borrows heavily from open source materials hosted on tidymodels.org found here. The author replaced the original urchins
data with Meredith, Ladd, and Werner 2021 which is described below.
Objectives:
- Load and examine data
- Build and fit a model
- Use a model to predict
- Model with different engines
Introduction
How do you create a statistical model using tidymodels
? In this article, we will walk you through the steps. We start with data for modeling, learn how to specify and train models with different engines using the parsnip package, and understand why these functions are designed this way.
To use code in this article, you will need to install the following packages: broom.mixed
, dotwhisker
, readr
, rstanarm
, RCurl
and tidymodels
.
Canopy rain forest drought tolerance data
Let’s use the data from (Meredith, Ladd, and Werner 2021) which is from a study that investigates the effects of climate change on canopy and understory (those that grow below the canopy and are share tolerant) trees. The data examine water flux in trees across four groups based on drought susceptibility Group
:
drought-sens-canopy
: drought sensitive canopy treesdrought-tol-canopy
: drought tolerant canopy treesdrought-sens-under
: drought sensitive understory trees.drought-tol-under
: drought tolerant understory trees.
# Read in csv from the web
<- getURL("https://raw.githubusercontent.com/Gchism94/Data7_EDA_In_R_Workshops/main/Data7_EDA_In_R_Book/data/Data_Fig2_Repo.csv")
data <- read.csv(text = data) data
Examine the data
%>%
data head()
Date Group Sap_Flow TWaterFlux pLWP mLWP
1 10/4/19 Drought-sens-canopy 184.040975 82.243292 -0.2633781 -0.6797690
2 10/4/19 Drought-sens-under 2.475989 1.258050 -0.2996688 -0.7613264
3 10/4/19 Drought-tol-canopy 10.598949 4.405479 -0.4375563 -0.7225572
4 10/4/19 Drought-tol-under 4.399854 2.055276 -0.2052237 -0.7028581
5 10/5/19 Drought-sens-canopy 182.905444 95.865255 -0.2769280 -0.7082610
6 10/5/19 Drought-sens-under 2.459209 1.225792 -0.3205980 -0.7928576
For each of the drought sensitivity groups (Group
), we know their:
Sap_Flow
: Sap flow rate \(V_s\) in cm hr\(^-1\).pLWP
: Pre-dawn water potential (MegaPascals MPa) representing the potential (energy) of water flow throughout the tree.mLWP
: Midday water potential (MegaPascals MPa) representing the potential (energy) of water flow throughout the tree.TWaterFlux
: Normalized total water flux (sap flowSap_Flow
, pre-dawnpLWP
and middaymLWP
water potential) in each treatment group. Values are normalized based on pre-drought levels.
Plot the data before modeling:
ggplot(data,
aes(x = pLWP,
y = mLWP,
group = Group,
col = Group)) +
geom_point() +
geom_smooth(method = lm, se = FALSE) +
scale_color_viridis_d(option = "plasma", end = 0.7)
`geom_smooth()` using formula = 'y ~ x'
Build and Fit a Model
A standard three-way analysis of variance (ANOVA) model makes sense for this dataset because we have both a continuous and a categorical predictor variables. Since the slopes appear to be different for at least two of the drought treatments, let’s build a model that allows for two-way interactions. Specifying an R formula with our variables in this way:
~ pLWP * Group mLWP
allows our regression model depending on pre-dawn water potential pLWP
to have separate slopes and intercepts for each drought sensitivity Group
.
For this kind of model, ordinary least squares is a good initial approach. With tidymodels, we start by specifying the functional form of the model that we want using the parsnip package. Since there is a numeric outcome and the model should be linear with slopes and intercepts, the model type is “linear regression”. We can declare this with:
linear_reg()
Linear Regression Model Specification (regression)
Computational engine: lm
That is pretty underwhelming since, on its own, it doesn’t really do much. However, now that the type of model has been specified, we can think about a method for fitting or training the model, the model engine. The engine value is often a mash-up of the software that can be used to fit or train the model as well as the estimation method. The default for linear_reg()
is "lm"
for ordinary least squares, as you can see above. We could set a non-default option instead (keras):
linear_reg() %>%
set_engine("keras")
Linear Regression Model Specification (regression)
Computational engine: keras
The documentation page for linear_reg()
lists all the possible engines. We’ll save our model object using the default engine as lm_mod
.
<- linear_reg() lm_mod
From here, the model can be estimated or trained using the fit()
function:
<-
lm_fit %>%
lm_mod fit(mLWP ~ pLWP * Group, data = data)
lm_fit
parsnip model object
Call:
stats::lm(formula = mLWP ~ pLWP * Group, data = data)
Coefficients:
(Intercept) pLWP
-0.566246 1.046276
GroupDrought-sens-under GroupDrought-tol-canopy
0.008669 0.099530
GroupDrought-tol-under pLWP:GroupDrought-sens-under
-0.487055 -0.310337
pLWP:GroupDrought-tol-canopy pLWP:GroupDrought-tol-under
-0.373874 -1.585690
Perhaps our analysis requires a description of the model parameter estimates and their statistical properties. Although the summary()
function for lm
objects can provide that, it gives the results back in an unwieldy format. Many models have a tidy()
method that provides the summary results in a more predictable and useful format (e.g. a data frame with standard column names):
tidy(lm_fit)
# A tibble: 8 × 5
term estimate std.error statistic p.value
<chr> <dbl> <dbl> <dbl> <dbl>
1 (Intercept) -0.566 0.0448 -12.6 4.63e-29
2 pLWP 1.05 0.0628 16.7 2.81e-43
3 GroupDrought-sens-under 0.00867 0.0607 0.143 8.86e- 1
4 GroupDrought-tol-canopy 0.0995 0.112 0.885 3.77e- 1
5 GroupDrought-tol-under -0.487 0.0701 -6.95 2.73e-11
6 pLWP:GroupDrought-sens-under -0.310 0.0831 -3.73 2.31e- 4
7 pLWP:GroupDrought-tol-canopy -0.374 0.174 -2.15 3.22e- 2
8 pLWP:GroupDrought-tol-under -1.59 0.133 -11.9 1.46e-26
This kind of output can be used to generate a dot-and-whisker plot of our regression results using the dotwhisker
package:
tidy(lm_fit) %>%
dwplot(dot_args = list(size = 2, color = "black"),
whisker_args = list(color = "black"),
vline = geom_vline(xintercept = 0, colour = "grey50", linetype = 2))
Use a Model to Predict
This fitted object lm_fit
has the lm
model output built-in, which you can access with lm_fit$fit
, but there are some benefits to using the fitted parsnip model object when it comes to predicting.
Suppose that, for a publication, it would be particularly interesting to make a plot of the mean midday water potential (mLWP
) for trees that have a pre-dawn water potential (pLWP
) of -1.75 MPa. To create such a graph, we start with some new example data that we will make predictions for, to show in our graph:
<- expand.grid(pLWP = -1.75,
new_points Group = c("Drought-sens-canopy",
"Drought-sens-under",
"Drought-tol-canopy",
"Drought-tol-under"))
new_points
pLWP Group
1 -1.75 Drought-sens-canopy
2 -1.75 Drought-sens-under
3 -1.75 Drought-tol-canopy
4 -1.75 Drought-tol-under
To get our predicted results, we can use the predict()
function to find the mean values at 200 cm hr\(^-1\).
It is also important to communicate the variability, so we also need to find the predicted confidence intervals. If we had used lm()
to fit the model directly, a few minutes of reading the documentation page for predict.lm()
would explain how to do this. However, if we decide to use a different model to estimate urchin size (spoiler: we will!), it is likely that a completely different syntax would be required.
Instead, with tidymodels
, the types of predicted values are standardized so that we can use the same syntax to get these values.
First, let’s generate the mean body width values:
<- predict(lm_fit, new_data = new_points)
mean_pred mean_pred
# A tibble: 4 × 1
.pred
<dbl>
1 -2.40
2 -1.85
3 -1.64
4 -0.109
When making predictions, the tidymodels
convention is to always produce a tibble of results with standardized column names. This makes it easy to combine the original data and the predictions in a usable format:
<- predict(lm_fit,
conf_int_pred new_data = new_points,
type = "conf_int")
conf_int_pred
# A tibble: 4 × 2
.pred_lower .pred_upper
<dbl> <dbl>
1 -2.53 -2.26
2 -1.96 -1.73
3 -2.00 -1.29
4 -0.413 0.195
# Now combine:
<-
plot_data %>%
new_points bind_cols(mean_pred) %>%
bind_cols(conf_int_pred)
# and plot:
ggplot(plot_data, aes(x = Group)) +
geom_point(aes(y = .pred)) +
geom_errorbar(aes(ymin = .pred_lower,
ymax = .pred_upper),
width = .2) +
labs(y = "mLWP")
Model with a Different Engine
Every one on your team is happy with that plot except that one person who just read their first book on Bayesian analysis. They are interested in knowing if the results would be different if the model were estimated using a Bayesian approach. In such an analysis, a prior distribution needs to be declared for each model parameter that represents the possible values of the parameters (before being exposed to the observed data). After some discussion, the group agrees that the priors should be bell-shaped but, since no one has any idea what the range of values should be, to take a conservative approach and make the priors wide using a Cauchy distribution (which is the same as a t-distribution with a single degree of freedom).
The documentation on the rstanarm
package shows us that the stan_glm()
function can be used to estimate this model, and that the function arguments that need to be specified are called prior
and prior_intercept
. It turns out that linear_reg()
has a stan
engine. Since these prior distribution arguments are specific to the Stan software, they are passed as arguments to parsnip::set_engine()
. After that, the same exact fit()
call is used:
# set the prior distribution
<- rstanarm::student_t(df = 1)
prior_dist
set.seed(123)
# make the parsnip model
<-
bayes_mod linear_reg() %>%
set_engine("stan",
prior_intercept = prior_dist,
prior = prior_dist)
# train the model
<-
bayes_fit %>%
bayes_mod fit(mLWP ~ pLWP * Group, data = data)
print(bayes_fit, digits = 5)
parsnip model object
stan_glm
family: gaussian [identity]
formula: mLWP ~ pLWP * Group
observations: 276
predictors: 8
------
Median MAD_SD
(Intercept) -0.56813 0.04422
pLWP 1.04504 0.06296
GroupDrought-sens-under 0.00835 0.06143
GroupDrought-tol-canopy 0.10172 0.10937
GroupDrought-tol-under -0.48327 0.06872
pLWP:GroupDrought-sens-under -0.31005 0.08285
pLWP:GroupDrought-tol-canopy -0.37154 0.16722
pLWP:GroupDrought-tol-under -1.58075 0.13124
Auxiliary parameter(s):
Median MAD_SD
sigma 0.12762 0.00564
------
* For help interpreting the printed output see ?print.stanreg
* For info on the priors used see ?prior_summary.stanreg
This kind of Bayesian analysis (like many models) involves randomly generated numbers in its fitting procedure. We can use set.seed()
to ensure that the same (pseudo-)random numbers are generated each time we run this code. The number 123
isn’t special or related to our data; it is just a “seed” used to choose random numbers.
To update the parameter table, the tidy()
method is once again used:
tidy(bayes_fit, conf.int = TRUE)
# A tibble: 8 × 5
term estimate std.error conf.low conf.high
<chr> <dbl> <dbl> <dbl> <dbl>
1 (Intercept) -0.568 0.0442 -0.640 -0.495
2 pLWP 1.05 0.0630 0.940 1.15
3 GroupDrought-sens-under 0.00835 0.0614 -0.0870 0.113
4 GroupDrought-tol-canopy 0.102 0.109 -0.0797 0.286
5 GroupDrought-tol-under -0.483 0.0687 -0.595 -0.368
6 pLWP:GroupDrought-sens-under -0.310 0.0829 -0.441 -0.167
7 pLWP:GroupDrought-tol-canopy -0.372 0.167 -0.651 -0.0909
8 pLWP:GroupDrought-tol-under -1.58 0.131 -1.79 -1.36
A goal of the tidymodels
packages is that the interfaces to common tasks are standardized (as seen in the tidy()
results above). The same is true for getting predictions; we can use the same code even though the underlying packages use very different syntax:
<-
bayes_plot_data %>%
new_points bind_cols(predict(bayes_fit, new_data = new_points)) %>%
bind_cols(predict(bayes_fit, new_data = new_points, type = "conf_int"))
ggplot(bayes_plot_data, aes(x = Group)) +
geom_point(aes(y = .pred)) +
geom_errorbar(aes(ymin = .pred_lower, ymax = .pred_upper), width = .2) +
labs(y = "mLWP") +
ggtitle("Bayesian model with t(1) prior distribution")
This isn’t very different from the non-Bayesian results (except in interpretation).
Note: The
parsnip
package can work with many model types, engines, and arguments. Check out tidymodels.org/find/parsnip/ to see what is available.
Why does is work that way?
The extra step of defining the model using a function like linear_reg()
might seem superfluous since a call to lm()
is much more succinct. However, the problem with standard modeling functions is that they don’t separate what you want to do from the execution. For example, the process of executing a formula has to happen repeatedly across model calls even when the formula does not change; we can’t recycle those computations.
Also, using the tidymodels
framework, we can do some interesting things by incrementally creating a model (instead of using single function call). Model tuning with tidymodels
uses the specification of the model to declare what parts of the model should be tuned. That would be very difficult to do if linear_reg()
immediately fit the model.
If you are familiar with the tidyverse, you may have noticed that our modeling code uses the magrittr pipe (%>%
). With dplyr and other tidyverse
packages, the pipe works well because all of the functions take the data as the first argument. For example:
%>%
data group_by(Group) %>%
drop_na() %>% # note that NAs need to be removed
summarize(med_pLWP = median(pLWP))
# A tibble: 4 × 2
Group med_pLWP
<chr> <dbl>
1 Drought-sens-canopy -0.706
2 Drought-sens-under -0.592
3 Drought-tol-canopy -0.603
4 Drought-tol-under -0.406
whereas the modeling code uses the pipe to pass around the model object:
%>% fit(mLWP ~ pLWP * Group, data = data) bayes_mod
This may seem jarring if you have used dplyr a lot, but it is extremely similar to how ggplot2 operates:
ggplot(data,
aes(pLWP, mLWP)) + # returns a ggplot object
geom_jitter() + # same
geom_smooth(method = lm, se = FALSE) + # same
labs(x = "pLWP", y = "mLWP") # etc.
Session Information
─ Session info ───────────────────────────────────────────────────────────────
setting value
version R version 4.2.1 (2022-06-23)
os macOS Monterey 12.2
system aarch64, darwin20
ui X11
language (EN)
collate en_US.UTF-8
ctype en_US.UTF-8
tz America/Phoenix
date 2023-02-01
pandoc 2.19.2 @ /Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/ (via rmarkdown)
─ Packages ───────────────────────────────────────────────────────────────────
package * version date (UTC) lib source
broom * 1.0.1 2022-08-29 [1] CRAN (R 4.2.0)
broom.mixed * 0.2.9.4 2022-04-17 [1] CRAN (R 4.2.0)
dials * 1.1.0 2022-11-04 [1] CRAN (R 4.2.0)
dotwhisker * 0.7.4 2021-09-02 [1] CRAN (R 4.2.0)
dplyr * 1.0.10 2022-09-01 [1] CRAN (R 4.2.1)
ggplot2 * 3.4.0 2022-11-04 [1] CRAN (R 4.2.0)
infer * 1.0.3 2022-08-22 [1] CRAN (R 4.2.0)
modeldata * 1.0.1 2022-09-06 [1] CRAN (R 4.2.0)
parsnip * 1.0.3 2022-11-11 [1] CRAN (R 4.2.0)
purrr * 0.3.5 2022-10-06 [1] CRAN (R 4.2.0)
RCurl * 1.98-1.9 2022-10-03 [1] CRAN (R 4.2.0)
recipes * 1.0.3 2022-11-09 [1] CRAN (R 4.2.0)
rsample * 1.1.0 2022-08-08 [1] CRAN (R 4.2.0)
scales * 1.2.1 2022-08-20 [1] CRAN (R 4.2.0)
tibble * 3.1.8 2022-07-22 [1] CRAN (R 4.2.0)
tidymodels * 1.0.0 2022-07-13 [1] CRAN (R 4.2.0)
tidyr * 1.2.1 2022-09-08 [1] CRAN (R 4.2.0)
tune * 1.0.1 2022-10-09 [1] CRAN (R 4.2.0)
workflows * 1.1.2 2022-11-16 [1] CRAN (R 4.2.0)
workflowsets * 1.0.0 2022-07-12 [1] CRAN (R 4.2.0)
yardstick * 1.1.0 2022-09-07 [1] CRAN (R 4.2.0)
[1] /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/library
──────────────────────────────────────────────────────────────────────────────