Limitation of RNN¶

  • RNN must remember the entire encoded input in a single hidden state before passing it to the decoder
    • RNN can’t directly access earlier hidden states from the encoder during the decoding phase.
    • Consequently, it relies solely on the current hidden state, which encapsulates all relevant information.
    • This can lead to a loss of context, especially in complex sentences where dependencies might span long distances.
In [1]:
# example of comuputing attention scores
import torch
inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your (x^1)
    [0.55, 0.87, 0.66], # journey (x^2)
    [0.57, 0.85, 0.64], # starts (x^3)
    [0.22, 0.58, 0.33], # with (x^4)
    [0.77, 0.25, 0.10], # one (x^5)
    [0.05, 0.80, 0.55]] # step (x^6)
)
print(inputs.shape)
torch.Size([6, 3])
In [2]:
query = inputs[1] # query = inputs[1]
print(query)
print(query)
tensor([0.5500, 0.8700, 0.6600])
tensor([0.5500, 0.8700, 0.6600])
In [3]:
attn_scores_2 = torch.empty(inputs.shape[0])
# doing dot product for query with all other vectors
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query)
print(attn_scores_2)
tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])
In [4]:
# nomalize attention scores s.t. sum up to 1
attn_weights_2_tmp = attn_scores_2/attn_scores_2.sum()
print('Attention weights:', attn_weights_2_tmp)
print("sum:", attn_weights_2_tmp.sum())
Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
sum: tensor(1.0000)
In [5]:
# use softmax to normalize
def softmax_naive(x):
    # x are expected to be a vecotr (dim = 0)
    return torch.exp(x) / torch.exp(x).sum(dim = 0)

attn_weights_2_naive = softmax_naive(attn_scores_2)
print("Attention weights:", attn_weights_2_naive)
print("Sum:", attn_weights_2_naive.sum())
Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)
In [6]:
attn_weights_2 = torch.softmax(attn_scores_2, dim = 0)
print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())
Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)
In [7]:
# compute context vector using attn weights
query = inputs[1]
context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i]*x_i
print(context_vec_2)
tensor([0.4419, 0.6515, 0.5683])

How to calculate context vecctor:¶

  1. turn inputs sequence (texts) into embeddings representation:

"Your journey starts with one step"

$\downarrow$

inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your (x^1)
    [0.55, 0.87, 0.66], # journey (x^2)
    [0.57, 0.85, 0.64], # starts (x^3)
    [0.22, 0.58, 0.33], # with (x^4)
    [0.77, 0.25, 0.10], # one (x^5)
    [0.05, 0.80, 0.55]] # step (x^6)
)
  1. Define query vector, here, we choose the second one as query

    $x^{(2)}\downarrow$

    query = inputs[1] = torch.tensor([0.55, 0.87, 0.66])
    
  2. calculate attention weights:

    [0.55, 0.87, 0.66] dot [0.43, 0.15, 0.89] = 0.1385
    [0.55, 0.87, 0.66] dot [0.55, 0.87, 0.66] = 0.2379
    [0.55, 0.87, 0.66] dot [0.57, 0.85, 0.64] = 0.2333
    [0.55, 0.87, 0.66] dot [0.22, 0.58, 0.33] = 0.1240
    [0.55, 0.87, 0.66] dot [0.77, 0.25, 0.10] = 0.1082
    [0.55, 0.87, 0.66] dot [0.05, 0.80, 0.55] = 0.1581
    

    $\downarrow$

    attn_weights = [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581]
    
  3. calculate context vector:

    [0.43, 0.15, 0.89] * 0.1385 + 
    [0.55, 0.87, 0.66] * 0.2379 +
    [0.57, 0.85, 0.64] * 0.2333 +
    [0.22, 0.58, 0.33] * 0.1240 +
    [0.77, 0.25, 0.10] * 0.1082 +
    [0.05, 0.80, 0.55] * 0.1581
    

    $\downarrow$

    = context_vec = tensor([0.4419, 0.6515, 0.5683])
    
In [8]:
# Computing attention weights for all input tokens
attn_scores = torch.empty(6, 6)
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)
print(attn_scores)
tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])
In [9]:
# achieve the same with matrix multiplication
attn_scores = inputs @ inputs.T
print(attn_scores)
tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])
In [10]:
# normalization
# shape = (6, 6) = [rows, columns]
# dim = -1 -> softmax across columns -> rows sum = 1
attn_weights = torch.softmax(attn_scores, dim = -1)
print(attn_weights)
print(attn_weights.shape)
tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])
torch.Size([6, 6])
In [11]:
# calculate all context vectors:
all_context_vecs = attn_weights @ inputs
print(all_context_vecs)
tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])
  • Note that in GPT-like models, the input and output dimensions are usually the same, but to better follow the computation, we’ll use different input ($d_{in}=3$) and output

