Skip to article frontmatterSkip to article content

Fine-tuning Clay Foundation Model for Land Cover Segmentation

Welcome to Tutorial 3! In this hands-on session, youโ€™ll learn how to fine-tune the Clay foundation model for land cover segmentation using the Chesapeake Bay dataset.

Open In Colab

Learning Objectivesยถ

By the end of this tutorial, you will:

  • Understand what foundation models are and why theyโ€™re powerful for Earth observation
  • Learn how to fine-tune a pre-trained model for semantic segmentation
  • Apply transfer learning techniques to land cover classification
  • Work with real satellite imagery and ground truth labels
  • Evaluate model performance on geospatial data

What Youโ€™ll Buildยถ

Youโ€™ll create a land cover segmentation model that can classify different types of land use (water, forest, urban areas, etc.) from satellite imagery.

Background for Different Audiencesยถ

For GIS Professionals ๐Ÿ“ยถ

  • Foundation models are like having a universal โ€œbase mapโ€ that understands Earthโ€™s features
  • Segmentation is similar to creating detailed land use polygons, but at the pixel level
  • Think of this as automated land cover classification that can replace manual digitization
  • The output is similar to creating a detailed land use/land cover (LULC) raster

For Data Analysts ๐Ÿ“Šยถ

  • Weโ€™re using transfer learning - starting with a model already trained on lots of Earth imagery
  • Fine-tuning means adapting this pre-trained model to our specific classification task
  • This is like taking a general-purpose tool and customizing it for your specific needs
  • The model learns patterns in pixel values to predict land cover categories

For ML Engineers ๐Ÿค–ยถ

  • Clay is a Vision Transformer (ViT) trained on massive Earth observation datasets
  • Weโ€™re doing semantic segmentation - predicting a class for every pixel
  • The architecture uses a frozen encoder (Clay) + trainable segmentation head
  • Weโ€™ll use PyTorch Lightning for training orchestration

Dataset Overviewยถ

The Chesapeake Bay Land Cover dataset contains:

  • High-resolution aerial imagery (NAIP - National Agriculture Imagery Program)
  • 7 land cover classes: Water, Tree Canopy, Low Vegetation, Barren, Impervious (Roads), Impervious (Other), No Data
  • Pixel-level annotations for supervised learning
  • Real-world complexity with mixed land uses and seasonal variations

How the Clay Segmentation Architecture Worksยถ

The Segmentor class combines two key components:

1. Frozen Clay Encoder ๐ŸงŠยถ

  • Pre-trained on millions of Earth observation images
  • Extracts rich feature representations from input imagery
  • Frozen = weights donโ€™t change during fine-tuning (saves compute!)
  • Acts like a โ€œuniversal feature extractorโ€ for Earth imagery

2. Trainable Segmentation Head ๐ŸŽฏยถ

  • Takes Clayโ€™s feature maps and upsamples them to original image size
  • Uses convolution + pixel shuffle operations for efficient upsampling
  • Only this part gets trained - much faster than training from scratch!

Key Parameters:

  • num_classes (int): Number of land cover classes to predict (7 for Chesapeake)
  • ckpt_path (str): Path to the pre-trained Clay model weights

Why This Approach Works:

  • โœ… Faster training: Only train the small segmentation head
  • โœ… Less data needed: Clay already understands Earth imagery patterns
  • โœ… Better performance: Foundation model knowledge transfers well
  • โœ… Cost effective: Requires fewer computational resources

About the Chesapeake Bay Dataset ๐Ÿฆ€ยถ

Weโ€™ll use the Chesapeake Bay Land Cover dataset - a high-quality dataset perfect for learning land cover segmentation.

Dataset Citationยถ

If you use this dataset in your work, please cite:

Robinson C, Hou L, Malkin K, Soobitsky R, Czawlytko J, Dilkina B, Jojic N.
Large Scale High-Resolution Land Cover Mapping with Multi-Resolution Data.
Proceedings of the 2019 Conference on Computer Vision and Pattern Recognition (CVPR 2019).

