Back to Article
NMS Example
Download Notebook

NMS Example

In [1]:
import torch
from torchvision.ops import nms

def non_max_suppression(boxes, scores, labels, iou_threshold=0.5):
    """
    Perform Non-Maximum Suppression (NMS) per class.

    Args:
        boxes (Tensor[N, 4]): Bounding boxes (x_min, y_min, x_max, y_max)
        scores (Tensor[N]): Confidence scores
        labels (Tensor[N]): Class labels
        iou_threshold (float): IoU threshold for suppression

    Returns:
        final_boxes (Tensor[M, 4]): Filtered bounding boxes after NMS
        final_scores (Tensor[M]): Filtered scores
        final_labels (Tensor[M]): Filtered labels
    """
    final_boxes = []
    final_scores = []
    final_labels = []
    
    unique_classes = labels.unique()  # Get unique class labels
    
    for cls in unique_classes:
        # Select boxes, scores for the current class
        class_indices = (labels == cls).nonzero(as_tuple=True)[0]
        boxes_c = boxes[class_indices]
        scores_c = scores[class_indices]
        
        # Perform NMS for this class
        keep_indices = nms(boxes_c, scores_c, iou_threshold)
        
        # Append selected boxes, scores, and labels
        final_boxes.append(boxes_c[keep_indices])
        final_scores.append(scores_c[keep_indices])
        final_labels.append(labels[class_indices][keep_indices])
    
    # Concatenate results across all classes
    if final_boxes:
        final_boxes = torch.cat(final_boxes, dim=0)
        final_scores = torch.cat(final_scores, dim=0)
        final_labels = torch.cat(final_labels, dim=0)
    else:
        final_boxes = torch.empty((0, 4))
        final_scores = torch.empty((0,))
        final_labels = torch.empty((0,), dtype=torch.long)
    
    return final_boxes, final_scores, final_labels
In [2]:
# Example usage
boxes = torch.tensor([[10, 20, 40, 60], [12, 22, 42, 62], [100, 200, 150, 250]], dtype=torch.long).float()
scores = torch.tensor([0.9, 0.8, 0.85])
labels = torch.tensor([1, 1, 2])  # Class 1 for first two, Class 2 for last one

filtered_boxes, filtered_scores, filtered_labels = non_max_suppression(boxes, scores, labels, iou_threshold=0.5)

print("Filtered Boxes:", filtered_boxes)
print("Filtered Scores:", filtered_scores)
print("Filtered Labels:", filtered_labels)
Filtered Boxes: tensor([[ 10.,  20.,  40.,  60.],
        [100., 200., 150., 250.]])
Filtered Scores: tensor([0.9000, 0.8500])
Filtered Labels: tensor([1, 2])
Back to top