Building a Simple Aircraft Image Classifier From Scratch

From a folder-of-images to a working classifier with fast iteration

computer vision
image classification
machine learning
transfer learning
MLOps
Python
Author

A. Srikanth

Published

January 16, 2026

Sandbox

Context

Late December last year I was in Seattle, Washington. One afternoon, I did what every aviation nerd eventually does: I went to see the Boeing factory.

I’ve photographed aircraft for fun for years, but standing that close to the real machines shifted the way I look at them. Planes don’t show up as neat labels in your head. They show up as geometry and hardware: the curve of an engine intake, the angle where the wing meets the fuselage, the shape of a nose cone, the height and sweep of a tail. I could tell two aircraft were different, but when it came to explaining the difference or naming what I was looking at, I kept coming up short.

On the flight back to Toronto, that gap wouldn’t leave me alone. If my eyes struggle to hold onto these distinctions, could a model learn them? And maybe more importantly, could I build the whole thing end to end without treating deep learning like magic?

That’s what this demo became: a small, personal sandbox. Images go in, predictions come out, and there’s enough evaluation to tell the difference between real signal and the model just learning backgrounds and vibes. It’s not research-grade and it’s not a polished piece, but it was an ideal first attempt: something I could actually complete under a time crunch while learning a thing or two.

Objectives

I kept the goals deliberately simple on purpose. If there’s one thing I wanted to avoid, it was a weekend-long chase for tiny performance bumps. Remember, I sought to build something that forced me to learn the basics cleanly.

So the aim was straightforward: train a multi-class model that could recognize a handful of aircraft types from photos, and build the whole workflow myself from start to finish. The work ahead of us included loading the images, splitting the data, training the model, evaluating it, and sanity-checking the results in a straightforward manner (we’re not just celebrating a single accuracy number).

Data Sources

Dataset

This project involved using the Aircraft Image Dataset hosted on Mendeley Data (Version 1.0). The dataset contains 4,520 labeled images spanning eight aircraft classes, distributed as folder-based categories: Airbus, ATR, Boeing, C130, F16, Grob, KAI, and Sukhoi.

Provenance

The dataset documentation notes that roughly 70% of the images come from the creator’s personal aircraft photography, with the remaining ~30% contributed by members of the Indonesian Aviation Photographers Community (KFAI). That origin matters because it implies a curated, photographer-driven collection process rather than fully automated scraping, which can shape the composition choices (e.g., factors such as angles, lighting, framing), background context, and the amount of variation within each class.

You can access the complete dataset here: https://doi.org/10.17632/mdmczsr5fy.1

Working Subset

For this demo, I narrowed the working set to six main classes to keep training and evaluation computationally manageable: Boeing, C130, F16, Grob, KAI, and Sukhoi.

Citation

Putra, R. D., & Ihsan, A. F. (2025). Aircraft Image Dataset (Version 1) [Data set]. Mendeley Data. https://doi.org/10.17632/mdmczsr5fy.1

Analysis (Part I) - Building a Predictive Model

Here’s the thing: I didn’t want this to turn into a weeks-long architecture rabbit hole. I wanted a solid baseline that let me learn the pipeline and still get meaningful results, so I went with ResNet18 and transfer learning.

In practice, that meant starting with a model that already understands general visual features, freezing most of it, and only training the final classifier layer plus a small slice of the last block. It’s a good middle ground: you still learn what matters (data loading, splits, training, evaluation), without turning the whole thing into from-scratch suffering or copy-paste magic.

The training curves behaved the way you hope they will in an early project. Loss dropped quickly, then flattened. Accuracy climbed fast. Train and test stayed fairly close, which is usually a decent sign we’re not immediately overfitting.

The first time I ran a fixed example image through the model, the confidence chart looked almost comically decisive: one bar near 1.0 and everything else basically zero. It felt like the model was saying: “this is a Boeing, and I will not be taking any further questions.”

I did a bit of digging and it turns out that kind of spike can be totally normal. Softmax probabilities are famous for looking overconfident, especially when the image is clean and the class shows up a lot in training. Still, it landed the point for me: confidence isn’t the same thing as correctness. A model can sound certain and still be wrong, and the only way to find out is to stress-test it with harder cases: weird angles, harsher lighting, similar silhouettes, near-duplicates.

