Processing earth observation data for semantic segmentation with deep learning#

A guide for processing raster data and labels into ML-ready format for use with a deep-learning based semantic segmentation.

Setup Notebook#

# install required libraries
!pip install -q rasterio==1.3.8
!pip install -q geopandas==0.13.2
!pip install -q radiant_mlhub # for dataset access, see: https://mlhub.earth/
# import required libraries
import os, glob, tarfile, json
from itertools import product
from pathlib import Path

import numpy as np
from fractions import Fraction  
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['axes.grid'] = False
mpl.rcParams['figure.figsize'] = (12,12)

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import matplotlib.image as mpimg
import pandas as pd
from PIL import Image

import rasterio
from rasterio import features, windows

import geopandas as gpd
import cv2

from tqdm.notebook import tqdm

from radiant_mlhub import Dataset, Collection
from google.colab import drive
# configure Radiant Earth MLHub access
!mlhub configure
if 'google.colab' in str(get_ipython()):
    # mount google drive
    drive.mount('/content/gdrive')
    processed_outputs_dir = '/content/gdrive/My Drive/tf-eo-devseed-2-processed-outputs/'
    user_outputs_dir = '/content/gdrive/My Drive/tf-eo-devseed-2-user_outputs_dir'
    if not os.path.exists(user_outputs_dir):
        os.makedirs(user_outputs_dir)
    print('Running on Colab')
else:
    processed_outputs_dir = os.path.abspath("./data/tf-eo-devseed-2-processed-outputs")
    user_outputs_dir = os.path.abspath('./tf-eo-devseed-2-user_outputs_dir')
    if not os.path.exists(user_outputs_dir):
        os.makedirs(user_outputs_dir)
        os.makedirs(processed_outputs_dir)
    print(f'Not running on Colab, data needs to be downloaded locally at {os.path.abspath(processed_outputs_dir)}')
%cd $user_outputs_dir

Enabling GPU#

Tip

This notebook can utilize a GPU and works better if you use one. Hopefully this notebook is using a GPU, and we can check with the following code.

If it’s not using a GPU you can change your session/notebook to use a GPU. See Instructions.

%tensorflow_version 2.x
import tensorflow as tf
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

Access the dataset#

We will use a crop type classification dataset from Radiant Earth MLHub: https://mlhub.earth/data/dlr_fusion_competition_germany

This dataset contains radar data from Sentinel-1, 3-meter resolution optical imagery from Planet Labs, and 10-20 meter resolution optical imagery from Sentinel-2. There is one train set and one test set with corresponding agriculture field polygon labels.

ds = Dataset.fetch('dlr_fusion_competition_germany')
for c in ds.collections:
    print(c.id)
dlr_fusion_competition_germany_train_source_planet
dlr_fusion_competition_germany_train_source_planet_5day
dlr_fusion_competition_germany_train_source_sentinel_1
dlr_fusion_competition_germany_train_source_sentinel_2
dlr_fusion_competition_germany_test_source_planet
dlr_fusion_competition_germany_test_source_planet_5day
dlr_fusion_competition_germany_test_source_sentinel_1
dlr_fusion_competition_germany_test_source_sentinel_2
dlr_fusion_competition_germany_train_labels
dlr_fusion_competition_germany_test_labels
collections = [
    'dlr_fusion_competition_germany_train_source_planet_5day',
    'dlr_fusion_competition_germany_test_source_planet_5day',
    'dlr_fusion_competition_germany_train_labels',
    'dlr_fusion_competition_germany_test_labels'
]

def download(collection_id):
    print(f'Downloading {collection_id}...')
    collection = Collection.fetch(collection_id)
    path = collection.download('.')
    tar = tarfile.open(path, "r:gz")
    tar.extractall()
    tar.close()
    os.remove(path)
    
def resolve_path(base, path):
    return Path(os.path.join(base, path)).resolve()
    
def load_df(collection_id):
    collection = json.load(open(f'{collection_id}/collection.json', 'r'))
    rows = []
    item_links = []
    for link in collection['links']:
        if link['rel'] != 'item':
            continue
        item_links.append(link['href'])
    for item_link in item_links:
        item_path = f'{collection_id}/{item_link}'
        current_path = os.path.dirname(item_path)
        item = json.load(open(item_path, 'r'))
        tile_id = item['id'].split('_')[-1]
        for asset_key, asset in item['assets'].items():
            rows.append([
                tile_id,
                None,
                None,
                asset_key,
                str(resolve_path(current_path, asset['href']))
            ])
            
        for link in item['links']:
            if link['rel'] != 'source':
                continue
            link_path = resolve_path(current_path, link['href'])
            source_path = os.path.dirname(link_path)
            try:
                source_item = json.load(open(link_path, 'r'))
            except FileNotFoundError:
                continue
            datetime = source_item['properties']['datetime']
            satellite_platform = source_item['collection'].split('_')[-1]
            for asset_key, asset in source_item['assets'].items():
                rows.append([
                    tile_id,
                    datetime,
                    satellite_platform,
                    asset_key,
                    str(resolve_path(source_path, asset['href']))
                ])
    return pd.DataFrame(rows, columns=['tile_id', 'datetime', 'satellite_platform', 'asset', 'file_path'])

for c in collections:
    download(c)

train_df = load_df('dlr_fusion_competition_germany_train_labels')
test_df = load_df('dlr_fusion_competition_germany_test_labels')

Check out the labels#

We’ll inspect the class labels that are stored in geojson by loading it as a GeoDataFrame. Class names and identifiers extracted from the documentation provided here: https://radiantearth.blob.core.windows.net/mlhub/esa-food-security-challenge/Crops_GT_Brandenburg_Doc.pdf

# Read the classes
pd.set_option('display.max_colwidth', None)

data = {'class_names':  ['Background', 'Wheat', 'Rye', 'Barley', 'Oats', 'Corn', 'Oil Seeds', 'Root Crops', 'Meadows', 'Forage Crops'],
        'class_ids': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
        }

classes = pd.DataFrame(data)
print(classes) 

classes.to_csv('lulc_classes.csv')

# Let's check the class labels
labels_geo = gpd.read_file('dlr_fusion_competition_germany_train_labels/dlr_fusion_competition_germany_train_labels_33N_18E_242N/labels.geojson')
classes = labels_geo.crop_id.unique()
classes.sort()
print("classes in labels geojson: ", classes)
    class_names  class_ids
0    Background          0
1         Wheat          1
2           Rye          2
3        Barley          3
4          Oats          4
5          Corn          5
6     Oil Seeds          6
7    Root Crops          7
8       Meadows          8
9  Forage Crops          9
classes in labels geojson:  [1 2 3 4 5 6 7 8 9]

Raster processing#

IMPORTANT

This section contains helper functions for processing the raw raster composites.

We’ll normalize each Planetscope image after we read it with the mean and standard deviation, and rescale all values to 8 bit integers. Rescaling to 8-bit integer keeps the training data as small as possible so that we can fit larger batch sizes into GPU memory. Normalizing each image helps the model train with more numerical stability and brings the data into a distribution that reflects the pretraining data of the pretrained model.

def raster_read(raster_dir):
    print(raster_dir)

    # Read band metadata and arrays
    # metadata
    rgbn = rasterio.open(os.path.join(raster_dir,'sr.tif')) #rgbn
    rgbn_src = rgbn
    target_crs = rgbn_src.crs
    print("rgbn: ", rgbn)

    # arrays
    # Read and re-scale the original 16 bit image to 8 bit.
    scale = True
    if scale:
      rgbn_norm = cv2.normalize(rgbn.read(), None, 0, 255, cv2.NORM_MINMAX)
      rgbn_norm_out=rasterio.open(os.path.join(raster_dir,'sr_byte_scaled.tif'), 'w', driver='Gtiff',
                                  width=rgbn_src.width, height=rgbn_src.height,
                                  count=4,
                                  crs=rgbn_src.crs,
                                  transform=rgbn_src.transform,
                                  dtype='uint8')

      rgbn_norm_out.write(rgbn_norm)
      rgbn_norm_out.close()
      rgbn = rasterio.open(os.path.join(raster_dir,'sr_byte_scaled.tif')) #rgbn
    else:
      rgbn = rasterio.open(os.path.join(raster_dir,'sr_byte_scaled.tif')) #rgbn
    print("Scaled to 8bit.")
    return raster_dir, rgbn, rgbn_src, target_crs

