Skip to content

Commit b5d0029

Browse files
committed
Upgrade to gradio 5
1 parent 79725b7 commit b5d0029

File tree

5 files changed

+144
-155
lines changed

5 files changed

+144
-155
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ venv/
1313
*.obj
1414
*.zip
1515
*.txt
16+
*.idea/

Dockerfile

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
22
# you will also find guides on how best to write your Dockerfile
3-
FROM python:3.8-slim
3+
FROM python:3.10-slim
44

55
# set language, format and stuff
66
ENV LANG=C.UTF-8 LC_ALL=C.UTF-8
@@ -50,10 +50,10 @@ WORKDIR $HOME/app
5050
COPY --chown=user . $HOME/app
5151

5252
# Download pretrained models
53-
RUN wget "https://github.com/raidionics/Raidionics-models/releases/download/1.2.0/Raidionics-CT_Airways-ONNX-v12.zip" && \
54-
unzip "Raidionics-CT_Airways-ONNX-v12.zip" && mkdir -p resources/models/ && mv CT_Airways/ resources/models/CT_Airways/
55-
RUN wget "https://github.com/raidionics/Raidionics-models/releases/download/1.2.0/Raidionics-CT_Lungs-ONNX-v12.zip" && \
56-
unzip "Raidionics-CT_Lungs-ONNX-v12.zip" && mv CT_Lungs/ resources/models/CT_Lungs/
53+
RUN wget "https://github.com/raidionics/Raidionics-models/releases/download/v1.3.0-rc/Raidionics-CT_Airways-v13.zip" && \
54+
unzip "Raidionics-CT_Airways-v13.zip" && mkdir -p resources/models/ && mv CT_Airways/ resources/models/CT_Airways/
55+
RUN wget "https://github.com/raidionics/Raidionics-models/releases/download/v1.3.0-rc/Raidionics-CT_Lungs-v13.zip" && \
56+
unzip "Raidionics-CT_Lungs-v13.zip" && mv CT_Lungs/ resources/models/CT_Lungs/
5757

5858
RUN rm -r *.zip
5959

demo/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
raidionicsrads@git+https://github.com/dbouget/raidionics_rads_lib
2-
gradio==4.29.0
1+
raidionicsrads
2+
gradio

demo/src/gui.py

Lines changed: 100 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import os
22

33
import gradio as gr
4+
from zipfile import ZipFile
5+
from PIL import Image
46

57
from .convert import nifti_to_obj
68
from .css_style import css
@@ -22,48 +24,39 @@ def __init__(
2224
cwd: str = "/home/user/app/",
2325
share: int = 1,
2426
):
27+
self.file_output = None
28+
self.model_selector = None
29+
self.stripped_cb = None
30+
self.registered_cb = None
31+
self.run_btn = None
32+
self.slider = None
33+
self.download_file = None
34+
2535
# global states
2636
self.images = []
2737
self.pred_images = []
2838

29-
# @TODO: This should be dynamically set based on chosen volume size
30-
self.nb_slider_items = 820
31-
3239
self.model_name = model_name
3340
self.cwd = cwd
3441
self.share = share
3542

36-
self.filename = None
37-
self.extension = None
38-
39-
self.class_name = "airways" # default
43+
self.class_name = "Airways" # default
4044
self.class_names = {
41-
"airways": "CT_Airways",
42-
"lungs": "CT_Lungs",
45+
"Airways": "CT_Airways",
4346
}
4447

4548
self.result_names = {
46-
"airways": "Airways",
47-
"lungs": "Lungs",
49+
"Airways": "Airways",
4850
}
4951

50-
# define widgets not to be rendered immediantly, but later on
51-
self.slider = gr.Slider(
52-
minimum=1,
53-
maximum=self.nb_slider_items,
54-
value=1,
55-
step=1,
56-
label="Which 2D slice to show",
57-
)
5852
self.volume_renderer = gr.Model3D(
5953
clear_color=[0.0, 0.0, 0.0, 0.0],
6054
label="3D Model",
61-
show_label=True,
6255
visible=True,
6356
elem_id="model-3d",
64-
camera_position=[90, 180, 768],
6557
height=512,
6658
)
59+
# self.volume_renderer = ShinyModel3D()
6760

6861
def set_class_name(self, value):
6962
LOGGER.info(f"Changed task to: {value}")
@@ -79,75 +72,107 @@ def upload_file(self, file):
7972

