|
|
| ||
|
|
|
|
with its own learned weights |
|
Then components are summed in a simple score. (Notice where source text's 3rd word vs target text's 3rd word would be) x √2
x √2
x √2
x √2
┴
-inf
┴
-inf
┴
-inf
┴
-inf
ex gives a divider != 0 (unless all inputs are negative infinity) e = 2.718281828459045… x y
x y
x y
x y
|
| ||
|
If one word has a high probablity, then the result will be mostly that word's components (with small noise added by other words). | |||
|
| |||
|
| |||
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import tensor
import math
def arange2d(w, h):
return (torch.arange(0, h)*w).unsqueeze(1) + torch.arange(1, w+1)
embed_dim = 4
num_heads = 2
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, bias=False)
# Weights to match computation in the diagram:
multihead_attn.in_proj_weight = nn.Parameter(arange2d(embed_dim, embed_dim*3) * 0.1)
multihead_attn.out_proj.weight = nn.Parameter(arange2d(embed_dim, embed_dim) * 0.1)
source_text = tensor([
[0.1, 0.2, 0.3, 0.4],
[0.5, 0.6, 0.7, 0.8],
])
target_text = tensor([
[0.9, 1.0, 1.1, 1.2],
[1.3, 1.4, 1.5, 1.6],
])
query = source_text
key = value = target_text
attn_output, _ = multihead_attn(query, key, value)
attn_output # tensor([[ 24.5519, 61.8296, 99.1073, 136.3850],
# [ 24.6951, 62.3463, 99.9974, 137.6486]])
attn_mask = tensor([[0, -torch.inf], [0, -torch.inf]])
attn_output, _ = multihead_attn(query, key, value, attn_mask=attn_mask)
attn_output # tensor([[17.9000, 45.1960, 72.4920, 99.7880],
# [17.9000, 45.1960, 72.4920, 99.7880]])