Skip to content
35 changes: 30 additions & 5 deletions build.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,8 +1220,20 @@ def create_dockerfile_cibase(ddir, dockerfile_name, argmap):


def create_dockerfile_linux(
ddir, dockerfile_name, argmap, backends, repoagents, caches, endpoints
ddir,
dockerfile_name,
argmap,
backends,
repoagents,
caches,
endpoints,
runtime_image=None,
):
base_image = argmap["BASE_IMAGE"]
# If runtime base image is provided, use it as the base image
if runtime_image:
base_image = runtime_image

df = """
ARG TRITON_VERSION={}
ARG TRITON_CONTAINER_VERSION={}
Expand All @@ -1230,7 +1242,7 @@ def create_dockerfile_linux(
""".format(
argmap["TRITON_VERSION"],
argmap["TRITON_CONTAINER_VERSION"],
argmap["BASE_IMAGE"],
base_image,
)

# PyTorch and TensorFlow backends need extra CUDA and other
Expand Down Expand Up @@ -1642,8 +1654,12 @@ def change_default_python_version_rhel(version):


def create_dockerfile_windows(
ddir, dockerfile_name, argmap, backends, repoagents, caches
ddir, dockerfile_name, argmap, backends, repoagents, caches, runtime_image=None
):
base_image = argmap["BASE_IMAGE"]
# If runtime base image is provided, use it as the base image
if runtime_image:
base_image = runtime_image
df = """
ARG TRITON_VERSION={}
ARG TRITON_CONTAINER_VERSION={}
Expand All @@ -1666,7 +1682,7 @@ def create_dockerfile_windows(
""".format(
argmap["TRITON_VERSION"],
argmap["TRITON_CONTAINER_VERSION"],
argmap["BASE_IMAGE"],
base_image,
)
df += """
WORKDIR /opt
Expand Down Expand Up @@ -1752,6 +1768,7 @@ def create_build_dockerfiles(
backends,
repoagents,
caches,
images.get("runtime"),
)
else:
create_dockerfile_linux(
Expand All @@ -1762,6 +1779,7 @@ def create_build_dockerfiles(
repoagents,
caches,
endpoints,
images.get("runtime"),
)

# Dockerfile used for the creating the CI base image.
Expand Down Expand Up @@ -2940,7 +2958,14 @@ def enable_all():
)
fail_if(
parts[0]
not in ["base", "gpu-base", "pytorch", "tensorflow", "tensorflow2"],
not in [
"base",
"gpu-base",
"pytorch",
"tensorflow",
"tensorflow2",
"runtime",
],
"unsupported value for --image",
)
log('image "{}": "{}"'.format(parts[0], parts[1]))
Expand Down
Loading