Skip to content

Commit 0327121

Browse files
author
tianxin
authored
Merge branch 'develop' into few_shot_rdrop
2 parents eec0299 + 928c34a commit 0327121

File tree

9 files changed

+229
-9
lines changed

9 files changed

+229
-9
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ pip install --upgrade paddlenlp
5252

5353
### Transformer API: 强大的预训练模型生态底座
5454

55-
覆盖**15**个网络结构和**67**个预训练模型参数,既包括百度自研的预训练模型如ERNIE系列, PLATO, SKEP等,也涵盖业界主流的中文预训练模型。也欢迎开发者进预训练模贡献!🤗
55+
覆盖**15**个网络结构和**67**个预训练模型参数,既包括百度自研的预训练模型如ERNIE系列, PLATO, SKEP等,也涵盖业界主流的中文预训练模型。也欢迎开发者贡献更多预训练模型!🤗
5656

5757
```python
5858
from paddlenlp.transformers import *

docs/data_prepare/dataset_list.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ PaddleNLP提供了以下数据集的快速读取API,实际使用时请根据
9292
| [DuReaderQG](https://github.com/PaddlePaddle/Research/tree/master/NLP/DuReader-Robust-BASELINE) | 基于DuReader的问题生成数据集| `paddlenlp.datasets.load_dataset('dureader_qg')`|
9393
| [AdvertiseGen](https://github.com/ZhihongShao/Planning-based-Hierarchical-Variational-Model) | 中文文案生成数据集| `paddlenlp.datasets.load_dataset('advertisegen')`|
9494
| [LCSTS_new](https://aclanthology.org/D15-1229.pdf) | 中文摘要生成数据集| `paddlenlp.datasets.load_dataset('lcsts_new')`|
95+
| [CNN/Dailymail](https://github.com/abisee/cnn-dailymail) | 英文摘要生成数据集| `paddlenlp.datasets.load_dataset('cnn_dailymail')`|
9596

9697
## 语料库
9798

paddlenlp/datasets/cnn_dailymail.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
# Copyright 2020 The TensorFlow Datasets Authors and the HuggingFace Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import collections
17+
import os
18+
import hashlib
19+
20+
from paddle.dataset.common import md5file
21+
from paddlenlp.utils.downloader import get_path_from_url, _decompress
22+
from paddlenlp.utils.env import DATA_HOME
23+
from paddlenlp.utils.log import logger
24+
from . import DatasetBuilder
25+
26+
27+
class CnnDailymail(DatasetBuilder):
28+
"""
29+
CNN/DailyMail non-anonymized summarization dataset.
30+
The CNN / DailyMail Dataset is an English-language dataset containing
31+
just over 300k unique news articles as written by journalists at CNN
32+
nd the Daily Mail. The current version supports both extractive and
33+
abstractive summarization, though the original version was created
34+
for machine reading and comprehension and abstractive question answering.
35+
36+
Version 1.0.0 aimed to support supervised neural methodologies for machine
37+
reading and question answering with a large amount of real natural language
38+
training data and released about 313k unique articles and nearly 1M Cloze
39+
style questions to go with the articles.
40+
Versions 2.0.0 and 3.0.0 changed the structure of the dataset to support
41+
summarization rather than question answering. Version 3.0.0 provided a
42+
non-anonymized version of the data, whereas both the previous versions were
43+
preprocessed to replace named entities with unique identifier labels.
44+
45+
An updated version of the code that does not anonymize the data is available
46+
at https://github.com/abisee/cnn-dailymail.
47+
"""
48+
lazy = False
49+
META_INFO = collections.namedtuple("META_INFO", ("file", "url", "md5"))
50+
SPLITS = {
51+
"train": META_INFO(
52+
"all_train.txt",
53+
"https://paddlenlp.bj.bcebos.com/datasets/cnn_dailymail/all_train.txt",
54+
"c8ca98cfcb6cf3f99a404552568490bc"),
55+
"dev": META_INFO(
56+
"all_val.txt",
57+
"https://paddlenlp.bj.bcebos.com/datasets/cnn_dailymail/all_val.txt",
58+
"83a3c483b3ed38b1392285bed668bfee"),
59+
"test": META_INFO(
60+
"all_test.txt",
61+
"https://paddlenlp.bj.bcebos.com/datasets/cnn_dailymail/all_test.txt",
62+
"4f3ac04669934dbc746b7061e68a0258")
63+
}
64+
cnn_dailymail = {
65+
"cnn": {
66+
"url":
67+
"https://paddlenlp.bj.bcebos.com/datasets/cnn_dailymail/cnn_stories.tgz",
68+
"md5": "85ac23a1926a831e8f46a6b8eaf57263",
69+
"file_num": 92579
70+
},
71+
"dailymail": {
72+
"url":
73+
"https://paddlenlp.bj.bcebos.com/datasets/cnn_dailymail/dailymail_stories.tgz",
74+
"md5": "f9c5f565e8abe86c38bfa4ae8f96fd72",
75+
"file_num": 219506
76+
}
77+
}
78+
79+
def _read_text_file(self, text_file):
80+
lines = []
81+
with open(text_file, "r", encoding="utf8") as f:
82+
for line in f:
83+
lines.append(line.strip())
84+
return lines
85+
86+
def _get_url_hashes(self, path):
87+
"""Get hashes of urls in file."""
88+
urls = self._read_text_file(path)
89+
90+
def url_hash(u):
91+
h = hashlib.sha1()
92+
try:
93+
u = u.encode("utf-8")
94+
except UnicodeDecodeError:
95+
logger.error("Cannot hash url: %s", u)
96+
h.update(u)
97+
return h.hexdigest()
98+
99+
return {url_hash(u): True for u in urls}
100+
101+
def _get_hash_from_path(self, p):
102+
"""Extract hash from path."""
103+
basename = os.path.basename(p)
104+
return basename[0:basename.find(".story")]
105+
106+
def _find_files(self, dl_paths, publisher, url_dict):
107+
"""Find files corresponding to urls."""
108+
if publisher == "cnn":
109+
top_dir = os.path.join(dl_paths["cnn"], "stories")
110+
elif publisher == "dailymail":
111+
top_dir = os.path.join(dl_paths["dailymail"], "stories")
112+
else:
113+
logger.error("Unsupported publisher: %s", publisher)
114+
files = sorted(os.listdir(top_dir))
115+
116+
ret_files = []
117+
for p in files:
118+
if self._get_hash_from_path(p) in url_dict:
119+
ret_files.append(os.path.join(top_dir, p))
120+
return ret_files
121+
122+
def _subset_filenames(self, dl_paths, split):
123+
"""Get filenames for a particular split."""
124+
# Get filenames for a split.
125+
urls = self._get_url_hashes(dl_paths[split])
126+
cnn = self._find_files(dl_paths, "cnn", urls)
127+
dm = self._find_files(dl_paths, "dailymail", urls)
128+
return cnn + dm
129+
130+
def _get_art_abs(self, story_file, version):
131+
"""Get abstract (highlights) and article from a story file path."""
132+
# Based on https://github.com/abisee/cnn-dailymail/blob/master/
133+
# make_datafiles.py
134+
135+
lines = self._read_text_file(story_file)
136+
137+
# The github code lowercase the text and we removed it in 3.0.0.
138+
139+
# Put periods on the ends of lines that are missing them
140+
# (this is a problem in the dataset because many image captions don't end in
141+
# periods; consequently they end up in the body of the article as run-on
142+
# sentences)
143+
def fix_missing_period(line):
144+
"""Adds a period to a line that is missing a period."""
145+
if "@highlight" in line:
146+
return line
147+
if not line:
148+
return line
149+
if line[-1] in [
150+
".", "!", "?", "...", "'", "`", '"', "\u2019", "\u201d", ")"
151+
]:
152+
return line
153+
return line + " ."
154+
155+
lines = [fix_missing_period(line) for line in lines]
156+
157+
# Separate out article and abstract sentences
158+
article_lines = []
159+
highlights = []
160+
next_is_highlight = False
161+
for line in lines:
162+
if not line:
163+
continue # empty line
164+
elif line.startswith("@highlight"):
165+
next_is_highlight = True
166+
elif next_is_highlight:
167+
highlights.append(line)
168+
else:
169+
article_lines.append(line)
170+
171+
# Make article into a single string
172+
article = " ".join(article_lines)
173+
174+
if version >= "2.0.0":
175+
abstract = "\n".join(highlights)
176+
else:
177+
abstract = " ".join(highlights)
178+
179+
return article, abstract
180+
181+
def _get_data(self, mode):
182+
""" Check and download Dataset """
183+
dl_paths = {}
184+
version = self.config.get("version", "3.0.0")
185+
if version not in ["1.0.0", "2.0.0", "3.0.0"]:
186+
raise ValueError("Unsupported version: %s" % version)
187+
dl_paths["version"] = version
188+
default_root = os.path.join(DATA_HOME, self.__class__.__name__)
189+
for k, v in self.cnn_dailymail.items():
190+
dir_path = os.path.join(default_root, k)
191+
if not os.path.exists(dir_path):
192+
get_path_from_url(v["url"], default_root, v["md5"])
193+
file_num = len(os.listdir(os.path.join(dir_path, "stories")))
194+
if file_num != v["file_num"]:
195+
logger.warning(
196+
"Number of %s stories is %d != %d, decompress again." %
197+
(k, file_num, v["file_num"]))
198+
_decompress(
199+
os.path.join(default_root, os.path.basename(v["url"])))
200+
dl_paths[k] = dir_path
201+
filename, url, data_hash = self.SPLITS[mode]
202+
fullname = os.path.join(default_root, filename)
203+
if not os.path.exists(fullname) or (data_hash and
204+
not md5file(fullname) == data_hash):
205+
get_path_from_url(url, default_root, data_hash)
206+
dl_paths[mode] = fullname
207+
return dl_paths
208+
209+
def _read(self, dl_paths, split):
210+
files = self._subset_filenames(dl_paths, split)
211+
for p in files:
212+
article, highlights = self._get_art_abs(p, dl_paths["version"])
213+
if not article or not highlights:
214+
continue
215+
yield {
216+
"article": article,
217+
"highlights": highlights,
218+
"id": self._get_hash_from_path(p),
219+
}

