Skip to content

Commit 98521c7

Browse files
fmassapeterjc123
andauthored
Make read_file and write_file accept unicode strings on Windows (#2949) (#3009)
* Make read_file accept unicode strings on Windows * More fixes * Remove definitions from source files * Move string definitions to header * Add checks * Fix comments * Update macro * Fix comments * Fix lint * include windows header * Change func signature in header * Use from_blob * Fix fread calls * Fix clang format * Fix missing return * Avoid copy Co-authored-by: peterjc123 <[email protected]>
1 parent 45f960c commit 98521c7

File tree

3 files changed

+78
-15
lines changed

3 files changed

+78
-15
lines changed

test/test_image.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,18 @@ def test_read_file(self):
221221
RuntimeError, "No such file or directory: 'tst'"):
222222
read_file('tst')
223223

224+
def test_read_file_non_ascii(self):
225+
with get_tmp_dir() as d:
226+
fname, content = '日本語(Japanese).bin', b'TorchVision\211\n'
227+
fpath = os.path.join(d, fname)
228+
with open(fpath, 'wb') as f:
229+
f.write(content)
230+
231+
data = read_file(fpath)
232+
expected = torch.tensor(list(content), dtype=torch.uint8)
233+
self.assertTrue(data.equal(expected))
234+
os.unlink(fpath)
235+
224236
def test_write_file(self):
225237
with get_tmp_dir() as d:
226238
fname, content = 'test1.bin', b'TorchVision\211\n'
@@ -233,6 +245,18 @@ def test_write_file(self):
233245
self.assertEqual(content, saved_content)
234246
os.unlink(fpath)
235247

248+
def test_write_file_non_ascii(self):
249+
with get_tmp_dir() as d:
250+
fname, content = '日本語(Japanese).bin', b'TorchVision\211\n'
251+
fpath = os.path.join(d, fname)
252+
content_tensor = torch.tensor(list(content), dtype=torch.uint8)
253+
write_file(fpath, content_tensor)
254+
255+
with open(fpath, 'rb') as f:
256+
saved_content = f.read()
257+
self.assertEqual(content, saved_content)
258+
os.unlink(fpath)
259+
236260

237261
if __name__ == '__main__':
238262
unittest.main()

torchvision/csrc/cpu/image/read_write_file_cpu.cpp

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,40 @@
11
#include "read_write_file_cpu.h"
22

3-
// According to
4-
// https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/stat-functions?view=vs-2019,
5-
// we should use _stat64 for 64-bit file size on Windows.
63
#ifdef _WIN32
7-
#define VISION_STAT _stat64
8-
#else
9-
#define VISION_STAT stat
4+
#define WIN32_LEAN_AND_MEAN
5+
#include <Windows.h>
6+
7+
std::wstring utf8_decode(const std::string& str) {
8+
if (str.empty()) {
9+
return std::wstring();
10+
}
11+
int size_needed = MultiByteToWideChar(
12+
CP_UTF8, 0, str.c_str(), static_cast<int>(str.size()), NULL, 0);
13+
TORCH_CHECK(size_needed > 0, "Error converting the content to Unicode");
14+
std::wstring wstrTo(size_needed, 0);
15+
MultiByteToWideChar(
16+
CP_UTF8,
17+
0,
18+
str.c_str(),
19+
static_cast<int>(str.size()),
20+
&wstrTo[0],
21+
size_needed);
22+
return wstrTo;
23+
}
1024
#endif
1125

12-
torch::Tensor read_file(std::string filename) {
13-
struct VISION_STAT stat_buf;
14-
int rc = VISION_STAT(filename.c_str(), &stat_buf);
26+
torch::Tensor read_file(const std::string& filename) {
27+
#ifdef _WIN32
28+
// According to
29+
// https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/stat-functions?view=vs-2019,
30+
// we should use struct __stat64 and _wstat64 for 64-bit file size on Windows.
31+
struct __stat64 stat_buf;
32+
auto fileW = utf8_decode(filename);
33+
int rc = _wstat64(fileW.c_str(), &stat_buf);
34+
#else
35+
struct stat stat_buf;
36+
int rc = stat(filename.c_str(), &stat_buf);
37+
#endif
1538
// errno is a variable defined in errno.h
1639
TORCH_CHECK(
1740
rc == 0, "[Errno ", errno, "] ", strerror(errno), ": '", filename, "'");
@@ -21,9 +44,20 @@ torch::Tensor read_file(std::string filename) {
2144
TORCH_CHECK(size > 0, "Expected a non empty file");
2245

2346
#ifdef _WIN32
24-
auto data =
25-
torch::from_file(filename, /*shared=*/false, /*size=*/size, torch::kU8)
26-
.clone();
47+
// TODO: Once torch::from_file handles UTF-8 paths correctly, we should move
48+
// back to use the following implementation since it uses file mapping.
49+
// auto data =
50+
// torch::from_file(filename, /*shared=*/false, /*size=*/size,
51+
// torch::kU8).clone()
52+
FILE* infile = _wfopen(fileW.c_str(), L"rb");
53+
54+
TORCH_CHECK(infile != nullptr, "Error opening input file");
55+
56+
auto data = torch::empty({size}, torch::kU8);
57+
auto dataBytes = data.data_ptr<uint8_t>();
58+
59+
fread(dataBytes, sizeof(uint8_t), size, infile);
60+
fclose(infile);
2761
#else
2862
auto data =
2963
torch::from_file(filename, /*shared=*/false, /*size=*/size, torch::kU8);
@@ -32,7 +66,7 @@ torch::Tensor read_file(std::string filename) {
3266
return data;
3367
}
3468

35-
void write_file(std::string filename, torch::Tensor& data) {
69+
void write_file(const std::string& filename, torch::Tensor& data) {
3670
// Check that the input tensor is on CPU
3771
TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU");
3872

@@ -44,7 +78,12 @@ void write_file(std::string filename, torch::Tensor& data) {
4478

4579
auto fileBytes = data.data_ptr<uint8_t>();
4680
auto fileCStr = filename.c_str();
81+
#ifdef _WIN32
82+
auto fileW = utf8_decode(filename);
83+
FILE* outfile = _wfopen(fileW.c_str(), L"wb");
84+
#else
4785
FILE* outfile = fopen(fileCStr, "wb");
86+
#endif
4887

4988
TORCH_CHECK(outfile != NULL, "Error opening output file");
5089

torchvision/csrc/cpu/image/read_write_file_cpu.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44
#include <sys/stat.h>
55
#include <torch/torch.h>
66

7-
C10_EXPORT torch::Tensor read_file(std::string filename);
7+
C10_EXPORT torch::Tensor read_file(const std::string& filename);
88

9-
C10_EXPORT void write_file(std::string filename, torch::Tensor& data);
9+
C10_EXPORT void write_file(const std::string& filename, torch::Tensor& data);

0 commit comments

Comments
 (0)