Skip to content

Commit 6822cde

Browse files
committed
Update object detection example to use new scannertools api
1 parent f56adf3 commit 6822cde

File tree

2 files changed

+19
-16
lines changed

2 files changed

+19
-16
lines changed

examples/apps/object_detection_tensorflow/kernels.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,13 @@
33
import cv2
44
import os
55
import scannerpy
6-
import scannerpy.stdlib.util
76
import pickle
87
import visualization_utils as vis_util
98
import tarfile
109

1110
from scannerpy import FrameType, DeviceType
12-
from scannerpy.stdlib import tensorflow
1311
from scannerpy.types import NumpyArrayFloat32
12+
from scannertools import tensorflow
1413
from typing import Tuple, Sequence
1514
from tqdm import tqdm
1615

@@ -24,12 +23,12 @@
2423
category_index = vis_util.create_category_index(categories)
2524

2625
def download_and_extract_model(url, local_path=None):
27-
path = scannerpy.stdlib.util.download_temp_file(url, local_path)
26+
path = scannerpy.util.download_temp_file(url, local_path)
2827
tar_file = tarfile.open(path)
2928
for f in tar_file.getmembers():
3029
file_name = os.path.basename(f.name)
3130
if 'frozen_inference_graph.pb' in file_name:
32-
local_path = scannerpy.stdlib.util.temp_directory()
31+
local_path = scannerpy.util.temp_directory()
3332
tar_file.extract(f, local_path)
3433
model_path = os.path.join(local_path, f.name)
3534
break
@@ -40,11 +39,13 @@ def download_and_extract_model(url, local_path=None):
4039
batch=2)
4140
class ObjDetect(tensorflow.TensorFlowKernel):
4241
def __init__(self, config, dnn_url):
42+
print('objdet', config)
43+
print([d.id for d in config.devices])
4344
tensorflow.TensorFlowKernel.__init__(self, config)
4445
self.dnn_url = dnn_url
4546
self.model_name = dnn_url.rsplit('/', 1)[-1]
4647
self.local_model_path = os.path.join(
47-
scannerpy.stdlib.util.temp_directory(),
48+
scannerpy.util.temp_directory(),
4849
self.model_name.rsplit('.')[0],
4950
'frozen_inference_graph.pb')
5051

examples/apps/object_detection_tensorflow/main.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from scannerpy import Client, DeviceType
2-
from scannerpy.storage import NamedVideoStream, PythonStream
1+
from scannertools.storage.python import PythonStream
2+
import scannerpy as sp
33
import os
44
import sys
55
import math
@@ -21,10 +21,10 @@ def main():
2121
print('Detecting objects in movie {}'.format(movie_path))
2222
movie_name = os.path.splitext(os.path.basename(movie_path))[0]
2323

24-
sc = Client()
24+
sc = sp.Client()
2525

2626
stride = 1
27-
input_stream = NamedVideoStream(sc, movie_name, path=movie_path)
27+
input_stream = sp.NamedVideoStream(sc, movie_name, path=movie_path)
2828
frame = sc.io.Input([input_stream])
2929
strided_frame = sc.streams.Stride(frame, [stride])
3030

@@ -33,12 +33,14 @@ def main():
3333
objdet_frame = sc.ops.ObjDetect(
3434
frame=strided_frame,
3535
dnn_url=model_url,
36-
device=DeviceType.GPU if sc.has_gpu() else DeviceType.CPU,
36+
device=sp.DeviceType.GPU if sc.has_gpu() else sp.DeviceType.CPU,
3737
batch=2)
3838

39-
detect_stream = NamedVideoStream(sc, movie_name + '_detect')
39+
detect_stream = sp.NamedVideoStream(sc, movie_name + '_detect')
4040
output_op = sc.io.Output(objdet_frame, [detect_stream])
41-
sc.run(output_op)
41+
sc.run(output_op,
42+
sp.PerfParams.estimate(),
43+
cache_mode=sp.CacheMode.Overwrite)
4244

4345
print('Extracting data from Scanner output...')
4446
# bundled_data_list is a list of bundled_data
@@ -58,9 +60,11 @@ def main():
5860
drawn_frame = sc.ops.TFDrawBoxes(frame=strided_frame,
5961
bundled_data=bundled_data,
6062
min_score_thresh=0.5)
61-
drawn_stream = NamedVideoStream(sc, movie_name + '_drawn_frames')
63+
drawn_stream = sp.NamedVideoStream(sc, movie_name + '_drawn_frames')
6264
output_op = sc.io.Output(drawn_frame, [drawn_stream])
63-
sc.run(output_op)
65+
sc.run(output_op,
66+
sp.PerfParams.estimate(),
67+
cache_mode=sp.CacheMode.Overwrite)
6468

6569
drawn_stream.save_mp4(movie_name + '_obj_detect')
6670

@@ -73,5 +77,3 @@ def main():
7377

7478
if __name__ == '__main__':
7579
main()
76-
77-

0 commit comments

Comments
 (0)