Skip to content

Commit 66cfc79

Browse files
authored
Merge branch 'main' into fix/run-button-compatible-message
2 parents d67ba08 + 11d0af3 commit 66cfc79

File tree

136 files changed

+3744
-1190
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

136 files changed

+3744
-1190
lines changed
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
"""add_team_id_to_config_table
2+
3+
Revision ID: c78d76a6d65c
4+
Revises: 1f7cb465d15a
5+
Create Date: 2025-12-04 11:23:22.165544
6+
7+
"""
8+
9+
from typing import Sequence, Union
10+
11+
from alembic import op
12+
import sqlalchemy as sa
13+
14+
15+
# revision identifiers, used by Alembic.
16+
revision: str = "c78d76a6d65c"
17+
down_revision: Union[str, Sequence[str], None] = "1f7cb465d15a"
18+
branch_labels: Union[str, Sequence[str], None] = None
19+
depends_on: Union[str, Sequence[str], None] = None
20+
21+
22+
def upgrade() -> None:
23+
"""Upgrade schema."""
24+
connection = op.get_bind()
25+
26+
# Check existing columns
27+
column_result = connection.execute(sa.text("PRAGMA table_info(config)"))
28+
existing_columns = [row[1] for row in column_result.fetchall()]
29+
30+
# Get existing indexes by querying SQLite directly
31+
# SQLite stores unique constraints as unique indexes
32+
index_result = connection.execute(
33+
sa.text("SELECT name FROM sqlite_master WHERE type='index' AND tbl_name='config'")
34+
)
35+
existing_index_names = [row[0] for row in index_result.fetchall()]
36+
37+
# Add columns first (outside batch mode to avoid circular dependency)
38+
# Only add if they don't already exist
39+
if "user_id" not in existing_columns:
40+
op.add_column("config", sa.Column("user_id", sa.String(), nullable=True))
41+
if "team_id" not in existing_columns:
42+
op.add_column("config", sa.Column("team_id", sa.String(), nullable=True))
43+
44+
# Handle indexes outside of batch mode to avoid type inference issues
45+
# Drop existing unique index on key if it exists (to recreate as non-unique)
46+
if "ix_config_key" in existing_index_names:
47+
# Check if it's unique by querying the index definition
48+
index_info = connection.execute(
49+
sa.text("SELECT sql FROM sqlite_master WHERE type='index' AND name='ix_config_key'")
50+
).fetchone()
51+
if index_info and index_info[0] and "UNIQUE" in index_info[0].upper():
52+
# Drop the unique index using raw SQL to avoid batch mode issues
53+
connection.execute(sa.text("DROP INDEX IF EXISTS ix_config_key"))
54+
existing_index_names.remove("ix_config_key") # Update our list
55+
56+
# Create new indexes (non-unique) - these can be done outside batch mode
57+
if "ix_config_key" not in existing_index_names:
58+
op.create_index("ix_config_key", "config", ["key"], unique=False)
59+
if "ix_config_user_id" not in existing_index_names:
60+
op.create_index("ix_config_user_id", "config", ["user_id"], unique=False)
61+
if "ix_config_team_id" not in existing_index_names:
62+
op.create_index("ix_config_team_id", "config", ["team_id"], unique=False)
63+
64+
# For SQLite, unique constraints are stored as unique indexes
65+
# Create the unique constraint as a unique index using raw SQL to avoid batch mode issues
66+
if "uq_config_user_team_key" not in existing_index_names:
67+
connection.execute(
68+
sa.text("CREATE UNIQUE INDEX IF NOT EXISTS uq_config_user_team_key ON config(user_id, team_id, key)")
69+
)
70+
71+
# Migrate existing configs to admin user's first team
72+
# Note: Don't call connection.commit() - Alembic manages transactions
73+
connection = op.get_bind()
74+
# Find admin user's first team
75+
admin_team_result = connection.execute(
76+
sa.text("""
77+
SELECT ut.team_id
78+
FROM users_teams ut
79+
JOIN user u ON ut.user_id = u.id
80+
WHERE u.email = 'admin@example.com'
81+
LIMIT 1
82+
""")
83+
)
84+
admin_team_row = admin_team_result.fetchone()
85+
86+
if admin_team_row:
87+
admin_team_id = admin_team_row[0]
88+
# Update all existing configs (where team_id is NULL) to use admin team
89+
connection.execute(
90+
sa.text("UPDATE config SET team_id = :team_id WHERE team_id IS NULL"), {"team_id": admin_team_id}
91+
)
92+
print(f"✅ Migrated existing configs to team {admin_team_id}")
93+
else:
94+
# If no admin team found, try to get any user's first team
95+
any_team_result = connection.execute(sa.text("SELECT team_id FROM users_teams LIMIT 1"))
96+
any_team_row = any_team_result.fetchone()
97+
if any_team_row:
98+
any_team_id = any_team_row[0]
99+
connection.execute(
100+
sa.text("UPDATE config SET team_id = :team_id WHERE team_id IS NULL"), {"team_id": any_team_id}
101+
)
102+
print(f"✅ Migrated existing configs to team {any_team_id}")
103+
else:
104+
# No teams found, delete existing configs
105+
deleted_count = connection.execute(sa.text("DELETE FROM config WHERE team_id IS NULL")).rowcount
106+
print(f"⚠️ No teams found, deleted {deleted_count} config entries")
107+
# ### end Alembic commands ###
108+
109+
110+
def downgrade() -> None:
111+
"""Downgrade schema."""
112+
connection = op.get_bind()
113+
114+
# Check existing indexes
115+
index_result = connection.execute(
116+
sa.text("SELECT name FROM sqlite_master WHERE type='index' AND tbl_name='config'")
117+
)
118+
existing_index_names = [row[0] for row in index_result.fetchall()]
119+
120+
# Check existing columns
121+
column_result = connection.execute(sa.text("PRAGMA table_info(config)"))
122+
existing_columns = [row[1] for row in column_result.fetchall()]
123+
124+
# Drop indexes and constraints outside of batch mode to avoid type inference issues
125+
# Drop unique constraint (stored as unique index in SQLite)
126+
if "uq_config_user_team_key" in existing_index_names:
127+
connection.execute(sa.text("DROP INDEX IF EXISTS uq_config_user_team_key"))
128+
129+
# Drop indexes
130+
if "ix_config_team_id" in existing_index_names:
131+
op.drop_index("ix_config_team_id", table_name="config")
132+
if "ix_config_user_id" in existing_index_names:
133+
op.drop_index("ix_config_user_id", table_name="config")
134+
if "ix_config_key" in existing_index_names:
135+
op.drop_index("ix_config_key", table_name="config")
136+
137+
# Drop columns using raw SQL to avoid batch mode type inference issues
138+
# SQLite doesn't support DROP COLUMN directly, so we recreate the table
139+
if "team_id" in existing_columns or "user_id" in existing_columns:
140+
# Create new table without user_id and team_id columns
141+
connection.execute(
142+
sa.text("""
143+
CREATE TABLE config_new (
144+
id INTEGER NOT NULL PRIMARY KEY,
145+
key VARCHAR NOT NULL,
146+
value VARCHAR
147+
)
148+
""")
149+
)
150+
# Copy data from old table to new table (only id, key, value columns)
151+
connection.execute(sa.text("INSERT INTO config_new (id, key, value) SELECT id, key, value FROM config"))
152+
# Drop old table (this also drops all indexes)
153+
connection.execute(sa.text("DROP TABLE config"))
154+
# Rename new table to original name
155+
connection.execute(sa.text("ALTER TABLE config_new RENAME TO config"))
156+
# Recreate the original unique index on key (it was dropped with the old table)
157+
op.create_index("ix_config_key", "config", ["key"], unique=True)
158+
else:
159+
# If we're not dropping columns, just recreate the unique index on key
160+
op.create_index("ix_config_key", "config", ["key"], unique=True)
161+
# ### end Alembic commands ###

