Skip to content

Commit 3dca403

Browse files
authored
Update plots to include smoothing (#176)
1 parent 6d9f6b4 commit 3dca403

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

src/trustyai/explainers/extras/tssaliency.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,32 @@ def as_html(self) -> Styler:
3131
dataframe = self.as_dataframe()
3232
return dataframe.style
3333

34-
def plot(self):
34+
def plot(self, index: int, cpos, window: int = None):
3535
"""Plot tssaliency explanation for the test point
3636
Based on https://github.com/Trusted-AI/AIX360/blob/master/examples/tssaliency"""
37-
max_abs = np.max(np.abs(self.explanation["saliency"]))
37+
if window:
38+
scores = (
39+
np.convolve(
40+
self.explanation["saliency"].flatten(), np.ones(window), mode="same"
41+
)
42+
/ window
43+
)
44+
else:
45+
scores = self.explanation["saliency"]
3846

47+
vmax = np.max(np.abs(self.explanation["saliency"]))
48+
49+
plt.figure(layout="constrained")
3950
plt.imshow(
40-
self.explanation["saliency"][np.newaxis, :],
41-
aspect="auto",
42-
cmap="seismic",
43-
vmin=-max_abs,
44-
vmax=max_abs,
51+
scores[np.newaxis, :], aspect="auto", cmap="seismic", vmin=-vmax, vmax=vmax
4552
)
4653
plt.colorbar()
4754
plt.plot(self.explanation["input_data"])
55+
instance = self.explanation["instance_prediction"]
56+
plt.title(
57+
"Time Series Saliency Explanation Plot for test point"
58+
f" i={index} with P(Y={cpos})= {instance}"
59+
)
4860
plt.show()
4961

5062

0 commit comments

Comments
 (0)