($d_{out}=2$) dimensions here.

In [12]:
# self-attn with trainable weights
# start with 1 query vector
# first set params and initialize
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
In [13]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value

print(query_2)
tensor([0.4306, 1.4551])
  • Even though our temporary goal is only to compute the one context vector, $z^{(2)}$, we still require the key and value vectors for all input elements as they are involved in computing the attention weights with respect to the query $q^{(2)}$
In [14]:
# obtain all keys and values via matrix multiplication
keys = inputs @ W_key
values = inputs @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)
keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])
In [15]:
# calculate attn scores
# start with example: q2 k2
keys_2 = keys[1]
attn_score_22 = query_2.dot(keys_2)
print(attn_score_22)
tensor(1.8524)
In [16]:
# get all attn scores via matrix multiplication
# attn scores for query 2
attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)
tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])
In [17]:
# get attn weights via scaling and softmax
d_k = keys.shape[-1] # which is d_out
attn_weights_2 = torch.softmax(attn_scores_2/d_k**0.5, dim = -1)
print(attn_weights_2)
tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])
In [18]:
# compute context vector for vector 2 in our inputs
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)
tensor([0.3061, 0.8210])
In [19]:
# generalize the above process to cal context vecs for all input vector
import torch.nn as nn
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))
    
    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim = -1)
        context_vec = attn_weights @ values
        return context_vec
In [20]:
# test
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))
print(f'context vector for input2: {context_vec_2}')
tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)
context vector for input2: tensor([0.3061, 0.8210])
  • We can improve the SelfAttention_v1 implementation further by utilizing PyTorch’s nn.Linear layers, which effectively perform matrix multiplication when the bias units are disabled.

  • Additionally, a significant advantage of using nn.Linear instead of manually implementing nn.Parameter(torch.rand(...)) is that nn.Linear has an optimized weight initialization scheme, contributing to more stable and effective model training.

In [21]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
        attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        context_vec = attn_weights @ values
        return context_vec
In [22]:
# test sa_v2
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))
tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)
  • Exercise 3.1 Comparing SelfAttention_v1 and SelfAttention_v2

*Hiding future words with causal attention (masked attention)*¶

In [23]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim = -1)
print(attn_weights)
tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)
In [24]:
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)
tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])
In [25]:
masked_simple = attn_weights*mask_simple
print(masked_simple)
tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)
In [26]:
# normalize rows to sum up to 1
row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)

procedure summary:¶

Calculate attn_scores ($QK^T$) $\rightarrow$ softmax across row to get attn_weights (softmax($\frac{QK^T}{\sqrt{d_{out}}}$))$\rightarrow$ use up-tri to mask $\rightarrow$ re-normalize using row-sum = masked attn_weights masked (softmax($\frac{QK^T}{\sqrt{d_{out}}}$))

*Information leakage*¶

When we apply a mask and then renormalize the attention weights, it might initially appear that information from future tokens (which we intend to mask) could still influence the current token because their values are part of the softmax calculation. However, the key insight is that when we renormalize the attention weights after masking, what we’re essentially doing is recalculating the softmax over a smaller subset (since masked positions don’t contribute to the softmax value).

The mathematical elegance of softmax is that despite initially including all positions in the denominator, after masking and renormalizing, the effect of the masked positions is nullified—they don’t contribute to the softmax score in any meaningful way.

In simpler terms, after masking and renormalization, the distribution of attention weights is as if it was calculated only among the unmasked positions to begin with. This ensures there’s no information leakage from future (or otherwise masked) tokens as we intended.

In [27]:
# a more efficient way of masking
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)
tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)
In [28]:
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim = -1)
print(attn_weights)
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)
In [29]:
# masked context vector:
context_vec = attn_weights @ values
print(context_vec)
tensor([[0.1855, 0.8812],
        [0.2795, 0.9361],
        [0.3133, 0.9508],
        [0.2994, 0.8595],
        [0.2702, 0.7554],
        [0.2772, 0.7618]], grad_fn=<MmBackward0>)
In [30]:
# dropout
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
example = torch.ones(6, 6)
print(dropout(example))
tensor([[2., 2., 2., 2., 2., 2.],
        [0., 2., 0., 0., 0., 0.],
        [0., 0., 2., 0., 2., 0.],
        [2., 2., 0., 0., 0., 2.],
        [2., 0., 0., 0., 0., 2.],
        [0., 2., 0., 0., 0., 0.]])

*compensate for dropout*¶

To compensate for the reduction in active elements, the values of the remaining elements in the matrix are scaled up by a factor of 1/0.5 = 2. This scaling is crucial to maintain the overall balance of the attention weights, ensuring that the average influence of the attention mechanism remains consistent during both the training and inference phases.

Compensation is done automatically by the nn.Dropout module.

dropout outputs may look different depending on your operating system: https://github.com/pytorch/pytorch/issues/121595

In [31]:
torch.manual_seed(123)
print(dropout(attn_weights))
tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.8966, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4921, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4350, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3327, 0.0000, 0.0000, 0.0000, 0.0000]],
       grad_fn=<MulBackward0>)

*Implementing a compact causal attention class*¶

In [32]:
# adapting to batch
# to simulate batch, we stack inputs along rows (dim = 0)
batch = torch.stack((inputs, inputs), dim = 0)
print(batch.shape)
torch.Size([2, 6, 3])
In [33]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length,
    dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        # new: add dropout layer
        self.dropout = nn.Dropout(dropout)
        # buffers are automatically moved to the appropriate device (CPU or GPU) along with our model
        # 注意,下三角矩阵不属于nn.Module,是我们自己创建的,因此会存在device问题
        self.register_buffer(
        'mask',
        torch.triu(torch.ones(context_length, context_length),
        diagonal=1)
        )
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.transpose(1, 2)

        # In PyTorch, operations
        # with a trailing underscore
        # are performed in-place,
        # avoiding unnecessary
        # memory copies.

        attn_scores.masked_fill_(
        self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)

        attn_weights = torch.softmax(
        attn_scores / keys.shape[-1]**0.5, dim=-1
        )

        attn_weights = self.dropout(attn_weights)
        context_vec = attn_weights @ values
        
        return context_vec
  • 为什么需要mask[:num_tokens, :num_tokens]: 因为在真正的模型中,我们都是以batch作为一个单位进行计算,同一个batch的sequence会被pad到相同的长度,因此可能存在batch内的sequence的sequence length不一样的情况,我们需要根据不同的sequence length调整mask的尺寸
In [34]:
# test
torch.manual_seed(123)
context_length = batch.shape[1]
print("context length:", context_length)
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)
print(context_vecs)
context length: 6
context_vecs.shape: torch.Size([2, 6, 2])
tensor([[[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]],

        [[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)
In [38]:
# stacking multiple single-head attention layers (multi-head)
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias = False):
        super().__init__()
        self.heads = nn.ModuleList(
            [
                CausalAttention(
                    d_in, d_out, context_length, dropout, qkv_bias
                )
                for _ in range(num_heads)
            ]
        )
    def forward(self, x):
        # concat along column dimension
        return torch.cat([head(x) for head in self.heads], dim = -1)
In [39]:
# test multihead output
torch.manual_seed(123)
# the number of tokens/sequence length
context_length = batch.shape[1]
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads = 2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)
tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])
In [49]:
# a more efficicient way to implemeny multihead attention
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias = False):
        super().__init__()
        assert (d_out % num_heads == 0), 'd_out must be divisible by num_heads'

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.W_query = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias = qkv_bias)

        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal = 1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        # split the matrix by adding a num_heads dimension. Then we unroll the last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim).
        # print(f'shapes of keys: {keys}')
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        # print(f'shapes of keys: {keys}')
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transposes from shape (b, num_tokens, num_heads, head_dim) to (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        # print(f'shapes of keys: {keys}')
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        attn_scores = queries @ keys.transpose(2, 3)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5, dim = -1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2)

        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)

        # to adjust output_dim, so that output_dim = input_dim, this is required for residual connection
        context_vec = self.out_proj(context_vec)
        return context_vec

为什么在Causal Self-Attention中需要投影层?¶

在因果自注意力(Causal Self-Attention)中,投影层(Projection Layer)的作用至关重要。结合具体例子,我们可以从以下几个方面解释其必要性:


1. 整合多头信息¶

自注意力机制通常采用​​多头注意力(Multi-Head Attention)​​,每个头会从不同子空间学习特征。例如:

  • 输入序列:["I", "love", "music"]
  • 不同注意力头的关注点:
    • ​​头1​​:语法关系(如主谓结构)
    • ​​头2​​:情感极性(如“love”是正向词)
    • ​​头3​​:长程依赖(如跨句子的指代关系)

​​投影层的作用​​:
将多个头的输出拼接后,通过线性变换(投影)整合为一个统一的表示。例如:

  • 每个头输出维度为64,3个头拼接后维度为192
  • 投影层将其压缩回64维,避免冗余并统一特征空间。

2. 维度对齐与特征选择¶

假设输入维度是 d_model=512,经过自注意力计算后:

