diff --git a/wetb/prepost/mplutils.py b/wetb/prepost/mplutils.py
index 8e8693c1de2890b03bde76ba796cea3c1eafea6f..08dcf1f990065ac67c932b4d1fc1de792c1ef374 100644
--- a/wetb/prepost/mplutils.py
+++ b/wetb/prepost/mplutils.py
@@ -121,6 +121,15 @@ def one_legend(*args, **kwargs):
     """First list all the axes as arguments. Any keyword arguments will be
     passed on to ax.legend(). Legend will be placed on the last axes that was
     passed as an argument.
+
+    Parameters
+    ----------
+
+    Returns
+    -------
+
+    legend
+
     """
     # or more general: not only simple line plots (bars, hist, ...)
     objs = []
@@ -271,7 +280,8 @@ def match_yticks(ax1, ax2, nr_ticks_forced=None, extend=False):
 
 def time_psd(results, labels, axes, alphas=[1.0, 0.7], colors=['k-', 'r-'],
              NFFT=None, res_param=250, f0=0, f1=None, nr_peaks=10, min_h=15,
-             mark_peaks=False):
+             mark_peaks=False, xlabels=['frequency [Hz]', 'time [s]'],
+             ypos_peaks=[0.04, 0.9], ypos_peaks_delta=0.12):
     """
     Plot time series and the corresponding PSD of the channel.
 
@@ -301,7 +311,6 @@ def time_psd(results, labels, axes, alphas=[1.0, 0.7], colors=['k-', 'r-'],
     """
 
     axes = axes.ravel()
-    ypos = [0.04, 0.90]
 
     for i, res in enumerate(results):
         time, data = res
@@ -332,14 +341,15 @@ def time_psd(results, labels, axes, alphas=[1.0, 0.7], colors=['k-', 'r-'],
         if mark_peaks:
             axes[0] = peaks(axes[0], freqs[i0:i1], Pxx[i0:i1], fn_max=f1,
                             nr_peaks=nr_peaks, col_line=col[:1],
-                            ypos_delta=0.04, bbox_alpha=0.5, col_text='w',
-                            ypos_mean=ypos[i], min_h=min_h)
+                            ypos_delta=ypos_peaks_delta, bbox_alpha=0.5,
+                            ypos_mean=ypos_peaks[i], min_h=min_h, col_text='w')
         # plotting time series
         axes[1].plot(time, data, col, label=label, alpha=alpha)
 
     axes[0].set_yscale('log')
-    axes[0].set_xlabel('frequency [Hz]')
-    axes[1].set_xlabel('time [s]')
+    if isinstance(xlabels, list):
+        axes[0].set_xlabel(xlabels[0])
+        axes[1].set_xlabel(xlabels[1])
     for ax in axes:
         leg = ax.legend(loc='best', borderaxespad=0)
         # leg is None when no labels have been defined