Skip to content

Commit fa32343

Browse files
committed
Fix permission mapping
Removes collision tests as categories and channel names can overlap in Discord anyway
1 parent a146270 commit fa32343

File tree

9 files changed

+94
-103
lines changed

9 files changed

+94
-103
lines changed

channels.example.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
permissions:
1212
everyone:
1313
send_messages: false
14-
volunteer:
14+
blueshirt:
1515
read_messages: true
1616
send_messages: true
1717
- name: blog
@@ -48,15 +48,15 @@
4848
everyone:
4949
read_messages: false
5050
send_messages: false
51-
volunteer:
51+
blueshirt:
5252
read_messages: true
5353
send_messages: true
5454
channels:
5555
- name: audit-log
5656
use_case: audit
5757
- name: blueshirt-onboarding
5858
permissions:
59-
unverified_volunteer:
59+
unverified_blueshirt:
6060
read_messages: true
6161
- name: blueshirt-banter
6262
- name: blueshirt-serious-business
@@ -85,7 +85,7 @@
8585
everyone:
8686
read_messages: false
8787
send_messages: false
88-
volunteer:
88+
blueshirt:
8989
read_messages: true
9090
send_messages: true
9191
supervisor:

src/sr/discord_bot/__main__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from discord import Intents
99

1010
from sr.discord_bot.bot import BotClient
11+
from sr.discord_bot.schema import ChannelDefinition
1112

1213
load_dotenv()
1314
logger = logging.getLogger("srbot")
@@ -33,11 +34,20 @@
3334
parser_run = subcommands.add_parser("run", help="Run the Discord bot")
3435
parser_plan = subcommands.add_parser("plan", help="List pending guild changes")
3536
parser_apply = subcommands.add_parser("apply", help="Apply pending guild changes")
37+
parser_diff = subcommands.add_parser("diff", help="Compare two channel configurations")
38+
parser_diff.add_argument("old_config", type=argparse.FileType('r'))
39+
parser_diff.add_argument("new_config", type=argparse.FileType('r'))
40+
3641
args = parser.parse_args()
3742

3843
if args.command is None:
3944
parser.print_help()
4045
exit(1)
46+
elif args.command == "diff":
47+
if args.old_config is None or args.new_config is None:
48+
parser.print_help()
49+
exit(1)
50+
4151

4252
bot = BotClient(logger=logger, intents=intents)
4353
bot.mode = args.command

src/sr/discord_bot/bot.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import asyncio
44
import logging
5+
from pathlib import Path
56
from typing import List, Literal
67

78
import yaml
@@ -124,7 +125,7 @@ async def on_ready(self) -> None:
124125
await self.apply_changes()
125126
await self.close()
126127

127-
async def _set_roles_and_channels(self, guild: Guild) -> None:
128+
async def set_roles_and_channels(self, guild: Guild) -> None:
128129
roles = await guild.fetch_roles()
129130
self.admin_role = find_role_by_name(roles, ADMIN_ROLE)
130131
self.verified_role = find_role_by_name(roles, VERIFIED_ROLE)
@@ -151,11 +152,11 @@ async def setup_bot(self) -> None:
151152
return
152153

153154
try:
154-
await self._set_roles_and_channels(guild)
155+
await self.set_roles_and_channels(guild)
155156
except ValueError:
156157
self.logger.info("Setting up guild...")
157158
await setup_guild(self)
158-
await self._set_roles_and_channels(guild)
159+
await self.set_roles_and_channels(guild)
159160

