33from grpclib .client import Channel
44
55from viam import logging
6- from viam .proto .app .data import Filter
76from viam .proto .app .mltraining import (
87 CancelTrainingJobRequest ,
8+ DeleteCompletedTrainingJobRequest ,
99 GetTrainingJobRequest ,
1010 GetTrainingJobResponse ,
1111 ListTrainingJobsRequest ,
1212 ListTrainingJobsResponse ,
1313 MLTrainingServiceStub ,
1414 ModelType ,
15+ SubmitCustomTrainingJobRequest ,
16+ SubmitCustomTrainingJobResponse ,
17+ SubmitTrainingJobRequest ,
18+ SubmitTrainingJobResponse ,
1519 TrainingJobMetadata ,
1620 TrainingStatus ,
1721)
@@ -66,13 +70,62 @@ def __init__(self, channel: Channel, metadata: Mapping[str, str]):
6670 async def submit_training_job (
6771 self ,
6872 org_id : str ,
73+ dataset_id : str ,
6974 model_name : str ,
7075 model_version : str ,
71- model_type : ModelType ,
76+ model_type : ModelType . ValueType ,
7277 tags : List [str ],
73- filter : Optional [Filter ] = None ,
7478 ) -> str :
75- raise NotImplementedError ()
79+ """Submit a training job.
80+
81+ Args:
82+ org_id (str): the id of the org to submit the training job to
83+ dataset_id (str): the id of the dataset
84+ model_name (str): the model name
85+ model_version (str): the model version
86+ model_type (ModelType.ValueType): the model type
87+ tags (List[str]): the tags
88+
89+ Returns:
90+ str: the id of the training job
91+ """
92+
93+ request = SubmitTrainingJobRequest (
94+ dataset_id = dataset_id ,
95+ organization_id = org_id ,
96+ model_name = model_name ,
97+ model_version = model_version ,
98+ model_type = model_type ,
99+ tags = tags ,
100+ )
101+ response : SubmitTrainingJobResponse = await self ._ml_training_client .SubmitTrainingJob (request , metadata = self ._metadata )
102+ return response .id
103+
104+ async def submit_custom_training_job (
105+ self , org_id : str , dataset_id : str , registry_item_id : str , model_name : str , model_version : str
106+ ) -> str :
107+ """Submit a custom training job.
108+
109+ Args:
110+ org_id (str): the id of the org to submit the training job to
111+ dataset_id (str): the id of the dataset
112+ registry_item_id (List[str]): the id of the registry item
113+ model_name (str): the model name
114+ model_version (str): the model version
115+
116+ Returns:
117+ str: the id of the training job
118+ """
119+
120+ request = SubmitCustomTrainingJobRequest (
121+ dataset_id = dataset_id ,
122+ registry_item_id = registry_item_id ,
123+ organization_id = org_id ,
124+ model_name = model_name ,
125+ model_version = model_version ,
126+ )
127+ response : SubmitCustomTrainingJobResponse = await self ._ml_training_client .SubmitCustomTrainingJob (request , metadata = self ._metadata )
128+ return response .id
76129
77130 async def get_training_job (self , id : str ) -> TrainingJobMetadata :
78131 """Gets training job data.
@@ -83,7 +136,7 @@ async def get_training_job(self, id: str) -> TrainingJobMetadata:
83136 id="INSERT YOUR JOB ID")
84137
85138 Args:
86- id (str): id of the requested training job.
139+ id (str): the id of the requested training job.
87140
88141 Returns:
89142 viam.proto.app.mltraining.TrainingJobMetadata: training job data.
@@ -140,3 +193,12 @@ async def cancel_training_job(self, id: str) -> None:
140193
141194 request = CancelTrainingJobRequest (id = id )
142195 await self ._ml_training_client .CancelTrainingJob (request , metadata = self ._metadata )
196+
197+ async def delete_completed_training_job (self , id : str ) -> None :
198+ """Delete a completed training job from the database, whether the job succeeded or failed
199+ Args:
200+ id (str): the id of the training job
201+ """
202+
203+ request = DeleteCompletedTrainingJobRequest (id = id )
204+ await self ._ml_training_client .DeleteCompletedTrainingJob (request , metadata = self ._metadata )
0 commit comments