An Introduction to Triplet Loss and Unsupervised Learning - 180D-FW-2023/Knowledge-Base-Wiki GitHub Wiki

An Introduction to Triplet Loss and Unsupervised Learning

Overview

A traditional computer vision classifier learns by training upon a dataset with different categories, each with a distinct label. With each iteration of the training loop, the model makes predictions which are compared against the labeled ground truth labels. This is considered supervised learning since the training requires the training program to explicitly tell the model whether it was correct in its prediction or not.

On the other hand, unsupervised learning is when the model tries to learn the underlying patterns and structure from data without being explicitly told what to look for. Imagine you have a large pile of different colored and shaped blocks and you are asked to organize them. No one tells you how to do it. There are no rules laid out for sorting by color, shape, or size. You look at the blocks, start to notice patterns, and decide to group them in a way that makes sense to you—Maybe all the red blocks together, the square blocks in another group, and so on. You are learning to categorize the blocks based on the features you observe. Most importantly, unsupervised training requires no labels, the most time-intensive step of creating large datasets.

Triplet Loss

The triplet loss function takes three input examples; an anchor, a positive example, and a negative example.

  • Anchor: A reference example.
  • Positive: An example that is similar to the anchor.
  • Negative: An example that is different from the anchor.

Visual Representation of Triplet Loss (https://arxiv.org/abs/1503.03832)

triplet_loss_diagram

The goal of the loss function is to ensure the distance between the anchor and the positive example is smaller than the distance between the anchor and the negative example by at least a margin α. The triplet loss equation for a triplet (anchor a, positive p, and negative n) is:

L(a, p, n) = max(d(a, p) − d(a, n) + α, 0)

where d(x, y) is the distance between the embeddings of x and y, and α is the margin that is enforced between positive and negative pairs.

Data Augmentation

If there are no labels in the dataset, then how does the model know which images should be considered similar and which should be different? The trick is to feed in the same example as the anchor and the positive. However, by performing data augmentation on the example we can create a positive and anchor that are similar, but visually distinct from each other and the original example. The negative is chosen as a random image in the dataset.

Example data augmentations (https://arxiv.org/pdf/2002.05709)

The figure shows different types of data augmentations that can be applied to an original image. These augmentations train the model to learn the defining features of the dog, regardless of color, orientation, and more. For example, if we take examples (a) and (b), we can see that the model is taught that both the rear and the frontal view of the dog are the same. This teaches the model to learn the underlying features that make a dog a dog, the leg shape, head shape, and more.

TensorFlow Code Example

Now we will walk through a TensorFlow tutorial that implements the triplet loss function on the MNIST dataset. You can follow along in a Google Collab environment.

Necessary Imports

import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, Dropout, Lambda
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import mnist, fashion_mnist
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator

Loading the MNIST dataset

(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train / 255.0
X_test = X_test / 255.0
X_train = np.expand_dims(X_train, axis=-1)  # Add channel dimension
X_test = np.expand_dims(X_test, axis=-1)  # Add channel dimension

Creating the embedding model

# Simple CNN model
input_shape = (28, 28, 1)
input_layer = Input(shape=input_shape)

# Define batch size
batch_size = 64

# Conv Block 1
x = Conv2D(64, kernel_size=(3, 3), activation='relu', padding='same')(input_layer)
x = Conv2D(64, kernel_size=(3, 3), activation='relu', padding='same')(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Dropout(0.25)(x)

# Conv Block 2
x = Conv2D(128, kernel_size=(3, 3), activation='relu', padding='same')(x)
x = Conv2D(128, kernel_size=(3, 3), activation='relu', padding='same')(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Dropout(0.25)(x)

# Conv Block 3
x = Conv2D(256, kernel_size=(3, 3), activation='relu', padding='same')(x)
x = Conv2D(256, kernel_size=(3, 3), activation='relu', padding='same')(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Dropout(0.25)(x)

# Linear Layers
x = Flatten()(x)
x = Dense(256, activation='relu')(x)
x = Dropout(0.5)(x)
x = Dense(128, activation='relu')(x)

# Model
model = Model(inputs=input_layer, outputs=x)

This model serves as the backbone that generates the embeddings for the images in the dataset.

Data augmentation function

def augment(image):
    datagen = ImageDataGenerator(
        rotation_range=10,
        width_shift_range=0.1,
        height_shift_range=0.1,
        shear_range=0.1,
        zoom_range=0.1,
        horizontal_flip=True,
        vertical_flip=True,
        fill_mode='nearest'
    )
    # Reshape to 4D for augmentation
    image = np.expand_dims(image, 0)
    it = datagen.flow(image, batch_size=1)
    return it.next().squeeze()

Our augment function modifies an example to be used as the anchor and positive utilizing various image transforms.

Triplet generating function

def create_triplets(batch, model):
    # Initialize triplets
    anchors = []
    positives = []
    negatives = []

    # Randomly choose negatives
    negative_indices = np.random.choice(len(batch), size=len(batch))
    
    for i in range(len(batch)):
        # Augment anchor and positive
        anchor = augment(batch[i])
        positive = augment(batch[i])
        negative = batch[negative_indices[i]]
        
        anchors.append(anchor)
        positives.append(positive)
        negatives.append(negative)

    return np.array(anchors), np.array(positives), np.array(negatives)

Displaying a sample triplet

import matplotlib.pyplot as plt

indices = np.random.randint(0, X_train.shape[0], batch_size)
batch = X_train[indices]
# Generate triplets from batch
a, p, n = create_triplets(batch, model)

# Create a figure and a set of subplots
fig, axs = plt.subplots(1, 3, figsize=(10, 5)) # 1 row, 3 columns

# Display each image on each subplot
axs[0].set_title("Anchor")
axs[0].imshow(a[0])
axs[0].axis('off')  # Hide axis

axs[1].set_title("Positive")
axs[1].imshow(p[0])
axs[1].axis('off')  # Hide axis

axs[2].set_title("Negative")
axs[2].imshow(n[0])
axs[2].axis('off')  # Hide axis

# Customize and show plot
fig.suptitle('Sample Triplet', fontsize=16)
plt.show()

The anchor and positive are both generated from the same image but rotated, shifted, and flipped differently. This augmentation ensures the model can learn the features of a digit regardless of orientation.

Triplet loss function

def triplet_loss(alpha=0.2):
    def loss(y_true, y_pred):
        anchor, positive, negative = y_pred[:,0:batch_size], y_pred[:,batch_size:2*batch_size], y_pred[:,2*batch_size:3*batch_size]
        ap_distance = tf.reduce_sum(tf.square(anchor - positive), axis=-1)
        an_distance = tf.reduce_sum(tf.square(anchor - negative), axis=-1)
        return tf.maximum(ap_distance - an_distance + alpha, 0.0)
    return loss

This function implements the equation for triplet loss defined earlier with alpha being the margin.

Combine models for training

input_anchor = Input(shape=input_shape)
input_positive = Input(shape=input_shape)
input_negative = Input(shape=input_shape)

anchor_embedding = model(input_anchor)
positive_embedding = model(input_positive)
negative_embedding = model(input_negative)

merged_output = Lambda(lambda tensors: tf.concat(tensors, axis=-1))([anchor_embedding, positive_embedding, negative_embedding])
triplet_model = Model(inputs=[input_anchor, input_positive, input_negative], outputs=merged_output)
triplet_model.compile(loss=triplet_loss(alpha=0.2), optimizer=Adam(0.0001))

Training Loop

epochs = 20
for epoch in range(epochs):
    # Randomly sample a batch of images from the training set
    indices = np.random.randint(0, X_train.shape[0], batch_size)
    batch = X_train[indices]
    # Generate triplets from batch
    a, p, n = create_triplets(batch, model)
    # Train on the triplets
    loss = triplet_model.train_on_batch([a, p, n], np.zeros((len(a), 1)))
    print(f'Epoch {epoch+1}/{epochs}, Loss: {loss}')

Embedding Outputs

from sklearn.manifold import TSNE

# Embed the first 1000 examples in the training dataset
embeddings = model.predict(X_train[:1000])

# Use t-SNE to reduce dimensionality for visualization
tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300)
tsne_results = tsne.fit_transform(embeddings)

# Plot the data colored by label
plt.figure(figsize=(10, 7))
for i in range(10):
    indices = np.where(y_train[:1000] == i)[0]
    plt.scatter(tsne_results[indices, 0], tsne_results[indices, 1], label=str(i))

plt.title('t-SNE visualization of MNIST training data colored by label')
plt.legend()
plt.show()

As shown in the visualized embeddings, the model was able to cluster the various digits in the training set without having to ever see the labels.

Test Embedding Outputs

embeddings = model.predict(X_test[:1000])

# Use t-SNE to reduce dimensionality for visualization
tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300)
tsne_results = tsne.fit_transform(embeddings)  # Using a subset for faster computation

# Plot the data colored by label
plt.figure(figsize=(10, 7))
for i in range(10):
    indices = np.where(y_test[:1000] == i)[0]
    plt.scatter(tsne_results[indices, 0], tsne_results[indices, 1], label=str(i))

plt.title('t-SNE visualization of MNIST testing data colored by label')
plt.legend()
plt.show()

Even on the testing set which consists of images the model has never seen, the model is still able to embed the same digits near each other. Thus we know that the model is learning the underlying features that distinguish each digit instead of just memorizing the training examples.

Practical Applications

Unsupervised learning and triplet loss can be applied to various real-world use cases. Some examples include face recognition, image search engines, and satellite surveillance. All of these examples require training across large unlabeled datasets which would be costly and time consuming to manually label. Instead, the models can learn the underlying features for each task and generate clusters of embeddings similar to the MNIST clusters.

One extremely popular example is FaceNet, a system designed for face verification trained using triplet loss.

FaceNet Diagram (https://sh-tsang.medium.com/review-facenet-a-unified-embedding-for-face-recognition-and-clustering-7b360d2a85e4)

The FaceNet model relies on the same triplet loss described before. There is, however, a slight change as the anchor and positive are now different images but of the same person. This allows the model to learn an individual's specific features regardless of lighting, expression, and most importantly angle. The power of the model comes when a new face is introduced to the model. Without even having to retrain on the new person's face, it can cluster images of the new person together as the model recognizes the underlying features that make each face unique.

Conclusion

In conclusion, triplet loss offers a unique approach of using anchor, positive, and negative examples to train a neural network to understand the underlying features in the data. It is especially useful when datasets are too large to be manually labeled. The use cases of triplet loss range from face recognition to recommendation systems which allow new users to be added without any model retraining. The code tutorial demonstrates how triplet loss can be implemented in TensorFlow and the power of unsupervised learning. As large datasets become more prevalent, unsupervised learning will continue to grow due to its ability to handle large amounts of data without the need for human labels.

References: