Skip to content

Commit e1942cd

Browse files
committed
Improve request parser for better compatibility with Socket Mode
1 parent ba3ca95 commit e1942cd

File tree

2 files changed

+64
-17
lines changed

2 files changed

+64
-17
lines changed

slack_bolt/request/internals.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ def parse_body(body: str, content_type: Optional[str]) -> Dict[str, Any]:
3636
) or body.startswith("{"):
3737
return json.loads(body)
3838
else:
39-
if "payload" in body:
39+
if "payload" in body: # This is not JSON format yet
4040
params = dict(parse_qsl(body))
41-
if "payload" in params:
41+
if params.get("payload") is not None:
4242
return json.loads(params.get("payload"))
4343
else:
4444
return {}
@@ -56,72 +56,72 @@ def extract_is_enterprise_install(payload: Dict[str, Any]) -> Optional[bool]:
5656

5757

5858
def extract_enterprise_id(payload: Dict[str, Any]) -> Optional[str]:
59-
if "enterprise" in payload:
59+
if payload.get("enterprise") is not None:
6060
org = payload.get("enterprise")
6161
if isinstance(org, str):
6262
return org
6363
elif "id" in org:
6464
return org.get("id") # type: ignore
65-
if "authorizations" in payload and len(payload["authorizations"]) > 0:
65+
if payload.get("authorizations") is not None and len(payload["authorizations"]) > 0:
6666
# To make Events API handling functioning also for shared channels,
6767
# we should use .authorizations[0].enterprise_id over .enterprise_id
6868
return extract_enterprise_id(payload["authorizations"][0])
6969
if "enterprise_id" in payload:
7070
return payload.get("enterprise_id")
71-
if "team" in payload and "enterprise_id" in payload["team"]:
71+
if payload.get("team") is not None and "enterprise_id" in payload["team"]:
7272
# In the case where the type is view_submission
7373
return payload["team"].get("enterprise_id")
74-
if "event" in payload:
74+
if payload.get("event") is not None:
7575
return extract_enterprise_id(payload["event"])
7676
return None
7777

7878

7979
def extract_team_id(payload: Dict[str, Any]) -> Optional[str]:
80-
if "team" in payload:
80+
if payload.get("team") is not None:
8181
team = payload.get("team")
8282
if isinstance(team, str):
8383
return team
8484
elif team and "id" in team:
8585
return team.get("id")
86-
if "authorizations" in payload and len(payload["authorizations"]) > 0:
86+
if payload.get("authorizations") is not None and len(payload["authorizations"]) > 0:
8787
# To make Events API handling functioning also for shared channels,
8888
# we should use .authorizations[0].team_id over .team_id
8989
return extract_team_id(payload["authorizations"][0])
9090
if "team_id" in payload:
9191
return payload.get("team_id")
92-
if "event" in payload:
92+
if payload.get("event") is not None:
9393
return extract_team_id(payload["event"])
94-
if "user" in payload:
94+
if payload.get("user") is not None:
9595
return payload.get("user")["team_id"]
9696
return None
9797

9898

9999
def extract_user_id(payload: Dict[str, Any]) -> Optional[str]:
100-
if "user" in payload:
100+
if payload.get("user") is not None:
101101
user = payload.get("user")
102102
if isinstance(user, str):
103103
return user
104104
elif "id" in user:
105105
return user.get("id") # type: ignore
106106
if "user_id" in payload:
107107
return payload.get("user_id")
108-
if "event" in payload:
108+
if payload.get("event") is not None:
109109
return extract_user_id(payload["event"])
110110
return None
111111

112112

113113
def extract_channel_id(payload: Dict[str, Any]) -> Optional[str]:
114-
if "channel" in payload:
114+
if payload.get("channel") is not None:
115115
channel = payload.get("channel")
116116
if isinstance(channel, str):
117117
return channel
118118
elif "id" in channel:
119119
return channel.get("id") # type: ignore
120120
if "channel_id" in payload:
121121
return payload.get("channel_id")
122-
if "event" in payload:
122+
if payload.get("event") is not None:
123123
return extract_channel_id(payload["event"])
124-
if "item" in payload:
124+
if payload.get("item") is not None:
125125
# reaction_added: body["event"]["item"]
126126
return extract_channel_id(payload["item"])
127127
return None

tests/slack_bolt/request/test_internals.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,35 @@ def teardown_method(self):
5050
},
5151
]
5252

53+
enterprise_no_channel_requests = [
54+
{
55+
"type": "shortcut",
56+
"token": "xxx",
57+
"action_ts": "1606983924.521157",
58+
"team": {"id": "T111", "domain": "ddd"},
59+
"user": {"id": "U111", "username": "use", "team_id": "T111"},
60+
"is_enterprise_install": False,
61+
"enterprise": {"id": "E111", "domain": "eee"},
62+
"callback_id": "run-socket-mode",
63+
"trigger_id": "111.222.xxx",
64+
},
65+
]
66+
67+
no_enterprise_no_channel_requests = [
68+
{
69+
"type": "shortcut",
70+
"token": "xxx",
71+
"action_ts": "1606983924.521157",
72+
"team": {"id": "T111", "domain": "ddd"},
73+
"user": {"id": "U111", "username": "use", "team_id": "T111"},
74+
"is_enterprise_install": False,
75+
# This may be "null" in Socket Mode
76+
"enterprise": None,
77+
"callback_id": "run-socket-mode",
78+
"trigger_id": "111.222.xxx",
79+
},
80+
]
81+
5382
def test_channel_id_extraction(self):
5483
for req in self.requests:
5584
channel_id = extract_channel_id(req)
@@ -59,16 +88,34 @@ def test_user_id_extraction(self):
5988
for req in self.requests:
6089
user_id = extract_user_id(req)
6190
assert user_id == "U111"
91+
for req in self.enterprise_no_channel_requests:
92+
user_id = extract_user_id(req)
93+
assert user_id == "U111"
94+
for req in self.no_enterprise_no_channel_requests:
95+
user_id = extract_user_id(req)
96+
assert user_id == "U111"
6297

6398
def test_team_id_extraction(self):
6499
for req in self.requests:
65100
team_id = extract_team_id(req)
66101
assert team_id == "T111"
102+
for req in self.enterprise_no_channel_requests:
103+
team_id = extract_team_id(req)
104+
assert team_id == "T111"
105+
for req in self.no_enterprise_no_channel_requests:
106+
team_id = extract_team_id(req)
107+
assert team_id == "T111"
67108

68109
def test_enterprise_id_extraction(self):
69110
for req in self.requests:
70-
team_id = extract_enterprise_id(req)
71-
assert team_id == "E111"
111+
enterprise_id = extract_enterprise_id(req)
112+
assert enterprise_id == "E111"
113+
for req in self.enterprise_no_channel_requests:
114+
enterprise_id = extract_enterprise_id(req)
115+
assert enterprise_id == "E111"
116+
for req in self.no_enterprise_no_channel_requests:
117+
enterprise_id = extract_enterprise_id(req)
118+
assert enterprise_id is None
72119

73120
def test_is_enterprise_install_extraction(self):
74121
for req in self.requests:

0 commit comments

Comments
 (0)