SkipGram - axkoro/graph-impute GitHub Wiki
-
Goal: Learn vector embeddings for nodes by training a neural network to predict surrounding nodes in a Random Walk given a center node.
- Originally, the Word2vec paper did this for words and sentences instead of nodes and random walks.
- Key Idea: For a given center node in a random walk, the model predicts the nodes within a specified window around it. Each (center, context) pair forms a training example.
- Model: The model consists of only one hidden layer.
Illustration of the model's layers
See NegativeSampler for our implementation of the sampling mechanism.
-
Problem:
- Theoretically, for each training pair we would have to perform two matrix-vector multiplications and then apply a softmax function over the entire output vector:
$y = \sigma(W' \cdot (W \cdot x))$ . Since the input vector is one-hot encoded, the first multiplication is actually only selecting a single column from$W$ . But the second multiplication would require$\text{embedding\_size} \cdot \text{num\_nodes}$ double multiplications. Then we would have to apply the softmax to this vector which requires$O(\text{num\_nodes})$ exponentiation, summing and division operations. We would then have to propagate back and update all$\text{embedding\_size} \cdot \text{num\_nodes}$ weights in$W'$ .- This would be computationally prohibitive (remember: we're doing this for every training pair).
- Theoretically, for each training pair we would have to perform two matrix-vector multiplications and then apply a softmax function over the entire output vector:
-
Solution:
- The DeepWalk paper proposes using a hierarchical softmax but later methods (e.g. node2vec) use negative sampling instead.
- Update only:
- the weights corresponding to the output component of the true ("positive") result, and
- the weights corresponding to a small number (
$k$ ) of nodes that shouldn't have been predicted by the model ("negative") per training instance.
- This means we only have to calculate
$k+1$ dot-products (see section Forward Pass for why), apply the sigmoid function on each of those ($O(1)$ per application) and then propagate back, where we also only need to update$k+1$ rows in$W'$ .- Our
$k$ is typically much much smaller than the original$\text{num\_nodes}$ (typically something like 5-20).
- Our
-
Loss Function:
- See Algorithm section below for explanation where these symbols and operations come from.
- The training objective for a single pair
$(c, o)$ with negative samples is
where
Sources: original paper, simple explanation
We start from the objective for a given center word vector
where the sigmoid function is defined as
- Gradient with Respect to
$v'_o$
Focus on the positive term:
Let
Then, using
the chain rule gives:
Thus,
- Gradient with Respect to Each Negative Sample
$v'_n$
For a negative sample, consider:
Let
Then,
Applying the chain rule:
Since
we have:
- Gradient with Respect to
$v_c$
The center word vector
- From the positive term:
- From each negative term:
Since
and using
it follows that:
Summing over all terms: