Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit f52007e

Browse files
committed
llava e2e 1/n
1 parent 3b162e2 commit f52007e

File tree

8 files changed

+1029
-14
lines changed

8 files changed

+1029
-14
lines changed

freq_compare.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import torch
2+
from typing import Any, Dict, Optional, Tuple
3+
from torchchat.utils.build_utils import find_multiple, get_precision
4+
5+
# Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L77
6+
def hf_precompute_freqs_cis(dim: int, end: int, theta: float):
7+
freqs = 1.0 / (
8+
theta
9+
** (torch.arange(0, dim, 2, device="cpu", dtype=torch.int64).float() / dim)
10+
)
11+
# pyre-ignore Undefined attribute [16]: `float` has no attribute `device`.
12+
t = torch.arange(end, device=freqs.device, dtype=torch.int64).type_as(
13+
freqs # pyre-ignore
14+
)
15+
freqs = torch.outer(t, freqs).float() # pyre-ignore
16+
emb = torch.cat((freqs, freqs), dim=-1)
17+
freqs_cos = torch.cos(emb)
18+
freqs_sin = torch.sin(emb)
19+
return freqs_cos, freqs_sin
20+
21+
22+
def precompute_freqs_cis(
23+
n_elem: int,
24+
seq_len: int,
25+
base: int = 10000,
26+
dtype=None,
27+
rope_scaling: Optional[Dict[str, Any]] = None,
28+
):
29+
if not dtype:
30+
dtype = get_precision()
31+
freqs = 1.0 / (
32+
base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
33+
)
34+
t = torch.arange(seq_len, device=freqs.device)
35+
if rope_scaling is not None:
36+
freqs = apply_scaling(freqs, rope_scaling)
37+
freqs = torch.outer(t, freqs)
38+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
39+
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
40+
return cache.to(dtype=dtype)
41+
42+
# Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L135
43+
def rotate_half(x):
44+
"""Rotates half the hidden dims of the input."""
45+
x1 = x[..., : x.shape[-1] // 2]
46+
x2 = x[..., x.shape[-1] // 2 :]
47+
return torch.cat((-x2, x1), dim=-1)
48+
49+
50+
def hf_apply_rotary_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
51+
"""Applies Rotary Position Embedding to the query and key tensors.
52+
53+
Args:
54+
q (`torch.Tensor`): The query tensor.
55+
k (`torch.Tensor`): The key tensor.
56+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
57+
sin (`torch.Tensor`): The sine part of the rotary embedding.
58+
position_ids (`torch.Tensor`, *optional*):
59+
Deprecated and unused.
60+
unsqueeze_dim (`int`, *optional*, defaults to 1):
61+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
62+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
63+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
64+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
65+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
66+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
67+
Returns:
68+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
69+
"""
70+
cos = cos.unsqueeze(unsqueeze_dim)
71+
sin = sin.unsqueeze(unsqueeze_dim)
72+
q_embed = (q * cos) + (rotate_half(q) * sin)
73+
k_embed = (k * cos) + (rotate_half(k) * sin)
74+
return q_embed, k_embed
75+
76+
def apply_rotary_emb(x, freqs_cis):
77+
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
78+
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
79+
x_out2 = torch.stack(
80+
[
81+
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
82+
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
83+
],
84+
-1,
85+
)
86+
87+
x_out2 = x_out2.flatten(3)
88+
return x_out2.type_as(x)
89+
90+
91+
# 比较函数
92+
def compare_methods():
93+
torch.manual_seed(0)
94+
x = torch.randn(1, 636, 32, 128)
95+
96+
# 设置参数
97+
n_elem = 128
98+
seq_len = 1536
99+
base = 10000
100+
dtype = None
101+
rope_scaling = None
102+
103+
all_freq_cis = precompute_freqs_cis(n_elem, seq_len, base, dtype, rope_scaling)
104+
input_pos = torch.arange(
105+
x.shape[1],
106+
device=x.device,
107+
dtype=torch.int,
108+
)
109+
freq_cis = all_freq_cis[input_pos]
110+
x_out1 = apply_rotary_emb(x, freq_cis)
111+
112+
113+
dim = 128
114+
end = 1536
115+
theta = 10000.0
116+
freqs_cos, freqs_sin = hf_precompute_freqs_cis(dim, end, theta)
117+
fc, fs = freqs_cos[:x.shape[1]], freqs_sin[:x.shape[1]]
118+
x_out2, _ = hf_apply_rotary_emb(x, x, fc, fs)
119+
120+
print(x_out1)
121+
print("************************")
122+
print(x_out2)
123+
124+
125+
if __name__ == "__main__":
126+
compare_methods()

0 commit comments

Comments
 (0)