From f07746d2f46bca0a72b3df0444c905a52d234dc2 Mon Sep 17 00:00:00 2001 From: David Robert Verelst <dave@dtu.dk> Date: Wed, 12 Jul 2017 09:50:26 +0200 Subject: [PATCH] prepost.mplutils: simple helper function for creating a PSD plot --- wetb/prepost/mplutils.py | 61 +++++++++++++++++++++++++--------------- 1 file changed, 38 insertions(+), 23 deletions(-) diff --git a/wetb/prepost/mplutils.py b/wetb/prepost/mplutils.py index 08dcf1f9..96f7ae63 100644 --- a/wetb/prepost/mplutils.py +++ b/wetb/prepost/mplutils.py @@ -278,6 +278,38 @@ def match_yticks(ax1, ax2, nr_ticks_forced=None, extend=False): return ax1, ax2 +def psd(ax, time, sig, nfft=None, res_param=250, f0=0, f1=None, nr_peaks=10, + min_h=15, mark_peaks=False, col='r-', label=None, alpha=1.0, + ypos_peaks=0.9, ypos_peaks_delta=0.12): + """Only plot the psd on a given axis and optionally mark the peaks. + """ + + sps = int(round(1.0/np.diff(time).mean(), 0)) + if f1 is None: + f1 = sps/2.0 + + if nfft is None: + nfft = int(round(res_param * sps / (f1-f0), 0)) + if nfft > len(sig): + nfft = len(sig) + + # calculate the PSD + Pxx, freqs = mpl.mlab.psd(sig, NFFT=nfft, Fs=sps) + + i0 = np.abs(freqs - f0).argmin() + i1 = np.abs(freqs - f1).argmin() + + # plotting psd, marking peaks + ax.plot(freqs[i0:i1], Pxx[i0:i1], col, label=label, alpha=alpha) + if mark_peaks: + ax = peaks(ax, freqs[i0:i1], Pxx[i0:i1], fn_max=f1, + nr_peaks=nr_peaks, col_line=col[:1], + ypos_delta=ypos_peaks_delta, bbox_alpha=0.5, + ypos_mean=ypos_peaks, min_h=min_h, col_text='w') + + return ax + + def time_psd(results, labels, axes, alphas=[1.0, 0.7], colors=['k-', 'r-'], NFFT=None, res_param=250, f0=0, f1=None, nr_peaks=10, min_h=15, mark_peaks=False, xlabels=['frequency [Hz]', 'time [s]'], @@ -317,32 +349,15 @@ def time_psd(results, labels, axes, alphas=[1.0, 0.7], colors=['k-', 'r-'], label = labels[i] col = colors[i] alpha = alphas[i] - sps = int(round(1.0/np.diff(time).mean(), 0)) - if f1 is None: - f1 = sps/2.0 - - if NFFT is None: - nfft = int(round(res_param * sps / (f1-f0), 0)) - elif isinstance(NFFT, list): + if isinstance(NFFT, list): nfft = NFFT[i] else: nfft = NFFT - if nfft > len(data): - nfft = len(data) - - # calculate the PSD - Pxx, freqs = mpl.mlab.psd(data, NFFT=nfft, Fs=sps) - - i0 = np.abs(freqs - f0).argmin() - i1 = np.abs(freqs - f1).argmin() - - # plotting psd, marking peaks - axes[0].plot(freqs[i0:i1], Pxx[i0:i1], col, label=label, alpha=alpha) - if mark_peaks: - axes[0] = peaks(axes[0], freqs[i0:i1], Pxx[i0:i1], fn_max=f1, - nr_peaks=nr_peaks, col_line=col[:1], - ypos_delta=ypos_peaks_delta, bbox_alpha=0.5, - ypos_mean=ypos_peaks[i], min_h=min_h, col_text='w') + axes[0] = psd(axes[0], time, data, nfft=nfft, res_param=res_param, + f0=f0, f1=f1, nr_peaks=nr_peaks, min_h=min_h, + mark_peaks=mark_peaks, col=col, label=label, alpha=alpha, + ypos_peaks=ypos_peaks, ypos_peaks_delta=ypos_peaks_delta) + # plotting time series axes[1].plot(time, data, col, label=label, alpha=alpha) -- GitLab