-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest1.py
More file actions
105 lines (80 loc) · 2.65 KB
/
test1.py
File metadata and controls
105 lines (80 loc) · 2.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
'''
仿照jupyter notebook的文本分类模型
简略框架
'''
import numpy as np
import csv
from string import punctuation
import re
import nltk
import ssl
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from collections import Counter
def read_csv(path, encoding='utf-8-sig', headers=None, sep=',', dropna=True):
with open(path, 'r', encoding=encoding) as csv_file:
f = csv.reader(csv_file, delimiter=sep)
start_idx = 0
if headers is None:
headers = next(f)
# print(headers)
start_idx += 1
# ID,txt,Label
sentences = []
labels = []
for line_idx, line in enumerate(f, start_idx):
contents = line
_dict = {}
for header, content in zip(headers, contents):
if str(header).lower() == "label":
labels.append(content)
else:
_dict[header] = str(content).lower() #小写
sentences.append(_dict)
return sentences, labels, headers
def preprocess_input(data, input_cols):
texts = []
all_words = []
# stopwords = get_stopwords()
for line in data:
#每行是一个字典
for key in line:
if key in input_cols:
new_line = deletebr(str(line[key]))
words = ''.join([c for c in new_line if c not in punctuation])
texts.append(words)
all_words.extend(words.split())
return texts, all_words
def preprocess_labels(data):
pass
def get_stopwords():
stop_words = set(stopwords.words('english'))
return stop_words
def deletebr(line):
new_line = re.sub(r'<br\s*.?>', r'', line)
return new_line
def printten(data):
for i in range(10):
print(data[i])
if __name__ == '__main__':
## Read csv data
path = "data/train.csv"
sentences, labels, headers = read_csv(path)
labels = np.array(labels)
## Preprocess input
ingore_cols = ['ID']
input_cols = ['txt']
texts, all_words = preprocess_input(sentences, input_cols)
# printten(texts)
# printten(all_words)
# print(len(labels), type(labels))
## Removing outliers
sentence_lens = Counter([len(x.split()) for x in texts])
# print(sentence_lens)
# print("Minimum review length: {}".format(min(sentence_lens)))
# print("Maximum review length: {}".format(max(sentence_lens)))
# 去除空字符串
non_zero_idx = [ii for ii, review in enumerate(texts) if len(review.split()) != 0]
texts = [texts[ii] for ii in non_zero_idx]
labels = [labels[ii] for ii in non_zero_idx]
# print(len(texts), len(labels))