Skip to content

Commit 5415261

Browse files
authored
Merge pull request #22 from transformerlab/fix/change-all-dirs-to-functions
Change all dirs to functions
2 parents 452ca32 + e835f67 commit 5415261

File tree

6 files changed

+85
-56
lines changed

6 files changed

+85
-56
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "transformerlab"
7-
version = "0.0.14"
7+
version = "0.0.15"
88
description = "Python SDK for Transformer Lab"
99
readme = "README.md"
1010
requires-python = ">=3.10"

src/lab/dataset.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import os
22
from werkzeug.utils import secure_filename
33

4-
from .dirs import DATASETS_DIR
4+
from .dirs import get_datasets_dir
55
from .labresource import BaseLabResource
66

77

88
class Dataset(BaseLabResource):
99
def get_dir(self):
1010
"""Abstract method on BaseLabResource"""
1111
dataset_id_safe = secure_filename(str(self.id))
12-
return os.path.join(DATASETS_DIR, dataset_id_safe)
12+
return os.path.join(get_datasets_dir(), dataset_id_safe)
1313

1414
def _default_json(self):
1515
# Default metadata modeled after API dataset table fields
@@ -44,10 +44,11 @@ def get_metadata(self):
4444
@staticmethod
4545
def list_all():
4646
results = []
47-
if not os.path.isdir(DATASETS_DIR):
47+
datasets_dir = get_datasets_dir()
48+
if not os.path.isdir(datasets_dir):
4849
return results
49-
for entry in os.listdir(DATASETS_DIR):
50-
full = os.path.join(DATASETS_DIR, entry)
50+
for entry in os.listdir(datasets_dir):
51+
full = os.path.join(datasets_dir, entry)
5152
if not os.path.isdir(full):
5253
continue
5354
# Attempt to read index.json (or latest snapshot)

src/lab/dirs.py

Lines changed: 55 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
print(f"Using default home directory: {HOME_DIR}")
1818

1919
# Context var for organization id (set by host app/session)
20-
_current_org_id: contextvars.ContextVar[str | None] = contextvars.ContextVar("current_org_id", default=None)
20+
_current_org_id: contextvars.ContextVar[str | None] = contextvars.ContextVar(
21+
"current_org_id", default=None
22+
)
2123

2224

2325
def set_organization_id(organization_id: str | None) -> None:
@@ -70,23 +72,27 @@ def get_workspace_dir() -> str:
7072
"TFL_HOME_DIR", os.path.join(str(os.path.expanduser("~")), ".transformerlab")
7173
)
7274

75+
7376
def get_experiments_dir() -> str:
7477
path = os.path.join(get_workspace_dir(), "experiments")
7578
os.makedirs(name=path, exist_ok=True)
7679
return path
7780

81+
7882
def get_jobs_dir() -> str:
7983
path = os.path.join(get_workspace_dir(), "jobs")
8084
os.makedirs(name=path, exist_ok=True)
8185
return path
8286

83-
# GLOBAL_LOG_PATH
84-
# MTMIGRATE: This doesn't work in multi-tenant world
85-
GLOBAL_LOG_PATH = os.path.join(get_workspace_dir(), "transformerlab.log")
8687

87-
# OTHER LOGS DIR:
88-
LOGS_DIR = os.path.join(HOME_DIR, "logs")
89-
os.makedirs(name=LOGS_DIR, exist_ok=True)
88+
def get_global_log_path() -> str:
89+
return os.path.join(get_workspace_dir(), "transformerlab.log")
90+
91+
92+
def get_logs_dir() -> str:
93+
path = os.path.join(HOME_DIR, "logs")
94+
os.makedirs(name=path, exist_ok=True)
95+
return path
9096

9197

9298
# TODO: Move this to Experiment
@@ -95,47 +101,66 @@ def experiment_dir_by_name(experiment_name: str) -> str:
95101
return os.path.join(experiments_dir, experiment_name)
96102

97103

98-
PLUGIN_DIR = os.path.join(get_workspace_dir(), "plugins")
104+
def get_plugin_dir() -> str:
105+
return os.path.join(get_workspace_dir(), "plugins")
99106

100107

