Skip to content

Commit ea0260e

Browse files
authored
Merge pull request #56 from transformerlab/add/save-dataset
Add lab.save_dataset() functionality
2 parents 24992e4 + 5195c2e commit ea0260e

File tree

2 files changed

+94
-1
lines changed

2 files changed

+94
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "transformerlab"
7-
version = "0.0.39"
7+
version = "0.0.40"
88
description = "Python SDK for Transformer Lab"
99
readme = "README.md"
1010
requires-python = ">=3.10"

src/lab/lab_facade.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .job import Job
1010
from . import dirs
1111
from .model import Model as ModelService
12+
from .dataset import Dataset
1213

1314
class Lab:
1415
"""
@@ -151,6 +152,98 @@ def save_artifact(self, source_path: str, name: Optional[str] = None) -> str:
151152

152153
return dest
153154

155+
def save_dataset(self, df, dataset_id: str, additional_metadata: Optional[Dict[str, Any]] = None, suffix: Optional[str] = None, is_image: bool = False) -> str:
156+
"""
157+
Save a dataset under the workspace datasets directory and mark it as generated.
158+
159+
Args:
160+
df: A pandas DataFrame or a Hugging Face datasets.Dataset to serialize to disk.
161+
dataset_id: Identifier for the dataset directory under `datasets/`.
162+
additional_metadata: Optional dict to merge into dataset json_data.
163+
suffix: Optional suffix to append to the output filename stem.
164+
is_image: If True, save JSON Lines (for image metadata-style rows).
165+
166+
Returns:
167+
The path to the saved dataset file on disk.
168+
"""
169+
self._ensure_initialized()
170+
if not isinstance(dataset_id, str) or dataset_id.strip() == "":
171+
raise ValueError("dataset_id must be a non-empty string")
172+
173+
# Normalize input: convert Hugging Face datasets.Dataset to pandas DataFrame
174+
try:
175+
if hasattr(df, "to_pandas") and callable(getattr(df, "to_pandas")):
176+
df = df.to_pandas()
177+
except Exception:
178+
pass
179+
180+
# Prepare dataset directory
181+
dataset_id_safe = dataset_id.strip()
182+
dataset_dir = dirs.dataset_dir_by_id(dataset_id_safe)
183+
# If exists, then raise an error
184+
if os.path.exists(dataset_dir):
185+
raise FileExistsError(f"Dataset with ID {dataset_id_safe} already exists")
186+
os.makedirs(dataset_dir, exist_ok=True)
187+
188+
# Determine output filename
189+
if is_image:
190+
lines = True
191+
output_filename = "metadata.jsonl"
192+
else:
193+
lines = False
194+
stem = dataset_id_safe
195+
if isinstance(suffix, str) and suffix.strip() != "":
196+
stem = f"{stem}_{suffix.strip()}"
197+
output_filename = f"{stem}.json"
198+
199+
output_path = os.path.join(dataset_dir, output_filename)
200+
201+
# Persist dataframe
202+
try:
203+
if not hasattr(df, "to_json"):
204+
raise TypeError("df must be a pandas DataFrame or a Hugging Face datasets.Dataset")
205+
df.to_json(output_path, orient="records", lines=lines)
206+
except Exception as e:
207+
raise RuntimeError(f"Failed to save dataset to {output_path}: {str(e)}")
208+
209+
# Create or update filesystem metadata so it appears under generated datasets
210+
try:
211+
try:
212+
ds = Dataset.get(dataset_id_safe)
213+
except FileNotFoundError:
214+
ds = Dataset.create(dataset_id_safe)
215+
216+
# Base json_data with generated flag for UI filtering
217+
json_data: Dict[str, Any] = {
218+
"generated": True,
219+
"sample_count": len(df) if hasattr(df, "__len__") else -1,
220+
"files": [output_filename],
221+
}
222+
if additional_metadata and isinstance(additional_metadata, dict):
223+
json_data.update(additional_metadata)
224+
225+
ds.set_metadata(
226+
location="local",
227+
description=json_data.get("description", ""),
228+
size=-1,
229+
json_data=json_data,
230+
)
231+
except Exception as e:
232+
# Do not fail the save if metadata write fails; log to job data
233+
try:
234+
self._job.update_job_data_field("dataset_metadata_error", str(e)) # type: ignore[union-attr]
235+
except Exception:
236+
pass
237+
238+
# Track dataset on the job for provenance
239+
try:
240+
self._job.update_job_data_field("dataset_id", dataset_id_safe) # type: ignore[union-attr]
241+
except Exception:
242+
pass
243+
244+
self.log(f"Dataset saved to '{output_path}' and registered as generated dataset '{dataset_id_safe}'")
245+
return output_path
246+
154247
def save_checkpoint(self, source_path: str, name: Optional[str] = None) -> str:
155248
"""
156249
Save a checkpoint file or directory into this job's checkpoints folder.

0 commit comments

Comments
 (0)