Next we will calculate relevant spectral indices. Using spectral indices when fine-tuning a pre-trained model on 3-band imagery is a quick way to make use of information from multiple multi-spectral bands while still getting the benefits of pre-training.

The following are the indices we will compute

WDRVI: Wide Dynamic Range Vegetation Index
NDVI: Normalized Difference Vegetation Index
SI: Shadow Index

# calculate spectral indices and concatenate them into one 3 channel image
def indexnormstack(red, green, blue, nir):

    def WDRVIcalc(nir, red): 
        a = 0.15
        wdrvi =  (a * nir-red)/(a * nir+red)
        return wdrvi

    def NPCRIcalc(red,blue):
        npcri =  (red-blue)/(red+blue)
        return npcri

    def NDVIcalc(nir, red): 
        ndvi = (nir - red) / (nir  + red + 1e-5) 
        return ndvi


    def SIcalc(red, green, blue):
        expo = Fraction('1/3') 
        si = (((1-red)*(1-green)*(1-blue))**expo)
        return si

    def norm(arr):
        scaler = MinMaxScaler(feature_range=(0, 255))
        scaler = scaler.fit(arr)
        arr_norm = scaler.transform(arr)
        # Checking reconstruction
        #arr_norm = scaler.inverse_transform(arr_norm)

        return arr_norm

    wdrvi = WDRVIcalc(nir,red) 

    #npcri = NPCRIcalc(red,blue)

    ndi = NDVIcalc(nir, red)

    si = SIcalc(red,green,blue) 

    print("wdrvi: ", wdrvi.min(), wdrvi.max(), "ndi: ", ndi.min(), ndi.max(), "si: ", si.min(), si.max())

    wdrvi = norm(wdrvi)
    ndi = norm(ndi)
    si = norm(si)

    index_stack = np.dstack((wdrvi, ndi, si))

    return index_stack

We can also stack specific bands of interest, and train the model with those.

def bandstack(red, green, blue, nir):

    stack = np.dstack((red, green, blue))

    return stack

Below is an (optional) color correction for the optical composite. We typically prepare data augmentations during training that can change brightness values and create other synthetic data within the tensorflow data pipeline. But, it can be helpful to save out image inputs with color corrections applied ahead of time, so that it is easier to visualize and compare imagery to labels and predictions during training.

# function to increase the brightness in an image
def change_brightness(img, value=30):
    hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
    h, s, v = cv2.split(hsv)
    v = cv2.add(v,value)
    v[v > 255] = 255
    v[v < 0] = 0
    final_hsv = cv2.merge((h, s, v))
    img = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR)
    return img

If you are rasterizing the labels from a vector file (e.g. GeoJSON or Shapefile).

Below we will rasterize our vector labels. This is a common step in image segmentation, since we need to pass a pixel image representing our labels to Tensorflow for training in order to compute the loss function.

We’ll read the label shapefile into a geopandas dataframe, check for invalid geometries and set it to the local CRS. Then, we rasterize the labeled polygons using the metadata from it’s corresponding grayscale band images.

In this function, geo_1 is used when there are two vector files used for labeling. The latter is given preference over the former because it overwrites when intersections occur. This section of the code was used historically where a different dataset had conflicting labels for the same location and date, and we chose to pick one label for demonstration purposes.

