-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtrain_svm.py
More file actions
43 lines (35 loc) · 1.22 KB
/
train_svm.py
File metadata and controls
43 lines (35 loc) · 1.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# -*- coding: utf-8 -*-
"""
Created on Mon Feb 26 20:47:21 2021
@author: Lung-GANs
"""
from time import time
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.externals import joblib
from sklearn.metrics import classification_report
from sklearn import svm
#
import numpy as np
acc = []
nums = [75]
style_label_file = './style_names.txt'
target_names = list(np.loadtxt(style_label_file, str, delimiter='\n'))
for num in nums:
X_train=np.load('features/features%d_train.npy'%num)
y_train=np.load('features/label%d_train.npy'%num)
X_test=np.load('features/features%d_test.npy'%num)
y_test=np.load('features/label%d_test.npy'%num)
print("Fitting the classifier to the training set")
t0 = time()
C = 1000.0 # SVM regularization parameter
clf = svm.SVC(kernel='linear', C=C).fit(X_train, y_train)
print("done in %0.3fs" % (time() - t0))
print("Predicting...")
t0 = time()
y_pred = clf.predict(X_test)
print "Accuracy: %.3f" %(accuracy_score(y_test, y_pred))
acc.append(accuracy_score(y_test, y_pred))
print "Classification Report"
print classification_report(y_test, y_pred, target_names=target_names)
print acc