Skip to content

Commit c6031a8

Browse files
committed
Make some api's compatible with python3.7+
1 parent b79f641 commit c6031a8

File tree

20 files changed

+63
-54
lines changed

20 files changed

+63
-54
lines changed

official/nlp/modeling/layers/transformer_encoder_block.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
"""Keras-based TransformerEncoder block layer."""
16-
from typing import Any, Optional, Sequence
16+
from typing import Any, Optional, Sequence, Union
1717
from absl import logging
1818
import tensorflow as tf, tf_keras
1919

@@ -28,7 +28,7 @@ class RMSNorm(tf_keras.layers.Layer):
2828

2929
def __init__(
3030
self,
31-
axis: int | Sequence[int] = -1,
31+
axis: Union[int , Sequence[int]] = -1,
3232
epsilon: float = 1e-6,
3333
**kwargs,
3434
):
@@ -43,7 +43,7 @@ def __init__(
4343
self.axis = [axis] if isinstance(axis, int) else axis
4444
self.epsilon = epsilon
4545

46-
def build(self, input_shape: tf.TensorShape | Sequence[int | None]):
46+
def build(self, input_shape: Union[tf.TensorShape, Sequence[Union[int, None]]]):
4747
input_shape = tf.TensorShape(input_shape)
4848
scale_shape = [1] * input_shape.rank
4949
for dim in self.axis:

official/recommendation/uplift/layers/uplift_networks/base_uplift_networks.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
from official.recommendation.uplift import types
2222

23+
from typing import Union
24+
2325

2426
class BaseTwoTowerUpliftNetwork(tf_keras.layers.Layer, metaclass=abc.ABCMeta):
2527
"""Abstract class for uplift layers that compute control and treatment logits.
@@ -33,7 +35,7 @@ class BaseTwoTowerUpliftNetwork(tf_keras.layers.Layer, metaclass=abc.ABCMeta):
3335
def call(
3436
self,
3537
inputs: types.DictOfTensors,
36-
training: bool | None = None,
37-
mask: tf.Tensor | None = None,
38+
training: Union[bool, None] = None,
39+
mask: Union[tf.Tensor, None] = None,
3840
) -> types.TwoTowerTrainingOutputs:
3941
raise NotImplementedError()

official/recommendation/uplift/metrics/label_mean.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from official.recommendation.uplift import types
2020
from official.recommendation.uplift.metrics import treatment_sliced_metric
2121

22+
from typing import Union
2223

2324
@tf_keras.utils.register_keras_serializable(package="Uplift")
2425
class LabelMean(tf_keras.metrics.Metric):
@@ -71,7 +72,7 @@ def update_state(
7172
self,
7273
y_true: tf.Tensor,
7374
y_pred: types.TwoTowerTrainingOutputs,
74-
sample_weight: tf.Tensor | None = None,
75+
sample_weight: Union[tf.Tensor, None] = None,
7576
):
7677
"""Updates the overall, control and treatment label means.
7778

official/recommendation/uplift/metrics/label_variance.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from official.recommendation.uplift.metrics import treatment_sliced_metric
2121
from official.recommendation.uplift.metrics import variance
2222

23+
from typing import Union
2324

2425
@tf_keras.utils.register_keras_serializable(package="Uplift")
2526
class LabelVariance(tf_keras.metrics.Metric):
@@ -72,7 +73,7 @@ def update_state(
7273
self,
7374
y_true: tf.Tensor,
7475
y_pred: types.TwoTowerTrainingOutputs,
75-
sample_weight: tf.Tensor | None = None,
76+
sample_weight: Union[tf.Tensor, None] = None,
7677
):
7778
"""Updates the overall, control and treatment label variances.
7879

official/recommendation/uplift/metrics/metric_configs.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616

1717
from collections.abc import Mapping
1818
import dataclasses
19-
from typing import Any
19+
from typing import Any, Union
2020

2121
from official.core.config_definitions import base_config
2222

2323

24-
@dataclasses.dataclass(kw_only=True)
24+
@dataclasses.dataclass
2525
class SlicedMetricConfig(base_config.Config):
2626
"""Sliced metric configuration.
2727
@@ -33,9 +33,9 @@ class SlicedMetricConfig(base_config.Config):
3333
values to slice on.
3434
"""
3535

36-
slicing_feature: str | None = None
37-
slicing_spec: Mapping[str, int] | None = None
38-
slicing_feature_dtype: str | None = None
36+
slicing_feature: Union[str, None] = None
37+
slicing_spec: Union[Mapping[str, int], None] = None
38+
slicing_feature_dtype: Union[str, None ]= None
3939

4040
def __post_init__(
4141
self, default_params: dict[str, Any], restrictions: list[str]

official/recommendation/uplift/metrics/sliced_metric.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
import tensorflow as tf, tf_keras
2020

21+
from typing import Union
22+
2123

2224
class SlicedMetric(tf_keras.metrics.Metric):
2325
"""A metric sliced by integer, boolean, or string features.
@@ -66,9 +68,9 @@ class SlicedMetric(tf_keras.metrics.Metric):
6668
def __init__(
6769
self,
6870
metric: tf_keras.metrics.Metric,
69-
slicing_spec: dict[str, str] | dict[str, int],
70-
slicing_feature_dtype: tf.DType | None = None,
71-
name: str | None = None,
71+
slicing_spec: Union[dict[str, str], dict[str, int]],
72+
slicing_feature_dtype: Union[tf.DType, None] = None,
73+
name: Union[str, None] = None,
7274
):
7375
"""Initializes the instance.
7476
@@ -123,7 +125,7 @@ def __init__(
123125
def update_state(
124126
self,
125127
*args: tf.Tensor,
126-
sample_weight: tf.Tensor | None = None,
128+
sample_weight: Union[tf.Tensor, None] = None,
127129
slicing_feature: tf.Tensor,
128130
**kwargs,
129131
):

official/recommendation/uplift/metrics/treatment_fraction.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from official.recommendation.uplift import types
2020

21+
from typing import Union
2122

2223
@tf_keras.utils.register_keras_serializable(package="Uplift")
2324
class TreatmentFraction(tf_keras.metrics.Metric):
@@ -57,7 +58,7 @@ def update_state(
5758
self,
5859
y_true: tf.Tensor,
5960
y_pred: types.TwoTowerTrainingOutputs,
60-
sample_weight: tf.Tensor | None = None,
61+
sample_weight: Union[tf.Tensor, None] = None,
6162
) -> None:
6263
"""Updates the treatment fraction.
6364

official/recommendation/uplift/metrics/uplift_mean.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from official.recommendation.uplift import types
2020
from official.recommendation.uplift.metrics import treatment_sliced_metric
2121

22+
from typing import Union
2223

2324
@tf_keras.utils.register_keras_serializable(package="Uplift")
2425
class UpliftMean(tf_keras.metrics.Metric):
@@ -68,7 +69,7 @@ def update_state(
6869
self,
6970
y_true: tf.Tensor,
7071
y_pred: types.TwoTowerTrainingOutputs,
71-
sample_weight: tf.Tensor | None = None,
72+
sample_weight: Union[tf.Tensor, None] = None,
7273
) -> None:
7374
"""Updates the overall, control and treatment uplift means.
7475

official/recommendation/uplift/types.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515
"""Defines types used by the keras uplift modeling library."""
1616

1717
import tensorflow as tf, tf_keras
18+
from typing import Union
1819

19-
TensorType = tf.Tensor | tf.SparseTensor | tf.RaggedTensor
20+
TensorType = Union[tf.Tensor, tf.SparseTensor, tf.RaggedTensor]
2021

2122
ListOfTensors = list[TensorType]
2223
TupleOfTensors = tuple[TensorType, ...]
2324
DictOfTensors = dict[str, TensorType]
2425

25-
CollectionOfTensors = ListOfTensors | TupleOfTensors | DictOfTensors
26+
CollectionOfTensors = Union[ListOfTensors, TupleOfTensors, DictOfTensors]
2627

2728

2829
class TwoTowerNetworkOutputs(tf.experimental.ExtensionType):

official/vision/configs/retinanet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class Parser(hyperparams.Config):
5353
match_threshold: float = 0.5
5454
unmatched_threshold: float = 0.5
5555
aug_rand_hflip: bool = False
56-
aug_rand_jpeg: common.RandJpegQuality | None = None
56+
aug_rand_jpeg: Union[common.RandJpegQuality, None] = None
5757
aug_scale_min: float = 1.0
5858
aug_scale_max: float = 1.0
5959
skip_crowd_during_training: bool = True

0 commit comments

Comments
 (0)