diff --git a/wetb/prepost/mplutils.py b/wetb/prepost/mplutils.py index 8e8693c1de2890b03bde76ba796cea3c1eafea6f..08dcf1f990065ac67c932b4d1fc1de792c1ef374 100644 --- a/wetb/prepost/mplutils.py +++ b/wetb/prepost/mplutils.py @@ -121,6 +121,15 @@ def one_legend(*args, **kwargs): """First list all the axes as arguments. Any keyword arguments will be passed on to ax.legend(). Legend will be placed on the last axes that was passed as an argument. + + Parameters + ---------- + + Returns + ------- + + legend + """ # or more general: not only simple line plots (bars, hist, ...) objs = [] @@ -271,7 +280,8 @@ def match_yticks(ax1, ax2, nr_ticks_forced=None, extend=False): def time_psd(results, labels, axes, alphas=[1.0, 0.7], colors=['k-', 'r-'], NFFT=None, res_param=250, f0=0, f1=None, nr_peaks=10, min_h=15, - mark_peaks=False): + mark_peaks=False, xlabels=['frequency [Hz]', 'time [s]'], + ypos_peaks=[0.04, 0.9], ypos_peaks_delta=0.12): """ Plot time series and the corresponding PSD of the channel. @@ -301,7 +311,6 @@ def time_psd(results, labels, axes, alphas=[1.0, 0.7], colors=['k-', 'r-'], """ axes = axes.ravel() - ypos = [0.04, 0.90] for i, res in enumerate(results): time, data = res @@ -332,14 +341,15 @@ def time_psd(results, labels, axes, alphas=[1.0, 0.7], colors=['k-', 'r-'], if mark_peaks: axes[0] = peaks(axes[0], freqs[i0:i1], Pxx[i0:i1], fn_max=f1, nr_peaks=nr_peaks, col_line=col[:1], - ypos_delta=0.04, bbox_alpha=0.5, col_text='w', - ypos_mean=ypos[i], min_h=min_h) + ypos_delta=ypos_peaks_delta, bbox_alpha=0.5, + ypos_mean=ypos_peaks[i], min_h=min_h, col_text='w') # plotting time series axes[1].plot(time, data, col, label=label, alpha=alpha) axes[0].set_yscale('log') - axes[0].set_xlabel('frequency [Hz]') - axes[1].set_xlabel('time [s]') + if isinstance(xlabels, list): + axes[0].set_xlabel(xlabels[0]) + axes[1].set_xlabel(xlabels[1]) for ax in axes: leg = ax.legend(loc='best', borderaxespad=0) # leg is None when no labels have been defined