✅ Lab 3: Linear regression as a machine learning algorithm - Solutions

Author

The DS202 Team

Instead of just reading the solutions, you might want to download the .qmd file and run the code yourself. You can download the .qmd file by clicking on the download button below:

Loading libraries and functions

library("ggsci")       # For pretty colour schemes
library("MASS")        # For simulating data  
library("scales")      # For number formatting in ggplot2
library("tidymodels")  # For train / test splits
library("tidyverse")   # For data wrangling / visualisation

# Let's use some of the theme functions we created in lab 2

theme_histogram <- function() {
  
  theme_minimal() +
  theme(panel.grid.minor = element_blank(),
        panel.grid.major.x = element_blank())
  
}

theme_boxplot <- function() {
  
  theme_minimal() +
  theme(panel.grid.minor = element_blank(),
        panel.grid.major.x = element_blank(),
        legend.position = "none")
  
}

theme_bar <- function() {
  
  theme_minimal() +
  theme(panel.grid.minor = element_blank(),
        panel.grid.major.y = element_blank())
  
}

theme_scatter <- function() {
  
  theme_minimal() +
  theme(panel.grid.minor = element_blank(),
        legend.position = "bottom")
  
}

Before we do anything more

Please create a data folder called data to store all the different data sets in this course.

The World Values Survey


wvs <- read_csv("data/WVS_Wave_7.csv")

We start our machine learning journey with Wave 7 of the World Values Survey (wvs), which contains information on r nrow(wvs) from r length(unique(wvs$iso3c)) countries. We have cleaned the data to only include non-missing values. The columns include:

  • iso3c 3-letter country iso code
  • satisfaction 1 to 10 rating of life satisfaction (the outcome)
  • social_trust TRUE / FALSE as to whether or not someone expresses social trust
  • male respondent is male (reference: female)
  • age age of respondent
  • post_second_edu respondent has post-secondary education
  • rural respondent lives in a rural area (reference: urban)
  • employment categorical employment variable (try count(wvs, employment))
  • financial_situ 1 to 10 rating of respondent’s financial situation
  • married the respondent is married (reference: other marital status)
  • relig_import 1 to 10 rating of how much importance the respondent attaches to religion.
  • better_living trichotomous rating of whether or not the respondent sees their lives as better off than their parents (try count(wvs, better_living))
  • no_food: In the last 12 months, how often have your or your family gone without enough food to eat? (Sometimes / Often = TRUE, Rarely / Never = FALSE)
  • no_safety In the last 12 months, how often have your or your family felt unsafe from crime in your home? (Sometimes / Often = TRUE, Rarely / Never = FALSE)
  • no_medical In the last 12 months, how often have your or your family gone without medicine or medical treatment that you needed? (Sometimes / Often = TRUE, Rarely / Never = FALSE)
  • no_cash In the last 12 months, how often have your or your family gone without a cash income? (Sometimes / Often = TRUE, Rarely / Never = FALSE)
  • no_shelter In the last 12 months, how often have your or your family gone without a safe shelter over your head? (Sometimes / Often = TRUE, Rarely / Never = FALSE)

👉 NOTE: With variables such as no_food, we recoded the original question mainly for the sake of simplicity. This simplification can be useful as it reduces the number of parameters in our model. However, there may be distinct differences between each level that may produce different results.

If you find that your laptop is unable to handle the full data set without running slowly, try experimenting with the following code. This takes a random sample of the data, stratifying by country so the sampling algorithm doesn’t take more data from one country and less from another.


# Set a seed for reproducibility

set.seed(123)
  
wvs <-
  # Load the .csv
  read_csv("data/WVS Wave 7.csv") %>%
  # Sample a proportion of the data set for each country.
  # We have used 25% but you can experiment depending on
  # the capability of your machine.
  group_by(iso3c) %>% 
  slice_sample(prop = 0.25) %>% 
  ungroup()

Understanding life satisfaction: some exploratory data analysis (EDA) (5 minutes)

Here are a couple of graphs that tell a few stories about wvs.

Median life satisfaction appears to be relatively high in the sample (7 out of 10)