def label(geos, labels_src):
    geo_0 = gpd.read_file(geos[0])
    # check for and remove invalid geometries
    geo_0 = geo_0.loc[geo_0.is_valid]
    # reproject training data into local coordinate reference system
    geo_0 = geo_0.to_crs(crs={'init': target_crs})
    #convert the class identifier column to type integer
    geo_0['landcover_int']  = geo_0.crop_id.astype(int)
    # pair the geometries and their integer class values
    shapes_0 = ((geom,value) for geom, value in zip(geo_0.geometry, geo_0.landcover_int)) 
    if len(geos) > 1:
      geo_1 = gpd.read_file(geos[1])
      geo_1 = geo_1.loc[geo_1.is_valid]
      geo_1 = geo_1.to_crs(crs={'init': target_crs})
      geo_1['landcover_int']  = geo_1.crop_id.astype(int)
      shapes_1 = ((geom,value) for geom, value in zip(geo_1.geometry, geo_1.landcover_int)) 
    else:
      print("Only one source of vector labels.") #continue

    # get the metadata (height, width, channels, transform, CRS) to use in constructing the labeled image array
    labels_src_prf = labels_src.profile
    # construct a blank array from the metadata and burn the labels in
    labels = features.rasterize(shapes=shapes_0, out_shape=(labels_src_prf['height'], labels_src_prf['width']), fill=0, all_touched=True, transform=labels_src_prf['transform'], dtype=labels_src_prf['dtype'])
    if len(geos) > 1:
      labels = features.rasterize(shapes=shapes_1, fill=0, all_touched=True, out=labels, transform=labels_src_prf['transform'])
    else:
      print("Only one source of vector labels.") #continue

    print("Values in labeled image: ", np.unique(labels))


    return labels

Below is a single function to write all of the processed rasters to files so we can use them for training later.

It’s often more efficient to save out processed intermediates before model training, rather than to keep all image processing on the fly.

def save_images(raster_dir, rgb_norm, stack, index_stack, labels, rgb_src):

    stack_computed = True # change to True if using the stack helper function above

    if stack_computed:
      stack_t = stack.transpose(2,0,1)
    else:
      stack_t = stack

    stack_out=rasterio.open(os.path.join(raster_dir,'stack.tif'), 'w', driver='Gtiff',
                              width=rgb_src.width, height=rgb_src.height,
                              count=3,
                              crs=rgb_src.crs,
                              transform=rgb_src.transform,
                              dtype='uint8')

    stack_out.write(stack_t)

    indices_computed = True # change to True if using the index helper function above
    if indices_computed:
      index_stack_t = index_stack.transpose(2,0,1)
    else:
      index_stack_t = index_stack

    index_stack_out=rasterio.open(os.path.join(raster_dir,'index_stack.tif'), 'w', driver='Gtiff',
                              width=rgb_src.width, height=rgb_src.height,
                              count=3,
                              crs=rgb_src.crs,
                              transform=rgb_src.transform,
                              dtype='uint8')

    index_stack_out.write(index_stack_t)
    #index_stack_out.close()

    labels = labels.astype(np.uint8)
    labels_out=rasterio.open(os.path.join(raster_dir,'labels.tif'), 'w', driver='Gtiff',
                              width=rgb_src.width, height=rgb_src.height,
                              count=1,
                              crs=rgb_src.crs,
                              transform=rgb_src.transform,
                              dtype='uint8')

    labels_out.write(labels, 1)
    #labels_out.close()

    print("written")

    return os.path.join(raster_dir,'stack.tif'), os.path.join(raster_dir,'index_stack.tif'), os.path.join(raster_dir,'labels.tif')

Now let’s divide the optical/index stack and labeled image into 224x224 pixel tiles.

