Skip to content

Commit fcf9dc3

Browse files
bobrenjc93pytorchmergebot
authored andcommitted
Migrate from Tuple -> tuple in benchmarks (pytorch#144259)
Pull Request resolved: pytorch#144259 Approved by: https://github.com/yanboliang
1 parent 2e42be0 commit fcf9dc3

File tree

12 files changed

+77
-80
lines changed

12 files changed

+77
-80
lines changed

benchmarks/dynamo/common.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
NamedTuple,
3333
Optional,
3434
Sequence,
35-
Tuple,
3635
Type,
3736
TYPE_CHECKING,
3837
)
@@ -746,7 +745,7 @@ def timed(
746745
return (time_total, result) if return_result else time_total
747746

748747

749-
def _normalize_bench_inputs(example_inputs) -> Tuple[Tuple[Any], Mapping[str, Any]]:
748+
def _normalize_bench_inputs(example_inputs) -> tuple[tuple[Any], Mapping[str, Any]]:
750749
# NOTE(bowbao): For huggingface benchmark, example_inputs are formatted as dictionary,
751750
# and consumed like `model(**example_inputs)`.
752751
# For other benchmarks, example_inputs are formatted as tuple and consumed

benchmarks/dynamo/microbenchmarks/operator_inp_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
from collections import Counter, defaultdict
66
from functools import partial
7-
from typing import Any, Dict, Generator, Iterable, Tuple
7+
from typing import Any, Dict, Generator, Iterable
88

99
import torch
1010
from torch.testing import make_tensor
@@ -263,7 +263,7 @@ def __init__(self, json_file_path):
263263

264264
def get_inputs_for_operator(
265265
self, operator, dtype=None, device="cuda"
266-
) -> Generator[Tuple[Iterable[Any], Dict[str, Any]], None, None]:
266+
) -> Generator[tuple[Iterable[Any], Dict[str, Any]], None, None]:
267267
assert (
268268
str(operator) in self.operator_db
269269
), f"Could not find {operator}, must provide overload"

benchmarks/fastrnns/cells.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Tuple
2-
31
import torch
42
from torch import Tensor
53

@@ -27,12 +25,12 @@ def milstm_cell(x, hx, cx, w_ih, w_hh, alpha, beta_i, beta_h, bias):
2725

2826
def lstm_cell(
2927
input: Tensor,
30-
hidden: Tuple[Tensor, Tensor],
28+
hidden: tuple[Tensor, Tensor],
3129
w_ih: Tensor,
3230
w_hh: Tensor,
3331
b_ih: Tensor,
3432
b_hh: Tensor,
35-
) -> Tuple[Tensor, Tensor]:
33+
) -> tuple[Tensor, Tensor]:
3634
hx, cx = hidden
3735
gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh
3836

@@ -57,7 +55,7 @@ def flat_lstm_cell(
5755
w_hh: Tensor,
5856
b_ih: Tensor,
5957
b_hh: Tensor,
60-
) -> Tuple[Tensor, Tensor]:
58+
) -> tuple[Tensor, Tensor]:
6159
gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh
6260

6361
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
@@ -75,11 +73,11 @@ def flat_lstm_cell(
7573

7674
def premul_lstm_cell(
7775
igates: Tensor,
78-
hidden: Tuple[Tensor, Tensor],
76+
hidden: tuple[Tensor, Tensor],
7977
w_hh: Tensor,
8078
b_ih: Tensor,
8179
b_hh: Tensor,
82-
) -> Tuple[Tensor, Tensor]:
80+
) -> tuple[Tensor, Tensor]:
8381
hx, cx = hidden
8482
gates = igates + torch.mm(hx, w_hh.t()) + b_ih + b_hh
8583

@@ -97,8 +95,8 @@ def premul_lstm_cell(
9795

9896

9997
def premul_lstm_cell_no_bias(
100-
igates: Tensor, hidden: Tuple[Tensor, Tensor], w_hh: Tensor, b_hh: Tensor
101-
) -> Tuple[Tensor, Tensor]:
98+
igates: Tensor, hidden: tuple[Tensor, Tensor], w_hh: Tensor, b_hh: Tensor
99+
) -> tuple[Tensor, Tensor]:
102100
hx, cx = hidden
103101
gates = igates + torch.mm(hx, w_hh.t()) + b_hh
104102

