2020"""
2121
2222import collections
23- import functools
2423import itertools
25- from typing import Any , Callable , Dict , Iterable , List , Mapping , Optional , Sequence , Tuple , Union
24+ from typing import Any , Callable , Dict , Iterable , List , Mapping , Optional , Tuple , Union
2625from absl import logging
2726
2827import tensorflow as tf
@@ -188,7 +187,7 @@ def infer_feature_schema(
188187 tensor_ranges = session .run (tensor_ranges )
189188 tensor_annotations , global_annotations = _get_schema_annotations (
190189 graph , session )
191- sparse_output_annotations = _get_sparse_output_annotations_v1 (graph )
190+ sparse_output_annotations = _get_sparse_output_annotations_v1 (graph , session )
192191 modified_sparse_output_annotations = {}
193192 modified_tensor_ranges = {}
194193 feature_annotations = {}
@@ -790,24 +789,14 @@ def metadata_fn():
790789 return module .metadata_fn
791790
792791_ANNOTATED_SPARSE_SHAPE_TENSORS = 'annotated_sparse_shape_tensors'
793- _ANNOTATED_SPARSE_SHAPE_RANKS = 'annotated_sparse_shape_ranks'
794- _ANNOTATED_SPARSE_SHAPE_DIMENSIONS = 'annotated_sparse_shape_dimensions'
792+ _ANNOTATED_SPARSE_SHAPES = 'annotated_sparse_shape_dimensions'
795793_ANNOTATED_TRUELY_SPARSE_TENSORS = 'annotated_truely_sparse_tensors'
796794
797795
798- def annotate_sparse_output_shape (tensor : tf .SparseTensor , shape : Sequence [ int ] ):
796+ def annotate_sparse_output_shape (tensor : tf .SparseTensor , shape : tf . Tensor ):
799797 """Annotates a sparse output with a given shape."""
800- if tensor .shape .rank - 1 != len (shape ):
801- raise ValueError (
802- f'Output { tensor } was annotated with an incompatible shape: { shape } '
803- )
804798 tf .compat .v1 .add_to_collection (_ANNOTATED_SPARSE_SHAPE_TENSORS , tensor .values )
805- # We store rank and dimensions to separate collections since TF collections
806- # don't allow storing lists. This can be simplified if we switch away from
807- # collections.
808- tf .compat .v1 .add_to_collection (_ANNOTATED_SPARSE_SHAPE_RANKS , len (shape ))
809- for dim in shape :
810- tf .compat .v1 .add_to_collection (_ANNOTATED_SPARSE_SHAPE_DIMENSIONS , dim )
799+ tf .compat .v1 .add_to_collection (_ANNOTATED_SPARSE_SHAPES , shape )
811800
812801
813802def annotate_true_sparse_output (tensor : tf .SparseTensor ):
@@ -818,26 +807,16 @@ def annotate_true_sparse_output(tensor: tf.SparseTensor):
818807
819808
820809def _extract_true_sparse_annotations (
821- graph : tf .compat .v1 .Graph ,
822- ) -> List [tf .Tensor ]:
810+ graph : tf .compat .v1 .Graph ) -> List [tf .Tensor ]:
823811 """Extracts true sparse annotations from the graph."""
824812 return graph .get_collection (_ANNOTATED_TRUELY_SPARSE_TENSORS )
825813
826814
827815def _extract_sparse_output_annotations (
828- graph : tf .compat .v1 .Graph ,
829- ) -> List [Tuple [tf .Tensor , List [tf .Tensor ]]]:
816+ graph : tf .compat .v1 .Graph ) -> List [Tuple [tf .Tensor , List [tf .Tensor ]]]:
830817 """Extracts sparse output annotations from the graph."""
831818 tensors = graph .get_collection (_ANNOTATED_SPARSE_SHAPE_TENSORS )
832- ranks = graph .get_collection (_ANNOTATED_SPARSE_SHAPE_RANKS )
833- assert len (tensors ) == len (ranks ), f'{ tensors } != { ranks } '
834- shape_flattened = graph .get_collection (_ANNOTATED_SPARSE_SHAPE_DIMENSIONS )
835-
836- # Splitting dimensions per annotated tensor.
837- splits = functools .reduce (lambda lst , x : lst + [lst [- 1 ] + x ], ranks , [0 ])
838- # Composing the annotated shape per tensor.
839- shapes = tuple (shape_flattened [s :e ] for s , e in zip (splits [:- 1 ], splits [1 :]))
840-
819+ shapes = graph .get_collection (_ANNOTATED_SPARSE_SHAPES )
841820 assert len (tensors ) == len (shapes ), f'{ tensors } != { shapes } '
842821 return list (zip (tensors , shapes ))
843822
@@ -852,7 +831,7 @@ def _get_sparse_output_annotations(
852831 return list (
853832 itertools .chain (
854833 (
855- (a , ['' ])
834+ (a , tf . constant ( ['' ]) )
856835 for a in _extract_true_sparse_annotations (graph )
857836 if a .ref () not in annotated_refs
858837 ),
@@ -862,12 +841,15 @@ def _get_sparse_output_annotations(
862841
863842
864843def _get_sparse_output_annotations_v1 (
865- graph : tf .compat .v1 .Graph ,
844+ graph : tf .compat .v1 .Graph , session : Optional [ tf . compat . v1 . Session ]
866845) -> Dict [Any , List [Union [str , tf .Tensor ]]]:
867- return {
868- tf_utils .hashable_tensor_or_op (kv [0 ]): kv [1 ]
869- for kv in _get_sparse_output_annotations (graph )
870- }
846+ if not session :
847+ return {}
848+ else :
849+ return {
850+ tf_utils .hashable_tensor_or_op (kv [0 ]): session .run (kv [1 ])
851+ for kv in _get_sparse_output_annotations (graph )
852+ }
871853
872854
873855def _get_sparse_output_annotations_v2 (
0 commit comments