Why This Dataset is Great for Learning:ยถ

  • High Resolution: 1-meter pixel resolution aerial imagery
  • Multiple Regions: Covers diverse landscapes in the Chesapeake Bay area
  • Expert Annotations: Ground truth labels created by domain experts
  • Real-world Complexity: Mixed land uses, seasonal variations, and edge cases
  • Well-documented: Extensively used in research with known baselines

Land Cover Classes (7 total):ยถ

  1. Water ๐Ÿ’ง - Rivers, lakes, bays, coastal areas
  2. Tree Canopy/Forest ๐ŸŒณ - Dense forest areas, large trees
  3. Low Vegetation/Fields ๐ŸŒฑ - Grass, crops, shrubs, sparse vegetation
  4. Barren Land ๐Ÿ”๏ธ - Exposed soil, construction sites, beaches
  5. Impervious (Roads) ๐Ÿ›ฃ๏ธ - Paved roads, highways, parking lots
  6. Impervious (Other) ๐Ÿข - Buildings, rooftops, other built structures
  7. No Data โฌœ - Areas with missing or invalid data

More information: Chesapeake Bay Dataset

๐Ÿš€ Setup and Installationยถ

Weโ€™ll install all required packages for fine-tuning the Clay model. This notebook is optimized for Google Colab but works in any Jupyter environment.

What Each Package Does:ยถ

  • torch: PyTorch deep learning framework
  • lightning: PyTorch Lightning for training orchestration
  • segmentation_models_pytorch: Pre-built segmentation architectures
  • rasterio: Reading/writing geospatial raster data (GeoTIFF files)
  • s5cmd: Fast, parallel S3 data transfers

Installation Optionsยถ

Option 1: All at once (recommended for Colab)

pip install torch lightning segmentation_models_pytorch rasterio s5cmd

Option 2: Individual packages (if you encounter conflicts)

pip install torch
pip install lightning  
pip install segmentation_models_pytorch
pip install rasterio
pip install s5cmd

For uv users (local environments):

uv sync --locked
source .venv/bin/activate

Letโ€™s install everything we need:

๐Ÿ“ฆ Install Required Packagesยถ

Run this cell to install all dependencies. This may take 2-3 minutes in Colab.

# Install packages (this may take a few minutes)
!pip install torch lightning segmentation_models_pytorch rasterio s5cmd -q

๐Ÿ“‚ Clone the Clay Model Repositoryยถ

We need the Clay model code for training. This downloads the latest version:

# Clone the Clay model repository
!git clone --depth=1 https://github.com/clay-foundation/model.git
# Navigate to the model directory and check contents
%cd model
!ls -la

๐Ÿ Add Clay Model to Python Pathยถ

This makes the Clay model modules available for import:

# Add the claymodel directory to Python path so we can import modules
import sys
sys.path.append("./claymodel")

# Import key modules we'll use for training
from claymodel.finetune.segment.chesapeake_datamodule import ChesapeakeDataModule
from claymodel.finetune.segment.chesapeake_model import ChesapeakeSegmentor

print("โœ… Clay model modules imported successfully!")

๐Ÿ“ฅ Download Training Dataยถ

Weโ€™ll download a subset of the Chesapeake Bay dataset for training. The full dataset is ~100GB, so weโ€™re using a small sample for this tutorial.

What Weโ€™re Downloading:ยถ

  • *_lc.tif: Land cover label images (ground truth)
  • *_naip-new.tif: NAIP aerial imagery (input images)
  • Training data: From New York region, 2013
  • Validation data: Separate set for evaluating model performance

About s5cmd:ยถ

s5cmd is a high-performance tool for transferring data from AWS S3. Itโ€™s much faster than standard AWS CLI for large datasets.

Note: Download may take 5-10 minutes depending on your internet connection.

# Create directory structure for our data
!mkdir -p data/cvpr/files/train data/cvpr/files/val

# Download training data (subset from NY region)
print("๐Ÿ“ฅ Downloading training data...")
!s5cmd \
    --no-sign-request \
    cp \
    --include "m_42076*_lc.tif" \
    --include "m_42076*_naip-new.tif" \
    "s3://us-west-2.opendata.source.coop/agentmorris/lila-wildlife/lcmcvpr2019/cvpr_chesapeake_landcover/ny_1m_2013_extended-debuffered-train_tiles/*" \
    data/cvpr/files/train/