def tile(index_stack, labels, prefix, width, height, raster_dir, output_dir, brighten=False):
    tiles_dir = os.path.join(output_dir,'tiled/')
    img_dir = os.path.join(output_dir,'tiled/stacks_brightened/')
    label_dir = os.path.join(output_dir,'tiled/labels/')
    dirs = [tiles_dir, img_dir, label_dir]
    for d in dirs:
        if not os.path.exists(d):
            os.makedirs(d)

    def get_tiles(ds):
        # get number of rows and columns (pixels) in the entire input image
        nols, nrows = ds.meta['width'], ds.meta['height']
        # get the grid from which tiles will be made 
        offsets = product(range(0, nols, width), range(0, nrows, height))
        # get the window of the entire input image
        big_window = windows.Window(col_off=0, row_off=0, width=nols, height=nrows)
        # tile the big window by mini-windows per grid cell
        for col_off, row_off in offsets:
            window = windows.Window(col_off=col_off, row_off=row_off, width=width, height=height).intersection(big_window)
            transform = windows.transform(window, ds.transform)
            yield window, transform

    tile_width, tile_height = width, height

    def crop(inpath, outpath, c):
        # read input image
        image = rasterio.open(inpath)
        # get the metadata 
        meta = image.meta.copy()
        print("meta: ", meta)
        # set the number of channels to 3 or 1, depending on if its the index image or labels image
        meta['count'] = int(c)
        # set the tile output file format to PNG (saves spatial metadata unlike JPG)
        meta['driver']='PNG'
        meta['dtype']='uint8'
        # tile the input image by the mini-windows
        i = 0
        for window, transform in get_tiles(image):
            meta['transform'] = transform
            meta['width'], meta['height'] = window.width, window.height
            outfile = os.path.join(outpath,"tile_%s_%s.png" % (prefix, str(i)))
            with rasterio.open(outfile, 'w', **meta) as outds:
                if brighten:
                  imw = image.read(window=window)
                  imw = imw.transpose(1,2,0)
                  imwb = change_brightness(imw, value=50)
                  imwb = imwb.transpose(2,0,1)
                  outds.write(imwb)
                else:
                  outds.write(image.read(window=window))
            i = i+1

    def process_tiles(index_flag):
        # tile the input images, when index_flag == True, we are tiling the spectral index image, 
        # when False we are tiling the labels image
        if index_flag==True:
            inpath = os.path.join(raster_dir,'stack.tif')
            outpath=img_dir
            crop(inpath, outpath, 3)
        else:
            inpath = os.path.join(raster_dir,'labels.tif')
            outpath=label_dir
            crop(inpath, outpath, 1)

    process_tiles(index_flag=True) # tile stack
    process_tiles(index_flag=False) # tile labels
    return tiles_dir, img_dir, label_dir

Run the image processing workflow.

Long running code ⚠ Google Drive based workflows can incur timeouts and latency issues. If this happens, try running the affected cell again. Having a VM with a mounted SSD would be a good start to solving these associated latency problems incurred from I/O of data hosted in Google Drive.
train_images_dir = 'dlr_fusion_competition_germany_train_source_planet_5day'
%cd $train_images_dir 
/content/gdrive/My Drive/tf-eo-devseed/dlr_fusion_competition_germany_train_source_planet_5day
train_images_dirs = [f.path for f in os.scandir('./') if f.is_dir()]
train_images_dirs = [x.replace('./', '') if type(x) is str else x for x in train_images_dirs]
%cd $processed_outputs_dir 
/content/gdrive/My Drive/tf-eo-devseed

If you want to write the files out to your personal drive, set write_out = True, but we recommend trying that in your free time because it takes about 2 hours or more for all composites when using Google Colab + Google Drive for storage.