benchmarks/fastrnns/custom_lstms.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numbers
22
import warnings
33
from collections import namedtuple
4-
from typing import List, Tuple
4+
from typing import List
55

66
import torch
77
import torch.jit as jit
@@ -131,8 +131,8 @@ def __init__(self, input_size, hidden_size):
131131

132132
@jit.script_method
133133
def forward(
134-
self, input: Tensor, state: Tuple[Tensor, Tensor]
135-
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
134+
self, input: Tensor, state: tuple[Tensor, Tensor]
135+
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
136136
hx, cx = state
137137
gates = (
138138
torch.mm(input, self.weight_ih.t())
@@ -199,8 +199,8 @@ def __init__(self, input_size, hidden_size, decompose_layernorm=False):
199199

200200
@jit.script_method
201201
def forward(
202-
self, input: Tensor, state: Tuple[Tensor, Tensor]
203-
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
202+
self, input: Tensor, state: tuple[Tensor, Tensor]
203+
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
204204
hx, cx = state
205205
igates = self.layernorm_i(torch.mm(input, self.weight_ih.t()))
206206
hgates = self.layernorm_h(torch.mm(hx, self.weight_hh.t()))
@@ -225,8 +225,8 @@ def __init__(self, cell, *cell_args):
225225

226226
@jit.script_method
227227
def forward(
228-
self, input: Tensor, state: Tuple[Tensor, Tensor]
229-
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
228+
self, input: Tensor, state: tuple[Tensor, Tensor]
229+
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
230230
inputs = input.unbind(0)
231231
outputs = torch.jit.annotate(List[Tensor], [])
232232
for i in range(len(inputs)):
@@ -242,8 +242,8 @@ def __init__(self, cell, *cell_args):
242242

243243
@jit.script_method
244244
def forward(
245-
self, input: Tensor, state: Tuple[Tensor, Tensor]
246-
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
245+
self, input: Tensor, state: tuple[Tensor, Tensor]
246+
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
247247
inputs = reverse(input.unbind(0))
248248
outputs = jit.annotate(List[Tensor], [])
249249
for i in range(len(inputs)):
@@ -266,11 +266,11 @@ def __init__(self, cell, *cell_args):
266266

267267
@jit.script_method
268268
def forward(
269-
self, input: Tensor, states: List[Tuple[Tensor, Tensor]]
270-
) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]:
269+
self, input: Tensor, states: List[tuple[Tensor, Tensor]]
270+
) -> tuple[Tensor, List[tuple[Tensor, Tensor]]]:
271271
# List[LSTMState]: [forward LSTMState, backward LSTMState]
272272
outputs = jit.annotate(List[Tensor], [])
273-
output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
273+
output_states = jit.annotate(List[tuple[Tensor, Tensor]], [])
274274
# XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
275275
i = 0
276276
for direction in self.directions:
@@ -300,10 +300,10 @@ def __init__(self, num_layers, layer, first_layer_args, other_layer_args):
300300

301301
@jit.script_method
302302
def forward(
303-
self, input: Tensor, states: List[Tuple[Tensor, Tensor]]
304-
) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]:
303+
self, input: Tensor, states: List[tuple[Tensor, Tensor]]
304+
) -> tuple[Tensor, List[tuple[Tensor, Tensor]]]:
305305
# List[LSTMState]: One state per layer
306-
output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
306+
output_states = jit.annotate(List[tuple[Tensor, Tensor]], [])
307307
output = input
308308
# XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
309309
i = 0
@@ -330,11 +330,11 @@ def __init__(self, num_layers, layer, first_layer_args, other_layer_args):
330330

331331
@jit.script_method
332332
def forward(
333-
self, input: Tensor, states: List[List[Tuple[Tensor, Tensor]]]
334-
) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]]:
333+
self, input: Tensor, states: List[List[tuple[Tensor, Tensor]]]
334+
) -> tuple[Tensor, List[List[tuple[Tensor, Tensor]]]]:
335335
# List[List[LSTMState]]: The outer list is for layers,
336336
# inner list is for directions.
337-
output_states = jit.annotate(List[List[Tuple[Tensor, Tensor]]], [])
337+
output_states = jit.annotate(List[List[tuple[Tensor, Tensor]]], [])
338338
output = input
339339
# XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
340340
i = 0
@@ -370,10 +370,10 @@ def __init__(self, num_layers, layer, first_layer_args, other_layer_args):
370370

