@@ -1555,15 +1555,32 @@ def call(self, inputs, inputs_positions=None):
1555
1555
class TransformerEncoderBlock (nlp_modeling .layers .TransformerEncoderBlock ):
1556
1556
"""TransformerEncoderBlock layer with stochastic depth and layerscale."""
1557
1557
1558
- def __init__ (self ,
1559
- * args ,
1560
- stochastic_depth_drop_rate = 0.0 ,
1561
- layer_scale_init_value = 0.0 ,
1562
- ** kwargs ):
1563
- """Initializes TransformerEncoderBlock."""
1558
+ def __init__ (
1559
+ self ,
1560
+ * args ,
1561
+ stochastic_depth_drop_rate = 0.0 ,
1562
+ layer_scale_init_value = 0.0 ,
1563
+ max_attention_inference_parallelism = None ,
1564
+ ** kwargs
1565
+ ):
1566
+ """Initializes TransformerEncoderBlock.
1567
+
1568
+ Args:
1569
+ *args: positional arguments passed to super().__init__.
1570
+ stochastic_depth_drop_rate: the drop rate for the stochastic depth layer.
1571
+ layer_scale_init_value:
1572
+ max_attention_inference_parallelism: the number of examples to run in
1573
+ parallel in the attention blocks during inference. Set this limit to
1574
+ reduce the peak memory usage. If None, use vectorized operations to run
1575
+ the whole batch in parallel.
1576
+ **kwargs: keyword arguments passed to super().__init__.
1577
+ """
1564
1578
super ().__init__ (* args , ** kwargs )
1565
1579
self ._stochastic_depth_drop_rate = stochastic_depth_drop_rate
1566
1580
self ._layer_scale_init_value = layer_scale_init_value
1581
+ self ._max_attention_inference_parallelism = (
1582
+ max_attention_inference_parallelism
1583
+ )
1567
1584
1568
1585
def build (self , input_shape ):
1569
1586
if self ._stochastic_depth_drop_rate :
@@ -1582,10 +1599,25 @@ def build(self, input_shape):
1582
1599
self ._layer_scale_mlp = lambda x , * args , ** kwargs : tf .identity (x )
1583
1600
super ().build (input_shape )
1584
1601
1602
+ if self ._max_attention_inference_parallelism is not None :
1603
+ attention_layer_config = self ._attention_layer .get_config ()
1604
+ self ._attention_layer = nn_layers .MultiHeadAttention .from_config ({
1605
+ ** attention_layer_config ,
1606
+ 'max_inference_parallelism' : (
1607
+ self ._max_attention_inference_parallelism
1608
+ ),
1609
+ })
1610
+
1585
1611
def get_config (self ):
1586
- config = {'stochastic_depth_drop_rate' : self ._stochastic_depth_drop_rate }
1587
- base_config = super ().get_config ()
1588
- return dict (list (base_config .items ()) + list (config .items ()))
1612
+ config = super ().get_config ()
1613
+ config .update ({
1614
+ 'stochastic_depth_drop_rate' : self ._stochastic_depth_drop_rate ,
1615
+ 'layer_scale_init_value' : self ._layer_scale_init_value ,
1616
+ 'max_attention_inference_parallelism' : (
1617
+ self ._max_attention_inference_parallelism
1618
+ ),
1619
+ })
1620
+ return config
1589
1621
1590
1622
def call (self , inputs , output_range = None , training = None ):
1591
1623
"""Transformer self-attention encoder block call."""
@@ -1675,29 +1707,39 @@ def call(self, inputs, output_range=None, training=None):
1675
1707
1676
1708
@tf .keras .utils .register_keras_serializable (package = 'Vision' )
1677
1709
class TransformerScaffold (nlp_modeling .layers .TransformerScaffold ):
1678
- """TransformerScaffold layer for vision applications.
1679
-
1680
- This layer is a subclass of NLP TransformerScaffold:
1710
+ """TransformerScaffold layer for vision applications."""
1681
1711
1682
- Attributes:
1683
- stochastic_depth_drop_rate: Drop rate for the residual connections.
1684
- return_attention_scores: Optionally return the attention output.
1685
- ffn_has_residual_connection: Whether the feedforward network has internal
1686
- residual connection and layer norm. If False, the residual connection and
1687
- the layer norm op are called inside TransformerScaffold.
1688
- """
1712
+ def __init__ (
1713
+ self ,
1714
+ * args ,
1715
+ stochastic_depth_drop_rate : float = 0.0 ,
1716
+ return_attention_scores : bool = False ,
1717
+ ffn_has_residual_connection : bool = False ,
1718
+ max_attention_inference_parallelism : Optional [int ] = None ,
1719
+ ** kwargs
1720
+ ):
1721
+ """Initializes TransformerEncoderBlock.
1689
1722
1690
- def __init__ (self ,
1691
- * args ,
1692
- stochastic_depth_drop_rate : float = 0.0 ,
1693
- return_attention_scores : bool = False ,
1694
- ffn_has_residual_connection : bool = False ,
1695
- ** kwargs ):
1696
- """Initializes TransformerEncoderBlock."""
1723
+ Args:
1724
+ *args: positional arguments passed to super().__init__.
1725
+ stochastic_depth_drop_rate: the drop rate for the stochastic depth layer.
1726
+ return_attention_scores: whether to return the attention output.
1727
+ ffn_has_residual_connection: whether the feedforward network has internal
1728
+ residual connection and layer norm. If False, the residual connection
1729
+ and the layer norm op are called inside TransformerScaffold.
1730
+ max_attention_inference_parallelism: the number of examples to run in
1731
+ parallel in the attention blocks during inference. Set this limit to
1732
+ reduce the peak memory usage. If None, use vectorized operations to run
1733
+ the whole batch in parallel.
1734
+ **kwargs: keyword arguments passed to super().__init__.
1735
+ """
1697
1736
super ().__init__ (* args , ** kwargs )
1698
1737
self ._stochastic_depth_drop_rate = stochastic_depth_drop_rate
1699
1738
self ._return_attention_scores = return_attention_scores
1700
1739
self ._ffn_has_residual_connection = ffn_has_residual_connection
1740
+ self ._max_attention_inference_parallelism = (
1741
+ max_attention_inference_parallelism
1742
+ )
1701
1743
1702
1744
def build (self , input_shape : Union [tf .TensorShape , List [int ]]):
1703
1745
if self ._stochastic_depth_drop_rate :
@@ -1708,15 +1750,26 @@ def build(self, input_shape: Union[tf.TensorShape, List[int]]):
1708
1750
1709
1751
super ().build (input_shape )
1710
1752
1753
+ if self ._max_attention_inference_parallelism is not None :
1754
+ attention_layer_config = self ._attention_layer .get_config ()
1755
+ self ._attention_layer = self ._attention_cls .from_config ({
1756
+ ** attention_layer_config ,
1757
+ 'max_inference_parallelism' : (
1758
+ self ._max_attention_inference_parallelism
1759
+ ),
1760
+ })
1761
+
1711
1762
def get_config (self ):
1712
- config = {
1763
+ config = super ().get_config ()
1764
+ config .update ({
1713
1765
'stochastic_depth_drop_rate' : self ._stochastic_depth_drop_rate ,
1714
1766
'return_attention_scores' : self ._return_attention_scores ,
1715
- 'ffn_has_residual_connection' : self ._ffn_has_residual_connection
1716
- }
1717
- base_config = super ().get_config ()
1718
- base_config .update (config )
1719
- return base_config
1767
+ 'ffn_has_residual_connection' : self ._ffn_has_residual_connection ,
1768
+ 'max_attention_inference_parallelism' : (
1769
+ self ._max_attention_inference_parallelism
1770
+ ),
1771
+ })
1772
+ return config
1720
1773
1721
1774
def call (
1722
1775
self ,
0 commit comments