101108
def plugin_dir_by_name(plugin_name: str) -> str:
102109
plugin_name = secure_filename(plugin_name)
103-
return os.path.join(PLUGIN_DIR, plugin_name)
110+
return os.path.join(get_plugin_dir(), plugin_name)
104111

105112

106-
MODELS_DIR = os.path.join(get_workspace_dir(), "models")
107-
os.makedirs(name=MODELS_DIR, exist_ok=True)
113+
def get_models_dir() -> str:
114+
path = os.path.join(get_workspace_dir(), "models")
115+
os.makedirs(name=path, exist_ok=True)
116+
return path
117+
108118

109-
DATASETS_DIR = os.path.join(get_workspace_dir(), "datasets")
110-
os.makedirs(name=DATASETS_DIR, exist_ok=True)
119+
def get_datasets_dir() -> str:
120+
path = os.path.join(get_workspace_dir(), "datasets")
121+
os.makedirs(name=path, exist_ok=True)
122+
return path
111123

112-
# TASKS_DIR
113-
TASKS_DIR = os.path.join(WORKSPACE_DIR, "tasks")
114-
os.makedirs(name=TASKS_DIR, exist_ok=True)
124+
125+
def get_tasks_dir() -> str:
126+
path = os.path.join(get_workspace_dir(), "tasks")
127+
os.makedirs(name=path, exist_ok=True)
128+
return path
115129

116130

117131
def dataset_dir_by_id(dataset_id: str) -> str:
118-
return os.path.join(DATASETS_DIR, dataset_id)
132+
return os.path.join(get_datasets_dir(), dataset_id)
119133

120134

121-
TEMP_DIR = os.path.join(get_workspace_dir(), "temp")
122-
os.makedirs(name=TEMP_DIR, exist_ok=True)
135+
def get_temp_dir() -> str:
136+
path = os.path.join(get_workspace_dir(), "temp")
137+
os.makedirs(name=path, exist_ok=True)
138+
return path
139+
140+
141+
def get_prompt_templates_dir() -> str:
142+
path = os.path.join(get_workspace_dir(), "prompt_templates")
143+
os.makedirs(name=path, exist_ok=True)
144+
return path
145+
123146

147+
def get_tools_dir() -> str:
148+
path = os.path.join(get_workspace_dir(), "tools")
149+
os.makedirs(name=path, exist_ok=True)
150+
return path
124151

125-
# Prompt Templates Dir:
126-
PROMPT_TEMPLATES_DIR = os.path.join(get_workspace_dir(), "prompt_templates")
127-
os.makedirs(name=PROMPT_TEMPLATES_DIR, exist_ok=True)
128152

129-
# Tools Dir:
130-
TOOLS_DIR = os.path.join(get_workspace_dir(), "tools")
131-
os.makedirs(name=TOOLS_DIR, exist_ok=True)
153+
def get_batched_prompts_dir() -> str:
154+
path = os.path.join(get_workspace_dir(), "batched_prompts")
155+
os.makedirs(name=path, exist_ok=True)
156+
return path
132157

133-
# Batched Prompts Dir:
134-
BATCHED_PROMPTS_DIR = os.path.join(get_workspace_dir(), "batched_prompts")
135-
os.makedirs(name=BATCHED_PROMPTS_DIR, exist_ok=True)
136158

137-
GALLERIES_CACHE_DIR = os.path.join(get_workspace_dir(), "galleries")
138-
os.makedirs(name=GALLERIES_CACHE_DIR, exist_ok=True)
159+
def get_galleries_cache_dir() -> str:
160+
path = os.path.join(get_workspace_dir(), "galleries")
161+
os.makedirs(name=path, exist_ok=True)
162+
return path
163+
139164

140165
# Evals output file:
141166
# TODO: These should probably be in the plugin subclasses

src/lab/model.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
import json
33
from werkzeug.utils import secure_filename
44

5-
from .dirs import MODELS_DIR
5+
from .dirs import get_models_dir
66
from .labresource import BaseLabResource
77

88

99
class Model(BaseLabResource):
1010
def get_dir(self):
1111
"""Abstract method on BaseLabResource"""
1212
model_id_safe = secure_filename(str(self.id))
13-
return os.path.join(MODELS_DIR, model_id_safe)
13+
return os.path.join(get_models_dir(), model_id_safe)
1414

