Skip to content

Commit ae487f6

Browse files
committed
Add endpoint for team create update and delete
1 parent fe58699 commit ae487f6

File tree

1 file changed

+104
-0
lines changed

1 file changed

+104
-0
lines changed

transformerlab/routers/teams.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
from fastapi import APIRouter, Depends, HTTPException
2+
from sqlalchemy.ext.asyncio import AsyncSession
3+
from transformerlab.shared.models.user_model import User, Team, UserTeam, get_async_session
4+
from transformerlab.models.users import current_active_user
5+
from pydantic import BaseModel
6+
from sqlalchemy import select, delete, update
7+
8+
9+
class TeamCreate(BaseModel):
10+
name: str
11+
12+
13+
class TeamUpdate(BaseModel):
14+
name: str
15+
16+
17+
class TeamResponse(BaseModel):
18+
id: str
19+
name: str
20+
21+
22+
router = APIRouter(tags=["teams"])
23+
24+
25+
@router.post("/teams", response_model=TeamResponse)
26+
async def create_team(
27+
team_data: TeamCreate,
28+
session: AsyncSession = Depends(get_async_session),
29+
user: User = Depends(current_active_user),
30+
):
31+
# Create team
32+
team = Team(name=team_data.name)
33+
session.add(team)
34+
await session.commit()
35+
await session.refresh(team)
36+
37+
# Add user to the team
38+
user_team = UserTeam(user_id=str(user.id), team_id=team.id)
39+
session.add(user_team)
40+
await session.commit()
41+
42+
return TeamResponse(id=team.id, name=team.name)
43+
44+
45+
@router.put("/teams/{team_id}", response_model=TeamResponse)
46+
async def update_team(
47+
team_id: str,
48+
team_data: TeamUpdate,
49+
session: AsyncSession = Depends(get_async_session),
50+
user: User = Depends(current_active_user),
51+
):
52+
# Check if user is in the team
53+
stmt = select(UserTeam).where(UserTeam.user_id == str(user.id), UserTeam.team_id == team_id)
54+
result = await session.execute(stmt)
55+
if not result.scalar_one_or_none():
56+
raise HTTPException(status_code=403, detail="Not authorized to update this team")
57+
58+
# Update
59+
stmt = update(Team).where(Team.id == team_id).values(name=team_data.name)
60+
await session.execute(stmt)
61+
await session.commit()
62+
63+
# Fetch updated
64+
stmt = select(Team).where(Team.id == team_id)
65+
result = await session.execute(stmt)
66+
team = result.scalar_one()
67+
68+
return TeamResponse(id=team.id, name=team.name)
69+
70+
71+
@router.delete("/teams/{team_id}")
72+
async def delete_team(
73+
team_id: str,
74+
session: AsyncSession = Depends(get_async_session),
75+
user: User = Depends(current_active_user),
76+
):
77+
# Check if user is in the team
78+
stmt = select(UserTeam).where(UserTeam.user_id == str(user.id), UserTeam.team_id == team_id)
79+
result = await session.execute(stmt)
80+
if not result.scalar_one_or_none():
81+
raise HTTPException(status_code=403, detail="Not authorized to delete this team")
82+
83+
# Check if user has other teams
84+
stmt = select(UserTeam).where(UserTeam.user_id == str(user.id))
85+
result = await session.execute(stmt)
86+
user_teams = result.scalars().all()
87+
if len(user_teams) <= 1:
88+
raise HTTPException(status_code=400, detail="Cannot delete the last team")
89+
90+
# Check if team has only this user
91+
stmt = select(UserTeam).where(UserTeam.team_id == team_id)
92+
result = await session.execute(stmt)
93+
team_users = result.scalars().all()
94+
if len(team_users) > 1:
95+
raise HTTPException(status_code=400, detail="Cannot delete team with multiple users")
96+
97+
# Delete associations and team
98+
stmt = delete(UserTeam).where(UserTeam.team_id == team_id)
99+
await session.execute(stmt)
100+
stmt = delete(Team).where(Team.id == team_id)
101+
await session.execute(stmt)
102+
await session.commit()
103+
104+
return {"message": "Team deleted"}

0 commit comments

Comments
 (0)