8073
def process(self, mesh_file_name):
8174
path = mesh_file_name.name
82-
curr = path.split("/")[-1]
83-
self.extension = ".".join(curr.split(".")[1:])
84-
self.filename = (
85-
curr.split(".")[0] + "-" + self.class_names[self.class_name]
86-
)
8775
run_model(
8876
path,
8977
model_path=os.path.join(self.cwd, "resources/models/"),
9078
task=self.class_names[self.class_name],
9179
name=self.result_names[self.class_name],
92-
output_filename=self.filename + "." + self.extension,
9380
)
9481
LOGGER.info("Converting prediction NIfTI to OBJ...")
95-
nifti_to_obj(path=self.filename + "." + self.extension)
82+
nifti_to_obj("prediction.nii.gz")
9683

9784
LOGGER.info("Loading CT to numpy...")
9885
self.images = load_ct_to_numpy(path)
9986

10087
LOGGER.info("Loading prediction volume to numpy..")
101-
self.pred_images = load_pred_volume_to_numpy(
102-
self.filename + "." + self.extension
103-
)
88+
self.pred_images = load_pred_volume_to_numpy("./prediction.nii.gz")
10489

105-
return "./prediction.obj"
90+
slider = gr.Slider(
91+
minimum=0,
92+
maximum=len(self.images) - 1,
93+
value=int(len(self.images) / 2),
94+
step=1,
95+
label="Which 2D slice to show",
96+
interactive=True,
97+
)
10698

107-
def download_prediction(self):
108-
if (self.filename is None) or (self.extension is None):
109-
LOGGER.error(
110-
"The prediction is not available or ready to download. Wait until the result is available in the 3D viewer."
111-
)
112-
raise ValueError("Run inference before downloading!")
113-
return self.filename + "." + self.extension
99+
return "./prediction.obj", slider
114100

115101
def get_img_pred_pair(self, k):
116-
k = int(k)
117-
out = gr.AnnotatedImage(
118-
self.combine_ct_and_seg(self.images[k], self.pred_images[k]),
119-
visible=True,
120-
elem_id="model-2d",
121-
color_map={self.class_name: "#ffae00"},
122-
height=512,
123-
width=512,
124-
)
125-
return out
102+
img = self.images[k]
103+
img_pil = Image.fromarray(img)
104+
seg_list = []
105+
seg_list.append((self.pred_images[k], self.class_name))
106+
return img_pil, seg_list
126107

127108
def toggle_sidebar(self, state):
128109
state = not state
129110
return gr.update(visible=state), state
130111

112+
def package_results(self):
113+
"""Generates text files and zips them."""
114+
output_dir = "temp_output"
115+
os.makedirs(output_dir, exist_ok=True)
116+
117+
zip_filename = os.path.join(output_dir, "generated_files.zip")
118+
with ZipFile(zip_filename, 'w') as zf:
119+
zf.write("./prediction.nii.gz")
120+
121+
return zip_filename
122+
123+
def setup_interface_outputs(self):
124+
with gr.Row():
125+
with gr.Group():
126+
with gr.Column(scale=2):
127+
t = gr.AnnotatedImage(
128+
visible=True,
129+
elem_id="model-2d",
130+
color_map={self.class_name: "#ffae00"},
131+
height=512,
132+
width=512,
133+
)
134+
135+
self.slider = gr.Slider(
136+
minimum=0,
137+
maximum=1,
138+
value=0,
139+
step=1,
140+
label="Which 2D slice to show",
141+
interactive=True,
142+
)
143+
144+
self.slider.change(fn=self.get_img_pred_pair, inputs=self.slider, outputs=t)
145+
146+
with gr.Group():
147+
self.volume_renderer.render()
148+
self.download_btn = gr.DownloadButton(label="Download results", visible=False)
149+
self.download_file = gr.File(label="Download Zip", interactive=True, visible=False)
150+
151+
131152
def run(self):
132153
with gr.Blocks(css=css) as demo:
133154
with gr.Row():
134-
with gr.Column(visible=True, scale=0.2) as sidebar_left:
155+
with gr.Column(scale=1, visible=True) as sidebar_left:
135156
logs = gr.Textbox(
136157
placeholder="\n" * 16,
137158
label="Logs",
138159
info="Verbose from inference will be displayed below.",
139-
lines=36,
140-
max_lines=36,
160+
lines=38,
161+
max_lines=38,
141162
autoscroll=True,
142163
elem_id="logs",
143164
show_copy_button=True,
165+
# scroll_to_output=False,
144166
container=True,
167+
# line_breaks=True,
145168
)
146-
demo.load(read_logs, None, logs, every=1)
169+
timer = gr.Timer(value=1, active=True)
170+
timer.tick(fn=read_logs, inputs=None, outputs=logs)
171+
# demo.load(read_logs, None, logs, every=0.5)
147172

