diff --git a/build.py b/build.py index e35e910d98..052f77aed9 100755 --- a/build.py +++ b/build.py @@ -1220,8 +1220,20 @@ def create_dockerfile_cibase(ddir, dockerfile_name, argmap): def create_dockerfile_linux( - ddir, dockerfile_name, argmap, backends, repoagents, caches, endpoints + ddir, + dockerfile_name, + argmap, + backends, + repoagents, + caches, + endpoints, + runtime_image=None, ): + base_image = argmap["BASE_IMAGE"] + # If runtime base image is provided, use it as the base image + if runtime_image: + base_image = runtime_image + df = """ ARG TRITON_VERSION={} ARG TRITON_CONTAINER_VERSION={} @@ -1230,7 +1242,7 @@ def create_dockerfile_linux( """.format( argmap["TRITON_VERSION"], argmap["TRITON_CONTAINER_VERSION"], - argmap["BASE_IMAGE"], + base_image, ) # PyTorch and TensorFlow backends need extra CUDA and other @@ -1642,8 +1654,12 @@ def change_default_python_version_rhel(version): def create_dockerfile_windows( - ddir, dockerfile_name, argmap, backends, repoagents, caches + ddir, dockerfile_name, argmap, backends, repoagents, caches, runtime_image=None ): + base_image = argmap["BASE_IMAGE"] + # If runtime base image is provided, use it as the base image + if runtime_image: + base_image = runtime_image df = """ ARG TRITON_VERSION={} ARG TRITON_CONTAINER_VERSION={} @@ -1666,7 +1682,7 @@ def create_dockerfile_windows( """.format( argmap["TRITON_VERSION"], argmap["TRITON_CONTAINER_VERSION"], - argmap["BASE_IMAGE"], + base_image, ) df += """ WORKDIR /opt @@ -1752,6 +1768,7 @@ def create_build_dockerfiles( backends, repoagents, caches, + images.get("runtime"), ) else: create_dockerfile_linux( @@ -1762,6 +1779,7 @@ def create_build_dockerfiles( repoagents, caches, endpoints, + images.get("runtime"), ) # Dockerfile used for the creating the CI base image. @@ -2940,7 +2958,14 @@ def enable_all(): ) fail_if( parts[0] - not in ["base", "gpu-base", "pytorch", "tensorflow", "tensorflow2"], + not in [ + "base", + "gpu-base", + "pytorch", + "tensorflow", + "tensorflow2", + "runtime", + ], "unsupported value for --image", ) log('image "{}": "{}"'.format(parts[0], parts[1]))