med_satisfaction <- median(wvs$satisfaction)

wvs %>% 
  ggplot(aes(satisfaction)) +
  geom_histogram(fill = "midnightblue", colour = "black", bins = 10, alpha = 0.5) +
  geom_vline(xintercept = med_satisfaction, linetype = "dashed", size = 2, colour = "red") +
  theme_histogram() +
  scale_x_continuous(breaks = c(1, med_satisfaction, 10)) +
  scale_y_continuous(labels = comma) +
  labs(x = "Life satisfaction", y = "Number of respondents",
       caption = "Note: dotted line represents median life satisfaction")

Life satisfaction tracks positively with financial situation


wvs %>% 
  count(financial_situ, satisfaction) %>% 
  ggplot(aes(financial_situ, satisfaction, size = n)) +
  geom_point() +
  scale_x_continuous(breaks = seq(2, 10, 2)) +
  scale_y_continuous(breaks = seq(2, 10, 2)) +
  scale_size_continuous(labels = comma) +
  theme(panel.grid = element_blank(),
        panel.background = element_rect(fill = "white"),
        legend.position = "bottom") +
  labs(x = "Financial situation", y = "Life satisfaction")

The median individual with post-secondary education has a higher life satisfaction than the median individual without


wvs %>% 
  ggplot(aes(post_second_edu, satisfaction, fill = post_second_edu)) +
  geom_boxplot() +
  theme_boxplot() +
  scale_fill_jco() +
  labs(x = "Post-secondary education?", y = "Life satisfaction")

Understanding life satisfaction: the hypothesis-testing approach (5 minutes)

Why do some people have a higher life satisfaction than others? This is one question that a quantitative social scientist might answer by exploring the magnitude and precision of a series of variables. Suppose we hypothesised that individuals with post-secondary level education have greater life satisfaction. We can estimate a linear regression model by using satisfaction as the dependent variable and post_second_edu as the independent variable.

To run a linear regression, we can use the lm function, which requires two things:

  • A model formula (a.k.a. equation)
  • The data used to estimate the model

Let’s do this now. We can call the summary function to get information on the coefficient estimate for post_second_edu.


lm(satisfaction ~ post_second_edu, data = wvs) %>% 
  summary()
Call:
lm(formula = satisfaction ~ post_second_edu, data = wvs)

Residuals:
    Min      1Q  Median      3Q     Max 
-6.2477 -1.2477 -0.0458  1.7523  2.9542 

Coefficients:
                    Estimate Std. Error t value Pr(>|t|)    
(Intercept)         7.045795   0.009161  769.10   <2e-16 ***
post_second_eduTRUE 0.201856   0.015928   12.67   <2e-16 ***
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1

Residual standard error: 2.209 on 86865 degrees of freedom
Multiple R-squared:  0.001846,  Adjusted R-squared:  0.001834 
F-statistic: 160.6 on 1 and 86865 DF,  p-value: < 2.2e-16

We see that individuals with post-secondary education have a positive and statistically significant (p < 0.001) increase in life satisfaction of about 0.2 points.

👉 NOTE: The process of hypothesis testing is obviously more involved when using observational data than is portrayed by this simple example. Control variables will almost always be incorporated and, increasingly, identification strategies will be used to uncover causal effects. The end result, however, will involve as rigorous an attempt at falsifying a hypothesis as can be provided with the data.

For an example of how multivariate regression is used, we can run the following code.


lm(satisfaction ~ . -iso3c, data = wvs) %>% 
  summary()
Call:
lm(formula = satisfaction ~ . - iso3c, data = wvs)

Residuals:
    Min      1Q  Median      3Q     Max 
-8.3926 -1.0385  0.0969  1.0341  6.4011 

Coefficients:
                               Estimate Std. Error t value Pr(>|t|)    
(Intercept)                   3.768e+00  3.594e-02 104.842  < 2e-16 ***
social_trustTRUE              7.328e-02  1.514e-02   4.839 1.31e-06 ***
maleTRUE                     -4.916e-02  1.327e-02  -3.705 0.000211 ***
age                          -3.842e-05  4.901e-04  -0.078 0.937514    
post_second_eduTRUE          -3.345e-02  1.388e-02  -2.409 0.015989 *  
ruralTRUE                    -4.186e-02  1.378e-02  -3.039 0.002377 ** 
employmentHouse wife/husband  1.200e-03  2.194e-02   0.055 0.956404    
employmentOther               1.059e-01  5.846e-02   1.811 0.070203 .  
employmentPart time           1.039e-01  2.362e-02   4.400 1.09e-05 ***
employmentRetired            -2.247e-02  2.410e-02  -0.932 0.351095    
employmentSelf employed      -1.072e-02  1.969e-02  -0.545 0.586018    
employmentStudent            -1.647e-02  3.032e-02  -0.543 0.587003    
employmentUnemployed         -1.378e-01  2.529e-02  -5.450 5.05e-08 ***
financial_situ                4.790e-01  2.753e-03 173.978  < 2e-16 ***
marriedTRUE                   1.908e-01  1.383e-02  13.798  < 2e-16 ***
relig_import                  5.722e-02  2.049e-03  27.921  < 2e-16 ***
better_livingBetter off      -2.524e-03  1.449e-02  -0.174 0.861716    
better_livingWorse off       -4.222e-01  1.925e-02 -21.932  < 2e-16 ***
no_foodTRUE                  -1.663e-01  2.066e-02  -8.050 8.41e-16 ***
no_safetyTRUE                -5.067e-02  1.821e-02  -2.783 0.005385 ** 
no_medicalTRUE               -1.286e-01  1.827e-02  -7.037 1.98e-12 ***
no_cashTRUE                   2.262e-02  1.697e-02   1.333 0.182516    
no_shelterTRUE               -2.384e-01  2.588e-02  -9.214  < 2e-16 ***
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1

Residual standard error: 1.811 on 86844 degrees of freedom
Multiple R-squared:  0.3289,    Adjusted R-squared:  0.3287 
F-statistic:  1935 on 22 and 86844 DF,  p-value: < 2.2e-16

The . placeholder indicates that we want to use all other variables in the data set. -iso3c indicates that we are omitting country dummy variables. Interestingly, we can see that the coefficient estimate for post_second_edu is now negative and significant, albeit at a lower level than before (p < 0.05).

👉 NOTE: p-values are useful to machine learning scientists as they indicate which variables may yield a significant increase in model performance. However, p-hacking where researchers manipulate data to find results that support their hypothesis make it hard to tell whether or not a relationship held up after honest attempts at falsification. This can range from using a specific modelling approach that produces statistically significant (while failing to report others that do not) findings to outright manipulation of the data. For a recent egregious case of the latter, we recommend the Data Falsificada series.

Predicting life satisfaction: the machine learning approach (30 minutes)

Machine learning scientists take a different approach. Our aim, in this context, is to build a model that can be used to accurately predict how happy a person is using a mixture of features and, for some models, hyperparameters (which we will address in Lab 5).

Thus, rather than attempting to falsify the effects of causes, we are more concerned about the fit of the model in the aggregate when applied to unforeseen data.

To achieve this, we do the following:

  • Split the data into training and test sets
  • Build a model using the training set
  • Evaluate the model on the test set

Let’s look at each of these in turn.

Split the data into training and test sets

It is worth considering what a training and test set is and why we might split the data this way.

A training set is data that we use to build (or “train”) a model. In the case of multivariate linear regression, we are using the training data to estimate a series of coefficients. Here is a made-up multivariate linear model with three coefficients derived from (non-existent) data to illustrate things.


sim_model_preds <- function(x1, x2, x3) {
  
  y <- 1.1*x1 + 2.2*x2 + 3.3*x3
  y
  
}

A test set is data that the model has not yet seen. We then apply the model to this data set and use an evaluation metric to find out how accurate our predictions are. For example, suppose we had a new observation where x1 = 10, x2 = 20 and x3 = 30 and y = 150. We can use the above model to develop a prediction.


sim_model_preds(10, 20, 30)
[1] 154

We get a prediction of 154 points!

We can also calculate the amount of error we make by calculating residuals (actual value - predicted value).