write_out = False
if write_out:
  raster_out_dir = os.path.join(user_outputs_dir,'rasters/')
  if not os.path.exists(raster_out_dir):
    os.makedirs(raster_out_dir)
  for train_image_dir in train_images_dirs: #[0:1]:
    # read the rasters and scale to 8bit
    print("reading and scaling rasters...")
    raster_dir, rgbn, rgbn_src, target_crs = raster_read(os.path.join(train_images_dir,train_image_dir))

    # Calculate indices and combine the indices into one single 3 channel image
    print("calculating spectral indices...")
    index_stack = indexnormstack(rgbn.read(3), rgbn.read(2), rgbn.read(1), rgbn.read(4))

    # Stack channels of interest (RGB) into one single 3 channel image
    print("Stacking channels of interest...")
    stack = bandstack(rgbn.read(3), rgbn.read(2), rgbn.read(1), rgbn.read(4))

    # Color correct the RGB image
    print("Color correcting a RGB image...")
    cc_stack = change_brightness(stack)

    # Rasterize labels
    labels = label([os.path.join(processed_outputs_dir,'dlr_fusion_competition_germany_train_labels/dlr_fusion_competition_germany_train_labels_33N_18E_242N/labels.geojson')], rgbn_src)

    # Save index stack and labels to geotiff
    print("writing scaled rasters and labels to file...")
    stack_file, index_stack_file, labels_file = save_images(raster_dir, rgbn, cc_stack, index_stack, labels, rgbn_src)

    # Tile images into 224x224
    print("tiling the indices and labels...")
    tiles_dir, img_dir, label_dir = tile(stack, labels, str(train_image_dir), 224, 224, raster_dir, raster_out_dir, brighten=False)
else:
  print("Not writing to file; using preprocessed dataset in shared drive.")

Read the data into memory#

Getting set up with the data#

Important

The tiled imagery is available at the following path that is accessible with the google.colab drive module: '/content/gdrive/My Drive/tf-eo-devseed-processed-outputs/'

We’ll be working with the following folders and files in the tf-eo-devseed-processed-outputs/ folder:

tf-eo-devseed-processed-outputs/
├── stacks/
├── stacks_brightened/
├── indices/
├── labels/
├── background_list_train.txt
├── train_list_clean.txt
└── lulc_classes.csv

Get lists of image and label tile pairs for training and testing.

def get_train_test_lists(imdir, lbldir):
  imgs = glob.glob(os.path.join(imdir,"*.png"))
  #print(imgs[0:1])
  dset_list = []
  for img in imgs:
    filename_split = os.path.splitext(img) 
    filename_zero, fileext = filename_split 
    basename = os.path.basename(filename_zero) 
    dset_list.append(basename)

  x_filenames = []
  y_filenames = []
  for img_id in dset_list:
    x_filenames.append(os.path.join(imdir, "{}.png".format(img_id)))
    y_filenames.append(os.path.join(lbldir, "{}.png".format(img_id)))

  print("number of images: ", len(dset_list))
  return dset_list, x_filenames, y_filenames

train_list, x_train_filenames, y_train_filenames = get_train_test_lists(img_dir, label_dir)

When training with satellite imagery, our detection targets of interest are can be very sparse, which results in a lot of images that don’t have any classes of interest. In computer vision, we call these images background images. It’s good to know the count and proportion of background to non-background, so that you can control for class imbalance. It’s even better to understand the geographic spread of your background and non-background images and control for it with a sampling strategy.

Here, we check for the proportion of background tiles. This takes a while. So after running this once, you can skip by loading from saved results.

skip = False

if not skip:
  background_list_train = []
  for i in train_list: 
      # read in each labeled images
      # print(os.path.join(label_dir,"{}.png".format(i))) 
      img = np.array(Image.open(os.path.join(label_dir,"{}.png".format(i))))  
      # check if no values in image are greater than zero (background value)
      if img.max()==0:
          background_list_train.append(i)

  print("Number of background images: ", len(background_list_train))

  with open(os.path.join(processed_outputs_dir,'background_list_train.txt'), 'w') as f:
    for item in background_list_train:
        f.write("%s\n" % item)

else:
  background_list_train = [line.strip() for line in open("background_list_train.txt", 'r')]
  print("Number of background images: ", len(background_list_train))

We will keep only 10% of the total. Too many background tiles can cause a form of class imbalance, since the background class, which cna contain a lot of different phenomena and false positives, is very over represented.

background_removal = len(background_list_train) * 0.9
train_list_clean = [y for y in train_list if y not in background_list_train[0:int(background_removal)]]

x_train_filenames = []
y_train_filenames = []

for i, img_id in zip(tqdm(range(len(train_list_clean))), train_list_clean):
  pass 
  x_train_filenames.append(os.path.join(img_dir, "{}.png".format(img_id)))
  y_train_filenames.append(os.path.join(label_dir, "{}.png".format(img_id)))

