Skip to content

Commit 4d8090c

Browse files
izmttkmalfet
authored andcommitted
Avoid file encoding issues when loading cpp extensions (pytorch#138565)
I've found that when using `torch.utils.cpp_extension.load` on my Windows system, decoding errors occur when my .cpp/.cu files contain certain non-English characters. `test.py`: ```py from torch.utils.cpp_extension import load my_lib = load(name='my_cuda_kernel', sources=['my_cuda_kernel.cu'], extra_cuda_cflags=['-O2', '-std=c++17']) # ...... ``` `my_cuda_kernel.cu`: ```cpp #include <torch/types.h> #include <torch/extension.h> // 向量化 <------ some chinese characters // ...... ``` Errors will be reported as: ``` Traceback (most recent call last): File "E:\test\test.py", line 8, in <module> my_lib = load( ^^^^^ File "C:\Users\XXX\AppData\Roaming\Python\Python311\site-packages\torch\utils\cpp_extension.py", line 1314, in load return _jit_compile( ^^^^^^^^^^^^^ File "C:\Users\XXX\AppData\Roaming\Python\Python311\site-packages\torch\utils\cpp_extension.py", line 1680, in _jit_compile version = JIT_EXTENSION_VERSIONER.bump_version_if_changed( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\XXX\AppData\Roaming\Python\Python311\site-packages\torch\utils\_cpp_extension_versioner.py", line 46, in bump_version_if_changed hash_value = hash_source_files(hash_value, source_files) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\XXX\AppData\Roaming\Python\Python311\site-packages\torch\utils\_cpp_extension_versioner.py", line 17, in hash_source_files hash_value = update_hash(hash_value, file.read()) ^^^^^^^^^^^ UnicodeDecodeError: 'gbk' codec can't decode byte 0x96 in position 141: illegal multibyte sequence ``` The issue lies in the fact that the `open()` function in Python is platform-dependent, which can cause decoding errors when a file contains characters that are not supported by the default encoding. Pytorch uses file contents to generate hash string: https://github.com/pytorch/pytorch/blob/60c14330411de8f52bfb28d6406f1822edaad944/torch/utils/_cpp_extension_versioner.py#L16-L17 In my windows the default encoding is `gbk` but all of my cpp files are in `utf-8`. There is a simple solution to this problem I think: just change the file reading mode to binary mode, which can avoid issues related to file encoding. It works perfectly on my computer. ```diff - with open(filename) as file: + with open(filename, 'rb') as file: hash_value = update_hash(hash_value, file.read()) ``` Pull Request resolved: pytorch#138565 Approved by: https://github.com/malfet, https://github.com/janeyx99 Co-authored-by: Nikita Shulga <[email protected]>
1 parent 1ec76dd commit 4d8090c

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

test/test_cpp_extensions_jit.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Owner(s): ["module: cpp-extensions"]
22

33
import glob
4+
import locale
45
import os
56
import re
67
import shutil
@@ -529,6 +530,40 @@ def compile(code):
529530
module = compile("int f() { return 789; }")
530531
self.assertEqual(module.f(), 789)
531532

533+
@unittest.skipIf(
534+
"utf" not in locale.getlocale()[1].lower(), "Only test in UTF-8 locale"
535+
)
536+
def test_load_with_non_platform_default_encoding(self):
537+
# Assume the code is saved in UTF-8, but the locale is set to a different encoding.
538+
# You might encounter decoding errors in ExtensionVersioner.
539+
# But this case is quite hard to cover because CI environments may not in non-latin locale.
540+
# So the following code just test source file in gbk and locale in utf-8.
541+
542+
cpp_source = """
543+
#include <torch/extension.h>
544+
545+
// Non-latin1 character test: 字符.
546+
// It will cause utf-8 decoding error.
547+
548+
int f() { return 123; }
549+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
550+
m.def("f", &f, "f");
551+
}
552+
"""
553+
554+
build_dir = tempfile.mkdtemp()
555+
src_path = os.path.join(build_dir, "main.cpp")
556+
557+
with open(src_path, encoding="gbk", mode="w") as f:
558+
f.write(cpp_source)
559+
560+
module = torch.utils.cpp_extension.load(
561+
name="non_default_encoding",
562+
sources=src_path,
563+
verbose=True,
564+
)
565+
self.assertEqual(module.f(), 123)
566+
532567
def test_cpp_frontend_module_has_same_output_as_python(self, dtype=torch.double):
533568
extension = torch.utils.cpp_extension.load(
534569
name="cpp_frontend_extension",

torch/utils/_cpp_extension_versioner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def update_hash(seed, value):
1313

1414
def hash_source_files(hash_value, source_files):
1515
for filename in source_files:
16-
with open(filename) as file:
16+
with open(filename, 'rb') as file:
1717
hash_value = update_hash(hash_value, file.read())
1818
return hash_value
1919

0 commit comments

Comments
 (0)