paddlenlp/taskflow/dependency_parsing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def __init__(self,
112112
ddparser, ddparser-ernie-1.0 and ddoarser-ernie-gram-zh")
113113
word_vocab_path = download_file(
114114
self._task_path, self.model + os.path.sep + "word_vocab.json",
115-
URLS[self.model][0], URLS[self.model][1])
115+
URLS[self.model][0], URLS[self.model][1], self.model)
116116
rel_vocab_path = download_file(
117117
self._task_path, self.model + os.path.sep + "rel_vocab.json",
118118
URLS[self.model][0], URLS[self.model][1])

paddlenlp/taskflow/lexical_analysis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def __init__(self, task, model, **kwargs):
9595
self._usage = usage
9696
word_dict_path = download_file(
9797
self._task_path, "lac_params" + os.path.sep + "word.dic",
98-
URLS['lac_params'][0], URLS['lac_params'][1])
98+
URLS['lac_params'][0], URLS['lac_params'][1], 'lexical_analysis')
9999
tag_dict_path = download_file(
100100
self._task_path, "lac_params" + os.path.sep + "tag.dic",
101101
URLS['lac_params'][0], URLS['lac_params'][1])

paddlenlp/taskflow/sentiment_analysis.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def _construct_model(self, model):
115115
padding_idx=pad_token_id,
116116
pooling_type='max')
117117
model_path = download_file(self._task_path, model + ".pdparams",
118-
URLS[model][0], URLS[model][1])
118+
URLS[model][0], URLS[model][1], model)
119119

