diff --git a/wetb/signal/fit/_spline_fit.py b/wetb/signal/fit/_spline_fit.py new file mode 100644 index 0000000000000000000000000000000000000000..1fa074f9f359861797b42f3652fe8947fd938645 --- /dev/null +++ b/wetb/signal/fit/_spline_fit.py @@ -0,0 +1,74 @@ +import numpy as np +def spline_fit(xp,yp): + + def akima(x, y): + n = len(x) + var = np.zeros((n + 3)) + z = np.zeros((n)) + co = np.zeros((n, 4)) + for i in range(n - 1): + var[i + 2] = (y[i + 1] - y[i]) / (x[i + 1] - x[i]) + var[n + 1] = 2 * var[n] - var[n - 1] + var[n + 2] = 2 * var[n + 1] - var[n] + var[1] = 2 * var[2] - var[3] + var[0] = 2 * var[1] - var[2] + + for i in range(n): + wi1 = abs(var[i + 3] - var[i + 2]) + wi = abs(var[i + 1] - var[i]) + if (wi1 + wi) == 0: + z[i] = (var[i + 2] + var[i + 1]) / 2 + else: + z[i] = (wi1 * var[i + 1] + wi * var[i + 2]) / (wi1 + wi) + + for i in range(n - 1): + dx = x[i + 1] - x[i] + a = (z[i + 1] - z[i]) * dx + b = y[i + 1] - y[i] - z[i] * dx + co[i, 0] = y[i] + co[i, 1] = z[i] + co[i, 2] = (3 * var[i + 2] - 2 * z[i] - z[i + 1]) / dx + co[i, 3] = (z[i] + z[i + 1] - 2 * var[i + 2]) / dx ** 2 + co[n - 1, 0] = y[n - 1] + co[n - 1, 1] = z[n - 1] + co[n - 1, 2] = 0 + co[n - 1, 3] = 0 + return co + + p_lst = [lambda x_, c=c, x0=x0: np.poly1d(c[::-1])(x_-x0) for c,x0 in zip(akima(xp,yp), xp)] + + def spline(x): + y = np.empty_like(x)+np.nan + segment = np.searchsorted(xp,x, 'right')-1 + for i in np.unique(segment): + m = segment==i + y[m] = p_lst[i](x[m]) + return y +# def coef2spline(x, xp, co): +# +# print (np.searchsorted(xp,x)-1) +# x, y = [], [] +# for i, c in enumerate(co.tolist()[:-1]): +# p = np.poly1d(c[::-1]) +# z = np.linspace(0, s[i + 1] - s[i ], 10, endpoint=i >= co.shape[0] - 2) +# x.extend(s[i] + z) +# y.extend(p(z)) +# return y +# + return spline + #x, y, z = [coef2spline(curve_z_nd, akima(curve_z_nd, self.c2def[:, i])) for i in range(3)] + #return x, y, z + +if __name__=="__main__": + import matplotlib.pyplot as plt + x = np.random.randint(0,100,10) + t = np.arange(0,100,10) + plt.plot(t,x,'.',label='points') + + t_ = np.arange(100) + spline = spline_fit(t,x) + print (np.abs(np.diff(np.diff(np.interp(t_, t,x)))).max()) + print (np.abs(np.diff(np.diff(spline(t_)))).max()) + plt.plot(t_, np.interp(t_, t,x)) + plt.plot(t_, spline(t_),label='spline') + plt.show() diff --git a/wetb/signal/tests/test_fit.py b/wetb/signal/tests/test_fit.py index 859c0cb95c31111385db25728bab055c0add7653..92c2c269a7759a9bfaeea492ff5e111c3d8a215d 100644 --- a/wetb/signal/tests/test_fit.py +++ b/wetb/signal/tests/test_fit.py @@ -10,6 +10,7 @@ import os import unittest from wetb.signal.fit import fourier_fit from wetb.signal.error_measures import rms +from wetb.signal.fit import spline_fit tfp = os.path.join(os.path.dirname(__file__), 'test_files/') class TestFit(unittest.TestCase): @@ -143,6 +144,24 @@ class TestFit(unittest.TestCase): # plt.plot(fourier_fit.F2x(np.fft.fft(y) / len(y)), label='fft') # plt.legend() # plt.show() + + def test_spline(self): + + x = np.random.randint(0,100,10) + t = np.arange(0,100,10) + + t_ = np.arange(100) + spline = spline_fit(t,x) + acc_lin = np.diff(np.diff(np.interp(t_, t,x))) + acc_spline = np.diff(np.diff(spline(t_))) + self.assertLess(np.abs(acc_spline).max(), np.abs(acc_lin).max()) + if 0: + import matplotlib.pyplot as plt + plt.plot(t,x,'.',label='points') + plt.plot(t_, spline(t_),label='spline') + plt.legend() + plt.show() + if __name__ == "__main__": #import sys;sys.argv = ['', 'Test.testName'] unittest.main() \ No newline at end of file