Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit cee0243

Browse files
authored
Fix CNNDM dataset tests (#2246)
1 parent ecb9ebc commit cee0243

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

torchtext/datasets/cnndm.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,21 @@ def _hash_urls(s: tuple):
7777

7878

7979
def _get_split_list(source: str, split: str):
80+
from torchdata.datapipes.iter import ( # noqa
81+
IterableWrapper,
82+
OnlineReader,
83+
)
8084
url_dp = IterableWrapper([SPLIT_LIST[source + "_" + split]])
8185
online_dp = OnlineReader(url_dp)
8286
return online_dp.readlines().map(fn=_hash_urls)
8387

8488

8589
def _load_stories(root: str, source: str, split: str):
90+
from torchdata.datapipes.iter import ( # noqa
91+
FileOpener,
92+
IterableWrapper,
93+
GDriveReader,
94+
)
8695
split_list = set(_get_split_list(source, split))
8796
story_dp = IterableWrapper([URL[source]])
8897
cache_compressed_dp = story_dp.on_disk_cache(
@@ -135,12 +144,6 @@ def CNNDM(root: str, split: Union[Tuple[str], str]):
135144
raise ModuleNotFoundError(
136145
"Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data"
137146
)
138-
from torchdata.datapipes.iter import ( # noqa
139-
FileOpener,
140-
IterableWrapper,
141-
OnlineReader,
142-
GDriveReader,
143-
)
144147

145148
cnn_dp = _load_stories(root, "cnn", split)
146149
dailymail_dp = _load_stories(root, "dailymail", split)

0 commit comments

Comments
 (0)