Skip to content

Commit ba139b2

Browse files
hotzenklotzcoderabbitai[bot]MichaelBuessemeyer
authored
Enable custom AI Instance Segmentation jobs (#8849)
This PR does three things: 1. Adds support for training and running AI instance segmentation models. This is already supported by the worker by was not available from the UI yet. This workflow is distinct from the regular neuron model training is designed for nuclei and other instance segmentation tasks. There is also a lot of naming confusion between our pre-trained models, the custom model that one can train, generic instance segmentation and EM_nuclei models (essentially the same) etc. Something for a follow up PR. 2. Adds AI job settings to better support VX instance segmentations (See issue #8278): - Instance Model Training: `max_distance_nm` - Instance Model Inference: `seed_generator_distance_threshold` 3. It splits the monolithic `starts_jobs_modal.tsx` into separate files for components, tabs, forms, hooks etc. - most of the code is simply moved into sub components - added React.callbacks to some callbacks and click handlers ### URL of deployed dev instance (used for testing): - https://___.webknossos.xyz ### Steps to test: - Enable worker Training a new Instances Model: 1. Create annotation with bounding box, open "AI" modal 2. Switch to tab "Train a Model" to start a new model training. Switch "Model category" to "EM Instance Segmentation for Nuclei, ..." 3. Enter a distance threshold - or use default value 5. Start training Using an instances model: 1. Create annotation, open "AI" modal 2. Switch to tab "Run a Model" 3. Switch from pre-trained to custom models with toggle switch in upper right 4. Select previously trained instance model from dropdown 6. Confirm that the "seed generatoion distance" option is shown 7. Start the inference ### Issues: - fixes #8278 ------ (Please delete unneeded items, merge only when none are left open) - [x] Added changelog entry (create a `$PR_NUMBER.md` file in `unreleased_changes` or use `./tools/create-changelog-entry.py`) - [ ] Added migration guide entry if applicable (edit the same file as for the changelog) - [ ] Updated [documentation](../blob/master/docs) if applicable - [ ] Adapted [wk-libs python client](https://github.com/scalableminds/webknossos-libs/tree/master/webknossos/webknossos/client) if relevant API parts change - [ ] Removed dev-only changes like prints and application.conf edits - [x] Considered [common edge cases](../blob/master/.github/common_edge_cases.md) - [ ] Needs datastore update after deployment --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: MichaelBuessemeyer <[email protected]>
1 parent cdbd039 commit ba139b2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+2810
-2343
lines changed

app/controllers/AiModelController.scala

Lines changed: 113 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,25 @@ object TrainingAnnotationSpecification {
3030
implicit val jsonFormat: OFormat[TrainingAnnotationSpecification] = Json.format[TrainingAnnotationSpecification]
3131
}
3232

33-
case class RunTrainingParameters(trainingAnnotations: List[TrainingAnnotationSpecification],
34-
name: String,
35-
comment: Option[String],
36-
aiModelCategory: Option[AiModelCategory],
37-
workflowYaml: Option[String])
33+
case class RunNeuronModelTrainingParameters(trainingAnnotations: List[TrainingAnnotationSpecification],
34+
name: String,
35+
aiModelCategory: Option[AiModelCategory],
36+
comment: Option[String],
37+
workflowYaml: Option[String])
3838

39-
object RunTrainingParameters {
40-
implicit val jsonFormat: OFormat[RunTrainingParameters] = Json.format[RunTrainingParameters]
39+
object RunNeuronModelTrainingParameters {
40+
implicit val jsonFormat: OFormat[RunNeuronModelTrainingParameters] = Json.format[RunNeuronModelTrainingParameters]
41+
}
42+
43+
case class RunInstanceModelTrainingParameters(trainingAnnotations: List[TrainingAnnotationSpecification],
44+
name: String,
45+
aiModelCategory: Option[AiModelCategory],
46+
maxDistanceNm: Option[Double],
47+
comment: Option[String],
48+
workflowYaml: Option[String])
49+
50+
object RunInstanceModelTrainingParameters {
51+
implicit val jsonFormat: OFormat[RunInstanceModelTrainingParameters] = Json.format[RunInstanceModelTrainingParameters]
4152
}
4253

4354
case class RunInferenceParameters(annotationId: Option[ObjectId],
@@ -48,7 +59,8 @@ case class RunInferenceParameters(annotationId: Option[ObjectId],
4859
boundingBox: String,
4960
newDatasetName: String,
5061
maskAnnotationLayerName: Option[String],
51-
workflowYaml: Option[String])
62+
workflowYaml: Option[String],
63+
seedGeneratorDistanceThreshold: Option[Double])
5264

5365
object RunInferenceParameters {
5466
implicit val jsonFormat: OFormat[RunInferenceParameters] = Json.format[RunInferenceParameters]
@@ -127,21 +139,18 @@ class AiModelController @Inject()(
127139
}
128140
}
129141

130-
def runNeuronTraining: Action[RunTrainingParameters] = sil.SecuredAction.async(validateJson[RunTrainingParameters]) {
131-
implicit request =>
142+
def runNeuronTraining: Action[RunNeuronModelTrainingParameters] =
143+
sil.SecuredAction.async(validateJson[RunNeuronModelTrainingParameters]) { implicit request =>
132144
for {
133145
_ <- userService.assertIsSuperUser(request.identity)
134146
trainingAnnotations = request.body.trainingAnnotations
135-
_ <- Fox
136-
.fromBool(trainingAnnotations.nonEmpty || request.body.workflowYaml.isDefined) ?~> "aiModel.training.zeroAnnotations"
147+
_ <- Fox.fromBool(trainingAnnotations.nonEmpty || request.body.workflowYaml.isDefined) ?~> "aiModel.training.zeroAnnotations"
137148
firstAnnotationId <- trainingAnnotations.headOption.map(_.annotationId).toFox
138149
annotation <- annotationDAO.findOne(firstAnnotationId)
139150
dataset <- datasetDAO.findOne(annotation._dataset)
140-
_ <- Fox
141-
.fromBool(request.identity._organization == dataset._organization) ?~> "job.trainModel.notAllowed.organization" ~> FORBIDDEN
151+
_ <- Fox.fromBool(request.identity._organization == dataset._organization) ?~> "job.trainModel.notAllowed.organization" ~> FORBIDDEN
142152
dataStore <- dataStoreDAO.findOneByName(dataset._dataStore) ?~> "dataStore.notFound"
143-
_ <- Fox
144-
.serialCombined(request.body.trainingAnnotations.map(_.annotationId))(annotationDAO.findOne) ?~> "annotation.notFound"
153+
_ <- Fox.serialCombined(request.body.trainingAnnotations.map(_.annotationId))(annotationDAO.findOne) ?~> "annotation.notFound"
145154
modelId = ObjectId.generate
146155
organization <- organizationDAO.findOne(request.identity._organization)
147156
jobCommand = JobCommand.train_neuron_model
@@ -154,8 +163,7 @@ class AiModelController @Inject()(
154163
existingAiModelsCount <- aiModelDAO.countByNameAndOrganization(request.body.name,
155164
request.identity._organization)
156165
_ <- Fox.fromBool(existingAiModelsCount == 0) ?~> "aiModel.nameInUse"
157-
newTrainingJob <- jobService
158-
.submitJob(jobCommand, commandArgs, request.identity, dataStore.name) ?~> "job.couldNotRunTrainModel"
166+
newTrainingJob <- jobService.submitJob(jobCommand, commandArgs, request.identity, dataStore.name) ?~> "job.couldNotRunTrainModel"
159167
newAiModel = AiModel(
160168
_id = modelId,
161169
_organization = request.identity._organization,
@@ -171,7 +179,93 @@ class AiModelController @Inject()(
171179
_ <- aiModelDAO.insertOne(newAiModel)
172180
newAiModelJs <- aiModelService.publicWrites(newAiModel, request.identity)
173181
} yield Ok(newAiModelJs)
174-
}
182+
}
183+
184+
def runInstanceTraining: Action[RunInstanceModelTrainingParameters] =
185+
sil.SecuredAction.async(validateJson[RunInstanceModelTrainingParameters]) { implicit request =>
186+
for {
187+
_ <- userService.assertIsSuperUser(request.identity)
188+
trainingAnnotations = request.body.trainingAnnotations
189+
_ <- Fox.fromBool(trainingAnnotations.nonEmpty || request.body.workflowYaml.isDefined) ?~> "aiModel.training.zeroAnnotations"
190+
firstAnnotationId <- trainingAnnotations.headOption.map(_.annotationId).toFox
191+
annotation <- annotationDAO.findOne(firstAnnotationId)
192+
dataset <- datasetDAO.findOne(annotation._dataset)
193+
_ <- Fox.fromBool(request.identity._organization == dataset._organization) ?~> "job.trainModel.notAllowed.organization" ~> FORBIDDEN
194+
dataStore <- dataStoreDAO.findOneByName(dataset._dataStore) ?~> "dataStore.notFound"
195+
_ <- Fox.serialCombined(request.body.trainingAnnotations.map(_.annotationId))(annotationDAO.findOne) ?~> "annotation.notFound"
196+
modelId = ObjectId.generate
197+
organization <- organizationDAO.findOne(request.identity._organization)
198+
jobCommand = JobCommand.train_instance_model
199+
commandArgs = Json.obj(
200+
"training_annotations" -> Json.toJson(trainingAnnotations),
201+
"organization_id" -> organization._id,
202+
"model_id" -> modelId,
203+
"custom_workflow_provided_by_user" -> request.body.workflowYaml,
204+
"max_distance_nm" -> request.body.maxDistanceNm
205+
)
206+
existingAiModelsCount <- aiModelDAO.countByNameAndOrganization(request.body.name,
207+
request.identity._organization)
208+
_ <- Fox.fromBool(existingAiModelsCount == 0) ?~> "aiModel.nameInUse"
209+
newTrainingJob <- jobService.submitJob(jobCommand, commandArgs, request.identity, dataStore.name) ?~> "job.couldNotRunTrainModel"
210+
newAiModel = AiModel(
211+
_id = modelId,
212+
_organization = request.identity._organization,
213+
_sharedOrganizations = List(),
214+
_dataStore = dataStore.name,
215+
_user = request.identity._id,
216+
_trainingJob = Some(newTrainingJob._id),
217+
_trainingAnnotations = trainingAnnotations.map(_.annotationId),
218+
name = request.body.name,
219+
comment = request.body.comment,
220+
category = request.body.aiModelCategory
221+
)
222+
_ <- aiModelDAO.insertOne(newAiModel)
223+
newAiModelJs <- aiModelService.publicWrites(newAiModel, request.identity)
224+
} yield Ok(newAiModelJs)
225+
}
226+
227+
def runCustomInstanceModelInference: Action[RunInferenceParameters] =
228+
sil.SecuredAction.async(validateJson[RunInferenceParameters]) { implicit request =>
229+
for {
230+
_ <- userService.assertIsSuperUser(request.identity)
231+
organization <- organizationDAO.findOne(request.body.organizationId)(GlobalAccessContext) ?~> Messages(
232+
"organization.notFound",
233+
request.body.organizationId)
234+
_ <- Fox.fromBool(request.identity._organization == organization._id) ?~> "job.runInference.notAllowed.organization" ~> FORBIDDEN
235+
dataset <- datasetDAO.findOneByDirectoryNameAndOrganization(request.body.datasetDirectoryName, organization._id)
236+
dataStore <- dataStoreDAO.findOneByName(dataset._dataStore) ?~> "dataStore.notFound"
237+
_ <- aiModelDAO.findOne(request.body.aiModelId) ?~> "aiModel.notFound"
238+
_ <- datasetService.assertValidDatasetName(request.body.newDatasetName)
239+
jobCommand = JobCommand.infer_instances
240+
boundingBox <- BoundingBox.fromLiteral(request.body.boundingBox).toFox
241+
commandArgs = Json.obj(
242+
"dataset_id" -> dataset._id,
243+
"organization_id" -> organization._id,
244+
"dataset_name" -> dataset.name,
245+
"layer_name" -> request.body.colorLayerName,
246+
"bbox" -> boundingBox.toLiteral,
247+
"model_id" -> request.body.aiModelId,
248+
"dataset_directory_name" -> request.body.datasetDirectoryName,
249+
"new_dataset_name" -> request.body.newDatasetName,
250+
"custom_workflow_provided_by_user" -> request.body.workflowYaml,
251+
"seed_generator_distance_threshold" -> request.body.seedGeneratorDistanceThreshold
252+
)
253+
newInferenceJob <- jobService.submitJob(jobCommand, commandArgs, request.identity, dataStore.name) ?~> "job.couldNotRunInferWithModel"
254+
newAiInference = AiInference(
255+
_id = ObjectId.generate,
256+
_organization = request.identity._organization,
257+
_aiModel = request.body.aiModelId,
258+
_newDataset = None,
259+
_annotation = request.body.annotationId,
260+
boundingBox = boundingBox,
261+
_inferenceJob = newInferenceJob._id,
262+
newSegmentationLayerName = "segmentation",
263+
maskAnnotationLayerName = request.body.maskAnnotationLayerName
264+
)
265+
_ <- aiInferenceDAO.insertOne(newAiInference)
266+
newAiModelJs <- aiInferenceService.publicWrites(newAiInference, request.identity)
267+
} yield Ok(newAiModelJs)
268+
}
175269

176270
def runCustomNeuronInference: Action[RunInferenceParameters] =
177271
sil.SecuredAction.async(validateJson[RunInferenceParameters]) { implicit request =>

app/models/aimodels/AiModel.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ class AiModelService @Inject()(dataStoreDAO: DataStoreDAO,
7171
"comment" -> aiModel.comment,
7272
"trainingJob" -> trainingJobJsOpt,
7373
"created" -> aiModel.created,
74-
"sharedOrganizationIds" -> sharedOrganizationIds
74+
"sharedOrganizationIds" -> sharedOrganizationIds,
75+
"category" -> aiModel.category
7576
)
7677
}
7778

app/models/job/JobCommand.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ object JobCommand extends ExtendedEnumeration {
1212
*/
1313

1414
val compute_mesh_file, compute_segment_index_file, convert_to_wkw, export_tiff, find_largest_segment_id,
15-
globalize_floodfills, infer_nuclei, infer_neurons, materialize_volume_annotation, render_animation,
16-
infer_mitochondria, align_sections, train_model, infer_with_model, train_neuron_model = Value
15+
globalize_floodfills, infer_nuclei, infer_neurons, infer_instances, materialize_volume_annotation, render_animation,
16+
infer_mitochondria, align_sections, train_model, infer_with_model, train_neuron_model, train_instance_model = Value
1717

1818
val highPriorityJobs: Set[Value] = Set(convert_to_wkw, export_tiff)
1919
val lowPriorityJobs: Set[Value] = values.diff(highPriorityJobs)

conf/webknossos.latest.routes

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,8 +285,10 @@ POST /jobs/:id/attachDatasetToInference
285285
GET /jobs/:id/export controllers.JobController.redirectToExport(id: ObjectId)
286286

287287
# AI Models
288-
POST /aiModels/runNeuronTraining controllers.AiModelController.runNeuronTraining
289-
POST /aiModels/inferences/runCustomNeuronInference controllers.AiModelController.runCustomNeuronInference
288+
POST /aiModels/runNeuronModelTraining controllers.AiModelController.runNeuronTraining
289+
POST /aiModels/runInstanceModelTraining controllers.AiModelController.runInstanceTraining
290+
POST /aiModels/inferences/runCustomNeuronModelInference controllers.AiModelController.runCustomNeuronInference
291+
POST /aiModels/inferences/runCustomInstanceModelInference controllers.AiModelController.runCustomInstanceModelInference
290292
GET /aiModels/inferences/:id controllers.AiModelController.readAiInferenceInfo(id: ObjectId)
291293
GET /aiModels/inferences controllers.AiModelController.listAiInferences
292294
GET /aiModels controllers.AiModelController.listAiModels

frontend/javascripts/admin/api/jobs.ts

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,11 @@ export function startAlignSectionsJob(
351351
});
352352
}
353353

354-
type AiModelCategory = "em_neurons" | "em_nuclei";
354+
// This enum needs to be kept in sync with the backend/database
355+
export enum APIAiModelCategory {
356+
EM_NEURONS = "em_neurons",
357+
EM_NUCLEI = "em_nuclei",
358+
}
355359

356360
type AiModelTrainingAnnotationSpecification = {
357361
annotationId: string;
@@ -360,22 +364,38 @@ type AiModelTrainingAnnotationSpecification = {
360364
mag: Vector3;
361365
};
362366

363-
type RunTrainingParameters = {
364-
trainingAnnotations: Array<AiModelTrainingAnnotationSpecification>;
367+
type RunNeuronModelTrainingParameters = {
368+
trainingAnnotations: AiModelTrainingAnnotationSpecification[];
369+
name: string;
370+
aiModelCategory: APIAiModelCategory.EM_NEURONS;
371+
comment?: string;
372+
workflowYaml?: string;
373+
};
374+
375+
export function runNeuronTraining(params: RunNeuronModelTrainingParameters) {
376+
return Request.sendJSONReceiveJSON("/api/aiModels/runNeuronModelTraining", {
377+
method: "POST",
378+
data: JSON.stringify(params),
379+
});
380+
}
381+
382+
type RunInstanceModelTrainingParameters = {
383+
trainingAnnotations: AiModelTrainingAnnotationSpecification[];
365384
name: string;
385+
aiModelCategory: APIAiModelCategory.EM_NUCLEI;
386+
maxDistanceNm: number;
366387
comment?: string;
367-
aiModelCategory?: AiModelCategory;
368388
workflowYaml?: string;
369389
};
370390

371-
export function runNeuronTraining(params: RunTrainingParameters) {
372-
return Request.sendJSONReceiveJSON("/api/aiModels/runNeuronTraining", {
391+
export function runInstanceModelTraining(params: RunInstanceModelTrainingParameters) {
392+
return Request.sendJSONReceiveJSON("/api/aiModels/runInstanceModelTraining", {
373393
method: "POST",
374394
data: JSON.stringify(params),
375395
});
376396
}
377397

378-
type RunInferenceParameters = {
398+
export type BaseModelInferenceParameters = {
379399
annotationId?: string;
380400
aiModelId: string;
381401
datasetDirectoryName: string;
@@ -386,9 +406,23 @@ type RunInferenceParameters = {
386406
workflowYaml?: string;
387407
// maskAnnotationLayerName?: string | null
388408
};
409+
type RunNeuronModelInferenceParameters = BaseModelInferenceParameters;
410+
411+
type RunInstanceModelInferenceParameters = BaseModelInferenceParameters & {
412+
seedGeneratorDistanceThreshold: number;
413+
};
414+
415+
export function runNeuronModelInferenceWithAiModelJob(params: RunNeuronModelInferenceParameters) {
416+
return Request.sendJSONReceiveJSON("/api/aiModels/inferences/runCustomNeuronModelInference", {
417+
method: "POST",
418+
data: JSON.stringify({ ...params, boundingBox: params.boundingBox.join(",") }),
419+
});
420+
}
389421

390-
export function runNeuronInferenceWithAiModelJob(params: RunInferenceParameters) {
391-
return Request.sendJSONReceiveJSON("/api/aiModels/inferences/runCustomNeuronInference", {
422+
export function runInstanceModelInferenceWithAiModelJob(
423+
params: RunInstanceModelInferenceParameters,
424+
) {
425+
return Request.sendJSONReceiveJSON("/api/aiModels/inferences/runCustomInstanceModelInference", {
392426
method: "POST",
393427
data: JSON.stringify({ ...params, boundingBox: params.boundingBox.join(",") }),
394428
});

frontend/javascripts/admin/job/job_list_view.tsx

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ export function JobState({ job }: { job: APIJob }) {
126126

127127
return (
128128
<Tooltip title={tooltip}>
129-
<span className="icon-margin-right">{icon}</span>
129+
<span>{icon}</span>
130130
{jobStateNormalized}
131131
</Tooltip>
132132
);
@@ -235,7 +235,7 @@ function JobListView() {
235235
) {
236236
return (
237237
<span>
238-
Neuron inferral for layer {job.layerName} of{" "}
238+
AI Neuron inferral for layer <i>{job.layerName}</i> of{" "}
239239
<Link to={linkToDataset}>{job.datasetName}</Link>{" "}
240240
</span>
241241
);
@@ -256,14 +256,21 @@ function JobListView() {
256256
) {
257257
return (
258258
<span>
259-
Mitochondria inferral for layer {job.layerName} of{" "}
259+
AI Mitochondria inferral for layer <i>{job.layerName}</i> of{" "}
260+
<Link to={linkToDataset}>{job.datasetName}</Link>{" "}
261+
</span>
262+
);
263+
} else if (job.type === APIJobType.INFER_INSTANCES && linkToDataset != null && job.layerName) {
264+
return (
265+
<span>
266+
AI instance segmentation for layer <i>{job.layerName}</i> of{" "}
260267
<Link to={linkToDataset}>{job.datasetName}</Link>{" "}
261268
</span>
262269
);
263270
} else if (job.type === APIJobType.ALIGN_SECTIONS && linkToDataset != null && job.layerName) {
264271
return (
265272
<span>
266-
Align sections for layer {job.layerName} of{" "}
273+
Align sections for layer <i>{job.layerName}</i> of{" "}
267274
<Link to={linkToDataset}>{job.datasetName}</Link>{" "}
268275
</span>
269276
);
@@ -277,11 +284,19 @@ function JobListView() {
277284
: null}
278285
</span>
279286
);
280-
} else if (job.type === APIJobType.TRAIN_NEURON_MODEL || APIJobType.DEPRECATED_TRAIN_MODEL) {
281-
const numberOfTrainingAnnotations = job.trainingAnnotations.length;
287+
} else if (
288+
job.type === APIJobType.TRAIN_NEURON_MODEL ||
289+
job.type === APIJobType.TRAIN_INSTANCE_MODEL ||
290+
job.type === APIJobType.DEPRECATED_TRAIN_MODEL
291+
) {
292+
const numberOfTrainingAnnotations = job.trainingAnnotations?.length || 0;
293+
const modelName =
294+
job.type === APIJobType.TRAIN_NEURON_MODEL || job.type === APIJobType.DEPRECATED_TRAIN_MODEL
295+
? "neuron model"
296+
: "instance model";
282297
return (
283298
<span>
284-
{`Train neuron model on ${numberOfTrainingAnnotations} ${Utils.pluralize("annotation", numberOfTrainingAnnotations)}. `}
299+
{`Train ${modelName} on ${numberOfTrainingAnnotations} ${Utils.pluralize("annotation", numberOfTrainingAnnotations)}. `}
285300
{getShowTrainingDataLink(job.trainingAnnotations)}
286301
</span>
287302
);

0 commit comments

Comments
 (0)