@@ -59,19 +59,31 @@ def __init__(
5959
6060 if not isinstance (size , (list , tuple )) and len (size ) == 2 :
6161 raise AssertionError ()
62-
6362 if len (self .inputs .get_shape ()) == 3 :
6463 if is_scale :
65- size_h = size [0 ] * tf .shape (self .inputs )[0 ]
66- size_w = size [1 ] * tf .shape (self .inputs )[1 ]
64+ input_shape = self .inputs .shape .as_list ()
65+ if input_shape [0 ] is not None :
66+ size_h = size [0 ] * input_shape [0 ]
67+ else :
68+ size_h = size [0 ] * tf .shape (self .inputs )[0 ]
69+ if input_shape [1 ] is not None :
70+ size_w = size [1 ] * input_shape [1 ]
71+ else :
72+ size_w = size [1 ] * tf .shape (self .inputs )[1 ]
6773 size = [size_h , size_w ]
6874
6975 elif len (self .inputs .get_shape ()) == 4 :
7076 if is_scale :
71- size_h = size [0 ] * tf .shape (self .inputs )[1 ]
72- size_w = size [1 ] * tf .shape (self .inputs )[2 ]
77+ input_shape = self .inputs .shape .as_list ()
78+ if input_shape [1 ] is not None :
79+ size_h = size [0 ] * input_shape [1 ]
80+ else :
81+ size_h = size [0 ] * tf .shape (self .inputs )[1 ]
82+ if input_shape [2 ] is not None :
83+ size_w = size [1 ] * input_shape [2 ]
84+ else :
85+ size_w = size [1 ] * tf .shape (self .inputs )[2 ]
7386 size = [size_h , size_w ]
74-
7587 else :
7688 raise Exception ("Donot support shape %s" % tf .shape (self .inputs ))
7789
@@ -135,14 +147,28 @@ def __init__(
135147
136148 if len (self .inputs .get_shape ()) == 3 :
137149 if is_scale :
138- size_h = size [0 ] * tf .shape (self .inputs )[0 ]
139- size_w = size [1 ] * tf .shape (self .inputs )[1 ]
150+ input_shape = self .inputs .shape .as_list ()
151+ if input_shape [0 ] is not None :
152+ size_h = size [0 ] * input_shape [0 ]
153+ else :
154+ size_h = size [0 ] * tf .shape (self .inputs )[0 ]
155+ if input_shape [1 ] is not None :
156+ size_w = size [1 ] * input_shape [1 ]
157+ else :
158+ size_w = size [1 ] * tf .shape (self .inputs )[1 ]
140159 size = [size_h , size_w ]
141160
142161 elif len (self .inputs .get_shape ()) == 4 :
143162 if is_scale :
144- size_h = size [0 ] * tf .shape (self .inputs )[1 ]
145- size_w = size [1 ] * tf .shape (self .inputs )[2 ]
163+ input_shape = self .inputs .shape .as_list ()
164+ if input_shape [1 ] is not None :
165+ size_h = size [0 ] * input_shape [1 ]
166+ else :
167+ size_h = size [0 ] * tf .shape (self .inputs )[1 ]
168+ if input_shape [2 ] is not None :
169+ size_w = size [1 ] * input_shape [2 ]
170+ else :
171+ size_w = size [1 ] * tf .shape (self .inputs )[2 ]
146172 size = [size_h , size_w ]
147173
148174 else :
0 commit comments