|
38 | 38 | import sys |
39 | 39 | import textwrap |
40 | 40 | import time |
41 | | -from typing import Any, Dict, List, Optional, Tuple, Union |
| 41 | +from typing import Dict, List, Optional, Tuple, Union |
42 | 42 |
|
43 | 43 | from .ci_helpers import ( |
44 | 44 | check_for_discrepancies_and_log_everything_into_a_json_file, |
@@ -609,36 +609,6 @@ def local_body_fn( |
609 | 609 | ] |
610 | 610 |
|
611 | 611 |
|
612 | | -def get_untrained_model(model_id: str, second_input: bool, verbose: int) -> Dict[str, Any]: |
613 | | - """ |
614 | | - Returns an untrained model. |
615 | | -
|
616 | | - :param model_id: model id |
617 | | - :param second_input: second input set |
618 | | - :param verbose: verbosity |
619 | | - :return: model and data |
620 | | - """ |
621 | | - from ..torch_models.hghub.model_inputs import get_untrained_model_with_inputs |
622 | | - |
623 | | - if model_id == "arnir0/Tiny-LLM": |
624 | | - # used to run a unit test |
625 | | - _config_reduction = None |
626 | | - else: |
627 | | - |
628 | | - def _config_reduction(config, task): |
629 | | - return {"_attn_implementation": "sdpa"} |
630 | | - |
631 | | - config_reduction = _config_reduction |
632 | | - data = get_untrained_model_with_inputs( |
633 | | - model_id, |
634 | | - verbose=verbose, |
635 | | - add_second_input=second_input, |
636 | | - config_reduction=config_reduction, |
637 | | - skip_inputs=True, |
638 | | - ) |
639 | | - return data |
640 | | - |
641 | | - |
642 | 612 | def get_inputs_for_part( |
643 | 613 | model_id: str, |
644 | 614 | part: str, |
@@ -808,35 +778,30 @@ def main( |
808 | 778 | ) |
809 | 779 | torch_dtype = get_torch_dtype_from_command_line_args(dtype) |
810 | 780 |
|
811 | | - # with torch_export_patches( |
812 | | - # patch_torch=False, |
813 | | - # patch_sympy=False, |
814 | | - # patch_transformers=True, |
815 | | - # verbose=1, |
816 | | - # stop_if_static=2, |
817 | | - ## profile=(f"{basename}.profile.html" if profile_exporter else None), |
818 | | - # custom_patches=get_patches_transformers(), |
819 | | - # ): |
820 | | - if 1: |
821 | | - if pretrained: |
822 | | - print("-- pretrained model") |
823 | | - config = AutoConfig.from_pretrained( |
824 | | - model_id, trust_remote_code=True, attn_implementation="sdpa" |
825 | | - ) |
826 | | - model = AutoModelForCausalLM.from_pretrained( |
827 | | - model_id, |
828 | | - config=config, |
829 | | - trust_remote_code=True, |
830 | | - torch_dtype=torch_dtype, |
831 | | - device_map=device, |
832 | | - attn_implementation="sdpa", |
833 | | - ).eval() |
834 | | - data = dict(model=model) |
835 | | - else: |
836 | | - print("-- random model") |
837 | | - data = get_untrained_model(model_id, second_input=second_input, verbose=1) |
838 | | - model = data["model"] |
839 | | - _config = data["configuration"] |
| 781 | + if pretrained: |
| 782 | + print("-- pretrained model") |
| 783 | + config = AutoConfig.from_pretrained( |
| 784 | + model_id, trust_remote_code=True, attn_implementation="sdpa" |
| 785 | + ) |
| 786 | + model = AutoModelForCausalLM.from_pretrained( |
| 787 | + model_id, |
| 788 | + config=config, |
| 789 | + trust_remote_code=True, |
| 790 | + torch_dtype=torch_dtype, |
| 791 | + device_map=device, |
| 792 | + attn_implementation="sdpa", |
| 793 | + ).eval() |
| 794 | + data = dict(model=model) |
| 795 | + else: |
| 796 | + print("-- random model") |
| 797 | + config = AutoConfig.from_pretrained( |
| 798 | + model_id, trust_remote_code=True, attn_implementation="sdpa" |
| 799 | + ) |
| 800 | + config.attn_implementation = "sdpa" |
| 801 | + config._attn_implementation = "sdpa" |
| 802 | + config.num_hidden_layers = 2 |
| 803 | + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) |
| 804 | + data = dict(model=model) |
840 | 805 |
|
841 | 806 | main_mod_name = model.__module__ |
842 | 807 | assert ( |
|
0 commit comments