Module 4‐Discrete Wavelet - igheyas/WaveletTransformation GitHub Wiki
📋 # Week 4: Discrete Wavelet Transform (DWT) Basics 📝 Outline Goals
From CWT to DWT: Dyadic Sampling
Scaling & Wavelet Functions
Approximation & Detail Coefficients
Perfect Reconstruction Formula
Example: Haar Wavelet
Exercises
import pywt
cA, cD = pywt.dwt(f, 'haar')
f_rec = pywt.idwt(cA, cD, 'haar')
Dyadic Sampling
Here’s a quick snippet using PyWavelets to list all built-in discrete wavelets:
import pywt
# 1. Simple list of all discrete wavelet names
discrete_wavs = pywt.wavelist(kind='discrete')
print("Available discrete wavelets:\n", discrete_wavs)
# 2. (Optional) Group them by family
print("\nWavelets by family:")
for fam in pywt.families():
names = pywt.wavelist(family=fam, kind='discrete')
print(f" {fam:6s}: {names}")
Output
Below is a concise guide to the wavelet families you saw in your listing, with three parts:
- Morlet is not a discrete wavelet When you run
pywt.wavelist(kind='discrete')
you only get orthogonal/biorthogonal (DWT) wavelets: ['haar', 'db1', …, 'db38', 'sym2', …, 'sym20', 'coif1', …, 'coif17', 'bior1.1', …, 'bior6.8', 'rbio1.1', …, 'rbio6.8', 'dmey']
Morlet ('morl') lives in the continuous list
pywt.wavelist(kind='continuous')
# ['cgau1',…,'cgau8','fbsp','gaus1',…,'gaus8','mexh','morl','shan','cmor']
# dwt_energy_reconstruction.py
import numpy as np
import matplotlib.pyplot as plt
import pywt
from scipy.signal import chirp
# 1) Create a non-stationary signal: two chirps + noise
fs = 1000 # sampling rate (Hz)
T = 1.0 # duration (s)
t = np.linspace(0, T, int(fs*T), endpoint=False)
signal = np.zeros_like(t)
half = len(t)//2
# up-chirp 50→200 Hz over first half
signal[:half] = chirp(t[:half], f0=50, f1=200, t1=T/2, method='linear')
# down-chirp 200→50 Hz over second half
signal[half:] = chirp(t[half:], f0=200, f1=50, t1=T/2, method='linear')
# add white noise
signal += 0.2 * np.random.randn(len(t))
# 2) Estimate time-domain energy
E_time = np.sum(signal**2)
print(f"Time-domain energy: {E_time:.6f}")
# 3) Discrete Wavelet Transform (multi-level)
wavelet = 'db4'
max_level = pywt.dwt_max_level(len(signal), pywt.Wavelet(wavelet).dec_len)
level = min(4, max_level) # choose e.g. 4 levels (<= max_level)
coeffs = pywt.wavedec(signal, wavelet, level=level)
cA, cDs = coeffs[0], coeffs[1:] # cA = approximation @ level, cDs = list of detail arrays
# 4) Estimate wavelet-domain energy (sum of squares of all coeffs)
E_wav = np.sum(cA**2) + sum(np.sum(d**2) for d in cDs)
print(f"Wavelet-domain energy: {E_wav:.6f}")
# energies should match (up to numerical rounding)
print(f"Energy mismatch: {E_wav - E_time:.6e}")
# 5) Reconstruct signal from the coefficients
signal_rec = pywt.waverec(coeffs, wavelet)
# waverec may return array slightly longer due to padding → truncate
signal_rec = signal_rec[:len(signal)]
# 6) Reconstruction performance metrics
mse = np.mean((signal - signal_rec)**2)
snr = 10 * np.log10(np.sum(signal**2) / np.sum((signal - signal_rec)**2))
print(f"Reconstruction MSE: {mse:.6e}")
print(f"Reconstruction SNR: {snr:.2f} dB")
# 7) Optional: plot original vs reconstructed
plt.figure(figsize=(10, 3))
plt.plot(t, signal, label='Original', alpha=0.7)
plt.plot(t, signal_rec, label='Reconstructed',alpha=0.7)
plt.xlabel('Time [s]')
plt.ylabel('Amplitude')
plt.title(f'DWT Reconstruction (wavelet={wavelet}, levels={level})')
plt.legend(loc='upper right')
plt.tight_layout()
plt.show()
Output
import numpy as np
import pywt
# ... (steps 1&2: generate `signal` and compute E_time as before) ...
wavelet = 'db4'
max_level = pywt.dwt_max_level(len(signal), pywt.Wavelet(wavelet).dec_len)
level = min(4, max_level)
# 3) DWT in periodization mode
coeffs = pywt.wavedec(
signal,
wavelet,
mode='periodization', # <— important for perfect energy preservation
level=level
)
# coeffs = [cA_level, cD_level, ..., cD1]
# 4) Wavelet-domain energy
E_wav = sum(np.sum(c**2) for c in coeffs)
print(f"Time-domain energy: {E_time:.12f}")
print(f"Wavelet-domain energy: {E_wav:.12f}")
print(f"Energy mismatch: {E_wav - E_time:.2e}")
# 5) Reconstruction (also periodization)
signal_rec = pywt.waverec(
coeffs,
wavelet,
mode='periodization'
)[:len(signal)]
# now proceed with MSE/SNR etc.
mse = np.mean((signal - signal_rec)**2)
snr = 10*np.log10(np.sum(signal**2)/np.sum((signal-signal_rec)**2))
print(f"Reconstruction MSE: {mse:.2e}")
print(f"Reconstruction SNR: {snr:.2f} dB")
total number of coefficients
import pywt
import numpy as np
# --- assume you already have `signal` defined, e.g.: ---
# fs = 1000
# T = 1.0
# t = np.linspace(0, T, int(fs*T), endpoint=False)
# signal = np.sin(2*np.pi*50*t) # example signal
# 1) Perform a multilevel DWT
wavelet = 'db4'
max_level = pywt.dwt_max_level(len(signal), pywt.Wavelet(wavelet).dec_len)
level = min(4, max_level)
coeffs = pywt.wavedec(signal, wavelet, mode='periodization', level=level)
# 2) Print how many coefficient arrays we got
print(f"Number of coefficient arrays (approx + details): {len(coeffs)}\n")
# 3) Print the shape (size) of each array
# coeffs[0] = approximation at the last level
# coeffs[1:] = detail coefficients from last->first level
print(f"Approximation cA (level {level}) shape: {coeffs[0].shape}")
for j, cD in enumerate(coeffs[1:], start=1):
print(f"Detail coeffs d{level-j+1} (level {level-j+1}) shape: {cD.shape}")
# If you also want the total number of coefficients:
total_coeffs = sum(c.size for c in coeffs)
print(f"\nTotal number of DWT coefficients: {total_coeffs}")
Output
Orthogonal vs. Biorthogonal Wavelets
[Future Works](Future Work)