160161
await check_bot_messages(self, self.guild)
161162
self.teams_data.gen_team_memberships(self.guild, self.supervisor_role)
@@ -288,13 +289,12 @@ def _load_passwords(self) -> None:
288289
teamname:password
289290
```
290291
"""
292+
path = Path('passwords.json')
291293
try:
292-
with open('passwords.json') as f:
293-
self.passwords = json.load(f)
294+
self.passwords = json.loads(path.read_text())
294295
except (json.JSONDecodeError, FileNotFoundError):
295-
with open('passwords.json', 'w') as f:
296-
f.write('{}')
297-
self.passwords = {}
296+
path.write_text('{}')
297+
self.passwords = {}
298298

299299
def set_password(self, tla: str, password: str) -> None:
300300
self.passwords[tla.upper()] = password

src/sr/discord_bot/channel.py

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,13 @@
4949

5050

5151
def match_overwrites(overwrites: Overwrites, guild: Guild) -> Mapping[Role | Member | Object, PermissionOverwrite]:
52-
role_overwrites: Mapping[Role | Member | Object, PermissionOverwrite] = {
53-
guild.default_role: PermissionOverwrite(**overwrites.get(RoleType.EVERYONE, {})),
54-
find_role_by_name(guild.roles, VERIFIED_ROLE): PermissionOverwrite(**overwrites.get(RoleType.VERIFIED, {})),
55-
find_role_by_name(guild.roles, VOLUNTEER_ROLE): PermissionOverwrite(**overwrites.get(RoleType.BLUESHIRT, {})),
56-
find_role_by_name(guild.roles, TEAM_LEADER_ROLE): PermissionOverwrite(
57-
**overwrites.get(RoleType.SUPERVISOR, {}),
58-
),
59-
find_role_by_name(guild.roles, SPECIAL_ROLE): PermissionOverwrite(
60-
**overwrites.get(RoleType.UNVERIFIED_BLUESHIRT, {}),
61-
),
52+
return {
53+
guild.default_role: overwrites.get(RoleType.EVERYONE, PermissionOverwrite()),
54+
find_role_by_name(guild.roles, VERIFIED_ROLE): overwrites.get(RoleType.VERIFIED, PermissionOverwrite()),
55+
find_role_by_name(guild.roles, VOLUNTEER_ROLE): overwrites.get(RoleType.BLUESHIRT, PermissionOverwrite()),
56+
find_role_by_name(guild.roles, TEAM_LEADER_ROLE): overwrites.get(RoleType.SUPERVISOR, PermissionOverwrite()),
57+
find_role_by_name(guild.roles, SPECIAL_ROLE): overwrites.get(RoleType.UNVERIFIED_BLUESHIRT, PermissionOverwrite()),
6258
}
63-
return role_overwrites
6459

6560

6661
class ChannelSet:
@@ -84,14 +79,11 @@ def _get_role_type(self, role: Role, guild: Guild) -> RoleType:
8479
return self._role_map[role]
8580

8681
def _get_overwrites(self, channel: GuildChannel) -> Overwrites:
87-
overwrites: Overwrites = {}
88-
for subject, permissions in channel.overwrites.items():
89-
if isinstance(subject, Role):
90-
overwrites[self._get_role_type(subject, channel.guild)] = {}
91-
for permission, value in permissions:
92-
if value is not None:
93-
overwrites[self._get_role_type(subject, channel.guild)][permission] = value
94-
return overwrites
82+
return {
83+
self._get_role_type(subject, channel.guild): permissions
84+
for subject, permissions in channel.overwrites.items()
85+
if isinstance(subject, Role)
86+
}
9587

9688
def count(self, category: Channel | None) -> int:
9789
"""Count the number of channels in a category, or top-level channels if category is None."""
@@ -197,7 +189,6 @@ def create_forum_channel(
197189
def diff(cls, old: ChannelSet, new: ChannelSet) -> list[ChannelCommand]:
198190
"""Calculate the difference between two ChannelSets."""
199191
commands: list[ChannelCommand] = []
200-
# TODO: Diff categories first (in case one gets renamed)
201192
cls._diff_channels(old._channels, new._channels, commands)
202193
return commands
203194

@@ -253,7 +244,7 @@ def _diff_channels(cls, old: list[Channel], new: list[Channel], commands: list[C
253244
commands.append(command)
254245
else:
255246
seen_old_channel_names.add(old_channel.name)
256-
command = AlterChannelCommand.diff(old_channel, new_channel, new_index, old_channel.channel_type)
247+
command = AlterChannelCommand.diff(old_channel, new_channel, old_channel.channel_type)
257248
if command.has_changes:
258249
commands.append(command)
259250

@@ -424,7 +415,6 @@ class AlterChannelCommand(Command):
424415
overwrites: Overwrites | None = None
425416
topic: str | None = None
426417
category: Channel | None = None
427-
position: int | None = None
428418
is_category: bool = False # Whether this command is for a category, only for display purposes
429419
channel_type: ChannelType = ChannelType.text # Used for sorting
430420
# Forum-specific:
@@ -434,8 +424,7 @@ class AlterChannelCommand(Command):
434424
use_case: ChannelUseCase | None = None
435425

436426
@classmethod
437-
def diff(cls, old_channel: Channel, new_channel: Channel,
438-
new_position: int, channel_type: ChannelType) -> AlterChannelCommand:
427+
def diff(cls, old_channel: Channel, new_channel: Channel, channel_type: ChannelType) -> AlterChannelCommand:
439428
command = cls(old_name=old_channel.name, new_name=new_channel.name,
440429
is_category=new_channel.is_category, channel_type=channel_type, use_case=new_channel.use_case)
441430
changed_category = new_channel.category is not None and old_channel.category != new_channel.category
@@ -445,7 +434,6 @@ def diff(cls, old_channel: Channel, new_channel: Channel,
445434
command.topic = new_channel.topic
446435
if not old_channel.category or changed_category:
447436
command.category = new_channel.category
448-
command.position = new_position
449437
return command
450438

451439
def is_rename(self) -> bool:
@@ -484,6 +472,22 @@ def __str__(self) -> str:
484472
return f"ALTER Category \"{self.old_name}\" ({', '.join(changes)})"
485473
return f"ALTER Channel #{self.old_name} ({', '.join(changes)})"
486474

475+
@classmethod
476+
def get_role_by_role_type(cls, role_type: RoleType, guild: Guild) -> Role:
477+
if role_type == RoleType.EVERYONE:
478+
return guild.default_role
479+
480+
role_names = {
481+
RoleType.VERIFIED: VERIFIED_ROLE,
482+
RoleType.BLUESHIRT: VOLUNTEER_ROLE,
483+
RoleType.SUPERVISOR: TEAM_LEADER_ROLE,
484+
RoleType.UNVERIFIED_BLUESHIRT: SPECIAL_ROLE,
485+
}
486+
487+
if role := discord.utils.get(guild.roles, name=role_names[role_type]):
488+
return role
489+
raise ValueError("Invalid role type")
490+
487491
async def apply(self, guild: Guild) -> None:
488492
channel = discord.utils.get(guild.channels, name=self.old_name)
489493
if channel is None:
@@ -499,8 +503,13 @@ async def apply(self, guild: Guild) -> None:
499503
if self.category is not None:
500504
kwargs["category"] = find_channel_by_name(guild.channels, self.category.name) # type: ignore
501505

502-
if self.overwrites is not None and self.overwrites != {}:
503-
kwargs["overwrites"] = self.overwrites # type: ignore
506+
if self.overwrites is not None:
507+
overwrites = {
508+
AlterChannelCommand.get_role_by_role_type(role_type, guild): permissions
509+
for role_type, permissions in self.overwrites.items()
510+
}
511+
if overwrites != channel.overwrites:
512+
kwargs["overwrites"] = overwrites # type: ignore
504513

505514
if channel.type == ChannelType.forum:
506515
await guild.fetch_emojis()
@@ -547,7 +556,7 @@ async def apply(self, guild: Guild) -> None:
547556
@dataclasses.dataclass(frozen=True)
548557
class CreateCategoryCommand(Command):
549558
name: str
550-
overwrites: Overwrites | None = None
559+
overwrites: Overwrites = dataclasses.field(default_factory=dict)
551560

552561
def __str__(self) -> str:
553562
return f'CREATE Category "{self.name}"'
@@ -563,7 +572,7 @@ class CreateChannelCommand(Command):
563572
name: str
564573
channel_type: ChannelType
565574
category: Channel | None = None
566-
overwrites: Overwrites | None = None
575+
overwrites: Overwrites = dataclasses.field(default_factory=dict)
567576
topic: str = ""
568577
# Bot use only:
569578
use_case: ChannelUseCase | None = None
@@ -615,7 +624,7 @@ async def apply(self, guild: Guild) -> None:
615624
class CreateForumCommand(Command):
616625
name: str
617626
category: Channel | None = None
618-
overwrites: Overwrites | None = None
627+
overwrites: Overwrites = dataclasses.field(default_factory=dict)
619628
topic: str = ""
620629
default_reaction_emoji: str | None = None
621630
available_tags: list[ForumTagDefinition] = dataclasses.field(default_factory=list)

src/sr/discord_bot/guild.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ async def create_channels(client: "BotClient", guild: Guild) -> None:
127127
await asyncio.sleep(.5) # avoid hitting rate limits
128128

129129
logging.info("Channels created")
130-
await client._set_roles_and_channels(guild)
130+
await client.set_roles_and_channels(guild)
131131

132132

133133
async def send_template_messages(client: "BotClient", guild: Guild) -> None:

src/sr/discord_bot/schema.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from typing import Any
1010

11-
from discord import ForumTag, ChannelType, PartialEmoji
11+
from discord import ForumTag, ChannelType, PartialEmoji, PermissionOverwrite
1212

1313

1414
class RoleType(StrEnum):
@@ -19,7 +19,7 @@ class RoleType(StrEnum):
1919
UNVERIFIED_BLUESHIRT = "unverified_blueshirt"
2020

2121

22-
Overwrites = dict[RoleType, dict[str, bool]]
22+
Overwrites = dict[RoleType, PermissionOverwrite]
2323

2424

2525
class ChannelUseCase(StrEnum):
@@ -77,7 +77,10 @@ def load( # type: ignore[misc] # This gets validated when the YAML is validated
7777
channel = cls(
7878
name=data["name"],
7979
old_names=data.get("old_names", []),
80-
overwrites=data.get("permissions", {}),
80+
overwrites={
81+
RoleType(role_type): PermissionOverwrite(**permissions)
82+
for role_type, permissions in data.get("permissions", {}).items()
83+
},
8184
topic=data.get("topic", ""),
8285
category=category,
8386
channel_type=ChannelType[data.get("channel_type", default_type)],

src/sr/discord_bot/teams.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
import re
12
from typing import List, NamedTuple
23
from statistics import mean
34
from collections import defaultdict
45

56
import discord
67

7-
from sr.discord_bot.constants import ROLE_PREFIX
8+
TEAM_ROLE_REGEX = re.compile('Team (?P<TLA>[A-Z]{3}\d?)')
89

910

1011
class TeamData(NamedTuple):
@@ -28,7 +29,7 @@ def school(self) -> str:
2829

2930
def __str__(self) -> str:
3031
data_str = f'{self.TLA:<15} {self.members:>2}'
31-
if self.leader is False:
32+
if not self.leader:
3233
data_str += ' No supervisor'
3334
return data_str
3435

@@ -42,9 +43,13 @@ def gen_team_memberships(self, guild: discord.Guild, leader_role: discord.Role)
4243
"""Generate a list of TeamData objects for the given guild, stored in teams_data."""
4344
teams_data = []
4445

45-
for role in filter(lambda role: role.name.startswith(ROLE_PREFIX), guild.roles):
46+
for role in guild.roles:
47+
match = TEAM_ROLE_REGEX.match(role.name)
48+
if match is None:
49+
continue
50+
4651
team_data = TeamData(
47-
TLA=role.name[len(ROLE_PREFIX):],
52+
TLA=match['TLA'],
4853
members=len(list(filter(
4954
lambda member: leader_role not in member.roles,
5055
role.members,

0 commit comments

Comments
 (0)