Skip to content

Lack of clarity about sim function vs feature map for paper/code #1

@deklanw

Description

@deklanw

Hi, I read your paper and found the following confusing. When you're describing your ablations which culminate in ReBased it starts with

x^2 – substituting the original kernel function with a simple element-wise squaring operation, ϕ(x) = x^2.

but, this doesn't seem to be what happens in your code. See these lines

b_s = tl.dot(b_q, (b_k), allow_tf32=False)
b_s = b_s * b_s

My understanding of Linear Attention is the following. We need two functions: a similarity function (called sim or s) which takes two vectors and returns a scalar and a feature map (called phi typically) which takes a single vector and returns another vector (possibly of different dimension). Ignoring normalization by 1/sqrt(d) for simplicity, Linear Attention requires that

s(q, k) = dot(phi(q), phi(k))

Those lines of code I linked to correspond to defining

s(q, k) = dot(q, k)**2

The feature map phi which corresponds to this similarity function is not elementwise squaring. I.e., phi(x) = x**2 is not the corresponding feature map for that similarity function. The correct corresponding feature map is phi(x) = flatten(outer(x, x))

One could say similar things about the other variants in the ablations, including ReBased.

Am I missing something?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions