Skip to content

Commit 5cdb6ed

Browse files
committed
[tmva][sofie] Fix testSofieModels for new torch version
New torch versions comes with a new export mode to ONNX. One needs to specify forst to have a single onnx file with the weights (external_data=False) and then in some case (new export mode, which is with dynamo=True) does not fork for batchNorm and recurrent network. Disable it for these cases.
1 parent 3f5a22f commit 5cdb6ed

File tree

7 files changed

+152
-127
lines changed

7 files changed

+152
-127
lines changed

tmva/sofie/test/Conv1dModelGenerator.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414

1515
class Net(nn.Module):
16-
16+
1717
def __init__(self, nc = 1, ng = 1, nl = 4, use_bn = False, use_maxpool = False, use_avgpool = False):
1818
super(Net, self).__init__()
1919

@@ -23,7 +23,7 @@ def __init__(self, nc = 1, ng = 1, nl = 4, use_bn = False, use_maxpool = False,
2323
self.use_bn = use_bn
2424
self.use_maxpool = use_maxpool
2525
self.use_avgpool = use_avgpool
26-
26+
2727
self.conv0 = nn.Conv1d(in_channels=self.nc, out_channels=4, kernel_size=2, groups=1, stride=1, padding=1)
2828
if (self.use_bn): self.bn1 = nn.BatchNorm2d(4)
2929
if (self.use_maxpool): self.pool1 = nn.MaxPool2d(2)
@@ -55,7 +55,7 @@ def main():
5555
parser = argparse.ArgumentParser(description='PyTorch model generator')
5656
parser.add_argument('params', type=int, nargs='+',
5757
help='parameters for the Conv network : batchSize , inputChannels, inputImageSize, nGroups, nLayers ')
58-
58+
5959
parser.add_argument('--bn', action='store_true', default=False,
6060
help='For using batch norm layer')
6161
parser.add_argument('--maxpool', action='store_true', default=False,
@@ -69,13 +69,13 @@ def main():
6969

7070

7171
args = parser.parse_args()
72-
72+
7373
#args.params = (4,2,4,1,4)
7474

7575
np = len(args.params)
7676
if (np < 5) : exit()
7777
bsize = args.params[0]
78-
nc = args.params[1]
78+
nc = args.params[1]
7979
d = args.params[2]
8080
ngroups = args.params[3]
8181
nlayers = args.params[4]
@@ -92,24 +92,24 @@ def main():
9292
input = torch.zeros([])
9393
for ib in range(0,bsize):
9494
xa = torch.ones([1, 1, d]) * (ib+1)
95-
if (nc > 1) :
95+
if (nc > 1) :
9696
xb = xa.neg()
9797
xc = torch.cat((xa,xb),1) # concatenate tensors
9898
if (nc > 2) :
9999
xd = torch.zeros([1,nc-2,d])
100100
xc = torch.cat((xa,xb,xd),1)
101101
else:
102102
xc = xa
103-
104-
#concatenate tensors
105-
if (ib == 0) :
103+
104+
#concatenate tensors
105+
if (ib == 0) :
106106
xinput = xc
107107
else :
108-
xinput = torch.cat((xinput,xc),0)
108+
xinput = torch.cat((xinput,xc),0)
109109

110110
print("input data",xinput.shape)
111111
print(xinput)
112-
112+
113113
name = "Conv1dModel"
114114
if (use_bn): name += "_BN"
115115
if (use_maxpool): name += "_MAXP"
@@ -120,24 +120,26 @@ def main():
120120
loadModel=False
121121
savePtModel = False
122122

123-
123+
124124
model = Net(nc,ngroups,nlayers, use_bn, use_maxpool, use_avgpool)
125125
print(model)
126126

127127
model(xinput)
128-
128+
129129
model.forward(xinput)
130130

131131
if savePtModel :
132132
torch.save({'model_state_dict':model.state_dict()}, name + ".pt")
133133

134134
if saveOnnx:
135-
torch.onnx.export(
136-
model,
137-
xinput,
138-
name + ".onnx",
139-
export_params=True
140-
)
135+
torch.onnx.export(
136+
model,
137+
xinput,
138+
name + ".onnx",
139+
export_params=True,
140+
dynamo=True,
141+
external_data=False
142+
)
141143

142144
if loadModel :
143145
print('Loading model from file....')
@@ -158,10 +160,10 @@ def main():
158160

159161
f = open(name + ".out", "w")
160162
for i in range(0,outSize):
161-
f.write(str(float(yvec[i]))+" ")
162-
163-
164-
163+
f.write(str(float(yvec[i].detach()))+" ")
164+
165+
166+
165167

166168
if __name__ == '__main__':
167169
main()

tmva/sofie/test/Conv2dModelGenerator.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
result = []
1313

1414
class Net(nn.Module):
15-
15+
1616
def __init__(self, nc = 1, ng = 1, nl = 4, use_bn = False, use_maxpool = False, use_avgpool = False):
1717
super(Net, self).__init__()
1818

@@ -22,7 +22,7 @@ def __init__(self, nc = 1, ng = 1, nl = 4, use_bn = False, use_maxpool = False,
2222
self.use_bn = use_bn
2323
self.use_maxpool = use_maxpool
2424
self.use_avgpool = use_avgpool
25-
25+
2626
self.conv0 = nn.Conv2d(in_channels=self.nc, out_channels=4, kernel_size=2, groups=1, stride=1, padding=1)
2727
if (self.use_bn): self.bn1 = nn.BatchNorm2d(4)
2828
if (self.use_maxpool): self.pool1 = nn.MaxPool2d(2)
@@ -32,7 +32,7 @@ def __init__(self, nc = 1, ng = 1, nl = 4, use_bn = False, use_maxpool = False,
3232
self.conv1 = nn.Conv2d(in_channels=4, out_channels=8, groups = self.ng, kernel_size=3, stride=1, padding=1)
3333
#output is same 4x4
3434
self.conv2 = nn.Conv2d(in_channels=8, out_channels=4, kernel_size=3, stride=1, padding=1)
35-
#use stride last layer
35+
#use stride last layer
3636
self.conv3 = nn.Conv2d(in_channels=4, out_channels=1, kernel_size=2, stride=2, padding=0)
3737

3838

@@ -60,7 +60,7 @@ def main():
6060
parser = argparse.ArgumentParser(description='PyTorch model generator')
6161
parser.add_argument('params', type=int, nargs='+',
6262
help='parameters for the Conv network : batchSize , inputChannels, inputImageSize, nGroups, nLayers ')
63-
63+
6464
parser.add_argument('--bn', action='store_true', default=False,
6565
help='For using batch norm layer')
6666
parser.add_argument('--maxpool', action='store_true', default=False,
@@ -72,13 +72,13 @@ def main():
7272

7373

7474
args = parser.parse_args()
75-
75+
7676
#args.params = (4,2,4,1,4)
7777

7878
np = len(args.params)
7979
if (np < 5) : exit()
8080
bsize = args.params[0]
81-
nc = args.params[1]
81+
nc = args.params[1]
8282
d = args.params[2]
8383
ngroups = args.params[3]
8484
nlayers = args.params[4]
@@ -94,24 +94,24 @@ def main():
9494
input = torch.zeros([])
9595
for ib in range(0,bsize):
9696
xa = torch.ones([1, 1, d, d]) * (ib+1)
97-
if (nc > 1) :
97+
if (nc > 1) :
9898
xb = xa.neg()
9999
xc = torch.cat((xa,xb),1) # concatenate tensors
100100
if (nc > 2) :
101101
xd = torch.zeros([1,nc-2,d,d])
102102
xc = torch.cat((xa,xb,xd),1)
103103
else:
104104
xc = xa
105-
106-
#concatenate tensors
107-
if (ib == 0) :
105+
106+
#concatenate tensors
107+
if (ib == 0) :
108108
xinput = xc
109109
else :
110-
xinput = torch.cat((xinput,xc),0)
110+
xinput = torch.cat((xinput,xc),0)
111111

112112
print("input data",xinput.shape)
113113
print(xinput)
114-
114+
115115
name = "Conv2dModel"
116116
if (use_bn): name += "_BN"
117117
if (use_maxpool): name += "_MAXP"
@@ -122,24 +122,33 @@ def main():
122122
loadModel=False
123123
savePtModel = False
124124

125-
125+
126126
model = Net(nc,ngroups,nlayers, use_bn, use_maxpool, use_avgpool)
127127
print(model)
128128

129129
model(xinput)
130-
130+
131131
model.forward(xinput)
132132

133133
if savePtModel :
134134
torch.save({'model_state_dict':model.state_dict()}, name + ".pt")
135135

136+
136137
if saveOnnx:
137-
torch.onnx.export(
138-
model,
139-
xinput,
140-
name + ".onnx",
141-
export_params=True
142-
)
138+
139+
#new ONNX exporter does not work for batchmorm
140+
dynamo_export=True
141+
if (use_bn): dynamo_export=False
142+
143+
torch.onnx.export(
144+
model,
145+
xinput,
146+
name + ".onnx",
147+
export_params=True,
148+
dynamo=dynamo_export,
149+
external_data=False
150+
)
151+
143152

144153
if loadModel :
145154
print('Loading model from file....')
@@ -160,10 +169,10 @@ def main():
160169

161170
f = open(name + ".out", "w")
162171
for i in range(0,outSize):
163-
f.write(str(float(yvec[i]))+" ")
164-
165-
166-
172+
f.write(str(float(yvec[i].detach()))+" ")
173+
174+
175+
167176

168177
if __name__ == '__main__':
169178
main()

0 commit comments

Comments
 (0)