|
4 | 4 | import transformers |
5 | 5 | from ...helpers.config_helper import update_config |
6 | 6 | from ...tasks import reduce_model_config, random_input_kwargs |
7 | | -from .hub_api import task_from_arch, get_pretrained_config |
| 7 | +from .hub_api import task_from_arch, task_from_id, get_pretrained_config |
8 | 8 |
|
9 | 9 |
|
10 | 10 | def get_untrained_model_with_inputs( |
@@ -64,17 +64,21 @@ def get_untrained_model_with_inputs( |
64 | 64 | config = get_pretrained_config( |
65 | 65 | model_id, use_preinstalled=use_preinstalled, **(model_kwargs or {}) |
66 | 66 | ) |
| 67 | + if hasattr(config, "architecture") and config.architecture: |
| 68 | + archs = [config.architecture] |
67 | 69 | archs = config.architectures # type: ignore |
68 | | - assert archs is not None and len(archs) == 1, ( |
| 70 | + task = None |
| 71 | + if archs is None: |
| 72 | + task = task_from_id(model_id) |
| 73 | + assert task is not None or (archs is not None and len(archs) == 1), ( |
69 | 74 | f"Unable to determine the architecture for model {model_id!r}, " |
70 | 75 | f"architectures={archs!r}, conf={config}" |
71 | 76 | ) |
72 | | - arch = archs[0] |
73 | | - if verbose: |
74 | | - print(f"[get_untrained_model_with_inputs] architecture={arch!r}") |
75 | 77 | if verbose: |
| 78 | + print(f"[get_untrained_model_with_inputs] architectures={archs!r}") |
76 | 79 | print(f"[get_untrained_model_with_inputs] cls={config.__class__.__name__!r}") |
77 | | - task = task_from_arch(arch) |
| 80 | + if task is None: |
| 81 | + task = task_from_arch(archs[0]) |
78 | 82 | if verbose: |
79 | 83 | print(f"[get_untrained_model_with_inputs] task={task!r}") |
80 | 84 |
|
@@ -106,7 +110,15 @@ def get_untrained_model_with_inputs( |
106 | 110 | if inputs_kwargs: |
107 | 111 | kwargs.update(inputs_kwargs) |
108 | 112 |
|
109 | | - model = getattr(transformers, arch)(config) |
| 113 | + if archs is not None: |
| 114 | + model = getattr(transformers, archs[0])(config) |
| 115 | + else: |
| 116 | + assert same_as_pretrained, ( |
| 117 | + f"Model {model_id!r} cannot be built, the model cannot be built. " |
| 118 | + f"It must be downloaded. Use same_as_pretrained=True." |
| 119 | + ) |
| 120 | + model = None |
| 121 | + |
110 | 122 | # This line is important. Some models may produce different |
111 | 123 | # outputs even with the same inputs in training mode. |
112 | 124 | model.eval() |
|
0 commit comments