Skip to content

Commit cdb1bf4

Browse files
committed
Update CUDNN_home for v9.7.0
1 parent c132ac1 commit cdb1bf4

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

tools/gen_ort_dockerfile.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,10 @@ def dockerfile_for_linux(output_file):
263263
ep_flags += ' --cuda_version "{}"'.format(FLAGS.cuda_version)
264264
if FLAGS.cuda_home is not None:
265265
ep_flags += ' --cuda_home "{}"'.format(FLAGS.cuda_home)
266+
if FLAGS.cudnn_home is not None:
267+
ep_flags += ' --cudnn_home "{}"'.format(FLAGS.cudnn_home)
268+
elif target_platform() == "igpu":
269+
ep_flags += ' --cudnn_home "/usr/lib/aarch64-linux-gnu"'
266270
if FLAGS.ort_tensorrt:
267271
ep_flags += " --use_tensorrt"
268272
if FLAGS.ort_version >= "1.12.1":
@@ -292,16 +296,16 @@ def dockerfile_for_linux(output_file):
292296
""".format(
293297
cuda_archs
294298
)
295-
if FLAGS.enable_gpu : #and target_platform() != "igpu"
296-
# For GPU build, include the cudnn_home flag
297-
df += """
298-
RUN _CUDNN_VERSION=$(echo $CUDNN_VERSION | cut -d. -f1-2) && ./build.sh ${{COMMON_BUILD_ARGS}} --update --build {} --cudnn_home /usr/local/cudnn-$_CUDNN_VERSION/cuda
299-
""".format(ep_flags)
300-
else:
301-
# For non-GPU
302-
df += """
303-
RUN ./build.sh ${{COMMON_BUILD_ARGS}} --update --build {}
304-
""".format(ep_flags)
299+
# if FLAGS.enable_gpu : #and target_platform() != "igpu"
300+
# # For GPU build, include the cudnn_home flag
301+
# df += """
302+
# RUN _CUDNN_VERSION=$(echo $CUDNN_VERSION | cut -d. -f1-2) && ./build.sh ${{COMMON_BUILD_ARGS}} --update --build {} --cudnn_home /usr/local/cudnn-$_CUDNN_VERSION/cuda
303+
# """.format(ep_flags)
304+
# else:
305+
# # For non-GPU
306+
# df += """
307+
# RUN ./build.sh ${{COMMON_BUILD_ARGS}} --update --build {}
308+
# """.format(ep_flags)
305309

306310
df += """
307311
RUN ./build.sh ${{COMMON_BUILD_ARGS}} --update --build {}
@@ -575,7 +579,8 @@ def preprocess_gpu_flags():
575579
# version = m.group(1)
576580
if FLAGS.cudnn_home is None:
577581
#FLAGS.cudnn_home = "/usr/local/cudnn-{}/cuda".format(version)
578-
FLAGS.cudnn_home = "/usr"
582+
#FLAGS.cudnn_home = "/usr/include"
583+
FLAGS.cudnn_home = "/usr/lib/x86_64-linux-gnu/"
579584

580585
if FLAGS.cuda_home is None:
581586
FLAGS.cuda_home = "/usr/local/cuda"

0 commit comments

Comments
 (0)