Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 25 additions & 116 deletions .github/workflows/build.yaml → .github/workflows/_build.yaml
Original file line number Diff line number Diff line change
@@ -1,102 +1,21 @@
name: PJRT GPU library
name: Build PJRT GPU libraries

on:
push:
tags:
- "*"
pull_request:
workflow_dispatch:
inputs:
xla_commit:
required: true
type: string
default: "main"
rocm_xla_commit:
required: true
type: string
default: "main"
workflow_call:
inputs:
xla_commit:
required: false
type: string
default: "main"
rocm_xla_commit:
required: false
type: string
default: "main"

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true

permissions:
contents: write

env:
# From master 2025-12-12
XLA_COMMIT: ${{ inputs.xla_commit || 'f238b48769d2ab8d62eeb09b5d31a972dfa4841a' }}
# rocm-jaxlib-v0.8.0
ROCM_XLA_COMMIT: ${{ inputs.rocm_xla_commit || '06402b44669c52956732678772104dcb85c53806' }}
XLA_COMMIT: ${{ inputs.xla_commit || 'f238b48769d2ab8d62eeb09b5d31a972dfa4841a' }} # main from 2025-12-12
ROCM_XLA_COMMIT: ${{ inputs.rocm_xla_commit || '06402b44669c52956732678772104dcb85c53806' }} # rocm-jaxlib-v0.8.0
TF_ROCM_AMDGPU_TARGETS: "gfx900,gfx906,gfx908,gfx90a,gfx942,gfx1030,gfx1100"

