@@ -4425,10 +4425,22 @@ def test3dSparseWithTFXIO(self):
44254425 feature {
44264426 name: "x$sparse_indices_0"
44274427 type: INT
4428+ # TODO(b/184055743): Once TensorFlow is released with
4429+ # cl/342914534, uncomment.
4430+ # int_domain {
4431+ # min: 0
4432+ # max: 4
4433+ # }
44284434 }
44294435 feature {
44304436 name: "x$sparse_indices_1"
44314437 type: INT
4438+ # TODO(b/184055743): Once TensorFlow is released with
4439+ # cl/342914534 uncomment.
4440+ # int_domain {
4441+ # min: 0
4442+ # max: 4
4443+ # }
44324444 }
44334445 feature {
44344446 name: "x$sparse_values"
@@ -4450,14 +4462,28 @@ def test3dSparseWithTFXIO(self):
44504462 if not tft_unit .is_external_environment ():
44514463 expected_metadata .generate_legacy_feature_spec = False
44524464
4453- self .assertProtoEquals (transformed_metadata .schema , expected_metadata )
4465+ # TODO(b/184055743): Once TensorFlow is released with cl/342914534,
4466+ # remove this.
4467+ # TODO(b/184057384): Even with cl/342914534,
4468+ # transformed_metadata.deferred_schema still does not contain the shape
4469+ # information about the SparseTensor.
4470+ def int_domain_cleared (schema ):
4471+ result = schema_pb2 .Schema ()
4472+ result .CopyFrom (schema )
4473+ for f in result .feature :
4474+ f .ClearField ('int_domain' )
4475+ return result
4476+
4477+ self .assertProtoEquals (int_domain_cleared (transformed_metadata .schema ),
4478+ expected_metadata )
44544479
44554480 beam_test_util .assert_that (
44564481 transformed_data , self ._MakeTransformOutputAssertFn (expected_data ))
44574482
44584483 def _assert_schemas_equal_fn (schema_dict_list ):
44594484 self .assertEqual (1 , len (schema_dict_list ))
4460- self .assertProtoEquals (schema_dict_list [0 ].schema , expected_metadata )
4485+ self .assertProtoEquals (
4486+ int_domain_cleared (schema_dict_list [0 ].schema ), expected_metadata )
44614487
44624488 beam_test_util .assert_that (
44634489 transformed_metadata .deferred_metadata ,
@@ -4468,15 +4494,26 @@ def _assert_schemas_equal_fn(schema_dict_list):
44684494 dataset = tf .data .TFRecordDataset (materialize_path )
44694495 tft_out = tft .TFTransformOutput (transform_output_path )
44704496 transformed_feature_spec = tft_out .transformed_feature_spec ()
4471- self .assertEqual (
4472- transformed_feature_spec , {
4473- 'x' :
4474- tf .io .SparseFeature (
4475- ['x$sparse_indices_0' , 'x$sparse_indices_1' ],
4476- 'x$sparse_values' ,
4477- tf .float32 , [- 1 , - 1 ],
4478- already_sorted = True )
4479- })
4497+ self .assertLen (transformed_feature_spec , 1 )
4498+ self .assertIn ('x' , transformed_feature_spec )
4499+ self .assertIn (
4500+ transformed_feature_spec ['x' ],
4501+ (tf .io .SparseFeature (['x$sparse_indices_0' , 'x$sparse_indices_1' ],
4502+ 'x$sparse_values' ,
4503+ tf .float32 , [5 , 5 ],
4504+ already_sorted = True ),
4505+ # TODO(b/184055743): Once TensorFlow is released with cl/342914534,
4506+ # remove this.
4507+ # TODO(b/184057384): Even with cl/342914534, still may not contain
4508+ # the shape information about the SparseTensor.
4509+ tf .io .SparseFeature (['x$sparse_indices_0' , 'x$sparse_indices_1' ],
4510+ 'x$sparse_values' ,
4511+ tf .float32 , [- 1 , - 1 ],
4512+ already_sorted = True )))
4513+
4514+ transformed_feature_spec ['x' ] = tf .io .SparseFeature (
4515+ ['x$sparse_indices_0' , 'x$sparse_indices_1' ],
4516+ 'x$sparse_values' , tf .float32 , [5 , 5 ], already_sorted = True )
44804517
44814518 def parse_fn (serialized_input ):
44824519 result = tf .io .parse_single_example (serialized_input ,
@@ -4493,7 +4530,7 @@ def parse_fn(serialized_input):
44934530 expected_sparse_components = [
44944531 np .array ([[arr ] for arr in zip (x_idx0 , x_idx1 )]),
44954532 np .array ([[x ] for x in x_data ]),
4496- np .array ([[- 1 , - 1 ]] * len (x_data ))
4533+ np .array ([[5 , 5 ]] * len (x_data ))
44974534 ]
44984535 self .assertLen (transformed_sparse_components ,
44994536 len (expected_sparse_components ))
0 commit comments