print("โœ… Training data downloaded!")
# Download validation data (complete validation set)
print("๐Ÿ“ฅ Downloading validation data...")
!s5cmd \
    --no-sign-request \
    cp \
    --include "*_lc.tif" \
    --include "*_naip-new.tif" \
    "s3://us-west-2.opendata.source.coop/agentmorris/lila-wildlife/lcmcvpr2019/cvpr_chesapeake_landcover/ny_1m_2013_extended-debuffered-val_tiles/*" \
    data/cvpr/files/val/

print("โœ… Validation data downloaded!")

โœ… Verify Downloaded Dataยถ

Letโ€™s check what we downloaded:

# Check what files we downloaded
print("๐Ÿ“Š Train data files:")
!ls data/cvpr/files/train | head -10
print("๐Ÿ“Š Validation data files:")
!ls data/cvpr/files/val | head -10

๐Ÿ”„ Data Preprocessingยถ

The downloaded GeoTIFF files are large (typically 1000x1000 pixels or more). For efficient training, we need to:

  1. Split into smaller chips: Break large images into 224x224 pixel tiles
  2. Organize directory structure: Separate images and labels into proper folders
  3. Create train/val splits: Ensure no data leakage between training and validation

Why 224x224 chips?ยถ

  • Memory efficiency: Fits in GPU memory for training
  • Standard size: Common input size for vision models
  • Balanced coverage: Good trade-off between context and computational efficiency

What the preprocessing script does:ยถ

  • Reads large GeoTIFF files
  • Splits them into 224x224 pixel chips
  • Saves chips as individual image files
  • Maintains spatial alignment between imagery and labels
  • Creates proper directory structure for PyTorch Lightning

Note: This step may take 5-10 minutes to process all the data.

# Clean up any existing processed data to ensure fresh start
!rm -rf data/cvpr/ny/
print("๐Ÿงน Cleaned up existing processed data")

๐Ÿ”ง Run Data Preprocessingยถ

This converts the large GeoTIFF files into training-ready 224x224 image chips:

# Run the preprocessing script
# Args: input_dir output_dir chip_size
print("๐Ÿ”„ Processing data into 224x224 chips...")
!python claymodel/finetune/segment/preprocess_data.py data/cvpr/files data/cvpr/ny 224
print("โœ… Data preprocessing complete!")

๐Ÿ“Š Check Processed Dataยถ

Letโ€™s verify our preprocessing worked correctly:

# Check the directory structure and count files
!echo "๐Ÿ“ Directory structure:"
!ls -la data/cvpr/ny/

!echo -e "\n๐Ÿ“Š Data counts:"
!echo "Validation labels: $(ls data/cvpr/ny/val/labels | wc -l) files"  
!echo "Validation chips: $(ls data/cvpr/ny/val/chips | wc -l) files"
!echo "Training labels: $(ls data/cvpr/ny/train/labels | wc -l) files"
!echo "Training chips: $(ls data/cvpr/ny/train/chips | wc -l) files"

๐Ÿ—๏ธ Download Pre-trained Clay Modelยถ

Now we need the pre-trained Clay foundation model. Think of this as downloading a โ€œuniversal Earth imagery expertโ€ that already understands features like vegetation, water, and built structures.

About the Clay Model:ยถ

  • Version 1.5: Latest stable version
  • Size: ~400MB (this is normal for foundation models!)
  • Training: Trained on millions of satellite/aerial images
  • Format: PyTorch Lightning checkpoint (.ckpt file)

What Makes Clay Special:ยถ

  • ๐ŸŒ Global coverage: Trained on imagery from around the world
  • ๐Ÿ›ฐ๏ธ Multi-sensor: Works with different satellite/aerial platforms
  • ๐ŸŽฏ Transfer learning ready: Designed to be fine-tuned for specific tasks
  • โšก Efficient: Optimized for both training and inference
# Create checkpoints directory and download Clay model
!mkdir -p checkpoints

print("โฌ‡๏ธ Downloading Clay v1.5 model (this may take a few minutes)...")
!wget -O checkpoints/clay-v1.5.ckpt https://huggingface.co/made-with-clay/Clay/resolve/main/v1.5/clay-v1.5.ckpt