jobs:
setup_openxla:
runs-on: ubuntu-latest
outputs:
xla_commit: ${{ steps.patches.outputs.XLA_COMMIT_ID }}
rocm_xla_commit: ${{ steps.rocm_patches.outputs.XLA_COMMIT_ID }}
steps:
- name: "Checking out repository"
uses: actions/checkout@v4
with:
path: "pjrt-artifacts"
- name: "Checking out openxla repository"
uses: actions/checkout@v4
with:
ref: ${{ env.XLA_COMMIT }}
repository: openxla/xla
path: "xla"
- name: Apply patches to openxla
id: patches
working-directory: ./xla
run: |
xla_commit=$(git rev-parse HEAD)
echo "XLA_COMMIT_ID=$xla_commit" >> $GITHUB_OUTPUT
echo ::notice::Applying patches to openxla $xla_commit
for patch in $(ls ../pjrt-artifacts/openxla/patches/upstream/*.patch | sort); do
echo "Applying patch $patch"
git apply "$patch"
done
- name: Upload openxla repository artifact
uses: actions/upload-artifact@v4
with:
include-hidden-files: true
name: xla-${{ steps.patches.outputs.XLA_COMMIT_ID }}
path: ./xla
- name: "Checking out ROCm xla repository"
uses: actions/checkout@v4
with:
ref: ${{ env.ROCM_XLA_COMMIT }}
repository: ROCm/xla
path: "xla-rocm"
- name: Apply patches to ROCm openxla
id: rocm_patches
working-directory: ./xla-rocm
run: |
xla_commit=$(git rev-parse HEAD)
echo "XLA_COMMIT_ID=$xla_commit" >> $GITHUB_OUTPUT
echo ::notice::Applying patches to ROCm openxla $xla_commit
for patch in $(ls ../pjrt-artifacts/openxla/patches/rocm/*.patch | sort); do
echo "Applying patch $patch"
git apply "$patch"
done
- name: Upload ROCm openxla repository artifact
uses: actions/upload-artifact@v4
with:
include-hidden-files: true
name: xla-rocm-${{ steps.rocm_patches.outputs.XLA_COMMIT_ID }}
path: ./xla-rocm

pjrt-artifacts:
runs-on: ${{ matrix.pjrt.runs_on }}
strategy:
Expand Down Expand Up @@ -143,8 +62,6 @@ jobs:
runs_on: ["runs-on", "runner=32cpu-linux-x64", "image=ubuntu22-amd64"]
platform: linux-amd64
bazel_target: //xla/pjrt/c:pjrt_c_api_cpu_plugin

needs: ["setup_openxla"]
steps:
- uses: runs-on/action@v1
if: matrix.pjrt.platform != 'darwin-arm64' && matrix.pjrt.platform != 'darwin-amd64'
Expand All @@ -154,6 +71,27 @@ jobs:
with:
path: "pjrt-artifacts"

- name: "Checking out openxla repository"
uses: actions/checkout@v4
with:
ref: ${{ matrix.pjrt.target == 'rocm' && env.ROCM_XLA_COMMIT || env.XLA_COMMIT }}
repository: ${{ matrix.pjrt.target == 'rocm' && 'ROCm/xla' || 'openxla/xla' }}
path: "xla"

- name: Apply patches to openxla
id: patches
working-directory: ./xla
env:
PATCH_DIRECTORY: ${{ matrix.pjrt.target == 'rocm' && 'rocm' || 'upstream' }}
run: |
xla_commit=$(git rev-parse HEAD)
echo "XLA_COMMIT_ID=$xla_commit" >> $GITHUB_OUTPUT
echo ::notice::Applying patches to openxla $xla_commit
for patch in $(ls ../pjrt-artifacts/openxla/patches/$PATCH_DIRECTORY/*.patch | sort); do
echo "Applying patch $patch"
git apply "$patch"
done

- uses: runs-on/snapshot@v1
if: matrix.pjrt.platform != 'darwin-arm64' && matrix.pjrt.platform != 'darwin-amd64'
with:
Expand All @@ -175,20 +113,6 @@ jobs:
sudo apt install ./amdgpu-install_7.1.1.70101-1_all.deb -y
sudo amdgpu-install --usecase=rocm,rocmdev,hiplibsdk -y --no-dkms

- name: Download xla artifact
if: matrix.pjrt.target != 'rocm'
uses: actions/download-artifact@v4
with:
name: xla-${{ needs.setup_openxla.outputs.xla_commit }}
path: xla

- name: Download ROCm xla artifact
if: matrix.pjrt.target == 'rocm'
uses: actions/download-artifact@v4
with:
name: xla-rocm-${{ needs.setup_openxla.outputs.rocm_xla_commit }}
path: xla

- uses: bazel-contrib/setup-bazel@0.15.0
with:
bazelisk-version: 1.26.0
Expand Down Expand Up @@ -234,18 +158,3 @@ jobs:
with:
name: pjrt-${{ matrix.pjrt.target }}_${{ matrix.pjrt.platform }}.tar.gz
path: pjrt-${{ matrix.pjrt.target }}_${{ matrix.pjrt.platform }}.tar.gz

release:
needs: ["pjrt-artifacts"]
runs-on: ubuntu-latest
steps:
- run: rm -rf pjrt*.tar.gz
- name: Download all artifacts
uses: actions/download-artifact@v4
if: startsWith(github.ref, 'refs/tags/')
- name: Release
uses: softprops/action-gh-release@v2
if: startsWith(github.ref, 'refs/tags/')
with:
files: |
*.tar.gz/*.tar.gz
11 changes: 11 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
name: CI
on:
pull_request:

concurrency:
group: ${{ github.workflow_sha }}-${{ github.event.pull_request.number }}
cancel-in-progress: true

jobs:
pjrt-artifacts:
uses: ./.github/workflows/_build.yaml
14 changes: 5 additions & 9 deletions .github/workflows/nightly.yaml
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
name: nightly - PJRT CUDA library
name: Nightly

on:
workflow_dispatch:
pull_request:
schedule:
- cron: '0 */4 * * *'
- cron: '0 0 * * *'

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
group: ${{ github.workflow_sha }}-${{ github.ref }}
cancel-in-progress: true

permissions:
contents: write

jobs:
nightly-pjrt-artifacts:
if: github.event_name != 'pull_request' || contains(join(github.event.pull_request.labels.*.name, ','), 'nightly')
uses: ./.github/workflows/build.yaml
uses: ./.github/workflows/_build.yaml
with:
xla_commit: main
rocm_xla_commit: rocm-jaxlib-v0.8.0
31 changes: 31 additions & 0 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: Release

on:
push:
tags:
- "*"

concurrency:
group: ${{ github.workflow_sha }}-${{ github.ref_name }}
cancel-in-progress: true

permissions:
contents: write

jobs:
pjrt-artifacts:
uses: ./.github/workflows/_build.yaml

release:
name: "Release PJRT Artifacts"
needs: ["pjrt-artifacts"]
runs-on: ubuntu-latest
steps:
- run: rm -rf pjrt*.tar.gz
- name: Download all artifacts
uses: actions/download-artifact@v4
- name: Release
uses: softprops/action-gh-release@v2
with:
files: |
*.tar.gz/*.tar.gz
Loading