Detection and classification demo workflow#
import os
import json
import glob
from random import shuffle
import cv2
import numpy as np
import torch
from torchmetrics import detection, F1Score
from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.data.datasets import register_coco_instances
from detectron2.engine import DefaultPredictor
from detectron2.engine.defaults import DefaultPredictor
from detectron2.utils.logger import setup_logger
from detectron2.utils.visualizer import Visualizer
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import rasterio
import pandas as pd
import torch, detectron2
!nvcc --version
TORCH_VERSION = ".".join(torch.__version__.split(".")[:2])
CUDA_VERSION = torch.__version__.split("+")[-1]
print("torch: ", TORCH_VERSION, "; cuda: ", CUDA_VERSION)
print("detectron2:", detectron2.__version__)
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Mon_Apr__3_17:16:06_PDT_2023
Cuda compilation tools, release 12.1, V12.1.105
Build cuda_12.1.r12.1/compiler.32688072_0
torch: 2.1 ; cuda: cu121
detectron2: 0.6
Set up#
data_dir = "/home/ubuntu/data/"
model_dir = "/home/ubuntu/model/"
# Register the partitioned COCO datasets for building detection
register_coco_instances("hp_train", {}, f"{data_dir}coco/instances_default_train.json", f"{data_dir}mapillary_images")
register_coco_instances("hp_val", {}, f"{data_dir}coco/instances_default_val.json", f"{data_dir}mapillary_images")
register_coco_instances("hp_test", {}, f"{data_dir}coco/instances_default_test.json", f"{data_dir}mapillary_images")
# Read the partitioned COCO datasets into dictionaries for later use
train_json = open(f"{data_dir}coco/instances_default_train.json")
train_json = json.load(train_json)
val_json = open(f"{data_dir}coco/instances_default_val.json")
val_json = json.load(val_json)
test_json = open(f"{data_dir}coco/instances_default_test.json")
test_json = json.load(test_json)
Check some images and labels#
i=0
for gt in train_json['annotations'][0:10]:
for im in train_json['images']:
if gt['image_id'] == im['id']:
img = cv2.imread(f"{data_dir}mapillary_images/"+im['file_name'])
fig, ax = plt.subplots()
ax.imshow(img)
rect_name = f'rect_{i}'
class Rectv:
rect_name = f'rect_{i}'
r = Rectv()
setattr(r, rect_name, rect_name)
r.rect_name = patches.Rectangle(
(gt['bbox'][0], gt['bbox'][1]),
gt['bbox'][2],
gt['bbox'][3],
linewidth=1,
edgecolor='r',
facecolor='none'
)
ax.add_patch(r.rect_name)
i=i+1
else:
continue
plt.show()
data:image/s3,"s3://crabby-images/5ae6c/5ae6c1121510eaa15ec1d450bdc9adcb90c83bed" alt="../_images/4a6a9c2dbdf33a160610566952992663dcc72bbb7688d0a1d62b33f2dac383ec.png"
data:image/s3,"s3://crabby-images/0de26/0de26b02b05786296d881416884df48d423e032d" alt="../_images/35eeeee90ee277a1c3ef0d7ff717975be548ade6f48888568f5997438b7e073a.png"
data:image/s3,"s3://crabby-images/52d5d/52d5d6f970a87ec3c51c0c1a166f4a2a37190f42" alt="../_images/3a85b8b16dacffc8ca9bf076af107c9f3d90ed8fe65b926e1782eedeaa02a456.png"
data:image/s3,"s3://crabby-images/d246b/d246baf04a02359d915e2dd5bcd0f28cad204fa7" alt="../_images/5e91023de95d8be78bf900f62e4d491c71527d4e39b41abf9d30572f686390f0.png"
data:image/s3,"s3://crabby-images/db3f7/db3f7b9b852ff528087fa6a57dc4a2b7795e3889" alt="../_images/8782a8ee82ae90de16d2d3be55fd14ec06012c4c79d51c1e97e6e5c155302734.png"
data:image/s3,"s3://crabby-images/d9e51/d9e5179077ab578c1e545a82ce38070c1f386d85" alt="../_images/9ca198684b962f30a523d7cdbc6f7e8e00864b26862b9dbd3c5774bc6f54ae24.png"
data:image/s3,"s3://crabby-images/263a1/263a13b0251c21eebbd1906f6c2cd2aa36d23826" alt="../_images/cb8c8d71b99ec5f1d45ec6bf98a35187107395311a51061737ff0c61a6b1a605.png"
data:image/s3,"s3://crabby-images/279c8/279c8fc24f0d5256e202829ee482a262973f666c" alt="../_images/5f67164750b082f48b792da60717cf5c1485c9cf7b26802a5bcd488dca8ea430.png"
data:image/s3,"s3://crabby-images/c967d/c967d44411a5a9c9a35ee8d0d9acaa5d716d7385" alt="../_images/ee350dc1d8d1beabce9825861b24e5d06d5c3ecce4afaa80cf048c3ed80c6535.png"
data:image/s3,"s3://crabby-images/2727a/2727a3abba48b9979c2ea7c1684c3643a04d5e76" alt="../_images/ba6643e6a8475b363ce318bf3c401348c849cb2350702a03b276810fb7501ea1.png"
Load the model#
os.chdir(model_dir)
detector_cpkt_path = f"{model_dir}output/50kiter/model_final.pth"
# Set up configuration details for the Detectron2 training experiment
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/retinanet_R_50_FPN_3x.yaml")) #faster_rcnn_X_101_32x8d_FPN_3x.yaml"))
cfg.MODEL.WEIGHTS = os.path.join(detector_cpkt_path)
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.MODEL.RETINANET.SCORE_THRESH_TEST = 0.5
cfg.DATASETS.TEST = ("hp_test",)
predictor = DefaultPredictor(cfg)
Check the logs of the trained model#
# Look at training curves in tensorboard:
%load_ext tensorboard
#%reload_ext tensorboard
%tensorboard --logdir output/50kiter/
Run inference#
# In case we'd like to store our test images with inferences plotted
inference_image_output = "./output/test_inference_images/"
os.makedirs(inference_image_output, exist_ok=True)
Generate predictions and plot them on some subset of the test dataset#
outputs_test = []
for gt in test_json['annotations'][0:10]:
for i in test_json['images']:
if gt['image_id'] == i['id']:
im = cv2.imread(f"{data_dir}mapillary_images/"+i['file_name'])
outputs = predictor(im)
#print(outputs["instances"].pred_boxes)
outputs_test.append(outputs)
# Use `Visualizer` to draw the predictions on the image.
v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TEST[0]), scale=1.2)
out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
bb = gt['bbox']
bbxyxy = (bb[0], bb[1], bb[0]+bb[2], bb[1]+bb[3])
out = v.draw_box(box_coord = bbxyxy, edge_color='dodgerblue')
pred_image = out.get_image()[:, :, ::-1]
window_name = "prediction"
# Display the image using imshow
#cv2.imshow(window_name, pred_image)
#cv2.waitKey(0)
#cv2.destroyAllWindows()
#cv2.imshow(out.get_image()[:, :, ::-1])
cv2.imwrite(f"{inference_image_output}{i['file_name'].replace('/', '_')}", out.get_image()[:, :, ::-1])
#print("written to ", f"{inference_image_output}{i['file_name'].replace('/', '_')}")
ls ./output/test_inference_images/
Visualize one of the resulting images#
import matplotlib.image as mpimg
# Specify the path to your image
image_path = './output/test_inference_images/r7tGTdge9ZQXVPAcNbiEnu_left_1966801483681588_left.jpg'
# Load the image using matplotlib.image.imread
img = mpimg.imread(image_path)
# Plot the image using matplotlib.pyplot.imshow
plt.imshow(img)
<matplotlib.image.AxesImage at 0x7fd691df7280>
data:image/s3,"s3://crabby-images/5691f/5691fc29ad574f4a2c37faf5b2471c57fb32b1b2" alt="../_images/6b4ca6dc7fa9f711f1cb74f3f5ac13b3a1e3d9b6f8bd4a9b3dc58c16dce70213.png"
Run evaluation of the trained model on the test dataset#
We’ll calculate mean average precision (mAP) as our score.
# Use the test dataset
dataset_name = "hp_test"
# Load metadata and dataset
metadata = MetadataCatalog.get(dataset_name)
dataset_dicts = DatasetCatalog.get(dataset_name)
# Lists to store predicted and ground truth bounding boxes
outputs_test_preds = []
outputs_test_gt = []
# Iterate over annotations in the test dataset
for t in test_json['annotations']:
# Get the image path from the test dataset
image_path = [i['file_name'] for i in test_json['images'] if i['id'] == t['image_id']]
# Read the image using OpenCV
im = cv2.imread(f"{data_dir}mapillary_images/{image_path[0]}")
# Make predictions using the trained model
outputs = predictor(im)
try:
# Get ground truth annotations for the current image
annotations = dataset_dicts[t['image_id']]["annotations"]
gt_boxes = [ann["bbox"] for ann in annotations]
# Check if there are both predicted and ground truth bounding boxes
if gt_boxes and outputs["instances"].pred_boxes:
# Convert ground truth bounding boxes to the required format (xywh -> xyxy)
gt_boxes = [[bb[0], bb[1], bb[0] + bb[2], bb[1] + bb[3]] for bb in gt_boxes]
# Convert predicted bounding boxes and labels to tensors
pred_boxes = torch.stack([torch.tensor(b) for b in outputs["instances"].pred_boxes])
gt_boxes = torch.stack([torch.tensor(list(b)) for b in gt_boxes])
labels = [1] * len(outputs["instances"].pred_boxes)
labels_preds = torch.stack([torch.tensor(l) for l in labels])
labels = [1] * len(gt_boxes)
labels_gts = torch.stack([torch.tensor(l) for l in labels])
# Create dictionaries for ground truth and predicted bounding boxes
gt = {'image_id': t['image_id'], 'labels': labels_gts, 'boxes': gt_boxes}
prd = {'image_id': t['image_id'], 'labels': labels_preds, 'boxes': pred_boxes, 'scores': outputs["instances"].scores}
# Append the dictionaries to the respective lists
outputs_test_preds.append(prd)
outputs_test_gt.append(gt)
except Exception as e:
# Handle exceptions, e.g., when ground truth annotations are not available
print(f"Exception: {e}")
pass
# whole test set
m = detection.mean_ap.MeanAveragePrecision(box_format='xyxy', iou_type='bbox', iou_thresholds=[0.0], rec_thresholds=None, max_detection_thresholds=None, class_metrics=False)
m.update(preds=outputs_test_preds, target=outputs_test_gt)
from pprint import pprint
pprint(m.compute()['map'])
tensor(0.8115)
Run detection, clipping and classification#
# go to the location of the repo
os.chdir(f"{model_dir}housing-passports-v2/")
from src.datamodule import HouseDataModule
from src.model import HPClassifier
from detect_clip_classify import evaluate_classification_model, clip_image_around_bbox_buffer, load_classification_model
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, ToPILImage
# load the trained model
classification_ckpt_path = "/home/ubuntu/.aim/63869a95ae444b3d87f78331/checkpoints/epoch:48-step:12838-loss:2.419-f1:4.245.ckpt"
classification_model = load_classification_model(classification_ckpt_path)
classification_model.to(classification_model.device)
# iterate through the first 100 images in the test dataset
cumulative_predictions = []
for gt in test_json['annotations'][0:100]:
for i in test_json['images']:
if gt['image_id'] == i['id']:
im = cv2.imread(f"{data_dir}mapillary_images/"+i['file_name'])
outputs = predictor(im)
bb_scores = outputs["instances"].scores
box_num = 0
if outputs["instances"].pred_boxes:
for bb, sc in zip(outputs["instances"].pred_boxes, outputs["instances"].scores):
bbox_data = bb.tolist()
clipped_img = clip_image_around_bbox_buffer(im, bbox_data)
img = torch.tensor(clipped_img.transpose(2,0,1)).to(torch.float32)
img = ToPILImage()(img)
transform = Compose([
Resize((512,512), antialias=True),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
img = transform(img)
img = img.unsqueeze(0)
image_clipped_output_dirs = i['file_name'].split('/')
os.makedirs(os.path.join("./output/test_inference_images", image_clipped_output_dirs[0]), exist_ok=True)
os.makedirs(os.path.join("./output/test_inference_images", image_clipped_output_dirs[0], \
image_clipped_output_dirs[1]), exist_ok=True)
if len(outputs["instances"].pred_boxes) > 1:
box_num=box_num+1
categories = evaluate_classification_model(classification_model, img, classification_model.device)
image_boxes_categories = {"image_name": i['file_name'], "boxes": bbox_data, "box_scores": sc, \
"complete": categories["complete"], "condition": categories["condition"], \
"material": categories["material"], "security": categories["security"], \
"use": categories["use"], "image_name_clip": f"{i['file_name'][:-4]}_{box_num}.jpg", \
"ground_truth": gt['attributes']}
cumulative_predictions.append(image_boxes_categories)
cv2.imwrite(os.path.join("./output/test_inference_images/", f"{i['file_name'][:-4]}_{box_num}.jpg"), clipped_img)
else:
categories = evaluate_classification_model(classification_model, img, classification_model.device)
image_boxes_categories = {"image_name": i['file_name'], "boxes": bbox_data, "box_scores": sc, \
"complete": categories["complete"], "condition": categories["condition"], \
"material": categories["material"], "security": categories["security"], \
"use": categories["use"], "image_name_clip": f"{i['file_name']}", \
"ground_truth": gt['attributes']}
cumulative_predictions.append(image_boxes_categories)
cumulative_gt_attributes.append(gt['attributes'])
cv2.imwrite(os.path.join("./output/test_inference_images/", f"{i['file_name']}"), clipped_img)
# make a dataframe of the predictions
df_preds = pd.DataFrame(cumulative_predictions)
Visualize some results#
df_preds_ = df_preds.sample(frac=1).reset_index(drop=True)
# Display images and bounding boxes for a random selection of tiles
for id, p in df_preds_.head(10).iterrows():
if os.path.exists(f"output/test_inference_images/{p.image_name_clip}"):
img = rasterio.open(f"./output/test_inference_images/{p.image_name_clip}")
imgr = img.read()
print("File ", p.image_name_clip, " with ground truth attributes ", p.ground_truth)
fig, ax = plt.subplots()
ax.imshow(imgr.transpose(1, 2, 0))
plt.text(0.5, -0.1, f"[{p.complete}, {p.condition}, {p.material}, {p.security}, {p.use}]", ha='center', fontsize=12, color='white', fontweight='bold', bbox=dict(facecolor='black', alpha=0.5))
plt.axis('off')
plt.show()
else:
pass
File rb3QV25WsP0NdylOHZSzxE/left/995792421751752_left.jpg with ground truth attributes {'building_completeness': 'complete', 'building_condition': 'poor', 'building_material': 'mix-other-unclear', 'building_security': 'unsecured', 'building_use': 'residential', 'occluded': False, 'rotation': 0.0}
data:image/s3,"s3://crabby-images/c1c65/c1c65a430cad8abc509c3d2f8ee0cea564517940" alt="../_images/f856256332ecff8f94c405fd131c69f67b2728200cded1b5f5b6939a4a8a29b1.png"
File sVDfTtZgioxYQJrk6clSeI/left/317396310690052_left.jpg with ground truth attributes {'building_completeness': 'complete', 'building_condition': 'fair', 'building_material': 'mix-other-unclear', 'building_security': 'unsecured', 'building_use': 'residential', 'occluded': False, 'rotation': 0.0}
data:image/s3,"s3://crabby-images/c6d63/c6d634dcd6144f3dc78e4ded85403a95bdd742aa" alt="../_images/3b69165d1671752ca1e3b8eaf5cd6d14833073016329c9b9ae63b58698e8b14e.png"
File rb3QV25WsP0NdylOHZSzxE/left/330498719422552_left.jpg with ground truth attributes {'building_completeness': 'complete', 'building_condition': 'poor', 'building_material': 'plaster', 'building_security': 'unsecured', 'building_use': 'mixed', 'occluded': False, 'rotation': 0.0}
data:image/s3,"s3://crabby-images/ec25a/ec25acbbc0c9a3d5e29cb19348b874528b59476d" alt="../_images/d8cdd5ea2bc7914ffb56568f40dbf936ee430480d4265c5a3913b3ba4b9f73e4.png"
File rb3QV25WsP0NdylOHZSzxE/left/830052638707374_left.jpg with ground truth attributes {'building_completeness': 'incomplete', 'building_condition': 'poor', 'building_material': 'corrugated_metal', 'building_security': 'unsecured', 'building_use': 'residential', 'occluded': False, 'rotation': 0.0}
data:image/s3,"s3://crabby-images/ad185/ad1850112beae1eec9ef77c5014077e87d8f6730" alt="../_images/35aad4d190fe3b2f4ca56533f4bad3bb12b0827f449d4246aeedacf95f4330f5.png"
File rb3QV25WsP0NdylOHZSzxE/left/295172426598155_left_2.jpg with ground truth attributes {'building_completeness': 'complete', 'building_condition': 'poor', 'building_material': 'plaster', 'building_security': 'secured', 'building_use': 'mixed', 'occluded': False, 'rotation': 0.0}
data:image/s3,"s3://crabby-images/6b740/6b740b651a8c6bc3bc432f5b37278a89372884a7" alt="../_images/c1b57d9967b76192c70fb7116ef04f40deded83e7d58f36d6563443c6f9de024.png"
File sVDfTtZgioxYQJrk6clSeI/left/1464783370936895_left_2.jpg with ground truth attributes {'building_completeness': 'complete', 'building_condition': 'poor', 'building_material': 'wood_polished', 'building_security': 'unsecured', 'building_use': 'residential', 'occluded': False, 'rotation': 0.0}
data:image/s3,"s3://crabby-images/567e7/567e7bcc9a88938be38990960a3afcb52589a370" alt="../_images/83561927ea8c58ed3d27369a740324bacdedf69004ec753518da203195e3d594.png"
File sVDfTtZgioxYQJrk6clSeI/left/1354190961832569_left_2.jpg with ground truth attributes {'building_completeness': 'complete', 'building_condition': 'poor', 'building_material': 'mix-other-unclear', 'building_security': 'secured', 'building_use': 'residential', 'occluded': False, 'rotation': 0.0}
data:image/s3,"s3://crabby-images/76097/760971282351abffa3de002edf60cb297f7a5fb9" alt="../_images/a7819839a5b52294bdc57cdd8c54bd910e29f576892d74a4d5a66e9f340eef37.png"
File r7tGTdge9ZQXVPAcNbiEnu/left/2648202768667523_left_2.jpg with ground truth attributes {'building_completeness': 'complete', 'building_condition': 'fair', 'building_material': 'plaster', 'building_security': 'secured', 'building_use': 'residential', 'occluded': False, 'rotation': 0.0}
data:image/s3,"s3://crabby-images/374c0/374c008434721f8e28b94390fabd9783381313a4" alt="../_images/3055089b6c5c0548ae71fe1a329bfdca3b48d59ba3f17d0460cbd4b7c1a82111.png"
File rb3QV25WsP0NdylOHZSzxE/left/1498491417358362_left_2.jpg with ground truth attributes {'building_completeness': 'complete', 'building_condition': 'fair', 'building_material': 'mix-other-unclear', 'building_security': 'secured', 'building_use': 'residential', 'occluded': False, 'rotation': 0.0}
data:image/s3,"s3://crabby-images/b4217/b4217684529a161de99aac4a30a13fdeec0a450b" alt="../_images/faf2c4b4741786d480cf9c1dc970e7574b8407b2b22e7d58b36468f49b2a125c.png"
File sVDfTtZgioxYQJrk6clSeI/left/345064584539357_left.jpg with ground truth attributes {'building_completeness': 'complete', 'building_condition': 'fair', 'building_material': 'mix-other-unclear', 'building_security': 'unsecured', 'building_use': 'residential', 'occluded': False, 'rotation': 0.0}
data:image/s3,"s3://crabby-images/8aba0/8aba0a54103977e7a3e12dfc5758ec6747fcd2af" alt="../_images/d760a42927cc544515242d66365e8dac95cb446657efff78d6696de44415badd.png"
This is a peek into what the predictions csv output looks like#
df_preds.head(10)
image_name | boxes | box_scores | complete | condition | material | security | use | image_name_clip | |
---|---|---|---|---|---|---|---|---|---|
0 | r7tGTdge9ZQXVPAcNbiEnu/left/1746355092495649_l... | [713.345458984375, 253.937255859375, 1022.2023... | tensor(0.9842, device='cuda:0') | complete | fair | plaster | unsecured | residential | r7tGTdge9ZQXVPAcNbiEnu/left/1746355092495649_l... |
1 | r7tGTdge9ZQXVPAcNbiEnu/left/1796054077478724_l... | [482.8565368652344, 392.7989501953125, 1019.79... | tensor(0.9918, device='cuda:0') | complete | good | plaster | unsecured | residential | r7tGTdge9ZQXVPAcNbiEnu/left/1796054077478724_l... |
2 | r7tGTdge9ZQXVPAcNbiEnu/left/1797722660681962_l... | [107.59165954589844, 127.55589294433594, 682.8... | tensor(0.9996, device='cuda:0') | complete | poor | mix-other-unclear | unsecured | residential | r7tGTdge9ZQXVPAcNbiEnu/left/1797722660681962_l... |
3 | r7tGTdge9ZQXVPAcNbiEnu/left/1939599736440639_l... | [0.5353906154632568, 246.79296875, 391.3089904... | tensor(0.9970, device='cuda:0') | complete | fair | plaster | unsecured | residential | r7tGTdge9ZQXVPAcNbiEnu/left/1939599736440639_l... |
4 | r7tGTdge9ZQXVPAcNbiEnu/left/1939599736440639_l... | [378.7174072265625, 15.589765548706055, 827.27... | tensor(0.9717, device='cuda:0') | incomplete | poor | mix-other-unclear | unsecured | mixed | r7tGTdge9ZQXVPAcNbiEnu/left/1939599736440639_l... |
5 | r7tGTdge9ZQXVPAcNbiEnu/left/1939599736440639_l... | [0.5353906154632568, 246.79296875, 391.3089904... | tensor(0.9970, device='cuda:0') | complete | poor | plaster | unsecured | commercial | r7tGTdge9ZQXVPAcNbiEnu/left/1939599736440639_l... |
6 | r7tGTdge9ZQXVPAcNbiEnu/left/1939599736440639_l... | [378.7174072265625, 15.589765548706055, 827.27... | tensor(0.9717, device='cuda:0') | incomplete | poor | mix-other-unclear | unsecured | mixed | r7tGTdge9ZQXVPAcNbiEnu/left/1939599736440639_l... |
7 | r7tGTdge9ZQXVPAcNbiEnu/left/1966801483681588_l... | [271.408203125, 437.2467041015625, 1021.005737... | tensor(0.9931, device='cuda:0') | incomplete | poor | corrugated_metal | unsecured | residential | r7tGTdge9ZQXVPAcNbiEnu/left/1966801483681588_l... |
8 | r7tGTdge9ZQXVPAcNbiEnu/left/1988227758197966_l... | [76.96910095214844, 54.42814254760742, 836.309... | tensor(0.9789, device='cuda:0') | complete | good | plaster | secured | residential | r7tGTdge9ZQXVPAcNbiEnu/left/1988227758197966_l... |
9 | r7tGTdge9ZQXVPAcNbiEnu/left/2012497842450451_l... | [631.1368408203125, 273.9692077636719, 1022.17... | tensor(0.9783, device='cuda:0') | complete | poor | mix-other-unclear | secured | residential | r7tGTdge9ZQXVPAcNbiEnu/left/2012497842450451_l... |