150 - sim_model_preds(10, 20, 30)
[1] -4

We can see that our model is 4 points off the real answer!

Why do we evaluate our models using different data? Because, as stated earlier, machine learning scientists care about the applicability of a model to unforeseen data. If we were to evaluate the model using the training data, we obviously cannot do this to begin with. Furthermore, we cannot ascertain whether the model we have built can generalise to other data sets or if the model has simply learned the idiosyncrasies of the data it was used to train on. We will discuss the concept of overfitting throughout this course.

We can use the rsample package in the tidymodels ecosystem to split the data into training and test sets.

# Set a seed for reproducibility

set.seed(123)

# Split the data with 75% being used to train the model

wvs_split <- initial_split(wvs, prop = 0.75)

# Create tibbles of the training and test set

wvs_train <- training(wvs_split)
wvs_test <- testing(wvs_split)

👉 NOTE: Our data are purely cross-sectional, so we can use this approach. However, when working with more complex data structures (e.g. time series cross sectional), different approaches to splitting the data will need to be used.

Build a model using the training set

This is remarkably simple. We will use almost exactly the same code we used to build a multivariate linear model, but with one exception. Instead of using the whole of the data, we will only use wvs_train. We will also only create a model object (mv_model, short for multivariate model).


mv_model <- lm(satisfaction ~ . -iso3c, data = wvs_train)

Evaluate the model using the test set

Now that we have trained a model, we can then evaluate its performance on the test set. We will look at two evaluation metrics:

  • R-squared: the proportion of variance in the outcome explained by the model.
  • Root mean squared error (RMSE): the amount of error a typical observation parameterised as the units used in the initial measurement.

reg_metrics <- metric_set(rsq, rmse)

mv_model %>% 
  augment(new_data = wvs_test) %>% 
  reg_metrics(truth = satisfaction, estimate = .fitted)
# A tibble: 2 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 rsq     standard       0.332
2 rmse    standard       1.81 

🗣️ CLASSROOM DISCUSSION:

How can we interpret these results?

We find that the model explains ~ 33% of the test set variance in life satisfaction. We also find that our model predictions are off by ~ 1.8 points.

Graphically exploring where we make errors

We are going to build some residual scatter plots which look at the relationship between the values fitted by the model for each observation and the residuals (actual - predicted values). Before we do this for our data, let’s take a look at an example where there is a near perfect relationship between two variables. As this very rarely exists in the social world, we will rely upon simulated data.

We adapted this code from here.

# Set a seed for reproducibility

set.seed(123)

# Create the variance covariance matrix

sigma <- rbind(c(1,0.99), c(0.99,1))

# Create the mean vector

mu <- c(10, 5) 

# Generate a multivariate normal distribution using 1,000 samples

sim_data <-
  mvrnorm(n=1000, mu=mu, Sigma=sigma) %>% 
  as.data.frame() %>% 
  as_tibble()

Plot the correlation


# Plot the correlation

ggplot(sim_data, aes(V1, V2)) +
  geom_point() +
  theme_scatter() +
  labs(x = "Variable 1", y = "Variable 2")

Residual plots


# Build a linear model and plot the fitted versus residual values

lm(V2 ~ V1, data = sim_data) %>% 
  augment() %>% 
  ggplot(aes(.fitted, .resid)) +
  geom_hline(yintercept = 0, linetype = "dashed") +
  geom_point() +
  theme_scatter() +
  labs(x = "Fitted values", y = "Residuals")

Now let’s run this code for our model.


mv_model %>% 
  augment(new_data = wvs_test) %>% 
  ggplot(aes(.fitted, .resid)) +
  geom_hline(yintercept = 0, linetype = "dashed") +
  geom_point() +
  theme_scatter() +
  labs(x = "Fitted values", y = "Residuals")

🎯 ACTION POINTS why does the graph of the simulated data illustrate a more well-fitting model when compared to our actual data?

The spread of our values in the actual data is relatively large. Furthermore, we see that as our fitted values become larger, we go from underpredicting to overpredicting life satisfaction.

Introduction to using nested tibbles to aid feature selection (30 minutes)

👨🏻‍🏫 TEACHING MOMENT: Your tutor will take you through the code, so sit back, relax and enjoy!

Remember our univariate model earlier? We are going to do the same for all features so see which ones show the best improvements in predictive power.

We could build 15 different model objects, but this would be very inefficient. Instead, we are going to take advantage of a unique feature of tibbles, the list column.

So far, we have looked at columns that are of numeric, integer, Boolean, factor and character class. However, with list columns, we can nest and unnest any class of object within a single cell. This means we can do things like apply functions over list columns simply by adding another list column to our tibble.

We are going to leverage this by creating a series of formulas and build a linear model using a combination of each formula and the nested training data.

Create a series of univariate regression formulas


formulas <- paste("satisfaction ~", colnames(wvs)[3:17])
formulas
 [1] "satisfaction ~ social_trust"    "satisfaction ~ male"            "satisfaction ~ age"             "satisfaction ~ post_second_edu"
 [5] "satisfaction ~ rural"           "satisfaction ~ employment"      "satisfaction ~ financial_situ"  "satisfaction ~ married"        
 [9] "satisfaction ~ relig_import"    "satisfaction ~ better_living"   "satisfaction ~ no_food"         "satisfaction ~ no_safety"      
[13] "satisfaction ~ no_medical"      "satisfaction ~ no_cash"         "satisfaction ~ no_shelter"  

Create a tibble that combines these formulas with the training and test sets


wvs_tbl <-
  # Use crossing to find all key combinations between the
  # formulas and the training and test sets
  crossing(formula = formulas,
           # We use the nest function to create a list column
           # for both the training and test sets
           nest(wvs_train, .key = "train_set"),
           nest(wvs_test, .key = "test_set")) 

wvs_tbl
# A tibble: 15 × 3
   formula                        train_set              test_set              
   <chr>                          <list>                 <list>                
 1 satisfaction ~ age             <tibble [65,150 × 17]> <tibble [21,717 × 17]>
 2 satisfaction ~ better_living   <tibble [65,150 × 17]> <tibble [21,717 × 17]>
 3 satisfaction ~ employment      <tibble [65,150 × 17]> <tibble [21,717 × 17]>
 4 satisfaction ~ financial_situ  <tibble [65,150 × 17]> <tibble [21,717 × 17]>
 5 satisfaction ~ male            <tibble [65,150 × 17]> <tibble [21,717 × 17]>
 6 satisfaction ~ married         <tibble [65,150 × 17]> <tibble [21,717 × 17]>
 7 satisfaction ~ no_cash         <tibble [65,150 × 17]> <tibble [21,717 × 17]>
 8 satisfaction ~ no_food         <tibble [65,150 × 17]> <tibble [21,717 × 17]>
 9 satisfaction ~ no_medical      <tibble [65,150 × 17]> <tibble [21,717 × 17]>
10 satisfaction ~ no_safety       <tibble [65,150 × 17]> <tibble [21,717 × 17]>
11 satisfaction ~ no_shelter      <tibble [65,150 × 17]> <tibble [21,717 × 17]>
12 satisfaction ~ post_second_edu <tibble [65,150 × 17]> <tibble [21,717 × 17]>
13 satisfaction ~ relig_import    <tibble [65,150 × 17]> <tibble [21,717 × 17]>
14 satisfaction ~ rural           <tibble [65,150 × 17]> <tibble [21,717 × 17]>
15 satisfaction ~ social_trust    <tibble [65,150 × 17]> <tibble [21,717 × 17]>

Build a linear model using the training set and apply it to the test set


models <-
  wvs_tbl %>% 
         # Add the features (this will help with plotting!)
  mutate(feature = sort(colnames(wvs)[3:17]),
         # Build a linear model using the training set
         model = map2(formula, train_set, ~ lm(.x, data = .y)),
         # Apply the predictions to the test set
         augmented = map2(model, test_set, ~ augment(.x, new_data = .y)))

models
# A tibble: 15 × 6
   formula                        train_set              test_set               feature         model  augmented            
   <chr>                          <list>                 <list>                 <chr>           <list> <list>               
 1 satisfaction ~ age             <tibble [65,150 × 17]> <tibble [21,717 × 17]> age             <lm>   <tibble [65,150 × 8]>
 2 satisfaction ~ better_living   <tibble [65,150 × 17]> <tibble [21,717 × 17]> better_living   <lm>   <tibble [65,150 × 8]>
 3 satisfaction ~ employment      <tibble [65,150 × 17]> <tibble [21,717 × 17]> employment      <lm>   <tibble [65,150 × 8]>
 4 satisfaction ~ financial_situ  <tibble [65,150 × 17]> <tibble [21,717 × 17]> financial_situ  <lm>   <tibble [65,150 × 8]>
 5 satisfaction ~ male            <tibble [65,150 × 17]> <tibble [21,717 × 17]> male            <lm>   <tibble [65,150 × 8]>
 6 satisfaction ~ married         <tibble [65,150 × 17]> <tibble [21,717 × 17]> married         <lm>   <tibble [65,150 × 8]>
 7 satisfaction ~ no_cash         <tibble [65,150 × 17]> <tibble [21,717 × 17]> no_cash         <lm>   <tibble [65,150 × 8]>
 8 satisfaction ~ no_food         <tibble [65,150 × 17]> <tibble [21,717 × 17]> no_food         <lm>   <tibble [65,150 × 8]>
 9 satisfaction ~ no_medical      <tibble [65,150 × 17]> <tibble [21,717 × 17]> no_medical      <lm>   <tibble [65,150 × 8]>
10 satisfaction ~ no_safety       <tibble [65,150 × 17]> <tibble [21,717 × 17]> no_safety       <lm>   <tibble [65,150 × 8]>
11 satisfaction ~ no_shelter      <tibble [65,150 × 17]> <tibble [21,717 × 17]> no_shelter      <lm>   <tibble [65,150 × 8]>
12 satisfaction ~ post_second_edu <tibble [65,150 × 17]> <tibble [21,717 × 17]> post_second_edu <lm>   <tibble [65,150 × 8]>
13 satisfaction ~ relig_import    <tibble [65,150 × 17]> <tibble [21,717 × 17]> relig_import    <lm>   <tibble [65,150 × 8]>
14 satisfaction ~ rural           <tibble [65,150 × 17]> <tibble [21,717 × 17]> rural           <lm>   <tibble [65,150 × 8]>
15 satisfaction ~ social_trust    <tibble [65,150 × 17]> <tibble [21,717 × 17]> social_trust    <lm>   <tibble [65,150 × 8]>

Unnest the data frame and calculate the r-squared value for each univariate regression


preds <- 
  models %>% 
  # To unnest a list column, we use the unnest function
  unnest(augmented) %>% 
  # We use group_by to perform grouped calculations of the r-squared
  # by feature.
  group_by(feature) %>% 
  rsq(truth = satisfaction, estimate = .fitted) %>% 
  # We reorder the features so the "best" / "worst" is on the top / 
  # bottom
  mutate(feature = fct_reorder(feature, .estimate))

preds
# A tibble: 15 × 4
   feature         .metric .estimator  .estimate
   <fct>           <chr>   <chr>           <dbl>
 1 age             rsq     standard   0.000334  
 2 better_living   rsq     standard   0.0509    
 3 employment      rsq     standard   0.00800   
 4 financial_situ  rsq     standard   0.314     
 5 male            rsq     standard   0.0000515 
 6 married         rsq     standard   0.00530   
 7 no_cash         rsq     standard   0.0295    
 8 no_food         rsq     standard   0.0266    
 9 no_medical      rsq     standard   0.0226    
10 no_safety       rsq     standard   0.00863   
11 no_shelter      rsq     standard   0.0111    
12 post_second_edu rsq     standard   0.00186   
13 relig_import    rsq     standard   0.00277   
14 rural           rsq     standard   0.00000335
15 social_trust    rsq     standard   0.00462   

Plot the results


preds %>% 
  # We the use a bar plot (see Lab 2)
  ggplot(aes(.estimate, feature)) +
  geom_col() +
  theme_bar() +
  labs(x = "Test set r-squared", y = NULL)

