Newer
Older
# -*- coding: utf-8 -*-
"""
Created on Wed Nov 23 11:22:50 2011
@author: dave
"""
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
from builtins import range
from builtins import int
from builtins import dict
from builtins import round
from future import standard_library
standard_library.install_aliases()
# external libraries
import numpy as np
import matplotlib as mpl
# use a headless backend
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigCanvas

David Verelst
committed
#import wafo
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def make_fig(nrows=1, ncols=1, figsize=(12,8), dpi=120):
"""
Equivalent function of pyplot.subplots(). The difference is that this one
is not interactive and is used with backend plotting only.
Parameters
----------
nrows=1, ncols=1, figsize=(12,8), dpi=120
Returns
-------
fig, canvas, axes
"""
return subplots(nrows=nrows, ncols=ncols, figsize=figsize, dpi=dpi)
def subplots(nrows=1, ncols=1, figsize=(12,8), dpi=120, num=0):
"""
Equivalent function of pyplot.subplots(). The difference is that this one
is not interactive and is used with backend plotting only.
Parameters
----------
nrows=1, ncols=1, figsize=(12,8), dpi=120
num : dummy variable for compatibility
Returns
-------
fig, axes
"""
fig = mpl.figure.Figure(figsize=figsize, dpi=dpi)
canvas = FigCanvas(fig)
fig.set_canvas(canvas)
axes = np.ndarray((nrows, ncols), dtype=np.object)
plt_nr = 1
for row in range(nrows):
for col in range(ncols):
axes[row,col] = fig.add_subplot(nrows, ncols, plt_nr)
plt_nr += 1
return fig, axes
def match_axis_ticks(ax1, ax2, ax1_format=None, ax2_format=None):
"""
Match ticks of ax2 to ax1
=========================
ax1_format: '%1.1f'
Parameters
----------
ax1, ax2
Returns
-------
ax1, ax2
"""
# match the ticks of ax2 to ax1
yticks1 = len(ax1.get_yticks())
ylim2 = ax2.get_ylim()
yticks2 = np.linspace(ylim2[0], ylim2[1], num=yticks1).tolist()
ax2.yaxis.set_ticks(yticks2)
# give the tick labels a given precision
if ax1_format:
majorFormatter = mpl.ticker.FormatStrFormatter(ax1_format)
ax1.yaxis.set_major_formatter(majorFormatter)
if ax2_format:
majorFormatter = mpl.ticker.FormatStrFormatter(ax2_format)
ax2.yaxis.set_major_formatter(majorFormatter)
return ax1, ax2
def one_legend(*args, **kwargs):
"""First list all the axes as arguments. Any keyword arguments will be
passed on to ax.legend(). Legend will be placed on the last axes that was
passed as an argument.
"""
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
# or more general: not only simple line plots (bars, hist, ...)
objs = []
for ax in args:
objs += ax.get_legend_handles_labels()[0]
# objs = [ax.get_legend_handles_labels()[0] for ax in args]
labels = [obj.get_label() for obj in objs]
# place the legend on the last axes
leg = ax.legend(objs, labels, **kwargs)
return leg
def p4psd(ax, rpm_mean, p_max=17, y_pos_rel=0.25, color='g', ls='--',
col_text='w'):
"""
Add the P's on a PSD
fn_max is the maximum value on the plot (ax.xlim). This only works when
setting the xlim of the plot before calling p4psd.
Parameters
----------
ax
rpm_mean
p_max : int, default=17
y_pos_rel : int or list, default=0.25
"""
if isinstance(y_pos_rel, float) or isinstance(y_pos_rel, int):
y_pos_rel = [y_pos_rel]*p_max
f_min = ax.get_xlim()[0]
f_max = ax.get_xlim()[1]
# add the P's
bbox = dict(boxstyle="round", edgecolor=color, facecolor=color)
for i, p in enumerate(range(1, p_max)):
p_freq = p * rpm_mean / 60.0
if p_freq > f_max:
break
if p%3 == 0:
alpha=0.5
ax.axvline(x=p_freq, linewidth=1, color=color, alpha=0.7, ls=ls)
else:
alpha = 0.2
ax.axvline(x=p_freq, linewidth=1, color=color, alpha=0.7, ls=ls)
x = (p_freq - f_min) / (f_max - f_min)
y = y_pos_rel[i]
p_str = '%iP' % p
bbox['alpha'] = alpha
ax.text(x, y, p_str, fontsize=8, verticalalignment='bottom',
horizontalalignment='center', bbox=bbox, color=col_text,
transform=ax.transAxes)
return ax
def peaks(ax, freqs, Pxx, fn_max, min_h, nr_peaks=15, col_line='k',
ypos_mean=0.14, col_text='w', ypos_delta=0.06, bbox_alpha=0.5):
"""
indicate the peaks
"""
i_fn_max = np.abs(freqs - fn_max).argmin()
# ignore everything above fn_max
freqs = freqs[:i_fn_max]
Pxx = Pxx[:i_fn_max]
Pxx_log = 10.*np.log10(Pxx)
try:
pi = wafo.misc.findpeaks(Pxx_log, n=len(Pxx), min_h=min_h)
print('len Pxx', len(Pxx_log), 'nr of peaks:', len(pi))
except Exception as e:
print('len Pxx', len(Pxx_log))
print('*** wafo.misc.findpeaks FAILED ***')
print(e)
return ax
# only take the nr_peaks most significant heights
pi = pi[:nr_peaks]
# and sort them accoriding to frequency (or index)
pi.sort()
# mark the peaks with a circle
# ax.plot(freqs[pi], Pxx[:xlim][pi], 'o')
# and mark all peaks
switch = True
yrange_plot = Pxx_log.max() - Pxx_log.min()
for peak_nr, ii in enumerate(pi):
freq_peak = freqs[ii]
# Pxx_peak = Pxx_log[ii]
# take the average frequency and plot vertical line
ax.axvline(x=freq_peak, linewidth=1, color=col_line, alpha=0.6)
# and the value in a text box
textbox = '%2.2f' % freq_peak
if switch:
# locate at the min value (down the plot), but a little
# lower so it does not interfere with the plot itself
# if ax.set_yscale('log') True, set log values as coordinates!
text_ypos = Pxx_log.min() + yrange_plot*0.1
text_ypos = ypos_mean + ypos_delta
switch = False
else:
# put it a little lower than the max value so it does
# not mess with the title (up the plot)
text_ypos = Pxx_log.min() - yrange_plot*0.4
text_ypos = ypos_mean - ypos_delta
switch = True
# print('%2.2e %2.2e %2.2e' % (yrange_plot, Pxx[:xlim].max(), Pxx[:xlim].min())
# print peak, text_ypos
# print textbox
# print yrange_plot
xrel = freq_peak/fn_max
ax.text(xrel, text_ypos, textbox, fontsize=10, transform=ax.transAxes,
va='bottom', color=col_text, bbox=dict(boxstyle="round",
ec=col_line, fc=col_line, alpha=bbox_alpha,))
return ax
def match_yticks(ax1, ax2, nr_ticks_forced=None, extend=False):
"""
"""
if nr_ticks_forced is None:
nr_yticks1 = len(ax1.get_yticks())
else:
nr_yticks1 = nr_ticks_forced
ylim1 = ax1.get_ylim()
yticks1 = np.linspace(ylim1[0], ylim1[1], num=nr_yticks1).tolist()
ax1.yaxis.set_ticks(yticks1)
ylim2 = ax2.get_ylim()
yticks2 = np.linspace(ylim2[0], ylim2[1], num=nr_yticks1).tolist()
ax2.yaxis.set_ticks(yticks2)
if extend:
offset1 = (ylim1[1] - ylim1[0])*0.1
ax1.set_ylim(ylim1[0]-offset1, ylim1[1]+offset1)
offset2 = (ylim2[1] - ylim2[0])*0.1
ax2.set_ylim(ylim2[0]-offset2, ylim2[1]+offset2)
return ax1, ax2
def time_psd(results, labels, axes, alphas=[1.0, 0.7], colors=['k-', 'r-'],
NFFT=None, res_param=250, f0=0, f1=None, nr_peaks=10, min_h=15,
mark_peaks=True):
"""
Plot time series and the corresponding PSD of the channel.
The frequency range depends on the sample rate: fn_max = sps/2.
The number of points is: NFFT/2
Consequently, the frequency resolution is: NFFT/SPS (points/Hz)
len(data) > NFFT otherwise the results do not make any sense
res_param = NFFT*frequency_range*sps
good range for nondim_resolution is: 200-300
Frequency range is based on f0 and f1
for the PSD powers of 2 are faster, but the difference in performance with
any other random number fluctuates on average and is not that relevant.
Requires two axes!
With the PSD peaks this only works nicely for 2 results
Parameters
----------
results : list
list of time,data pairs (ndarrays)
"""
axes = axes.ravel()
ypos = [0.04, 0.90]
for i, res in enumerate(results):
time, data = res
label = labels[i]
col = colors[i]
alpha = alphas[i]
sps = int(round(1.0/np.diff(time).mean(), 0))
if f1 is None:
f1 = sps/2.0
if NFFT is None:
nfft = int(round(res_param * sps / (f1-f0), 0))
elif isinstance(NFFT, list):
nfft = NFFT[i]
else:
nfft = NFFT
if nfft > len(data):
nfft = len(data)
# calculate the PSD
Pxx, freqs = mpl.mlab.psd(data, NFFT=nfft, Fs=sps)
i0 = np.abs(freqs - f0).argmin()
i1 = np.abs(freqs - f1).argmin()
# plotting psd, marking peaks
axes[0].plot(freqs[i0:i1], Pxx[i0:i1], col, label=label, alpha=alpha)
if mark_peaks:
axes[0] = peaks(axes[0], freqs[i0:i1], Pxx[i0:i1], fn_max=f1,
nr_peaks=nr_peaks, col_line=col[:1],
ypos_delta=0.04, bbox_alpha=0.5, col_text='w',
ypos_mean=ypos[i], min_h=min_h)
# plotting time series
axes[1].plot(time, data, col, label=label, alpha=alpha)
axes[0].set_yscale('log')
axes[0].set_xlabel('frequency [Hz]')
axes[1].set_xlabel('time [s]')
for ax in axes:
leg = ax.legend(loc='best', borderaxespad=0)
# leg is None when no labels have been defined
if leg is not None:
leg.get_frame().set_alpha(0.7)
ax.grid(True)
return axes
if __name__ == '__main__':
pass