Skip to content

Commit 7acf717

Browse files
RSDK-7191: add mltraining wrappers (#602)
1 parent 97ddc1d commit 7acf717

File tree

4 files changed

+113
-8
lines changed

4 files changed

+113
-8
lines changed

src/viam/app/ml_training_client.py

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,19 @@
33
from grpclib.client import Channel
44

55
from viam import logging
6-
from viam.proto.app.data import Filter
76
from viam.proto.app.mltraining import (
87
CancelTrainingJobRequest,
8+
DeleteCompletedTrainingJobRequest,
99
GetTrainingJobRequest,
1010
GetTrainingJobResponse,
1111
ListTrainingJobsRequest,
1212
ListTrainingJobsResponse,
1313
MLTrainingServiceStub,
1414
ModelType,
15+
SubmitCustomTrainingJobRequest,
16+
SubmitCustomTrainingJobResponse,
17+
SubmitTrainingJobRequest,
18+
SubmitTrainingJobResponse,
1519
TrainingJobMetadata,
1620
TrainingStatus,
1721
)
@@ -66,13 +70,62 @@ def __init__(self, channel: Channel, metadata: Mapping[str, str]):
6670
async def submit_training_job(
6771
self,
6872
org_id: str,
73+
dataset_id: str,
6974
model_name: str,
7075
model_version: str,
71-
model_type: ModelType,
76+
model_type: ModelType.ValueType,
7277
tags: List[str],
73-
filter: Optional[Filter] = None,
7478
) -> str:
75-
raise NotImplementedError()
79+
"""Submit a training job.
80+
81+
Args:
82+
org_id (str): the id of the org to submit the training job to
83+
dataset_id (str): the id of the dataset
84+
model_name (str): the model name
85+
model_version (str): the model version
86+
model_type (ModelType.ValueType): the model type
87+
tags (List[str]): the tags
88+
89+
Returns:
90+
str: the id of the training job
91+
"""
92+
93+
request = SubmitTrainingJobRequest(
94+
dataset_id=dataset_id,
95+
organization_id=org_id,
96+
model_name=model_name,
97+
model_version=model_version,
98+
model_type=model_type,
99+
tags=tags,
100+
)
101+
response: SubmitTrainingJobResponse = await self._ml_training_client.SubmitTrainingJob(request, metadata=self._metadata)
102+
return response.id
103+
104+
async def submit_custom_training_job(
105+
self, org_id: str, dataset_id: str, registry_item_id: str, model_name: str, model_version: str
106+
) -> str:
107+
"""Submit a custom training job.
108+
109+
Args:
110+
org_id (str): the id of the org to submit the training job to
111+
dataset_id (str): the id of the dataset
112+
registry_item_id (List[str]): the id of the registry item
113+
model_name (str): the model name
114+
model_version (str): the model version
115+
116+
Returns:
117+
str: the id of the training job
118+
"""
119+
120+
request = SubmitCustomTrainingJobRequest(
121+
dataset_id=dataset_id,
122+
registry_item_id=registry_item_id,
123+
organization_id=org_id,
124+
model_name=model_name,
125+
model_version=model_version,
126+
)
127+
response: SubmitCustomTrainingJobResponse = await self._ml_training_client.SubmitCustomTrainingJob(request, metadata=self._metadata)
128+
return response.id
76129

77130
async def get_training_job(self, id: str) -> TrainingJobMetadata:
78131
"""Gets training job data.
@@ -83,7 +136,7 @@ async def get_training_job(self, id: str) -> TrainingJobMetadata:
83136
id="INSERT YOUR JOB ID")
84137
85138
Args:
86-
id (str): id of the requested training job.
139+
id (str): the id of the requested training job.
87140
88141
Returns:
89142
viam.proto.app.mltraining.TrainingJobMetadata: training job data.
@@ -140,3 +193,12 @@ async def cancel_training_job(self, id: str) -> None:
140193

141194
request = CancelTrainingJobRequest(id=id)
142195
await self._ml_training_client.CancelTrainingJob(request, metadata=self._metadata)
196+
197+
async def delete_completed_training_job(self, id: str) -> None:
198+
"""Delete a completed training job from the database, whether the job succeeded or failed
199+
Args:
200+
id (str): the id of the training job
201+
"""
202+
203+
request = DeleteCompletedTrainingJobRequest(id=id)
204+
await self._ml_training_client.DeleteCompletedTrainingJob(request, metadata=self._metadata)

tests/mocks/services.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,6 +1023,7 @@ def __init__(self, job_id: str, training_metadata: TrainingJobMetadata):
10231023
async def SubmitTrainingJob(self, stream: Stream[SubmitTrainingJobRequest, SubmitTrainingJobResponse]) -> None:
10241024
request = await stream.recv_message()
10251025
assert request is not None
1026+
self.dataset_id = request.dataset_id
10261027
self.org_id = request.organization_id
10271028
self.model_name = request.model_name
10281029
self.model_version = request.model_version
@@ -1031,7 +1032,14 @@ async def SubmitTrainingJob(self, stream: Stream[SubmitTrainingJobRequest, Submi
10311032
await stream.send_message(SubmitTrainingJobResponse(id=self.job_id))
10321033

10331034
async def SubmitCustomTrainingJob(self, stream: Stream[SubmitCustomTrainingJobRequest, SubmitCustomTrainingJobResponse]) -> None:
1034-
return await super().SubmitCustomTrainingJob(stream)
1035+
request = await stream.recv_message()
1036+
assert request is not None
1037+
self.dataset_id = request.dataset_id
1038+
self.registry_item_id = request.registry_item_id
1039+
self.org_id = request.organization_id
1040+
self.model_name = request.model_name
1041+
self.model_version = request.model_version
1042+
await stream.send_message(SubmitCustomTrainingJobResponse(id=self.job_id))
10351043

10361044
async def GetTrainingJob(self, stream: Stream[GetTrainingJobRequest, GetTrainingJobResponse]) -> None:
10371045
request = await stream.recv_message()
@@ -1055,7 +1063,10 @@ async def CancelTrainingJob(self, stream: Stream[CancelTrainingJobRequest, Cance
10551063
async def DeleteCompletedTrainingJob(
10561064
self, stream: Stream[DeleteCompletedTrainingJobRequest, DeleteCompletedTrainingJobResponse]
10571065
) -> None:
1058-
raise NotImplementedError()
1066+
request = await stream.recv_message()
1067+
assert request is not None
1068+
self.delete_id = request.id
1069+
await stream.send_message(DeleteCompletedTrainingJobResponse())
10591070

10601071

10611072
class MockBilling(BillingServiceBase):

tests/test_dataset.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,11 @@ async def test_list_datasets_by_organization_id(self, service: MockDataset):
5757
datasets = await client.list_datasets_by_organization_id(ORG_ID)
5858
assert service.org_id == ORG_ID
5959
assert datasets == DATASETS
60+
61+
@pytest.mark.asyncio
62+
async def test_rename_dataset(self, service: MockDataset):
63+
async with ChannelFor([service]) as channel:
64+
client = DataClient(channel, DATA_SERVICE_METADATA)
65+
await client.rename_dataset(ID, NAME)
66+
assert service.id == ID
67+
assert service.name == NAME

tests/test_ml_training_client.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313
ID = "id"
1414
TRAINING_JOB_ID = "training-job-id"
1515
CANCEL_ID = "cancel-id"
16+
DELETE_ID = "delete-id"
1617
JOB_ID = "job-id"
1718
ORG_ID = "org-id"
19+
DATASET_ID = "dataset-id"
20+
REGISTRY_ITEM_ID = "registry-item-id"
1821
MODEL_ID = "model-id"
1922
MODEL_NAME = "model-name"
2023
MODEL_VERSION = "model-version"
@@ -66,7 +69,21 @@ async def test_cancel_training_job(self, service: MockMLTraining):
6669

6770
@pytest.mark.asyncio
6871
async def test_submit_training_job(self, service: MockMLTraining):
69-
assert True
72+
async with ChannelFor([service]) as channel:
73+
client = MLTrainingClient(channel, ML_TRAINING_SERVICE_METADATA)
74+
id = await client.submit_training_job(
75+
org_id=ORG_ID, dataset_id=DATASET_ID, model_name=MODEL_NAME, model_version=MODEL_VERSION, model_type=MODEL_TYPE, tags=TAGS
76+
)
77+
assert id == JOB_ID
78+
79+
@pytest.mark.asyncio
80+
async def test_custom_submit_training_job(self, service: MockMLTraining):
81+
async with ChannelFor([service]) as channel:
82+
client = MLTrainingClient(channel, ML_TRAINING_SERVICE_METADATA)
83+
id = await client.submit_custom_training_job(
84+
org_id=ORG_ID, dataset_id=DATASET_ID, registry_item_id=REGISTRY_ITEM_ID, model_name=MODEL_NAME, model_version=MODEL_VERSION
85+
)
86+
assert id == JOB_ID
7087

7188
@pytest.mark.asyncio
7289
async def test_get_training_job(self, service: MockMLTraining):
@@ -85,3 +102,10 @@ async def test_list_training_jobs(self, service: MockMLTraining):
85102
assert training_jobs[0] == TRAINING_METADATA
86103
assert service.training_status == TRAINING_STATUS
87104
assert service.org_id == ORG_ID
105+
106+
@pytest.mark.asyncio
107+
async def test_delete_completed_training_job(self, service: MockMLTraining):
108+
async with ChannelFor([service]) as channel:
109+
client = MLTrainingClient(channel, ML_TRAINING_SERVICE_METADATA)
110+
await client.delete_completed_training_job(DELETE_ID)
111+
assert service.delete_id == DELETE_ID

0 commit comments

Comments
 (0)