Wavelet transform visualization python script - CoBrALab/documentation GitHub Wiki

dt=1.0 #sampling space is 1.0 sec
N=all_timeseries.shape[1] #number of data points
t=np.arange(0,N)*dt #set a time axis
high_f=0.2 #high frequency cutoff in Hz
low_f=0.01 #low frequency cutoff in Hz

import pycwt
mother = pycwt.Morlet(6) #setting a Morlet wavelet with omega=6, which sets the time-frequency resolution of the wavelet and is recommended in Chang and Glover
s0 = 1/high_f
dj = 1 / 12  # Twelve sub-octaves per octaves
num_powers=0
while((1/high_f)*(2**num_powers)<1/low_f):
    num_powers+=1
J = int(num_powers / dj)

#taking two sample timeseries
y1=all_timeseries[0,:,20]
y2=all_timeseries[0,:,10]

#plot the two timeseries
plt.plot(y1[30:100])
plt.plot(y2[30:100])

#derive the wavelet transform of one timeserie
W, scales, freq, coi, fft, fftfreqs = pycwt.cwt(y1, dt, dj=dj, s0=s0, J=J, wavelet=mother, freqs=None)

#derive the cross-wavelet transform of the two timeseries
xWT, coi, freq=wavelet_analysis.cross_wavelet_transform(y1, y2, dt, dj=dj, s0=s0, J=J, wavelet=mother, normalize=True)

#taking the wavelet amplitude
wave=np.abs(W) 
#wave=np.abs(xWT)

#colormap
cmap='viridis'
#cmap='coolwarm'

#wavelet angle
#wave=np.cos(np.angle(xWT))

freqs_coi=1/coi
#threshold the frequencies at the max frequencies of the wavelet
pos_indices=freqs_coi>np.max(freq)
neg_indices=freqs_coi<np.min(freq)
for i in range(len(freqs_coi)):
    if pos_indices[i]:
        freqs_coi[i]=np.max(freq)
    elif neg_indices[i]:
        freqs_coi[i]=np.min(freq)

if cmap=='coolwarm':
    vmax=np.max(np.abs(wave))
    vmin=-np.max(np.abs(wave))
else:
    vmax=np.max(wave)
    vmin=0
    
fig,ax=plt.subplots(1,1)
plt.pcolormesh(t, freq, wave, cmap=cmap,
            vmax=vmax, vmin=vmin)
cbar = plt.colorbar()
ax.fill_between(t,freqs_coi,np.ones(len(t))*np.min(freq),color='k',alpha=0.7, hatch='x')
ax.set_yscale("log", basey=2)
ax.set_yticks([0.1,0.05,0.02,0.01])
ax.set_yticklabels([0.1,0.05,0.02,0.01], fontsize=15)
ax.set_xticks([0,100,200,300])
ax.set_xticklabels([0,100,200,300], fontsize=15)
ax.set_ylabel("Frequency (Hz)", fontsize=20)
ax.set_xlabel("Time (sec)", fontsize=20)
plt.tight_layout()