print("โœ… Clay model downloaded successfully!")

โœ… Verify Model Downloadยถ

Letโ€™s check the downloaded model file:

# Verify the model was downloaded correctly
!ls -lh checkpoints/
!du -h checkpoints/clay-v1.5.ckpt

โš™๏ธ Training Configurationยถ

Before training, letโ€™s examine the configuration file that controls all the training parameters. Understanding these settings helps you adapt the model for your own projects.

Whatโ€™s in the Config File:ยถ

  • Data paths: Where to find training/validation data
  • Model settings: Architecture choices and hyperparameters
  • Training params: Learning rate, batch size, number of epochs
  • Hardware settings: GPU usage, mixed precision training
  • Logging: Where to save results and checkpoints
# Let's look at the training configuration
print("๐Ÿ“‹ Training Configuration:")
!cat configs/segment_chesapeake.yaml

print("\n๐Ÿ’ก Key Settings Explained:")
print("- lr: 1e-5 (learning rate - how fast the model learns)")  
print("- batch_size: 16 (number of images processed together)")
print("- max_epochs: 50 (maximum training iterations)")
print("- precision: bf16-mixed (faster training with minimal accuracy loss)")

๐Ÿ“ Understanding the Configurationยถ

The config file uses YAML format - a human-readable way to specify settings. Hereโ€™s what each section does:

  • data: Paths to training and validation data
  • model: Architecture and learning parameters
  • trainer: Hardware settings and training duration
  • callbacks: When to save models and how to monitor progress
  • logger: Where to save training logs and metrics

๐Ÿš€ Model Training Setupยถ

Now weโ€™ll set up the training pipeline using PyTorch Lightning. This approach separates data handling from model training, making the code cleaner and more maintainable.

Training Components:ยถ

  1. DataModule: Handles loading and preprocessing of images
  2. Model: The Clay encoder + segmentation head
  3. Trainer: Orchestrates the training process

Key Benefits of This Approach:ยถ

  • โœ… Reproducible: Same setup works across different environments
  • โœ… Scalable: Easy to train on single GPU or multiple GPUs
  • โœ… Maintainable: Clean separation of concerns
  • โœ… Flexible: Easy to modify individual components

๐Ÿ“Š Initialize Data Moduleยถ

The DataModule handles all data operations - loading images, applying transforms, creating batches:

# Initialize the data module with our processed data
print("๐Ÿ“Š Setting up data module...")

dm = ChesapeakeDataModule(
    train_chip_dir="data/cvpr/ny/train/chips/",      # Training images
    train_label_dir="data/cvpr/ny/train/labels/",    # Training labels  
    val_chip_dir="data/cvpr/ny/val/chips/",          # Validation images
    val_label_dir="data/cvpr/ny/val/labels/",        # Validation labels
    metadata_path="configs/metadata.yaml",           # Data normalization info
    batch_size=16,                                   # Images per training batch
    num_workers=8,                                   # Parallel data loading processes  
    platform="naip",                                 # Image type (NAIP aerial imagery)
)

# Prepare the data loaders
dm.setup()
print("โœ… Data module ready!")
print(f"๐Ÿ“ˆ Training batches: {len(dm.train_dataloader())}")
print(f"๐Ÿ“Š Validation batches: {len(dm.val_dataloader())}")

๐Ÿค– Initialize the Modelยถ

Now we create our segmentation model - Clay encoder + segmentation head:

# Initialize the segmentation model
print("๐Ÿค– Setting up segmentation model...")

model = ChesapeakeSegmentor(
    num_classes=7,                              # 7 land cover classes
    ckpt_path="checkpoints/clay-v1.5.ckpt",    # Pre-trained Clay model
    lr=1e-5,                                    # Learning rate (conservative for fine-tuning)
    wd=0.05,                                    # Weight decay (regularization)
    b1=0.9,                                     # Adam optimizer beta1  
    b2=0.95,                                    # Adam optimizer beta2
)

print("โœ… Model initialized!")
print(f"๐ŸงŠ Clay encoder: FROZEN (saves compute)")
print(f"๐ŸŽฏ Segmentation head: TRAINABLE (learns land cover patterns)")

