From f07746d2f46bca0a72b3df0444c905a52d234dc2 Mon Sep 17 00:00:00 2001
From: David Robert Verelst <dave@dtu.dk>
Date: Wed, 12 Jul 2017 09:50:26 +0200
Subject: [PATCH] prepost.mplutils: simple helper function for creating a PSD
 plot

---
 wetb/prepost/mplutils.py | 61 +++++++++++++++++++++++++---------------
 1 file changed, 38 insertions(+), 23 deletions(-)

diff --git a/wetb/prepost/mplutils.py b/wetb/prepost/mplutils.py
index 08dcf1f9..96f7ae63 100644
--- a/wetb/prepost/mplutils.py
+++ b/wetb/prepost/mplutils.py
@@ -278,6 +278,38 @@ def match_yticks(ax1, ax2, nr_ticks_forced=None, extend=False):
     return ax1, ax2
 
 
+def psd(ax, time, sig, nfft=None, res_param=250, f0=0, f1=None, nr_peaks=10,
+        min_h=15, mark_peaks=False, col='r-', label=None, alpha=1.0,
+        ypos_peaks=0.9, ypos_peaks_delta=0.12):
+    """Only plot the psd on a given axis and optionally mark the peaks.
+    """
+
+    sps = int(round(1.0/np.diff(time).mean(), 0))
+    if f1 is None:
+        f1 = sps/2.0
+
+    if nfft is None:
+        nfft = int(round(res_param * sps / (f1-f0), 0))
+    if nfft > len(sig):
+        nfft = len(sig)
+
+    # calculate the PSD
+    Pxx, freqs = mpl.mlab.psd(sig, NFFT=nfft, Fs=sps)
+
+    i0 = np.abs(freqs - f0).argmin()
+    i1 = np.abs(freqs - f1).argmin()
+
+    # plotting psd, marking peaks
+    ax.plot(freqs[i0:i1], Pxx[i0:i1], col, label=label, alpha=alpha)
+    if mark_peaks:
+        ax = peaks(ax, freqs[i0:i1], Pxx[i0:i1], fn_max=f1,
+                   nr_peaks=nr_peaks, col_line=col[:1],
+                   ypos_delta=ypos_peaks_delta, bbox_alpha=0.5,
+                   ypos_mean=ypos_peaks, min_h=min_h, col_text='w')
+
+    return ax
+
+
 def time_psd(results, labels, axes, alphas=[1.0, 0.7], colors=['k-', 'r-'],
              NFFT=None, res_param=250, f0=0, f1=None, nr_peaks=10, min_h=15,
              mark_peaks=False, xlabels=['frequency [Hz]', 'time [s]'],
@@ -317,32 +349,15 @@ def time_psd(results, labels, axes, alphas=[1.0, 0.7], colors=['k-', 'r-'],
         label = labels[i]
         col = colors[i]
         alpha = alphas[i]
-        sps = int(round(1.0/np.diff(time).mean(), 0))
-        if f1 is None:
-            f1 = sps/2.0
-
-        if NFFT is None:
-            nfft = int(round(res_param * sps / (f1-f0), 0))
-        elif isinstance(NFFT, list):
+        if isinstance(NFFT, list):
             nfft = NFFT[i]
         else:
             nfft = NFFT
-        if nfft > len(data):
-            nfft = len(data)
-
-        # calculate the PSD
-        Pxx, freqs = mpl.mlab.psd(data, NFFT=nfft, Fs=sps)
-
-        i0 = np.abs(freqs - f0).argmin()
-        i1 = np.abs(freqs - f1).argmin()
-
-        # plotting psd, marking peaks
-        axes[0].plot(freqs[i0:i1], Pxx[i0:i1], col, label=label, alpha=alpha)
-        if mark_peaks:
-            axes[0] = peaks(axes[0], freqs[i0:i1], Pxx[i0:i1], fn_max=f1,
-                            nr_peaks=nr_peaks, col_line=col[:1],
-                            ypos_delta=ypos_peaks_delta, bbox_alpha=0.5,
-                            ypos_mean=ypos_peaks[i], min_h=min_h, col_text='w')
+        axes[0] = psd(axes[0], time, data, nfft=nfft, res_param=res_param,
+                      f0=f0, f1=f1, nr_peaks=nr_peaks, min_h=min_h,
+                      mark_peaks=mark_peaks, col=col, label=label, alpha=alpha,
+                      ypos_peaks=ypos_peaks, ypos_peaks_delta=ypos_peaks_delta)
+
         # plotting time series
         axes[1].plot(time, data, col, label=label, alpha=alpha)
 
-- 
GitLab