#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Mar 27 13:06:18 2025

@author: andreagarulli
"""

import numpy as np
import matplotlib.pyplot as plt
import control as clt
import scipy
import types

plt.close('all')

data=scipy.io.loadmat('data_kf_cattle.mat')
d=types.SimpleNamespace(**data)
t=d.t 
yA=d.yA
uA=d.uA
yC=d.yC
yD=d.yD
uD=d.uD
xtrueA=d.xtrueA
xtrueD=d.xtrueD

# Define system matrices
A = np.block([[d.a1, d.a2, d.a3, 0,  0,  0],
              [d.b1, 0,  0,  0,  0,  0],
              [0, d.b2, d.b3, 0,  0,  0],
              [d.c1, d.c2, d.c3, 0,  0,  0],
              [0,  0,  0, d.d1, 0,  0],
              [0,  0,  0,  0, d.d2, d.d3]])
B = np.eye(6)
G = np.eye(6)
C = np.array([[1, 1, 1, 0, 0, 0],
              [0, 0, 0, 1, 1, 1]])
Q = d.sw2 * np.eye(6)
R = d.sv2 * np.eye(2)

N = len(t)

###
# A
###

# Initialization
x0 = np.zeros((6, 1))
P0 = (100**2) * np.eye(6)

xp = x0.copy()
Pp = P0.copy()

XestA = np.zeros((N,6))
PkfA = np.zeros((N,6))

# Kalman Filter Recursion
for k in range(N):
    K = Pp @ C.T @ np.linalg.inv(C @ Pp @ C.T + R)
    x = xp + K @ (yA[k, :].reshape(-1, 1) - C @ xp)
    P = Pp @ (np.eye(6) - C.T @ K.T)
    xp = A @ x + B @ uA[k, :].reshape(-1, 1)
    Pp = A @ P @ A.T + G @ Q @ G.T
    XestA[k,:]=x.T
    PkfA[k,:]=np.diag(P)

errorsA = xtrueA - XestA
errsumA = np.sum(errorsA**2, axis=1)
MSE_A = np.sum(PkfA, axis=1)

# Plot State Estimates
plt.figure(figsize=(10,8))
for i in range(6):
    plt.subplot(3, 2, i+1)
    plt.plot(t, xtrueA[:, i], 'g', label='True')
    plt.plot(t, XestA[:, i], 'r--', label='KF')
    plt.ylabel(f'$x{i+1}$')
plt.suptitle('A: True State (Green); KF (Red)')
plt.show()

# Mean Square Error Plot
plt.figure()
plt.plot(t, np.sqrt(errsumA), 'g', label='True Error')
plt.plot(t, np.sqrt(MSE_A), 'r', label='Expected Error')
plt.title('A: Square Error')
plt.legend()
plt.show()

# Confidence Intervals
plt.figure(figsize=(10,8))
for i in range(6):
    plt.subplot(3, 2, i+1)
    plt.plot(t, errorsA[:, i], 'g', label='Error')
    plt.plot(t, -3*np.sqrt(PkfA[:, i]), 'b--')
    plt.plot(t, 3*np.sqrt(PkfA[:, i]), 'b--')
    plt.ylabel(f'x{i+1}')
plt.suptitle('A: Estimation Errors (Green); Confidence Intervals (Blue)')
plt.legend()
plt.show()

###
# B
###

outdlqe=clt.dlqe(A,G,C,Q,R)
Kinf=outdlqe[0]


# Initialization
xp = x0.copy()
Pp = P0.copy()
XestB = np.zeros((N, 6))

# Kalman filter loop
for k in range(N):
    # Correction step
    x = xp + Kinf @ (yA[k, :].reshape(-1, 1) - C @ xp)

    # Prediction step
    xp = A @ x + B @ uA[k, :].reshape(-1, 1)

    XestB[k, :] = x.T

# Plot state estimates
fig, axes = plt.subplots(3, 2, figsize=(10, 8))
axes = axes.ravel()

for i in range(6):
    axes[i].plot(t, xtrueA[:, i], 'g', label='True State')
    axes[i].plot(t, XestA[:, i], 'r--', label='KF Estimate')
    axes[i].plot(t, XestB[:, i], 'b-', label='Asymptotic KF Estimate')
    axes[i].set_ylabel(f'x{i+1}')

plt.suptitle('B: true state (green); KF (red); asymptotic KF (blue)')
plt.show()


###
# C
###

R_fixed=300**2*np.eye(2);

# Initial conditions
x0C = np.zeros((6,1))  # x(0|-1)
P0C = (10**2) * np.eye(6)  # P(0|-1)

N = len(yC)  # Number of time steps
xp = x0C.copy()
Pp = P0C.copy()

XestC_fixedR = np.zeros((N, 6))
PkfC_fixedR = np.zeros((N, 6))

# Kalman filter with fixed R
for k in range(N):
    # Correction step
    K = Pp @ C.T @ np.linalg.inv(C @ Pp @ C.T + R_fixed)
    x = xp + K @ (yC[k, :].reshape(-1, 1) - C @ xp)
    P = Pp @ (np.eye(6) -  C.T @ K.T)
    
    # Prediction step
    xp = A @ x + B @ uA[k, :].reshape(-1, 1)
    Pp = A @ P @ A.T + G @ Q @ G.T
    
    XestC_fixedR[k, :] = x.T
    PkfC_fixedR[k, :] = np.diag(P)

# Kalman filter with time-varying R
xp = x0C.copy()
Pp = P0C.copy()
XestC_trueR = np.zeros((N, 6))
PkfC_trueR = np.zeros((N, 6))
sv = np.zeros(N)

for k in range(N):
    # Compute time-varying R
    sv[k] = 300 * (0.95 ** (k - 1))
    R = (sv[k] ** 2) * np.eye(2)
    
    # Correction step
    K = Pp @ C.T @ np.linalg.inv(C @ Pp @ C.T + R)
    x = xp + K @ (yC[k, :].reshape(-1, 1) - C @ xp)
    P = Pp @ (np.eye(6) - C.T @ K.T)
    
    # Prediction step
    xp = A @ x + B @ uA[k, :].reshape(-1, 1)
    Pp = A @ P @ A.T + G @ Q @ G.T
    
    XestC_trueR[k, :] = x.T
    PkfC_trueR[k, :] = np.diag(P)

# Error computations
errorsC_fR = xtrueA - XestC_fixedR
errsumC_fR = np.sum(errorsC_fR ** 2, axis=1)
MSE_C_fR = np.sum(PkfC_fixedR, axis=1)

errorsC_tR = xtrueA - XestC_trueR
errsumC_tR = np.sum(errorsC_tR ** 2, axis=1)
MSE_C_tR = np.sum(PkfC_trueR, axis=1)

# State Estimates Plot
fig, axes = plt.subplots(3, 2, figsize=(10, 8))
axes = axes.ravel()

for i in range(6):
    axes[i].plot(t, xtrueA[:, i], 'g', label='True State')
    axes[i].plot(t, XestC_trueR[:, i], 'r--', label='KF True R')
    axes[i].plot(t, XestC_fixedR[:, i], 'b--', label='KF Fixed R')
    axes[i].set_ylabel(f'x{i+1}')
    
plt.suptitle('C: true state (green); KF true R (red); KF fixed R (blue)')
plt.show()

# Error Comparison Plot
plt.figure()
plt.plot(t, np.sqrt(errsumC_fR), 'b', label='Fixed R Error')
plt.plot(t, np.sqrt(errsumC_tR), 'g', label='True R Error')
plt.plot(t, np.sqrt(MSE_C_fR), 'b--', label='Fixed R MSE')
plt.plot(t, np.sqrt(MSE_C_tR), 'g--', label='True R MSE')
plt.title('C: Error Comparison (blue: fixed R; green: true R)')
plt.legend()
plt.show()

# Confidence Intervals Plot
fig, axes = plt.subplots(3, 2, figsize=(10, 8))
axes = axes.ravel()

for i in range(6):
    axes[i].plot(t, errorsC_fR[:, i], 'b', label='Fixed R Error')
    axes[i].plot(t, -3 * np.sqrt(PkfC_fixedR[:, i]), 'b--')
    axes[i].plot(t, 3 * np.sqrt(PkfC_fixedR[:, i]), 'b--')
    axes[i].plot(t, errorsC_tR[:, i], 'g', label='True R Error')
    axes[i].plot(t, -3 * np.sqrt(PkfC_trueR[:, i]), 'g--')
    axes[i].plot(t, 3 * np.sqrt(PkfC_trueR[:, i]), 'g--')
    axes[i].set_ylabel(f'x{i+1}')

plt.suptitle('C: Estimation Errors (blue: fixed R; green: true R)')
plt.show()

###
# D
###

R = d.sv2 * np.eye(2)

# Kalman filter with constant A
xp = x0.copy()
Pp = P0.copy()
XestD = np.zeros((N, 6))
PkfD = np.zeros((N, 6))

for k in range(N):
    # Correction step
    K = Pp @ C.T @ np.linalg.inv(C @ Pp @ C.T + R)
    x = xp + K @ (yD[k, :].reshape(-1, 1) - C @ xp)
    P = Pp @ (np.eye(6) -  C.T @ K.T)
    
    # Prediction step
    xp = A @ x + B @ uD[k, :].reshape(-1, 1)
    Pp = A @ P @ A.T + G @ Q @ G.T

    XestD[k, :] = x.T
    PkfD[k, :] = np.diag(P)

errorsD = xtrueD - XestD
errsumD = np.sum(errorsD**2, axis=1)
MSE_D = np.sum(PkfD, axis=1)

# Kalman filter with time-varying A
xp = x0.copy()
Pp = P0.copy()
XestD_trueA = np.zeros((N, 6))
PkfD_trueA = np.zeros((N, 6))

for k in range(N):
    # Correction step
    K = Pp @ C.T @ np.linalg.inv(C @ Pp @ C.T + R)
    x = xp + K @ (yD[k, :].reshape(-1, 1) - C @ xp)
    P = Pp @ (np.eye(6) - C.T @ K.T)
    
    # Time-varying A matrix
    A[0, 1] = 0.45 + 0.3 * np.sin(2 * np.pi / 20 * k)

    # Prediction step
    xp = A @ x + B @ uD[k, :].reshape(-1, 1)
    Pp = A @ P @ A.T + G @ Q @ G.T

    XestD_trueA[k, :] = x.T
    PkfD_trueA[k, :] = np.diag(P)

errorsD_tA = xtrueD - XestD_trueA
errsumD_tA = np.sum(errorsD_tA**2, axis=1)
MSE_D_tA = np.sum(PkfD_trueA, axis=1)

# Plot results
plt.figure(figsize=(10, 8))
for i in range(6):
    plt.subplot(3, 2, i+1)
    plt.plot(t, xtrueD[:, i], 'g', label='True State')
    plt.plot(t, XestD_trueA[:, i], 'r--', label='KF True A')
    plt.plot(t, XestD[:, i], 'b--', label='KF Fixed A')
    plt.ylabel(f'x{i+1}')
        
plt.suptitle('D: True state (green); KF true A (red); KF fixed A (blue)')
plt.show()

plt.figure()
plt.plot(t, np.sqrt(errsumD), 'b', label='Fixed A')
plt.plot(t, np.sqrt(errsumD_tA), 'g', label='True A')
plt.plot(t, np.sqrt(MSE_D), 'b--', label='MSE Fixed A')
plt.plot(t, np.sqrt(MSE_D_tA), 'g--', label='MSE True A')
plt.title('D: Error Comparison')
plt.legend()
plt.show()

# Confidence Intervals
plt.figure(figsize=(10, 8))
for i in range(6):
    plt.subplot(3, 2, i+1)
    plt.plot(t, errorsD[:, i], 'b', label='Error Fixed A')
    plt.plot(t, -3 * np.sqrt(PkfD[:, i]), 'b--')
    plt.plot(t, 3 * np.sqrt(PkfD[:, i]), 'b--')
    plt.plot(t, errorsD_tA[:, i], 'g', label='Error True A')
    plt.plot(t, -3 * np.sqrt(PkfD_trueA[:, i]), 'g--')
    plt.plot(t, 3 * np.sqrt(PkfD_trueA[:, i]), 'g--')
    plt.ylabel(f'x{i+1}')
    
plt.suptitle('D: Estimation Errors (Blue: Fixed A; Green: True A)')
plt.show()
