|
8 | 8 | The input of the model is flattened to a 1D tensor of tokens. The model uses
|
9 | 9 | InputMetadata to extract the original 2D shape of the input.
|
10 | 10 | """
|
11 |
| -from typing import List, Optional, Tuple |
| 11 | +from typing import Any, Dict, List, Optional, Tuple |
12 | 12 |
|
13 | 13 | import torch
|
14 | 14 | from torch import nn
|
@@ -76,13 +76,12 @@ def forward(self, x):
|
76 | 76 |
|
77 | 77 | class QWenAttention(nn.Module):
|
78 | 78 |
|
79 |
| - def __init__( |
80 |
| - self, |
81 |
| - hidden_size: int, |
82 |
| - num_heads: int, |
83 |
| - max_position_embeddings: int, |
84 |
| - rope_theta: float = 10000, |
85 |
| - ): |
| 79 | + def __init__(self, |
| 80 | + hidden_size: int, |
| 81 | + num_heads: int, |
| 82 | + max_position_embeddings: int, |
| 83 | + rope_theta: float = 10000, |
| 84 | + rope_scaling: Optional[Dict[str, Any]] = None): |
86 | 85 | super().__init__()
|
87 | 86 | self.hidden_size = hidden_size
|
88 | 87 | tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
|
@@ -116,7 +115,7 @@ def __init__(
|
116 | 115 | rotary_dim=self.head_dim,
|
117 | 116 | base=rope_theta,
|
118 | 117 | max_position=max_position_embeddings,
|
119 |
| - ) |
| 118 | + rope_scaling=rope_scaling) |
120 | 119 |
|
121 | 120 | def forward(
|
122 | 121 | self,
|
@@ -144,10 +143,12 @@ def __init__(self, config: QWenConfig):
|
144 | 143 | self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
145 | 144 |
|
146 | 145 | rope_theta = getattr(config, "rope_theta", 10000)
|
| 146 | + rope_scaling = getattr(config, "rope_scaling", None) |
147 | 147 | self.attn = QWenAttention(config.hidden_size,
|
148 | 148 | config.num_attention_heads,
|
149 | 149 | config.max_position_embeddings,
|
150 |
| - rope_theta=rope_theta) |
| 150 | + rope_theta=rope_theta, |
| 151 | + rope_scaling=rope_scaling) |
151 | 152 |
|
152 | 153 | self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
153 | 154 |
|
|
0 commit comments