TorchGeo: An Introduction to Object Detection Example
parmarjatin4911@gmail.com
Posted on August 23, 2024
TorchGeo is a PyTorch domain library similar to torchvision, specialized for geospatial data. It offers datasets, samplers, transformations, and pre-trained models tailored for geospatial information. This tutorial will introduce an object detection example in TorchGeo.
This example is based on the notebookCode demo
provided by Caleb Robinson from Microsoft AI for Good, with added explanations.
GPU Selection
Before starting the training, select Runtime > Change runtime type from the notebook’s top menu, then choose GPU from the Hardware accelerator menu and save. I subscribe to Colab Pro and used an A100 GPU.
Installing TorchGeo
TorchGeo installs only a set of essential dependencies by default with pip install torchgeo to keep the installation relatively lightweight. For a full installation that includes an optional set of dependencies, you can use pip install torchgeo[datasets].
pip install torchgeo: Installs the "Required" set of dependencies.
pip install torchgeo[datasets]: Full installation that includes the "Optional" set of dependencies.
%pip install -q -U torchgeo[datasets]
Installing PyTorch Lightning
PyTorch Lightning offers a high-level interface for PyTorch, simplifying and streamlining the model training process.
!pip install -q -U pytorch-lightning
Downloading the VHR-10 Dataset
The VHR-10 dataset, provided by Northwestern Polytechnical University (NWPU) in China, is a Very High Resolution (VHR) remote sensing image dataset encompassing 10 classes.
Comprising a total of 800 VHR optical remote sensing images, 715 of these color images were acquired from Google Earth with spatial resolutions ranging from 0.5 to 2 meters. The remaining 85 pan-sharpened Color InfraRed (CIR) images, derived from the Vaihingen dataset, boast a spatial resolution of 0.08 meters.
Note: Pan-sharpening is a technique that combines the high-resolution detail of the panchromatic band with the lower resolution color information of other bands.
The dataset is divided into two sets:
Positive image set (650 images): Images containing at least one object.
Negative image set (150 images): Images without any objects.
The positive image set includes objects from the following ten classes:
Airplanes (757 instances)
Ships (302 instances)
Storage tanks (655 instances)
Baseball diamonds (390 instances)
Tennis courts (524 instances)
Basketball courts (159 instances)
Ground track fields (163 instances)
Harbors (224 instances)
Bridges (124 instances)
Vehicles (477 instances)
The dataset includes object detection bounding boxes and instance segmentation masks.
When using this dataset for research, please cite the following papers:
12[3]
https://doi.org/10.3390/rs12060989)
Import the necessary libraries and download the VHR-10 dataset.
import torchgeo
from torchgeo.datasets import VHR10
from torchgeo.trainers import ObjectDetectionTask
import torch
from torch.utils.data import DataLoader
import lightning.pytorch as pl
import matplotlib.pyplot as plt
import os, gdown
os.makedirs('data/VHR10/', exist_ok=True)
url = 'https://drive.google.com/uc?id=1--foZ3dV5OCsqXQXT84UeKtrAqc5CkAE'
output_path = 'data/VHR10/NWPU VHR-10 dataset.rar'
gdown.download(url, output_path, quiet=False)
The code below downloads the VHR10 dataset to the data/VHR10/ location with the 'positive' split, and applies transformations using a given preprocessing function. The role of the preprocess function is to convert each image to a floating-point type and normalize the values to be between 0 and 1 by dividing by 255. This process helps prepare the data in a form suitable for model training.
def preprocess(sample):
sample["image"] = sample["image"].float() / 255.0
return sample
ds = VHR10(
root="data/VHR10/",
split="positive",
transforms=preprocess,
download=True,
checksum=True,
)
Exploring the VHR-10 Dataset
Let’s now take a look at the positive image set. There are a total of 650 images.
print(f"VHR-10 dataset: {len(ds)}")
Access the first item in the ds dataset and check the dimensions of the image corresponding to the "image" key. The shape attribute returns the dimensions of the image, which are represented in the format (number of channels, height, width).
ds[0]["image"].shape
torch.Size([3, 808, 958])
Let’s open the image of the sixth item at random. Change the dimensions of the image from (number of channels, height, width) to (height, width, number of channels) before displaying it.
image = ds[5]["image"].permute(1, 2, 0)
plt.imshow(image)
plt.show()
The VHR10 dataset in TorchGeo includes a plot method that visualizes annotations contained in annotations.json. These annotations provide information for object detection and instance segmentation, including visual elements like bounding boxes and masks. This allows for direct visualization of images with ground truth annotations marked on them.
ds.plot(ds[5])
plt.savefig('ground_truth.png', bbox_inches='tight')
plt.show()
Model Training
The following demonstrates the process of using PyTorch’s DataLoader to define a collate_fn function and utilize it to load data:
The collate_fn function extracts the image, boxes, labels, and masks from each item in a given batch and constructs a new batch dictionary. This newly constructed batch is then directly used for model training or evaluation. The DataLoader prepares the data in batch units for the model using this function. The option shuffle=True randomly shuffles the order of the dataset to reduce dependency on the sequence of data during model training.
def collate_fn(batch):
new_batch = {
"image": [item["image"] for item in batch], # Images
"boxes": [item["boxes"] for item in batch], # Bounding boxes
"labels": [item["labels"] for item in batch], # Labels
"masks": [item["masks"] for item in batch], # Masks
}
return new_batch # Return the new batch
Data Loader
dl = DataLoader(
ds, # Dataset
batch_size=32, # Number of data to load at one time
num_workers=2, # Number of processes to use for data loading
shuffle=True, # Whether to shuffle the dataset before loading
collate_fn=collate_fn, # collate_fn function for batch processing
)
This code defines a training class for object detection tasks and creates an instance of it. Importantly, the class is designed to handle variable-sized inputs.
The VariableSizeInputObjectDetectionTask class inherits from a standard ObjectDetectionTask and defines the training_step method to process variable-sized input images within each batch. This allows the model to effectively learn from input images of various sizes. The created instance is prepared to perform object detection tasks using the Faster R-CNN model with specified settings.
class VariableSizeInputObjectDetectionTask(ObjectDetectionTask):
# Define the training step
def training_step(self, batch, batch_idx, dataloader_idx=0):
x = batch["image"] # Image
batch_size = len(x) # Set batch size (number of images)
y = [
{"boxes": batch["boxes"][i], "labels": batch["labels"][i]}
for i in range(batch_size)
] # Extract bounding box and label information for each image
loss_dict = self(x, y) # Loss
train_loss: Tensor = sum(loss_dict.values()) # Training loss (sum of loss values)
self.log_dict(loss_dict) # Record loss values
return train_loss # Return training loss
task = VariableSizeInputObjectDetectionTask(
model="faster-rcnn", # Faster R-CNN model
backbone="resnet18", # ResNet18 neural network architecture
weights=True, # Use pretrained weights
in_channels=3, # Number of channels in the input image (RGB images)
num_classes=11, # Number of classes to classify (10 + background)
trainable_layers=3, # Number of trainable layers
lr=1e-3, # Learning rate
patience=10, # Set the number of patience iterations for early stopping
freeze_backbone=False, # Whether to train with the backbone network weights unfrozen
)
task.monitor = "loss_classifier" # Set the metric to monitor (here, the classifier's loss)
Prepare the configuration for model training using the PyTorch Lightning library. The settings below specify training the model with GPU, saving training logs and checkpoints to the ‘logs/’ directory, and setting the training to run for a minimum of 6 epochs and a maximum of 100 epochs.
trainer = pl.Trainer(
default_root_dir="logs/", # Set the default directory
accelerator="gpu", # Set the type of hardware accelerator for training (using GPU)
devices=[0], # List of device IDs to use ([0] means the first GPU)
min_epochs=6, # Set the minimum number of training epochs
max_epochs=100, # Set the maximum number of training epochs
log_every_n_steps=20, # Set how often to log after a number of steps
)
%%time
Model training
trainer.fit(task, train_dataloaders=dl)
Model Inference Example
Retrieve the next batch from the data loader (dl).
batch = next(iter(dl))
Obtain the model from the task (task) and set it to evaluation mode. This action deactivates specific layers, such as Dropout, that are used during training. Dropout is a highly effective regularization technique designed to prevent overfitting in neural networks.
Using torch.no_grad() disables gradient calculations, reducing memory usage and increasing computation speed. This is utilized when the model is not being updated during evaluation or inference phases. Now, pass the image batch through the model to obtain the prediction results.
model = task.model
model.eval()
with torch.no_grad():
out = model(batch["image"])
Define a sample for a specific batch index.
def create_sample(batch, out, batch_idx):
return {
"image": batch["image"][batch_idx], # Image
"boxes": batch["boxes"][batch_idx], # Actual bounding boxes
"labels": batch["labels"][batch_idx], # Actual labels
"masks": batch["masks"][batch_idx], # Actual masks
"prediction_labels": out[batch_idx]["labels"], # Labels predicted by the model
"prediction_boxes": out[batch_idx]["boxes"], # Bounding boxes predicted by the model
"prediction_scores": out[batch_idx]["scores"], # Confidence scores for each prediction
}
batch_idx = 0
sample = create_sample(batch, out, batch_idx)
Now, visualize the given sample. The plot method visualizes the image included in the sample, along with the actual labels and bounding boxes, and the predicted labels and bounding boxes.
ds.plot(sample)
plt.savefig('inference.png', bbox_inches='tight')
plt.show()
Visualizing Sample for Batch Index 3
batch_idx = 3
sample = create_sample(batch, out, batch_idx)
ds.plot(sample)
plt.show()
Visualizing Sample for Batch Index 5
batch_idx = 5
sample = create_sample(batch, out, batch_idx)
ds.plot(sample)
plt.show()
Posted on August 23, 2024
Join Our Newsletter. No Spam, Only the good stuff.
Sign up to receive the latest update from our blog.
Related
November 29, 2024