Welcome to Tutorial 3! In this hands-on session, youโll learn how to fine-tune the Clay foundation model for land cover segmentation using the Chesapeake Bay dataset.
Learning Objectivesยถ
By the end of this tutorial, you will:
- Understand what foundation models are and why theyโre powerful for Earth observation
- Learn how to fine-tune a pre-trained model for semantic segmentation
- Apply transfer learning techniques to land cover classification
- Work with real satellite imagery and ground truth labels
- Evaluate model performance on geospatial data
What Youโll Buildยถ
Youโll create a land cover segmentation model that can classify different types of land use (water, forest, urban areas, etc.) from satellite imagery.
Background for Different Audiencesยถ
For GIS Professionals ๐ยถ
- Foundation models are like having a universal โbase mapโ that understands Earthโs features
- Segmentation is similar to creating detailed land use polygons, but at the pixel level
- Think of this as automated land cover classification that can replace manual digitization
- The output is similar to creating a detailed land use/land cover (LULC) raster
For Data Analysts ๐ยถ
- Weโre using transfer learning - starting with a model already trained on lots of Earth imagery
- Fine-tuning means adapting this pre-trained model to our specific classification task
- This is like taking a general-purpose tool and customizing it for your specific needs
- The model learns patterns in pixel values to predict land cover categories
For ML Engineers ๐คยถ
- Clay is a Vision Transformer (ViT) trained on massive Earth observation datasets
- Weโre doing semantic segmentation - predicting a class for every pixel
- The architecture uses a frozen encoder (Clay) + trainable segmentation head
- Weโll use PyTorch Lightning for training orchestration
Dataset Overviewยถ
The Chesapeake Bay Land Cover dataset contains:
- High-resolution aerial imagery (NAIP - National Agriculture Imagery Program)
- 7 land cover classes: Water, Tree Canopy, Low Vegetation, Barren, Impervious (Roads), Impervious (Other), No Data
- Pixel-level annotations for supervised learning
- Real-world complexity with mixed land uses and seasonal variations
How the Clay Segmentation Architecture Worksยถ
The Segmentor
class combines two key components:
1. Frozen Clay Encoder ๐งยถ
- Pre-trained on millions of Earth observation images
- Extracts rich feature representations from input imagery
- Frozen = weights donโt change during fine-tuning (saves compute!)
- Acts like a โuniversal feature extractorโ for Earth imagery
2. Trainable Segmentation Head ๐ฏยถ
- Takes Clayโs feature maps and upsamples them to original image size
- Uses convolution + pixel shuffle operations for efficient upsampling
- Only this part gets trained - much faster than training from scratch!
Key Parameters:
num_classes (int)
: Number of land cover classes to predict (7 for Chesapeake)ckpt_path (str)
: Path to the pre-trained Clay model weights
Why This Approach Works:
- โ Faster training: Only train the small segmentation head
- โ Less data needed: Clay already understands Earth imagery patterns
- โ Better performance: Foundation model knowledge transfers well
- โ Cost effective: Requires fewer computational resources
About the Chesapeake Bay Dataset ๐ฆยถ
Weโll use the Chesapeake Bay Land Cover dataset - a high-quality dataset perfect for learning land cover segmentation.
Dataset Citationยถ
If you use this dataset in your work, please cite:
Robinson C, Hou L, Malkin K, Soobitsky R, Czawlytko J, Dilkina B, Jojic N.
Large Scale High-Resolution Land Cover Mapping with Multi-Resolution Data.
Proceedings of the 2019 Conference on Computer Vision and Pattern Recognition (CVPR 2019).
Why This Dataset is Great for Learning:ยถ
- High Resolution: 1-meter pixel resolution aerial imagery
- Multiple Regions: Covers diverse landscapes in the Chesapeake Bay area
- Expert Annotations: Ground truth labels created by domain experts
- Real-world Complexity: Mixed land uses, seasonal variations, and edge cases
- Well-documented: Extensively used in research with known baselines
Land Cover Classes (7 total):ยถ
- Water ๐ง - Rivers, lakes, bays, coastal areas
- Tree Canopy/Forest ๐ณ - Dense forest areas, large trees
- Low Vegetation/Fields ๐ฑ - Grass, crops, shrubs, sparse vegetation
- Barren Land ๐๏ธ - Exposed soil, construction sites, beaches
- Impervious (Roads) ๐ฃ๏ธ - Paved roads, highways, parking lots
- Impervious (Other) ๐ข - Buildings, rooftops, other built structures
- No Data โฌ - Areas with missing or invalid data
More information: Chesapeake Bay Dataset
๐ Setup and Installationยถ
Weโll install all required packages for fine-tuning the Clay model. This notebook is optimized for Google Colab but works in any Jupyter environment.
What Each Package Does:ยถ
- torch: PyTorch deep learning framework
- lightning: PyTorch Lightning for training orchestration
- segmentation_models_pytorch: Pre-built segmentation architectures
- rasterio: Reading/writing geospatial raster data (GeoTIFF files)
- s5cmd: Fast, parallel S3 data transfers
Installation Optionsยถ
Option 1: All at once (recommended for Colab)
pip install torch lightning segmentation_models_pytorch rasterio s5cmd
Option 2: Individual packages (if you encounter conflicts)
pip install torch
pip install lightning
pip install segmentation_models_pytorch
pip install rasterio
pip install s5cmd
For uv
users (local environments):
uv sync --locked
source .venv/bin/activate
Letโs install everything we need:
๐ฆ Install Required Packagesยถ
Run this cell to install all dependencies. This may take 2-3 minutes in Colab.
# Install packages (this may take a few minutes)
!pip install torch lightning segmentation_models_pytorch rasterio s5cmd -q
๐ Clone the Clay Model Repositoryยถ
We need the Clay model code for training. This downloads the latest version:
# Clone the Clay model repository
!git clone --depth=1 https://github.com/clay-foundation/model.git
# Navigate to the model directory and check contents
%cd model
!ls -la
๐ Add Clay Model to Python Pathยถ
This makes the Clay model modules available for import:
# Add the claymodel directory to Python path so we can import modules
import sys
sys.path.append("./claymodel")
# Import key modules we'll use for training
from claymodel.finetune.segment.chesapeake_datamodule import ChesapeakeDataModule
from claymodel.finetune.segment.chesapeake_model import ChesapeakeSegmentor
print("โ
Clay model modules imported successfully!")
๐ฅ Download Training Dataยถ
Weโll download a subset of the Chesapeake Bay dataset for training. The full dataset is ~100GB, so weโre using a small sample for this tutorial.
What Weโre Downloading:ยถ
*_lc.tif
: Land cover label images (ground truth)*_naip-new.tif
: NAIP aerial imagery (input images)- Training data: From New York region, 2013
- Validation data: Separate set for evaluating model performance
About s5cmd:ยถ
s5cmd
is a high-performance tool for transferring data from AWS S3. Itโs much faster than standard AWS CLI for large datasets.
Note: Download may take 5-10 minutes depending on your internet connection.
# Create directory structure for our data
!mkdir -p data/cvpr/files/train data/cvpr/files/val
# Download training data (subset from NY region)
print("๐ฅ Downloading training data...")
!s5cmd \
--no-sign-request \
cp \
--include "m_42076*_lc.tif" \
--include "m_42076*_naip-new.tif" \
"s3://us-west-2.opendata.source.coop/agentmorris/lila-wildlife/lcmcvpr2019/cvpr_chesapeake_landcover/ny_1m_2013_extended-debuffered-train_tiles/*" \
data/cvpr/files/train/
print("โ
Training data downloaded!")
# Download validation data (complete validation set)
print("๐ฅ Downloading validation data...")
!s5cmd \
--no-sign-request \
cp \
--include "*_lc.tif" \
--include "*_naip-new.tif" \
"s3://us-west-2.opendata.source.coop/agentmorris/lila-wildlife/lcmcvpr2019/cvpr_chesapeake_landcover/ny_1m_2013_extended-debuffered-val_tiles/*" \
data/cvpr/files/val/
print("โ
Validation data downloaded!")
โ Verify Downloaded Dataยถ
Letโs check what we downloaded:
# Check what files we downloaded
print("๐ Train data files:")
!ls data/cvpr/files/train | head -10
print("๐ Validation data files:")
!ls data/cvpr/files/val | head -10
๐ Data Preprocessingยถ
The downloaded GeoTIFF files are large (typically 1000x1000 pixels or more). For efficient training, we need to:
- Split into smaller chips: Break large images into 224x224 pixel tiles
- Organize directory structure: Separate images and labels into proper folders
- Create train/val splits: Ensure no data leakage between training and validation
Why 224x224 chips?ยถ
- Memory efficiency: Fits in GPU memory for training
- Standard size: Common input size for vision models
- Balanced coverage: Good trade-off between context and computational efficiency
What the preprocessing script does:ยถ
- Reads large GeoTIFF files
- Splits them into 224x224 pixel chips
- Saves chips as individual image files
- Maintains spatial alignment between imagery and labels
- Creates proper directory structure for PyTorch Lightning
Note: This step may take 5-10 minutes to process all the data.
# Clean up any existing processed data to ensure fresh start
!rm -rf data/cvpr/ny/
print("๐งน Cleaned up existing processed data")
๐ง Run Data Preprocessingยถ
This converts the large GeoTIFF files into training-ready 224x224 image chips:
# Run the preprocessing script
# Args: input_dir output_dir chip_size
print("๐ Processing data into 224x224 chips...")
!python claymodel/finetune/segment/preprocess_data.py data/cvpr/files data/cvpr/ny 224
print("โ
Data preprocessing complete!")
๐ Check Processed Dataยถ
Letโs verify our preprocessing worked correctly:
# Check the directory structure and count files
!echo "๐ Directory structure:"
!ls -la data/cvpr/ny/
!echo -e "\n๐ Data counts:"
!echo "Validation labels: $(ls data/cvpr/ny/val/labels | wc -l) files"
!echo "Validation chips: $(ls data/cvpr/ny/val/chips | wc -l) files"
!echo "Training labels: $(ls data/cvpr/ny/train/labels | wc -l) files"
!echo "Training chips: $(ls data/cvpr/ny/train/chips | wc -l) files"
๐๏ธ Download Pre-trained Clay Modelยถ
Now we need the pre-trained Clay foundation model. Think of this as downloading a โuniversal Earth imagery expertโ that already understands features like vegetation, water, and built structures.
About the Clay Model:ยถ
- Version 1.5: Latest stable version
- Size: ~400MB (this is normal for foundation models!)
- Training: Trained on millions of satellite/aerial images
- Format: PyTorch Lightning checkpoint (.ckpt file)
What Makes Clay Special:ยถ
- ๐ Global coverage: Trained on imagery from around the world
- ๐ฐ๏ธ Multi-sensor: Works with different satellite/aerial platforms
- ๐ฏ Transfer learning ready: Designed to be fine-tuned for specific tasks
- โก Efficient: Optimized for both training and inference
# Create checkpoints directory and download Clay model
!mkdir -p checkpoints
print("โฌ๏ธ Downloading Clay v1.5 model (this may take a few minutes)...")
!wget -O checkpoints/clay-v1.5.ckpt https://huggingface.co/made-with-clay/Clay/resolve/main/v1.5/clay-v1.5.ckpt
print("โ
Clay model downloaded successfully!")
โ Verify Model Downloadยถ
Letโs check the downloaded model file:
# Verify the model was downloaded correctly
!ls -lh checkpoints/
!du -h checkpoints/clay-v1.5.ckpt
โ๏ธ Training Configurationยถ
Before training, letโs examine the configuration file that controls all the training parameters. Understanding these settings helps you adapt the model for your own projects.
Whatโs in the Config File:ยถ
- Data paths: Where to find training/validation data
- Model settings: Architecture choices and hyperparameters
- Training params: Learning rate, batch size, number of epochs
- Hardware settings: GPU usage, mixed precision training
- Logging: Where to save results and checkpoints
# Let's look at the training configuration
print("๐ Training Configuration:")
!cat configs/segment_chesapeake.yaml
print("\n๐ก Key Settings Explained:")
print("- lr: 1e-5 (learning rate - how fast the model learns)")
print("- batch_size: 16 (number of images processed together)")
print("- max_epochs: 50 (maximum training iterations)")
print("- precision: bf16-mixed (faster training with minimal accuracy loss)")
๐ Understanding the Configurationยถ
The config file uses YAML format - a human-readable way to specify settings. Hereโs what each section does:
- data: Paths to training and validation data
- model: Architecture and learning parameters
- trainer: Hardware settings and training duration
- callbacks: When to save models and how to monitor progress
- logger: Where to save training logs and metrics
๐ Model Training Setupยถ
Now weโll set up the training pipeline using PyTorch Lightning. This approach separates data handling from model training, making the code cleaner and more maintainable.
Training Components:ยถ
- DataModule: Handles loading and preprocessing of images
- Model: The Clay encoder + segmentation head
- Trainer: Orchestrates the training process
Key Benefits of This Approach:ยถ
- โ Reproducible: Same setup works across different environments
- โ Scalable: Easy to train on single GPU or multiple GPUs
- โ Maintainable: Clean separation of concerns
- โ Flexible: Easy to modify individual components
๐ Initialize Data Moduleยถ
The DataModule handles all data operations - loading images, applying transforms, creating batches:
# Initialize the data module with our processed data
print("๐ Setting up data module...")
dm = ChesapeakeDataModule(
train_chip_dir="data/cvpr/ny/train/chips/", # Training images
train_label_dir="data/cvpr/ny/train/labels/", # Training labels
val_chip_dir="data/cvpr/ny/val/chips/", # Validation images
val_label_dir="data/cvpr/ny/val/labels/", # Validation labels
metadata_path="configs/metadata.yaml", # Data normalization info
batch_size=16, # Images per training batch
num_workers=8, # Parallel data loading processes
platform="naip", # Image type (NAIP aerial imagery)
)
# Prepare the data loaders
dm.setup()
print("โ
Data module ready!")
print(f"๐ Training batches: {len(dm.train_dataloader())}")
print(f"๐ Validation batches: {len(dm.val_dataloader())}")
๐ค Initialize the Modelยถ
Now we create our segmentation model - Clay encoder + segmentation head:
# Initialize the segmentation model
print("๐ค Setting up segmentation model...")
model = ChesapeakeSegmentor(
num_classes=7, # 7 land cover classes
ckpt_path="checkpoints/clay-v1.5.ckpt", # Pre-trained Clay model
lr=1e-5, # Learning rate (conservative for fine-tuning)
wd=0.05, # Weight decay (regularization)
b1=0.9, # Adam optimizer beta1
b2=0.95, # Adam optimizer beta2
)
print("โ
Model initialized!")
print(f"๐ง Clay encoder: FROZEN (saves compute)")
print(f"๐ฏ Segmentation head: TRAINABLE (learns land cover patterns)")
โก Setup the Trainerยถ
The Trainer handles the training loop, GPU usage, and checkpointing:
# Import the Trainer
from lightning import Trainer
print("โก Setting up trainer...")
# Configure the trainer for our training session
trainer = Trainer(
accelerator="auto", # Automatically detect GPU/CPU
devices=1, # Use 1 device (GPU if available)
num_nodes=1, # Single machine training
precision="bf16-mixed", # Mixed precision (faster training)
log_every_n_steps=5, # Log metrics every 5 training steps
max_epochs=1, # Train for 1 epoch (demo purposes)
accumulate_grad_batches=1, # No gradient accumulation
default_root_dir="checkpoints/segment", # Where to save checkpoints
fast_dev_run=False, # Full training (not debugging mode)
num_sanity_val_steps=0, # Skip validation sanity check
)
print("โ
Trainer configured!")
print(f"๐ฏ Will train for {trainer.max_epochs} epoch(s)")
print(f"๐พ Checkpoints saved to: {trainer.default_root_dir}")
๐ Start Training!ยถ
Everything is set up - letโs train the model! This will:
- Load batches of images and labels
- Forward pass: Run images through Clay encoder + segmentation head
- Compute loss: Compare predictions to ground truth labels
- Backward pass: Calculate gradients for the segmentation head
- Update weights: Improve the segmentation head parameters
- Validate: Test performance on validation data
- Save checkpoint: Store the trained model
Expected time: ~5-10 minutes for 1 epoch (depending on hardware)
What to watch for:
- Training loss should decrease over time
- Validation metrics should improve
- No out-of-memory errors
# Start the training process!
print("๐ Starting training...")
print("๐ Watch the progress below:")
trainer.fit(model, dm)
print("\n๐ Training complete!")
print("๐ Check the checkpoints directory for your trained model")
print("โก๏ธ Next: Lets run the inference to see predictions!")
Model Inference and Results Visualizationยถ
Welcome to the inference and visualization part of Tutorial - Here youโll see your fine-tuned Clay model in action, making predictions on real satellite imagery.
What Youโll Learnยถ
- How to load a trained segmentation model for inference
- Techniques for visualizing model predictions
- How to interpret land cover segmentation results
- Methods for comparing predictions with ground truth
- Best practices for model evaluation
What Weโll Doยถ
- Load the trained model from the previous notebook
- Prepare validation data for testing
- Run inference to generate predictions
- Visualize results with color-coded land cover maps
- Compare predictions with ground truth labels
Key Conceptsยถ
For GIS Professionals ๐ยถ
- Inference: Using your trained model to classify new imagery
- Visualization: Creating interpretable land cover maps from model outputs
- Think of this as automated feature extraction and classification
- Results can be exported as GeoTIFF files for use in GIS software
For Data Analysts ๐ยถ
- Model evaluation: Assessing how well our model performs
- Visual validation: Checking predictions against known ground truth
- Pattern recognition: Understanding what the model learned vs. missed
- Quality assessment: Identifying areas for model improvement
For ML Engineers ๐คยถ
- Inference pipeline: Loading checkpoints and running forward passes
- Post-processing: Converting logits to class predictions
- Batch processing: Efficient handling of multiple images
- Model interpretation: Understanding model behavior through visualization
Letโs make sure we have access to our trained model:
# Verify we have the necessary files
print("๐ Current directory contents:")
!ls -la
print("\n๐ Checking for trained model...")
!ls -la checkpoints/segment/lightning_logs/*/checkpoints/ 2>/dev/null || echo "โ No trained model found - please run the training notebook first!"
๐ Import Required Librariesยถ
Letโs import all the tools we need for inference and visualization:
# Core Python libraries
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
# PyTorch for deep learning
import torch
import torch.nn.functional as F
# Additional utilities
from einops import rearrange
print("โ
All libraries imported successfully!")
print(f"๐ง PyTorch version: {torch.__version__}")
print(f"๐ฎ GPU available: {'Yes' if torch.cuda.is_available() else 'No (using CPU)'}")
โ
All libraries imported successfully!
๐ง PyTorch version: 2.7.0
๐ฎ GPU available: No (using CPU)
โ๏ธ Configuration and File Pathsยถ
Letโs define all the paths and parameters weโll need. These should match what you used in the training notebook:
# File paths and configuration
CHESAPEAKE_CHECKPOINT_PATH = "checkpoints/segment/lightning_logs/version_0/checkpoints/epoch=0-step=63.ckpt"
CLAY_CHECKPOINT_PATH = "checkpoints/clay-v1.5.ckpt"
METADATA_PATH = "configs/metadata.yaml"
# Data directories
TRAIN_CHIP_DIR = "data/cvpr/ny/train/chips/"
TRAIN_LABEL_DIR = "data/cvpr/ny/train/labels/"
VAL_CHIP_DIR = "data/cvpr/ny/val/chips/"
VAL_LABEL_DIR = "data/cvpr/ny/val/labels/"
# Data loading parameters
BATCH_SIZE = 32 # Process 32 images at once (larger batch for inference)
NUM_WORKERS = 1 # Single worker to avoid issues in Colab
PLATFORM = "naip" # NAIP aerial imagery platform
print("๐ Configuration loaded:")
print(f" ๐ฏ Model checkpoint: {CHESAPEAKE_CHECKPOINT_PATH}")
print(f" ๐ง Clay model: {CLAY_CHECKPOINT_PATH}")
print(f" ๐ Batch size: {BATCH_SIZE}")
print(f" ๐ท Platform: {PLATFORM}")
๐ง Helper Functionsยถ
Letโs define functions to handle model loading, data preparation, inference, and visualization. Breaking these into functions makes the code more organized and reusable:
def get_model(chesapeake_checkpoint_path, clay_checkpoint_path, metadata_path):
"""
Load the trained segmentation model from checkpoint.
Args:
chesapeake_checkpoint_path: Path to our trained model
clay_checkpoint_path: Path to the Clay foundation model
metadata_path: Path to data normalization metadata
Returns:
model: Loaded model in evaluation mode
"""
print("๐ค Loading trained model...")
model = ChesapeakeSegmentor.load_from_checkpoint(
checkpoint_path=chesapeake_checkpoint_path,
metadata_path=metadata_path,
ckpt_path=clay_checkpoint_path,
)
# Set to evaluation mode (disables dropout, batch norm training mode, etc.)
model.eval()
print("โ
Model loaded successfully!")
return model
def get_data(train_chip_dir, train_label_dir, val_chip_dir, val_label_dir,
metadata_path, batch_size, num_workers, platform):
"""
Set up data loading for inference.
Args:
Various paths and parameters for data loading
Returns:
batch: A batch of validation data
metadata: Data normalization and class information
"""
print("๐ Setting up data loader...")
# Create data module (same as training, but we only need validation data)
dm = ChesapeakeDataModule(
train_chip_dir=train_chip_dir,
train_label_dir=train_label_dir,
val_chip_dir=val_chip_dir,
val_label_dir=val_label_dir,
metadata_path=metadata_path,
batch_size=batch_size,
num_workers=num_workers,
platform=platform,
)
# Setup the data module
dm.setup(stage="fit")
# Get one batch of validation data for visualization
val_dl = iter(dm.val_dataloader())
batch = next(val_dl)
print(f"โ
Data loaded - batch contains {batch['pixels'].shape[0]} images")
print(f"๐ Image shape: {list(batch['pixels'].shape[1:])}")
return batch, dm.metadata
def run_prediction(model, batch):
"""
Run inference on a batch of images.
Args:
model: Trained segmentation model
batch: Batch of input images
Returns:
outputs: Model predictions (probabilities for each class)
"""
print("๐ฎ Running inference...")
# Disable gradient computation for faster inference
with torch.no_grad():
# Forward pass through the model
outputs = model(batch)
# Upsample predictions to match original image size (256x256)
# The model outputs smaller feature maps that need to be upsampled
outputs = F.interpolate(
outputs,
size=(256, 256), # Target size
mode="bilinear", # Smooth interpolation
align_corners=False # PyTorch default
)
print(f"โ
Inference complete - predictions shape: {list(outputs.shape)}")
return outputs
def denormalize_images(normalized_images, means, stds):
"""
Convert normalized images back to viewable format.
During training, images are normalized (mean=0, std=1) for better model performance.
For visualization, we need to reverse this normalization.
Args:
normalized_images: Normalized image tensors
means: Mean values used for normalization
stds: Standard deviation values used for normalization
Returns:
denormalized_images: Images in 0-255 range for display
"""
means = np.array(means).reshape(1, -1, 1, 1)
stds = np.array(stds).reshape(1, -1, 1, 1)
# Reverse normalization: multiply by std, then add mean
denormalized_images = normalized_images * stds + means
# Convert to 0-255 range for display
return denormalized_images.astype(np.uint8)
def post_process(batch, outputs, metadata):
"""
Convert model outputs and inputs into visualization-ready format.
Args:
batch: Original batch of data
outputs: Model prediction probabilities
metadata: Data normalization info
Returns:
images: RGB images ready for display
labels: Ground truth segmentation maps
preds: Predicted segmentation maps
"""
print("๐ Post-processing results...")
# Convert prediction probabilities to class predictions
# argmax selects the class with highest probability for each pixel
preds = torch.argmax(outputs, dim=1).detach().cpu().numpy()
# Extract ground truth labels
labels = batch["label"].detach().cpu().numpy()
# Extract normalized pixel values
pixels = batch["pixels"].detach().cpu().numpy()
# Get normalization parameters for this platform (NAIP)
means = list(metadata["naip"].bands.mean.values())
stds = list(metadata["naip"].bands.std.values())
# Denormalize images for display
norm_pixels = denormalize_images(pixels, means, stds)
# Rearrange from (batch, channels, height, width) to (batch, height, width, channels)
# This is the format matplotlib expects for RGB images
images = rearrange(norm_pixels[:, :3, :, :], "b c h w -> b h w c")
print(f"โ
Post-processing complete")
print(f"๐ Processed {len(images)} images")
return images, labels, preds
def plot_predictions(images, labels, preds):
"""
Create a comprehensive visualization of results.
Shows original images, ground truth labels, and model predictions
in an easy-to-compare grid format.
Args:
images: RGB aerial images
labels: Ground truth segmentation maps
preds: Model predicted segmentation maps
"""
print("๐จ Creating visualization...")
# Define colors for each land cover class
# These colors are chosen to be intuitive and visually distinct
colors = [
(0/255, 0/255, 255/255, 1), # Deep Blue for water ๐ง
(34/255, 139/255, 34/255, 1), # Forest Green for tree canopy ๐ณ
(154/255, 205/255, 50/255, 1), # Yellow Green for low vegetation ๐ฑ
(210/255, 180/255, 140/255, 1), # Tan for barren land ๐๏ธ
(169/255, 169/255, 169/255, 1), # Dark Gray for impervious (other) ๐ข
(105/255, 105/255, 105/255, 1), # Dim Gray for impervious (road) ๐ฃ๏ธ
(255/255, 255/255, 255/255, 1), # White for no data โฌ
]
cmap = ListedColormap(colors)
# Create a large figure to show all comparisons
fig, axes = plt.subplots(12, 8, figsize=(16, 24))
fig.suptitle('๐ Land Cover Segmentation Results', fontsize=16, fontweight='bold')
# Plot in three rows: Images, Ground Truth, Predictions
plot_data(axes, images, row_offset=0, title="๐ท Original Image")
plot_data(axes, labels, row_offset=1, title="๐ฏ Ground Truth", cmap=cmap, vmin=0, vmax=6)
plot_data(axes, preds, row_offset=2, title="๐ค Model Prediction", cmap=cmap, vmin=0, vmax=6)
# Add a legend explaining the color scheme
add_legend(fig, cmap)
plt.tight_layout()
plt.show()
print("โ
Visualization complete!")
def plot_data(axes, data, row_offset, title=None, cmap=None, vmin=None, vmax=None):
"""Helper function to plot a row of data in the grid."""
for i, item in enumerate(data):
if i >= 24: # Only show first 24 images (3 rows of 8)
break
row = row_offset + (i // 8) * 3
col = i % 8
axes[row, col].imshow(item, cmap=cmap, vmin=vmin, vmax=vmax)
axes[row, col].axis("off")
# Add row titles
if title and col == 0:
axes[row, col].set_ylabel(title, rotation=0, fontsize=12,
fontweight='bold', ha='right', va='center')
def add_legend(fig, cmap):
"""Add a color legend explaining the land cover classes."""
class_names = [
"๐ง Water",
"๐ณ Tree Canopy",
"๐ฑ Low Vegetation",
"๐๏ธ Barren Land",
"๐ข Impervious (Other)",
"๐ฃ๏ธ Impervious (Roads)",
"โฌ No Data"
]
# Create legend patches
import matplotlib.patches as mpatches
patches = [mpatches.Patch(color=cmap.colors[i], label=class_names[i])
for i in range(len(class_names))]
# Add legend to the figure
fig.legend(handles=patches, loc='center', bbox_to_anchor=(0.5, 0.02),
ncol=4, fontsize=10)
๐ Run the Complete Inference Pipelineยถ
Now letโs put it all together! Weโll load the model, prepare data, run inference, and visualize results:
# Load the trained model
model = get_model(CHESAPEAKE_CHECKPOINT_PATH, CLAY_CHECKPOINT_PATH, METADATA_PATH)
# Get validation data for testing
batch, metadata = get_data(
TRAIN_CHIP_DIR,
TRAIN_LABEL_DIR,
VAL_CHIP_DIR,
VAL_LABEL_DIR,
METADATA_PATH,
BATCH_SIZE,
NUM_WORKERS,
PLATFORM,
)
# Move data to GPU if available (same device as model)
device = next(model.parameters()).device
batch = {k: v.to(device) for k, v in batch.items()}
print(f"๐ฑ Using device: {device}")
# Run inference on the batch
outputs = run_prediction(model, batch)
# Post-process results for visualization
images, labels, preds = post_process(batch, outputs, metadata)
# Create the final visualization
plot_predictions(images, labels, preds)
print("\n๐ Inference and visualization complete!")
print("\n๐ What to Look For:")
print(" โข How well does the model identify water bodies?")
print(" โข Are forest areas correctly classified?")
print(" โข Does the model distinguish between different types of impervious surfaces?")
print(" โข Where does the model struggle or make mistakes?")
print("\n๐ก Next Steps:")
print(" โข Try running on more batches to see consistency")
print(" โข Consider additional training epochs for better performance")
print(" โข Experiment with different learning rates or data augmentation")