nn.Embedding in PyTorch to map tokens to dense vectors.
digraph nlp_vs_vision {
rankdir=LR;
node [fontsize=11, shape=box, style=rounded];
img [label="Image\n(64x64x3 pixels)"];
txt [label="Text\n(\"I really liked this movie\")"];
cnn [label="CNN / VAE"];
nlp [label="Embedding + NLP model"];
img -> cnn;
txt -> nlp;
}
High-level steps to go from raw text to model-ready tensors.
digraph nlp_pipeline {
rankdir=LR;
node [fontsize=12, shape=box, style=rounded, height=0.7];
edge [penwidth=1.5];
graph [nodesep=0.5, ranksep=0.8];
text [label="Raw text\n\"I loved the movie!\""];
tokens [label="Tokenization\n['I', 'loved', 'the', 'movie']", style="filled,rounded", fillcolor="#e3f2fd"];
ids [label="Vocabulary lookup\n[12, 57, 3, 98]", style="filled,rounded", fillcolor="#fff3e0"];
emb [label="Embedding layer\nlookup vectors", style="filled,rounded", fillcolor="#e8f5e9"];
model [label="NLP model\nclassifier / LM"];
text -> tokens -> ids -> emb -> model;
}
Tokenization splits text into units that the model will see.
text = "I really liked this movie!"
tokens = ["I", "really", "liked", "this", "movie", "!"]
Models operate on integers, not strings. We build a vocabulary mapping each token to an id.
tokens = ["i", "really", "liked", "this", "movie", "!"]
vocab = {
"<pad>": 0,
"<unk>": 1,
"i": 2,
"really": 3,
"liked": 4,
"this": 5,
"movie": 6,
"!": 7,
}
ids = [vocab.get(t, vocab["<unk>"]) for t in tokens]
# ids: [2, 3, 4, 5, 6, 7]
Let \(|V|\) be vocabulary size and \(d\) embedding dimension.
Toy vocabulary with \(|V| = 4\) and \(d = 2\).
| Token | Id | One-hot \(\mathbf{e}_i\) | Embedding \(\mathbf{x}_i\) |
|---|---|---|---|
| good | 0 | \([1, 0, 0, 0]\) | \([0.8,\ 0.6]\) |
| great | 1 | \([0, 1, 0, 0]\) | \([0.9,\ 0.7]\) |
| bad | 2 | \([0, 0, 1, 0]\) | \([-0.7,\ -0.6]\) |
| terrible | 3 | \([0, 0, 0, 1]\) | \([-0.9,\ -0.8]\) |
Vectors for “good” and “great” are close; “bad” and “terrible” are close but far from the positives.
nn.EmbeddingThe embedding layer is just a learnable matrix of shape \((|V|, d)\).
import torch
import torch.nn as nn
vocab_size = len(vocab)
embedding_dim = 64
emb = nn.Embedding(num_embeddings=vocab_size,
embedding_dim=embedding_dim)
ids = torch.tensor([[2, 3, 4, 5, 6, 7]]) # shape: (batch=1, seq_len=6)
embedded = emb(ids) # shape: (1, 6, 64)
Core idea: “You shall know a word by the company it keeps.” (Firth, 1957)
# Pseudocode: Skip-gram training example
center = "movie"
context = ["great", "funny", "exciting"]
# Objective: embeddings("movie") should be
# good at predicting its context words.
digraph cbow {
rankdir=LR;
node [fontsize=11, shape=box, style=rounded];
c1 [label="Context word\n'I'"];
c2 [label="Context word\n'the'"];
c3 [label="Context word\n'movie'"];
emb [label="Embedding\nlookup"];
avg [label="Average\ncontext vectors"];
out [label="Linear + Softmax\np(center | context)"];
ctr [label="Predicted\ncenter word\n'loved'"];
{c1 c2 c3} -> emb -> avg -> out -> ctr;
}
digraph skipgram {
rankdir=LR;
node [fontsize=11, shape=box, style=rounded];
ctr [label="Center word\n'movie'"];
emb [label="Embedding\nlookup"];
out [label="Linear + Softmax\np(context | center)"];
ctx1 [label="Context\n'great'"];
ctx2 [label="Context\n'funny'"];
ctx3 [label="Context\n'exciting'"];
ctr -> emb -> out;
out -> {ctx1 ctx2 ctx3};
}
We can see embeddings by projecting high-dimensional vectors (e.g., 300D Word2Vec) down to 3D with PCA.
PCA finds directions of maximum variance; 3D PCA gives an intuitive, interactive view of the global structure of the embedding space.
class SimpleSentimentModel(nn.Module):
def __init__(self, vocab_size, emb_dim=64):
super().__init__()
self.embedding = nn.Embedding(vocab_size, emb_dim)
self.fc = nn.Linear(emb_dim, 1) # binary
def forward(self, ids):
# ids: (batch, seq_len)
x = self.embedding(ids) # (batch, seq_len, emb_dim)
x = x.mean(dim=1) # (batch, emb_dim)
logits = self.fc(x) # (batch, 1)
return logits.squeeze(-1)
<pad> up to a fixed length.max_len = 10
def pad(ids, max_len, pad_id=0):
return ids[:max_len] + [pad_id] * max(0, max_len - len(ids))
batch_ids = [
pad([2, 3, 4, 5, 6, 7], max_len),
pad([2, 4, 6], max_len),
]
nn.Embedding + mean pooling + linear classifier to predict sentiment.