120120
# Load the model parameter for the predict
121121
state_dict = paddle.load(model_path)
@@ -234,7 +234,7 @@ def _construct_model(self, model):
234234
model_instance = SkepSequenceModel.from_pretrained(
235235
model, num_classes=len(self._label_map))
236236
model_path = download_file(self._task_path, model + ".pdparams",
237-
URLS[model][0], URLS[model][1])
237+
URLS[model][0], URLS[model][1], model)
238238
state_dict = paddle.load(model_path)
239239
model_instance.set_state_dict(state_dict)
240240
self._model = model_instance

paddlenlp/taskflow/text_correction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def _construct_model(self, model):
139139
pad_pinyin_id=self._pinyin_vocab[self._pinyin_vocab.pad_token])
140140
# Load the model parameter for the predict
141141
model_path = download_file(self._task_path, model + ".pdparams",
142-
URLS[model][0], URLS[model][1])
142+
URLS[model][0], URLS[model][1], model)
143143
state_dict = paddle.load(model_path)
144144
model_instance.set_state_dict(state_dict)
145145
model_instance.eval()

paddlenlp/taskflow/text_generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(self, task, model, **kwargs):
7474
if self._static_mode:
7575
download_file(self._task_path,
7676
"static" + os.path.sep + "inference.pdiparams",
77-
URLS[self.model][0], URLS[self.model][1])
77+
URLS[self.model][0], URLS[self.model][1], model)
7878
self._get_inference_model()
7979
else:
8080
self._construct_model(model)

paddlenlp/transformers/albert/tokenizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,12 @@ def __init__(
151151

152152
if vocab_file is not None:
153153
self.tokenizer = AlbertChineseTokenizer(
154-
vocab_file,
154+
vocab_file=vocab_file,
155155
do_lower_case=False,
156156
)
157157
elif sentencepiece_model_file is not None:
158158
self.tokenizer = AlbertEnglishTokenizer(
159-
sentencepiece_model_file,
159+
sentencepiece_model_file=sentencepiece_model_file,
160160
do_lower_case=True,
161161
)
162162
else:

0 commit comments

Comments
 (0)