Skip to content

Commit a3832d3

Browse files
committed
[ADD] parser.py to systematically control each argument.
[ADD] smt related files. [DELETE] options.py & similarity.py [DELETE] jax module [DELETE] unnecessary perturbation types, only Linf left.
1 parent 7e2e423 commit a3832d3

File tree

18 files changed

+72764
-32588
lines changed

18 files changed

+72764
-32588
lines changed
Lines changed: 28 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,12 @@
88
)
99
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
1010
os.environ["JAX_TRACEBACK_FILTERING"] = "off"
11-
import argparse
1211

1312
from matplotlib import pyplot as plt
1413
import onnxruntime as ort
1514
import numpy as np
16-
import jax.numpy as jnp
1715

1816
from src.mip import mip_verifier
19-
from src.similarity import Similarity
2017

2118
from utils.gurobi_modeling import GurobiModel
2219
from utils.scip_modeling import SCIPModel
@@ -28,25 +25,24 @@
2825
write_vnnlib,
2926
export_vnnlib,
3027
)
31-
from utils.options import VerificationSolver, RobustnessType
3228
from utils.save_results import Results
3329
from utils.log import Logger
34-
30+
import utils.parser as parser
3531

3632
# global configuration
3733
T: int = 0 # true label
3834

3935

4036
# TODO: implement verification algorithm in different ways
4137
def verify(
42-
solver: VerificationSolver,
38+
args,
4339
dataset: DataSet,
44-
input: jnp.ndarray,
40+
input: np.ndarray,
4541
) -> str:
4642
"""
4743
Verification algorithm:
4844
49-
Support: MIP (SCIP, Gurobi), CROWN
45+
Support: MIP (SCIP, Gurobi), SMT
5046
"""
5147
result: str = "UNSAT"
5248
vnnlib_filename: str = write_vnnlib(data=input,
@@ -57,31 +53,27 @@ def verify(
5753
onnx_file_path=dataset.onnx_filename, vnnlib_file_path=vnnlib_filename
5854
)
5955

60-
if solver is VerificationSolver.SCIP or solver is VerificationSolver.GUROBI:
61-
Logger.info(messages=f"Verification Algorithm is MIP solver ({solver})")
62-
63-
m: SCIPModel | GurobiModel = mip_verifier(solver_name=solver, networks=networks)
56+
if args.solver == "scip" or args.solver == "gurobi":
57+
Logger.info(messages=f"Verification Algorithm is MIP solver ({args.solver})")
58+
m: SCIPModel | GurobiModel = mip_verifier(solver_name=args.solver, networks=networks)
6459
result = "UNSAT" if m.get_solution_status() == "Infeasible" else "SAT"
65-
elif solver is VerificationSolver.BOX:
66-
Logger.info(messages="Verification Algorithm is BOX")
67-
Logger.info(messages="[ERROR] Box verifier is not able to use in this version.")
6860
if result == "UNSAT":
69-
Logger.info(messages="UNSAT, New template generated!")
61+
Logger.info(messages="UNSAT")
7062
return "UNSAT"
7163
else:
7264
Logger.info(messages="SAT")
7365
# * Testing checker for counter-example found by verifier
7466
counter_example: List[float] = get_ce(
75-
solver=solver, networks=networks, filename="./test_cex.txt", m=m
67+
solver=args.solver, networks=networks, filename="./test_cex.txt", m=m
7668
)
77-
counter_example: jnp.ndarray = jnp.array(counter_example)
69+
counter_example: np.ndarray = np.array(counter_example)
7870

7971
return result
8072

8173

82-
def _execute(solver: VerificationSolver) -> None:
74+
def _execute(args) -> None:
8375
"""
84-
MIP verification
76+
Complete verification
8577
8678
Build a mixed-integer programming model to verify neural networks.
8779
@@ -97,12 +89,11 @@ def _execute(solver: VerificationSolver) -> None:
9789
# step 0.
9890
Logger.info(messages="step 0: read the input files")
9991
dataset: DataSet = load_dataset(
100-
dataset_name="mnist",
101-
onnx_filename="./utils/benchmarks/onnx/fc_5x100.onnx",
102-
robustness_type=RobustnessType.LP_NORM,
103-
num_inputs=1, # len(distribution_filtered_test_labels[test_true_label])
104-
distance_type="linf",
105-
epsilon=0.1,
92+
dataset_name=args.dataset,
93+
onnx_filename=f"./utils/benchmarks/onnx/{args.network}",
94+
robustness_type=args.perturbation_type,
95+
num_test=args.num_test, # len(distribution_filtered_test_labels[test_true_label])
96+
epsilon=args.epsilon,
10697
)
10798

10899
# step 1.
@@ -136,7 +127,7 @@ def _execute(solver: VerificationSolver) -> None:
136127
# * ======================= * #
137128
# * testing dataset part
138129
# * ======================= * #
139-
filterd_test_images: List[jnp.ndarray] = []
130+
filterd_test_images: List[np.ndarray] = []
140131
filterd_test_labels: List[int] = []
141132
all_inference_result = dict()
142133
num_test_images: int = len(dataset.test_images)
@@ -146,11 +137,11 @@ def _execute(solver: VerificationSolver) -> None:
146137
[output_name],
147138
{
148139
input_name: dataset.test_images[data_id]
149-
.astype(jnp.float32)
140+
.astype(np.float32)
150141
.reshape(1, dataset.num_pixels, 1)
151142
},
152143
)[0]
153-
inference_label = jnp.argmax(inference_result)
144+
inference_label = np.argmax(inference_result)
154145
true_label = dataset.test_labels[data_id]
155146
if inference_label == true_label:
156147
filterd_test_images.append(dataset.test_images[data_id])
@@ -170,7 +161,7 @@ def _execute(solver: VerificationSolver) -> None:
170161
# * ************************ * #
171162
# * step 2. based on each label, separate into different groups.
172163
# * ************************ * #
173-
distribution_filtered_test_labels: Dict[int, List[jnp.ndarray]] = (
164+
distribution_filtered_test_labels: Dict[int, List[np.ndarray]] = (
174165
{}
175166
) # * key: label, value: images
176167
for label in range(dataset.num_labels):
@@ -182,73 +173,36 @@ def _execute(solver: VerificationSolver) -> None:
182173
messages=f"label: {label}, number of images: {len(distribution_filtered_test_labels[label])}"
183174
)
184175

185-
# # * ************************ * #
186-
# # * [Archived] step 3. similarity analysis
187-
# # * ************************ * #
188-
# test_true_label: int = 1 # YES: 0, Y3(1), label 1: 0 & 1109 可以結合
189-
# Logger.debugging(
190-
# messages=f"number of testing images: {len(distribution_filtered_test_labels[test_true_label])}"
191-
# )
192-
# distance_matrix: jnp.ndarray = Similarity.generate_distance_matrix(
193-
# all_data=distribution_filtered_test_labels[test_true_label],
194-
# distance_type=dataset.distance_type,
195-
# chunk_size=100,
196-
# )
197-
198-
# # * find the similar data
199-
# all_inputs: List[jnp.ndarray] = []
200-
# similarity_data: List[int] = Similarity.greedy(distance_matrix=distance_matrix)
201-
# for id, value in enumerate(similarity_data):
202-
# all_inputs.append(distribution_filtered_test_labels[test_true_label][value])
203-
204-
# for i in all_inputs:
205-
# verify(solver=solver, dataset=dataset, all_inputs=[i], true_label=test_true_label)
206-
207-
# end_time = time.time()
208-
# Logger.info(
209-
# messages=f"elapsed time for batch verification is : {end_time - start_time}"
210-
# )
211-
# Logger.info(messages=f"number of iterations: {COUNT}")
212-
213-
214176
# step 4. & step 5.
215-
all_inputs: List[jnp.ndarray] = distribution_filtered_test_labels[T]
177+
all_inputs: List[np.ndarray] = distribution_filtered_test_labels[T]
216178
results: Results = Results()
217179
for i, each_input in enumerate(all_inputs):
218180
if i < dataset.num_inputs:
219181
start_time = time.time()
220-
status: str = verify(solver=solver, dataset=dataset, input=each_input)
182+
status: str = verify(args, dataset=dataset, input=each_input)
221183
end_time = time.time()
222184
new_result: List[Any] = ["Lp",
223185
"mnist",
224186
i,
225-
dataset.distance_type,
226187
str(end_time - start_time),
227188
status,
228-
dataset.epsilon,
229-
dataset.rotation_degree,
230-
dataset.brightness_level]
189+
dataset.epsilon]
231190
results.add_result(new_result)
232191

233192
return
234193

235194

236-
def main(solver: VerificationSolver = VerificationSolver.SCIP) -> str:
195+
def main(args) -> str:
237196
Logger.initialize(filename="log.txt", with_log_file=False)
238197
Logger.info(messages="mip verification is starting...")
239198

240-
_execute(solver=solver)
199+
_execute(args)
241200

242201
Logger.info(messages="mip verification is finished!")
243202

244203
return
245204

246205

247206
if __name__ == "__main__":
248-
parser = argparse.ArgumentParser()
249-
parser.add_argument("--solver", type=str, default="scip")
250-
251-
args = parser.parse_args()
252-
solver: VerificationSolver = VerificationSolver(args.solver)
253-
254-
main(solver=solver)
207+
args = parser.parse()
208+
main(args)

0 commit comments

Comments
 (0)