371371
@jit.script_method
372372
def forward(
373-
self, input: Tensor, states: List[Tuple[Tensor, Tensor]]
374-
) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]:
373+
self, input: Tensor, states: List[tuple[Tensor, Tensor]]
374+
) -> tuple[Tensor, List[tuple[Tensor, Tensor]]]:
375375
# List[LSTMState]: One state per layer
376-
output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
376+
output_states = jit.annotate(List[tuple[Tensor, Tensor]], [])
377377
output = input
378378
# XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
379379
i = 0

benchmarks/fastrnns/factory.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections import namedtuple
2-
from typing import List, Tuple
2+
from typing import List
33

44
import torch
55
from torch import Tensor
@@ -266,12 +266,12 @@ def forward(sequences, hidden):
266266
def varlen_lstm_factory(cell, script):
267267
def dynamic_rnn(
268268
sequences: List[Tensor],
269-
hiddens: Tuple[Tensor, Tensor],
269+
hiddens: tuple[Tensor, Tensor],
270270
wih: Tensor,
271271
whh: Tensor,
272272
bih: Tensor,
273273
bhh: Tensor,
274-
) -> Tuple[List[Tensor], Tuple[List[Tensor], List[Tensor]]]:
274+
) -> tuple[List[Tensor], tuple[List[Tensor], List[Tensor]]]:
275275
hx, cx = hiddens
276276
hxs = hx.unbind(1)
277277
cxs = cx.unbind(1)
@@ -406,12 +406,12 @@ def lstm_inputs(
406406
def lstm_factory(cell, script):
407407
def dynamic_rnn(
408408
input: Tensor,
409-
hidden: Tuple[Tensor, Tensor],
409+
hidden: tuple[Tensor, Tensor],
410410
wih: Tensor,
411411
whh: Tensor,
412412
bih: Tensor,
413413
bhh: Tensor,
414-
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
414+
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
415415
hx, cx = hidden
416416
outputs = []
417417
inputs = input.unbind(0)
@@ -432,12 +432,12 @@ def dynamic_rnn(
432432
def lstm_factory_premul(premul_cell, script):
433433
def dynamic_rnn(
434434
input: Tensor,
435-
hidden: Tuple[Tensor, Tensor],
435+
hidden: tuple[Tensor, Tensor],
436436
wih: Tensor,
437437
whh: Tensor,
438438
bih: Tensor,
439439
bhh: Tensor,
440-
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
440+
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
441441
hx, cx = hidden
442442
outputs = []
443443
inputs = torch.matmul(input, wih.t()).unbind(0)
@@ -458,12 +458,12 @@ def dynamic_rnn(
458458
def lstm_factory_premul_bias(premul_cell, script):
459459
def dynamic_rnn(
460460
input: Tensor,
461-
hidden: Tuple[Tensor, Tensor],
461+
hidden: tuple[Tensor, Tensor],
462462
wih: Tensor,
463463
whh: Tensor,
464464
bih: Tensor,
465465
bhh: Tensor,
466-
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
466+
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
467467
hx, cx = hidden
468468
outputs = []
469469
inpSize = input.size()
@@ -506,8 +506,8 @@ def dynamic_rnn(input, hx, cx, wih, whh, bih, bhh):
506506

507507
def lstm_factory_multilayer(cell, script):
508508
def dynamic_rnn(
509-
input: Tensor, hidden: Tuple[Tensor, Tensor], params: List[Tensor]
510-
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
509+
input: Tensor, hidden: tuple[Tensor, Tensor], params: List[Tensor]
510+
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
511511
params_stride = 4 # NB: this assumes that biases are there
512512
hx, cx = hidden
513513
hy, cy = hidden # for scoping...

benchmarks/functional_autograd_benchmark/torchaudio_models.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import math
55
from collections import OrderedDict
6-
from typing import Optional, Tuple
6+
from typing import Optional
77

88
import torch
99
import torch.nn.functional as F
@@ -512,7 +512,7 @@ def forward(
512512
attn_mask: Optional[torch.Tensor] = None,
513513
bias_k: Optional[torch.Tensor] = None,
514514
bias_v: Optional[torch.Tensor] = None,
515-
) -> Tuple[torch.Tensor, torch.Tensor]:
515+
) -> tuple[torch.Tensor, torch.Tensor]:
516516
r"""
517517
Args:
518518
query, key, value (Tensor): map a query and a set of key-value pairs to an output.
@@ -589,7 +589,7 @@ def forward(
589589
attn_mask: Optional[torch.Tensor] = None,
590590
bias_k: Optional[torch.Tensor] = None,
591591
bias_v: Optional[torch.Tensor] = None,
592-
) -> Tuple[torch.Tensor, torch.Tensor]:
592+
) -> tuple[torch.Tensor, torch.Tensor]:
593593
r"""Uses a scaled dot product with the projected key-value pair to update
594594
the projected query.
595595
Args:
@@ -686,7 +686,7 @@ def __init__(self, query_proj, key_proj, value_proj):
686686

687687
def forward(
688688
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
689-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
689+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
690690
r"""Projects the input sequences using in-proj layers.
691691
Args:
692692
query, key, value (Tensors): sequence to be projected

benchmarks/functional_autograd_benchmark/utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
from collections import defaultdict
2-
from typing import Callable, Dict, List, Optional, Tuple, Union
2+
from typing import Callable, Dict, List, Optional, Union
33

44
import torch
55
from torch import nn, Tensor
66

77

88
# Type helpers
9-
InputsType = Union[Tensor, Tuple[Tensor, ...]]
9+
InputsType = Union[Tensor, tuple[Tensor, ...]]
1010
# A Getter takes in a device and returns a callable and the inputs to that callable
11-
GetterReturnType = Tuple[Callable[..., Tensor], InputsType]
11+
GetterReturnType = tuple[Callable[..., Tensor], InputsType]
1212
GetterType = Callable[[torch.device], GetterReturnType]
1313
# V here refers to the v in either vjp, jvp, vhp or hvp
14-
VType = Union[None, Tensor, Tuple[Tensor, ...]]
14+
VType = Union[None, Tensor, tuple[Tensor, ...]]
1515
# Type used to store timing results. The first key is the model name, the second key
1616
# is the task name, the result is a Tuple of: speedup, mean_before, var_before, mean_after, var_after.
17-
TimingResultType = Dict[str, Dict[str, Tuple[float, ...]]]
17+
TimingResultType = Dict[str, Dict[str, tuple[float, ...]]]
1818

1919

2020
# Utilities to make nn.Module "functional"
@@ -44,7 +44,7 @@ def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:
4444
_set_nested_attr(getattr(obj, names[0]), names[1:], value)
4545

4646

47-
def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
47+
def extract_weights(mod: nn.Module) -> tuple[tuple[Tensor, ...], List[str]]:
4848
"""
4949
This function removes all the Parameters from the model and
5050
return them as a tuple as well as their original attribute names.
@@ -65,7 +65,7 @@ def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
6565
return params, names
6666

6767

68-
def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...]) -> None:
68+
def load_weights(mod: nn.Module, names: List[str], params: tuple[Tensor, ...]) -> None:
6969
"""
7070
Reload a set of weights so that `mod` can be used again to perform a forward pass.
7171
Note that the `params` are regular Tensors (that can have history) and so are left
@@ -77,7 +77,7 @@ def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...]) -
7777

7878
# Utilities to read/write markdown table-like content.
7979
def to_markdown_table(
80-
res: TimingResultType, header: Optional[Tuple[str, ...]] = None
80+
res: TimingResultType, header: Optional[tuple[str, ...]] = None
8181
) -> str:
8282
if header is None:
8383
header = ("model", "task", "mean", "var")

benchmarks/gpt_fast/generate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import itertools
33
import platform
44
import time
5-
from typing import Optional, Tuple
5+
from typing import Optional
66

77
import torchao
88
from common import Experiment, register_experiment
@@ -89,7 +89,7 @@ def prefill(
8989

9090
def decode_one_token(
9191
model: torch.nn.Module, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
92-
) -> Tuple[torch.Tensor, torch.Tensor]:
92+
) -> tuple[torch.Tensor, torch.Tensor]:
9393
# input_pos: [B, 1]
9494
assert input_pos.shape[-1] == 1
9595
logits = model(x, input_pos)

0 commit comments

Comments
 (0)