|
| ||
|
|
|
with its own learned weights |
Then components are summed in a simple score. (Notice where source text's 3rd word vs target text's 3rd word would be) x √2
x √2
x √2
x √2
┴
-inf
┴
-inf
┴
-inf
┴
-inf
ex gives a divider != 0 (unless all inputs are negative infinity) e = 2.718281828459045… x y
x y
x y
x y
|
| ||
If one word has a high probablity, then the result will be mostly that word's components (with small noise added by other words). | |||
| |||
|
import torch import torch.nn as nn import torch.nn.functional as F from torch import tensor import math def arange2d(w, h): return (torch.arange(0, h)*w).unsqueeze(1) + torch.arange(1, w+1) embed_dim = 4 num_heads = 2 multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, bias=False) # Weights to match computation in the diagram: multihead_attn.in_proj_weight = nn.Parameter(arange2d(embed_dim, embed_dim*3) * 0.1) multihead_attn.out_proj.weight = nn.Parameter(arange2d(embed_dim, embed_dim) * 0.1) source_text = tensor([ [0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], ]) target_text = tensor([ [0.9, 1.0, 1.1, 1.2], [1.3, 1.4, 1.5, 1.6], ]) query = source_text key = value = target_text attn_output, _ = multihead_attn(query, key, value) attn_output # tensor([[ 24.5519, 61.8296, 99.1073, 136.3850], # [ 24.6951, 62.3463, 99.9974, 137.6486]]) attn_mask = tensor([[0, -torch.inf], [0, -torch.inf]]) attn_output, _ = multihead_attn(query, key, value, attn_mask=attn_mask) attn_output # tensor([[17.9000, 45.1960, 72.4920, 99.7880], # [17.9000, 45.1960, 72.4920, 99.7880]])