From e0ae2e549eba65c7e6ec00829b576228cf1a8720 Mon Sep 17 00:00:00 2001
From: David Robert Verelst <dave@dtu.dk>
Date: Sun, 1 Jul 2018 13:20:39 +0200
Subject: [PATCH] prepost.mplutils: update psd peak detection printing and text
 box pos

---
 wetb/prepost/mplutils.py | 28 ++++++++++++++++------------
 1 file changed, 16 insertions(+), 12 deletions(-)

diff --git a/wetb/prepost/mplutils.py b/wetb/prepost/mplutils.py
index 96f7ae6..94bda4e 100644
--- a/wetb/prepost/mplutils.py
+++ b/wetb/prepost/mplutils.py
@@ -50,7 +50,7 @@ def make_fig(nrows=1, ncols=1, figsize=(12,8), dpi=120):
     return subplots(nrows=nrows, ncols=ncols, figsize=figsize, dpi=dpi)
 
 
-def subplots(nrows=1, ncols=1, figsize=(12,8), dpi=120, num=0):
+def subplots(nrows=1, ncols=1, figsize=(12,8), dpi=120, num=0, subplot_kw={}):
     """
 
     Equivalent function of pyplot.subplots(). The difference is that this one
@@ -76,7 +76,7 @@ def subplots(nrows=1, ncols=1, figsize=(12,8), dpi=120, num=0):
     plt_nr = 1
     for row in range(nrows):
         for col in range(ncols):
-            axes[row,col] = fig.add_subplot(nrows, ncols, plt_nr)
+            axes[row,col] = fig.add_subplot(nrows, ncols, plt_nr, **subplot_kw)
             plt_nr += 1
     return fig, axes
 
@@ -193,7 +193,8 @@ def p4psd(ax, rpm_mean, p_max=17, y_pos_rel=0.25, color='g', ls='--',
 
 
 def peaks(ax, freqs, Pxx, fn_max, min_h, nr_peaks=15, col_line='k',
-          ypos_mean=0.14, col_text='w', ypos_delta=0.06, bbox_alpha=0.5):
+          ypos_mean=0.14, col_text='w', ypos_delta=0.06, bbox_alpha=0.5,
+          verbose=False):
     """
     indicate the peaks
     """
@@ -204,11 +205,13 @@ def peaks(ax, freqs, Pxx, fn_max, min_h, nr_peaks=15, col_line='k',
     Pxx_log = 10.*np.log10(Pxx)
     try:
         pi = wafo.misc.findpeaks(Pxx_log, n=len(Pxx), min_h=min_h)
-        print('len Pxx', len(Pxx_log), 'nr of peaks:', len(pi))
+        if verbose:
+            print('len Pxx', len(Pxx_log), 'nr of peaks:', len(pi))
     except Exception as e:
-        print('len Pxx', len(Pxx_log))
-        print('*** wafo.misc.findpeaks FAILED ***')
-        print(e)
+        if verbose:
+            print('len Pxx', len(Pxx_log))
+            print('*** wafo.misc.findpeaks FAILED ***')
+            print(e)
         return ax
 
     # only take the nr_peaks most significant heights
@@ -220,7 +223,7 @@ def peaks(ax, freqs, Pxx, fn_max, min_h, nr_peaks=15, col_line='k',
 #    ax.plot(freqs[pi], Pxx[:xlim][pi], 'o')
     # and mark all peaks
     switch = True
-    yrange_plot = Pxx_log.max() - Pxx_log.min()
+    # yrange_plot = Pxx_log.max() - Pxx_log.min()
     for peak_nr, ii in enumerate(pi):
         freq_peak = freqs[ii]
 #        Pxx_peak = Pxx_log[ii]
@@ -232,13 +235,13 @@ def peaks(ax, freqs, Pxx, fn_max, min_h, nr_peaks=15, col_line='k',
             # locate at the min value (down the plot), but a little
             # lower so it does not interfere with the plot itself
             # if ax.set_yscale('log') True, set log values as coordinates!
-            text_ypos = Pxx_log.min() + yrange_plot*0.1
+            # text_ypos = Pxx_log.min() + yrange_plot*0.1
             text_ypos = ypos_mean + ypos_delta
             switch = False
         else:
             # put it a little lower than the max value so it does
             # not mess with the title (up the plot)
-            text_ypos = Pxx_log.min() - yrange_plot*0.4
+            # text_ypos = Pxx_log.min() - yrange_plot*0.4
             text_ypos = ypos_mean - ypos_delta
             switch = True
 #        print('%2.2e %2.2e %2.2e' % (yrange_plot, Pxx[:xlim].max(), Pxx[:xlim].min())
@@ -313,7 +316,7 @@ def psd(ax, time, sig, nfft=None, res_param=250, f0=0, f1=None, nr_peaks=10,
 def time_psd(results, labels, axes, alphas=[1.0, 0.7], colors=['k-', 'r-'],
              NFFT=None, res_param=250, f0=0, f1=None, nr_peaks=10, min_h=15,
              mark_peaks=False, xlabels=['frequency [Hz]', 'time [s]'],
-             ypos_peaks=[0.04, 0.9], ypos_peaks_delta=0.12):
+             ypos_peaks=[0.3, 0.9], ypos_peaks_delta=[0.12, 0.12]):
     """
     Plot time series and the corresponding PSD of the channel.
 
@@ -356,7 +359,8 @@ def time_psd(results, labels, axes, alphas=[1.0, 0.7], colors=['k-', 'r-'],
         axes[0] = psd(axes[0], time, data, nfft=nfft, res_param=res_param,
                       f0=f0, f1=f1, nr_peaks=nr_peaks, min_h=min_h,
                       mark_peaks=mark_peaks, col=col, label=label, alpha=alpha,
-                      ypos_peaks=ypos_peaks, ypos_peaks_delta=ypos_peaks_delta)
+                      ypos_peaks_delta=ypos_peaks_delta[i],
+                      ypos_peaks=ypos_peaks[i])
 
         # plotting time series
         axes[1].plot(time, data, col, label=label, alpha=alpha)
-- 
GitLab