Predicting Diamond Prices with LASSO Regression

Automatic variable selection on gem quality data

machine learning
regression
diamonds
variable selection
lasso
glmnet
predictive modeling
feature engineering
R (Programming Language)
Author

A. Srikanth

Published

September 26, 2025

Data Byte

Important

The work was originally implemented in Python and adapted for R.

Context

Diamonds are priced individually, and their value is shaped by a mix of measurable traits. Carat weight is numeric, but most other descriptors like cut, clarity, and color are categorical. A traditional linear regression could handle this data if we created dummy variables for each category, but the sheer number of possible predictors quickly grows. The challenge is finding a way to build an accurate model without manually testing hundreds of feature combinations.

Objectives

The goal is to build a model that predicts the price of an individual diamond while automatically selecting the most important variables. Instead of manually choosing which attributes to include, we’ll use a method that balances accuracy with simplicity: LASSO regression.

Data Sources

We use a dataset contained within Diamonds.csv, which contains information on about 6,000 past diamond sales. The file includes:

# A tibble: 8 × 2
  Variable Type              
  <chr>    <chr>             
1 CARAT    Numeric           
2 CUT      Categorical       
3 COLOR    Categorical       
4 CLARITY  Categorical       
5 POLISH   Categorical       
6 SYMMETRY Categorical       
7 REPORT   Categorical       
8 PRICE    Numeric (Response)

This mix of numeric and categorical data is well suited to regression once the categorical variables are encoded into dummy variables.

Code
suppressPackageStartupMessages({
  library(dplyr)
  library(scales)
  library(DT)
  library(htmltools)
})

diamond_data <- read_csv("data/Diamonds.csv")

bg_col      <- "#FAF8F1"
font_family <- "Ramabhadra"
.fixed_fig_width  <- 1100L
.wrapper_max_w    <- 790L

htmltools::browsable(
  htmltools::tagList(
    htmltools::tags$style(htmltools::HTML(sprintf("
      #temp-table-wrap {
        border-radius: 6px;
        overflow: hidden;
        overflow-x: auto;
        -webkit-overflow-scrolling: touch;
        background: %s;
        padding: 0.75rem;
        width: 100%%;
        max-width: %dpx;
        margin: 0 auto 1rem;
      }
      #temp-table-wrap > div { min-width: %dpx; }

      @media (max-width: %dpx) {
        #temp-table-wrap { padding: 0.5rem; }
      }

      #temp-table table.dataTable thead th { text-align: left !important; }
      #temp-table table.dataTable tbody td { text-align: right !important; }
      #temp-table table.dataTable thead th:first-child,
      #temp-table table.dataTable tbody td:first-child { text-align: left !important; }
    ",
      bg_col, .wrapper_max_w, .fixed_fig_width, .wrapper_max_w
    ))),

    htmltools::tags$div(
      id = "temp-table-wrap",
      htmltools::tags$div(
        id = "temp-table",
        DT::datatable(
          diamond_data %>%
            transmute(
              CARAT       = carat_weight,
              CUT         = cut,
              COLOR       = color,
              CLARITY     = clarity,
              POLISH      = polish,
              SYMMETRY    = symmetry,
              REPORT      = report,
              PRICE       = price
            ) %>%
            slice(1:9),
          rownames = TRUE,
          options = list(
            autoWidth = TRUE,
            dom = "tip",
            scrollY = "135px",
            scrollX = TRUE,
            scrollCollapse = TRUE,
            ordering = FALSE,
            paging = FALSE
          )
        )
      )
    )
  )
)

Data Preparation

To stabilize variance, we log-transform the price. We also split the dataset into training and test sets (75/25). Dummy variables are created automatically when we construct the design matrix, and we include interaction terms to allow for more flexible relationships.

Below is a preview of the first few rows, showing the core variables and their log-transformed price.

Code
diamond_data <- diamond_data %>%
  mutate(log_price = log(price))

set.seed(5)
idx <- caret::createDataPartition(diamond_data$clarity, p = 0.75, list = FALSE)
train_data <- diamond_data[idx, ]
test_data  <- diamond_data[-idx, ]

suppressPackageStartupMessages({
  library(dplyr)
  library(DT)
  library(htmltools)
})

bg_col           <- "#FAF8F1"
.fixed_fig_width <- 1100L
.wrapper_max_w   <- 790L

htmltools::browsable(
  htmltools::tagList(
    htmltools::tags$style(htmltools::HTML(sprintf("
      #temp-table-wrap {
        border-radius: 6px;
        overflow: hidden;
        overflow-x: auto;
        -webkit-overflow-scrolling: touch;
        background: %s;
        padding: 0.75rem;
        width: 100%%;
        max-width: %dpx;
        margin: 0 auto 1rem;
      }
      #temp-table-wrap > div { min-width: %dpx; }

      @media (max-width: %dpx) {
        #temp-table-wrap { padding: 0.5rem; }
      }

      #temp-table table.dataTable thead th { text-align: left !important; }
      #temp-table table.dataTable tbody td { text-align: right !important; }
      #temp-table table.dataTable thead th:first-child,
      #temp-table table.dataTable tbody td:first-child { text-align: left !important; }
    ",
      bg_col, .wrapper_max_w, .fixed_fig_width, .wrapper_max_w
    ))),

    htmltools::tags$div(
      id = "temp-table-wrap",
      htmltools::tags$div(
        id = "temp-table",
        DT::datatable(
          diamond_data %>%
            transmute(
              CARAT      = carat_weight,
              CUT        = cut,
              COLOR      = color,
              CLARITY    = clarity,
              POLISH     = polish,
              SYMMETRY   = symmetry,
              REPORT     = report,
              PRICE      = price,
              `LOG PRICE` = sprintf("%.2f", log_price)
            ) %>%
            slice(1:9),
          rownames = TRUE,
          options = list(
            autoWidth = TRUE,
            dom = "tip",
            scrollY = "135px",
            scrollX = TRUE,
            scrollCollapse = TRUE,
            ordering = FALSE,
            paging = FALSE
          )
        )
      )
    )
  )
)

Methodology

At the core, linear regression estimates coefficients by minimizing squared error. But with too many variables, the model risks overfitting: capturing noise instead of general patterns. On the other hand, dropping too many variables leads to underfitting.

LASSO regression (Least Absolute Shrinkage and Selection Operator) addresses this by modifying the training objective. It adds a penalty term proportional to the absolute size of the coefficients.

The objective function for LASSO is:

\[ \hat{\beta} = \underset{\beta}{\arg\min} \left\{ \underbrace{\sum_{i=1}^n \big(y_i - \beta_0 - \sum_{j=1}^p x_{ij}\beta_j\big)^2}_{\text{Training Error}} \;+\; \underbrace{\lambda \sum_{j=1}^p |\beta_j|}_{\text{Model Complexity}} \right\} \]

  • \(y_i\) — the response variable (e.g., LOG PRICE in this case)
  • \(x_{i1}, x_{i2}\) — the predictor variables (e.g., carat, cut, clarity, etc.)
  • \(a_0\) — the intercept term
  • \(b_1, b_2\) — the coefficients to estimate for each predictor
  • \(\mathbf{b}\) — the vector of coefficients \((b_1, b_2)\)
  • \(\|\mathbf{b}\|_1\) — the L1 norm (sum of absolute values of coefficients)
  • \(\lambda\) — the tuning parameter controlling the tradeoff between fit and simplicity

Effect of \(\lambda\) on Model Complexity

Interpretation of \(\lambda\) is straightforward. When \(\lambda\) is small, the penalty is light, and the model keeps most variables, which raises the risk of overfitting.

Code
suppressPackageStartupMessages({
  library(dplyr)
  library(plotly)
  library(htmltools)
})

bg_col      <- "#FAF8F1"
point_color <- "#751F2C"
font_family <- "Ramabhadra"

.fixed_fig_width  <- 1100L
.fixed_fig_height <- 390L
.wrapper_max_w    <- 790L

spike_style <- list(
  spikecolor     = "#000",
  spikedash      = "dash",
  spikethickness = 1.5,
  spikemode      = "across",
  spikesnap      = "cursor",
  showspikes     = TRUE
)

set.seed(42)
n  <- 100
x  <- seq_len(n)
y  <- sin(x/8) + rnorm(n, 0, 0.3)
df <- tibble(x = x, y = y)

fit_over <- lm(y ~ poly(x, 8), data = df)
df_over_line <- tibble(x = x, yhat = as.numeric(predict(fit_over, newdata = tibble(x = x))))

fit_under <- lm(y ~ 1, data = df)
df_under_line <- tibble(x = x, yhat = as.numeric(predict(fit_under, newdata = tibble(x = x))))

hover_pts_over  <- paste0("<b>Index:</b> ", df$x, "<br><b>Y:</b> ", signif(df$y, 5))
hover_line_over <- paste0("<b>Index:</b> ", df_over_line$x, "<br><b>Fit:</b> ", signif(df_over_line$yhat, 5))

