Skip to content

Commit d686b80

Browse files
sjvendittogemini-code-assist[bot]skjerns
authored
correctly load 2D char arrays (#62)
* small change to correctly load in 2D char arrays * python 3.7 has been removed from runners see actions/runner-images#10893 * Update mat73/core.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update test_mat73.py * Update create_mat.m * Update core.py * fix string array reading * remove dependency on StringDType numpy>=2 --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Simon Kern <14980558+skjerns@users.noreply.github.com>
1 parent 35bb606 commit d686b80

File tree

4 files changed

+64
-13
lines changed

4 files changed

+64
-13
lines changed

mat73/core.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -266,9 +266,31 @@ def convert_mat(self, dataset, depth, MATLAB_class=None):
266266
return None
267267

268268
elif mtype=='char':
269-
string_array = np.ravel(dataset)
270-
string_array = ''.join([chr(x) for x in string_array])
271-
return string_array
269+
codes = np.asarray(dataset, dtype=np.uint16)
270+
271+
# object dtype → keeps '\x00'
272+
# see https://github.com/numpy/numpy/issues/28964
273+
to_char = np.vectorize(chr, otypes=[object])
274+
arr = to_char(codes)
275+
276+
char_axis = 0 if arr.ndim < 3 else -2
277+
char_arr = np.apply_along_axis(lambda x: ''.join(x), axis=char_axis, arr=arr)
278+
279+
string_list = char_arr.tolist()
280+
281+
if arr.ndim==2 and arr.shape[1]==1:
282+
string_list = string_list[0]
283+
284+
if arr.ndim>2:
285+
# print warning to be sure. I haven't encountered any char
286+
# arrays with ndim>2 in the wild yet so can't be sure that
287+
# they are actually the way I synthesized them
288+
logging.warning(f"Loading char array '{dataset.name}' with {arr.ndim} dimensions "
289+
f"might be wrong stacked (i.e. dimensions scrambled). "
290+
f"please check variable is correct and report errors "
291+
f"on github.com/skjerns/mat7.3")
292+
293+
return string_list
272294

273295
elif mtype=='bool':
274296
return bool(dataset)
@@ -361,8 +383,5 @@ def savemat(filename, verbose=True):
361383

362384
if __name__=='__main__':
363385
# for testing / debugging
364-
d = loadmat('../tests/testfile11.mat', only_include='foo')
365-
366-
367-
# file = '../tests/testfile8.mat'
368-
# data = loadmat(file)
386+
d = loadmat('../tests/testfile16.mat')
387+
print(d)

tests/create_mat.m

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,20 @@
6565
clear all
6666

6767

68+
%% file to test 2D char arrays
69+
char_arr_1d = ['abcd']
70+
char_arr_2d = [ ...
71+
'PSTH tensor for image sequences (averaged across frames):'; ...
72+
'dimension 1: 2 scales (zoom1x, zoom2x) '; ...
73+
'dimension 2: 3 category (natural, synthetic, contrast) '; ...
74+
'dimension 3: 10 movies '; ...
75+
'dimension 4: sorted units '; ...
76+
'dimension 5: PSTH time bins '];
77+
78+
char_arr_3d = cat(3, ...
79+
['abcd'; 'defg'], ... % First "page"
80+
['ghij'; 'jklm'], ... % Second "page"
81+
['mnöp'; 'pqrs']) % Third "page"
82+
83+
save('testfile16.mat','char_arr_1d', 'char_arr_2d', 'char_arr_3d', '-v7.3')
84+
clear all

tests/test_mat73.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
version = pkg_resources.get_distribution('mat73').version
1616
except:
1717
version = '0.00'
18-
18+
1919
try:
2020
repo = Repository('.')
2121
head = repo.head
@@ -35,7 +35,7 @@ class Testing(unittest.TestCase):
3535

3636
def setUp(self):
3737
"""make links to test files and make sure they are present"""
38-
for i in range(1, 16):
38+
for i in range(1, 17):
3939
file = 'testfile{}.mat'.format(i)
4040
if not os.path.exists(file):
4141
file = os.path.join('./tests', file)
@@ -390,7 +390,7 @@ def test_load_specific_vars(self):
390390
elapsed2 = time.time()-start
391391
assert elapsed2<elapsed1, 'loading specific var was not faster'
392392

393-
def test_file10_nullchars(self):
393+
def test_file8_nullchars(self):
394394
"""test if null chars are retained in char arrays"""
395395
data = mat73.loadmat(self.testfile8)
396396
self.assertEqual(len(data['char_array']), 7, 'not all elements loaded')
@@ -503,12 +503,27 @@ def test_file15_strip(self):
503503
'x_1_1_10_1_1': (1, 1, 10),
504504
'x_10_1_1_10': (10, 1, 1, 10),
505505
}
506-
507-
506+
507+
508508
for var, shape in expected.items():
509509
self.assertEqual(data[var].shape, shape)
510510
self.assertEqual(data[var].ndim, len(shape))
511511

512+
def test_file16_2d_char_array(self):
513+
"""Test loading of 2D char array"""
514+
# the matlab array has shape (6 X 57)
515+
516+
data = mat73.loadmat(self.testfile16)
517+
518+
expected = ['PSTH tensor for image sequences (averaged across frames):',
519+
'dimension 1: 2 scales (zoom1x, zoom2x) ',
520+
'dimension 2: 3 category (natural, synthetic, contrast) ',
521+
'dimension 3: 10 movies ',
522+
'dimension 4: sorted units ',
523+
'dimension 5: PSTH time bins ']
524+
self.assertEqual(data['char_arr_1d'], 'abcd')
525+
self.assertEqual(data['char_arr_2d'], expected)
526+
self.assertEqual(data['char_arr_3d'], [['abcd', 'defg'], ['ghij', 'jklm'], ['mnöp', 'pqrs']])
512527

513528
if __name__ == '__main__':
514529

tests/testfile16.mat

3.13 KB
Binary file not shown.

0 commit comments

Comments
 (0)