Hi, I read your paper and found the following confusing. When you're describing your ablations which culminate in ReBased it starts with
x^2 – substituting the original kernel function with a simple element-wise squaring operation, ϕ(x) = x^2.
but, this doesn't seem to be what happens in your code. See these lines
|
b_s = tl.dot(b_q, (b_k), allow_tf32=False) |
|
b_s = b_s * b_s |
My understanding of Linear Attention is the following. We need two functions: a similarity function (called sim or s) which takes two vectors and returns a scalar and a feature map (called phi typically) which takes a single vector and returns another vector (possibly of different dimension). Ignoring normalization by 1/sqrt(d) for simplicity, Linear Attention requires that
s(q, k) = dot(phi(q), phi(k))
Those lines of code I linked to correspond to defining
s(q, k) = dot(q, k)**2
The feature map phi which corresponds to this similarity function is not elementwise squaring. I.e., phi(x) = x**2 is not the corresponding feature map for that similarity function. The correct corresponding feature map is phi(x) = flatten(outer(x, x))
One could say similar things about the other variants in the ablations, including ReBased.
Am I missing something?
Hi, I read your paper and found the following confusing. When you're describing your ablations which culminate in ReBased it starts with
but, this doesn't seem to be what happens in your code. See these lines
rebased/flash_linear_attention/fla/ops/triton/rebased_fast/parallel.py
Lines 68 to 69 in 7a085b4
My understanding of Linear Attention is the following. We need two functions: a similarity function (called
simors) which takes two vectors and returns a scalar and a feature map (calledphitypically) which takes a single vector and returns another vector (possibly of different dimension). Ignoring normalization by1/sqrt(d)for simplicity, Linear Attention requires thats(q, k) = dot(phi(q), phi(k))Those lines of code I linked to correspond to defining
s(q, k) = dot(q, k)**2The feature map
phiwhich corresponds to this similarity function is not elementwise squaring. I.e.,phi(x) = x**2is not the corresponding feature map for that similarity function. The correct corresponding feature map isphi(x) = flatten(outer(x, x))One could say similar things about the other variants in the ablations, including ReBased.
Am I missing something?