Skip to content

Commit a894902

Browse files
committed
chore: fixup pylint complaints
Signed-off-by: Maryam Tahhan <mtahhan@redhat.com>
1 parent ece957d commit a894902

File tree

10 files changed

+162
-31
lines changed

10 files changed

+162
-31
lines changed
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
name: Build triton-util
2+
3+
on:
4+
push:
5+
branches: [main]
6+
paths:
7+
- triton_util/**
8+
- .github/workflows/build.yml
9+
pull_request:
10+
paths:
11+
- triton_util/**
12+
13+
jobs:
14+
build:
15+
runs-on: ubuntu-latest
16+
17+
steps:
18+
- name: Checkout code
19+
uses: actions/checkout@v4
20+
21+
- name: Set up Python
22+
uses: actions/setup-python@v5
23+
with:
24+
python-version: 3.12
25+
26+
- name: Upgrade pip and install tools
27+
run: |
28+
python -m pip install --upgrade pip
29+
pip install build setuptools wheel
30+
31+
- name: Build the package
32+
run: |
33+
cd triton_util
34+
python -m build
35+
36+
- name: Install the package
37+
run: |
38+
pip install ./triton_util[dev]

.github/workflows/pylint.yml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ name: Lint Python code with pylint
33
on: # yamllint disable-line rule:truthy
44
pull_request:
55
paths:
6-
- "**/*.py"
6+
- tcm/**/*.py
7+
- triton_util/**/*.py
78
push:
89
paths:
9-
- "**/*.py"
10+
- tcm/**/*.py
11+
- triton_util/**/*.py
1012

1113
jobs:
1214
pylint:
@@ -26,6 +28,7 @@ jobs:
2628
python -m pip install --upgrade pip
2729
pip install pylint
2830
pip install -r ./tcm/requirements.txt
31+
pip install -e ./triton_util[dev]
2932
3033
- name: Run pylint
3134
run: |

triton_util/setup.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,28 @@
1+
"""Setup script for the triton-util package."""
2+
13
from setuptools import setup, find_packages
24

5+
with open("README.md", encoding="utf-8") as f:
6+
long_description = f.read()
7+
38
setup(
4-
name='triton-util',
5-
version='0.0.2',
9+
name="triton-util",
10+
version="0.0.2",
611
packages=find_packages(),
7-
install_requires=['triton'],
8-
author='Umer Adil',
9-
author_email='umer.hayat.adil@gmail.com',
10-
description='Make Triton easier - A utility package for OpenAI Triton',
11-
long_description=open('README.md').read(),
12-
long_description_content_type='text/markdown',
13-
url='https://github.com/umerHA/triton_util',
12+
install_requires=["triton"],
13+
extras_require={
14+
"dev": ["pytest", "pylint", "torch", "ipython"],
15+
},
16+
author="Umer Adil",
17+
author_email="umer.hayat.adil@gmail.com",
18+
description="Make Triton easier - A utility package for OpenAI Triton",
19+
long_description=long_description,
20+
long_description_content_type="text/markdown",
21+
url="https://github.com/redhat-et/TKDK/triton_util",
1422
classifiers=[
15-
'Programming Language :: Python :: 3',
16-
'License :: OSI Approved :: MIT License',
17-
'Operating System :: OS Independent',
23+
"Programming Language :: Python :: 3",
24+
"License :: OSI Approved :: MIT License",
25+
"Operating System :: OS Independent",
1826
],
19-
python_requires='>=3.6',
27+
python_requires=">=3.12",
2028
)

