88)
99os .environ ["CUDA_VISIBLE_DEVICES" ] = "0"
1010os .environ ["JAX_TRACEBACK_FILTERING" ] = "off"
11- import argparse
1211
1312from matplotlib import pyplot as plt
1413import onnxruntime as ort
1514import numpy as np
16- import jax .numpy as jnp
1715
1816from src .mip import mip_verifier
19- from src .similarity import Similarity
2017
2118from utils .gurobi_modeling import GurobiModel
2219from utils .scip_modeling import SCIPModel
2825 write_vnnlib ,
2926 export_vnnlib ,
3027)
31- from utils .options import VerificationSolver , RobustnessType
3228from utils .save_results import Results
3329from utils .log import Logger
34-
30+ import utils . parser as parser
3531
3632# global configuration
3733T : int = 0 # true label
3834
3935
4036# TODO: implement verification algorithm in different ways
4137def verify (
42- solver : VerificationSolver ,
38+ args ,
4339 dataset : DataSet ,
44- input : jnp .ndarray ,
40+ input : np .ndarray ,
4541) -> str :
4642 """
4743 Verification algorithm:
4844
49- Support: MIP (SCIP, Gurobi), CROWN
45+ Support: MIP (SCIP, Gurobi), SMT
5046 """
5147 result : str = "UNSAT"
5248 vnnlib_filename : str = write_vnnlib (data = input ,
@@ -57,31 +53,27 @@ def verify(
5753 onnx_file_path = dataset .onnx_filename , vnnlib_file_path = vnnlib_filename
5854 )
5955
60- if solver is VerificationSolver .SCIP or solver is VerificationSolver .GUROBI :
61- Logger .info (messages = f"Verification Algorithm is MIP solver ({ solver } )" )
62-
63- m : SCIPModel | GurobiModel = mip_verifier (solver_name = solver , networks = networks )
56+ if args .solver == "scip" or args .solver == "gurobi" :
57+ Logger .info (messages = f"Verification Algorithm is MIP solver ({ args .solver } )" )
58+ m : SCIPModel | GurobiModel = mip_verifier (solver_name = args .solver , networks = networks )
6459 result = "UNSAT" if m .get_solution_status () == "Infeasible" else "SAT"
65- elif solver is VerificationSolver .BOX :
66- Logger .info (messages = "Verification Algorithm is BOX" )
67- Logger .info (messages = "[ERROR] Box verifier is not able to use in this version." )
6860 if result == "UNSAT" :
69- Logger .info (messages = "UNSAT, New template generated! " )
61+ Logger .info (messages = "UNSAT" )
7062 return "UNSAT"
7163 else :
7264 Logger .info (messages = "SAT" )
7365 # * Testing checker for counter-example found by verifier
7466 counter_example : List [float ] = get_ce (
75- solver = solver , networks = networks , filename = "./test_cex.txt" , m = m
67+ solver = args . solver , networks = networks , filename = "./test_cex.txt" , m = m
7668 )
77- counter_example : jnp .ndarray = jnp .array (counter_example )
69+ counter_example : np .ndarray = np .array (counter_example )
7870
7971 return result
8072
8173
82- def _execute (solver : VerificationSolver ) -> None :
74+ def _execute (args ) -> None :
8375 """
84- MIP verification
76+ Complete verification
8577
8678 Build a mixed-integer programming model to verify neural networks.
8779
@@ -97,12 +89,11 @@ def _execute(solver: VerificationSolver) -> None:
9789 # step 0.
9890 Logger .info (messages = "step 0: read the input files" )
9991 dataset : DataSet = load_dataset (
100- dataset_name = "mnist" ,
101- onnx_filename = "./utils/benchmarks/onnx/fc_5x100.onnx" ,
102- robustness_type = RobustnessType .LP_NORM ,
103- num_inputs = 1 , # len(distribution_filtered_test_labels[test_true_label])
104- distance_type = "linf" ,
105- epsilon = 0.1 ,
92+ dataset_name = args .dataset ,
93+ onnx_filename = f"./utils/benchmarks/onnx/{ args .network } " ,
94+ robustness_type = args .perturbation_type ,
95+ num_test = args .num_test , # len(distribution_filtered_test_labels[test_true_label])
96+ epsilon = args .epsilon ,
10697 )
10798
10899 # step 1.
@@ -136,7 +127,7 @@ def _execute(solver: VerificationSolver) -> None:
136127 # * ======================= * #
137128 # * testing dataset part
138129 # * ======================= * #
139- filterd_test_images : List [jnp .ndarray ] = []
130+ filterd_test_images : List [np .ndarray ] = []
140131 filterd_test_labels : List [int ] = []
141132 all_inference_result = dict ()
142133 num_test_images : int = len (dataset .test_images )
@@ -146,11 +137,11 @@ def _execute(solver: VerificationSolver) -> None:
146137 [output_name ],
147138 {
148139 input_name : dataset .test_images [data_id ]
149- .astype (jnp .float32 )
140+ .astype (np .float32 )
150141 .reshape (1 , dataset .num_pixels , 1 )
151142 },
152143 )[0 ]
153- inference_label = jnp .argmax (inference_result )
144+ inference_label = np .argmax (inference_result )
154145 true_label = dataset .test_labels [data_id ]
155146 if inference_label == true_label :
156147 filterd_test_images .append (dataset .test_images [data_id ])
@@ -170,7 +161,7 @@ def _execute(solver: VerificationSolver) -> None:
170161 # * ************************ * #
171162 # * step 2. based on each label, separate into different groups.
172163 # * ************************ * #
173- distribution_filtered_test_labels : Dict [int , List [jnp .ndarray ]] = (
164+ distribution_filtered_test_labels : Dict [int , List [np .ndarray ]] = (
174165 {}
175166 ) # * key: label, value: images
176167 for label in range (dataset .num_labels ):
@@ -182,73 +173,36 @@ def _execute(solver: VerificationSolver) -> None:
182173 messages = f"label: { label } , number of images: { len (distribution_filtered_test_labels [label ])} "
183174 )
184175
185- # # * ************************ * #
186- # # * [Archived] step 3. similarity analysis
187- # # * ************************ * #
188- # test_true_label: int = 1 # YES: 0, Y3(1), label 1: 0 & 1109 可以結合
189- # Logger.debugging(
190- # messages=f"number of testing images: {len(distribution_filtered_test_labels[test_true_label])}"
191- # )
192- # distance_matrix: jnp.ndarray = Similarity.generate_distance_matrix(
193- # all_data=distribution_filtered_test_labels[test_true_label],
194- # distance_type=dataset.distance_type,
195- # chunk_size=100,
196- # )
197-
198- # # * find the similar data
199- # all_inputs: List[jnp.ndarray] = []
200- # similarity_data: List[int] = Similarity.greedy(distance_matrix=distance_matrix)
201- # for id, value in enumerate(similarity_data):
202- # all_inputs.append(distribution_filtered_test_labels[test_true_label][value])
203-
204- # for i in all_inputs:
205- # verify(solver=solver, dataset=dataset, all_inputs=[i], true_label=test_true_label)
206-
207- # end_time = time.time()
208- # Logger.info(
209- # messages=f"elapsed time for batch verification is : {end_time - start_time}"
210- # )
211- # Logger.info(messages=f"number of iterations: {COUNT}")
212-
213-
214176 # step 4. & step 5.
215- all_inputs : List [jnp .ndarray ] = distribution_filtered_test_labels [T ]
177+ all_inputs : List [np .ndarray ] = distribution_filtered_test_labels [T ]
216178 results : Results = Results ()
217179 for i , each_input in enumerate (all_inputs ):
218180 if i < dataset .num_inputs :
219181 start_time = time .time ()
220- status : str = verify (solver = solver , dataset = dataset , input = each_input )
182+ status : str = verify (args , dataset = dataset , input = each_input )
221183 end_time = time .time ()
222184 new_result : List [Any ] = ["Lp" ,
223185 "mnist" ,
224186 i ,
225- dataset .distance_type ,
226187 str (end_time - start_time ),
227188 status ,
228- dataset .epsilon ,
229- dataset .rotation_degree ,
230- dataset .brightness_level ]
189+ dataset .epsilon ]
231190 results .add_result (new_result )
232191
233192 return
234193
235194
236- def main (solver : VerificationSolver = VerificationSolver . SCIP ) -> str :
195+ def main (args ) -> str :
237196 Logger .initialize (filename = "log.txt" , with_log_file = False )
238197 Logger .info (messages = "mip verification is starting..." )
239198
240- _execute (solver = solver )
199+ _execute (args )
241200
242201 Logger .info (messages = "mip verification is finished!" )
243202
244203 return
245204
246205
247206if __name__ == "__main__" :
248- parser = argparse .ArgumentParser ()
249- parser .add_argument ("--solver" , type = str , default = "scip" )
250-
251- args = parser .parse_args ()
252- solver : VerificationSolver = VerificationSolver (args .solver )
253-
254- main (solver = solver )
207+ args = parser .parse ()
208+ main (args )
0 commit comments