1
1
#include " read_write_file_cpu.h"
2
2
3
- // According to
4
- // https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/stat-functions?view=vs-2019,
5
- // we should use _stat64 for 64-bit file size on Windows.
6
3
#ifdef _WIN32
7
- #define VISION_STAT _stat64
8
- #else
9
- #define VISION_STAT stat
4
+ #define WIN32_LEAN_AND_MEAN
5
+ #include < Windows.h>
6
+
7
+ std::wstring utf8_decode (const std::string& str) {
8
+ if (str.empty ()) {
9
+ return std::wstring ();
10
+ }
11
+ int size_needed = MultiByteToWideChar (
12
+ CP_UTF8, 0 , str.c_str (), static_cast <int >(str.size ()), NULL , 0 );
13
+ TORCH_CHECK (size_needed > 0 , " Error converting the content to Unicode" );
14
+ std::wstring wstrTo (size_needed, 0 );
15
+ MultiByteToWideChar (
16
+ CP_UTF8,
17
+ 0 ,
18
+ str.c_str (),
19
+ static_cast <int >(str.size ()),
20
+ &wstrTo[0 ],
21
+ size_needed);
22
+ return wstrTo;
23
+ }
10
24
#endif
11
25
12
- torch::Tensor read_file (std::string filename) {
13
- struct VISION_STAT stat_buf;
14
- int rc = VISION_STAT (filename.c_str (), &stat_buf);
26
+ torch::Tensor read_file (const std::string& filename) {
27
+ #ifdef _WIN32
28
+ // According to
29
+ // https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/stat-functions?view=vs-2019,
30
+ // we should use struct __stat64 and _wstat64 for 64-bit file size on Windows.
31
+ struct __stat64 stat_buf;
32
+ auto fileW = utf8_decode (filename);
33
+ int rc = _wstat64 (fileW.c_str (), &stat_buf);
34
+ #else
35
+ struct stat stat_buf;
36
+ int rc = stat (filename.c_str (), &stat_buf);
37
+ #endif
15
38
// errno is a variable defined in errno.h
16
39
TORCH_CHECK (
17
40
rc == 0 , " [Errno " , errno, " ] " , strerror (errno), " : '" , filename, " '" );
@@ -21,9 +44,20 @@ torch::Tensor read_file(std::string filename) {
21
44
TORCH_CHECK (size > 0 , " Expected a non empty file" );
22
45
23
46
#ifdef _WIN32
24
- auto data =
25
- torch::from_file (filename, /* shared=*/ false , /* size=*/ size, torch::kU8 )
26
- .clone ();
47
+ // TODO: Once torch::from_file handles UTF-8 paths correctly, we should move
48
+ // back to use the following implementation since it uses file mapping.
49
+ // auto data =
50
+ // torch::from_file(filename, /*shared=*/false, /*size=*/size,
51
+ // torch::kU8).clone()
52
+ FILE* infile = _wfopen (fileW.c_str (), L" rb" );
53
+
54
+ TORCH_CHECK (infile != nullptr , " Error opening input file" );
55
+
56
+ auto data = torch::empty ({size}, torch::kU8 );
57
+ auto dataBytes = data.data_ptr <uint8_t >();
58
+
59
+ fread (dataBytes, sizeof (uint8_t ), size, infile);
60
+ fclose (infile);
27
61
#else
28
62
auto data =
29
63
torch::from_file (filename, /* shared=*/ false , /* size=*/ size, torch::kU8 );
@@ -32,7 +66,7 @@ torch::Tensor read_file(std::string filename) {
32
66
return data;
33
67
}
34
68
35
- void write_file (std::string filename, torch::Tensor& data) {
69
+ void write_file (const std::string& filename, torch::Tensor& data) {
36
70
// Check that the input tensor is on CPU
37
71
TORCH_CHECK (data.device () == torch::kCPU , " Input tensor should be on CPU" );
38
72
@@ -44,7 +78,12 @@ void write_file(std::string filename, torch::Tensor& data) {
44
78
45
79
auto fileBytes = data.data_ptr <uint8_t >();
46
80
auto fileCStr = filename.c_str ();
81
+ #ifdef _WIN32
82
+ auto fileW = utf8_decode (filename);
83
+ FILE* outfile = _wfopen (fileW.c_str (), L" wb" );
84
+ #else
47
85
FILE* outfile = fopen (fileCStr, " wb" );
86
+ #endif
48
87
49
88
TORCH_CHECK (outfile != NULL , " Error opening output file" );
50
89
0 commit comments