# 输入维度
input_shape = [batch_size, seq_len, 512]

# 多头输出(假设3个头,每个头维度64)
multi_head_output = [batch_size, seq_len, 3 * 64=192]

投影层的作用: 通过可学习的权重矩阵将维度映射回512,以支持残差连接

projected_output = dense_layer(multi_head_output)  # 192 → 512
output = input + projected_output  # 残差连接

3. 增强非线性表达能力¶

自注意力本质是线性加权求和,而投影层能:

  • ​​引入非线性​​(通过后续激活函数)
  • ​​参数化学习更复杂的特征组合​​

​​例子​​:
生成任务中对“music”一词的表示需要综合:

  • 语法(头1)
  • 情感(头2)
  • 上下文(头3)

4. 与后续层兼容¶

Transformer的模块化设计要求各层输入输出维度一致:

plaintext
Causal Self-Attention → Projection Layer → Feed-Forward Network
In [50]:
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads = 2)

context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape", context_vecs.shape)
tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vecs.shape torch.Size([2, 6, 2])

为什么不同的头可以关注到不同的信息:通过损失函数驱动多头学习互补特征的机制¶

在Transformer的多头注意力中,尽管每个头的结构相同,但模型通过​​损失函数的反向传播​​和​​参数独立更新​​,强制不同头学习互补特征。以下通过具体例子分步解释这一过程:


1. 初始状态:参数随机性导致注意力分化¶

假设一个3头注意力模型处理句子:
"The cat sat on the mat because it was tired."

​​初始条件​​:

  • 每个头的参数矩阵 $ W_Q^h, W_K^h, W_V^h $ 随机初始化。
  • ​​头1​​可能偶然关注局部语法(如 cat → sat),
  • ​​头2​​可能偶然关联指代关系(如 it → cat),
  • ​​头3​​可能捕捉语义角色(如 tired → cat)。

2. 损失函数驱动参数更新¶

假设任务是​​预测下一个词​​(如预测 tired):
模型计算总损失(如交叉熵),并通过反向传播更新参数。

关键机制:¶

  • ​​梯度竞争​​:
    如果多个头学习相同特征(例如都关注局部语法),它们的梯度更新方向会重叠,导致对损失的贡献重复(冗余)。

  • ​​损失最小化压力​​:
    模型为了最小化损失,会迫使不同头通过​​参数分化​​捕捉不同特征。例如:

    • 若头1已较好捕捉主谓一致(cat → sat),损失函数会降低对头1的梯度更新强度,
    • 梯度将更强烈地推动头2和头3学习其他有用特征(如指代或语义)。

3. 具体示例:机器翻译中的分工¶

以翻译 "He ate an apple" → "他吃了一个苹果" 为例:

多头注意力分工:¶

  • ​​头1​​:学习主谓关系(He → ate)
  • ​​头2​​:捕捉时态(ate → eat的过去式)
  • ​​头3​​:关联数量一致性(an → 一个)

损失函数的强制分化:¶

  • 如果头1和头2都关注时态,模型在翻译时会重复修正时态错误,导致梯度冲突。
  • 损失函数通过反向传播,使冗余的头调整参数转向学习其他特征(如头2转向学习冠词选择)。

4. 数学视角:参数空间的分化¶

假设两个头 $ h_1 $ 和 $ h_2 $ 的查询矩阵 $ W_Q^1 $ 和 $ W_Q^2 $ 初始化为随机值:

  • ​​训练前​​:
    $ W_Q^1 $ 和 $ W_Q^2 $ 的向量空间方向接近随机分布。

  • ​​训练后​​:
    损失函数迫使 $ W_Q^1 $ 的向量偏向捕捉句法(如主语-动词),
    $ W_Q^2 $ 的向量偏向捕捉语义(如动作-对象关系),
    两者在参数空间中逐渐正交化。


5. 实验证据¶

  • ​​可视化研究​​:
    在BERT的可视化中,不同头确实分别关注句法、语义、共指等不同模式 (Clark et al., 2019)。

  • ​​剪枝实验​​:
    若随机关闭部分注意力头,模型性能下降幅度远小于关闭所有同类头,说明多头之间存在功能互补。


总结¶

机制 效果
​​参数随机初始化​​ 为不同头提供初始分化基础
​​梯度竞争​​ 冗余头的梯度更新被抑制,迫使参数转向新方向
​​损失最小化压力​​ 模型优先学习对当前任务最有效的特征组合
​​子空间正交化​​ 不同头的参数在向量空间中逐渐分化,覆盖更全面的特征

最终,多头注意力通过损失函数的“无形之手”,实现了类似​​分工协作​​的效果,从而高效捕捉复杂语言现象。