fig_over <- plot_ly() %>%
  add_trace(
    data = df, x = ~x, y = ~y,
    type = "scatter", mode = "markers",
    marker = list(size = 7, color = point_color, opacity = 0.5),
    hoverinfo = "text", text = hover_pts_over,
    name = "Observed"
  ) %>%
  add_trace(
    data = df_over_line, x = ~x, y = ~yhat,
    type = "scatter", mode = "lines",
    line = list(color = point_color, width = 2),
    hoverinfo = "text", text = hover_line_over,
    name = "Flexible fit"
  ) %>%
  layout(
    title = list(text = "λ = 0 (Overfitting)", x = 0.03, y = 0.95, xanchor = "left"),
    width  = .fixed_fig_width,
    height = .fixed_fig_height,
    font = list(family = font_family, size = 12),
    paper_bgcolor = bg_col, plot_bgcolor = bg_col,
    margin = list(l = 18, r = 18, t = 36, b = 36),
    legend = list(orientation = "v", x = 1.05, y = 1, xanchor = "left", yanchor = "top",
                  font = list(size = 16), title = list(text = "")),
    xaxis = c(list(
      title = list(text = "OBSERVATION INDEX", standoff = 20),
      tickfont = list(size = 16), gridcolor = "#E8E8E8",
      zeroline = FALSE, fixedrange = TRUE
    ), spike_style),
    yaxis = c(list(
      title = list(text = "RESPONSE", standoff = 20),
      tickfont = list(size = 16), automargin = TRUE, fixedrange = TRUE
    ), spike_style),
    hovermode = "closest",
    hoverlabel = list(
      font = list(family = font_family, size = 14, color = "#313131"),
      bgcolor = "#FFF", namelength = -1
    )
  ) %>%
  config(
    responsive = FALSE,
    scrollZoom = TRUE, doubleClick = FALSE,
    modeBarButtonsToRemove = list(
      "zoom2d","pan2d","select2d","lasso2d","zoomIn2d","zoomOut2d",
      "autoScale2d","resetScale2d","toggleSpikelines","toImage"
    ),
    displaylogo = FALSE, displayModeBar = TRUE, showTips = FALSE
  )

