Future work - igheyas/WaveletTransformation GitHub Wiki
import tensorflow as tf
import pywt
import numpy as np
class WaveletConv2D(tf.keras.layers.Layer):
def __init__(self, wavelet_name, trainable=False, **kwargs):
super().__init__(**kwargs)
self.wavelet_name = wavelet_name
self.trainable = trainable
def build(self, input_shape):
# fetch the PyWavelets object
w = pywt.Wavelet(self.wavelet_name)
# here we use the low‑pass decomposition filter
lp = np.array(w.dec_lo, dtype=np.float32)
# form a 2D separable kernel
filt2d = np.outer(lp, lp) # shape (k, k)
# expand to (k, k, in_channels, 1)
in_ch = input_shape[-1]
kernel = filt2d[:, :, None, None]
kernel = np.tile(kernel, (1, 1, in_ch, 1))
# store as a tf.Variable so we can freeze/unfreeze as desired
self.kernel = self.add_weight(
name=f'{self.wavelet_name}_kernel',
shape=kernel.shape,
initializer=tf.constant_initializer(kernel),
trainable=self.trainable
)
def call(self, inputs):
# simple 2D convolution with SAME padding
return tf.nn.conv2d(inputs, self.kernel, strides=1, padding='SAME')
def wavelet_block(x, wavelet_names, trainable=False):
"""Apply one fixed-wavelet conv for each name, then concatenate + ReLU."""
convs = []
for name in wavelet_names:
convs.append(WaveletConv2D(name, trainable=trainable)(x))
x = tf.keras.layers.Concatenate()(convs)
return tf.keras.layers.ReLU()(x)
# 1) load your bird/fish data
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
'training',
labels='inferred',
label_mode='binary',
image_size=(128, 128),
batch_size=32,
color_mode='grayscale'
)
# 2) build the model
inputs = tf.keras.Input(shape=(128, 128, 1))
# layer 1: order 1 wavelets
layer1 = ['db1', 'bior1.1', 'coif1', 'sym1', 'rbio1.1', 'gaus1', 'cgau1']
x = wavelet_block(inputs, layer1, trainable=False)
# layer 2: order 2 wavelets
layer2 = ['db2', 'bior1.3', 'coif2', 'sym2', 'rbio1.3', 'gaus2', 'cgau2']
x = wavelet_block(x, layer2, trainable=False)
# classification head
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
outputs = tf.keras.layers.Dense(1, activation='sigmoid')(x)
model = tf.keras.Model(inputs, outputs)
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'])
# 3) train!
model.fit(train_ds, epochs=10)
A few notes:
Freezing vs. fine‑tuning: above we freeze all wavelet filters (trainable=False), so you’re essentially building a fixed wavelet transform as your first two CNN layers. If you’d rather let the network adapt them, simply set trainable=True.
Varying kernel sizes: higher‑order wavelets have longer filter lengths. In production you can either pad all to the same size (as above), or build each as a separate Conv2D and then concatenate.
Extensions: you can stack as many orders as you like (db3, db4…), include both low‑pass & high‑pass combinations, or even switch to a full scattering transform via Kymatio for an out‑of‑the‑box wavelet‐CNN hybrid.
We can build a “wavelet bank” CNN with multiple convolutional layers, each layer drawing from a different wavelet family and order, and train it on your bird vs. fish images.
Here’s a step‑by‑step recipe for extracting the Daubechies‑3 (db3) low‑pass filter and turning it into a 2D convolutional kernel in Python using PyWavelets:
- Install PyWavelets If you haven’t already:
pip install PyWavelets
- Import the libraries
import pywt
import numpy as np
- Load the db3 wavelet
w = pywt.Wavelet('db3')
- Inspect the 1D filter coefficients
w.dec_lo is the analysis low‑pass filter (length 6 for db3).
w.dec_hi is the analysis high‑pass filter.
import pywt
import numpy as np
w = pywt.Wavelet('db3')
lp = np.array(w.dec_lo, dtype=np.float32)
hp = np.array(w.dec_hi, dtype=np.float32)
print("Low‑pass (dec_lo):", lp)
print("High‑pass (dec_hi):", hp)
filt2d = np.outer(lp, lp) # shape will be (6, 6)
print("2D low‑pass kernel shape:", filt2d.shape)
Output
Summary
-
You fetch the db3 analysis filters via pywt.Wavelet('db3').
-
Pick dec_lo (low‑pass) or dec_hi (high‑pass) as your 1D vector.
-
Build a 2D separable kernel with np.outer.
-
Pad/reshape to match your framework’s Conv2D shape, then initialise your layer with those fixed weights.
That’s all there is to it! You can repeat the same steps for any other wavelet family and order (e.g. db4, sym2, etc.).
You can’t just “reshape” a 6×6 Daubechies‑3 kernel into 2×2 without changing its frequency response — but you can approximate or down‑sample it to a 2‑tap filter, and then build a 2×2 separable kernel. Here are two simple strategies:
- Decimation (straight subsampling) Pick every 3rd coefficient of the 1D low‑pass filter to get 2 taps:
import pywt, numpy as np
# 1D db3 low‑pass
w = pywt.Wavelet('db3')
lp = np.array(w.dec_lo, dtype=np.float32) # length = 6
# subsample: take taps at positions 0 and 3
lp2 = lp[::3] # [lp[0], lp[3]] → length 2
print("Subsampled taps:", lp2)
# make 2×2 separable kernel
filt2x2 = np.outer(lp2, lp2)
print("2×2 kernel:\n", filt2x2)
This gives a very crude approximation of db3, since you’re only keeping two of its six coefficients.
- Grouped averaging Average the 6 taps in two groups of three, preserving a bit more of the original shape:
# group into [0:3] and [3:6], then mean‑pool
g1 = lp[0:3].mean()
g2 = lp[3:6].mean()
lp2_avg = np.array([g1, g2], dtype=np.float32)
print("Averaged taps:", lp2_avg)
# build 2×2 kernel
filt2x2_avg = np.outer(lp2_avg, lp2_avg)
print("2×2 averaged kernel:\n", filt2x2_avg)
Caveat
-
Both methods alter the wavelet’s pass‑band characteristics heavily.
-
If you truly need a 2×2 low‑pass filter, the canonical choice is Haar (db1), whose dec_lo is already length 2.
-
If you want to approximate a longer wavelet with a shorter kernel, you’ll always trade fidelity for size.
Pick the approach (decimation, averaging, or simply switching to a 2‑tap wavelet) that best matches your application’s accuracy vs. performance needs.
Yes—wavelet transforms can be used to give your CNN richer, multiscale feature maps in a couple of ways:
- Precompute wavelet subbands as input channels
import numpy as np
import pywt
from tensorflow.keras import layers, models
def make_wavelet_channels(img, wavelet='db1', levels=2):
# img: H×W grayscale float32 array
coeffs = pywt.wavedec2(img, wavelet, level=levels, mode='periodization')
cA, details = coeffs[0], coeffs[1:]
# collect channels
chs = [cA]
for (cH, cV, cD) in details:
chs += [cH, cV, cD]
# resize all bands to the same H×W via up/down‐sampling
chs_up = [pywt.resize(c, img.shape) if c.shape!=img.shape else c
for c in chs]
return np.stack(chs_up, axis=-1) # shape (H,W,1+3*levels)
# load a single image I of shape (H,W)
X_wave = make_wavelet_channels(I, wavelet='db2', levels=2)
# build a simple CNN that ingests 7 channels (1 + 3*2)
model = models.Sequential([
layers.Input((H, W, X_wave.shape[-1])),
layers.Conv2D(32, 3, activation='relu'),
layers.MaxPool2D(),
layers.Conv2D(64, 3, activation='relu'),
layers.GlobalAvgPool2D(),
layers.Dense(num_classes, activation='softmax')
])
# compile & train as usual
model.compile('adam', 'sparse_categorical_crossentropy', ['accuracy'])
model.fit(X_wave[None], y[None], epochs=10)
- Use fixed wavelet filters in the first convolutional layer You can initialize the first Conv2D layer’s kernels to be 2D wavelet filters (e.g. Haar, Daubechies) and then freeze them:
import numpy as np
from tensorflow.keras import layers, models, initializers
# Example: 2×2 Haar filters
haar_ll = np.array([1,1],[1,1](/igheyas/WaveletTransformation/wiki/1,1],[1,1))/2
haar_lh = np.array([1,1],[-1,-1](/igheyas/WaveletTransformation/wiki/1,1],[-1,-1))/2
haar_hl = np.array([1,-1],[1,-1](/igheyas/WaveletTransformation/wiki/1,-1],[1,-1))/2
haar_hh = np.array([1,-1],[-1,1](/igheyas/WaveletTransformation/wiki/1,-1],[-1,1))/2
kernels = np.stack([haar_ll, haar_lh, haar_hl, haar_hh], axis=-1)
# shape (2,2, in_channels=1, out_channels=4)
kernels = kernels.reshape(2,2,1,4).astype('float32')
model = models.Sequential([
layers.Input((H, W, 1)),
layers.Conv2D(
filters=4,
kernel_size=2,
padding='same',
use_bias=False,
kernel_initializer=initializers.Constant(kernels),
trainable=False # keep them fixed
),
# ... follow with your usual trainable Conv/Pooling etc.
])
Those four feature‐maps (LL, LH, HL, HH) now become your “learned” first‐layer activations—except they’re fixed wavelets.
- Wavelet Scattering Networks A scattering network is essentially a cascade of fixed wavelet–modulus–averaging transforms that feeds into a small learned network. It’s been shown to give translation‐ and deformation‐stable features that combine well with a shallow CNN head.
See Mallat’s Scattering Networks
kymatio Python library provides ready‐to‐use scattering layers.
Summary Preprocessing: compute DWT subbands → stack as multi‐channel input.
Architectural: bake wavelet filters directly into the first Conv layer.
Advanced: use a scattering transform (fixed wavelet + modulus + pooling) as your feature extractor.
All of these effectively use wavelets as feature maps inside a CNN.