148-
with gr.Column():
173+
with gr.Column(scale=2):
149174
with gr.Row():
150-
with gr.Column(scale=1, min_width=150):
175+
with gr.Column(min_width=150):
151176
sidebar_state = gr.State(True)
152177

153178
btn_toggle_sidebar = gr.Button(
@@ -160,66 +185,30 @@ def run(self):
160185
[sidebar_left, sidebar_state],
161186
)
162187

163-
btn_clear_logs = gr.Button(
164-
"Clear logs", elem_id="logs-button"
165-
)
188+
btn_clear_logs = gr.Button("Clear logs", elem_id="logs-button")
166189
btn_clear_logs.click(flush_logs, [], [])
167190

168-
file_output = gr.File(
169-
file_count="single",
170-
elem_id="upload",
171-
scale=3,
172-
)
173-
file_output.upload(
174-
self.upload_file, file_output, file_output
175-
)
191+
self.file_output = gr.File(file_count="single", elem_id="upload")
176192

177-
model_selector = gr.Dropdown(
193+
self.model_selector = gr.Dropdown(
178194
list(self.class_names.keys()),
179195
label="Task",
180196
info="Which structure to segment.",
181197
multiselect=False,
182-
scale=1,
183-
)
184-
model_selector.input(
185-
fn=lambda x: self.set_class_name(x),
186-
inputs=model_selector,
187-
outputs=None,
188198
)
189199

190-
with gr.Column(scale=1, min_width=150):
191-
run_btn = gr.Button(
192-
"Run analysis",
193-
variant="primary",
194-
elem_id="run-button",
195-
)
196-
run_btn.click(
197-
fn=lambda x: self.process(x),
198-
inputs=file_output,
199-
outputs=self.volume_renderer,
200-
)
201-
202-
download_btn = gr.DownloadButton(
203-
"Download prediction",
204-
visible=True,
205-
variant="secondary",
206-
elem_id="download",
207-
)
208-
download_btn.click(
209-
fn=self.download_prediction,
210-
inputs=None,
211-
outputs=download_btn,
212-
)
200+
with gr.Column(min_width=150):
201+
self.run_btn = gr.Button("Run segmentation", variant="primary", elem_id="run-button")
213202

214203
with gr.Row():
215204
gr.Examples(
216205
examples=[
217206
os.path.join(self.cwd, "test_thorax_CT.nii.gz"),
218207
],
219-
inputs=file_output,
220-
outputs=file_output,
208+
inputs=self.file_output,
209+
outputs=self.file_output,
221210
fn=self.upload_file,
222-
cache_examples=True,
211+
cache_examples=False,
223212
)
224213

225214
gr.Markdown(
@@ -229,32 +218,19 @@ def run(self):
229218
"""
230219
)
231220

232-
with gr.Row():
233-
with gr.Group():
234-
with gr.Column():
235-
# create dummy image to be replaced by loaded images
236-
t = gr.AnnotatedImage(
237-
visible=True,
238-
elem_id="model-2d",
239-
color_map={self.class_name: "#ffae00"},
240-
# height=512,
241-
# width=512,
242-
)
243-
self.slider.input(
244-
self.get_img_pred_pair,
245-
self.slider,
246-
t,
247-
)
248-
249-
self.slider.render()
250-
251-
with gr.Group(): # gr.Box():
252-
self.volume_renderer.render()
253-
221+
self.setup_interface_outputs()
222+
223+
# Define the signals/slots
224+
self.file_output.upload(self.upload_file, self.file_output, self.file_output)
225+
self.model_selector.input(fn=lambda x: self.set_class_name(x), inputs=self.model_selector, outputs=None)
226+
self.run_btn.click(fn=self.process, inputs=[self.file_output],
227+
outputs=[self.volume_renderer, self.slider]).then(fn=lambda:
228+
gr.DownloadButton(visible=True), inputs=None, outputs=self.download_btn)
229+
self.download_btn.click(fn=self.package_results, inputs=[], outputs=self.download_file).then(fn=lambda
230+
file_path: gr.File(label="Download Zip", visible=True, value=file_path), inputs=self.download_file,
231+
outputs=self.download_file)
254232
# sharing app publicly -> share=True:
255233
# https://gradio.app/sharing-your-app/
256234
# inference times > 60 seconds -> need queue():
257235
# https://github.com/tloen/alpaca-lora/issues/60#issuecomment-1510006062
258-
demo.queue().launch(
259-
server_name="0.0.0.0", server_port=7860, share=self.share
260-
)
236+
demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=self.share)

0 commit comments

Comments
 (0)