@@ -49,7 +49,11 @@ def _split_generators(self, dl_manager):
49
49
del dl_manager # Unused
50
50
51
51
metadata_dict = dict ()
52
- metadata_path = os .path .join (_DATA_DIR , 'metadata.csv' )
52
+ if self ._data_dir :
53
+ data_dir = self ._data_dir
54
+ else :
55
+ data_dir = _DATA_DIR
56
+ metadata_path = os .path .join (data_dir , 'metadata.csv' )
53
57
metadata = tf .io .gfile .GFile (metadata_path ).read ().splitlines ()
54
58
55
59
for row in metadata :
@@ -62,21 +66,21 @@ def _split_generators(self, dl_manager):
62
66
name = tfds .Split .TRAIN ,
63
67
gen_kwargs = {
64
68
'metadata' : metadata_dict ,
65
- 'filepath' : os .path .join (_DATA_DIR , 'train' ),
69
+ 'filepath' : os .path .join (data_dir , 'train' ),
66
70
},
67
71
),
68
72
tfds .core .SplitGenerator (
69
73
name = tfds .Split .VALIDATION ,
70
74
gen_kwargs = {
71
75
'metadata' : metadata_dict ,
72
- 'filepath' : os .path .join (_DATA_DIR , 'validation' ),
76
+ 'filepath' : os .path .join (data_dir , 'validation' ),
73
77
},
74
78
),
75
79
tfds .core .SplitGenerator (
76
80
name = tfds .Split .TEST ,
77
81
gen_kwargs = {
78
82
'metadata' : metadata_dict ,
79
- 'filepath' : os .path .join (_DATA_DIR , 'test' ),
83
+ 'filepath' : os .path .join (data_dir , 'test' ),
80
84
},
81
85
),
82
86
]
0 commit comments