In a perfect world, I’d spend an afternoon at a (safe) viewing area near an international runway, take a fresh batch of shots, and feed the model images that actually fight back (believe me, I’d love that.) But for this project, simplicity remains the mantra. The goal is to learn the pipeline cleanly, not turn it into a months-long data collection mission.

Analysis (Part II) - What the Confusion Matrix Taught Me

The confusion matrix ended up being the most useful plot in the whole demo, mostly because it doesn’t just tell you if the model is wrong. It tells you how it’s wrong. And honestly, that’s probably why it got the name (someone should fact-check this). It’s the one chart that reduces confusion by showing you where the confusion lives.

Here’s what it is: a confusion matrix is a grid that compares the true label (what the image actually is) to the predicted label (what the model guessed). The diagonal is the happy path. Big numbers there mean the model is landing the right class. Everything off the diagonal is where things get interesting, because those cells are the model admitting, “Yeah… I mixed those two up.”

With aircraft photos, those mistakes usually have a story. Some planes genuinely share silhouettes. A lot of photos are shot from the same side-on runway angle. And backgrounds can quietly become a crutch. If most Boeing photos happen to be taken on grassy runways in a certain kind of light, the model can learn the scene along with the aircraft. It’s not malicious. It’s just doing what models do: grabbing whatever patterns are easiest to exploit.

There’s also a trap that matters a lot in photography datasets: near-duplicates. If I take a burst of the same aircraft and keep five frames that are basically identical, a random split can accidentally put one frame in training and another in test. The model looks like a genius, but it’s partly because it’s seeing the same moment twice with a slightly different timestamp. The confusion matrix helps catch that too. If things look too clean, it’s a nudge to ask whether the model learned aircraft features… or just memorized familiar shots in disguise.

Methodology

Here’s the castle I built in this sandbox. –

  1. The Python script first loaded the images using a simple ImageFolder setup: I had one folder per class, and those became the labels automatically. That kept the data plumbing boring in the best way, so I could focus on the model and evaluation sections.

  2. With a stratified train/test split by class, each aircraft type stayed represented on both sides (train and test). During training, I used light augmentations like random crops, horizontal flips, and small rotations to make the model less fragile to angle and framing. The sample images shown here are the original photos, just center-cropped into squares to make the grid cleaner and easier to compare at a glance.

  3. For the model, I used ResNet18 with transfer learning. I froze most of the network and only trained the final classification layer plus the last block (layer4). That gave me a baseline that learns quickly without pretending I’m doing research-grade architecture work.

  4. Optimization was straightforward: cross-entropy loss with an AdamW optimizer, which is a clean default for this kind of multi-class classification problem.

  5. I made sure the outputs were interpretable, not just a single accuracy number: we brought loss and accuracy curves, a confusion matrix, random prediction grids, a 2×2 sample of misclassifications, and one fixed example image with a probability bar chart.

Results & Next Steps

The outputs were coherent. Training looked stable and the confusion matrix was mostly diagonal.

But I wouldn’t oversell the performance yet. Aircraft photo datasets can quietly inflate results if you don’t control for near-duplicates and shared context. Same aircraft. Same airport. Same photography style. Similar angles. That’s enough to make accuracy look “unreal” without real generalization.

If I were continuing this, I’d enforce a stricter split so near-identical shots can’t land in both train and test, hold out a hard set from different airports and conditions, and use something like Grad-CAM to sanity-check that the model is actually looking at the aircraft and not the runway vibes. And if I wanted the confidence chart to mean what it looks like it means, I’d also calibrate the probabilities so the numbers behave less like a gut feeling and more like actual likelihoods.

I’ll take a step back and say one more thing: the real takeaway for me wasn’t look, I built a perfect classifier. It was that curiosity travels well. A trip, a factory visit, and a small moment of “can I actually tell these apart?” turned into a compact project that taught me more than any generic tutorial could. If you want to learn a new method, start with a problem you already care about. That’s usually the difference between abandoning a Jupyter Notebook halfway and actually shipping something you can look back on and say, yeah, I learned that. That’s all for now, folks. I’ll see you in the next post. ✌️

Code
import os, random
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageOps

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms, models

import sys, io, contextlib

class _MuteAll:
    def __init__(self):
        self._stdout = io.StringIO()
        self._stderr = io.StringIO()
        self.enabled = True

    @contextlib.contextmanager
    def muted(self):
        if not self.enabled:
            yield
            return
        old_out, old_err = sys.stdout, sys.stderr
        try:
            sys.stdout, sys.stderr = self._stdout, self._stderr
            yield
        finally:
            sys.stdout, sys.stderr = old_out, old_err

