Skip to content

Commit a6b0cc3

Browse files
authored
Merge pull request #30 from transformerlab/add/config-string-checks
add back string checks for exp config
2 parents 48028e0 + 24c4aa7 commit a6b0cc3

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "transformerlab"
7-
version = "0.0.19"
7+
version = "0.0.20"
88
description = "Python SDK for Transformer Lab"
99
readme = "README.md"
1010
requires-python = ">=3.10"

src/lab/experiment.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,23 +38,36 @@ def _initialize(self):
3838
def update_config_field(self, key, value):
3939
"""Update a single key in config."""
4040
current_config = self._get_json_data_field("config", {})
41+
if isinstance(current_config, str):
42+
try:
43+
current_config = json.loads(current_config)
44+
except json.JSONDecodeError:
45+
current_config = {}
4146
current_config[key] = value
4247
self._update_json_data_field("config", current_config)
4348

4449
@classmethod
4550
def create_with_config(cls, name: str, config: dict) -> 'Experiment':
4651
"""Create an experiment with config."""
47-
if not isinstance(config, dict):
48-
raise TypeError("Config must be a dictionary")
52+
if isinstance(config, str):
53+
try:
54+
config = json.loads(config)
55+
except json.JSONDecodeError:
56+
raise TypeError("config must be a dict or valid JSON string")
57+
elif not isinstance(config, dict):
58+
raise TypeError("config must be a dict")
4959
exp = cls.create(name)
5060
exp._update_json_data_field("config", config)
5161
return exp
5262

5363
def update_config(self, config: dict):
5464
"""Update entire config."""
55-
if not isinstance(config, dict):
56-
raise TypeError("Config must be a dictionary")
5765
current_config = self._get_json_data_field("config", {})
66+
if isinstance(current_config, str):
67+
try:
68+
current_config = json.loads(current_config)
69+
except json.JSONDecodeError:
70+
current_config = {}
5871
current_config.update(config)
5972
self._update_json_data_field("config", current_config)
6073

0 commit comments

Comments
 (0)