triton_util/tests/conftest.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Pytest configuration file with Triton interpreter mode fixture."""
2+
13
import os
24
import importlib
35

@@ -6,9 +8,14 @@
68
import triton
79
import triton.language as tl
810

9-
@pytest.fixture(scope='class', params=['0', '1']) # Run tests in regular mode (TRITON_INTERPRET=0) and in interpreter mode (TRITON_INTERPRET=1)
11+
@pytest.fixture(scope='class', params=['0', '1'])
1012
def triton_interpret(request):
11-
'''Set env var TRITON_INTERPRET and reload triton'''
13+
"""
14+
Test Triton in both regular mode (TRITON_INTERPRET=0) and interpreter mode (TRITON_INTERPRET=1)
15+
16+
Sets the TRITON_INTERPRET environment variable to either "0" or "1",
17+
reloads the Triton modules, and ensures the env var is cleaned up afterward.
18+
"""
1219
os.environ['TRITON_INTERPRET'] = request.param
1320
importlib.reload(triton)
1421
importlib.reload(tl)

triton_util/tests/test_coding.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import inspect
2-
1+
"""Test suite for triton_util coding utilities."""
2+
# pylint: disable=missing-function-docstring,unused-argument,undefined-variable
33
import pytest
44

55
import torch
@@ -10,7 +10,9 @@
1010
import triton_util as tu
1111

1212
class TestCodingUtils:
13+
"""Unit tests for triton_util offset, mask, load, and store utilities."""
1314
def test_cdiv(self):
15+
"""Test cdiv (ceiling division) function."""
1416
assert tu.cdiv(10, 2)==5
1517
assert tu.cdiv(10, 3)==4
1618

@@ -60,7 +62,7 @@ def test_mask_1d(self, triton_interpret):
6062

6163
@triton.jit
6264
def partial_copy(i_ptr, o_ptr, n):
63-
offs = n*2 + tl.arage(0,2)
65+
offs = n*2 + tl.arange(0,2)
6466
mask = mask_1d(offs, 4)
6567
vals = tl.load(i_ptr + offs, mask)
6668
tl.store(o_ptr + offs, vals, mask)
@@ -115,7 +117,7 @@ def test_load_full_1d(self, triton_interpret):
115117

116118
@triton.jit
117119
def copy(i_ptr, o_ptr):
118-
offs = tl.arage(0,4)
120+
offs = tl.arange(0,4)
119121
mask = offs < 4
120122
vals = tu.load_full_1d(i_ptr, 4)
121123
tl.store(o_ptr + offs, vals, mask)

triton_util/tests/test_debugging.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
"""
2+
Debugging utilities for Triton GPU kernels.
3+
"""
4+
# pylint: disable=multiple-statements,unused-argument,no-value-for-parameter,missing-class-docstring,missing-module-docstring,missing-function-docstring
5+
16
import pytest
27

38
import torch

triton_util/tests/test_loading.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# pylint: disable=line-too-long,missing-module-docstring,missing-function-docstring,multiple-statements,too-few-public-methods,unused-import
12
import os
23

34
import pytest
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,36 @@
1+
"""
2+
Triton-Util 🔱 – Utility functions for writing Triton GPU kernels with less friction.
3+
4+
This package provides high-level abstractions and helpers for writing
5+
fast and readable Triton code, reducing the need for repetitive and error-prone
6+
index calculations.
7+
8+
Features:
9+
- **Coding utilities**: Chunking, masking, offset computation, and bulk load/store helpers.
10+
- **Debugging utilities**: Convenient debugging hooks like `print_once` and `breakpoint_if`
11+
that make inspecting kernel behavior simpler and less intrusive.
12+
13+
The utilities are designed to be:
14+
- Minimal and interoperable: fully compatible with native Triton code.
15+
- Expressive: match how you actually think about GPU data access patterns.
16+
- Progressive: use as little or as much of the library as needed.
17+
18+
Example usage:
19+
>>> load_2d(ptr, sz0, sz1, n0, n1, max0, max1, stride0)
20+
21+
instead of:
22+
>>> offs0 = n0 * sz0 + tl.arange(0, sz0)
23+
>>> offs1 = n1 * sz1 + tl.arange(0, sz1)
24+
>>> offs = offs0[:, None] * stride0 + offs1[None, :] * stride1
25+
>>> mask = (offs0[:, None] < max0) & (offs1[None, :] < max1)
26+
>>> tl.load(ptr + offs, mask)
27+
28+
For documentation, examples, and community support, see:
29+
- GitHub: https://github.com/cuda-mode/triton-util
30+
- Discord: https://discord.gg/cudamode (Triton channel)
31+
32+
Author: Umer Hadil
33+
"""
34+
135
from .debugging import *
236
from .coding import *

triton_util/triton_util/coding.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,46 @@
1+
"""Triton utility functions for offset calculation, masking, and load/store operations."""
2+
# pylint: disable=too-many-arguments,too-many-positional-arguments,redefined-builtin,unused-argument
3+
14
import triton
25
import triton.language as tl
36
from triton.language import constexpr as const
47

5-
def cdiv(a,b): return (a + b - 1) // b
8+
def cdiv(a, b):
9+
"""Ceiling division."""
10+
return (a + b - 1) // b
611

712
# # offsets
813

914
@triton.jit
10-
def offset_1d(sz: const, n_prev_chunks=0): return n_prev_chunks * sz + tl.arange(0, sz)
15+
def offset_1d(sz: const, n_prev_chunks=0):
16+
"""Compute 1D offset based on chunk size and previous chunks."""
17+
return n_prev_chunks * sz + tl.arange(0, sz)
1118

1219
@triton.jit
13-
def offset_2d(offs0, offs1, stride0, stride1=1): return tl.expand_dims(offs0, 1)*stride0 + tl.expand_dims(offs1, 0)*stride1
20+
def offset_2d(offs0, offs1, stride0, stride1=1):
21+
"""Compute 2D offset using strides."""
22+
return tl.expand_dims(offs0, 1)*stride0 + tl.expand_dims(offs1, 0)*stride1
1423

1524
# # masks
1625

1726
@triton.jit
18-
def mask_1d(offs, max): return offs < max
27+
def mask_1d(offs, max):
28+
"""Create a 1D mask based on a max bound."""
29+
return offs < max
1930

2031
@triton.jit
21-
def mask_2d(offs0, offs1, max0, max1): return (tl.expand_dims(offs0, 1) < max0) & (tl.expand_dims(offs1, 0) < max1)
32+
def mask_2d(offs0, offs1, max0, max1):
33+
"""Create a 2D mask using upper bounds for each axis."""
34+
return (tl.expand_dims(offs0, 1) < max0) & (tl.expand_dims(offs1, 0) < max1)
2235

2336
# # load
2437

2538
@triton.jit
2639
def load_1d(ptr, sz: const, n, max, stride=1):
27-
'''Chunk 1d vector (defined by ptr) into 1d grid, where each chunk has size sz. Load the nth chunk. Ie, load [n*sz,...,(n+1)*sz-1].'''
40+
"""
41+
Chunk 1d vector (defined by ptr) into 1d grid, where each chunk has size sz.
42+
Load the nth chunk. Ie, load [n*sz,...,(n+1)*sz-1].
43+
"""
2844
offs = offset_1d(sz, n)
2945
mask = mask_1d(offs, max)
3046
return tl.load(ptr + offs, mask)
@@ -38,7 +54,10 @@ def load_full_1d(ptr, sz: const, stride=1):
3854

3955
@triton.jit
4056
def load_2d(ptr, sz0: const, sz1: const, n0, n1, max0, max1, stride0=None, stride1=1):
41-
'''Chunk 2d matrix (defined by ptr) into 2d grid, where each chunk has size (sz0,sz1). Load the (n0,n1)th chunk. Ie, load [n0*sz0,...,(n0+1)*sz0-1] x [n1*sz1,...,(n1+1)*sz1-1].'''
57+
"""
58+
Chunk 2d matrix (defined by ptr) into 2d grid, where each chunk has size (sz0,sz1).
59+
Load the (n0,n1)th chunk. Ie, load [n0*sz0,...,(n0+1)*sz0-1] x [n1*sz1,...,(n1+1)*sz1-1].
60+
"""
4261
stride0 = stride0 or sz1
4362
offs0 = offset_1d(sz0, n0)
4463
offs1 = offset_1d(sz1, n1)
@@ -72,7 +91,10 @@ def store_full_1d(vals, ptr, sz: const, stride=1):
7291

7392
@triton.jit
7493
def store_2d(vals, ptr, sz0: const, sz1: const, n0, n1, max0, max1, stride0=None, stride1=1):
75-
'''Store 2d block into (n0,n1)th chunk of matrix (defined by ptr), where each chunk has size (sz0, sz1)'''
94+
"""
95+
Store 2d block into (n0,n1)th chunk of matrix (defined by ptr), where each chunk has size
96+
(sz0, sz1)
97+
"""
7698
stride0 = stride0 or sz1
7799
offs0 = offset_1d(sz0, n0)
78100
offs1 = offset_1d(sz1, n1)

triton_util/triton_util/debugging.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
"""Debugging utilities for Triton kernels.
2+
3+
Includes conditional breakpoints and printing based on thread identifiers,
4+
plus tensor readiness checks for CUDA or interpreted environments.
5+
"""
6+
# pylint: disable=multiple-statements,line-too-long,import-outside-toplevel,eval-used,fixme,unused-variable
17
import os
28
import triton
39
import triton.language as tl
@@ -42,12 +48,17 @@ def print_if(*txt, conds):
4248
if test_pid_conds(conds): print(*txt)
4349

4450
@triton.jit
45-
def breakpoint_once(): breakpoint_if('=0,=0,=0')
51+
def breakpoint_once():
52+
"""Trigger a breakpoint."""
53+
breakpoint_if('=0,=0,=0')
4654

4755
@triton.jit
48-
def print_once(*txt): print_if(*txt,conds='=0,=0,=0')
56+
def print_once(*txt):
57+
"""Print a message."""
58+
print_if(*txt,conds='=0,=0,=0')
4959

5060
def assert_tensors_gpu_ready(*tensors):
61+
"""Assert that each tensor is contiguous and on the GPU (unless TRITON_INTERPRET=1)."""
5162
for t in tensors:
5263
assert t.is_contiguous(), "A tensor is not contiguous"
5364
if not os.environ.get('TRITON_INTERPRET') == '1': assert t.is_cuda, "A tensor is not on cuda"

0 commit comments

Comments
 (0)