Module 4‐Discrete Wavelet - igheyas/WaveletTransformation GitHub Wiki

📋 # Week 4: Discrete Wavelet Transform (DWT) Basics 📝 Outline Goals

From CWT to DWT: Dyadic Sampling

Scaling & Wavelet Functions

Approximation & Detail Coefficients

Perfect Reconstruction Formula

Example: Haar Wavelet

Exercises

import pywt
cA, cD = pywt.dwt(f, 'haar')

f_rec = pywt.idwt(cA, cD, 'haar')

Dyadic Sampling

Here’s a quick snippet using PyWavelets to list all built-in discrete wavelets:
import pywt

# 1. Simple list of all discrete wavelet names
discrete_wavs = pywt.wavelist(kind='discrete')
print("Available discrete wavelets:\n", discrete_wavs)

# 2. (Optional) Group them by family
print("\nWavelets by family:")
for fam in pywt.families():
    names = pywt.wavelist(family=fam, kind='discrete')
    print(f"  {fam:6s}: {names}")

Output

Below is a concise guide to the wavelet families you saw in your listing, with three parts:
  1. Morlet is not a discrete wavelet When you run
pywt.wavelist(kind='discrete')

you only get orthogonal/biorthogonal (DWT) wavelets: ['haar', 'db1', …, 'db38', 'sym2', …, 'sym20', 'coif1', …, 'coif17', 'bior1.1', …, 'bior6.8', 'rbio1.1', …, 'rbio6.8', 'dmey']

Morlet ('morl') lives in the continuous list

pywt.wavelist(kind='continuous')
# ['cgau1',…,'cgau8','fbsp','gaus1',…,'gaus8','mexh','morl','shan','cmor']

# dwt_energy_reconstruction.py

import numpy as np
import matplotlib.pyplot as plt
import pywt
from scipy.signal import chirp

# 1) Create a non-stationary signal: two chirps + noise
fs = 1000                     # sampling rate (Hz)
T  = 1.0                      # duration (s)
t  = np.linspace(0, T, int(fs*T), endpoint=False)

signal = np.zeros_like(t)
half = len(t)//2
# up-chirp  50→200 Hz over first half
signal[:half] = chirp(t[:half], f0=50,  f1=200, t1=T/2, method='linear')
# down-chirp 200→50 Hz over second half
signal[half:] = chirp(t[half:], f0=200, f1=50,  t1=T/2, method='linear')
# add white noise
signal += 0.2 * np.random.randn(len(t))

# 2) Estimate time-domain energy
E_time = np.sum(signal**2)
print(f"Time-domain energy: {E_time:.6f}")

# 3) Discrete Wavelet Transform (multi-level)
wavelet = 'db4'
max_level = pywt.dwt_max_level(len(signal), pywt.Wavelet(wavelet).dec_len)
level = min(4, max_level)     # choose e.g. 4 levels (<= max_level)

coeffs = pywt.wavedec(signal, wavelet, level=level)
cA, cDs = coeffs[0], coeffs[1:]  # cA = approximation @ level, cDs = list of detail arrays

# 4) Estimate wavelet-domain energy (sum of squares of all coeffs)
E_wav = np.sum(cA**2) + sum(np.sum(d**2) for d in cDs)
print(f"Wavelet-domain energy: {E_wav:.6f}")

# energies should match (up to numerical rounding)
print(f"Energy mismatch: {E_wav - E_time:.6e}")

# 5) Reconstruct signal from the coefficients
signal_rec = pywt.waverec(coeffs, wavelet)
# waverec may return array slightly longer due to padding → truncate
signal_rec = signal_rec[:len(signal)]

# 6) Reconstruction performance metrics
mse = np.mean((signal - signal_rec)**2)
snr = 10 * np.log10(np.sum(signal**2) / np.sum((signal - signal_rec)**2))
print(f"Reconstruction MSE: {mse:.6e}")
print(f"Reconstruction SNR: {snr:.2f} dB")

# 7) Optional: plot original vs reconstructed
plt.figure(figsize=(10, 3))
plt.plot(t, signal,     label='Original',    alpha=0.7)
plt.plot(t, signal_rec, label='Reconstructed',alpha=0.7)
plt.xlabel('Time [s]')
plt.ylabel('Amplitude')
plt.title(f'DWT Reconstruction (wavelet={wavelet}, levels={level})')
plt.legend(loc='upper right')
plt.tight_layout()
plt.show()

Output


import numpy as np
import pywt

# ... (steps 1&2: generate `signal` and compute E_time as before) ...

wavelet   = 'db4'
max_level = pywt.dwt_max_level(len(signal), pywt.Wavelet(wavelet).dec_len)
level     = min(4, max_level)

# 3) DWT in periodization mode
coeffs = pywt.wavedec(
    signal,
    wavelet,
    mode='periodization',    # <— important for perfect energy preservation
    level=level
)
# coeffs = [cA_level, cD_level, ..., cD1]

# 4) Wavelet-domain energy
E_wav = sum(np.sum(c**2) for c in coeffs)
print(f"Time-domain energy:   {E_time:.12f}")
print(f"Wavelet-domain energy: {E_wav:.12f}")
print(f"Energy mismatch:      {E_wav - E_time:.2e}")

# 5) Reconstruction (also periodization)
signal_rec = pywt.waverec(
    coeffs,
    wavelet,
    mode='periodization'
)[:len(signal)]

# now proceed with MSE/SNR etc.
mse = np.mean((signal - signal_rec)**2)
snr = 10*np.log10(np.sum(signal**2)/np.sum((signal-signal_rec)**2))
print(f"Reconstruction MSE: {mse:.2e}")
print(f"Reconstruction SNR: {snr:.2f} dB")

total number of coefficients

import pywt
import numpy as np

# --- assume you already have `signal` defined, e.g.: ---
# fs = 1000
# T  = 1.0
# t  = np.linspace(0, T, int(fs*T), endpoint=False)
# signal = np.sin(2*np.pi*50*t)  # example signal

# 1) Perform a multilevel DWT
wavelet = 'db4'
max_level = pywt.dwt_max_level(len(signal), pywt.Wavelet(wavelet).dec_len)
level = min(4, max_level)
coeffs = pywt.wavedec(signal, wavelet, mode='periodization', level=level)

# 2) Print how many coefficient arrays we got
print(f"Number of coefficient arrays (approx + details): {len(coeffs)}\n")

# 3) Print the shape (size) of each array
# coeffs[0] = approximation at the last level
# coeffs[1:] = detail coefficients from last->first level
print(f"Approximation  cA (level {level}) shape: {coeffs[0].shape}")
for j, cD in enumerate(coeffs[1:], start=1):
    print(f"Detail coeffs d{level-j+1} (level {level-j+1}) shape: {cD.shape}")

# If you also want the total number of coefficients:
total_coeffs = sum(c.size for c in coeffs)
print(f"\nTotal number of DWT coefficients: {total_coeffs}")

Output

Orthogonal vs. Biorthogonal Wavelets

[Future Works](Future Work)