Attention Mechanisms
Imagine you’re reading a sentence. Your brain doesn’t give equal weight to every word. Instead, it focuses on the most relevant words, depending on the context. Attention mechanisms in AI mimic this behavior. They allow models to weigh the importance of different parts of the input data dynamically, leading to more accurate and nuanced results
The Scaled Dot-Product Attention
The scaled dot-product attention is a popular type of attention. It calculates a weighted sum of values (V), where the weights are determined by the similarity between query (Q) and key (K) vectors. This similarity is typically calculated using the dot product, then scaled down by a factor to stabilize training.
The following code provides a simplified implementation of the flax dot_product_attention_weights function, illustrating the key mechanisms.
import math
import jax
from jax import numpy as jnp
from flax import linen as nn
def dot_product_attention_weights(query, key, mask=None):
d_query = query.shape[-1]
d_key = q.shape[-1]
assert d_query == d_key, "query and key must have the same dimension"
# dot products between query and key
# Calculate dot products between query and key. 'q' and 'k' in the annotation
# highlight the dimensions involved in the dot product, not their overall
# shape. Both 'q' and 'k' have a shape of [..., seq_len, latent_dim].
attn_logits = jnp.einsum('...qd,...kd->...qk', query, key)
# normalize by sqrt(d_k)
attn_logits = attn_logits / math.sqrt(d_key)
if mask is not None:
big_neg = jnp.finfo(dtype).min
attn_logits = jnp.where(mask, attn_logits, big_neg)
# softmax on the key dimension
attention = nn.softmax(attn_logits, axis=-1)
# attention in dimension [..., query, key]
return attentionseq_len = 5
# Dimension of query and key, which must have the same value.
d_k = 3
qkv = jax.random.normal(jax.random.PRNGKey(123), (3, seq_len, d_k))
q, k, v = qkv[0], qkv[1], qkv[2]
# values, attention = scaled_dot_product(q, k, v)
print(f"Query of shape {q.shape}:\n {q}")
print(f"Key of shape {k.shape}:\n {k}")
print(f"Value of shape {v.shape}:\n {v}")
attention_weight = dot_product_attention_weights(q, k)
print(f"Attention weight of shape {attention_weight.shape}:\n {attention_weight}")Query of shape (5, 3):
[[ 0.9321981 -1.3114095 -2.2122283 ]
[-1.0379515 -1.092678 0.21755463]
[-0.9908671 -0.9149708 -0.5571356 ]
[ 1.0599949 1.0308888 0.46293053]
[ 0.07368055 0.43742564 -1.6710967 ]]
Key of shape (5, 3):
[[-0.18015677 -3.1749253 -1.1962503 ]
[ 0.97549754 -0.17032507 -0.48637325]
[-1.32757 -0.13875987 -0.6687496 ]
[-0.59773946 1.4172351 -0.96903974]
[-0.7421036 0.9139938 -1.0052445 ]]
Value of shape (5, 3):
[[ 1.1325437e+00 1.5071929e+00 6.8808295e-02]
[-1.9742130e+00 2.3285030e-04 -2.0178690e+00]
[ 1.5945306e-01 7.7342010e-01 8.3827895e-01]
[-1.5689812e+00 1.6362010e+00 -1.4837370e+00]
[-1.5369723e+00 8.9348221e-01 -4.8948497e-01]]
Attention weight of shape (5, 5):
[[0.8698768 0.0672719 0.02400489 0.01606247 0.02278393]
[0.6341718 0.05211585 0.19849986 0.04625175 0.06896075]
[0.60054535 0.05045933 0.19650535 0.06266445 0.08982551]
[0.02481197 0.36379632 0.08624452 0.3140895 0.21105762]
[0.12254417 0.13859314 0.15103342 0.30834934 0.27947986]]
TODO
- Write about mask
- use the attn weights in attn
- multiple head attention