Home - beekbin/RQ-VAE-Recommender GitHub Wiki

Welcome to the RQ-VAE-Recommender wiki!

Overview

It has some related techniques:

  • Variational Autoencoder (VAE)
  • Residual-Quantized Variational AutoEncoder (RQ-VAE)

Variational Autoencoder (VAE)

Jim Wang's blog about VAE gives an excellent explanation of VAE. Following content is from Jim's blog Screenshot 2025-03-09 at 11 20 36 AM

A Variational Autoencoder (VAE) is a type of generative model in machine learning that is used to learn a compressed representation of data while also being able to generate new, synthetic data points. VAEs are a form of autoencoder, but they incorporate probabilistic inference, allowing them to generate new samples by sampling from a learned latent distribution.

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        
        # Encoder: Input -> Hidden layers
        self.fc1 = nn.Linear(28*28, 400)  # Flattened input
        self.fc21 = nn.Linear(400, 20)    # Mean of the latent space
        self.fc22 = nn.Linear(400, 20)    # Log-variance of the latent space

        # Decoder: Latent space -> Output
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 28*28)  # Output size

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)  # Mean and log-variance

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)  # Standard deviation
        eps = torch.randn_like(std)    # Sample epsilon
        return mu + eps * std          # Reparameterization trick

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 28*28))  # Flatten input
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# VAE Loss Function
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 28*28), reduction='sum')
    # KL Divergence term for regularizing the latent prediction
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

Explanation:

  • Encoder: Takes the input data (28x28 images) and passes it through fully connected layers to produce mean (mu) and log-variance (logvar) for the latent space representation.
  • Reparameterization: A random noise (eps) is sampled and combined with the mean and variance to generate a latent vector z (using the reparameterization trick).
  • Decoder: The latent vector z is passed through the decoder to reconstruct the image.
  • Loss Function : The loss consists of a reconstruction loss (how well the VAE reconstructs the input) and a KL divergence term (regularizing the latent space to be close to a unit Gaussian).

Screenshot 2025-03-09 at 11 57 07 PM

KL divergence between two univariate Gaussians

  • Training: The model is trained using backpropagation, with Adam as the optimizer.

Residual-Quantized Variational AutoEncoder (RQ-VAE)

Screenshot 2025-03-09 at 11 23 15 AM

According to google's paper:Recommender Systems with Generative Retrieval, the Residual-Quantized Variational AutoEncoder (RQ-VAE) is a multi-level vector quantizer that applies quantization on residuals to generate a tuple of codewords (aka Semantic IDs in the google's paper). The Autoencoder is jointly trained by updating the quantization codebook and the DNN encoder-decoder parameters.