Skip to content

Commit 111ffc4

Browse files
HunterTracerzenghongtairusty1s
authored
Use static constexpr for VS compilation (#342)
* static constexpr should be used in cuda/reducer.cuh, otherwise compilation fails in VS * modify torch/include/ATen/Parallel.h in testing.yml * update * update * update * typo * typo * typo * remove prints Co-authored-by: zenghongtai <[email protected]> Co-authored-by: rusty1s <[email protected]>
1 parent 8e6635b commit 111ffc4

File tree

3 files changed

+22
-6
lines changed

3 files changed

+22
-6
lines changed

.github/workflows/building.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ jobs:
5959
python -c "import torch; print('PyTorch:', torch.__version__)"
6060
python -c "import torch; print('CUDA:', torch.version.cuda)"
6161
62+
- name: Patch PyTorch static constexpr on Windows
63+
if: ${{ runner.os == 'Windows' }}
64+
run: |
65+
Torch_DIR=`python -c 'import os; import torch; print(os.path.dirname(torch.__file__))'`
66+
sed -i '31,38c\
67+
TORCH_API void lazy_init_num_threads();' ${Torch_DIR}/include/ATen/Parallel.h
68+
shell: bash
69+
6270
- name: Set version
6371
if: ${{ runner.os != 'macOS' }}
6472
run: |

.github/workflows/testing.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,14 @@ jobs:
2929
run: |
3030
pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/cpu
3131
32+
- name: Patch PyTorch static constexpr on Windows
33+
if: ${{ runner.os == 'Windows' }}
34+
run: |
35+
Torch_DIR=`python -c 'import os; import torch; print(os.path.dirname(torch.__file__))'`
36+
sed -i '31,38c\
37+
TORCH_API void lazy_init_num_threads();' ${Torch_DIR}/include/ATen/Parallel.h
38+
shell: bash
39+
3240
- name: Install main package
3341
run: |
3442
pip install -e .[test]

csrc/cuda/reducer.cuh

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,27 @@ const std::map<std::string, ReductionType> reduce2REDUCE = {
1616
[&] { \
1717
switch (reduce2REDUCE.at(reduce)) { \
1818
case SUM: { \
19-
const ReductionType REDUCE = SUM; \
19+
static constexpr ReductionType REDUCE = SUM; \
2020
return __VA_ARGS__(); \
2121
} \
2222
case MEAN: { \
23-
const ReductionType REDUCE = MEAN; \
23+
static constexpr ReductionType REDUCE = MEAN; \
2424
return __VA_ARGS__(); \
2525
} \
2626
case MUL: { \
27-
const ReductionType REDUCE = MUL; \
27+
static constexpr ReductionType REDUCE = MUL; \
2828
return __VA_ARGS__(); \
2929
} \
3030
case DIV: { \
31-
const ReductionType REDUCE = DIV; \
31+
static constexpr ReductionType REDUCE = DIV; \
3232
return __VA_ARGS__(); \
3333
} \
3434
case MIN: { \
35-
const ReductionType REDUCE = MIN; \
35+
static constexpr ReductionType REDUCE = MIN; \
3636
return __VA_ARGS__(); \
3737
} \
3838
case MAX: { \
39-
const ReductionType REDUCE = MAX; \
39+
static constexpr ReductionType REDUCE = MAX; \
4040
return __VA_ARGS__(); \
4141
} \
4242
} \

0 commit comments

Comments
 (0)