|
34 | 34 | try: |
35 | 35 | # TODO: remove this after we figure out where in torchtune an `evaluate` module |
36 | 36 | # is being imported, which is being confused with huggingface's `evaluate``. |
37 | | - import lm_eval # noqa |
| 37 | + import lm_eval # noqa |
38 | 38 | except Exception: |
39 | 39 | pass |
40 | 40 |
|
@@ -278,6 +278,9 @@ class TransformerArgs: |
278 | 278 | # For pipeline parallel |
279 | 279 | n_stages: int = 1 |
280 | 280 | stage_idx: int = 0 |
| 281 | + # Optional biases |
| 282 | + attention_bias: bool = False |
| 283 | + feed_forward_bias: bool = False |
281 | 284 |
|
282 | 285 | def __post_init__(self): |
283 | 286 | if self.n_local_heads == -1: |
@@ -394,7 +397,7 @@ def from_name(cls, name: str): |
394 | 397 | config = [ |
395 | 398 | config |
396 | 399 | for config in known_model_params |
397 | | - if config in str(name).upper() or config in str(name) |
| 400 | + if config.upper() in str(name).upper() or config in str(name) |
398 | 401 | ] |
399 | 402 |
|
400 | 403 | # We may have two or more configs matched (e.g., "7B" and |
@@ -471,7 +474,7 @@ def build_model(self) -> nn.Module: |
471 | 474 | modules[name] = module_class(TransformerArgs.from_params(config_args)) |
472 | 475 | else: |
473 | 476 | modules[name] = module_class(**config_args) |
474 | | - |
| 477 | + |
475 | 478 | # Temporary add extra params to the DeepFusionModel. |
476 | 479 | # TODO: Remove it once we can make fusion model configurable in model_param. |
477 | 480 | if recipe.fusion_class == DeepFusionModel: |
@@ -730,16 +733,16 @@ def __init__(self, config: TransformerArgs): |
730 | 733 |
|
731 | 734 | # key, query, value projections for all heads, but in a batch |
732 | 735 | # total_head_dim = (config.n_heads + 2 * config.n_local_heads) * config.head_dim |
733 | | - # self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) |
734 | | - self.wq = nn.Linear(config.dim, config.n_heads * config.head_dim, bias=False) |
| 736 | + # self.wqkv = nn.Linear(config.dim, total_head_dim, bias=config.attention_bias) |
| 737 | + self.wq = nn.Linear(config.dim, config.n_heads * config.head_dim, bias=config.attention_bias) |
735 | 738 | self.wk = nn.Linear( |
736 | | - config.dim, config.n_local_heads * config.head_dim, bias=False |
| 739 | + config.dim, config.n_local_heads * config.head_dim, bias=config.attention_bias |
737 | 740 | ) |
738 | 741 | self.wv = nn.Linear( |
739 | | - config.dim, config.n_local_heads * config.head_dim, bias=False |
| 742 | + config.dim, config.n_local_heads * config.head_dim, bias=config.attention_bias |
740 | 743 | ) |
741 | 744 |
|
742 | | - self.wo = nn.Linear(config.dim, config.dim, bias=False) |
| 745 | + self.wo = nn.Linear(config.dim, config.dim, bias=config.attention_bias) |
743 | 746 | self.kv_cache = None |
744 | 747 |
|
745 | 748 | self.n_heads = config.n_heads |
@@ -852,9 +855,9 @@ def forward( |
852 | 855 | class FeedForward(nn.Module): |
853 | 856 | def __init__(self, config: TransformerArgs) -> None: |
854 | 857 | super().__init__() |
855 | | - self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False) |
856 | | - self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False) |
857 | | - self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False) |
| 858 | + self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=config.feed_forward_bias) |
| 859 | + self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=config.feed_forward_bias) |
| 860 | + self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=config.feed_forward_bias) |
858 | 861 |
|
859 | 862 | def distribute(self, device_mesh: DeviceMesh): |
860 | 863 | parallelize_module(self.w1, device_mesh, ColwiseParallel()) |
|
0 commit comments