@@ -1919,6 +1919,157 @@ def analyzer_fn(inputs):
19191919 expected_outputs ,
19201920 desired_batch_size = 10 )
19211921
1922+ @tft_unit .named_parameters (
1923+ dict (
1924+ testcase_name = '_dense_2d' ,
1925+ input_data = [{
1926+ 'x' : [4 , 8 ],
1927+ 'key' : 'a'
1928+ }, {
1929+ 'x' : [1 , 5 ],
1930+ 'key' : 'a'
1931+ }, {
1932+ 'x' : [5 , 9 ],
1933+ 'key' : 'a'
1934+ }, {
1935+ 'x' : [2 , 6 ],
1936+ 'key' : 'a'
1937+ }, {
1938+ 'x' : [- 2 , 0 ],
1939+ 'key' : 'b'
1940+ }, {
1941+ 'x' : [0 , 2 ],
1942+ 'key' : 'b'
1943+ }, {
1944+ 'x' : [2 , 4 ],
1945+ 'key' : 'b'
1946+ }],
1947+ input_metadata = tft .DatasetMetadata .from_feature_spec ({
1948+ 'x' : tf .io .FixedLenFeature ([2 ], tf .float32 ),
1949+ 'key' : tf .io .FixedLenFeature ([], tf .string ),
1950+ }),
1951+ reduce_instance_dims = True ,
1952+ expected_outputs = {
1953+ 'key_vocab' : np .array ([b'a' , b'b' ], np .object ),
1954+ 'min_x_value' : np .array ([1 , - 2 ], np .float32 ),
1955+ 'max_x_value' : np .array ([9 , 4 ], np .float32 ),
1956+ }),
1957+ dict (
1958+ testcase_name = '_dense_2d_elementwise' ,
1959+ input_data = [{
1960+ 'x' : [4 , 8 ],
1961+ 'key' : 'a'
1962+ }, {
1963+ 'x' : [1 , 5 ],
1964+ 'key' : 'a'
1965+ }, {
1966+ 'x' : [5 , 9 ],
1967+ 'key' : 'a'
1968+ }, {
1969+ 'x' : [2 , 6 ],
1970+ 'key' : 'a'
1971+ }, {
1972+ 'x' : [- 2 , 0 ],
1973+ 'key' : 'b'
1974+ }, {
1975+ 'x' : [0 , 2 ],
1976+ 'key' : 'b'
1977+ }, {
1978+ 'x' : [2 , 4 ],
1979+ 'key' : 'b'
1980+ }],
1981+ input_metadata = tft .DatasetMetadata .from_feature_spec ({
1982+ 'x' : tf .io .FixedLenFeature ([2 ], tf .float32 ),
1983+ 'key' : tf .io .FixedLenFeature ([], tf .string ),
1984+ }),
1985+ reduce_instance_dims = False ,
1986+ expected_outputs = {
1987+ 'key_vocab' : np .array ([b'a' , b'b' ], np .object ),
1988+ 'min_x_value' : np .array ([[1 , 5 ], [- 2 , 0 ]], np .float32 ),
1989+ 'max_x_value' : np .array ([[5 , 9 ], [2 , 4 ]], np .float32 ),
1990+ }),
1991+ dict (
1992+ testcase_name = '_dense_3d' ,
1993+ input_data = [
1994+ {
1995+ 'x' : [[1 , 5 ], [1 , 1 ]],
1996+ 'key' : 'a'
1997+ },
1998+ {
1999+ 'x' : [[5 , 1 ], [5 , 5 ]],
2000+ 'key' : 'a'
2001+ },
2002+ {
2003+ 'x' : [[2 , 2 ], [2 , 5 ]],
2004+ 'key' : 'a'
2005+ },
2006+ {
2007+ 'x' : [[3 , - 3 ], [3 , 3 ]],
2008+ 'key' : 'b'
2009+ },
2010+ ],
2011+ input_metadata = tft .DatasetMetadata .from_feature_spec ({
2012+ 'x' : tf .io .FixedLenFeature ([2 , 2 ], tf .float32 ),
2013+ 'key' : tf .io .FixedLenFeature ([], tf .string ),
2014+ }),
2015+ reduce_instance_dims = True ,
2016+ expected_outputs = {
2017+ 'key_vocab' : np .array ([b'a' , b'b' ], np .object ),
2018+ 'min_x_value' : np .array ([1 , - 3 ], np .float32 ),
2019+ 'max_x_value' : np .array ([5 , 3 ], np .float32 ),
2020+ }),
2021+ dict (
2022+ testcase_name = '_dense_3d_elementwise' ,
2023+ input_data = [
2024+ {
2025+ 'x' : [[1 , 5 ], [1 , 1 ]],
2026+ 'key' : 'a'
2027+ },
2028+ {
2029+ 'x' : [[5 , 1 ], [5 , 5 ]],
2030+ 'key' : 'a'
2031+ },
2032+ {
2033+ 'x' : [[2 , 2 ], [2 , 5 ]],
2034+ 'key' : 'a'
2035+ },
2036+ {
2037+ 'x' : [[3 , - 3 ], [3 , 3 ]],
2038+ 'key' : 'b'
2039+ },
2040+ ],
2041+ input_metadata = tft .DatasetMetadata .from_feature_spec ({
2042+ 'x' : tf .io .FixedLenFeature ([2 , 2 ], tf .float32 ),
2043+ 'key' : tf .io .FixedLenFeature ([], tf .string ),
2044+ }),
2045+ reduce_instance_dims = False ,
2046+ expected_outputs = {
2047+ 'key_vocab' :
2048+ np .array ([b'a' , b'b' ], np .object ),
2049+ 'min_x_value' :
2050+ np .array ([[[1 , 1 ], [1 , 1 ]], [[3 , - 3 ], [3 , 3 ]]], np .float32 ),
2051+ 'max_x_value' :
2052+ np .array ([[[5 , 5 ], [5 , 5 ]], [[3 , - 3 ], [3 , 3 ]]], np .float32 ),
2053+ }),
2054+ )
2055+ def testMinAndMaxPerKey (self , input_data , input_metadata ,
2056+ reduce_instance_dims , expected_outputs ):
2057+ self ._SkipIfOutputRecordBatches ()
2058+
2059+ def analyzer_fn (inputs ):
2060+ key_vocab , min_x_value , max_x_value = analyzers ._min_and_max_per_key (
2061+ x = inputs ['x' ],
2062+ key = inputs ['key' ],
2063+ reduce_instance_dims = reduce_instance_dims )
2064+ return {
2065+ 'key_vocab' : key_vocab ,
2066+ 'min_x_value' : min_x_value ,
2067+ 'max_x_value' : max_x_value ,
2068+ }
2069+
2070+ self .assertAnalyzerOutputs (input_data , input_metadata , analyzer_fn ,
2071+ expected_outputs )
2072+
19222073 @tft_unit .parameters ((True ,), (False ,))
19232074 def testPerKeyWithOOVKeys (self , use_vocabulary ):
19242075 def preprocessing_fn (inputs ):
0 commit comments