Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions channels.example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
permissions:
everyone:
send_messages: false
volunteer:
blueshirt:
read_messages: true
send_messages: true
- name: blog
Expand Down Expand Up @@ -48,15 +48,15 @@
everyone:
read_messages: false
send_messages: false
volunteer:
blueshirt:
read_messages: true
send_messages: true
channels:
- name: audit-log
use_case: audit
- name: blueshirt-onboarding
permissions:
unverified_volunteer:
unverified_blueshirt:
read_messages: true
- name: blueshirt-banter
- name: blueshirt-serious-business
Expand Down Expand Up @@ -85,7 +85,7 @@
everyone:
read_messages: false
send_messages: false
volunteer:
blueshirt:
read_messages: true
send_messages: true
supervisor:
Expand Down
16 changes: 8 additions & 8 deletions src/sr/discord_bot/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import asyncio
import logging
from pathlib import Path
from typing import List, Literal

import yaml
Expand Down Expand Up @@ -124,7 +125,7 @@ async def on_ready(self) -> None:
await self.apply_changes()
await self.close()

async def _set_roles_and_channels(self, guild: Guild) -> None:
async def set_roles_and_channels(self, guild: Guild) -> None:
roles = await guild.fetch_roles()
self.admin_role = find_role_by_name(roles, ADMIN_ROLE)
self.verified_role = find_role_by_name(roles, VERIFIED_ROLE)
Expand All @@ -151,11 +152,11 @@ async def setup_bot(self) -> None:
return

try:
await self._set_roles_and_channels(guild)
await self.set_roles_and_channels(guild)
except ValueError:
self.logger.info("Setting up guild...")
await setup_guild(self)
await self._set_roles_and_channels(guild)
await self.set_roles_and_channels(guild)

