Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions channels.example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
permissions:
everyone:
send_messages: false
volunteer:
blueshirt:
read_messages: true
send_messages: true
- name: blog
Expand Down Expand Up @@ -48,15 +48,15 @@
everyone:
read_messages: false
send_messages: false
volunteer:
blueshirt:
read_messages: true
send_messages: true
channels:
- name: audit-log
use_case: audit
- name: blueshirt-onboarding
permissions:
unverified_volunteer:
unverified_blueshirt:
read_messages: true
- name: blueshirt-banter
- name: blueshirt-serious-business
Expand Down Expand Up @@ -85,7 +85,7 @@
everyone:
read_messages: false
send_messages: false
volunteer:
blueshirt:
read_messages: true
send_messages: true
supervisor:
Expand Down
16 changes: 8 additions & 8 deletions src/sr/discord_bot/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import logging
from typing import List, Literal
from pathlib import Path

import yaml
import discord
Expand Down Expand Up @@ -124,7 +125,7 @@ async def on_ready(self) -> None:
await self.apply_changes()
await self.close()

async def _set_roles_and_channels(self, guild: Guild) -> None:
async def set_roles_and_channels(self, guild: Guild) -> None:
roles = await guild.fetch_roles()
self.admin_role = find_role_by_name(roles, ADMIN_ROLE)
self.verified_role = find_role_by_name(roles, VERIFIED_ROLE)
Expand All @@ -151,11 +152,11 @@ async def setup_bot(self) -> None:
return

try:
await self._set_roles_and_channels(guild)
await self.set_roles_and_channels(guild)
except ValueError:
self.logger.info("Setting up guild...")
await setup_guild(self)
await self._set_roles_and_channels(guild)
await self.set_roles_and_channels(guild)

await check_bot_messages(self, self.guild)
self.teams_data.gen_team_memberships(self.guild, self.supervisor_role)
Expand Down Expand Up @@ -288,13 +289,12 @@ def _load_passwords(self) -> None:
teamname:password
```
"""
path = Path('passwords.json')
try:
with open('passwords.json') as f:
self.passwords = json.load(f)
self.passwords = json.loads(path.read_text())
except (json.JSONDecodeError, FileNotFoundError):
with open('passwords.json', 'w') as f:
f.write('{}')
self.passwords = {}
path.write_text('{}')
self.passwords = {}

def set_password(self, tla: str, password: str) -> None:
self.passwords[tla.upper()] = password
Expand Down
72 changes: 40 additions & 32 deletions src/sr/discord_bot/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,14 @@


def match_overwrites(overwrites: Overwrites, guild: Guild) -> Mapping[Role | Member | Object, PermissionOverwrite]:
role_overwrites: Mapping[Role | Member | Object, PermissionOverwrite] = {
guild.default_role: PermissionOverwrite(**overwrites.get(RoleType.EVERYONE, {})),
find_role_by_name(guild.roles, VERIFIED_ROLE): PermissionOverwrite(**overwrites.get(RoleType.VERIFIED, {})),
find_role_by_name(guild.roles, VOLUNTEER_ROLE): PermissionOverwrite(**overwrites.get(RoleType.BLUESHIRT, {})),
find_role_by_name(guild.roles, TEAM_LEADER_ROLE): PermissionOverwrite(
**overwrites.get(RoleType.SUPERVISOR, {}),
),
find_role_by_name(guild.roles, SPECIAL_ROLE): PermissionOverwrite(
**overwrites.get(RoleType.UNVERIFIED_BLUESHIRT, {}),
),
return {
guild.default_role: overwrites.get(RoleType.EVERYONE, PermissionOverwrite()),
find_role_by_name(guild.roles, VERIFIED_ROLE): overwrites.get(RoleType.VERIFIED, PermissionOverwrite()),
find_role_by_name(guild.roles, VOLUNTEER_ROLE): overwrites.get(RoleType.BLUESHIRT, PermissionOverwrite()),
find_role_by_name(guild.roles, TEAM_LEADER_ROLE): overwrites.get(RoleType.SUPERVISOR, PermissionOverwrite()),
find_role_by_name(guild.roles, SPECIAL_ROLE):
overwrites.get(RoleType.UNVERIFIED_BLUESHIRT, PermissionOverwrite()),
}
return role_overwrites


class ChannelSet:
Expand All @@ -84,14 +80,11 @@ def _get_role_type(self, role: Role, guild: Guild) -> RoleType:
return self._role_map[role]

def _get_overwrites(self, channel: GuildChannel) -> Overwrites:
overwrites: Overwrites = {}
for subject, permissions in channel.overwrites.items():
if isinstance(subject, Role):
overwrites[self._get_role_type(subject, channel.guild)] = {}
for permission, value in permissions:
if value is not None:
overwrites[self._get_role_type(subject, channel.guild)][permission] = value
return overwrites
return {
self._get_role_type(subject, channel.guild): permissions
for subject, permissions in channel.overwrites.items()
if isinstance(subject, Role)
}

def count(self, category: Channel | None) -> int:
"""Count the number of channels in a category, or top-level channels if category is None."""
Expand Down Expand Up @@ -197,7 +190,6 @@ def create_forum_channel(
def diff(cls, old: ChannelSet, new: ChannelSet) -> list[ChannelCommand]:
"""Calculate the difference between two ChannelSets."""
commands: list[ChannelCommand] = []
# TODO: Diff categories first (in case one gets renamed)
cls._diff_channels(old._channels, new._channels, commands)
return commands

Expand Down Expand Up @@ -253,7 +245,7 @@ def _diff_channels(cls, old: list[Channel], new: list[Channel], commands: list[C
commands.append(command)
else:
seen_old_channel_names.add(old_channel.name)
command = AlterChannelCommand.diff(old_channel, new_channel, new_index, old_channel.channel_type)
command = AlterChannelCommand.diff(old_channel, new_channel, old_channel.channel_type)
if command.has_changes:
commands.append(command)

Expand Down Expand Up @@ -424,7 +416,6 @@ class AlterChannelCommand(Command):
overwrites: Overwrites | None = None
topic: str | None = None
category: Channel | None = None
position: int | None = None
is_category: bool = False # Whether this command is for a category, only for display purposes
channel_type: ChannelType = ChannelType.text # Used for sorting
# Forum-specific:
Expand All @@ -434,8 +425,7 @@ class AlterChannelCommand(Command):
use_case: ChannelUseCase | None = None

@classmethod
def diff(cls, old_channel: Channel, new_channel: Channel,
new_position: int, channel_type: ChannelType) -> AlterChannelCommand:
def diff(cls, old_channel: Channel, new_channel: Channel, channel_type: ChannelType) -> AlterChannelCommand:
command = cls(old_name=old_channel.name, new_name=new_channel.name,
is_category=new_channel.is_category, channel_type=channel_type, use_case=new_channel.use_case)
changed_category = new_channel.category is not None and old_channel.category != new_channel.category
Expand All @@ -445,7 +435,6 @@ def diff(cls, old_channel: Channel, new_channel: Channel,
command.topic = new_channel.topic
if not old_channel.category or changed_category:
command.category = new_channel.category
command.position = new_position
return command

def is_rename(self) -> bool:
Expand Down Expand Up @@ -484,6 +473,22 @@ def __str__(self) -> str:
return f"ALTER Category \"{self.old_name}\" ({', '.join(changes)})"
return f"ALTER Channel #{self.old_name} ({', '.join(changes)})"

@classmethod
def get_role_by_role_type(cls, role_type: RoleType, guild: Guild) -> Role:
if role_type == RoleType.EVERYONE:
return guild.default_role

role_names = {
RoleType.VERIFIED: VERIFIED_ROLE,
RoleType.BLUESHIRT: VOLUNTEER_ROLE,
RoleType.SUPERVISOR: TEAM_LEADER_ROLE,
RoleType.UNVERIFIED_BLUESHIRT: SPECIAL_ROLE,
}

if role := discord.utils.get(guild.roles, name=role_names[role_type]):
return role
raise ValueError("Invalid role type")

async def apply(self, guild: Guild) -> None:
channel = discord.utils.get(guild.channels, name=self.old_name)
if channel is None:
Expand All @@ -499,8 +504,13 @@ async def apply(self, guild: Guild) -> None:
if self.category is not None:
kwargs["category"] = find_channel_by_name(guild.channels, self.category.name) # type: ignore

if self.overwrites is not None and self.overwrites != {}:
kwargs["overwrites"] = self.overwrites # type: ignore
if self.overwrites is not None:
overwrites = {
AlterChannelCommand.get_role_by_role_type(role_type, guild): permissions
for role_type, permissions in self.overwrites.items()
}
if overwrites != channel.overwrites:
kwargs["overwrites"] = overwrites # type: ignore

if channel.type == ChannelType.forum:
await guild.fetch_emojis()
Expand Down Expand Up @@ -547,14 +557,12 @@ async def apply(self, guild: Guild) -> None:
@dataclasses.dataclass(frozen=True)
class CreateCategoryCommand(Command):
name: str
overwrites: Overwrites | None = None
overwrites: Overwrites = dataclasses.field(default_factory=dict)

def __str__(self) -> str:
return f'CREATE Category "{self.name}"'

async def apply(self, guild: Guild) -> None:
if self.overwrites is None:
return
await guild.create_category(self.name, overwrites=match_overwrites(self.overwrites, guild))


Expand All @@ -563,7 +571,7 @@ class CreateChannelCommand(Command):
name: str
channel_type: ChannelType
category: Channel | None = None
overwrites: Overwrites | None = None
overwrites: Overwrites = dataclasses.field(default_factory=dict)
topic: str = ""
# Bot use only:
use_case: ChannelUseCase | None = None
Expand Down Expand Up @@ -615,7 +623,7 @@ async def apply(self, guild: Guild) -> None:
class CreateForumCommand(Command):
name: str
category: Channel | None = None
overwrites: Overwrites | None = None
overwrites: Overwrites = dataclasses.field(default_factory=dict)
topic: str = ""
default_reaction_emoji: str | None = None
available_tags: list[ForumTagDefinition] = dataclasses.field(default_factory=list)
Expand Down
2 changes: 1 addition & 1 deletion src/sr/discord_bot/guild.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ async def create_channels(client: "BotClient", guild: Guild) -> None:
await asyncio.sleep(.5) # avoid hitting rate limits

logging.info("Channels created")
await client._set_roles_and_channels(guild)
await client.set_roles_and_channels(guild)


async def send_template_messages(client: "BotClient", guild: Guild) -> None:
Expand Down
9 changes: 6 additions & 3 deletions src/sr/discord_bot/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from typing import Any

from discord import ForumTag, ChannelType, PartialEmoji
from discord import ForumTag, ChannelType, PartialEmoji, PermissionOverwrite


class RoleType(StrEnum):
Expand All @@ -19,7 +19,7 @@ class RoleType(StrEnum):
UNVERIFIED_BLUESHIRT = "unverified_blueshirt"


Overwrites = dict[RoleType, dict[str, bool]]
Overwrites = dict[RoleType, PermissionOverwrite]


class ChannelUseCase(StrEnum):
Expand Down Expand Up @@ -77,7 +77,10 @@ def load( # type: ignore[misc] # This gets validated when the YAML is validated
channel = cls(
name=data["name"],
old_names=data.get("old_names", []),
overwrites=data.get("permissions", {}),
overwrites={
RoleType(role_type): PermissionOverwrite(**permissions)
for role_type, permissions in data.get("permissions", {}).items()
},
topic=data.get("topic", ""),
category=category,
channel_type=ChannelType[data.get("channel_type", default_type)],
Expand Down
13 changes: 9 additions & 4 deletions src/sr/discord_bot/teams.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import re
from typing import List, NamedTuple
from statistics import mean
from collections import defaultdict

import discord

from sr.discord_bot.constants import ROLE_PREFIX
TEAM_ROLE_REGEX = re.compile(r'Team (?P<TLA>[A-Z]{3}\d?)')


class TeamData(NamedTuple):
Expand All @@ -28,7 +29,7 @@ def school(self) -> str:

def __str__(self) -> str:
data_str = f'{self.TLA:<15} {self.members:>2}'
if self.leader is False:
if not self.leader:
data_str += ' No supervisor'
return data_str

Expand All @@ -42,9 +43,13 @@ def gen_team_memberships(self, guild: discord.Guild, leader_role: discord.Role)
"""Generate a list of TeamData objects for the given guild, stored in teams_data."""
teams_data = []

for role in filter(lambda role: role.name.startswith(ROLE_PREFIX), guild.roles):
for role in guild.roles:
match = TEAM_ROLE_REGEX.match(role.name)
if match is None:
continue

team_data = TeamData(
TLA=role.name[len(ROLE_PREFIX):],
TLA=match['TLA'],
members=len(list(filter(
lambda member: leader_role not in member.roles,
role.members,
Expand Down
Loading
Loading