api/pyproject.toml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ dependencies = [
4646
[project.optional-dependencies]
4747
nvidia = [
4848
# NVIDIA-specific packages
49-
"torch==2.8.0",
50-
"torchaudio==2.8.0",
51-
"torchvision==0.23.0",
49+
"torch==2.9.1",
50+
"torchaudio==2.9.1",
51+
"torchvision==0.24.1",
5252
"nvidia-ml-py==12.575.51",
5353
# Packages with versions (using no-gpu/nvidia versions)
5454
"accelerate==1.3.0",
@@ -68,9 +68,9 @@ nvidia = [
6868
]
6969
rocm = [
7070
# ROCm-specific packages
71-
"torch==2.8.0+rocm6.4",
72-
"torchaudio==2.8.0+rocm6.4",
73-
"torchvision==0.23.0+rocm6.4",
71+
"torch==2.9.1+rocm6.4",
72+
"torchaudio==2.9.1+rocm6.4",
73+
"torchvision==0.24.1+rocm6.4",
7474
"pyrsmi==0.2.0",
7575
# Packages with ROCm-specific versions
7676
"accelerate==1.6.0",
@@ -90,9 +90,9 @@ rocm = [
9090
]
9191
cpu = [
9292
# CPU-specific packages
93-
"torch==2.8.0",
94-
"torchaudio==2.8.0",
95-
"torchvision==0.23.0",
93+
"torch==2.9.1",
94+
"torchaudio==2.9.1",
95+
"torchvision==0.24.1",
9696
# Packages with versions (using no-gpu/cpu versions)
9797
"accelerate==1.3.0",
9898
"aiosqlite==0.20.0",
@@ -109,4 +109,4 @@ cpu = [
109109
"tensorboard==2.18.0",
110110
"tiktoken==0.8.0",
111111
"watchfiles==1.0.4",
112-
]
112+
]

api/run.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ elif command -v rocminfo &> /dev/null; then
105105
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib:/opt/rocm/lib64
106106
fi
107107

108+
# Temporary: Turn off python buffering or debug output made by print() may not show up in logs
109+
export PYTHONUNBUFFERED=1
110+
108111
echo "▶️ Starting the API server:"
109112
if [ "$RELOAD" = true ]; then
110113
echo "🔁 Reload the server on file changes"

api/test/api/test_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
def test_set_config(client):
33
response = client.get("/config/set", params={"k": "api_test_key", "v": "test_value"})
44
assert response.status_code == 200
5-
assert response.json() == {"key": "api_test_key", "value": "test_value"}
5+
assert response.json() == {"key": "api_test_key", "value": "test_value", "team_wide": True}
66

77

88
def test_get_config(client):

api/test/server/test_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def test_set(live_server):
2222

2323
response = requests.get(f"{live_server}/config/set", params={"k": "message", "v": "Hello, World!"}, headers=headers)
2424
assert response.status_code == 200
25-
assert response.json() == {"key": "message", "value": "Hello, World!"}
25+
assert response.json() == {"key": "message", "value": "Hello, World!", "team_wide": True}
2626

2727

2828
@pytest.mark.live_server

api/transformerlab/db/db.py

Lines changed: 82 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from sqlalchemy import select
2-
from sqlalchemy.dialects.sqlite import insert # Correct import for SQLite upsert
32

43
# from sqlalchemy import create_engine
54
from sqlalchemy.ext.asyncio import AsyncSession
@@ -25,17 +24,93 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
2524
###############
2625

2726

28-
async def config_get(key: str):
27+
async def config_get(key: str, user_id: str | None = None, team_id: str | None = None):
28+
"""
29+
Get config value with priority: user-specific -> team-specific -> global.
30+
31+
Priority order:
32+
1. User-specific (user_id set, team_id matches current team)
33+
2. Team-specific (user_id IS NULL, team_id set)
34+
"""
2935
async with async_session() as session:
30-
result = await session.execute(select(Config.value).where(Config.key == key))
36+
# First try user-specific config (if user_id provided)
37+
if user_id and team_id:
38+
result = await session.execute(
39+
select(Config.value)
40+
.where(Config.key == key, Config.user_id == user_id, Config.team_id == team_id)
41+
.limit(1)
42+
)
43+
row = result.scalar_one_or_none()
44+
if row is not None:
45+
return row
46+
47+
# Then try team-specific config (user_id IS NULL, team_id set)
48+
if team_id:
49+
result = await session.execute(
50+
select(Config.value)
51+
.where(Config.key == key, Config.user_id.is_(None), Config.team_id == team_id)
52+
.limit(1)
53+
)
54+
row = result.scalar_one_or_none()
55+
if row is not None:
56+
return row
57+
58+
# Finally fallback to global config (user_id IS NULL, team_id IS NULL)
59+
result = await session.execute(
60+
select(Config.value).where(Config.key == key, Config.user_id.is_(None), Config.team_id.is_(None)).limit(1)
61+
)
3162
row = result.scalar_one_or_none()
3263
return row
3364

3465

35-
async def config_set(key: str, value: str):
36-
stmt = insert(Config).values(key=key, value=value)
37-
stmt = stmt.on_conflict_do_update(index_elements=["key"], set_={"value": value})
66+
async def config_set(key: str, value: str, user_id: str | None = None, team_id: str | None = None):
67+
"""
68+
Set config value.
69+
70+
Args:
71+
key: Config key
72+
value: Config value
73+
user_id: User ID for user-specific config. If None, sets team-wide config.
74+
team_id: Team ID for team-specific config. If None, sets global config.
75+
76+
Config types:
77+
- User-specific: user_id is set, team_id is set
78+
- Team-wide: user_id is None, team_id is set
79+
"""
3880
async with async_session() as session:
39-
await session.execute(stmt)
81+
# Check if config already exists
82+
if user_id is None and team_id is None:
83+
# Global config: both user_id and team_id are NULL
84+
result = await session.execute(
85+
select(Config).where(Config.key == key, Config.user_id.is_(None), Config.team_id.is_(None))
86+
)
87+
elif user_id is None:
88+
# Team-wide config: user_id is NULL, team_id is set
89+
result = await session.execute(
90+
select(Config).where(Config.key == key, Config.user_id.is_(None), Config.team_id == team_id)
91+
)
92+
else:
93+
# User-specific config: both user_id and team_id are set
94+
# Note: team_id should always be set when user_id is set (validated by router)
95+
if team_id is None:
96+
raise ValueError("team_id is required when user_id is set for user-specific configs")
97+
result = await session.execute(
98+
select(Config).where(
99+
Config.key == key,
100+
Config.user_id == user_id,
101+
Config.team_id == team_id,
102+
)
103+
)
104+
105+
existing = result.scalar_one_or_none()
106+
107+
if existing:
108+
# Update existing config
109+
existing.value = value
110+
else:
111+
# Insert new config
112+
new_config = Config(key=key, value=value, user_id=user_id, team_id=team_id)
113+
session.add(new_config)
114+
40115
await session.commit()
41116
return

api/transformerlab/fastchat_openai_api.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
ErrorResponse,
4242
ModelCard,
4343
ModelList,
44-
ModelPermission,
4544
UsageInfo,
4645
)
4746
from pydantic import BaseModel as PydanticBaseModel
@@ -520,7 +519,7 @@ async def show_available_models():
520519
# TODO: return real model permission details
521520
model_cards = []
522521
for m in models:
523-
model_cards.append(ModelCard(id=m, root=m, permission=[ModelPermission()]))
522+
model_cards.append(ModelCard(id=m, root=m, permission=[]))
524523
return ModelList(data=model_cards)
525524

526525
# If no models, refresh and try again
@@ -532,7 +531,7 @@ async def show_available_models():
532531
# TODO: return real model permission details
533532
model_cards = []
534533
for m in models:
535-
model_cards.append(ModelCard(id=m, root=m, permission=[ModelPermission()]))
534+
model_cards.append(ModelCard(id=m, root=m, permission=[]))
536535
return ModelList(data=model_cards)
537536

538537

0 commit comments

Comments
 (0)