Skip to content
Snippets Groups Projects
Commit f07746d2 authored by David Verelst's avatar David Verelst
Browse files

prepost.mplutils: simple helper function for creating a PSD plot

parent 996ef63b
No related branches found
No related tags found
No related merge requests found
...@@ -278,6 +278,38 @@ def match_yticks(ax1, ax2, nr_ticks_forced=None, extend=False): ...@@ -278,6 +278,38 @@ def match_yticks(ax1, ax2, nr_ticks_forced=None, extend=False):
return ax1, ax2 return ax1, ax2
def psd(ax, time, sig, nfft=None, res_param=250, f0=0, f1=None, nr_peaks=10,
min_h=15, mark_peaks=False, col='r-', label=None, alpha=1.0,
ypos_peaks=0.9, ypos_peaks_delta=0.12):
"""Only plot the psd on a given axis and optionally mark the peaks.
"""
sps = int(round(1.0/np.diff(time).mean(), 0))
if f1 is None:
f1 = sps/2.0
if nfft is None:
nfft = int(round(res_param * sps / (f1-f0), 0))
if nfft > len(sig):
nfft = len(sig)
# calculate the PSD
Pxx, freqs = mpl.mlab.psd(sig, NFFT=nfft, Fs=sps)
i0 = np.abs(freqs - f0).argmin()
i1 = np.abs(freqs - f1).argmin()
# plotting psd, marking peaks
ax.plot(freqs[i0:i1], Pxx[i0:i1], col, label=label, alpha=alpha)
if mark_peaks:
ax = peaks(ax, freqs[i0:i1], Pxx[i0:i1], fn_max=f1,
nr_peaks=nr_peaks, col_line=col[:1],
ypos_delta=ypos_peaks_delta, bbox_alpha=0.5,
ypos_mean=ypos_peaks, min_h=min_h, col_text='w')
return ax
def time_psd(results, labels, axes, alphas=[1.0, 0.7], colors=['k-', 'r-'], def time_psd(results, labels, axes, alphas=[1.0, 0.7], colors=['k-', 'r-'],
NFFT=None, res_param=250, f0=0, f1=None, nr_peaks=10, min_h=15, NFFT=None, res_param=250, f0=0, f1=None, nr_peaks=10, min_h=15,
mark_peaks=False, xlabels=['frequency [Hz]', 'time [s]'], mark_peaks=False, xlabels=['frequency [Hz]', 'time [s]'],
...@@ -317,32 +349,15 @@ def time_psd(results, labels, axes, alphas=[1.0, 0.7], colors=['k-', 'r-'], ...@@ -317,32 +349,15 @@ def time_psd(results, labels, axes, alphas=[1.0, 0.7], colors=['k-', 'r-'],
label = labels[i] label = labels[i]
col = colors[i] col = colors[i]
alpha = alphas[i] alpha = alphas[i]
sps = int(round(1.0/np.diff(time).mean(), 0)) if isinstance(NFFT, list):
if f1 is None:
f1 = sps/2.0
if NFFT is None:
nfft = int(round(res_param * sps / (f1-f0), 0))
elif isinstance(NFFT, list):
nfft = NFFT[i] nfft = NFFT[i]
else: else:
nfft = NFFT nfft = NFFT
if nfft > len(data): axes[0] = psd(axes[0], time, data, nfft=nfft, res_param=res_param,
nfft = len(data) f0=f0, f1=f1, nr_peaks=nr_peaks, min_h=min_h,
mark_peaks=mark_peaks, col=col, label=label, alpha=alpha,
# calculate the PSD ypos_peaks=ypos_peaks, ypos_peaks_delta=ypos_peaks_delta)
Pxx, freqs = mpl.mlab.psd(data, NFFT=nfft, Fs=sps)
i0 = np.abs(freqs - f0).argmin()
i1 = np.abs(freqs - f1).argmin()
# plotting psd, marking peaks
axes[0].plot(freqs[i0:i1], Pxx[i0:i1], col, label=label, alpha=alpha)
if mark_peaks:
axes[0] = peaks(axes[0], freqs[i0:i1], Pxx[i0:i1], fn_max=f1,
nr_peaks=nr_peaks, col_line=col[:1],
ypos_delta=ypos_peaks_delta, bbox_alpha=0.5,
ypos_mean=ypos_peaks[i], min_h=min_h, col_text='w')
# plotting time series # plotting time series
axes[1].plot(time, data, col, label=label, alpha=alpha) axes[1].plot(time, data, col, label=label, alpha=alpha)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment