-
Notifications
You must be signed in to change notification settings - Fork 44
Expand file tree
/
Copy pathDataLoader.py
More file actions
109 lines (97 loc) · 5.52 KB
/
DataLoader.py
File metadata and controls
109 lines (97 loc) · 5.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 17-4-27 下午8:43
# @Author : Tianyu Liu
import tensorflow as tf
import time
import numpy as np
class DataLoader(object):
def __init__(self, data_dir, limits):
self.train_data_path = [data_dir + '/train/train.summary.id', data_dir + '/train/train.box.val.id',
data_dir + '/train/train.box.lab.id', data_dir + '/train/train.box.pos',
data_dir + '/train/train.box.rpos']
self.test_data_path = [data_dir + '/test/test.summary.id', data_dir + '/test/test.box.val.id',
data_dir + '/test/test.box.lab.id', data_dir + '/test/test.box.pos',
data_dir + '/test/test.box.rpos']
self.dev_data_path = [data_dir + '/valid/valid.summary.id', data_dir + '/valid/valid.box.val.id',
data_dir + '/valid/valid.box.lab.id', data_dir + '/valid/valid.box.pos',
data_dir + '/valid/valid.box.rpos']
self.limits = limits
self.man_text_len = 100
start_time = time.time()
print('Reading datasets ...')
self.train_set = self.load_data(self.train_data_path)
self.test_set = self.load_data(self.test_data_path)
# self.small_test_set = self.load_data(self.small_test_data_path)
self.dev_set = self.load_data(self.dev_data_path)
print ('Reading datasets comsumes %.3f seconds' % (time.time() - start_time))
def load_data(self, path):
summary_path, text_path, field_path, pos_path, rpos_path = path
summaries = open(summary_path, 'r').read().strip().split('\n')
texts = open(text_path, 'r').read().strip().split('\n')
fields = open(field_path, 'r').read().strip().split('\n')
poses = open(pos_path, 'r').read().strip().split('\n')
rposes = open(rpos_path, 'r').read().strip().split('\n')
if self.limits > 0:
summaries = summaries[:self.limits]
texts = texts[:self.limits]
fields = fields[:self.limits]
poses = poses[:self.limits]
rposes = rposes[:self.limits]
print summaries[0].strip().split(' ')
summaries = [list(map(int, summary.strip().split(' '))) for summary in summaries]
texts = [list(map(int, text.strip().split(' '))) for text in texts]
fields = [list(map(int, field.strip().split(' '))) for field in fields]
poses = [list(map(int, pos.strip().split(' '))) for pos in poses]
rposes = [list(map(int, rpos.strip().split(' '))) for rpos in rposes]
return summaries, texts, fields, poses, rposes
def batch_iter(self, data, batch_size, shuffle):
summaries, texts, fields, poses, rposes = data
data_size = len(summaries)
num_batches = int(data_size / batch_size) if data_size % batch_size == 0 \
else int(data_size / batch_size) + 1
if shuffle:
shuffle_indices = np.random.permutation(np.arange(data_size))
summaries = np.array(summaries)[shuffle_indices]
texts = np.array(texts)[shuffle_indices]
fields = np.array(fields)[shuffle_indices]
poses = np.array(poses)[shuffle_indices]
rposes = np.array(rposes)[shuffle_indices]
for batch_num in range(num_batches):
start_index = batch_num * batch_size
end_index = min((batch_num + 1) * batch_size, data_size)
max_summary_len = max([len(sample) for sample in summaries[start_index:end_index]])
max_text_len = max([len(sample) for sample in texts[start_index:end_index]])
batch_data = {'enc_in':[], 'enc_fd':[], 'enc_pos':[], 'enc_rpos':[], 'enc_len':[],
'dec_in':[], 'dec_len':[], 'dec_out':[]}
for summary, text, field, pos, rpos in zip(summaries[start_index:end_index], texts[start_index:end_index],
fields[start_index:end_index], poses[start_index:end_index],
rposes[start_index:end_index]):
summary_len = len(summary)
text_len = len(text)
pos_len = len(pos)
rpos_len = len(rpos)
assert text_len == len(field)
assert pos_len == len(field)
assert rpos_len == pos_len
gold = summary + [2] + [0] * (max_summary_len - summary_len)
summary = summary + [0] * (max_summary_len - summary_len)
text = text + [0] * (max_text_len - text_len)
field = field + [0] * (max_text_len - text_len)
pos = pos + [0] * (max_text_len - text_len)
rpos = rpos + [0] * (max_text_len - text_len)
if max_text_len > self.man_text_len:
text = text[:self.man_text_len]
field = field[:self.man_text_len]
pos = pos[:self.man_text_len]
rpos = rpos[:self.man_text_len]
text_len = min(text_len, self.man_text_len)
batch_data['enc_in'].append(text)
batch_data['enc_len'].append(text_len)
batch_data['enc_fd'].append(field)
batch_data['enc_pos'].append(pos)
batch_data['enc_rpos'].append(rpos)
batch_data['dec_in'].append(summary)
batch_data['dec_len'].append(summary_len)
batch_data['dec_out'].append(gold)
yield batch_data