import numpy as np
import cv2
import matplotlib.pyplot as plt
from collections import deque
import torch
from ultralytics import YOLO
class KalmanFilter:
def __init__(self):
# State vector [x, y, vx, vy]
self.x = np.array([[0], [0], [0], [0]])
# State covariance matrix
self.P = np.eye(4) * 1000
# State transition matrix (assuming constant velocity model)
self.F = np.array([[1, 0, 1, 0],
[0, 1, 0, 1],
[0, 0, 1, 0],
[0, 0, 0, 1]])
# Measurement matrix (we only observe position)
self.H = np.array([[1, 0, 0, 0],
[0, 1, 0, 0]])
# Initial process and measurement noise
self.Q = np.eye(4) * 0.01 # Process noise covariance
self.R = np.eye(2) * 1.0 # Measurement noise covariance
# Identity matrix
self.I = np.eye(4)
# Store innovation statistics
self.innovation_stats = []
def predict(self):
# Predict next state
self.x = self.F @ self.x
self.P = self.F @ self.P @ self.F.T + self.Q
return self.x[:2].flatten()
def update(self, z):
# Compute Kalman gain
S = self.H @ self.P @ self.H.T + self.R
K = self.P @ self.H.T @ np.linalg.inv(S)
# Compute innovation
innovation = z.reshape(2, 1) - (self.H @ self.x)
innovation_norm = np.linalg.norm(innovation)
self.innovation_stats.append(innovation_norm)
# Update state and covariance
self.x = self.x + K @ innovation
self.P = (self.I - K @ self.H) @ self.P
return self.x[:2].flatten()
# Real-time object tracking using YOLO
def track_object():
cap = cv2.VideoCapture(0) # Open webcam
kf = KalmanFilter()
model = YOLO("yolov8n.pt") # Load YOLO model
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Run YOLO detection
results = model(frame)
detections = results[0].boxes.xyxy.cpu().numpy() # Extract bounding boxes
if len(detections) > 0:
# Assume the first detection is the target object
x1, y1, x2, y2 = detections[0]
centroid = np.array([(x1 + x2) // 2, (y1 + y2) // 2])
# Predict and update Kalman filter
predicted = kf.predict()
estimated = kf.update(centroid)
# Draw detection and tracking
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
cv2.circle(frame, (int(centroid[0]), int(centroid[1])), 5, (0, 0, 255), -1) # Measured position
cv2.circle(frame, (int(estimated[0]), int(estimated[1])), 5, (255, 0, 0), -1) # Estimated position
cv2.imshow('Tracking', frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
# Plot innovation statistics
plt.figure(figsize=(8, 5))
plt.plot(kf.innovation_stats, label='Innovation Magnitude')
plt.xlabel('Frame')
plt.ylabel('Innovation Norm')
plt.title('Kalman Filter Innovation Over Time')
plt.legend()
plt.show()
if __name__ == "__main__":
track_object()