Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions endaq/calc/kalman.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import pandas as pd
from typing import List, Optional, Iterable, Union, Callable
import numpy as np

"""
A mutable class that performs a Kalman filter, implementation based on
https://thekalmanfilter.com/kalman-filter-explained-simply/.
A new `Kalman` class should be instantialized for filter operations on different inputs.
"""
class Kalman:

"""
Initialization method for the Kalman class.
:param inputs: dataset to be smoothed out by the Kalman filter with an expected form
of `endaq.ide.to_pandas`. All parameters in the dataset will be assumed to be used.
:param name: the name of the row that will hold the data generated by the kalman filter.
Set to "filtered data" by default.
:param A: the initial state transition matrix.
Set to None (default) if no preset matrix has been previously calculated.
:param H: the initial state to measurement matrix.
Set to None (default) if no preset matrix has been previously calculated.
:param Q: the initial process noise covariance matrix.
Set to None (default) if no preset matrix has been previously calculated.
"""
def __init__(
self,
inputs: pd.DataFrame,
name: str = "filtered data",
/,
A: Optional[np.array] = None,
H: Optional[np.array] = None,
Q: Optional[np.array] = None,
):
raise NotImplementedError()

#============= Optional Getters =============#

"""
Getter for the weight parameters.
:return: a 3-tuple composed of the immutable copies of A, H, and Q.
"""
def get_weights(self) -> tuple[np.array, np.array, np.array]:
raise NotImplementedError()

"""
Getter for the Kalman Gain associated with this Object.
:return: an immutable float representing the Kalman gain.
"""
def get_gain(self) -> float:
raise NotImplementedError()

#============= Public Methods =============#

"""
Runs a Kalman filter on all steps for `inputs` set in :py:func:`__init__`.
:return: a pandas `Series` object, with original timesteps and the computed values.
If insufficient enough data, the original Dataset will be returned instead.
"""
def run(self) -> Union[pd.Series | pd.DataFrame]:
raise NotImplementedError()

"""
Predicts the filtered value generated by the Kalman filter at the given point.
:param z: the measurement vector to predict on.
"""
def predict(self, z: np.array) -> pd.array:
raise NotImplementedError()

#============= Helper Methods =============#

"""
Helper method for :py:func:`run`. Updates the internal weights used by the Kalman filter.
:param i: the step that the filter is on.
:return: None. The method mutates data internal to the Kalman class.
"""
def _update(self, i: int) -> None:
raise NotImplementedError()


"""
Helper method for :py:func:`_update`. Performs the definition of a derivaitive
on a point in a dataset and it's right neighbor.
:param iloc: location of the point to take a derivative at.
:return: a float representation of the derivative.
"""
def _derivative(self, iloc: int) -> float:
raise NotImplementedError()

"""
Helper method for :py:func:`_update`. Computes the Kalman gain based on
the most recent covariance matrices.
:return: the computed Kalman Gain.
"""
def _compute_gain(self) -> np.array:
raise NotImplementedError()