print("Number of background tiles: ", background_removal)
print("Remaining number of tiles after 90% background removal: ", len(train_list_clean))

Now that we have our set of files we want to use for developing our model, we need to split them into three sets:

  • the training set for the model to learn from

  • the validation set that allows us to evaluate models and make decisions to change models

  • and the test set that we will use to communicate the results of the best performing model (as determined by the validation set)

We will split index tiles and label tiles into train, validation and test sets: 70%, 20% and 10%, respectively.

x_train_filenames, x_val_filenames, y_train_filenames, y_val_filenames = train_test_split(x_train_filenames, y_train_filenames, test_size=0.3, random_state=42)
x_val_filenames, x_test_filenames, y_val_filenames, y_test_filenames = train_test_split(x_val_filenames, y_val_filenames, test_size=0.33, random_state=42)

num_train_examples = len(x_train_filenames)
num_val_examples = len(x_val_filenames)
num_test_examples = len(x_test_filenames)

print("Number of training examples: {}".format(num_train_examples))
print("Number of validation examples: {}".format(num_val_examples))
print("Number of test examples: {}".format(num_test_examples))

Warning

Long running cell
The code below checks for values in train, val, and test partitions. We won’t run this since it takes over 10 minutes on colab due to slow IO.

vals_train = []
vals_val = []
vals_test = []

def get_vals_in_partition(partition_list, x_filenames, y_filenames):
  for x,y,i in zip(x_filenames, y_filenames, tqdm(range(len(y_filenames)))):
      pass 
      try:
        img = np.array(Image.open(y)) 
        vals = np.unique(img)
        partition_list.append(vals)
      except:
        continue

def flatten(partition_list):
    return [item for sublist in partition_list for item in sublist]

get_vals_in_partition(vals_train, x_train_filenames, y_train_filenames)
get_vals_in_partition(vals_val, x_val_filenames, y_val_filenames)
get_vals_in_partition(vals_test, x_test_filenames, y_test_filenames)
print("Values in training partition: ", set(flatten(vals_train)))
print("Values in validation partition: ", set(flatten(vals_val)))
print("Values in test partition: ", set(flatten(vals_test)))
Values in training partition:  {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
Values in validation partition:  {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
Values in test partition:  {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}

Visualize the data#

Warning

Long running cell
The code below loads foreground examples randomly. It’s always a good sanity check to be able to plot your masks next to your labels for many samples to make sure that labels are being matched to masks correctly, that labels are high quality and representing the detection targets of interest, and that the images look correct.

display_num = 3

background_list_train = [line.strip() for line in open("background_list_train.txt", 'r')]

# select only for tiles with foreground labels present
foreground_list_x = []
foreground_list_y = []
for x,y in zip(x_train_filenames, y_train_filenames): 
    try:
      filename_split = os.path.splitext(y) 
      filename_zero, fileext = filename_split 
      basename = os.path.basename(filename_zero) 
      if basename not in background_list_train:
        foreground_list_x.append(x)
        foreground_list_y.append(y)
      else:
        continue
    except:
      continue

num_foreground_examples = len(foreground_list_y)

# randomlize the choice of image and label pairs
r_choices = np.random.choice(num_foreground_examples, display_num)

plt.figure(figsize=(10, 15))
for i in range(0, display_num * 2, 2):
  img_num = r_choices[i // 2]
  img_num = i // 2
  x_pathname = foreground_list_x[img_num]
  y_pathname = foreground_list_y[img_num]
  
  plt.subplot(display_num, 2, i + 1)
  plt.imshow(mpimg.imread(x_pathname))
  plt.title("Original Image")
  
  example_labels = Image.open(y_pathname)
  label_vals = np.unique(np.array(example_labels))
  
  plt.subplot(display_num, 2, i + 2)
  plt.imshow(example_labels)
  plt.title("Masked Image")  
  
plt.suptitle("Examples of Images and their Masks")
plt.show()