MUTE = _MuteAll()

VERBOSE = False
def log(*args, **kwargs):
    if VERBOSE:
        print(*args, **kwargs)

with MUTE.muted():
  
  os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "TRUE")
  os.environ.setdefault("OMP_NUM_THREADS", "1")
  os.environ.setdefault("MKL_NUM_THREADS", "1")
  torch.set_num_threads(1)
  
  FIG_W, FIG_H, FIG_DPI = 9, 5, 110
  color_blue   = "#033c73"
  color_indigo = "#6610f2"
  color_purple = "#6f42c1"
  color_red    = "#751F2C"
  
  plt.rcParams.update(
      {
          "figure.figsize": (FIG_W, FIG_H),
          "figure.dpi": FIG_DPI,
          "font.family": "Ramabhadra",
          "font.weight": "bold",
          "text.color": "black",
          "axes.labelcolor": "black",
          "axes.titlecolor": "black",
          "xtick.color": "black",
          "ytick.color": "black",
          "axes.titlesize": 14,
          "axes.titleweight": "bold",
          "axes.labelsize": 12,
          "axes.labelweight": "bold",
          "xtick.labelsize": 10,
          "ytick.labelsize": 10,
          "legend.fontsize": 10,
          "axes.grid": True,
          "grid.alpha": 0.25,
      }
  )
  
  def find_data_root() -> Path:
      env_proj = os.environ.get("QUARTO_PROJECT_DIR")
      candidates = []
      if env_proj:
          candidates.append(Path(env_proj))
      candidates.append(Path.cwd())
      candidates.extend([Path.cwd().parent, Path.cwd().parent.parent])
  
      exts = (".jpg", ".jpeg", ".png", ".bmp", ".webp")
      data_names = ["data", "Data", "DATA"]
  
      tried = []
      for base in candidates:
          for dn in data_names:
              p = (base / dn).resolve()
              tried.append(str(p))
              if not p.exists() or not p.is_dir():
                  continue
              subdirs = [d for d in p.iterdir() if d.is_dir() and not d.name.startswith(".")]
              if not subdirs:
                  continue
              ok = False
              for sd in subdirs:
                  for e in exts:
                      if any(sd.glob(f"*{e}")):
                          ok = True
                          break
                  if ok:
                      break
              if ok:
                  return p
  
      raise FileNotFoundError(
          "Couldn't find class folders under a data directory.\n"
          "Tried:\n  - " + "\n  - ".join(tried) + "\n\n"
          "Expected: <project>/data/<ClassName>/*.jpg"
      )
  
  DATA_ROOT = find_data_root()
  
  SEED = 42
  IMG_SIZE = 224
  BATCH_SIZE = 32
  EPOCHS = 6
  LR = 2e-4
  TRAIN_FRAC = 0.90
  NUM_WORKERS = 0
  
  MAX_TRAIN = 1200
  MAX_TEST  = 450
  
  random.seed(SEED)
  np.random.seed(SEED)
  torch.manual_seed(SEED)
  
  device = (
      "cuda" if torch.cuda.is_available()
      else "mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
      else "cpu"
  )
  device = torch.device(device)
  
  train_tfm = transforms.Compose([
      transforms.RandomResizedCrop(IMG_SIZE, scale=(0.80, 1.0), ratio=(1.0, 1.0)),
      transforms.RandomHorizontalFlip(0.5),
      transforms.RandomRotation(7),
      transforms.ToTensor(),
      transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225]),
  ])
  
  test_tfm = transforms.Compose([
      transforms.Resize(IMG_SIZE + 32),
      transforms.CenterCrop(IMG_SIZE),
      transforms.ToTensor(),
      transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225]),
  ])
  
  base_ds = datasets.ImageFolder(str(DATA_ROOT))
  classes = base_ds.classes
  num_classes = len(classes)
  n = len(base_ds)
  
  log(f"DATA_ROOT: {DATA_ROOT}")
  log(f"Device: {device.type} | Images: {n} | Classes: {classes}")
  
  labels = np.array([y for _, y in base_ds.samples], dtype=int)
  rng = np.random.default_rng(SEED)
  
  train_idx = []
  test_idx  = []
  
  for c in range(num_classes):
      idx_c = np.where(labels == c)[0]
      rng.shuffle(idx_c)
  
      cut = int(len(idx_c) * TRAIN_FRAC)
      tr = idx_c[:cut]
      te = idx_c[cut:]
  
      train_idx.extend(tr.tolist())
      test_idx.extend(te.tolist())
  
  rng.shuffle(train_idx)
  rng.shuffle(test_idx)
  
  train_idx = train_idx[: min(MAX_TRAIN, len(train_idx))]
  test_idx  = test_idx[:  min(MAX_TEST,  len(test_idx))]
  
  ds_train_view = datasets.ImageFolder(str(DATA_ROOT), transform=train_tfm)
  ds_test_view  = datasets.ImageFolder(str(DATA_ROOT), transform=test_tfm)
  
  train_ds = Subset(ds_train_view, train_idx)
  test_ds  = Subset(ds_test_view,  test_idx)
  
  pin = (device.type == "cuda")
  train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS, pin_memory=pin)
  test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=pin)
  
  model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
  
  model.fc = nn.Linear(model.fc.in_features, num_classes)
  
  for p in model.parameters():
      p.requires_grad = False
  for p in model.layer4.parameters():
      p.requires_grad = True
  for p in model.fc.parameters():
      p.requires_grad = True
  
  model.to(device)
  
  MUTE.enabled = False
  print(model)
  MUTE.enabled = True
  
  criterion = nn.CrossEntropyLoss()
  
  opt = torch.optim.AdamW(
      [
          {"params": model.layer4.parameters(), "lr": LR * 0.3},
          {"params": model.fc.parameters(),     "lr": LR},
      ],
      weight_decay=1e-4
  )
  
  history = {"train_loss": [], "train_acc": [], "test_loss": [], "test_acc": []}
  
  def run_epoch(loader, train: bool):
      model.train(train)
      total_loss, correct, total = 0.0, 0, 0
  
      for xb, yb in loader:
          xb = xb.to(device)
          yb = yb.to(device)
  
          if train:
              opt.zero_grad(set_to_none=True)
  
          logits = model(xb)
          loss = criterion(logits, yb)
  
          if train:
              loss.backward()
              opt.step()
  
          bs = xb.size(0)
          total_loss += loss.item() * bs
          correct += (logits.argmax(1) == yb).sum().item()
          total += bs
  
      return total_loss / max(total, 1), correct / max(total, 1)
  
  for ep in range(1, EPOCHS + 1):
      tr_loss, tr_acc = run_epoch(train_loader, train=True)
      with torch.inference_mode():
          te_loss, te_acc = run_epoch(test_loader, train=False)
  
      history["train_loss"].append(tr_loss)
      history["train_acc"].append(tr_acc)
      history["test_loss"].append(te_loss)
      history["test_acc"].append(te_acc)
  
  fig, ax = plt.subplots()
  ax.plot(range(1, EPOCHS + 1), history["train_loss"], marker="o", linewidth=2.0, color=color_indigo, alpha=0.9, label="Train")
  ax.plot(range(1, EPOCHS + 1), history["test_loss"],  marker="o", linewidth=2.0, color=color_purple, alpha=0.9, label="Test")
  ax.set_title("LOSS BY EPOCH")
  ax.set_xlabel("EPOCH")
  ax.set_ylabel("CROSS-ENTROPY LOSS")
  ax.legend(loc="upper right")
  fig.tight_layout()
  plt.show()
  plt.close(fig)
  
  fig, ax = plt.subplots()
  ax.plot(range(1, EPOCHS + 1), np.array(history["train_acc"]) * 100, marker="o", linewidth=2.0, color=color_indigo, alpha=0.9, label="Train")
  ax.plot(range(1, EPOCHS + 1), np.array(history["test_acc"])  * 100, marker="o", linewidth=2.0, color=color_purple, alpha=0.9, label="Test")
  ax.set_title("ACCURACY BY EPOCH")
  ax.set_xlabel("EPOCH")
  ax.set_ylabel("ACCURACY (%)")
  ax.set_ylim(0, 100)
  ax.legend(loc="lower right")
  fig.tight_layout()
  plt.show()
  plt.close(fig)
  
  @torch.inference_mode()
  def predict_test():
      ys, ps, confs = [], [], []
      for xb, yb in test_loader:
          xb = xb.to(device)
          logits = model(xb)
          probs = torch.softmax(logits, dim=1)
          p = probs.argmax(1).cpu().numpy()
          c = probs.max(1).values.cpu().numpy()
          ys.extend(yb.numpy().tolist())
          ps.extend(p.tolist())
          confs.extend(c.tolist())
      return np.array(ys), np.array(ps), np.array(confs)
  
  y_true, y_pred, y_conf = predict_test()
  
  cm = np.zeros((num_classes, num_classes), dtype=int)
  for t, p in zip(y_true, y_pred):
      cm[t, p] += 1
  
  from mpl_toolkits.axes_grid1 import make_axes_locatable
  
  fig, ax = plt.subplots(figsize=(FIG_W, FIG_H), dpi=FIG_DPI, constrained_layout=True)
  
  im = ax.imshow(cm, interpolation="nearest")
  
  ax.set_title("CONFUSION MATRIX (TEST SUBSET)")
  ax.set_xlabel("PREDICTED CLASS")
  ax.set_ylabel("TRUE CLASS")
  
  ax.set_xticks(np.arange(num_classes))
  ax.set_yticks(np.arange(num_classes))
  ax.set_xticklabels([c.upper() for c in classes], rotation=30, ha="right")
  ax.set_yticklabels([c.upper() for c in classes])
  
  ax.set_aspect("equal", adjustable="box")
  ax.set_anchor("C")
  
  divider = make_axes_locatable(ax)
  cax = divider.append_axes("right", size="4.5%", pad=0.12)
  cbar = fig.colorbar(im, cax=cax)
  cbar.ax.set_ylabel("COUNT", rotation=90, labelpad=12, fontweight="bold")
  
  mx = float(cm.max()) if cm.size else 0.0
  thr = 0.55
  
  for i in range(num_classes):
      for j in range(num_classes):
          val = int(cm[i, j])
          frac = (val / mx) if mx > 0 else 0.0
          txt_color = "black" if frac >= thr else "white"
          ax.text(j, i, str(val), ha="center", va="center",
                  fontsize=10, fontweight="bold", color=txt_color)
  
  plt.show()
  plt.close(fig)
  
  def center_crop_square_pil(img: Image.Image) -> Image.Image:
      return ImageOps.fit(
          img, (min(img.size), min(img.size)),
          method=Image.Resampling.LANCZOS, centering=(0.5, 0.5)
      )
  
  raw_test_samples = [base_ds.samples[i] for i in test_idx]  # (path, y)
  seen = set()
  test_samples = []
  for p, y in raw_test_samples:
      rp = str(Path(p).resolve())
      if rp in seen:
          continue
      seen.add(rp)
      test_samples.append((p, y))
  
  def plot_collage_3x3(items, title):
      fig, axes = plt.subplots(3, 3, figsize=(FIG_W, FIG_W), dpi=FIG_DPI)
      axes = axes.ravel()
  
      for ax in axes:
          ax.axis("off")
  
      for i, (p, y, pred, conf) in enumerate(items[:9]):
          img = Image.open(p).convert("RGB")
          img = center_crop_square_pil(img)
  
          axes[i].imshow(img)
          axes[i].set_title(
              (f"{Path(p).name}\nTRUE: {classes[y]} | PRED: {classes[pred]} | {conf*100:.1f}%").upper(),
              fontsize=9,
              fontweight="bold",
          )
          axes[i].axis("off")
  
      fig.suptitle(title.upper(), fontsize=14, fontweight="bold")
      fig.tight_layout()
      plt.show()
      plt.close(fig)
      
  def plot_collage_2x2(items, title):
    fig, axes = plt.subplots(2, 2, figsize=(FIG_W, FIG_W), dpi=FIG_DPI)
    axes = axes.ravel()

    for ax in axes:
        ax.axis("off")

    for i, (p, y, pred, conf) in enumerate(items[:4]):
        img = Image.open(p).convert("RGB")
        img = center_crop_square_pil(img)

        axes[i].imshow(img)
        axes[i].set_title(
            (f"{Path(p).name}\nT:{classes[y]} | P:{classes[pred]} | {conf*100:.1f}%").upper(),
            fontsize=9,
            fontweight="bold",
        )
        axes[i].axis("off")

    fig.suptitle(title.upper(), fontsize=14, fontweight="bold")
    fig.tight_layout()
    plt.show()
    plt.close(fig)
  
  pick = np.random.choice(len(test_samples), size=min(9, len(test_samples)), replace=False)
  collage_items = []
  for i in pick:
      p, y = test_samples[i]
      collage_items.append((p, y, int(y_pred[i]), float(y_conf[i])))
  
  plot_collage_3x3(collage_items, "Test set sample (3×3): true vs predicted")
  
  wrong = np.where(y_true != y_pred)[0]
  if len(wrong) >= 1:
      k = min(4, len(wrong))
      pick_wrong = np.random.choice(wrong, size=k, replace=False)
      wrong_items = []
      for i in pick_wrong:
          p, y = test_samples[i]
          wrong_items.append((p, y, int(y_pred[i]), float(y_conf[i])))
      plot_collage_2x2(wrong_items, "Misclassifications (2×2 sample)")
  else:
      fig, ax = plt.subplots(figsize=(FIG_W, 2.2), dpi=FIG_DPI)
      ax.axis("off")
      ax.text(0.5, 0.5, "NO MISCLASSIFICATIONS IN THIS TEST SUBSET", ha="center", va="center",
              fontsize=14, fontweight="bold")
      fig.tight_layout()
      plt.show()
      plt.close(fig)
  
  TARGET_EXAMPLE = "Boeing360.jpg"
  
  cand = [
      Path.cwd() / TARGET_EXAMPLE,
      Path.cwd() / "images" / TARGET_EXAMPLE,
      Path(os.environ.get("QUARTO_PROJECT_DIR", Path.cwd())) / TARGET_EXAMPLE,
      Path(os.environ.get("QUARTO_PROJECT_DIR", Path.cwd())) / "images" / TARGET_EXAMPLE,
  ]
  
  ex_path = None
  for p in cand:
      if p.exists():
          ex_path = str(p)
          break
  
  if ex_path is None:
      hits = list(Path(DATA_ROOT).rglob(TARGET_EXAMPLE))
      if hits:
          ex_path = str(hits[0])
  
  if ex_path is None:
      ex_i = int(wrong[0]) if len(wrong) else int(np.random.randint(0, len(test_samples)))
      ex_path, ex_y = test_samples[ex_i]
  else:
      ex_y = None
      ex_rp = str(Path(ex_path).resolve())
      for p, y in base_ds.samples:
          if str(Path(p).resolve()) == ex_rp:
              ex_y = int(y)
              break
      if ex_y is None:
          ex_y = 0
  
  img = Image.open(ex_path).convert("RGB")
  img_sq = center_crop_square_pil(img)
  
  x = test_tfm(img_sq).unsqueeze(0).to(device)
  with torch.inference_mode():
      probs = torch.softmax(model(x), dim=1).cpu().numpy().ravel()

      probs = np.asarray(probs, dtype=float)
      assert probs.ndim == 1 and probs.shape[0] == len(classes)
      assert np.isfinite(probs).all()
      assert abs(probs.sum() - 1.0) < 1e-3

  
  pred_i = int(np.argmax(probs))
  
  fig, ax = plt.subplots(figsize=(FIG_W, 4.8), dpi=FIG_DPI)
  ax.imshow(img_sq)
  ax.set_title((f"EXAMPLE IMAGE | TRUE: {classes[ex_y]} | PRED: {classes[pred_i]}").upper())
  ax.axis("off")
  fig.tight_layout()
  plt.show()
  plt.close(fig)
  
  fig, ax = plt.subplots(figsize=(FIG_W, 4.8), dpi=FIG_DPI)
  
  bars = ax.bar(classes, probs, edgecolor="black", linewidth=0.8, color=color_blue, alpha=0.9)
  
  ax.set_ylim(0, 1.20)
  ax.set_title("MODEL CONFIDENCE BY CLASS")
  ax.set_xlabel("CLASS")
  ax.set_ylabel("PROBABILITY")
  ax.tick_params(axis="x", rotation=30)
  
  y_pad = 0.03
  for b, p in zip(bars, probs):
      ax.text(
          b.get_x() + b.get_width() / 2,
          min(max(p, 0.01) + y_pad, 1.04),
          f"{p*100:.3f}%",
          ha="center",
          va="bottom",
          fontsize=9,
          fontweight="bold",
          rotation=0
      )
  
  fig.tight_layout()
  plt.show()
  plt.close(fig)