-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathplot_backdoor_mnist_results.py
More file actions
77 lines (66 loc) · 2.31 KB
/
plot_backdoor_mnist_results.py
File metadata and controls
77 lines (66 loc) · 2.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
sns.set_theme(style="whitegrid", palette="colorblind")
results_df = pd.read_csv("out/backdoor_mnist_results.csv")
results_df["poisoned samples"] = results_df["poison rate"].apply(
lambda r: int(r * 0.8 * 14000)
)
results_df["method"] = results_df["method"].replace(
"PrivaTree (0.1)", "PrivaTree (ε = 0.1)"
)
results_df["method"] = results_df["method"].replace(
"PrivaTree (0.01)", "PrivaTree (ε = 0.01)"
)
n_poison_samples = np.arange(results_df["poisoned samples"].max())
base_asr_01 = results_df[results_df["method"] == "PrivaTree (ε = 0.1)"][
results_df["poison rate"] == 0
]["ASR"].mean()
base_asr_001 = results_df[results_df["method"] == "PrivaTree (ε = 0.01)"][
results_df["poison rate"] == 0
]["ASR"].mean()
bound_01 = 1 - (1 - base_asr_01) * np.exp(-0.1 * n_poison_samples)
bound_001 = 1 - (1 - base_asr_001) * np.exp(-0.01 * n_poison_samples)
accuracy_dt = results_df[results_df["method"] == "decision tree"][
results_df["poison rate"] == 0
]["accuracy"].mean()
accuracy_01 = results_df[results_df["method"] == "PrivaTree (ε = 0.1)"][
results_df["poison rate"] == 0
]["accuracy"].mean()
accuracy_001 = results_df[results_df["method"] == "PrivaTree (ε = 0.01)"][
results_df["poison rate"] == 0
]["accuracy"].mean()
print("Base accuracy decision tree:", accuracy_dt)
print("Base accuracy PrivaTree (ε = 0.1):", accuracy_01)
print("Base accuracy PrivaTree (ε = 0.01):", accuracy_001)
_, ax = plt.subplots(figsize=(6.4, 3.0))
sns.lineplot(x="poisoned samples", y="ASR", hue="method", marker="o", data=results_df, ax=ax)
ax.plot(
n_poison_samples,
bound_01,
c=sns.color_palette()[1],
linestyle="--",
label="bound (ε = 0.1)",
)
ax.plot(
n_poison_samples,
bound_001,
c=sns.color_palette()[2],
linestyle="--",
label="bound (ε = 0.01)",
)
# Reorder legend items for nicer fit
handles, labels = plt.gca().get_legend_handles_labels()
order = [1, 2, 3, 4, 0]
plt.legend(
[handles[idx] for idx in order],
[labels[idx] for idx in order],
loc="lower center",
bbox_to_anchor=(0.48, 1.0),
ncols=3,
)
plt.tight_layout()
plt.savefig("out/mnist_backdoor_asr_bounds.png", bbox_inches="tight")
plt.savefig("out/mnist_backdoor_asr_bounds.pdf", bbox_inches="tight")
plt.close()