await check_bot_messages(self, self.guild)
self.teams_data.gen_team_memberships(self.guild, self.supervisor_role)
Expand Down Expand Up @@ -288,13 +289,12 @@ def _load_passwords(self) -> None:
teamname:password
```
"""
path = Path('passwords.json')
try:
with open('passwords.json') as f:
self.passwords = json.load(f)
self.passwords = json.loads(path.read_text())
except (json.JSONDecodeError, FileNotFoundError):
with open('passwords.json', 'w') as f:
f.write('{}')
self.passwords = {}
path.write_text('{}')
self.passwords = {}

def set_password(self, tla: str, password: str) -> None:
self.passwords[tla.upper()] = password
Expand Down
69 changes: 39 additions & 30 deletions src/sr/discord_bot/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,13 @@


def match_overwrites(overwrites: Overwrites, guild: Guild) -> Mapping[Role | Member | Object, PermissionOverwrite]:
role_overwrites: Mapping[Role | Member | Object, PermissionOverwrite] = {
guild.default_role: PermissionOverwrite(**overwrites.get(RoleType.EVERYONE, {})),
find_role_by_name(guild.roles, VERIFIED_ROLE): PermissionOverwrite(**overwrites.get(RoleType.VERIFIED, {})),
find_role_by_name(guild.roles, VOLUNTEER_ROLE): PermissionOverwrite(**overwrites.get(RoleType.BLUESHIRT, {})),
find_role_by_name(guild.roles, TEAM_LEADER_ROLE): PermissionOverwrite(
**overwrites.get(RoleType.SUPERVISOR, {}),
),
find_role_by_name(guild.roles, SPECIAL_ROLE): PermissionOverwrite(
**overwrites.get(RoleType.UNVERIFIED_BLUESHIRT, {}),
),
return {
guild.default_role: overwrites.get(RoleType.EVERYONE, PermissionOverwrite()),
find_role_by_name(guild.roles, VERIFIED_ROLE): overwrites.get(RoleType.VERIFIED, PermissionOverwrite()),
find_role_by_name(guild.roles, VOLUNTEER_ROLE): overwrites.get(RoleType.BLUESHIRT, PermissionOverwrite()),
find_role_by_name(guild.roles, TEAM_LEADER_ROLE): overwrites.get(RoleType.SUPERVISOR, PermissionOverwrite()),
find_role_by_name(guild.roles, SPECIAL_ROLE): overwrites.get(RoleType.UNVERIFIED_BLUESHIRT, PermissionOverwrite()),
}
return role_overwrites


class ChannelSet:
Expand All @@ -84,14 +79,11 @@ def _get_role_type(self, role: Role, guild: Guild) -> RoleType:
return self._role_map[role]

def _get_overwrites(self, channel: GuildChannel) -> Overwrites:
overwrites: Overwrites = {}
for subject, permissions in channel.overwrites.items():
if isinstance(subject, Role):
overwrites[self._get_role_type(subject, channel.guild)] = {}
for permission, value in permissions:
if value is not None:
overwrites[self._get_role_type(subject, channel.guild)][permission] = value
return overwrites
return {
self._get_role_type(subject, channel.guild): permissions
for subject, permissions in channel.overwrites.items()
if isinstance(subject, Role)
}

def count(self, category: Channel | None) -> int:
"""Count the number of channels in a category, or top-level channels if category is None."""
Expand Down Expand Up @@ -197,7 +189,6 @@ def create_forum_channel(
def diff(cls, old: ChannelSet, new: ChannelSet) -> list[ChannelCommand]:
"""Calculate the difference between two ChannelSets."""
commands: list[ChannelCommand] = []
# TODO: Diff categories first (in case one gets renamed)
cls._diff_channels(old._channels, new._channels, commands)
return commands

Expand Down Expand Up @@ -253,7 +244,7 @@ def _diff_channels(cls, old: list[Channel], new: list[Channel], commands: list[C
commands.append(command)
else:
seen_old_channel_names.add(old_channel.name)
command = AlterChannelCommand.diff(old_channel, new_channel, new_index, old_channel.channel_type)
command = AlterChannelCommand.diff(old_channel, new_channel, old_channel.channel_type)
if command.has_changes:
commands.append(command)

Expand Down Expand Up @@ -424,7 +415,6 @@ class AlterChannelCommand(Command):
overwrites: Overwrites | None = None
topic: str | None = None
category: Channel | None = None
position: int | None = None
is_category: bool = False # Whether this command is for a category, only for display purposes
channel_type: ChannelType = ChannelType.text # Used for sorting
# Forum-specific:
Expand All @@ -434,8 +424,7 @@ class AlterChannelCommand(Command):
use_case: ChannelUseCase | None = None

@classmethod
def diff(cls, old_channel: Channel, new_channel: Channel,
new_position: int, channel_type: ChannelType) -> AlterChannelCommand:
def diff(cls, old_channel: Channel, new_channel: Channel, channel_type: ChannelType) -> AlterChannelCommand:
command = cls(old_name=old_channel.name, new_name=new_channel.name,
is_category=new_channel.is_category, channel_type=channel_type, use_case=new_channel.use_case)
changed_category = new_channel.category is not None and old_channel.category != new_channel.category
Expand All @@ -445,7 +434,6 @@ def diff(cls, old_channel: Channel, new_channel: Channel,
command.topic = new_channel.topic
if not old_channel.category or changed_category:
command.category = new_channel.category
command.position = new_position
return command

def is_rename(self) -> bool:
Expand Down Expand Up @@ -484,6 +472,22 @@ def __str__(self) -> str:
return f"ALTER Category \"{self.old_name}\" ({', '.join(changes)})"
return f"ALTER Channel #{self.old_name} ({', '.join(changes)})"

@classmethod
def get_role_by_role_type(cls, role_type: RoleType, guild: Guild) -> Role:
if role_type == RoleType.EVERYONE:
return guild.default_role

role_names = {
RoleType.VERIFIED: VERIFIED_ROLE,
RoleType.BLUESHIRT: VOLUNTEER_ROLE,
RoleType.SUPERVISOR: TEAM_LEADER_ROLE,
RoleType.UNVERIFIED_BLUESHIRT: SPECIAL_ROLE,
}

if role := discord.utils.get(guild.roles, name=role_names[role_type]):
return role
raise ValueError("Invalid role type")

async def apply(self, guild: Guild) -> None:
channel = discord.utils.get(guild.channels, name=self.old_name)
if channel is None:
Expand All @@ -499,8 +503,13 @@ async def apply(self, guild: Guild) -> None:
if self.category is not None:
kwargs["category"] = find_channel_by_name(guild.channels, self.category.name) # type: ignore

if self.overwrites is not None and self.overwrites != {}:
kwargs["overwrites"] = self.overwrites # type: ignore
if self.overwrites is not None:
overwrites = {
AlterChannelCommand.get_role_by_role_type(role_type, guild): permissions
for role_type, permissions in self.overwrites.items()
}
if overwrites != channel.overwrites:
kwargs["overwrites"] = overwrites # type: ignore

if channel.type == ChannelType.forum:
await guild.fetch_emojis()
Expand Down Expand Up @@ -547,7 +556,7 @@ async def apply(self, guild: Guild) -> None:
@dataclasses.dataclass(frozen=True)
class CreateCategoryCommand(Command):
name: str
overwrites: Overwrites | None = None
overwrites: Overwrites = dataclasses.field(default_factory=dict)

def __str__(self) -> str:
return f'CREATE Category "{self.name}"'
Expand All @@ -563,7 +572,7 @@ class CreateChannelCommand(Command):
name: str
channel_type: ChannelType
category: Channel | None = None
overwrites: Overwrites | None = None
overwrites: Overwrites = dataclasses.field(default_factory=dict)
topic: str = ""
# Bot use only:
use_case: ChannelUseCase | None = None
Expand Down Expand Up @@ -615,7 +624,7 @@ async def apply(self, guild: Guild) -> None:
class CreateForumCommand(Command):
name: str
category: Channel | None = None
overwrites: Overwrites | None = None
overwrites: Overwrites = dataclasses.field(default_factory=dict)
topic: str = ""
default_reaction_emoji: str | None = None
available_tags: list[ForumTagDefinition] = dataclasses.field(default_factory=list)
Expand Down
2 changes: 1 addition & 1 deletion src/sr/discord_bot/guild.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ async def create_channels(client: "BotClient", guild: Guild) -> None:
await asyncio.sleep(.5) # avoid hitting rate limits

logging.info("Channels created")
await client._set_roles_and_channels(guild)
await client.set_roles_and_channels(guild)


async def send_template_messages(client: "BotClient", guild: Guild) -> None:
Expand Down
9 changes: 6 additions & 3 deletions src/sr/discord_bot/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from typing import Any

from discord import ForumTag, ChannelType, PartialEmoji
from discord import ForumTag, ChannelType, PartialEmoji, PermissionOverwrite


class RoleType(StrEnum):
Expand All @@ -19,7 +19,7 @@ class RoleType(StrEnum):
UNVERIFIED_BLUESHIRT = "unverified_blueshirt"


Overwrites = dict[RoleType, dict[str, bool]]
Overwrites = dict[RoleType, PermissionOverwrite]


class ChannelUseCase(StrEnum):
Expand Down Expand Up @@ -77,7 +77,10 @@ def load( # type: ignore[misc] # This gets validated when the YAML is validated
channel = cls(
name=data["name"],
old_names=data.get("old_names", []),
overwrites=data.get("permissions", {}),
overwrites={
RoleType(role_type): PermissionOverwrite(**permissions)
for role_type, permissions in data.get("permissions", {}).items()
},
topic=data.get("topic", ""),
category=category,
channel_type=ChannelType[data.get("channel_type", default_type)],
Expand Down
13 changes: 9 additions & 4 deletions src/sr/discord_bot/teams.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import re
from typing import List, NamedTuple
from statistics import mean
from collections import defaultdict

import discord

from sr.discord_bot.constants import ROLE_PREFIX
TEAM_ROLE_REGEX = re.compile('Team (?P<TLA>[A-Z]{3}\d?)')


class TeamData(NamedTuple):
Expand All @@ -28,7 +29,7 @@ def school(self) -> str:

def __str__(self) -> str:
data_str = f'{self.TLA:<15} {self.members:>2}'
if self.leader is False:
if not self.leader:
data_str += ' No supervisor'
return data_str

Expand All @@ -42,9 +43,13 @@ def gen_team_memberships(self, guild: discord.Guild, leader_role: discord.Role)
"""Generate a list of TeamData objects for the given guild, stored in teams_data."""
teams_data = []

for role in filter(lambda role: role.name.startswith(ROLE_PREFIX), guild.roles):
for role in guild.roles:
match = TEAM_ROLE_REGEX.match(role.name)
if match is None:
continue

team_data = TeamData(
TLA=role.name[len(ROLE_PREFIX):],
TLA=match['TLA'],
members=len(list(filter(
lambda member: leader_role not in member.roles,
role.members,
Expand Down
Loading
Loading