Skip to content

Commit db9b214

Browse files
authored
[CI][Enhancement] Add Pytest in unit test (#154)
1. Refactor the current testing framework using the pytest framework; 2. Remove profile testing in ci.yml, retaining only unit tests; 3. Fix the precision thresholds for some test cases; 4. Other minor bug fixes. 5. '@py_assert1' in `test_deepseek_nsa_cmp_fwd.py` will be fixed in #171
1 parent 98d5659 commit db9b214

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+452
-1022
lines changed

.github/workflows/ci.yml

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ jobs:
7171
source "${{ runner.tool_cache }}/${VENV_DIR}/bin/activate"
7272
export PYTHONPATH="$(pwd):$PYTHONPATH"
7373
echo "PYTHONPATH=$PYTHONPATH"
74-
bash tests/ci_test.sh tileops_test_release.log
74+
set -o pipefail
75+
python -m pytest -q tests | tee tileops_test_release.log
7576
shell: bash
7677

7778
- name: Cleanup venv
@@ -145,7 +146,8 @@ jobs:
145146
source "${{ runner.tool_cache }}/${VENV_DIR}/bin/activate"
146147
export PYTHONPATH="$(pwd):$PYTHONPATH"
147148
echo "PYTHONPATH=$PYTHONPATH"
148-
bash tests/ci_test.sh tileops_test_nightly.log
149+
set -o pipefail
150+
python -m pytest -q tests | tee tileops_test_nightly.log
149151
shell: bash
150152

151153
- name: Cleanup venv
@@ -165,30 +167,3 @@ jobs:
165167
name: tileops_test_nightly.log
166168
path: tileops_test_nightly.log
167169
retention-days: 7 # Equivalent to expire_in: 1 week
168-
169-
tileops_profile_release:
170-
# needs: [pre-commit, tileops_test_release]
171-
needs: [tileops_test_release]
172-
runs-on: [self-hosted, profile]
173-
steps:
174-
- name: Checkout code
175-
uses: actions/checkout@v3
176-
with:
177-
fetch-depth: 0 # Equivalent to GIT_STRATEGY: fetch
178-
179-
- name: Setup & Run tests
180-
run: |
181-
source ~/miniconda3/etc/profile.d/conda.sh
182-
conda activate tileops-release
183-
export PYTHONPATH="$(pwd):$PYTHONPATH"
184-
echo "PYTHONPATH=$PYTHONPATH"
185-
bash benchmarks/profile_run.sh --log profile_out/tileops_profile_release.log
186-
shell: bash
187-
188-
- name: Upload profile_out artifacts
189-
uses: actions/upload-artifact@v4
190-
if: always()
191-
with:
192-
name: profile_out
193-
path: profile_out/
194-
retention-days: 7

benchmarks/flash_attn/mha.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@ def ref_program(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, o: torc
148148
q_bhsd, k_bhsd, v_bhsd, is_causal=self.is_causal)
149149
output = output_bhsd.transpose(1, 2).contiguous()
150150

151-
# from IPython import embed; embed()
152151
output.backward(grad_output)
153152
return q.grad, k.grad, v.grad
154153

tests/__init__.py

Whitespace-only changes.

tests/ci_test.sh

Lines changed: 0 additions & 78 deletions
This file was deleted.

tests/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import pytest
2+
import torch
3+
4+
5+
@pytest.fixture(autouse=True)
6+
def setup() -> None:
7+
torch.manual_seed(1235)
8+
if torch.cuda.is_available():
9+
torch.cuda.manual_seed_all(1235)

tests/functions/test_deepseek_dsa_decode_func.py

Lines changed: 12 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,20 @@
1-
import argparse
1+
import pytest
2+
import torch
23

34
from benchmarks import DeepSeekSparseAttentionDecodeBenchmark
45
from top.functions import DeepSeekSparseAttentionDecodeWithKVCacheFunc
56
from top.layers import DeepSeekSparseAttentionDecodeLayer
6-
from top.utils import str2dtype
77

88

9-
def test_sparse_mla_decode(batch,
10-
heads,
11-
seq_len_q,
12-
seq_len_kv,
13-
dim,
14-
dim_tail,
15-
topk,
16-
stride_kv,
17-
group_kv,
18-
q_start_index_s,
19-
sm_scale,
20-
dtype,
21-
tune=False):
9+
@pytest.mark.parametrize(
10+
"batch, heads, seq_len_q, seq_len_kv, dim, dim_tail, topk, stride_kv, group_kv, q_start_index_s, sm_scale, dtype, tune",
11+
[
12+
(1, 128, 1024, 2048, 512, 64, 2048, 1, 1, 1024, None, torch.float16, False),
13+
],
14+
)
15+
def test_sparse_mla_decode(batch: int, heads: int, seq_len_q: int, seq_len_kv: int, dim: int,
16+
dim_tail: int, topk: int, stride_kv: int, group_kv: int,
17+
q_start_index_s: int, sm_scale: float, dtype: torch.dtype, tune: bool):
2218
fn = DeepSeekSparseAttentionDecodeWithKVCacheFunc(
2319
batch,
2420
heads,
@@ -81,23 +77,4 @@ def test_sparse_mla_decode(batch,
8177

8278

8379
if __name__ == "__main__":
84-
parser = argparse.ArgumentParser()
85-
parser.add_argument('--batch', type=int, default=1, help='batch size')
86-
parser.add_argument('--seq_len', type=int, default=1024, help='sequence length')
87-
parser.add_argument('--seq_len_kv', type=int, default=2048, help='key/value sequence length')
88-
parser.add_argument('--heads', type=int, default=128, help='num heads')
89-
parser.add_argument('--dim', type=int, default=512, help='head dim')
90-
parser.add_argument('--dim_tail', type=int, default=64, help='tail dim')
91-
parser.add_argument('--topk', type=int, default=2048, help='topk')
92-
parser.add_argument('--stride_kv', type=int, default=1, help='stride_kv')
93-
parser.add_argument('--group_kv', type=int, default=1, help='group_kv')
94-
parser.add_argument('--sm_scale', type=float, default=None, help='softmax scaling factor')
95-
parser.add_argument('--q_start_index_s', type=int, default=1024, help='query start index')
96-
parser.add_argument(
97-
'--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type')
98-
parser.add_argument('--tune', action='store_true', default=False, help='enable autotune')
99-
args = parser.parse_args()
100-
101-
test_sparse_mla_decode(args.batch, args.heads, args.seq_len, args.seq_len_kv, args.dim,
102-
args.dim_tail, args.topk, args.stride_kv, args.group_kv,
103-
args.q_start_index_s, args.sm_scale, str2dtype[args.dtype], args.tune)
80+
pytest.main([__file__, "-vvs"])

tests/functions/test_deepseek_mla_decode_func.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
1-
import argparse
1+
import pytest
2+
3+
import torch
24

35
from benchmarks import MultiHeadLatentAttentionDecodeBenchmark
46
from top.functions import MultiHeadLatentAttentionDecodeWithKVCacheFunc, mla_decode_with_kvcache
57
from top.layers import MultiHeadLatentAttentionDecodeLayer
6-
from top.utils import str2dtype
78

89

9-
def test_mla_decode_fn(batch, kv_head_num, seq_len_kv, heads, dim, pe_dim, dtype):
10+
@pytest.mark.parametrize(
11+
"batch, kv_head_num, seq_len_kv, heads, dim, pe_dim, dtype",
12+
[
13+
(32, 1, 8192, 128, 512, 64, torch.float16),
14+
],
15+
)
16+
def test_mla_decode_fn(batch: int, kv_head_num: int, seq_len_kv: int, heads: int, dim: int,
17+
pe_dim: int, dtype: torch.dtype):
1018

1119
mla_layer = MultiHeadLatentAttentionDecodeLayer(batch, heads, kv_head_num, seq_len_kv, dim,
1220
pe_dim, dtype)
@@ -43,17 +51,4 @@ def test_mla_decode_fn(batch, kv_head_num, seq_len_kv, heads, dim, pe_dim, dtype
4351

4452

4553
if __name__ == "__main__":
46-
parser = argparse.ArgumentParser()
47-
parser.add_argument('--batch', type=int, default=32, help='batch size')
48-
parser.add_argument('--kv_head_num', type=int, default=1, help='number of key/value heads')
49-
parser.add_argument('--seq_len_kv', type=int, default=8192, help='key/value sequence length')
50-
parser.add_argument('--heads', type=int, default=128, help='num heads')
51-
parser.add_argument('--dim', type=int, default=512, help='head dim')
52-
parser.add_argument('--pe_dim', type=int, default=64, help='positional encoding dim')
53-
parser.add_argument(
54-
'--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type')
55-
parser.add_argument('--tune', action='store_true', default=False, help='enable autotune')
56-
args = parser.parse_args()
57-
58-
test_mla_decode_fn(args.batch, args.kv_head_num, args.seq_len_kv, args.heads, args.dim,
59-
args.pe_dim, str2dtype[args.dtype])
54+
pytest.main([__file__, "-vvs"])
Lines changed: 10 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
1-
import argparse
1+
import pytest
22

33
from benchmarks.deepseek_mla import Fp8LightingIndexerBenchmark
44
from top.functions import Fp8LightingIndexerFunc
55
from top.layers import Fp8LightingIndexerDecodeLayer
66

77

8-
def test_fp8_lighting_indexer(seq_len, heads, index_dim, seq_len_kv, clean_logits, config):
8+
@pytest.mark.parametrize(
9+
"seq_len, heads, index_dim, seq_len_kv, clean_logits, config",
10+
[
11+
(4096, 32, 64, 8192, True, None),
12+
],
13+
)
14+
def test_fp8_lighting_indexer(seq_len: int, heads: int, index_dim: int, seq_len_kv: int,
15+
clean_logits: bool, config):
916
fn = Fp8LightingIndexerFunc(seq_len, heads, index_dim, seq_len_kv, clean_logits, config)
1017
layer = Fp8LightingIndexerDecodeLayer(seq_len, heads, index_dim, seq_len_kv, clean_logits,
1118
config)
@@ -32,64 +39,4 @@ def test_fp8_lighting_indexer(seq_len, heads, index_dim, seq_len_kv, clean_logit
3239

3340

3441
if __name__ == "__main__":
35-
parser = argparse.ArgumentParser()
36-
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
37-
parser.add_argument('--heads', type=int, default=32, help='number of heads')
38-
parser.add_argument('--index_dim', type=int, default=64, help='index dim')
39-
parser.add_argument('--seq_len_kv', type=int, default=8192, help='key/value sequence length')
40-
parser.add_argument(
41-
'--clean_logits',
42-
action=argparse.BooleanOptionalAction,
43-
default=True,
44-
help='whether to clean logits outside the valid range')
45-
parser.add_argument('--config', type=str, default=None, help='positional encoding dim')
46-
parser.add_argument('--tune', action='store_true', default=False, help='enable autotune')
47-
args = parser.parse_args()
48-
49-
test_fp8_lighting_indexer(args.seq_len, args.heads, args.index_dim, args.seq_len_kv,
50-
args.clean_logits, args.config)
51-
52-
53-
def test_fp8_lighting_indexer(seq_len, heads, index_dim, seq_len_kv, clean_logits, config):
54-
fn = Fp8LightingIndexerFunc(seq_len, heads, index_dim, seq_len_kv, clean_logits, config)
55-
layer = Fp8LightingIndexerDecodeLayer(seq_len, heads, index_dim, seq_len_kv, clean_logits,
56-
config)
57-
benchmark = Fp8LightingIndexerBenchmark(seq_len, heads, index_dim, seq_len_kv, clean_logits,
58-
config)
59-
60-
inputs = benchmark.gen_inputs()
61-
62-
try:
63-
print("Testing indexer_fn...")
64-
benchmark.check_fn(fn, *inputs, grad=False)
65-
print("✅ indexer_fn test passed")
66-
except Exception as e:
67-
print(f"❌ indexer_fn test failed: {e}")
68-
raise
69-
70-
try:
71-
print("Testing indexer_layer...")
72-
benchmark.check_fn(layer, *inputs, grad=False)
73-
print("✅ indexer_layer test passed")
74-
except Exception as e:
75-
print(f"❌ indexer_layer test failed: {e}")
76-
raise
77-
78-
79-
if __name__ == "__main__":
80-
parser = argparse.ArgumentParser()
81-
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
82-
parser.add_argument('--heads', type=int, default=32, help='number of heads')
83-
parser.add_argument('--index_dim', type=int, default=64, help='index dim')
84-
parser.add_argument('--seq_len_kv', type=int, default=8192, help='key/value sequence length')
85-
parser.add_argument(
86-
'--clean_logits',
87-
action=argparse.BooleanOptionalAction,
88-
default=True,
89-
help='whether to clean logits outside the valid range')
90-
parser.add_argument('--config', type=str, default=None, help='positional encoding dim')
91-
parser.add_argument('--tune', action='store_true', default=False, help='enable autotune')
92-
args = parser.parse_args()
93-
94-
test_fp8_lighting_indexer(args.seq_len, args.heads, args.index_dim, args.seq_len_kv,
95-
args.clean_logits, args.config)
42+
pytest.main([__file__, "-vvs"])
Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
(16384, 32, torch.float32, False),
1616
],
1717
)
18-
def test_fp8_quant(seq_len_kv, index_dim, in_dtype, tune=False):
18+
def test_fp8_quant(seq_len_kv, index_dim, in_dtype, tune):
1919
fn = Fp8QuantFunc(seq_len_kv, index_dim, in_dtype, tune=tune)
2020
layer = Fp8QuantLayer(seq_len_kv, index_dim, in_dtype, tune=tune)
2121
benchmark = Fp8QuantBenchmark(seq_len_kv, index_dim, in_dtype)
@@ -39,7 +39,4 @@ def test_fp8_quant(seq_len_kv, index_dim, in_dtype, tune=False):
3939

4040

4141
if __name__ == "__main__":
42-
test_fp8_quant(8192, 64, torch.float16, False)
43-
test_fp8_quant(8192, 64, torch.bfloat16, False)
44-
test_fp8_quant(4096, 128, torch.float32, False)
45-
test_fp8_quant(16384, 32, torch.float32, False)
42+
pytest.main([__file__, "-vvs"])
Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
1-
import argparse
1+
import pytest
2+
import torch
23

34
from benchmarks import GroupQueryAttentionDecodeBenchmark
45
from top.functions import GroupQueryAttentionDecodeWithKVCacheFunc, gqa_decode_with_kvcache
5-
from top.utils import str2dtype
66

77

8-
def test_gqa_decode_fn(batch, heads, seq_len_kv, dim, groups, dtype):
8+
@pytest.mark.parametrize(
9+
"batch, heads, seq_len_kv, dim, groups, dtype",
10+
[
11+
(1, 32, 8192, 128, 1, torch.float16),
12+
],
13+
)
14+
def test_gqa_decode_fn(batch: int, heads: int, seq_len_kv: int, dim: int, groups: int,
15+
dtype: torch.dtype):
916
benchmark = GroupQueryAttentionDecodeBenchmark(batch, heads, groups, seq_len_kv, dim, dtype)
1017

1118
inputs = benchmark.gen_inputs()
@@ -19,16 +26,4 @@ def test_gqa_decode_fn(batch, heads, seq_len_kv, dim, groups, dtype):
1926

2027

2128
if __name__ == "__main__":
22-
parser = argparse.ArgumentParser()
23-
parser.add_argument('--batch', type=int, default=1, help='batch size')
24-
parser.add_argument('--groups', type=int, default=1, help='num groups')
25-
parser.add_argument('--seq_len_kv', type=int, default=8192, help='key/value sequence length')
26-
parser.add_argument('--heads', type=int, default=32, help='num heads')
27-
parser.add_argument('--dim', type=int, default=128, help='head dim')
28-
parser.add_argument(
29-
'--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type')
30-
parser.add_argument('--tune', action='store_true', default=False, help='enable autotune')
31-
args = parser.parse_args()
32-
33-
test_gqa_decode_fn(args.batch, args.heads, args.seq_len_kv, args.dim, args.groups,
34-
str2dtype[args.dtype])
29+
pytest.main([__file__, "-vvs"])

0 commit comments

Comments
 (0)