Skip to content

Commit 7386cab

Browse files
authored
Test ragged batches. (#77)
Extended dynamic batching tests to check ragged batches support. Signed-off-by: Rafal <[email protected]>
1 parent 191ae68 commit 7386cab

File tree

3 files changed

+9
-6
lines changed

3 files changed

+9
-6
lines changed

qa/L0_DALI_GPU_ensemble/client.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,13 @@ def ref_func(inp1, inp2):
9797
return inp1 * 2 / 3, (inp2 * 3).astype(np.half).astype(np.single) / 2
9898

9999

100-
def random_gen(max_batch_size, uniform_groups=1):
100+
def random_gen(max_batch_size):
101101
while True:
102102
size1 = randint(100, 300)
103103
size2 = randint(100, 300)
104-
for i in range(uniform_groups):
105-
bs = randint(1, max_batch_size + 1)
106-
yield np.random.random((bs, size1)).astype(np.single), \
107-
np.random.random((bs, size2)).astype(np.single)
104+
bs = randint(1, max_batch_size + 1)
105+
yield np.random.random((bs, size1)).astype(np.single), \
106+
np.random.random((bs, size2)).astype(np.single)
108107

109108
def parse_args():
110109
parser = argparse.ArgumentParser()
@@ -120,7 +119,7 @@ def main():
120119
args = parse_args()
121120
client = TestClient('dali_ensemble', ['INPUT_0', 'INPUT_1'], ['OUTPUT_0', 'OUTPUT_1'], args.url,
122121
concurrency=args.concurrency)
123-
client.run_tests(random_gen(args.max_batch_size, args.concurrency), ref_func,
122+
client.run_tests(random_gen(args.max_batch_size), ref_func,
124123
n_infers=args.n_iters, eps=1e-4)
125124

126125
if __name__ == '__main__':

qa/L0_DALI_GPU_ensemble/model_repository/dali_1/config.pbtxt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ input [
2727
name: "DALI_INPUT_0"
2828
data_type: TYPE_FP32
2929
dims: [ -1 ]
30+
allow_ragged_batch: true
3031
}
3132
]
3233

@@ -35,6 +36,7 @@ input [
3536
name: "DALI_INPUT_1"
3637
data_type: TYPE_FP32
3738
dims: [ -1 ]
39+
allow_ragged_batch: true
3840
}
3941
]
4042

qa/L0_DALI_GPU_ensemble/model_repository/dali_2/config.pbtxt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ input [
2727
name: "DALI_INPUT_0"
2828
data_type: TYPE_FP32
2929
dims: [ -1 ]
30+
allow_ragged_batch: true
3031
}
3132
]
3233

@@ -35,6 +36,7 @@ input [
3536
name: "DALI_INPUT_1"
3637
data_type: TYPE_FP16
3738
dims: [ -1 ]
39+
allow_ragged_batch: true
3840
}
3941
]
4042

0 commit comments

Comments
 (0)