Using penalised linear regression to perform feature selection (20 minutes)

We are now going to experiment with a lasso regression which, in this case, is a linear regression that uses a so-called hyperparameter - a “dial” built into a given model that can be experimented with to improve model performance. The hyperparameter in this case is a regularisation penalty which takes the value of a non-negative number. This penalty can shrink the magnitude of coefficients down to zero and the larger the penalty, the more shrinkage occurs.

Step 1: Create a lasso model

Run the following code. This builds a lasso model with the penalty parameter set to 0.01.


lasso_model <-
  linear_reg(penalty = 0.01, mixture = 1) %>%
  set_engine("glmnet") %>% 
  fit(satisfaction ~ . -iso3c, data = wvs_train)

Step 2: Extract lasso coefficients

Use the tidy function on the lasso model to get the coefficients.


tidy(lasso_model) 
Attaching package: ‘Matrix’

The following objects are masked from ‘package:tidyr’:

    expand, pack, unpack

Loaded glmnet 4.1-8
# A tibble: 23 × 3
   term                         estimate penalty
   <chr>                           <dbl>   <dbl>
 1 (Intercept)                   3.77       0.01
 2 social_trustTRUE              0.0438     0.01
 3 maleTRUE                     -0.0307     0.01
 4 age                           0          0.01
 5 post_second_eduTRUE          -0.00487    0.01
 6 ruralTRUE                    -0.0148     0.01
 7 employmentHouse wife/husband  0          0.01
 8 employmentOther               0.00857    0.01
 9 employmentPart time           0.0802     0.01
10 employmentRetired             0          0.01
# ℹ 13 more rows
# ℹ Use `print(n = ...)` to see more rows

🎯 ACTION POINTS What is the output? Which coefficients have been shrunk to zero? What is the most important feature?

Step 3: Create a bar plot


tidy(lasso_model) %>% 
  filter(term != "(Intercept)") %>% 
  ggplot(aes(abs(estimate), fct_reorder(term, abs(estimate)), fill = estimate > 0)) +
  geom_col() +
  theme_bar() +
  ggsci::scale_fill_jama() +
  labs(x = "Lasso coefficient", y = "Feature",
       fill = "Positive?")

Step 4: Evaluate on the test set

Although a different model is used, the code for evaluating the model on the test set is exactly the same as earlier.


lasso_model %>% 
  augment(new_data = wvs_test) %>% 
  reg_metrics(truth = satisfaction, estimate = .pred)
# A tibble: 2 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 rsq     standard       0.319
2 rmse    standard       1.81 

🗣️ CLASSROOM DISCUSSION:

Does this model represent an improvement on the linear model?

No, we find slightly poorer results relative to the linear model.

(Bonus) Step 5: Experiment with different penalties

This is your chance to try out different penalties. Can you find a penalty that improves test set performance?

Let’s employ a nested tibbles approach to find better penalty values.

lasso <- partial(linear_reg, engine = "glmnet", mixture = 1)

preds <-
  crossing(penalty = c(0.0001, 0.001, 0.01, 0.1),
           nest(wvs_train, .key = "train"),
           nest(wvs_test, .key = "test")
           ) %>% 
  mutate(model = map(penalty, ~ lasso(penalty = .x)),
         fit = map2(model, train, ~ fit(.x, satisfaction ~ . -iso3c, data = .y)),
         augmented = map2(fit, test, ~ augment(.x, new_data = .y)),
         rmses = map(augmented, ~ rmse(.x, truth = satisfaction, estimate = .pred)))

preds %>% 
  unnest(rmses) %>% 
  ggplot(aes(x = as.factor(penalty), y = .estimate, group = 1)) +
  geom_point() +
  geom_line(linetype = "dashed") +
  theme_minimal() +
  theme(panel.grid.minor = element_blank(),
        panel.grid.major.x = element_blank())

It looks like a penalty of 0.01 works just as well as the other penalties we have tried out.

👉 NOTE: In labs 4 and 5, we are going to use a method called k-fold cross validation to systematically test different combinations of hyperparameters for models such as the lasso.