โšก Setup the Trainerยถ

The Trainer handles the training loop, GPU usage, and checkpointing:

# Import the Trainer
from lightning import Trainer

print("โšก Setting up trainer...")
# Configure the trainer for our training session
trainer = Trainer(
    accelerator="auto",                    # Automatically detect GPU/CPU
    devices=1,                            # Use 1 device (GPU if available)  
    num_nodes=1,                          # Single machine training
    precision="bf16-mixed",               # Mixed precision (faster training)
    log_every_n_steps=5,                  # Log metrics every 5 training steps
    max_epochs=1,                         # Train for 1 epoch (demo purposes)
    accumulate_grad_batches=1,            # No gradient accumulation
    default_root_dir="checkpoints/segment", # Where to save checkpoints
    fast_dev_run=False,                   # Full training (not debugging mode)
    num_sanity_val_steps=0,               # Skip validation sanity check
)

print("โœ… Trainer configured!")
print(f"๐ŸŽฏ Will train for {trainer.max_epochs} epoch(s)")
print(f"๐Ÿ’พ Checkpoints saved to: {trainer.default_root_dir}")

๐Ÿ Start Training!ยถ

Everything is set up - letโ€™s train the model! This will:

  1. Load batches of images and labels
  2. Forward pass: Run images through Clay encoder + segmentation head
  3. Compute loss: Compare predictions to ground truth labels
  4. Backward pass: Calculate gradients for the segmentation head
  5. Update weights: Improve the segmentation head parameters
  6. Validate: Test performance on validation data
  7. Save checkpoint: Store the trained model

Expected time: ~5-10 minutes for 1 epoch (depending on hardware)

What to watch for:

  • Training loss should decrease over time
  • Validation metrics should improve
  • No out-of-memory errors
# Start the training process!
print("๐Ÿš€ Starting training...")
print("๐Ÿ“Š Watch the progress below:")

trainer.fit(model, dm)

print("\n๐ŸŽ‰ Training complete!")
print("๐Ÿ“ Check the checkpoints directory for your trained model")
print("โžก๏ธ Next: Lets run the inference to see predictions!")

Model Inference and Results Visualizationยถ

Welcome to the inference and visualization part of Tutorial - Here youโ€™ll see your fine-tuned Clay model in action, making predictions on real satellite imagery.

What Youโ€™ll Learnยถ

  • How to load a trained segmentation model for inference
  • Techniques for visualizing model predictions
  • How to interpret land cover segmentation results
  • Methods for comparing predictions with ground truth
  • Best practices for model evaluation

What Weโ€™ll Doยถ

  1. Load the trained model from the previous notebook
  2. Prepare validation data for testing
  3. Run inference to generate predictions
  4. Visualize results with color-coded land cover maps
  5. Compare predictions with ground truth labels

Key Conceptsยถ

For GIS Professionals ๐Ÿ“ยถ

  • Inference: Using your trained model to classify new imagery
  • Visualization: Creating interpretable land cover maps from model outputs
  • Think of this as automated feature extraction and classification
  • Results can be exported as GeoTIFF files for use in GIS software

For Data Analysts ๐Ÿ“Šยถ

  • Model evaluation: Assessing how well our model performs
  • Visual validation: Checking predictions against known ground truth
  • Pattern recognition: Understanding what the model learned vs. missed
  • Quality assessment: Identifying areas for model improvement

For ML Engineers ๐Ÿค–ยถ

  • Inference pipeline: Loading checkpoints and running forward passes
  • Post-processing: Converting logits to class predictions
  • Batch processing: Efficient handling of multiple images
  • Model interpretation: Understanding model behavior through visualization

Letโ€™s make sure we have access to our trained model:

# Verify we have the necessary files
print("๐Ÿ“ Current directory contents:")
!ls -la

