@@ -496,10 +496,10 @@ class Conv2D(object):
496496
497497 def __init__ (self , strides , padding , data_format = 'NHWC' , dilations = None , out_channel = None , k_size = None ):
498498 self .data_format , self .padding = preprocess_2d_format (data_format , padding )
499- if self .data_format is 'NHWC' :
499+ if self .data_format == 'NHWC' :
500500 self ._stride = (strides [1 ], strides [2 ])
501501 self ._dilation = (dilations [1 ], dilations [2 ])
502- elif self .data_format is 'NCHW' :
502+ elif self .data_format == 'NCHW' :
503503 self ._stride = (strides [2 ], strides [3 ])
504504 self ._dilation = (dilations [2 ], dilations [3 ])
505505
@@ -537,10 +537,10 @@ def conv2d(input, filters, strides, padding, data_format='NCHW', dilations=None)
537537 A Tensor. Has the same type as input.
538538 """
539539 data_format , padding = preprocess_2d_format (data_format , padding )
540- if data_format is 'NHWC' :
540+ if data_format == 'NHWC' :
541541 _stride = (strides [1 ], strides [2 ])
542542 _dilation = (dilations [1 ], dilations [2 ])
543- elif data_format is 'NCHW' :
543+ elif data_format == 'NCHW' :
544544 _stride = (strides [2 ], strides [3 ])
545545 _dilation = (dilations [2 ], dilations [3 ])
546546 outputs = F .conv2d (
@@ -553,10 +553,10 @@ class Conv3D(object):
553553
554554 def __init__ (self , strides , padding , data_format = 'NDHWC' , dilations = None , out_channel = None , k_size = None ):
555555 self .data_format , self .padding = preprocess_3d_format (data_format , padding )
556- if self .data_format is 'NDHWC' :
556+ if self .data_format == 'NDHWC' :
557557 self ._strides = (strides [1 ], strides [2 ], strides [3 ])
558558 self ._dilations = (dilations [1 ], dilations [2 ], dilations [3 ])
559- elif self .data_format is 'NCDHW' :
559+ elif self .data_format == 'NCDHW' :
560560 self ._strides = (strides [2 ], strides [3 ], strides [4 ])
561561 self ._dilations = (dilations [2 ], dilations [3 ], dilations [4 ])
562562
@@ -603,10 +603,10 @@ def conv3d(input, filters, strides, padding, data_format='NDHWC', dilations=None
603603 A Tensor. Has the same type as input.
604604 """
605605 data_format , padding = preprocess_3d_format (data_format , padding )
606- if data_format is 'NDHWC' :
606+ if data_format == 'NDHWC' :
607607 _strides = (strides [1 ], strides [2 ], strides [3 ])
608608 _dilations = (dilations [1 ], dilations [2 ], dilations [3 ])
609- elif data_format is 'NCDHW' :
609+ elif data_format == 'NCDHW' :
610610 _strides = (strides [2 ], strides [3 ], strides [4 ])
611611 _dilations = (dilations [2 ], dilations [3 ], dilations [4 ])
612612 outputs = F .conv3d (
@@ -1195,10 +1195,10 @@ def __init__(self, strides, padding, data_format, dilations, out_channel, k_size
11951195 self .k_size = k_size
11961196 self .groups = groups
11971197 self .data_format , self .padding = preprocess_2d_format (data_format , padding )
1198- if self .data_format is 'NHWC' :
1198+ if self .data_format == 'NHWC' :
11991199 self .strides = (strides [1 ], strides [2 ])
12001200 self .dilations = (dilations [1 ], dilations [2 ])
1201- elif self .data_format is 'NCHW' :
1201+ elif self .data_format == 'NCHW' :
12021202 self .strides = (strides [2 ], strides [3 ])
12031203 self .dilations = (dilations [2 ], dilations [3 ])
12041204
@@ -1241,10 +1241,10 @@ def __init__(self, strides, padding, data_format, dilations, out_channel, k_size
12411241 self .in_channel = int (in_channel )
12421242 self .depth_multiplier = depth_multiplier
12431243 self .data_format , self .padding = preprocess_2d_format (data_format , padding )
1244- if self .data_format is 'NHWC' :
1244+ if self .data_format == 'NHWC' :
12451245 self .strides = (strides [1 ], strides [2 ])
12461246 self .dilations = (dilations [1 ], dilations [2 ])
1247- elif self .data_format is 'NCHW' :
1247+ elif self .data_format == 'NCHW' :
12481248 self .strides = (strides [2 ], strides [3 ])
12491249 self .dilations = (dilations [2 ], dilations [3 ])
12501250
0 commit comments