1515
def _default_json(self):
1616
# Default metadata modeled after API model table fields
@@ -44,10 +44,11 @@ def get_metadata(self):
4444
def list_all():
4545
"""List all models in the filesystem, similar to dataset service"""
4646
results = []
47-
if not os.path.isdir(MODELS_DIR):
47+
models_dir = get_models_dir()
48+
if not os.path.isdir(models_dir):
4849
return results
49-
for entry in os.listdir(MODELS_DIR):
50-
full = os.path.join(MODELS_DIR, entry)
50+
for entry in os.listdir(models_dir):
51+
full = os.path.join(models_dir, entry)
5152
if not os.path.isdir(full):
5253
continue
5354
# Attempt to read index.json (or latest snapshot)

src/lab/task.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
from datetime import datetime
33
from werkzeug.utils import secure_filename
44

5-
from .dirs import TASKS_DIR
5+
from .dirs import get_tasks_dir
66
from .labresource import BaseLabResource
77

88

99
class Task(BaseLabResource):
1010
def get_dir(self):
1111
"""Abstract method on BaseLabResource"""
1212
task_id_safe = secure_filename(str(self.id))
13-
return os.path.join(TASKS_DIR, task_id_safe)
13+
return os.path.join(get_tasks_dir(), task_id_safe)
1414

1515
def _default_json(self):
1616
# Default metadata modeled after API tasks table fields
@@ -61,10 +61,11 @@ def get_metadata(self):
6161
def list_all():
6262
"""List all tasks in the filesystem"""
6363
results = []
64-
if not os.path.isdir(TASKS_DIR):
64+
tasks_dir = get_tasks_dir()
65+
if not os.path.isdir(tasks_dir):
6566
return results
66-
for entry in os.listdir(TASKS_DIR):
67-
full = os.path.join(TASKS_DIR, entry)
67+
for entry in os.listdir(tasks_dir):
68+
full = os.path.join(tasks_dir, entry)
6869
if not os.path.isdir(full):
6970
continue
7071
# Attempt to read index.json (or latest snapshot)
@@ -108,10 +109,11 @@ def get_by_id(task_id: str):
108109
@staticmethod
109110
def delete_all():
110111
"""Delete all tasks"""
111-
if not os.path.isdir(TASKS_DIR):
112+
tasks_dir = get_tasks_dir()
113+
if not os.path.isdir(tasks_dir):
112114
return
113-
for entry in os.listdir(TASKS_DIR):
114-
full = os.path.join(TASKS_DIR, entry)
115+
for entry in os.listdir(tasks_dir):
116+
full = os.path.join(tasks_dir, entry)
115117
if os.path.isdir(full):
116118
import shutil
117119
shutil.rmtree(full)

tests/test_dirs.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ def test_dirs_structure_created(monkeypatch, tmp_path):
2424
# Key directories exist
2525
assert os.path.isdir(dirs.get_experiments_dir())
2626
assert os.path.isdir(dirs.get_jobs_dir())
27-
assert os.path.isdir(dirs.MODELS_DIR)
28-
assert os.path.isdir(dirs.DATASETS_DIR)
29-
assert os.path.isdir(dirs.TEMP_DIR)
30-
assert os.path.isdir(dirs.PROMPT_TEMPLATES_DIR)
31-
assert os.path.isdir(dirs.TOOLS_DIR)
32-
assert os.path.isdir(dirs.BATCHED_PROMPTS_DIR)
33-
assert os.path.isdir(dirs.GALLERIES_CACHE_DIR)
27+
assert os.path.isdir(dirs.get_models_dir())
28+
assert os.path.isdir(dirs.get_datasets_dir())
29+
assert os.path.isdir(dirs.get_temp_dir())
30+
assert os.path.isdir(dirs.get_prompt_templates_dir())
31+
assert os.path.isdir(dirs.get_tools_dir())
32+
assert os.path.isdir(dirs.get_batched_prompts_dir())
33+
assert os.path.isdir(dirs.get_galleries_cache_dir())

0 commit comments

Comments
 (0)