import torch
from torchvision.ops import nms
def non_max_suppression(boxes, scores, labels, iou_threshold=0.5):
"""
Perform Non-Maximum Suppression (NMS) per class.
Args:
boxes (Tensor[N, 4]): Bounding boxes (x_min, y_min, x_max, y_max)
scores (Tensor[N]): Confidence scores
labels (Tensor[N]): Class labels
iou_threshold (float): IoU threshold for suppression
Returns:
final_boxes (Tensor[M, 4]): Filtered bounding boxes after NMS
final_scores (Tensor[M]): Filtered scores
final_labels (Tensor[M]): Filtered labels
"""
= []
final_boxes = []
final_scores = []
final_labels
= labels.unique() # Get unique class labels
unique_classes
for cls in unique_classes:
# Select boxes, scores for the current class
= (labels == cls).nonzero(as_tuple=True)[0]
class_indices = boxes[class_indices]
boxes_c = scores[class_indices]
scores_c
# Perform NMS for this class
= nms(boxes_c, scores_c, iou_threshold)
keep_indices
# Append selected boxes, scores, and labels
final_boxes.append(boxes_c[keep_indices])
final_scores.append(scores_c[keep_indices])
final_labels.append(labels[class_indices][keep_indices])
# Concatenate results across all classes
if final_boxes:
= torch.cat(final_boxes, dim=0)
final_boxes = torch.cat(final_scores, dim=0)
final_scores = torch.cat(final_labels, dim=0)
final_labels else:
= torch.empty((0, 4))
final_boxes = torch.empty((0,))
final_scores = torch.empty((0,), dtype=torch.long)
final_labels
return final_boxes, final_scores, final_labels
NMS Example
In [1]:
In [2]:
# Example usage
= torch.tensor([[10, 20, 40, 60], [12, 22, 42, 62], [100, 200, 150, 250]], dtype=torch.long).float()
boxes = torch.tensor([0.9, 0.8, 0.85])
scores = torch.tensor([1, 1, 2]) # Class 1 for first two, Class 2 for last one
labels
= non_max_suppression(boxes, scores, labels, iou_threshold=0.5)
filtered_boxes, filtered_scores, filtered_labels
print("Filtered Boxes:", filtered_boxes)
print("Filtered Scores:", filtered_scores)
print("Filtered Labels:", filtered_labels)
Filtered Boxes: tensor([[ 10., 20., 40., 60.],
[100., 200., 150., 250.]])
Filtered Scores: tensor([0.9000, 0.8500])
Filtered Labels: tensor([1, 2])