htmltools::browsable(
  htmltools::tagList(
    htmltools::tags$style(htmltools::HTML(sprintf("
      #temp-plot-wrap-over, #temp-plot-wrap-under {
        border-radius: 6px;
        overflow: hidden;
        overflow-x: auto;
        -webkit-overflow-scrolling: touch;
        background: %s;
        padding: 0.75rem;
        width: 100%%;
        max-width: %dpx;
        margin: 0 auto 1rem;
      }
      #temp-plot-wrap-over > div, #temp-plot-wrap-under > div { min-width: %dpx; }
      @media (max-width: %dpx) {
        #temp-plot-wrap-over, #temp-plot-wrap-under { padding: 0.5rem; }
      }
    ", bg_col, .wrapper_max_w, .fixed_fig_width, .wrapper_max_w))),
    htmltools::tags$div(id = "temp-plot-wrap-over",  fig_over)
  )
)

When \(\lambda\) is large, the penalty dominates and many coefficients are shrunk to zero, which raises the risk of underfitting.

Code
hover_pts_under  <- paste0("<b>Index:</b> ", df$x, "<br><b>Y:</b> ", signif(df$y, 5))
hover_line_under <- paste0("<b>Index:</b> ", df_under_line$x, "<br><b>Fit:</b> ", signif(df_under_line$yhat, 5))

fig_under <- plot_ly() %>%
  add_trace(
    data = df, x = ~x, y = ~y,
    type = "scatter", mode = "markers",
    marker = list(size = 7, color = point_color, opacity = 0.5),
    hoverinfo = "text", text = hover_pts_under,
    name = "Observed"
  ) %>%
  add_trace(
    data = df_under_line, x = ~x, y = ~yhat,
    type = "scatter", mode = "lines",
    line = list(color = point_color, width = 2),
    hoverinfo = "text", text = hover_line_under,
    name = "Simple fit"
  ) %>%
  layout(
    title = list(text = "λ Large (Underfitting)", x = 0.03, y = 0.95, xanchor = "left"),
    width  = .fixed_fig_width,
    height = .fixed_fig_height,
    font = list(family = font_family, size = 12),
    paper_bgcolor = bg_col, plot_bgcolor = bg_col,
    margin = list(l = 18, r = 18, t = 36, b = 36),
    legend = list(orientation = "v", x = 1.05, y = 1, xanchor = "left", yanchor = "top",
                  font = list(size = 16), title = list(text = "")),
    xaxis = c(list(
      title = list(text = "OBSERVATION INDEX", standoff = 20),
      tickfont = list(size = 16), gridcolor = "#E8E8E8",
      zeroline = FALSE, fixedrange = TRUE
    ), spike_style),
    yaxis = c(list(
      title = list(text = "RESPONSE", standoff = 20),
      tickfont = list(size = 16), automargin = TRUE, fixedrange = TRUE
    ), spike_style),
    hovermode = "closest",
    hoverlabel = list(
      font = list(family = font_family, size = 14, color = "#313131"),
      bgcolor = "#FFF", namelength = -1
    )
  ) %>%
  config(
    responsive = FALSE,
    scrollZoom = TRUE, doubleClick = FALSE,
    modeBarButtonsToRemove = list(
      "zoom2d","pan2d","select2d","lasso2d","zoomIn2d","zoomOut2d",
      "autoScale2d","resetScale2d","toggleSpikelines","toImage"
    ),
    displaylogo = FALSE, displayModeBar = TRUE, showTips = FALSE
  )

htmltools::browsable(
  htmltools::tagList(
    htmltools::tags$style(htmltools::HTML(sprintf("
      #temp-plot-wrap-over, #temp-plot-wrap-under {
        border-radius: 6px;
        overflow: hidden;
        overflow-x: auto;
        -webkit-overflow-scrolling: touch;
        background: %s;
        padding: 0.75rem;
        width: 100%%;
        max-width: %dpx;
        margin: 0 auto 1rem;
      }
      #temp-plot-wrap-over > div, #temp-plot-wrap-under > div { min-width: %dpx; }
      @media (max-width: %dpx) {
        #temp-plot-wrap-over, #temp-plot-wrap-under { padding: 0.5rem; }
      }
    ", bg_col, .wrapper_max_w, .fixed_fig_width, .wrapper_max_w))),
    htmltools::tags$div(id = "temp-plot-wrap-under", fig_under)
  )
)

The key is to choose a value of \(\lambda\) that strikes the right balance, and cross-validation provides a systematic way to identify the value that minimizes out-of-sample error. In practice, this penalty acts as an automatic variable selection tool: unimportant predictors are dropped, leaving only the features that meaningfully contribute to the model.

Code
suppressPackageStartupMessages(library(stats))

add_log_price <- function(df) {
  if (!"log_Price" %in% names(df)) {
    if ("Price" %in% names(df))       df$log_Price <- log(df$Price)
    else if ("price" %in% names(df))  df$log_Price <- log(df$price)
  }
  df
}
train_data <- add_log_price(train_data)
test_data  <- add_log_price(test_data)

has_caps <- all(c("Carat","Color","Clarity","Cut","Polish","Symmetry","Report") %in% names(train_data))
nm <- if (has_caps) list(
  carat      = "Carat",
  color      = "Color",
  clarity    = "Clarity",
  cut        = "Cut",
  polish     = "Polish",
  symmetry   = "Symmetry",
  report     = "Report"
) else list(
  carat      = "carat_weight",
  color      = "color",
  clarity    = "clarity",
  cut        = "cut",
  polish     = "polish",
  symmetry   = "symmetry",
  report     = "report"
)

form <- as.formula(sprintf(
  "log_Price ~ %1$s + I(sqrt(%1$s)) + %2$s + %3$s + %4$s + %2$s:%4$s + %5$s + %6$s + %7$s",
  nm$carat, nm$color, nm$clarity, nm$cut, nm$polish, nm$symmetry, nm$report
))

build_mm <- function(form, data) {
  mf <- model.frame(form, data = data)
  y  <- model.response(mf)
  X  <- model.matrix(form, data = mf)[, -1, drop = FALSE]
  list(X = X, y = y)
}
mm_train <- build_mm(form, train_data)
mm_test  <- build_mm(form,  test_data)

x_train <- mm_train$X; y_train <- mm_train$y
x_test  <- mm_test$X;  y_test  <- mm_test$y

.title_case <- function(x) paste0(toupper(substring(x,1,1)), tolower(substring(x,2)))

.title_word <- function(w) {
  if (grepl("^[A-Z0-9]+$", w)) return(toupper(w))
  paste0(toupper(substring(w, 1, 1)), toupper(substring(w, 2)))
}

.clean_level <- function(lvl) {
  lvl <- gsub("[._]+", " ", lvl)
  lvl <- gsub("\\s+", " ", lvl)
  lvl <- trimws(lvl)
  if (lvl == "") return(lvl)
  paste(vapply(strsplit(lvl, " +")[[1]], .title_word, character(1)), collapse = " ")
}

.pretty_token <- function(tok) {
  tok <- sub(sprintf("^%s$", nm$carat),                  "Carat (ct)", tok, ignore.case = FALSE)
  tok <- sub(sprintf("^I\\(sqrt\\(%s\\)\\)$", nm$carat), "√Carat (√ct)", tok, ignore.case = FALSE)

  if (grepl("^.+\\s\\(.+\\)$", tok)) return(tok)

  disp_map <- c(
    color    = "Color",
    clarity  = "Clarity",
    cut      = "Cut",
    polish   = "Polish",
    symmetry = "Symmetry",
    report   = "Report"
  )

  cand <- c(nm$color, nm$clarity, nm$cut, nm$polish, nm$symmetry, nm$report)

  for (cat in cand) {
    if (grepl(paste0("^", cat), tok, ignore.case = TRUE)) {
      lvl_raw <- sub(paste0("^", cat), "", tok, ignore.case = TRUE, perl = TRUE)
      lvl     <- .clean_level(lvl_raw)
      cat_key <- names(which(tolower(unlist(nm)) == tolower(cat)))[1]
      cat_lab <- disp_map[cat_key]
      return(sprintf("%s (%s)", cat_lab, lvl))
    }
  }

  m <- regexec("^([A-Za-z]+)(.+)$", tok)
  hit <- regmatches(tok, m)[[1]]
  if (length(hit) == 3) {
    return(sprintf("%s (%s)", .clean_level(hit[2]), .clean_level(hit[3])))
  }
  tok
}

.pretty_name <- function(nm) {
  if (grepl(":", nm, fixed = TRUE)) {
    toks <- strsplit(nm, ":", fixed = TRUE)[[1]]
    return(paste(vapply(toks, .pretty_token, character(1)), collapse = " × "))
  }
  .pretty_token(nm)
}

pretty_all <- function(nms) vapply(nms, .pretty_name, character(1))

colnames(x_train) <- pretty_all(colnames(x_train))
colnames(x_test)  <- pretty_all(colnames(x_test))

head(as.data.frame(x_train))
  Carat (ct) √Carat (√ct) Color (E) Color (F) Color (G) Color (H) Color (I)
1       0.83    0.9110434         0         0         0         1         0
2       1.53    1.2369317         1         0         0         0         0
3       1.00    1.0000000         0         0         0         0         0
4       1.50    1.2247449         0         1         0         0         0
5       2.11    1.4525839         0         0         0         1         0
6       0.91    0.9539392         0         0         0         0         0
  Clarity (IF) Clarity (SI1) Clarity (VS1) Clarity (VS2) Clarity (VVS1)
1            0             0             1             0              0
2            0             1             0             0              0
3            0             1             0             0              0
4            0             1             0             0              0
5            0             1             0             0              0
6            0             0             0             1              0
  Clarity (VVS2) Cut (GOOD) Cut (IDEAL) Cut (SIGNATURE-IDEAL) Cut (VERY GOOD)
1              0          0           1                     0               0
2              0          0           1                     0               0
3              0          0           0                     0               1
4              0          0           0                     0               0
5              0          0           1                     0               0
6              0          0           1                     0               0
  Polish (G) Polish (ID) Polish (VG) Symmetry (G) Symmetry (ID) Symmetry (VG)
1          0           1           0            0             1             0
2          0           1           0            0             1             0
3          0           0           1            1             0             0
4          0           0           1            0             0             1
5          0           0           1            0             0             1
6          0           0           1            0             0             1
  Report (GIA) Color (E) × Cut (GOOD) Color (F) × Cut (GOOD)
1            0                      0                      0
2            0                      0                      0
3            1                      0                      0
4            1                      0                      0
5            1                      0                      0
6            1                      0                      0
  Color (G) × Cut (GOOD) Color (H) × Cut (GOOD) Color (I) × Cut (GOOD)
1                      0                      0                      0
2                      0                      0                      0
3                      0                      0                      0
4                      0                      0                      0
5                      0                      0                      0
6                      0                      0                      0
  Color (E) × Cut (IDEAL) Color (F) × Cut (IDEAL) Color (G) × Cut (IDEAL)
1                       0                       0                       0
2                       1                       0                       0
3                       0                       0                       0
4                       0                       0                       0
5                       0                       0                       0
6                       0                       0                       0
  Color (H) × Cut (IDEAL) Color (I) × Cut (IDEAL)
1                       1                       0
2                       0                       0
3                       0                       0
4                       0                       0
5                       1                       0
6                       0                       0
  Color (E) × Cut (SIGNATURE-IDEAL) Color (F) × Cut (SIGNATURE-IDEAL)
1                                 0                                 0
2                                 0                                 0
3                                 0                                 0
4                                 0                                 0
5                                 0                                 0
6                                 0                                 0
  Color (G) × Cut (SIGNATURE-IDEAL) Color (H) × Cut (SIGNATURE-IDEAL)
1                                 0                                 0
2                                 0                                 0
3                                 0                                 0
4                                 0                                 0
5                                 0                                 0
6                                 0                                 0
  Color (I) × Cut (SIGNATURE-IDEAL) Color (E) × Cut (VERY GOOD)
1                                 0                           0
2                                 0                           0
3                                 0                           0
4                                 0                           0
5                                 0                           0
6                                 0                           0
  Color (F) × Cut (VERY GOOD) Color (G) × Cut (VERY GOOD)
1                           0                           0
2                           0                           0
3                           0                           0
4                           0                           0
5                           0                           0
6                           0                           0
  Color (H) × Cut (VERY GOOD) Color (I) × Cut (VERY GOOD)
1                           0                           0
2                           0                           0
3                           0                           0
4                           0                           0
5                           0                           0
6                           0                           0

Analysis

Let’s start with a quick baseline fit (using glmnet) before tuning. We’ll train a LASSO model on the design matrix x_train and response y_train using a fixed value for lambda \((\lambda = 1)\). This isn’t the optimal penalty; it’s a sanity check to confirm the pipeline is wired correctly (matrix construction, factor encoding, and prediction). We then score on the held-out test set by computing Mean Squared Error (MSE) and Mean Absolute Percentage Error (MAPE). MSE penalizes larger errors more heavily, while MAPE reports average error as a percent of the true value—which is useful for communicating accuracy to a non-technical audience. After this baseline, we’ll replace the fixed penalty with cross-validated \(\lambda\) to balance bias and variance and improve out-of-sample performance.

Code
suppressPackageStartupMessages(library(glmnet))

fit <- glmnet(
  x = as.matrix(x_train),
  y = as.numeric(y_train),
  alpha = 1,
  lambda = 1,
  standardize = TRUE,
  family = "gaussian"
)

log_pred <- as.numeric(predict(fit, newx = as.matrix(x_test), s = 1))

actual_price <- if ("Price" %in% names(test_data)) test_data$Price else test_data$price
mse  <- mean((actual_price - log_pred)^2, na.rm = TRUE)

mape <- {
  ok <- is.finite(actual_price) & actual_price != 0
  mean(abs((actual_price[ok] - log_pred[ok]) / actual_price[ok])) * 100
}

cat(sprintf("LASSO Testing MSE:  %.6f\n", mse))
LASSO Testing MSE:  231411504.522450
Code
cat(sprintf("LASSO Testing MAPE: %.6f%%\n", mape))
LASSO Testing MAPE: 99.872544%
Code
suppressPackageStartupMessages({
  library(glmnet)
  library(dplyr)
  library(tidyr)
  library(plotly)
  library(htmltools)
})

bg_col      <- "#FAF8F1"
point_color <- "#751F2C"
font_family <- "Ramabhadra"

.fixed_fig_width  <- 1100L
.fixed_fig_height <- 390L
.wrapper_max_w    <- 790L

spike_style <- list(
  spikecolor     = "#000",
  spikedash      = "dash",
  spikethickness = 1.5,
  spikemode      = "across",
  spikesnap      = "cursor",
  showspikes     = TRUE
)

set.seed(0)
lasso_cv <- cv.glmnet(
  x = as.matrix(x_train),
  y = as.numeric(y_train),
  alpha = 1, nfolds = 10,
  standardize = FALSE, intercept = TRUE,
  type.measure = "mse",
  family = "gaussian"
)
best_alpha <- lasso_cv$lambda.min

cat(sprintf("BEST ALPHA: %.6f\n", best_alpha))
BEST ALPHA: 0.000031
Code
cat(sprintf("BEST ALPHA (SCIENTIFIC): %.3e", best_alpha))
BEST ALPHA (SCIENTIFIC): 3.098e-05
Code
df_cv <- tibble(alpha = lasso_cv$lambda, mse = lasso_cv$cvm)
hover_cv <- paste0(
  "<b>ALPHA:</b> ", signif(df_cv$alpha, 6),
  "<br><b>MSE:</b> ", signif(df_cv$mse, 6)
)

cv_fig <- plot_ly(
  df_cv, x = ~alpha, y = ~mse,
  type = "scatter", mode = "markers+lines",
  text = hover_cv, hoverinfo = "text",
  marker = list(size = 9, color = point_color, line = list(width = 0), opacity = 0.9,
                sizemode = "diameter", sizemin = 4),
  line   = list(dash = "solid", color = point_color)
) %>%
  layout(
    width  = .fixed_fig_width,
    height = .fixed_fig_height,
    font = list(family = font_family, size = 12),
    paper_bgcolor = bg_col, plot_bgcolor = bg_col,
    margin = list(l = 18, r = 18, t = 36, b = 36),
    legend = list(orientation = "v", x = 1.05, y = 1, xanchor = "left", yanchor = "top",
                  font = list(size = 16), title = list(text = "")),
    xaxis = c(list(
      title = list(text = "ALPHA", standoff = 20),
      type = "log",
      tickfont = list(size = 16), gridcolor = "#E8E8E8",
      zeroline = FALSE, fixedrange = TRUE
    ), spike_style),
    yaxis = c(list(
      title = list(text = "MEAN SQUARED ERROR", standoff = 20),
      tickfont = list(size = 16), automargin = TRUE, fixedrange = TRUE
    ), spike_style),
    uniformtext = list(minsize = 16),
    hovermode = "x unified",
    hoverlabel = list(
      font = list(family = font_family, size = 14, color = "#313131"),
      bgcolor = "#FFF", namelength = -1
    ),
    shapes = list(
      list(
        type = "line", x0 = best_alpha, x1 = best_alpha, xref = "x", yref = "paper",
        y0 = 0, y1 = 1, line = list(dash = "dash", color = "slategray", width = 1.5)
      )
    ),
    annotations = list(list(
      x = best_alpha, y = 1, xref = "x", yref = "paper",
      text = sprintf("BEST ALPHA: %.3e", best_alpha),
      showarrow = FALSE, xanchor = "left", yanchor = "bottom",
      font = list(color = "slategray", family = font_family, size = 12)
    )),
    dragmode = FALSE
  ) %>%
  config(
    responsive = FALSE,
    scrollZoom = TRUE, doubleClick = FALSE,
    modeBarButtonsToRemove = list(
      "zoom2d","pan2d","select2d","lasso2d","zoomIn2d","zoomOut2d",
      "autoScale2d","resetScale2d","toggleSpikelines","toImage"
    ),
    displaylogo = FALSE, displayModeBar = TRUE, showTips = FALSE
  )

suppressPackageStartupMessages(library(glmnet))

fit_path <- glmnet(
  x = as.matrix(x_train),
  y = as.numeric(y_train),
  alpha = 1,
  standardize = FALSE, intercept = TRUE,
  family = "gaussian"
)

beta_mat   <- as.matrix(fit_path$beta)
alpha_path <- fit_path$lambda

coef_df <- as.data.frame(t(beta_mat)) %>%
  dplyr::mutate(alpha = alpha_path) %>%
  tidyr::pivot_longer(cols = -alpha, names_to = "feature", values_to = "coef")

if (!exists("best_alpha")) {
  set.seed(0)
  lasso_cv <- cv.glmnet(
    x = as.matrix(x_train),
    y = as.numeric(y_train),
    alpha = 1, nfolds = 10,
    standardize = FALSE, intercept = TRUE,
    type.measure = "mse",
    family = "gaussian"
  )
  best_alpha <- lasso_cv$lambda.min
}
nearest_idx <- which.min(abs(alpha_path - best_alpha))
ord <- order(abs(beta_mat[, nearest_idx]), decreasing = TRUE)
coef_df$feature <- factor(coef_df$feature, levels = rownames(beta_mat)[ord])

n_ser <- nlevels(coef_df$feature)
series_cols <- setNames(
  grDevices::hcl(h = seq(0, 360, length.out = n_ser + 1)[1:n_ser], c = 70, l = 50),
  levels(coef_df$feature)
)

coef_fig <- plotly::plot_ly()
for (f in levels(coef_df$feature)) {
  df_f <- dplyr::filter(coef_df, feature == f)
  hover_txt <- paste0(
    "<b>", f, "</b>",
    "<br><b>ALPHA:</b> ", signif(df_f$alpha, 6),
    "<br><b>COEF:</b> ", signif(df_f$coef, 6)
  )
  coef_fig <- add_trace(
    coef_fig,
    data = df_f,
    x = ~alpha, y = ~coef,
    type = "scatter", mode = "lines",
    name = f,
    text = hover_txt,
    hoverinfo = "text",
    hovertemplate = "%{text}<extra></extra>",
    line = list(dash = "solid", color = series_cols[[f]], width = 1.8)
  )
}

features <- levels(coef_df$feature)

vis_all   <- rep(TRUE, length(features))
vis_carat <- grepl("carat|sqrt\\(carat", features, ignore.case = TRUE)
vis_color <- grepl("color",             features, ignore.case = TRUE)
vis_clar  <- grepl("clarity",           features, ignore.case = TRUE)
vis_cut   <- grepl("cut",               features, ignore.case = TRUE)

buttons_fourCs <- list(
  list(method = "update", label = "ALL",
       args = list(list(visible = vis_all))),
  list(method = "update", label = "CARAT",
       args = list(list(visible = vis_carat))),
  list(method = "update", label = "COLOR",
       args = list(list(visible = vis_color))),
  list(method = "update", label = "CLARITY",
       args = list(list(visible = vis_clar))),
  list(method = "update", label = "CUT",
       args = list(list(visible = vis_cut)))
)

spike_style <- list(
  spikemode      = "across",
  spikecolor     = "#000",
  spikedash      = "dash",
  spikethickness = 1.5,
  spikesnap      = "cursor",
  showspikes     = TRUE
)

coef_fig <- layout(
  coef_fig,
  updatemenus = list(list(
    type = "buttons",
    direction = "right",
    xref = "paper", yref = "paper",
    x = 0,          xanchor = "left",
    y = 1.25,       yanchor = "middle",
    buttons = buttons_fourCs,
    showactive = TRUE
  )),
  width  = .fixed_fig_width,
  height = .fixed_fig_height,
  font = list(family = font_family, size = 12),
  paper_bgcolor = bg_col, plot_bgcolor = bg_col,
  margin = list(l = 18, r = 18, t = 36, b = 36),
  legend = list(orientation = "v", x = 1.05, y = 1,
                xanchor = "left", yanchor = "top",
                font = list(size = 16), title = list(text = "")),
  xaxis = c(list(
    title = list(text = "ALPHA", standoff = 20),
    type = "log",
    tickfont = list(size = 16), gridcolor = "#E8E8E8",
    zeroline = FALSE, fixedrange = TRUE
  ), spike_style),
  yaxis = c(list(
    title = list(text = "COEFFICIENTS", standoff = 20),
    tickfont = list(size = 16), automargin = TRUE, fixedrange = TRUE
  ), spike_style),
  uniformtext = list(minsize = 16),
  hovermode = "closest",
  hoverlabel = list(
    font = list(family = font_family, size = 14, color = "#313131"),
    bgcolor = "#FFF", namelength = -1
  ),
  shapes = list(
    list(type = "line", x0 = best_alpha, x1 = best_alpha, xref = "x", yref = "paper",
         y0 = 0, y1 = 1, line = list(dash = "dash", color = "slategray", width = 1.5)),
    list(type = "line", x0 = 0, x1 = 1, xref = "paper", yref = "y",
         y0 = 0, y1 = 0, line = list(dash = "dash", color = "#000", width = 1.5))
  ),
  annotations = list(list(
    x = best_alpha, y = 1, xref = "x", yref = "paper",
    text = sprintf("BEST ALPHA: %.3e", best_alpha),
    showarrow = FALSE, xanchor = "left", yanchor = "bottom",
    font = list(color = "slategray", family = font_family, size = 12)
  )),
  dragmode = FALSE
) %>%
  config(
    responsive = FALSE,
    scrollZoom = TRUE, doubleClick = FALSE,
    modeBarButtonsToRemove = list(
      "zoom2d","pan2d","select2d","lasso2d","zoomIn2d","zoomOut2d",
      "autoScale2d","resetScale2d","toggleSpikelines","toImage"
    ),
    displaylogo = FALSE, displayModeBar = TRUE, showTips = FALSE
  )

Here, we turn to cross-validation to choose the best value of \(\lambda\). By training the model across a grid of penalty strengths, we can see how prediction error changes as complexity increases or decreases. The plot highlights this tradeoff between model flexibility and accuracy, with the vertical dashed line marking the \(\lambda\) that delivers the lowest out-of-sample error.

Code
htmltools::browsable(
  htmltools::tagList(
    htmltools::tags$style(htmltools::HTML(sprintf("
      #temp-plot-wrap-cv, #temp-plot-wrap-coef {
        border-radius: 6px;
        overflow: hidden;
        overflow-x: auto;
        -webkit-overflow-scrolling: touch;
        background: %s;
        padding: 0.75rem;
        width: 100%%;
        max-width: %dpx;
        margin: 0 auto 1rem;
      }
      #temp-plot-wrap-cv > div, #temp-plot-wrap-coef > div { min-width: %dpx; }

      @media (max-width: %dpx) {
        #temp-plot-wrap-cv, #temp-plot-wrap-coef { padding: 0.5rem; }
      }
    ",
      bg_col, .wrapper_max_w, .fixed_fig_width, .wrapper_max_w
    ))),
    htmltools::tags$div(id = "temp-plot-wrap-cv",   cv_fig)
  )
)

This visualization plots the the coefficient paths across different values of \(\lambda\). As the penalty increases, some coefficients shrink toward zero while others remain, giving a clear picture of which predictors truly drive diamond prices. This view lets us see not just which features are selected, but also how their influence grows or fades as the model becomes more or less complex.

Code
htmltools::browsable(
  htmltools::tagList(
    htmltools::tags$style(htmltools::HTML(sprintf("
      #temp-plot-wrap-cv, #temp-plot-wrap-coef {
        border-radius: 6px;
        overflow: hidden;
        overflow-x: auto;
        -webkit-overflow-scrolling: touch;
        background: %s;
        padding: 0.75rem;
        width: 100%%;
        max-width: %dpx;
        margin: 0 auto 1rem;
      }
      #temp-plot-wrap-cv > div, #temp-plot-wrap-coef > div { min-width: %dpx; }

      @media (max-width: %dpx) {
        #temp-plot-wrap-cv, #temp-plot-wrap-coef { padding: 0.5rem; }
      }
    ",
      bg_col, .wrapper_max_w, .fixed_fig_width, .wrapper_max_w
    ))),
    htmltools::tags$div(id = "temp-plot-wrap-coef", coef_fig)
  )
)

Results & Next Steps

The LASSO regression confirms that CARAT is the strongest predictor of diamond prices, with quality factors such as CLARITY, CUT, and COLOR also contributing meaningfully. By shrinking weaker coefficients toward zero, the model filtered out predictors with little explanatory power, leaving a smaller and more interpretable set of variables. This not only improves transparency but also highlights the attributes that matter most in valuation.

Performance on the test set was strong, with error rates low enough to suggest that the model generalizes well beyond the training data. Cross-validation played a critical role here, guiding the choice of penalty strength so that the model avoided both underfitting and overfitting. The coefficient path visualization further reinforced this, showing how predictors entered or dropped out as the penalty changed.

Looking ahead, there are several directions to build on this work. Ridge regression and Elastic Net are natural alternatives that may yield more stable results when predictors are correlated. Additional feature engineering, such as higher-order terms or new interaction effects, could uncover subtler relationships in the data. Robustness should also be validated on different datasets or with resampling to confirm consistency. Finally, these results can be framed for applied contexts, helping buyers, sellers, and appraisers better understand which diamond attributes drive price.