print("\n๐Ÿ” Checking for trained model...")
!ls -la checkpoints/segment/lightning_logs/*/checkpoints/ 2>/dev/null || echo "โŒ No trained model found - please run the training notebook first!"

๐Ÿ“š Import Required Librariesยถ

Letโ€™s import all the tools we need for inference and visualization:

# Core Python libraries
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

# PyTorch for deep learning
import torch
import torch.nn.functional as F

# Additional utilities  
from einops import rearrange

print("โœ… All libraries imported successfully!")
print(f"๐Ÿ”ง PyTorch version: {torch.__version__}")
print(f"๐ŸŽฎ GPU available: {'Yes' if torch.cuda.is_available() else 'No (using CPU)'}")
โœ… All libraries imported successfully!
๐Ÿ”ง PyTorch version: 2.7.0
๐ŸŽฎ GPU available: No (using CPU)

โš™๏ธ Configuration and File Pathsยถ

Letโ€™s define all the paths and parameters weโ€™ll need. These should match what you used in the training notebook:

# File paths and configuration
CHESAPEAKE_CHECKPOINT_PATH = "checkpoints/segment/lightning_logs/version_0/checkpoints/epoch=0-step=63.ckpt"
CLAY_CHECKPOINT_PATH = "checkpoints/clay-v1.5.ckpt"
METADATA_PATH = "configs/metadata.yaml"

# Data directories
TRAIN_CHIP_DIR = "data/cvpr/ny/train/chips/"
TRAIN_LABEL_DIR = "data/cvpr/ny/train/labels/"
VAL_CHIP_DIR = "data/cvpr/ny/val/chips/"
VAL_LABEL_DIR = "data/cvpr/ny/val/labels/"

# Data loading parameters
BATCH_SIZE = 32          # Process 32 images at once (larger batch for inference)
NUM_WORKERS = 1          # Single worker to avoid issues in Colab
PLATFORM = "naip"        # NAIP aerial imagery platform

print("๐Ÿ“‹ Configuration loaded:")
print(f"   ๐ŸŽฏ Model checkpoint: {CHESAPEAKE_CHECKPOINT_PATH}")
print(f"   ๐Ÿง  Clay model: {CLAY_CHECKPOINT_PATH}")
print(f"   ๐Ÿ“Š Batch size: {BATCH_SIZE}")
print(f"   ๐Ÿ“ท Platform: {PLATFORM}")

๐Ÿ”ง Helper Functionsยถ

Letโ€™s define functions to handle model loading, data preparation, inference, and visualization. Breaking these into functions makes the code more organized and reusable:

def get_model(chesapeake_checkpoint_path, clay_checkpoint_path, metadata_path):
    """
    Load the trained segmentation model from checkpoint.
    
    Args:
        chesapeake_checkpoint_path: Path to our trained model
        clay_checkpoint_path: Path to the Clay foundation model  
        metadata_path: Path to data normalization metadata
        
    Returns:
        model: Loaded model in evaluation mode
    """
    print("๐Ÿค– Loading trained model...")
    
    model = ChesapeakeSegmentor.load_from_checkpoint(
        checkpoint_path=chesapeake_checkpoint_path,
        metadata_path=metadata_path,
        ckpt_path=clay_checkpoint_path,
    )
    
    # Set to evaluation mode (disables dropout, batch norm training mode, etc.)
    model.eval()
    
    print("โœ… Model loaded successfully!")
    return model
def get_data(train_chip_dir, train_label_dir, val_chip_dir, val_label_dir, 
             metadata_path, batch_size, num_workers, platform):
    """
    Set up data loading for inference.
    
    Args:
        Various paths and parameters for data loading
        
    Returns:
        batch: A batch of validation data
        metadata: Data normalization and class information
    """
    print("๐Ÿ“Š Setting up data loader...")
    
    # Create data module (same as training, but we only need validation data)
    dm = ChesapeakeDataModule(
        train_chip_dir=train_chip_dir,
        train_label_dir=train_label_dir,  
        val_chip_dir=val_chip_dir,
        val_label_dir=val_label_dir,
        metadata_path=metadata_path,
        batch_size=batch_size,
        num_workers=num_workers,
        platform=platform,
    )
    
    # Setup the data module
    dm.setup(stage="fit")
    
    # Get one batch of validation data for visualization
    val_dl = iter(dm.val_dataloader())
    batch = next(val_dl)
    
    print(f"โœ… Data loaded - batch contains {batch['pixels'].shape[0]} images")
    print(f"๐Ÿ“ Image shape: {list(batch['pixels'].shape[1:])}")
    
    return batch, dm.metadata
def run_prediction(model, batch):
    """
    Run inference on a batch of images.
    
    Args:
        model: Trained segmentation model
        batch: Batch of input images
        
    Returns:
        outputs: Model predictions (probabilities for each class)
    """
    print("๐Ÿ”ฎ Running inference...")
    
    # Disable gradient computation for faster inference
    with torch.no_grad():
        # Forward pass through the model
        outputs = model(batch)
    
    # Upsample predictions to match original image size (256x256)
    # The model outputs smaller feature maps that need to be upsampled
    outputs = F.interpolate(
        outputs, 
        size=(256, 256),           # Target size
        mode="bilinear",           # Smooth interpolation
        align_corners=False        # PyTorch default
    )
    
    print(f"โœ… Inference complete - predictions shape: {list(outputs.shape)}")
    return outputs
def denormalize_images(normalized_images, means, stds):
    """
    Convert normalized images back to viewable format.
    
    During training, images are normalized (mean=0, std=1) for better model performance.  
    For visualization, we need to reverse this normalization.
    
    Args:
        normalized_images: Normalized image tensors
        means: Mean values used for normalization
        stds: Standard deviation values used for normalization
        
    Returns:
        denormalized_images: Images in 0-255 range for display
    """
    means = np.array(means).reshape(1, -1, 1, 1)
    stds = np.array(stds).reshape(1, -1, 1, 1)
    
    # Reverse normalization: multiply by std, then add mean
    denormalized_images = normalized_images * stds + means
    
    # Convert to 0-255 range for display
    return denormalized_images.astype(np.uint8)


def post_process(batch, outputs, metadata):
    """
    Convert model outputs and inputs into visualization-ready format.
    
    Args:
        batch: Original batch of data
        outputs: Model prediction probabilities
        metadata: Data normalization info
        
    Returns:
        images: RGB images ready for display
        labels: Ground truth segmentation maps
        preds: Predicted segmentation maps
    """
    print("๐Ÿ”„ Post-processing results...")
    
    # Convert prediction probabilities to class predictions
    # argmax selects the class with highest probability for each pixel
    preds = torch.argmax(outputs, dim=1).detach().cpu().numpy()
    
    # Extract ground truth labels
    labels = batch["label"].detach().cpu().numpy()
    
    # Extract normalized pixel values
    pixels = batch["pixels"].detach().cpu().numpy()
    
    # Get normalization parameters for this platform (NAIP)
    means = list(metadata["naip"].bands.mean.values())
    stds = list(metadata["naip"].bands.std.values())
    
    # Denormalize images for display
    norm_pixels = denormalize_images(pixels, means, stds)
    
    # Rearrange from (batch, channels, height, width) to (batch, height, width, channels)
    # This is the format matplotlib expects for RGB images
    images = rearrange(norm_pixels[:, :3, :, :], "b c h w -> b h w c")
    
    print(f"โœ… Post-processing complete")
    print(f"๐Ÿ“Š Processed {len(images)} images")
    
    return images, labels, preds
def plot_predictions(images, labels, preds):
    """
    Create a comprehensive visualization of results.
    
    Shows original images, ground truth labels, and model predictions
    in an easy-to-compare grid format.
    
    Args:
        images: RGB aerial images
        labels: Ground truth segmentation maps  
        preds: Model predicted segmentation maps
    """
    print("๐ŸŽจ Creating visualization...")
    
    # Define colors for each land cover class
    # These colors are chosen to be intuitive and visually distinct
    colors = [
        (0/255, 0/255, 255/255, 1),         # Deep Blue for water ๐Ÿ’ง
        (34/255, 139/255, 34/255, 1),       # Forest Green for tree canopy ๐ŸŒณ
        (154/255, 205/255, 50/255, 1),      # Yellow Green for low vegetation ๐ŸŒฑ
        (210/255, 180/255, 140/255, 1),     # Tan for barren land ๐Ÿ”๏ธ
        (169/255, 169/255, 169/255, 1),     # Dark Gray for impervious (other) ๐Ÿข
        (105/255, 105/255, 105/255, 1),     # Dim Gray for impervious (road) ๐Ÿ›ฃ๏ธ
        (255/255, 255/255, 255/255, 1),     # White for no data โฌœ
    ]
    cmap = ListedColormap(colors)
    
    # Create a large figure to show all comparisons
    fig, axes = plt.subplots(12, 8, figsize=(16, 24))
    fig.suptitle('๐ŸŒ Land Cover Segmentation Results', fontsize=16, fontweight='bold')
    
    # Plot in three rows: Images, Ground Truth, Predictions
    plot_data(axes, images, row_offset=0, title="๐Ÿ“ท Original Image")
    plot_data(axes, labels, row_offset=1, title="๐ŸŽฏ Ground Truth", cmap=cmap, vmin=0, vmax=6)
    plot_data(axes, preds, row_offset=2, title="๐Ÿค– Model Prediction", cmap=cmap, vmin=0, vmax=6)
    
    # Add a legend explaining the color scheme
    add_legend(fig, cmap)
    
    plt.tight_layout()
    plt.show()
    
    print("โœ… Visualization complete!")


def plot_data(axes, data, row_offset, title=None, cmap=None, vmin=None, vmax=None):
    """Helper function to plot a row of data in the grid."""
    for i, item in enumerate(data):
        if i >= 24:  # Only show first 24 images (3 rows of 8)
            break
            
        row = row_offset + (i // 8) * 3
        col = i % 8
        
        axes[row, col].imshow(item, cmap=cmap, vmin=vmin, vmax=vmax)
        axes[row, col].axis("off")
        
        # Add row titles
        if title and col == 0:
            axes[row, col].set_ylabel(title, rotation=0, fontsize=12, 
                                    fontweight='bold', ha='right', va='center')


def add_legend(fig, cmap):
    """Add a color legend explaining the land cover classes."""
    class_names = [
        "๐Ÿ’ง Water",
        "๐ŸŒณ Tree Canopy", 
        "๐ŸŒฑ Low Vegetation",
        "๐Ÿ”๏ธ Barren Land",
        "๐Ÿข Impervious (Other)",
        "๐Ÿ›ฃ๏ธ Impervious (Roads)", 
        "โฌœ No Data"
    ]
    
    # Create legend patches
    import matplotlib.patches as mpatches
    patches = [mpatches.Patch(color=cmap.colors[i], label=class_names[i]) 
               for i in range(len(class_names))]
    
    # Add legend to the figure
    fig.legend(handles=patches, loc='center', bbox_to_anchor=(0.5, 0.02), 
               ncol=4, fontsize=10)

๐Ÿš€ Run the Complete Inference Pipelineยถ

Now letโ€™s put it all together! Weโ€™ll load the model, prepare data, run inference, and visualize results:

# Load the trained model
model = get_model(CHESAPEAKE_CHECKPOINT_PATH, CLAY_CHECKPOINT_PATH, METADATA_PATH)
# Get validation data for testing
batch, metadata = get_data(
    TRAIN_CHIP_DIR,
    TRAIN_LABEL_DIR,
    VAL_CHIP_DIR,
    VAL_LABEL_DIR,
    METADATA_PATH,
    BATCH_SIZE,
    NUM_WORKERS,
    PLATFORM,
)

# Move data to GPU if available (same device as model)
device = next(model.parameters()).device
batch = {k: v.to(device) for k, v in batch.items()}
print(f"๐Ÿ“ฑ Using device: {device}")
# Run inference on the batch
outputs = run_prediction(model, batch)
# Post-process results for visualization
images, labels, preds = post_process(batch, outputs, metadata)
# Create the final visualization
plot_predictions(images, labels, preds)

print("\n๐ŸŽ‰ Inference and visualization complete!")
print("\n๐Ÿ” What to Look For:")
print("   โ€ข How well does the model identify water bodies?")
print("   โ€ข Are forest areas correctly classified?") 
print("   โ€ข Does the model distinguish between different types of impervious surfaces?")
print("   โ€ข Where does the model struggle or make mistakes?")
print("\n๐Ÿ’ก Next Steps:")
print("   โ€ข Try running on more batches to see consistency")
print("   โ€ข Consider additional training epochs for better performance")
print("   โ€ข Experiment with different learning rates or data augmentation")