11import os
22
33import gradio as gr
4+ from zipfile import ZipFile
5+ from PIL import Image
46
57from .convert import nifti_to_obj
68from .css_style import css
@@ -22,48 +24,39 @@ def __init__(
2224 cwd : str = "/home/user/app/" ,
2325 share : int = 1 ,
2426 ):
27+ self .file_output = None
28+ self .model_selector = None
29+ self .stripped_cb = None
30+ self .registered_cb = None
31+ self .run_btn = None
32+ self .slider = None
33+ self .download_file = None
34+
2535 # global states
2636 self .images = []
2737 self .pred_images = []
2838
29- # @TODO: This should be dynamically set based on chosen volume size
30- self .nb_slider_items = 820
31-
3239 self .model_name = model_name
3340 self .cwd = cwd
3441 self .share = share
3542
36- self .filename = None
37- self .extension = None
38-
39- self .class_name = "airways" # default
43+ self .class_name = "Airways" # default
4044 self .class_names = {
41- "airways" : "CT_Airways" ,
42- "lungs" : "CT_Lungs" ,
45+ "Airways" : "CT_Airways" ,
4346 }
4447
4548 self .result_names = {
46- "airways" : "Airways" ,
47- "lungs" : "Lungs" ,
49+ "Airways" : "Airways" ,
4850 }
4951
50- # define widgets not to be rendered immediantly, but later on
51- self .slider = gr .Slider (
52- minimum = 1 ,
53- maximum = self .nb_slider_items ,
54- value = 1 ,
55- step = 1 ,
56- label = "Which 2D slice to show" ,
57- )
5852 self .volume_renderer = gr .Model3D (
5953 clear_color = [0.0 , 0.0 , 0.0 , 0.0 ],
6054 label = "3D Model" ,
61- show_label = True ,
6255 visible = True ,
6356 elem_id = "model-3d" ,
64- camera_position = [90 , 180 , 768 ],
6557 height = 512 ,
6658 )
59+ # self.volume_renderer = ShinyModel3D()
6760
6861 def set_class_name (self , value ):
6962 LOGGER .info (f"Changed task to: { value } " )
@@ -79,75 +72,107 @@ def upload_file(self, file):
7972
8073 def process (self , mesh_file_name ):
8174 path = mesh_file_name .name
82- curr = path .split ("/" )[- 1 ]
83- self .extension = "." .join (curr .split ("." )[1 :])
84- self .filename = (
85- curr .split ("." )[0 ] + "-" + self .class_names [self .class_name ]
86- )
8775 run_model (
8876 path ,
8977 model_path = os .path .join (self .cwd , "resources/models/" ),
9078 task = self .class_names [self .class_name ],
9179 name = self .result_names [self .class_name ],
92- output_filename = self .filename + "." + self .extension ,
9380 )
9481 LOGGER .info ("Converting prediction NIfTI to OBJ..." )
95- nifti_to_obj (path = self . filename + "." + self . extension )
82+ nifti_to_obj ("prediction.nii.gz" )
9683
9784 LOGGER .info ("Loading CT to numpy..." )
9885 self .images = load_ct_to_numpy (path )
9986
10087 LOGGER .info ("Loading prediction volume to numpy.." )
101- self .pred_images = load_pred_volume_to_numpy (
102- self .filename + "." + self .extension
103- )
88+ self .pred_images = load_pred_volume_to_numpy ("./prediction.nii.gz" )
10489
105- return "./prediction.obj"
90+ slider = gr .Slider (
91+ minimum = 0 ,
92+ maximum = len (self .images ) - 1 ,
93+ value = int (len (self .images ) / 2 ),
94+ step = 1 ,
95+ label = "Which 2D slice to show" ,
96+ interactive = True ,
97+ )
10698
107- def download_prediction (self ):
108- if (self .filename is None ) or (self .extension is None ):
109- LOGGER .error (
110- "The prediction is not available or ready to download. Wait until the result is available in the 3D viewer."
111- )
112- raise ValueError ("Run inference before downloading!" )
113- return self .filename + "." + self .extension
99+ return "./prediction.obj" , slider
114100
115101 def get_img_pred_pair (self , k ):
116- k = int (k )
117- out = gr .AnnotatedImage (
118- self .combine_ct_and_seg (self .images [k ], self .pred_images [k ]),
119- visible = True ,
120- elem_id = "model-2d" ,
121- color_map = {self .class_name : "#ffae00" },
122- height = 512 ,
123- width = 512 ,
124- )
125- return out
102+ img = self .images [k ]
103+ img_pil = Image .fromarray (img )
104+ seg_list = []
105+ seg_list .append ((self .pred_images [k ], self .class_name ))
106+ return img_pil , seg_list
126107
127108 def toggle_sidebar (self , state ):
128109 state = not state
129110 return gr .update (visible = state ), state
130111
112+ def package_results (self ):
113+ """Generates text files and zips them."""
114+ output_dir = "temp_output"
115+ os .makedirs (output_dir , exist_ok = True )
116+
117+ zip_filename = os .path .join (output_dir , "generated_files.zip" )
118+ with ZipFile (zip_filename , 'w' ) as zf :
119+ zf .write ("./prediction.nii.gz" )
120+
121+ return zip_filename
122+
123+ def setup_interface_outputs (self ):
124+ with gr .Row ():
125+ with gr .Group ():
126+ with gr .Column (scale = 2 ):
127+ t = gr .AnnotatedImage (
128+ visible = True ,
129+ elem_id = "model-2d" ,
130+ color_map = {self .class_name : "#ffae00" },
131+ height = 512 ,
132+ width = 512 ,
133+ )
134+
135+ self .slider = gr .Slider (
136+ minimum = 0 ,
137+ maximum = 1 ,
138+ value = 0 ,
139+ step = 1 ,
140+ label = "Which 2D slice to show" ,
141+ interactive = True ,
142+ )
143+
144+ self .slider .change (fn = self .get_img_pred_pair , inputs = self .slider , outputs = t )
145+
146+ with gr .Group ():
147+ self .volume_renderer .render ()
148+ self .download_btn = gr .DownloadButton (label = "Download results" , visible = False )
149+ self .download_file = gr .File (label = "Download Zip" , interactive = True , visible = False )
150+
151+
131152 def run (self ):
132153 with gr .Blocks (css = css ) as demo :
133154 with gr .Row ():
134- with gr .Column (visible = True , scale = 0.2 ) as sidebar_left :
155+ with gr .Column (scale = 1 , visible = True ) as sidebar_left :
135156 logs = gr .Textbox (
136157 placeholder = "\n " * 16 ,
137158 label = "Logs" ,
138159 info = "Verbose from inference will be displayed below." ,
139- lines = 36 ,
140- max_lines = 36 ,
160+ lines = 38 ,
161+ max_lines = 38 ,
141162 autoscroll = True ,
142163 elem_id = "logs" ,
143164 show_copy_button = True ,
165+ # scroll_to_output=False,
144166 container = True ,
167+ # line_breaks=True,
145168 )
146- demo .load (read_logs , None , logs , every = 1 )
169+ timer = gr .Timer (value = 1 , active = True )
170+ timer .tick (fn = read_logs , inputs = None , outputs = logs )
171+ # demo.load(read_logs, None, logs, every=0.5)
147172
148- with gr .Column ():
173+ with gr .Column (scale = 2 ):
149174 with gr .Row ():
150- with gr .Column (scale = 1 , min_width = 150 ):
175+ with gr .Column (min_width = 150 ):
151176 sidebar_state = gr .State (True )
152177
153178 btn_toggle_sidebar = gr .Button (
@@ -160,66 +185,30 @@ def run(self):
160185 [sidebar_left , sidebar_state ],
161186 )
162187
163- btn_clear_logs = gr .Button (
164- "Clear logs" , elem_id = "logs-button"
165- )
188+ btn_clear_logs = gr .Button ("Clear logs" , elem_id = "logs-button" )
166189 btn_clear_logs .click (flush_logs , [], [])
167190
168- file_output = gr .File (
169- file_count = "single" ,
170- elem_id = "upload" ,
171- scale = 3 ,
172- )
173- file_output .upload (
174- self .upload_file , file_output , file_output
175- )
191+ self .file_output = gr .File (file_count = "single" , elem_id = "upload" )
176192
177- model_selector = gr .Dropdown (
193+ self . model_selector = gr .Dropdown (
178194 list (self .class_names .keys ()),
179195 label = "Task" ,
180196 info = "Which structure to segment." ,
181197 multiselect = False ,
182- scale = 1 ,
183- )
184- model_selector .input (
185- fn = lambda x : self .set_class_name (x ),
186- inputs = model_selector ,
187- outputs = None ,
188198 )
189199
190- with gr .Column (scale = 1 , min_width = 150 ):
191- run_btn = gr .Button (
192- "Run analysis" ,
193- variant = "primary" ,
194- elem_id = "run-button" ,
195- )
196- run_btn .click (
197- fn = lambda x : self .process (x ),
198- inputs = file_output ,
199- outputs = self .volume_renderer ,
200- )
201-
202- download_btn = gr .DownloadButton (
203- "Download prediction" ,
204- visible = True ,
205- variant = "secondary" ,
206- elem_id = "download" ,
207- )
208- download_btn .click (
209- fn = self .download_prediction ,
210- inputs = None ,
211- outputs = download_btn ,
212- )
200+ with gr .Column (min_width = 150 ):
201+ self .run_btn = gr .Button ("Run segmentation" , variant = "primary" , elem_id = "run-button" )
213202
214203 with gr .Row ():
215204 gr .Examples (
216205 examples = [
217206 os .path .join (self .cwd , "test_thorax_CT.nii.gz" ),
218207 ],
219- inputs = file_output ,
220- outputs = file_output ,
208+ inputs = self . file_output ,
209+ outputs = self . file_output ,
221210 fn = self .upload_file ,
222- cache_examples = True ,
211+ cache_examples = False ,
223212 )
224213
225214 gr .Markdown (
@@ -229,32 +218,19 @@ def run(self):
229218 """
230219 )
231220
232- with gr .Row ():
233- with gr .Group ():
234- with gr .Column ():
235- # create dummy image to be replaced by loaded images
236- t = gr .AnnotatedImage (
237- visible = True ,
238- elem_id = "model-2d" ,
239- color_map = {self .class_name : "#ffae00" },
240- # height=512,
241- # width=512,
242- )
243- self .slider .input (
244- self .get_img_pred_pair ,
245- self .slider ,
246- t ,
247- )
248-
249- self .slider .render ()
250-
251- with gr .Group (): # gr.Box():
252- self .volume_renderer .render ()
253-
221+ self .setup_interface_outputs ()
222+
223+ # Define the signals/slots
224+ self .file_output .upload (self .upload_file , self .file_output , self .file_output )
225+ self .model_selector .input (fn = lambda x : self .set_class_name (x ), inputs = self .model_selector , outputs = None )
226+ self .run_btn .click (fn = self .process , inputs = [self .file_output ],
227+ outputs = [self .volume_renderer , self .slider ]).then (fn = lambda :
228+ gr .DownloadButton (visible = True ), inputs = None , outputs = self .download_btn )
229+ self .download_btn .click (fn = self .package_results , inputs = [], outputs = self .download_file ).then (fn = lambda
230+ file_path : gr .File (label = "Download Zip" , visible = True , value = file_path ), inputs = self .download_file ,
231+ outputs = self .download_file )
254232 # sharing app publicly -> share=True:
255233 # https://gradio.app/sharing-your-app/
256234 # inference times > 60 seconds -> need queue():
257235 # https://github.com/tloen/alpaca-lora/issues/60#issuecomment-1510006062
258- demo .queue ().launch (
259- server_name = "0.0.0.0" , server_port = 7860 , share = self .share
260- )
236+ demo .queue ().launch (server_name = "0.0.0.0" , server_port = 7860 , share = self .share )
0 commit comments