88from concurrent .futures import ProcessPoolExecutor , as_completed
99from yt .frontends .boxlib .api import CastroDataset
1010from yt .units import cm
11+ from dataclasses import dataclass
1112
12- def track_flame_front (ds , args ):
13+ # class to hold necessary parameters for tracking flame
14+ @dataclass
15+ class Metric :
16+ field : str
17+ threshold : float
18+ percent : float
19+
20+ def track_flame_front (ds , metric ):
1321 '''
1422 This function tracks the flame front position for a given dataset.
1523 It returns a list of the form: [Time (in ms), Theta, averaged_max_field, Theta_max, max_field]
@@ -48,10 +56,10 @@ def track_flame_front(ds, args):
4856 averaged_field = []
4957
5058 # First determine the global max of field quantity
51- max_val = ds .all_data ()[args .field ].max ()
59+ max_val = ds .all_data ()[metric .field ].max ()
5260
5361 # Determine a threshold of selecting zones for the average, i.e. minimum value allowed
54- min_val = max_val * args .threshold
62+ min_val = max_val * metric .threshold
5563
5664 # track the theta that has the maximum global value
5765 max_theta_loc = 0.0
@@ -65,12 +73,12 @@ def track_flame_front(ds, args):
6573 # isrt = np.argsort(ray["t"])
6674
6775 # Do the tracking
68- if any (ray [args .field ) == max_val ):
76+ if any (ray [metric .field ) == max_val ):
6977 max_theta_loc = theta
7078
7179 # Consider zones that are larger than minimum value
72- valid_zones = ray [args .field ] > min_val
73- valid_values = ray [args .field ][valid_zones ]
80+ valid_zones = ray [metric .field ] > min_val
81+ valid_values = ray [metric .field ][valid_zones ]
7482
7583 if len (valid_values ) > 0 :
7684 averaged_field .append (valid_values .mean ())
@@ -81,7 +89,7 @@ def track_flame_front(ds, args):
8189 max_index = np .argmax (averaged_field )
8290
8391 # Now assuming flame moves forward in theta, find theta such that the field drops below some threshold of the averaged max
84- loc_index = averaged_field [max_index :] < args .percent * max (averaged_field )
92+ loc_index = averaged_field [max_index :] < metric .percent * max (averaged_field )
8593
8694 # Find the first theta that the field drops below the threshold.
8795 theta_loc = thetas [max_index :][loc_index ][0 ]
@@ -97,7 +105,7 @@ def track_flame_front(ds, args):
97105 return timeTheta
98106
99107
100- def process_dataset (fname , args ):
108+ def process_dataset (fname , metric ):
101109 ds = CastroDataset (fname )
102110
103111 # Returns a list [time, theta, max averaged value, theta_max, max value]
@@ -122,7 +130,7 @@ def process_dataset(fname, args):
122130 the global maximum of the field quantity used to select valid zones
123131 for averaging""" )
124132 parser .add_argument ('--jobs' , '-j' , default = 1 , type = int ,
125- help = """Number of workers to process plot files in parallel"""
133+ help = """Number of workers to process plot files in parallel""" )
126134 parser .add_argument ('--out' , '-o' , default = "front_tracking.dat" , type = str ,
127135 help = """Output filename for the tracking information""" )
128136
@@ -138,14 +146,21 @@ def process_dataset(fname, args):
138146 if args .threshold <= 0.0 or args .percent > 1.0 :
139147 parser .error ("threshold must be a float between (0, 1]" )
140148
149+ # create a metric class to hold data needed to track flame
150+ metric = Metric (
151+ field = args .field ,
152+ threshold = args .threshold ,
153+ percent = args .percent ,
154+ )
155+
141156 timeThetaArray = []
142157
143158 ###
144159 ### Parallelize the loop. Copied from flame_wave/analysis/front_tracker.py
145160 ###
146161 with concurrent .futures .ProcessPoolExecutor (max_workers = args .jobs ) as executor :
147162 future_to_index = {
148- executor .submit (process_dataset , fname , args ): i
163+ executor .submit (process_dataset , fname , metric ): i
149164 for i , fname in enumerate (args .fnames )
150165 }
151166 try :
0 commit comments