import numpy as np
import cv2
import matplotlib.pyplot as plt
from collections import deque
import torch
from ultralytics import YOLO

class KalmanFilter:
    def __init__(self):
        # State vector [x, y, vx, vy]
        self.x = np.array([[0], [0], [0], [0]])
        
        # State covariance matrix
        self.P = np.eye(4) * 1000
        
        # State transition matrix (assuming constant velocity model)
        self.F = np.array([[1, 0, 1, 0],
                           [0, 1, 0, 1],
                           [0, 0, 1, 0],
                           [0, 0, 0, 1]])
        
        # Measurement matrix (we only observe position)
        self.H = np.array([[1, 0, 0, 0],
                           [0, 1, 0, 0]])
        
        # Initial process and measurement noise
        self.Q = np.eye(4) * 0.01  # Process noise covariance
        self.R = np.eye(2) * 1.0   # Measurement noise covariance
        
        # Identity matrix
        self.I = np.eye(4)
        
        # Store innovation statistics
        self.innovation_stats = []
    
    def predict(self):
        # Predict next state
        self.x = self.F @ self.x
        self.P = self.F @ self.P @ self.F.T + self.Q
        return self.x[:2].flatten()
    
    def update(self, z):
        # Compute Kalman gain
        S = self.H @ self.P @ self.H.T + self.R
        K = self.P @ self.H.T @ np.linalg.inv(S)
        
        # Compute innovation
        innovation = z.reshape(2, 1) - (self.H @ self.x)
        innovation_norm = np.linalg.norm(innovation)
        self.innovation_stats.append(innovation_norm)
        
        # Update state and covariance
        self.x = self.x + K @ innovation
        self.P = (self.I - K @ self.H) @ self.P
        
        return self.x[:2].flatten()

# Real-time object tracking using YOLO
def track_object():
    cap = cv2.VideoCapture(0)  # Open webcam
    kf = KalmanFilter()
    model = YOLO("yolov8n.pt")  # Load YOLO model
    
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        
        # Run YOLO detection
        results = model(frame)
        detections = results[0].boxes.xyxy.cpu().numpy()  # Extract bounding boxes
        
        if len(detections) > 0:
            # Assume the first detection is the target object
            x1, y1, x2, y2 = detections[0]
            centroid = np.array([(x1 + x2) // 2, (y1 + y2) // 2])
            
            # Predict and update Kalman filter
            predicted = kf.predict()
            estimated = kf.update(centroid)
            
            # Draw detection and tracking
            cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
            cv2.circle(frame, (int(centroid[0]), int(centroid[1])), 5, (0, 0, 255), -1)  # Measured position
            cv2.circle(frame, (int(estimated[0]), int(estimated[1])), 5, (255, 0, 0), -1)  # Estimated position
        
        cv2.imshow('Tracking', frame)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
    
    cap.release()
    cv2.destroyAllWindows()
    
    # Plot innovation statistics
    plt.figure(figsize=(8, 5))
    plt.plot(kf.innovation_stats, label='Innovation Magnitude')
    plt.xlabel('Frame')
    plt.ylabel('Innovation Norm')
    plt.title('Kalman Filter Innovation Over Time')
    plt.legend()
    plt.show()

if __name__ == "__main__":
    track_object()