## Solution of lab session on nonlinear state estimation

import numpy as np
import matplotlib.pyplot as plt

# Loading data

datafile=np.load('data1_ekf_mr.npz')  # try also data2_ekf_mr.npz
U=datafile['U']
Y=datafile['Y']
Q=datafile['Q']
R=datafile['R']
Ts=datafile['Ts']
t=datafile['t']
X=datafile['X']


# Data plot
plt.figure(figsize=(10,4))
plt.subplot(121)
plt.plot(t, Y[:, 0])
plt.xlabel('$t$')
plt.ylabel('$d(t)$')
plt.title('distance measurements')
plt.subplot(122)
plt.plot(t, Y[:, 1] * 180 / np.pi)
plt.xlabel('$t$')
plt.ylabel(r'$\alpha(t)$')
plt.title('angle measurements')
plt.show()


# Initialize variables
N = len(t)
Xest = np.zeros((N, 3))
D = np.zeros((N, 3))
actMSE = np.zeros(N)
estMSE = np.zeros(N)
ur = U[:, 0]
ua = U[:, 1]

# Extended Kalman Filter
z0 = np.array([10, 0, 0])  # z(0|-1)
P0 = np.diag([1000, 1000, (5 / 180 * np.pi) ** 2])  # P(0|-1)
zkp = z0
Pp = P0

for i in range(N):
    # Correction
    Hden = np.sqrt(zkp[0] ** 2 + zkp[1] ** 2)
    H = np.array([
        [zkp[0] / Hden, zkp[1] / Hden, 0],
        [-zkp[1] / Hden ** 2, zkp[0] / Hden ** 2, 0]
    ])
    K = Pp @ H.T @ np.linalg.inv(H @ Pp @ H.T + R)
    P = Pp - K @ H @ Pp
    zk = zkp + K @ (Y[i, :] - np.array([np.sqrt(zkp[0] ** 2 + zkp[1] ** 2), np.arctan2(zkp[1], zkp[0])]))
    
    # Prediction
    zkp = zk + Ts * np.array([ur[i] * np.cos(zk[2]), ur[i] * np.sin(zk[2]), ua[i]])
    F = np.array([
        [1, 0, -Ts * ur[i] * np.sin(zk[2])],
        [0, 1, Ts * ur[i] * np.cos(zk[2])],
        [0, 0, 1]
    ])
    Pp = F @ P @ F.T + Q

    Xest[i, :] = zk
    D[i, :] = np.diag(P)
    diffth = min([np.mod(X[i, 2] - Xest[i, 2], 2 * np.pi), np.mod(Xest[i, 2] - X[i, 2], 2 * np.pi)])
    actMSE[i] = np.sqrt((X[i, 0] - Xest[i, 0]) ** 2 + (X[i, 1] - Xest[i, 1]) ** 2 + diffth ** 2)
    estMSE[i] = np.sqrt(np.trace(P))


## Plots and comparisons

# State estimates
plt.figure(figsize=(6,12))
plt.subplot(311)
plt.plot(t, X[:, 0], 'g', t, Xest[:, 0], 'r--')
plt.ylabel('$x(t)$')
plt.title('$x(t)$: true (green); estimate (red)')
plt.subplot(312)
plt.plot(t, X[:, 1], 'g', t, Xest[:, 1], 'r--')
plt.ylabel('$y(t)$')
plt.title('$y(t)$: true (green); estimate (red)')
plt.subplot(313)
plt.plot(t, X[:, 2]*180/np.pi, 'g', t, Xest[:, 2]*180/np.pi, 'r--')
plt.xlabel('$t$')
plt.ylabel(r'$\theta(t)$')
plt.title(r'$\theta(t)$: true (green); estimate (red)')
plt.show()

# Trajectories
plt.figure(figsize=(6,6))
plt.plot(Y[:, 0] * np.cos(Y[:, 1]), Y[:, 0] * np.sin(Y[:, 1]), 'c:')  # Radar
plt.plot(X[:, 0], X[:, 1], 'b')
plt.plot(X[0, 0], X[0, 1], 'bo')
plt.plot([z0[0]] + list(Xest[:, 0]), [z0[1]] + list(Xest[:, 1]), 'r--')
plt.plot(z0[0], z0[1], 'ro')
plt.title('Trajectories: true (blue), EKF (red), radar (light blue)')
plt.axis('equal')
plt.show()

# Comparison between true mean square error and error predicted by EKF
plt.figure(figsize=(6,4))
plt.plot(t, actMSE, 'g', t, estMSE, 'r')
plt.xlabel('$t$')
plt.ylabel('MSE')
plt.title('True MSE (green) and EKF MSE estimate (red)')
plt.show()

# Confidence intervals
plt.figure(figsize=(6,12))
plt.subplot(311)
plt.plot(t, X[:, 0] - Xest[:, 0], 'r', t, 3 * np.sqrt(D[:, 0]), 'b--', t, -3 * np.sqrt(D[:, 0]), 'b--')
plt.ylabel(r'$x(t)-\hat{x}(t)$')
title = r'Estimation error (red) and $3\sigma$-confidence intervals (blue) for $x(t)$'
plt.title(title)
plt.axis([t[0], t[-1], -np.sqrt(D[-1, 0]) * 20, np.sqrt(D[-1, 0]) * 20])
plt.subplot(312)
plt.plot(t, X[:, 1] - Xest[:, 1], 'r', t, 3 * np.sqrt(D[:, 1]), 'b--', t, -3 * np.sqrt(D[:, 1]), 'b--')
plt.ylabel(r'$y(t)-\hat{y}(t)$')
title = r'Estimation error (red) and $3\sigma$-confidence intervals (blue) for $y(t)$'
plt.title(title)
plt.axis([t[0], t[-1], -np.sqrt(D[-1, 1]) * 20, np.sqrt(D[-1, 1]) * 20])
plt.subplot(313)
plt.plot(t, 180 / np.pi * np.unwrap(X[:, 2] - Xest[:, 2]), 'r', t, 3 * 180 / np.pi * np.sqrt(D[:, 2]), 'b--', t, -3 * 180 / np.pi * np.sqrt(D[:, 2]), 'b--')
plt.xlabel('$t$')
plt.ylabel(r'$\theta(t)-\hat{\theta}(t)$')
plt.title(r'Estimation error (red) and $3\sigma$-confidence intervals (blue) for $\theta(t)$')
#plt.axis([t[0], t[-1], -np.sqrt(D[-1, 2]) * 20, np.sqrt(D[-1, 2]) * 20])
plt.show()

