From e29b9b4b90ad98f79e0848036ea23a177a6ba4ec Mon Sep 17 00:00:00 2001 From: Tom Herold Date: Mon, 25 Aug 2025 16:35:42 +0200 Subject: [PATCH 01/72] add first draft for new AI job runner --- frontend/javascripts/types/api_types.ts | 4 +- .../viewer/view/action_bar_view.tsx | 5 +- .../viewer/view/ai_jobs/ai_jobs_drawer.tsx | 24 +++ .../alignment/ai_image_alignment_job.tsx | 5 + .../view/ai_jobs/bounding_box_selector.tsx | 38 ++++ .../view/ai_jobs/credit_information.tsx | 88 ++++++++++ .../run_ai_model/ai_analysis_parameters.tsx | 53 ++++++ .../ai_image_segmentation_job.tsx | 21 +++ .../ai_image_segmentation_job_context.tsx | 165 ++++++++++++++++++ .../run_ai_model/ai_model_selector.tsx | 153 ++++++++++++++++ .../train_ai_model/ai_model_training_job.tsx | 5 + 11 files changed, 557 insertions(+), 4 deletions(-) create mode 100644 frontend/javascripts/viewer/view/ai_jobs/ai_jobs_drawer.tsx create mode 100644 frontend/javascripts/viewer/view/ai_jobs/alignment/ai_image_alignment_job.tsx create mode 100644 frontend/javascripts/viewer/view/ai_jobs/bounding_box_selector.tsx create mode 100644 frontend/javascripts/viewer/view/ai_jobs/credit_information.tsx create mode 100644 frontend/javascripts/viewer/view/ai_jobs/run_ai_model/ai_analysis_parameters.tsx create mode 100644 frontend/javascripts/viewer/view/ai_jobs/run_ai_model/ai_image_segmentation_job.tsx create mode 100644 frontend/javascripts/viewer/view/ai_jobs/run_ai_model/ai_image_segmentation_job_context.tsx create mode 100644 frontend/javascripts/viewer/view/ai_jobs/run_ai_model/ai_model_selector.tsx create mode 100644 frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_model_training_job.tsx diff --git a/frontend/javascripts/types/api_types.ts b/frontend/javascripts/types/api_types.ts index aa76b2a0d55..8ba17ade061 100644 --- a/frontend/javascripts/types/api_types.ts +++ b/frontend/javascripts/types/api_types.ts @@ -787,11 +787,11 @@ export enum APIJobType { FIND_LARGEST_SEGMENT_ID = "find_largest_segment_id", INFER_NUCLEI = "infer_nuclei", INFER_NEURONS = "infer_neurons", + INFER_MITOCHONDRIA = "infer_mitochondria", + INFER_INSTANCES = "infer_instances", MATERIALIZE_VOLUME_ANNOTATION = "materialize_volume_annotation", TRAIN_NEURON_MODEL = "train_neuron_model", TRAIN_INSTANCE_MODEL = "train_instance_model", - INFER_MITOCHONDRIA = "infer_mitochondria", - INFER_INSTANCES = "infer_instances", // Only used for backwards compatibility, e.g. to display results. DEPRECATED_INFER_WITH_MODEL = "infer_with_model", DEPRECATED_TRAIN_MODEL = "train_model", diff --git a/frontend/javascripts/viewer/view/action_bar_view.tsx b/frontend/javascripts/viewer/view/action_bar_view.tsx index 3e263c2ef4d..18a262a48a7 100644 --- a/frontend/javascripts/viewer/view/action_bar_view.tsx +++ b/frontend/javascripts/viewer/view/action_bar_view.tsx @@ -42,10 +42,11 @@ import { layoutEmitter, } from "viewer/view/layouting/layout_persistence"; import type { StartAIJobModalState } from "./action-bar/ai_job_modals/constants"; -import { StartAIJobModal } from "./action-bar/ai_job_modals/start_ai_job_modal"; +// import { StartAIJobModal } from "./action-bar/ai_job_modals/start_ai_job_modal"; import ToolkitView from "./action-bar/tools/toolkit_switcher_view"; import ButtonComponent from "./components/button_component"; import { NumberSliderSetting } from "./components/setting_input_views"; +import { AiJobsDrawer } from "./ai_jobs/ai_jobs_drawer"; const VersionRestoreWarning = ( { }) } /> - + ); } diff --git a/frontend/javascripts/viewer/view/ai_jobs/ai_jobs_drawer.tsx b/frontend/javascripts/viewer/view/ai_jobs/ai_jobs_drawer.tsx new file mode 100644 index 00000000000..fd04c30a5fd --- /dev/null +++ b/frontend/javascripts/viewer/view/ai_jobs/ai_jobs_drawer.tsx @@ -0,0 +1,24 @@ +import { Drawer, Tabs } from "antd"; +import { AiImageSegmentation } from "./run_ai_model/ai_image_segmentation_job"; +import { AiModelTraining } from "./train_ai_model/ai_model_training_job"; +import { AiImageAlignment } from "./alignment/ai_image_alignment_job"; + +const { TabPane } = Tabs; + +export const AiJobsDrawer = () => { + return ( + + + + + + + + + + + + + + ); +}; diff --git a/frontend/javascripts/viewer/view/ai_jobs/alignment/ai_image_alignment_job.tsx b/frontend/javascripts/viewer/view/ai_jobs/alignment/ai_image_alignment_job.tsx new file mode 100644 index 00000000000..65e5be31776 --- /dev/null +++ b/frontend/javascripts/viewer/view/ai_jobs/alignment/ai_image_alignment_job.tsx @@ -0,0 +1,5 @@ +import { Empty } from "antd"; + +export const AiImageAlignment = () => { + return ; +}; diff --git a/frontend/javascripts/viewer/view/ai_jobs/bounding_box_selector.tsx b/frontend/javascripts/viewer/view/ai_jobs/bounding_box_selector.tsx new file mode 100644 index 00000000000..5d344ca15a8 --- /dev/null +++ b/frontend/javascripts/viewer/view/ai_jobs/bounding_box_selector.tsx @@ -0,0 +1,38 @@ +import { useWkSelector } from "libs/react_hooks"; +import type React from "react"; +import { useMemo } from "react"; +import { getColorLayers } from "viewer/model/accessors/dataset_accessor"; +import { getUserBoundingBoxesFromState } from "viewer/model/accessors/tracing_accessor"; +import type { UserBoundingBox } from "viewer/store"; +import { BoundingBoxSelection } from "viewer/view/action-bar/ai_job_modals/components/bounding_box_selection"; +import { getBoundingBoxesForLayers } from "viewer/view/action-bar/ai_job_modals/utils"; + +interface BoundingBoxSelectorProps { + value: UserBoundingBox | null; + onChange: (value: UserBoundingBox | null) => void; +} + +export const BoundingBoxSelector: React.FC = ({ value, onChange }) => { + const userBoundingBoxes = useWkSelector(getUserBoundingBoxesFromState); + const dataset = useWkSelector((state) => state.dataset); + const colorLayers = getColorLayers(dataset); + + const allBoundingBoxes = useMemo(() => { + const defaultBBs = getBoundingBoxesForLayers(colorLayers); + return defaultBBs.concat(userBoundingBoxes); + }, [colorLayers, userBoundingBoxes]); + + const handleChange = (id: number | null) => { + const selected = allBoundingBoxes.find((bb) => bb.id === id) || null; + onChange(selected); + }; + + return ( + + ); +}; diff --git a/frontend/javascripts/viewer/view/ai_jobs/credit_information.tsx b/frontend/javascripts/viewer/view/ai_jobs/credit_information.tsx new file mode 100644 index 00000000000..d45d5c34ba6 --- /dev/null +++ b/frontend/javascripts/viewer/view/ai_jobs/credit_information.tsx @@ -0,0 +1,88 @@ +import { type JobCreditCostInfo, getJobCreditCost } from "admin/rest_api"; +import { Button, Card, Col, Row, Spin, Typography } from "antd"; +import { formatCreditsString } from "libs/format_utils"; +import { useFetch } from "libs/react_helpers"; +import { useWkSelector } from "libs/react_hooks"; +import { computeArrayFromBoundingBox } from "libs/utils"; +import type React from "react"; +import { useRunAiModelJobContext } from "./run_ai_model/ai_image_segmentation_job_context"; + +const { Title, Text } = Typography; + +export const CreditInformation: React.FC = () => { + const { selectedModel, selectedJobType, selectedBoundingBox, handleStartAnalysis } = + useRunAiModelJobContext(); + const organizationCredits = useWkSelector( + (state) => state.activeOrganization?.creditBalance || "0", + ); + + const jobCreditCostInfo = useFetch( + async () => + selectedBoundingBox && selectedJobType + ? await getJobCreditCost( + selectedJobType, + computeArrayFromBoundingBox(selectedBoundingBox.boundingBox), + ) + : undefined, + undefined, + [selectedBoundingBox, selectedJobType], + ); + + const costInCredits = jobCreditCostInfo?.costInCredits; + + return ( + + + + Available Credits + + + + {formatCreditsString(organizationCredits)} + + + +
+ Cost Breakdown: + + + Selected Model: + + + {selectedModel ? selectedModel.name : "-"} + + + + + Dataset Size (Est.): + + + {selectedBoundingBox ? "Selected" : "-"} + + +
+ + + Total Cost: + + + {jobCreditCostInfo === undefined && selectedBoundingBox && selectedModel ? ( + + ) : ( + {costInCredits ? formatCreditsString(costInCredits) : "FREE"} + )} + + + +
+ ); +}; diff --git a/frontend/javascripts/viewer/view/ai_jobs/run_ai_model/ai_analysis_parameters.tsx b/frontend/javascripts/viewer/view/ai_jobs/run_ai_model/ai_analysis_parameters.tsx new file mode 100644 index 00000000000..6ed58108390 --- /dev/null +++ b/frontend/javascripts/viewer/view/ai_jobs/run_ai_model/ai_analysis_parameters.tsx @@ -0,0 +1,53 @@ +import { Card, Col, Input, Row, Select, Typography } from "antd"; +import { useWkSelector } from "libs/react_hooks"; +import type React from "react"; +import { getColorLayers } from "viewer/model/accessors/dataset_accessor"; +import { BoundingBoxSelector } from "../bounding_box_selector"; +import { useRunAiModelJobContext } from "./ai_image_segmentation_job_context"; + +const { Text } = Typography; + +export const AiAnalysisParameters: React.FC = () => { + const { + selectedBoundingBox, + setSelectedBoundingBox, + newDatasetName, + setNewDatasetName, + selectedLayerName, + setSelectedLayerName, + } = useRunAiModelJobContext(); + const dataset = useWkSelector((state) => state.dataset); + const colorLayers = getColorLayers(dataset); + + return ( + + + +
+ New Dataset Name + setNewDatasetName(e.target.value)} /> +
+
+ Image Data Layer + + + + + + + + + + ); +}; diff --git a/frontend/javascripts/viewer/view/ai_jobs/alignment/ai_image_alignment_job.tsx b/frontend/javascripts/viewer/view/ai_jobs/alignment/ai_image_alignment_job.tsx index 65e5be31776..9120ef431e9 100644 --- a/frontend/javascripts/viewer/view/ai_jobs/alignment/ai_image_alignment_job.tsx +++ b/frontend/javascripts/viewer/view/ai_jobs/alignment/ai_image_alignment_job.tsx @@ -1,5 +1,21 @@ -import { Empty } from "antd"; +import { Flex } from "antd"; +import { AlignmentCreditInformation } from "../credit_information"; +import { AlignmentJobContextProvider } from "./ai_alignment_job_context"; +import { AiAlignmentModelSelector } from "./ai_alignment_model_selector"; +import { AiAlignmentParameters } from "./ai_alignment_parameters"; -export const AiImageAlignment = () => { - return ; +export const AiImageAlignmentJob = () => { + return ( + + + + + + + + + + + + ); }; diff --git a/frontend/javascripts/viewer/view/ai_jobs/bounding_box_selector.tsx b/frontend/javascripts/viewer/view/ai_jobs/bounding_box_selector.tsx index 5d344ca15a8..a5a2ee3ebb5 100644 --- a/frontend/javascripts/viewer/view/ai_jobs/bounding_box_selector.tsx +++ b/frontend/javascripts/viewer/view/ai_jobs/bounding_box_selector.tsx @@ -8,8 +8,8 @@ import { BoundingBoxSelection } from "viewer/view/action-bar/ai_job_modals/compo import { getBoundingBoxesForLayers } from "viewer/view/action-bar/ai_job_modals/utils"; interface BoundingBoxSelectorProps { - value: UserBoundingBox | null; - onChange: (value: UserBoundingBox | null) => void; + value?: UserBoundingBox | null; + onChange?: (value: UserBoundingBox | null) => void; } export const BoundingBoxSelector: React.FC = ({ value, onChange }) => { @@ -31,7 +31,7 @@ export const BoundingBoxSelector: React.FC = ({ value, ); diff --git a/frontend/javascripts/viewer/view/ai_jobs/credit_information.tsx b/frontend/javascripts/viewer/view/ai_jobs/credit_information.tsx index 5cc8136766b..008c875acb7 100644 --- a/frontend/javascripts/viewer/view/ai_jobs/credit_information.tsx +++ b/frontend/javascripts/viewer/view/ai_jobs/credit_information.tsx @@ -8,25 +8,68 @@ import { useWkSelector } from "libs/react_hooks"; import { computeArrayFromBoundingBox } from "libs/utils"; import type React from "react"; import { useCallback, useMemo } from "react"; -import { APIJobType } from "types/api_types"; +import { APIJobType, type AiModel } from "types/api_types"; import BoundingBox from "viewer/model/bucket_data_handling/bounding_box"; +import type { UserBoundingBox } from "viewer/store"; +import { useAlignmentJobContext } from "./alignment/ai_alignment_job_context"; import { useRunAiModelJobContext } from "./run_ai_model/ai_image_segmentation_job_context"; const { Title, Text } = Typography; -export const CreditInformation: React.FC = () => { +export const RunAiModelCreditInformation: React.FC = () => { + const { selectedModel, selectedJobType, selectedBoundingBox, handleStartAnalysis } = + useRunAiModelJobContext(); + return ( + + ); +}; + +export const AlignmentCreditInformation: React.FC = () => { + const { selectedTask, selectedJobType, selectedBoundingBox, handleStartAnalysis } = + useAlignmentJobContext(); + return ( + + ); +}; + +interface CreditInformationProps { + selectedModel: AiModel | Partial | null; + selectedJobType: APIJobType | null; + selectedBoundingBox: UserBoundingBox | null; + handleStartAnalysis: () => void; + startButtonTitle: string; +} + +export const CreditInformation: React.FC = ({ + selectedModel, + selectedJobType, + selectedBoundingBox, + handleStartAnalysis, + startButtonTitle, +}) => { const jobTypeToCreditCostPerGVx: Partial> = useMemo( () => ({ [APIJobType.INFER_NUCLEI]: features().neuronInferralCostPerGVx, [APIJobType.INFER_NEURONS]: features().neuronInferralCostPerGVx, [APIJobType.INFER_MITOCHONDRIA]: features().mitochondriaInferralCostPerGVx, [APIJobType.INFER_INSTANCES]: features().neuronInferralCostPerGVx, + [APIJobType.ALIGN_SECTIONS]: features().alignmentCostPerGVx, }), [], ); - const { selectedModel, selectedJobType, selectedBoundingBox, handleStartAnalysis } = - useRunAiModelJobContext(); const organizationCredits = useWkSelector( (state) => state.activeOrganization?.creditBalance || "0", ); @@ -121,7 +164,7 @@ export const CreditInformation: React.FC = () => { disabled={!selectedModel || !selectedBoundingBox || !jobCreditCostInfo?.hasEnoughCredits} onClick={handleStartAnalysis} > - Start Analysis + {startButtonTitle} ); diff --git a/frontend/javascripts/viewer/view/ai_jobs/run_ai_model/ai_analysis_parameters.tsx b/frontend/javascripts/viewer/view/ai_jobs/run_ai_model/ai_analysis_parameters.tsx index 5df936c3501..fd02fec9e62 100644 --- a/frontend/javascripts/viewer/view/ai_jobs/run_ai_model/ai_analysis_parameters.tsx +++ b/frontend/javascripts/viewer/view/ai_jobs/run_ai_model/ai_analysis_parameters.tsx @@ -1,13 +1,28 @@ import { SettingOutlined } from "@ant-design/icons"; -import { Card, Col, Input, Row, Select, Space, Typography } from "antd"; +import { APIAiModelCategory } from "admin/rest_api"; +import { + Card, + Col, + Collapse, + ConfigProvider, + Form, + Input, + InputNumber, + Row, + Select, + Space, +} from "antd"; +import type { FormProps } from "antd"; import { useWkSelector } from "libs/react_hooks"; import type React from "react"; import { getColorLayers } from "viewer/model/accessors/dataset_accessor"; +import { + CollapsibleSplitMergerEvaluationSettings, + type SplitMergerEvaluationSettings, +} from "viewer/view/action-bar/ai_job_modals/components/collapsible_split_merger_evaluation_settings"; import { BoundingBoxSelector } from "../bounding_box_selector"; import { useRunAiModelJobContext } from "./ai_image_segmentation_job_context"; -const { Text, Title } = Typography; - export const AiAnalysisParameters: React.FC = () => { const { selectedBoundingBox, @@ -16,10 +31,69 @@ export const AiAnalysisParameters: React.FC = () => { setNewDatasetName, selectedLayerName, setSelectedLayerName, + selectedModel, + seedGeneratorDistanceThreshold, + setSeedGeneratorDistanceThreshold, + isEvaluationActive, + setIsEvaluationActive, + splitMergerEvaluationSettings, + setSplitMergerEvaluationSettings, + useAnnotation, + setUseAnnotation, } = useRunAiModelJobContext(); const dataset = useWkSelector((state) => state.dataset); const colorLayers = getColorLayers(dataset); + const handleValuesChange: FormProps["onValuesChange"] = (changedValues, allValues) => { + if (Object.prototype.hasOwnProperty.call(changedValues, "newDatasetName")) { + setNewDatasetName(changedValues.newDatasetName); + } + if (Object.prototype.hasOwnProperty.call(changedValues, "selectedLayerName")) { + setSelectedLayerName(changedValues.selectedLayerName); + } + if (Object.prototype.hasOwnProperty.call(changedValues, "selectedBoundingBox")) { + setSelectedBoundingBox(changedValues.selectedBoundingBox); + } + if (Object.prototype.hasOwnProperty.call(changedValues, "seedGeneratorDistanceThreshold")) { + setSeedGeneratorDistanceThreshold(changedValues.seedGeneratorDistanceThreshold); + } + if (Object.prototype.hasOwnProperty.call(allValues, "splitMergerEvaluationSettings")) { + setSplitMergerEvaluationSettings( + allValues.splitMergerEvaluationSettings as SplitMergerEvaluationSettings, + ); + } + if (Object.prototype.hasOwnProperty.call(changedValues, "useAnnotation")) { + setUseAnnotation(changedValues.useAnnotation); + } + }; + + const isInstanceModel = selectedModel?.category === APIAiModelCategory.EM_NUCLEI; + const isNeuronModel = selectedModel ? !isInstanceModel : false; + + const formFields = [ + { name: ["newDatasetName"], value: newDatasetName }, + { name: ["selectedLayerName"], value: selectedLayerName }, + { name: ["selectedBoundingBox"], value: selectedBoundingBox }, + { name: ["seedGeneratorDistanceThreshold"], value: seedGeneratorDistanceThreshold }, + { + name: ["splitMergerEvaluationSettings", "useSparseTracing"], + value: splitMergerEvaluationSettings?.useSparseTracing, + }, + { + name: ["splitMergerEvaluationSettings", "maxEdgeLength"], + value: splitMergerEvaluationSettings?.maxEdgeLength, + }, + { + name: ["splitMergerEvaluationSettings", "sparseTubeThresholdInNm"], + value: splitMergerEvaluationSettings?.sparseTubeThresholdInNm, + }, + { + name: ["splitMergerEvaluationSettings", "minimumMergerPathLengthInNm"], + value: splitMergerEvaluationSettings?.minimumMergerPathLengthInNm, + }, + { name: ["useAnnotation"], value: useAnnotation }, + ]; + return ( { } > - - -
- New Dataset Name - setNewDatasetName(e.target.value)} /> -
-
- Image Data Layer - + + + + + + ({ value: l.name, label: l.name }))} /> + + + + + + + { + if (value && value.elementClass === "uint24") { + return Promise.reject( + new Error( + "The selected layer of type uint24 is not supported. Please select a different one.", + ), + ); + } + return Promise.resolve(); + }, + }, + ]} > ({ value: l.name, label: l.name }))} /> + + + ({ + value: m, + label: `${m[0]}x, ${m[1]}x, ${m[2]}x`, + }))} + /> + + + + + ); +}; + +export const AiTrainingDataSection = () => { + return ( + + + Training Data + + } + > +
+ + + +
+ ); +}; diff --git a/frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_training_job.tsx b/frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_training_job.tsx index 2b4e66f3082..b7334792eb8 100644 --- a/frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_training_job.tsx +++ b/frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_training_job.tsx @@ -3,6 +3,7 @@ import { TrainingCreditInformation } from "../credit_information"; import { AiTrainingJobContextProvider } from "./ai_training_job_context"; import { AiTrainingModelSelector } from "./ai_training_model_selector"; import { AiTrainingParameters } from "./ai_training_parameters"; +import { AiTrainingDataSection } from "./ai_training_data_selector"; export const AiModelTrainingJob = () => { return ( @@ -10,6 +11,7 @@ export const AiModelTrainingJob = () => { + diff --git a/frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_training_job_context.tsx b/frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_training_job_context.tsx index e8079074dd0..7017424fc40 100644 --- a/frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_training_job_context.tsx +++ b/frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_training_job_context.tsx @@ -113,7 +113,7 @@ export const AiTrainingJobContextProvider: React.FC<{ children: React.ReactNode } }, [ modelName, - selectedTask, + selectedJobType, imageDataLayer, groundTruthLayer, magnification, diff --git a/frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_training_parameters.tsx b/frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_training_parameters.tsx index e211fbcc303..f0c95c0c20b 100644 --- a/frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_training_parameters.tsx +++ b/frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_training_parameters.tsx @@ -1,15 +1,8 @@ import { SettingOutlined } from "@ant-design/icons"; -import { Card, Col, Form, Input, InputNumber, Row, Select, Space } from "antd"; +import { Card, Col, Form, Input, InputNumber, Row, Space } from "antd"; import type { FormProps } from "antd"; -import { useWkSelector } from "libs/react_hooks"; import type React from "react"; -import { useMemo } from "react"; import { APIJobType } from "types/api_types"; -import { - getColorLayers, - getMagInfo, - getSegmentationLayers, -} from "viewer/model/accessors/dataset_accessor"; import { useAiTrainingJobContext } from "./ai_training_job_context"; @@ -17,12 +10,6 @@ export const AiTrainingParameters: React.FC = () => { const { modelName, setModelName, - imageDataLayer, - setImageDataLayer, - groundTruthLayer, - setGroundTruthLayer, - magnification, - setMagnification, comments, setComments, selectedTask, @@ -30,35 +17,10 @@ export const AiTrainingParameters: React.FC = () => { setMaxDistanceNm, } = useAiTrainingJobContext(); - const dataset = useWkSelector((state) => state.dataset); - const colorLayers = getColorLayers(dataset); - const segmentationLayers = getSegmentationLayers(dataset); - - const availableMagnifications = useMemo(() => { - if (!imageDataLayer) { - return []; - } - const selectedLayer = colorLayers.find((l) => l.name === imageDataLayer); - if (!selectedLayer) { - return []; - } - return getMagInfo(selectedLayer.resolutions).getMagList(); - }, [imageDataLayer, colorLayers]); - const handleValuesChange: FormProps["onValuesChange"] = (changedValues) => { if (Object.prototype.hasOwnProperty.call(changedValues, "modelName")) { setModelName(changedValues.modelName); } - if (Object.prototype.hasOwnProperty.call(changedValues, "imageDataLayer")) { - setImageDataLayer(changedValues.imageDataLayer); - setMagnification(null); - } - if (Object.prototype.hasOwnProperty.call(changedValues, "groundTruthLayer")) { - setGroundTruthLayer(changedValues.groundTruthLayer); - } - if (Object.prototype.hasOwnProperty.call(changedValues, "magnification")) { - setMagnification(changedValues.magnification); - } if (Object.prototype.hasOwnProperty.call(changedValues, "comments")) { setComments(changedValues.comments); } @@ -69,9 +31,6 @@ export const AiTrainingParameters: React.FC = () => { const formFields = [ { name: ["modelName"], value: modelName }, - { name: ["imageDataLayer"], value: imageDataLayer }, - { name: ["groundTruthLayer"], value: groundTruthLayer }, - { name: ["magnification"], value: magnification }, { name: ["comments"], value: comments }, { name: ["maxDistanceNm"], value: maxDistanceNm }, ]; @@ -95,38 +54,6 @@ export const AiTrainingParameters: React.FC = () => { > - - ({ value: l.name, label: l.name }))} /> - - - - - ({ value: l.name, label: l.name }))} /> + ({ value: l.name, label: l.name }))} /> + ({ - value: m, - label: `${m[0]}x, ${m[1]}x, ${m[2]}x`, + disabled={!selection?.imageDataLayer || !selection?.groundTruthLayer} + options={availableMagnifications.map((m, index) => ({ + value: index, + label: `${m[0]}-${m[1]}-${m[2]}`, }))} + value={selection?.magnification} + onChange={(index: number) => + handleSelectionChange(annotationId, { + magnification: availableMagnifications[index], + }) + } /> @@ -96,8 +150,10 @@ export const AiTrainingDataSection = () => { } >
- - + {annotationInfos.map((info) => { + const annotationId = "id" in info.annotation ? info.annotation.id : info.annotation; + return ; + })} ); diff --git a/frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_training_job.tsx b/frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_training_job.tsx index b7334792eb8..ea172263b92 100644 --- a/frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_training_job.tsx +++ b/frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_training_job.tsx @@ -1,9 +1,9 @@ import { Flex } from "antd"; import { TrainingCreditInformation } from "../credit_information"; +import { AiTrainingDataSection } from "./ai_training_data_selector"; import { AiTrainingJobContextProvider } from "./ai_training_job_context"; import { AiTrainingModelSelector } from "./ai_training_model_selector"; import { AiTrainingParameters } from "./ai_training_parameters"; -import { AiTrainingDataSection } from "./ai_training_data_selector"; export const AiModelTrainingJob = () => { return ( diff --git a/frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_training_job_context.tsx b/frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_training_job_context.tsx index e8b9a5d5685..446f1b78aae 100644 --- a/frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_training_job_context.tsx +++ b/frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_training_job_context.tsx @@ -7,15 +7,22 @@ import { import { useWkSelector } from "libs/react_hooks"; import Toast from "libs/toast"; import type React from "react"; -import { createContext, useCallback, useContext, useState } from "react"; +import { createContext, useCallback, useContext, useEffect, useState } from "react"; import { useDispatch } from "react-redux"; -import { APIAnnotation, APIJobType } from "types/api_types"; +import { type APIAnnotation, APIJobType } from "types/api_types"; import type { Vector3 } from "viewer/constants"; import { getUserBoundingBoxesFromState } from "viewer/model/accessors/tracing_accessor"; import { setAIJobModalStateAction } from "viewer/model/actions/ui_actions"; import type { UserBoundingBox } from "viewer/store"; +import type { AnnotationInfoForAITrainingJob } from "viewer/view/action-bar/ai_job_modals/utils"; import type { AiTrainingTask } from "./ai_training_model_selector"; -import { AnnotationInfoForAITrainingJob } from "viewer/view/action-bar/ai_job_modals/utils"; + +export interface AiTrainingAnnotationSelection { + annotationId: string; + imageDataLayer?: string; + groundTruthLayer?: string; + magnification?: Vector3; +} interface AiTrainingJobContextType { handleStartAnalysis: () => void; @@ -31,6 +38,13 @@ interface AiTrainingJobContextType { setComments: (comments: string) => void; maxDistanceNm: number; setMaxDistanceNm: (dist: number) => void; + + annotationInfos: AnnotationInfoForAITrainingJob[]; + selections: AiTrainingAnnotationSelection[]; + handleSelectionChange: ( + annotationId: string, + newValues: Partial>, + ) => void; } const AiTrainingJobContext = createContext(undefined); @@ -45,29 +59,82 @@ export const AiTrainingJobContextProvider: React.FC<{ children: React.ReactNode const [annotationInfos, setAnnotationInfos] = useState< AnnotationInfoForAITrainingJob[] >([]); + const [selections, setSelections] = useState([]); const [comments, setComments] = useState(""); const [maxDistanceNm, setMaxDistanceNm] = useState(1000.0); const dispatch = useDispatch(); - const annotationId = useWkSelector((state) => state.annotation.annotationId); - const selectedBoundingBoxes = useWkSelector((state) => getUserBoundingBoxesFromState(state)); + const annotation = useWkSelector((state) => state.annotation); + const dataset = useWkSelector((state) => state.dataset); + const userBoundingBoxes = useWkSelector((state) => getUserBoundingBoxesFromState(state)); + + useEffect(() => { + // Initialize with current annotation if nothing is selected + if (annotationInfos.length === 0 && userBoundingBoxes) { + if (dataset) { + setAnnotationInfos([ + { + annotation: annotation as unknown as APIAnnotation, + dataset, + volumeTracings: annotation.volumes, + volumeTracingMags: [], + userBoundingBoxes, + }, + ]); + } + } + }, [annotation, dataset, userBoundingBoxes, annotationInfos.length]); + + useEffect(() => { + if (annotationInfos) { + setSelections( + annotationInfos.map((info) => ({ + annotationId: "id" in info.annotation ? info.annotation.id : info.annotation.annotationId, + })), + ); + } + }, [annotationInfos]); + + const handleSelectionChange = useCallback( + ( + annotationId: string, + newValues: Partial>, + ) => { + setSelections((prev) => { + const newSelections = [...prev]; + const index = newSelections.findIndex((s) => s.annotationId === annotationId); + if (index > -1) { + newSelections[index] = { ...newSelections[index], ...newValues }; + // When a layer changes, reset magnification + if (newValues.imageDataLayer || newValues.groundTruthLayer) { + delete newSelections[index].magnification; + } + } + return newSelections; + }); + }, + [], + ); const handleStartAnalysis = useCallback(async () => { - if (!modelName || !selectedJobType || !imageDataLayer || !groundTruthLayer || !magnification) { + if (!modelName || !selectedJobType) { Toast.error("Please fill all required fields."); return; } - const trainingAnnotations: AiModelTrainingAnnotationSpecification[] = annotationInfos.map( - (annotationInfo) => { - return { - annotationId: annotationInfo.annotation.annotationId, - colorLayerName: annotationInfo.imageDataLayer, - segmentationLayerName: annotationInfo.groundTruthLayer, - mag: annotationInfo.magnification, - }; - }, + if (selections.some((s) => !s.imageDataLayer || !s.groundTruthLayer || !s.magnification)) { + Toast.error("Please fill all required fields for all annotations."); + return; + } + + const trainingAnnotations: AiModelTrainingAnnotationSpecification[] = selections.map( + (selection) => ({ + annotationId: selection.annotationId, + colorLayerName: selection.imageDataLayer!, + segmentationLayerName: selection.groundTruthLayer!, + mag: selection.magnification!, + }), ); const commonJobArgmuments = { @@ -95,37 +162,23 @@ export const AiTrainingJobContextProvider: React.FC<{ children: React.ReactNode console.error(error); Toast.error("Failed to start training."); } - }, [ - modelName, - selectedJobType, - imageDataLayer, - groundTruthLayer, - magnification, - comments, - maxDistanceNm, - annotationId, - dispatch, - ]); + }, [modelName, selectedJobType, selections, comments, maxDistanceNm, dispatch]); const value = { selectedJobType, selectedTask, - selectedBoundingBoxes, setSelectedJobType, setSelectedTask, handleStartAnalysis, modelName, setModelName, - imageDataLayer, - setImageDataLayer, - groundTruthLayer, - setGroundTruthLayer, - magnification, - setMagnification, comments, setComments, maxDistanceNm, setMaxDistanceNm, + annotationInfos, + selections, + handleSelectionChange, }; return {children}; From 50e1f6befe3accb244e4a06ce73b815e74ac950d Mon Sep 17 00:00:00 2001 From: Tom Herold Date: Tue, 2 Sep 2025 14:39:24 +0200 Subject: [PATCH 10/72] properly connect AI job drawer open/close state --- frontend/javascripts/viewer/default_state.ts | 2 +- .../viewer/model/actions/ui_actions.ts | 10 +-- .../viewer/model/reducers/ui_reducer.ts | 4 +- frontend/javascripts/viewer/store.ts | 4 +- .../action-bar/ai_job_modals/constants.ts | 11 ++- .../forms/align_sections_form.tsx | 4 +- .../forms/custom_ai_model_inference_form.tsx | 4 +- .../forms/mitochondria_segmentation_form.tsx | 4 +- .../forms/neuron_segmentation_form.tsx | 4 +- .../forms/nuclei_detection_form.tsx | 4 +- .../ai_job_modals/start_ai_job_modal.tsx | 18 ++--- .../ai_job_modals/tabs/run_ai_model_tab.tsx | 24 +++---- .../viewer/view/action_bar_view.tsx | 67 ++++++++++++++----- .../viewer/view/ai_jobs/ai_jobs_drawer.tsx | 23 +++++-- .../alignment/ai_alignment_job_context.tsx | 4 +- .../view/ai_jobs/credit_information.tsx | 6 +- .../ai_image_segmentation_job_context.tsx | 4 +- .../ai_training_job_context.tsx | 4 +- 18 files changed, 121 insertions(+), 80 deletions(-) diff --git a/frontend/javascripts/viewer/default_state.ts b/frontend/javascripts/viewer/default_state.ts index caf89b83baf..e32f5ca4072 100644 --- a/frontend/javascripts/viewer/default_state.ts +++ b/frontend/javascripts/viewer/default_state.ts @@ -245,7 +245,7 @@ const defaultState: WebknossosState = { showMergeAnnotationModal: false, showZarrPrivateLinksModal: false, showPythonClientModal: false, - aIJobModalState: "invisible", + aIJobDrawerState: "invisible", showRenderAnimationModal: false, showShareModal: false, storedLayouts: {}, diff --git a/frontend/javascripts/viewer/model/actions/ui_actions.ts b/frontend/javascripts/viewer/model/actions/ui_actions.ts index b71682cb42e..1d72401c0ee 100644 --- a/frontend/javascripts/viewer/model/actions/ui_actions.ts +++ b/frontend/javascripts/viewer/model/actions/ui_actions.ts @@ -1,7 +1,7 @@ import type { OrthoView, Vector3 } from "viewer/constants"; import type { AnnotationTool } from "viewer/model/accessors/tool_accessor"; import type { BorderOpenStatus, Theme, WebknossosState } from "viewer/store"; -import type { StartAIJobModalState } from "viewer/view/action-bar/ai_job_modals/constants"; +import type { StartAiJobDrawerState } from "viewer/view/action-bar/ai_job_modals/constants"; type SetDropzoneModalVisibilityAction = ReturnType; type SetVersionRestoreVisibilityAction = ReturnType; @@ -18,7 +18,7 @@ type SetShareModalVisibilityAction = ReturnType; type SetBusyBlockingInfoAction = ReturnType; type SetPythonClientModalVisibilityAction = ReturnType; -type SetAIJobModalStateAction = ReturnType; +type SetaIJobDrawerStateAction = ReturnType; export type EnterAction = ReturnType; export type EscapeAction = ReturnType; export type SetQuickSelectStateAction = ReturnType; @@ -53,7 +53,7 @@ export type UiAction = | SetDownloadModalVisibilityAction | SetPythonClientModalVisibilityAction | SetShareModalVisibilityAction - | SetAIJobModalStateAction + | SetaIJobDrawerStateAction | SetRenderAnimationModalVisibilityAction | SetMergeModalVisibilityAction | SetUserScriptsModalVisibilityAction @@ -133,9 +133,9 @@ export const setShareModalVisibilityAction = (visible: boolean) => type: "SET_SHARE_MODAL_VISIBILITY", visible, }) as const; -export const setAIJobModalStateAction = (state: StartAIJobModalState) => +export const setAIJobDrawerStateAction = (state: StartAiJobDrawerState) => ({ - type: "SET_AI_JOB_MODAL_STATE", + type: "SET_AI_JOB_DRAWER_STATE", state, }) as const; export const setRenderAnimationModalVisibilityAction = (visible: boolean) => diff --git a/frontend/javascripts/viewer/model/reducers/ui_reducer.ts b/frontend/javascripts/viewer/model/reducers/ui_reducer.ts index 1dea24ea337..d517c623bab 100644 --- a/frontend/javascripts/viewer/model/reducers/ui_reducer.ts +++ b/frontend/javascripts/viewer/model/reducers/ui_reducer.ts @@ -101,9 +101,9 @@ function UiReducer(state: WebknossosState, action: Action): WebknossosState { }); } - case "SET_AI_JOB_MODAL_STATE": { + case "SET_AI_JOB_DRAWER_STATE": { return updateKey(state, "uiInformation", { - aIJobModalState: action.state, + aIJobDrawerState: action.state, }); } diff --git a/frontend/javascripts/viewer/store.ts b/frontend/javascripts/viewer/store.ts index fa6d4846488..41657fa3977 100644 --- a/frontend/javascripts/viewer/store.ts +++ b/frontend/javascripts/viewer/store.ts @@ -81,7 +81,7 @@ import VolumeTracingReducer from "viewer/model/reducers/volumetracing_reducer"; import { eventEmitterMiddleware } from "./model/helpers/event_emitter_middleware"; import FlycamInfoCacheReducer from "./model/reducers/flycam_info_cache_reducer"; import OrganizationReducer from "./model/reducers/organization_reducer"; -import type { StartAIJobModalState } from "./view/action-bar/ai_job_modals/constants"; +import type { StartAiJobDrawerState } from "./view/action-bar/ai_job_modals/constants"; export type { BoundingBoxObject } from "types/bounding_box"; @@ -494,7 +494,7 @@ type UiInformation = { readonly showMergeAnnotationModal: boolean; readonly showZarrPrivateLinksModal: boolean; readonly showAddScriptModal: boolean; - readonly aIJobModalState: StartAIJobModalState; + readonly aIJobDrawerState: StartAiJobDrawerState; readonly showRenderAnimationModal: boolean; readonly activeTool: AnnotationTool; readonly activeUserBoundingBoxId: number | null | undefined; diff --git a/frontend/javascripts/viewer/view/action-bar/ai_job_modals/constants.ts b/frontend/javascripts/viewer/view/action-bar/ai_job_modals/constants.ts index 124e96986f2..653042535ec 100644 --- a/frontend/javascripts/viewer/view/action-bar/ai_job_modals/constants.ts +++ b/frontend/javascripts/viewer/view/action-bar/ai_job_modals/constants.ts @@ -1,12 +1,11 @@ import { APIJobType } from "types/api_types"; import type { Vector3 } from "viewer/constants"; -export type ModalJobTypes = - | APIJobType.INFER_NEURONS - | APIJobType.INFER_NUCLEI - | APIJobType.INFER_MITOCHONDRIA; - -export type StartAIJobModalState = ModalJobTypes | "invisible"; +export type StartAiJobDrawerState = + | "open_ai_training" + | "open_ai_inference" + | "open_ai_alignment" + | "invisible"; // "materialize_volume_annotation" is only used in this module export const jobNameToImagePath = { diff --git a/frontend/javascripts/viewer/view/action-bar/ai_job_modals/forms/align_sections_form.tsx b/frontend/javascripts/viewer/view/action-bar/ai_job_modals/forms/align_sections_form.tsx index 9ee71cf16ea..6668e50384c 100644 --- a/frontend/javascripts/viewer/view/action-bar/ai_job_modals/forms/align_sections_form.tsx +++ b/frontend/javascripts/viewer/view/action-bar/ai_job_modals/forms/align_sections_form.tsx @@ -5,7 +5,7 @@ import { useWkSelector } from "libs/react_hooks"; import { useCallback } from "react"; import { useDispatch } from "react-redux"; import { APIJobType } from "types/api_types"; -import { setAIJobModalStateAction } from "viewer/model/actions/ui_actions"; +import { setAIJobDrawerStateAction } from "viewer/model/actions/ui_actions"; import { type JobApiCallArgsType, StartJobForm } from "./start_job_form"; export function AlignSectionsForm() { @@ -14,7 +14,7 @@ export function AlignSectionsForm() { const { alignmentCostPerGVx } = features(); const handleClose = useCallback( - () => dispatch(setAIJobModalStateAction("invisible")), + () => dispatch(setAIJobDrawerStateAction("invisible")), [dispatch], ); const jobApiCall = useCallback( diff --git a/frontend/javascripts/viewer/view/action-bar/ai_job_modals/forms/custom_ai_model_inference_form.tsx b/frontend/javascripts/viewer/view/action-bar/ai_job_modals/forms/custom_ai_model_inference_form.tsx index 93e2b57928e..cd6768b9bb0 100644 --- a/frontend/javascripts/viewer/view/action-bar/ai_job_modals/forms/custom_ai_model_inference_form.tsx +++ b/frontend/javascripts/viewer/view/action-bar/ai_job_modals/forms/custom_ai_model_inference_form.tsx @@ -13,7 +13,7 @@ import { useCallback, useState } from "react"; import { useDispatch } from "react-redux"; import { APIJobType } from "types/api_types"; import { ControlModeEnum } from "viewer/constants"; -import { setAIJobModalStateAction } from "viewer/model/actions/ui_actions"; +import { setAIJobDrawerStateAction } from "viewer/model/actions/ui_actions"; import { ExperimentalInferenceAlert } from "../components/experimental_inference_alert"; import { type JobApiCallArgsType, StartJobForm } from "./start_job_form"; @@ -87,7 +87,7 @@ export function CustomAiModelInferenceForm() { ); const handleClose = useCallback( - () => dispatch(setAIJobModalStateAction("invisible")), + () => dispatch(setAIJobDrawerStateAction("invisible")), [dispatch], ); diff --git a/frontend/javascripts/viewer/view/action-bar/ai_job_modals/forms/mitochondria_segmentation_form.tsx b/frontend/javascripts/viewer/view/action-bar/ai_job_modals/forms/mitochondria_segmentation_form.tsx index d21aa639894..d9470c40bf4 100644 --- a/frontend/javascripts/viewer/view/action-bar/ai_job_modals/forms/mitochondria_segmentation_form.tsx +++ b/frontend/javascripts/viewer/view/action-bar/ai_job_modals/forms/mitochondria_segmentation_form.tsx @@ -6,7 +6,7 @@ import { computeArrayFromBoundingBox } from "libs/utils"; import { useCallback } from "react"; import { useDispatch } from "react-redux"; import { APIJobType } from "types/api_types"; -import { setAIJobModalStateAction } from "viewer/model/actions/ui_actions"; +import { setAIJobDrawerStateAction } from "viewer/model/actions/ui_actions"; import { ExperimentalInferenceAlert } from "../components/experimental_inference_alert"; import { getBestFittingMagComparedToTrainingDS, isDatasetOrBoundingBoxTooSmall } from "../utils"; import { type JobApiCallArgsType, StartJobForm } from "./start_job_form"; @@ -17,7 +17,7 @@ export function MitochondriaSegmentationForm() { const dispatch = useDispatch(); const handleClose = useCallback( - () => dispatch(setAIJobModalStateAction("invisible")), + () => dispatch(setAIJobDrawerStateAction("invisible")), [dispatch], ); const jobApiCall = useCallback( diff --git a/frontend/javascripts/viewer/view/action-bar/ai_job_modals/forms/neuron_segmentation_form.tsx b/frontend/javascripts/viewer/view/action-bar/ai_job_modals/forms/neuron_segmentation_form.tsx index d6c51ee9802..5ced1acc9f2 100644 --- a/frontend/javascripts/viewer/view/action-bar/ai_job_modals/forms/neuron_segmentation_form.tsx +++ b/frontend/javascripts/viewer/view/action-bar/ai_job_modals/forms/neuron_segmentation_form.tsx @@ -13,7 +13,7 @@ import { getTaskBoundingBoxes, getUserBoundingBoxesFromState, } from "viewer/model/accessors/tracing_accessor"; -import { setAIJobModalStateAction } from "viewer/model/actions/ui_actions"; +import { setAIJobDrawerStateAction } from "viewer/model/actions/ui_actions"; import { CollapsibleSplitMergerEvaluationSettings, type SplitMergerEvaluationSettings, @@ -35,7 +35,7 @@ export function NeuronSegmentationForm() { const taskBoundingBoxes = useWkSelector(getTaskBoundingBoxes); const handleClose = useCallback( - () => dispatch(setAIJobModalStateAction("invisible")), + () => dispatch(setAIJobDrawerStateAction("invisible")), [dispatch], ); const jobApiCall = useCallback( diff --git a/frontend/javascripts/viewer/view/action-bar/ai_job_modals/forms/nuclei_detection_form.tsx b/frontend/javascripts/viewer/view/action-bar/ai_job_modals/forms/nuclei_detection_form.tsx index 984d7d0aecb..d5fb0fd4d12 100644 --- a/frontend/javascripts/viewer/view/action-bar/ai_job_modals/forms/nuclei_detection_form.tsx +++ b/frontend/javascripts/viewer/view/action-bar/ai_job_modals/forms/nuclei_detection_form.tsx @@ -5,7 +5,7 @@ import { useCallback } from "react"; import { useDispatch } from "react-redux"; import { APIJobType } from "types/api_types"; import { Unicode } from "viewer/constants"; -import { setAIJobModalStateAction } from "viewer/model/actions/ui_actions"; +import { setAIJobDrawerStateAction } from "viewer/model/actions/ui_actions"; import { getBestFittingMagComparedToTrainingDS, isDatasetOrBoundingBoxTooSmall } from "../utils"; import { type JobApiCallArgsType, StartJobForm } from "./start_job_form"; @@ -16,7 +16,7 @@ export function NucleiDetectionForm() { const dispatch = useDispatch(); const handleClose = useCallback( - () => dispatch(setAIJobModalStateAction("invisible")), + () => dispatch(setAIJobDrawerStateAction("invisible")), [dispatch], ); const jobApiCall = useCallback( diff --git a/frontend/javascripts/viewer/view/action-bar/ai_job_modals/start_ai_job_modal.tsx b/frontend/javascripts/viewer/view/action-bar/ai_job_modals/start_ai_job_modal.tsx index baccb49dad3..af60a77d19c 100644 --- a/frontend/javascripts/viewer/view/action-bar/ai_job_modals/start_ai_job_modal.tsx +++ b/frontend/javascripts/viewer/view/action-bar/ai_job_modals/start_ai_job_modal.tsx @@ -3,19 +3,19 @@ import { useWkSelector } from "libs/react_hooks"; import _ from "lodash"; import { useCallback, useMemo } from "react"; import { useDispatch } from "react-redux"; -import { setAIJobModalStateAction } from "viewer/model/actions/ui_actions"; -import type { StartAIJobModalState } from "./constants"; +import { setAIJobDrawerStateAction } from "viewer/model/actions/ui_actions"; +import type { StartAiJobDrawerState } from "./constants"; import { AlignmentTab } from "./tabs/alignment_tab"; import { RunAiModelTab } from "./tabs/run_ai_model_tab"; import { TrainAiModelFromAnnotationTab } from "./tabs/train_ai_model_tab"; -export type StartAIJobModalProps = { - aIJobModalState: StartAIJobModalState; +export type StartAIJobDrawerProps = { + aIJobDrawerState: StartAiJobDrawerState; }; -export function StartAIJobModal({ aIJobModalState }: StartAIJobModalProps) { +export function StartAIJobModal({ aIJobDrawerState }: StartAIJobDrawerProps) { const dispatch = useDispatch(); - const onClose = useCallback(() => dispatch(setAIJobModalStateAction("invisible")), [dispatch]); + const onClose = useCallback(() => dispatch(setAIJobDrawerStateAction("invisible")), [dispatch]); const isSuperUser = useWkSelector((state) => state.activeUser?.isSuperUser || false); const tabs = useMemo( () => @@ -23,7 +23,7 @@ export function StartAIJobModal({ aIJobModalState }: StartAIJobModalProps) { { label: "Run a model", key: "runModel", - children: , + children: , }, isSuperUser ? { @@ -38,9 +38,9 @@ export function StartAIJobModal({ aIJobModalState }: StartAIJobModalProps) { children: , }, ]), - [isSuperUser, aIJobModalState, onClose], + [isSuperUser, aIJobDrawerState, onClose], ); - return aIJobModalState !== "invisible" ? ( + return aIJobDrawerState !== "invisible" ? ( dispatch(setAIJobModalStateAction(APIJobType.INFER_NEURONS))} + checked={aIJobDrawerState === APIJobType.INFER_NEURONS} + onClick={() => dispatch(setAIJobDrawerStateAction(APIJobType.INFER_NEURONS))} > @@ -85,8 +85,8 @@ export function RunAiModelTab({ aIJobModalState }: { aIJobModalState: string }) dispatch(setAIJobModalStateAction(APIJobType.INFER_MITOCHONDRIA))} + checked={aIJobDrawerState === APIJobType.INFER_MITOCHONDRIA} + onClick={() => dispatch(setAIJobDrawerStateAction(APIJobType.INFER_MITOCHONDRIA))} > @@ -106,8 +106,8 @@ export function RunAiModelTab({ aIJobModalState }: { aIJobModalState: string }) dispatch(setAIJobModalStateAction(APIJobType.INFER_NUCLEI))} + checked={aIJobDrawerState === APIJobType.INFER_NUCLEI} + onClick={() => dispatch(setAIJobDrawerStateAction(APIJobType.INFER_NUCLEI))} > @@ -124,12 +124,12 @@ export function RunAiModelTab({ aIJobModalState }: { aIJobModalState: string }) - {aIJobModalState === APIJobType.INFER_NEURONS ? : null} - {aIJobModalState === APIJobType.INFER_NUCLEI ? : null} - {aIJobModalState === APIJobType.INFER_MITOCHONDRIA ? ( + {aIJobDrawerState === APIJobType.INFER_NEURONS ? : null} + {aIJobDrawerState === APIJobType.INFER_NUCLEI ? : null} + {aIJobDrawerState === APIJobType.INFER_MITOCHONDRIA ? ( ) : null} - {aIJobModalState === APIJobType.ALIGN_SECTIONS ? : null} + {aIJobDrawerState === APIJobType.ALIGN_SECTIONS ? : null} )} diff --git a/frontend/javascripts/viewer/view/action_bar_view.tsx b/frontend/javascripts/viewer/view/action_bar_view.tsx index bcf4e87329f..849b7715b51 100644 --- a/frontend/javascripts/viewer/view/action_bar_view.tsx +++ b/frontend/javascripts/viewer/view/action_bar_view.tsx @@ -1,6 +1,6 @@ import { withAuthentication } from "admin/auth/authentication_modal"; import { createExplorational } from "admin/rest_api"; -import { Alert, Modal, Popover, Space } from "antd"; +import { Alert, Button, Dropdown, Modal, Popover, Space } from "antd"; import { AsyncButton, type AsyncButtonProps } from "components/async_clickables"; import { NewVolumeLayerSelection } from "dashboard/advanced_dataset/create_explorative_modal"; import { useWkSelector } from "libs/react_hooks"; @@ -23,7 +23,7 @@ import { is2dDataset, } from "viewer/model/accessors/dataset_accessor"; import { setAdditionalCoordinatesAction } from "viewer/model/actions/flycam_actions"; -import { setAIJobModalStateAction } from "viewer/model/actions/ui_actions"; +import { setAIJobDrawerStateAction } from "viewer/model/actions/ui_actions"; import type { WebknossosState } from "viewer/store"; import Store from "viewer/store"; import AddNewLayoutModal from "viewer/view/action-bar/add_new_layout_modal"; @@ -41,7 +41,7 @@ import { getLayoutConfig, layoutEmitter, } from "viewer/view/layouting/layout_persistence"; -import type { StartAIJobModalState } from "./action-bar/ai_job_modals/constants"; +import type { StartAiJobDrawerState } from "./action-bar/ai_job_modals/constants"; // import { StartAIJobModal } from "./action-bar/ai_job_modals/start_ai_job_modal"; import ToolkitView from "./action-bar/tools/toolkit_switcher_view"; import { AiJobsDrawer } from "./ai_jobs/ai_jobs_drawer"; @@ -64,7 +64,7 @@ type StateProps = { showVersionRestore: boolean; is2d: boolean; viewMode: ViewMode; - aiJobModalState: StartAIJobModalState; + aiJobDrawerState: StartAiJobDrawerState; }; type OwnProps = { layoutProps: LayoutProps; @@ -292,17 +292,52 @@ class ActionBarView extends React.PureComponent { renderStartAIJobButton(disabled: boolean, tooltipTextIfDisabled: string): React.ReactNode { const tooltipText = disabled ? tooltipTextIfDisabled : "Start a processing job using AI"; + const menuItems = [ + { + key: "open_ai_inference_button", + onClick: () => Store.dispatch(setAIJobDrawerStateAction("open_ai_inference")), + label: "Run AI model", + }, + { + key: "open_ai_training_button", + onClick: () => Store.dispatch(setAIJobDrawerStateAction("open_ai_training")), + label: "Train AI model", + }, + { + key: "open_ai_alignment_button", + onClick: () => Store.dispatch(setAIJobDrawerStateAction("open_ai_alignment")), + label: "Run AI Alignment", + }, + ]; + return ( - Store.dispatch(setAIJobModalStateAction(APIJobType.INFER_NEURONS))} - style={{ marginLeft: 12, pointerEvents: "auto" }} - disabled={disabled} - title={tooltipText} - icon={} - > - AI Analysis - + <> + Store.dispatch(setAIJobDrawerStateAction("open_ai_inference"))} + style={{ marginLeft: 12, pointerEvents: "auto" }} + disabled={disabled} + title={tooltipText} + icon={} + > + AI Analysis + + + + + ); } @@ -377,7 +412,7 @@ class ActionBarView extends React.PureComponent { }) } /> - + ); } @@ -390,7 +425,7 @@ const mapStateToProps = (state: WebknossosState): StateProps => ({ showVersionRestore: state.uiInformation.showVersionRestore, is2d: is2dDataset(state.dataset), viewMode: state.temporaryConfiguration.viewMode, - aiJobModalState: state.uiInformation.aIJobModalState, + aiJobDrawerState: state.uiInformation.aIJobDrawerState, }); const connector = connect(mapStateToProps); diff --git a/frontend/javascripts/viewer/view/ai_jobs/ai_jobs_drawer.tsx b/frontend/javascripts/viewer/view/ai_jobs/ai_jobs_drawer.tsx index 43b11575489..b17e07e343f 100644 --- a/frontend/javascripts/viewer/view/ai_jobs/ai_jobs_drawer.tsx +++ b/frontend/javascripts/viewer/view/ai_jobs/ai_jobs_drawer.tsx @@ -1,30 +1,41 @@ import { Drawer, Tabs } from "antd"; +import { useWkSelector } from "libs/react_hooks"; +import { useCallback } from "react"; +import { useDispatch } from "react-redux"; +import { setAIJobDrawerStateAction } from "viewer/model/actions/ui_actions"; import { AiImageAlignmentJob } from "./alignment/ai_image_alignment_job"; import { AiImageSegmentationJob } from "./run_ai_model/ai_image_segmentation_job"; import { AiModelTrainingJob } from "./train_ai_model/ai_training_job"; -export const AiJobsDrawer = () => { +export const AiJobsDrawer = ({ isOpen }: { isOpen: boolean }) => { + const dispatch = useDispatch(); + const ai_job_drawer_state = useWkSelector((state) => state.uiInformation.aIJobDrawerState); + + const handleClose = useCallback(() => { + dispatch(setAIJobDrawerStateAction("invisible")); + }, [dispatch]); + const items = [ { label: "Image Segmentation", - key: "1", + key: "open_ai_inference", children: , }, { label: "Model Training", - key: "2", + key: "open_ai_training", children: , }, { label: "Image Alignment", - key: "3", + key: "open_ai_alignment", children: , }, ]; return ( - - + + ); }; diff --git a/frontend/javascripts/viewer/view/ai_jobs/alignment/ai_alignment_job_context.tsx b/frontend/javascripts/viewer/view/ai_jobs/alignment/ai_alignment_job_context.tsx index d23b3f258ac..cfd2e5bba9f 100644 --- a/frontend/javascripts/viewer/view/ai_jobs/alignment/ai_alignment_job_context.tsx +++ b/frontend/javascripts/viewer/view/ai_jobs/alignment/ai_alignment_job_context.tsx @@ -6,7 +6,7 @@ import { createContext, useCallback, useContext, useEffect, useState } from "rea import { useDispatch } from "react-redux"; import type { APIJobType } from "types/api_types"; import { getColorLayers } from "viewer/model/accessors/dataset_accessor"; -import { setAIJobModalStateAction } from "viewer/model/actions/ui_actions"; +import { setAIJobDrawerStateAction } from "viewer/model/actions/ui_actions"; import type { UserBoundingBox } from "viewer/store"; import { getBoundingBoxesForLayers } from "viewer/view/action-bar/ai_job_modals/utils"; import type { AlignmentTask } from "./ai_alignment_model_selector"; @@ -57,7 +57,7 @@ export const AlignmentJobContextProvider: React.FC<{ children: React.ReactNode } shouldUseManualMatches ? annotationId : undefined, ); Toast.success("Alignment started successfully!"); - dispatch(setAIJobModalStateAction("invisible")); + dispatch(setAIJobDrawerStateAction("invisible")); } catch (error) { console.error(error); Toast.error("Failed to start alignment."); diff --git a/frontend/javascripts/viewer/view/ai_jobs/credit_information.tsx b/frontend/javascripts/viewer/view/ai_jobs/credit_information.tsx index f3d1bf66e62..e85c3b527cb 100644 --- a/frontend/javascripts/viewer/view/ai_jobs/credit_information.tsx +++ b/frontend/javascripts/viewer/view/ai_jobs/credit_information.tsx @@ -5,11 +5,7 @@ import { Button, Card, Col, Row, Space, Spin, Typography } from "antd"; import features from "features"; import { formatCreditsString, formatVoxels } from "libs/format_utils"; import { useWkSelector } from "libs/react_hooks"; -import { - computeArrayFromBoundingBox, - computeShapeFromBoundingBox, - computeVolumeFromBoundingBox, -} from "libs/utils"; +import { computeArrayFromBoundingBox, computeVolumeFromBoundingBox } from "libs/utils"; import type React from "react"; import { useCallback, useMemo } from "react"; import { APIJobType, type AiModel } from "types/api_types"; diff --git a/frontend/javascripts/viewer/view/ai_jobs/run_ai_model/ai_image_segmentation_job_context.tsx b/frontend/javascripts/viewer/view/ai_jobs/run_ai_model/ai_image_segmentation_job_context.tsx index 265143c0478..4447f829794 100644 --- a/frontend/javascripts/viewer/view/ai_jobs/run_ai_model/ai_image_segmentation_job_context.tsx +++ b/frontend/javascripts/viewer/view/ai_jobs/run_ai_model/ai_image_segmentation_job_context.tsx @@ -21,7 +21,7 @@ import { getTaskBoundingBoxes, getUserBoundingBoxesFromState, } from "viewer/model/accessors/tracing_accessor"; -import { setAIJobModalStateAction } from "viewer/model/actions/ui_actions"; +import { setAIJobDrawerStateAction } from "viewer/model/actions/ui_actions"; import { Model } from "viewer/singletons"; import type { UserBoundingBox } from "viewer/store"; import type { SplitMergerEvaluationSettings } from "viewer/view/action-bar/ai_job_modals/components/collapsible_split_merger_evaluation_settings"; @@ -197,7 +197,7 @@ export const RunAiModelJobContextProvider: React.FC<{ children: React.ReactNode } } Toast.success("Analysis started successfully!"); - dispatch(setAIJobModalStateAction("invisible")); + dispatch(setAIJobDrawerStateAction("invisible")); } catch (error) { console.error(error); Toast.error("Failed to start analysis."); diff --git a/frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_training_job_context.tsx b/frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_training_job_context.tsx index 446f1b78aae..f80e3ccf72a 100644 --- a/frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_training_job_context.tsx +++ b/frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_training_job_context.tsx @@ -12,7 +12,7 @@ import { useDispatch } from "react-redux"; import { type APIAnnotation, APIJobType } from "types/api_types"; import type { Vector3 } from "viewer/constants"; import { getUserBoundingBoxesFromState } from "viewer/model/accessors/tracing_accessor"; -import { setAIJobModalStateAction } from "viewer/model/actions/ui_actions"; +import { setAIJobDrawerStateAction } from "viewer/model/actions/ui_actions"; import type { UserBoundingBox } from "viewer/store"; import type { AnnotationInfoForAITrainingJob } from "viewer/view/action-bar/ai_job_modals/utils"; import type { AiTrainingTask } from "./ai_training_model_selector"; @@ -157,7 +157,7 @@ export const AiTrainingJobContextProvider: React.FC<{ children: React.ReactNode }); } Toast.success("The training has successfully started."); - dispatch(setAIJobModalStateAction("invisible")); + dispatch(setAIJobDrawerStateAction("invisible")); } catch (error) { console.error(error); Toast.error("Failed to start training."); From 4dff13fd6a0a54a97361c658943084875540b605 Mon Sep 17 00:00:00 2001 From: Tom Herold Date: Tue, 2 Sep 2025 16:00:32 +0200 Subject: [PATCH 11/72] stuff --- frontend/javascripts/viewer/default_state.ts | 2 +- .../view/ai_jobs/credit_information.tsx | 4 +- .../ai_training_data_selector.tsx | 72 ++++++++++++++----- .../ai_training_job_context.tsx | 68 +++++++----------- 4 files changed, 81 insertions(+), 65 deletions(-) diff --git a/frontend/javascripts/viewer/default_state.ts b/frontend/javascripts/viewer/default_state.ts index e32f5ca4072..d6305ba30d5 100644 --- a/frontend/javascripts/viewer/default_state.ts +++ b/frontend/javascripts/viewer/default_state.ts @@ -245,7 +245,7 @@ const defaultState: WebknossosState = { showMergeAnnotationModal: false, showZarrPrivateLinksModal: false, showPythonClientModal: false, - aIJobDrawerState: "invisible", + aIJobDrawerState: "open_ai_training", showRenderAnimationModal: false, showShareModal: false, storedLayouts: {}, diff --git a/frontend/javascripts/viewer/view/ai_jobs/credit_information.tsx b/frontend/javascripts/viewer/view/ai_jobs/credit_information.tsx index e85c3b527cb..9e4ddda3399 100644 --- a/frontend/javascripts/viewer/view/ai_jobs/credit_information.tsx +++ b/frontend/javascripts/viewer/view/ai_jobs/credit_information.tsx @@ -46,12 +46,12 @@ export const AlignmentCreditInformation: React.FC = () => { }; export const TrainingCreditInformation: React.FC = () => { - const { selectedTask, selectedJobType, annotationInfos, handleStartAnalysis } = + const { selectedTask, selectedJobType, selectedAnnotations, handleStartAnalysis } = useAiTrainingJobContext(); // sum all training volumes into a single bounding box // This is a shitty way to do it, but it works for now. - const totalVolume = annotationInfos.reduce( + const totalVolume = selectedAnnotations.reduce( (total, { userBoundingBoxes }) => total + userBoundingBoxes.reduce( diff --git a/frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_training_data_selector.tsx b/frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_training_data_selector.tsx index 99472a95dc7..d883cce3391 100644 --- a/frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_training_data_selector.tsx +++ b/frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_training_data_selector.tsx @@ -1,6 +1,8 @@ import { FolderOutlined } from "@ant-design/icons"; -import { Card, Col, Form, Row, Select, Space } from "antd"; +import { Card, Col, Form, Row, Select, Space, Statistic } from "antd"; +import { formatVoxels } from "libs/format_utils"; import { V3 } from "libs/mjs"; +import { computeVolumeFromBoundingBox } from "libs/utils"; import _ from "lodash"; import { useMemo } from "react"; import type { APIAnnotation, APIDataLayer, APIDataset } from "types/api_types"; @@ -12,7 +14,9 @@ import { import { getSegmentationLayerByHumanReadableName } from "viewer/model/accessors/volumetracing_accessor"; import type { StoreAnnotation } from "viewer/store"; import type { AnnotationInfoForAITrainingJob } from "viewer/view/action-bar/ai_job_modals/utils"; -import { useAiTrainingJobContext } from "./ai_training_job_context"; +import { AiTrainingAnnotationSelection, useAiTrainingJobContext } from "./ai_training_job_context"; +import { Store } from "antd/es/form/interface"; +import { useWkSelector } from "libs/react_hooks"; const getMagsForColorLayer = (colorLayers: APIDataLayer[], layerName: string) => { const colorLayer = colorLayers.find((layer) => layer.name === layerName); @@ -39,12 +43,14 @@ const getIntersectingMagList = ( ); }; -const AiTrainingDataSelector = (props: AnnotationInfoForAITrainingJob) => { - const { annotation, dataset } = props; - const { selections, handleSelectionChange } = useAiTrainingJobContext(); +const AiTrainingDataSelector = ({ + selectedAnnotation, +}: { selectedAnnotation: AiTrainingAnnotationSelection }) => { + const dataset = useWkSelector((state) => state.dataset); + const { handleSelectionChange } = useAiTrainingJobContext(); - const annotationId = annotation.annotationId; - const selection = selections.find((s) => s.annotationId === annotationId); + const annotation = selectedAnnotation.annotation; + const annotationId = selectedAnnotation.annotation.annotationId; // Gather layer names from dataset. Omit the layers that are also present // in annotationLayers. @@ -71,18 +77,36 @@ const AiTrainingDataSelector = (props: AnnotationInfoForAITrainingJob layer.elementClass !== "uint24"); const availableMagnifications = useMemo(() => { - if (selection?.imageDataLayer && selection?.groundTruthLayer) { + if (selectedAnnotation?.imageDataLayer && selectedAnnotation?.groundTruthLayer) { return ( getIntersectingMagList( annotation, dataset, - selection.groundTruthLayer, - selection.imageDataLayer, + selectedAnnotation.groundTruthLayer, + selectedAnnotation.imageDataLayer, ) || [] ); } return []; - }, [selection?.imageDataLayer, selection?.groundTruthLayer, annotation, dataset]); + }, [ + selectedAnnotation?.imageDataLayer, + selectedAnnotation?.groundTruthLayer, + annotation, + dataset, + ]); + + const boundingBoxCount = useMemo( + () => selectedAnnotation.userBoundingBoxes.length, + [selectedAnnotation.userBoundingBoxes], + ); + const boundingBoxVolume = useMemo( + () => + selectedAnnotation.userBoundingBoxes.reduce( + (sum, box) => sum + computeVolumeFromBoundingBox(box.boundingBox), + 0, + ), + [selectedAnnotation?.userBoundingBoxes], + ); return ( @@ -95,7 +119,7 @@ const AiTrainingDataSelector = (props: AnnotationInfoForAITrainingJob ({ value: l, label: l }))} - value={selection?.groundTruthLayer} + value={selectedAnnotation?.groundTruthLayer} onChange={(value) => handleSelectionChange(annotationId, { groundTruthLayer: value })} /> @@ -118,12 +142,14 @@ const AiTrainingDataSelector = (props: AnnotationInfoForAITrainingJob { }; export const AiModelSelector: React.FC = () => { + const dispatch = useDispatch(); const { selectedModel, setSelectedModel, setSelectedJobType } = useRunAiModelJobContext(); const [searchTerm, setSearchTerm] = useState(""); @@ -86,15 +90,20 @@ export const AiModelSelector: React.FC = () => { const lowerCaseSearchTerm = searchTerm.toLowerCase(); return models.filter( (model) => - model.name.toLowerCase().includes(lowerCaseSearchTerm) || - (model.comment && model.comment.toLowerCase().includes(lowerCaseSearchTerm)), + model.name?.toLowerCase().includes(lowerCaseSearchTerm) || + model.comment?.toLowerCase().includes(lowerCaseSearchTerm), ); }; + // biome-ignore lint/correctness/useExhaustiveDependencies: filtered models need an update after searchTerm changes const filteredPreTrainedModels = useMemo(() => filterModels(preTrainedModels), [searchTerm]); - const filteredCustomModels = useMemo( - () => filterModels(customModels), - [searchTerm, customModels], + // biome-ignore lint/correctness/useExhaustiveDependencies: filtered models need an update after searchTerm changes + const filteredCustomModels = useMemo(() => filterModels([]), [searchTerm, customModels]); + + const switchToTraininButton = ( + ); return ( @@ -118,6 +127,7 @@ export const AiModelSelector: React.FC = () => { ( { 0 ? "No models match your search." : switchToTraininButton, + }} renderItem={(item) => ( { const colorLayer = colorLayers.find((layer) => layer.name === layerName); @@ -109,13 +110,16 @@ const AiTrainingDataSelector = ({ ); return ( - + ({ value: l, label: l }))} diff --git a/frontend/javascripts/viewer/view/ai_jobs/utils.ts b/frontend/javascripts/viewer/view/ai_jobs/utils.ts new file mode 100644 index 00000000000..dc284da4978 --- /dev/null +++ b/frontend/javascripts/viewer/view/ai_jobs/utils.ts @@ -0,0 +1,15 @@ +import type { Rule } from "antd/es/form"; +import type { APIDataLayer } from "types/api_types"; + +export const colorLayerMustNotBeUint24Rule = { + validator: (_: Rule, value: APIDataLayer) => { + if (value && value.elementClass === "uint24") { + return Promise.reject( + new Error( + "The selected layer of type uint24 is not supported. Please select a different one.", + ), + ); + } + return Promise.resolve(); + }, +}; From d6e9930dbfbe4d09e24bd991f8f15a209eed6db2 Mon Sep 17 00:00:00 2001 From: Tom Herold Date: Wed, 3 Sep 2025 15:50:34 +0200 Subject: [PATCH 14/72] add more validation --- app/models/job/JobService.scala | 2 + .../run_ai_model/ai_analysis_parameters.tsx | 2 +- .../run_ai_model/ai_model_selector.tsx | 3 +- .../ai_training_data_selector.tsx | 146 ++++++++++++++---- 4 files changed, 118 insertions(+), 35 deletions(-) diff --git a/app/models/job/JobService.scala b/app/models/job/JobService.scala index 4fb56b87a73..1b02a7390b2 100644 --- a/app/models/job/JobService.scala +++ b/app/models/job/JobService.scala @@ -264,6 +264,8 @@ class JobService @Inject()(wkConf: WkConf, case JobCommand.infer_neurons => Fox.successful(wkConf.Features.neuronInferralCostPerGVx) case JobCommand.infer_nuclei => Fox.successful(wkConf.Features.neuronInferralCostPerGVx) case JobCommand.infer_mitochondria => Fox.successful(wkConf.Features.mitochondriaInferralCostPerGVx) + case JobCommand.train_neuron_model => Fox.successful(0) + case JobCommand.train_instance_model => Fox.successful(0) case JobCommand.align_sections => Fox.successful(wkConf.Features.alignmentCostPerGVx) case _ => Fox.failure(s"Unsupported job command $jobCommand") } diff --git a/frontend/javascripts/viewer/view/ai_jobs/run_ai_model/ai_analysis_parameters.tsx b/frontend/javascripts/viewer/view/ai_jobs/run_ai_model/ai_analysis_parameters.tsx index fc1a6f41a58..6c529b0a3e6 100644 --- a/frontend/javascripts/viewer/view/ai_jobs/run_ai_model/ai_analysis_parameters.tsx +++ b/frontend/javascripts/viewer/view/ai_jobs/run_ai_model/ai_analysis_parameters.tsx @@ -29,8 +29,8 @@ import { isDatasetOrBoundingBoxTooSmall, } from "viewer/view/action-bar/ai_job_modals/utils"; import { BoundingBoxSelector } from "../bounding_box_selector"; -import { useRunAiModelJobContext } from "./ai_image_segmentation_job_context"; import { colorLayerMustNotBeUint24Rule } from "../utils"; +import { useRunAiModelJobContext } from "./ai_image_segmentation_job_context"; export const AiAnalysisParameters: React.FC = () => { const { diff --git a/frontend/javascripts/viewer/view/ai_jobs/run_ai_model/ai_model_selector.tsx b/frontend/javascripts/viewer/view/ai_jobs/run_ai_model/ai_model_selector.tsx index 922a55ea586..ef247dc0357 100644 --- a/frontend/javascripts/viewer/view/ai_jobs/run_ai_model/ai_model_selector.tsx +++ b/frontend/javascripts/viewer/view/ai_jobs/run_ai_model/ai_model_selector.tsx @@ -8,7 +8,6 @@ import { useDispatch } from "react-redux"; import { APIJobType, type AiModel } from "types/api_types"; import { setAIJobDrawerStateAction } from "viewer/model/actions/ui_actions"; import { useRunAiModelJobContext } from "./ai_image_segmentation_job_context"; -import { Store } from "viewer/singletons"; const { Title, Text } = Typography; @@ -101,7 +100,7 @@ export const AiModelSelector: React.FC = () => { const filteredCustomModels = useMemo(() => filterModels([]), [searchTerm, customModels]); const switchToTraininButton = ( - ); diff --git a/frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_training_data_selector.tsx b/frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_training_data_selector.tsx index d2b8678cfe2..cf57a89db0e 100644 --- a/frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_training_data_selector.tsx +++ b/frontend/javascripts/viewer/view/ai_jobs/train_ai_model/ai_training_data_selector.tsx @@ -1,5 +1,5 @@ import { FolderOutlined } from "@ant-design/icons"; -import { Card, Col, Form, Row, Select, Space, Statistic } from "antd"; +import { Alert, Card, Col, Form, Row, Select, Space, Statistic } from "antd"; import { formatVoxels } from "libs/format_utils"; import { V3 } from "libs/mjs"; import { useWkSelector } from "libs/react_hooks"; @@ -13,11 +13,12 @@ import { getSegmentationLayers, } from "viewer/model/accessors/dataset_accessor"; import { getSegmentationLayerByHumanReadableName } from "viewer/model/accessors/volumetracing_accessor"; +import BoundingBox from "viewer/model/bucket_data_handling/bounding_box"; +import { colorLayerMustNotBeUint24Rule } from "../utils"; import { type AiTrainingAnnotationSelection, useAiTrainingJobContext, } from "./ai_training_job_context"; -import { colorLayerMustNotBeUint24Rule } from "../utils"; const getMagsForColorLayer = (colorLayers: APIDataLayer[], layerName: string) => { const colorLayer = colorLayers.find((layer) => layer.name === layerName); @@ -46,12 +47,15 @@ const getIntersectingMagList = ( const AiTrainingDataSelector = ({ selectedAnnotation, -}: { selectedAnnotation: AiTrainingAnnotationSelection }) => { +}: { + selectedAnnotation: AiTrainingAnnotationSelection; +}) => { const dataset = useWkSelector((state) => state.dataset); const { handleSelectionChange } = useAiTrainingJobContext(); - const annotation = selectedAnnotation.annotation; - const annotationId = selectedAnnotation.annotation.annotationId; + const { annotation, imageDataLayer, groundTruthLayer, magnification, userBoundingBoxes } = + selectedAnnotation; + const annotationId = annotation.annotationId; // Gather layer names from dataset. Omit the layers that are also present // in annotationLayers. @@ -78,37 +82,107 @@ const AiTrainingDataSelector = ({ const colorLayers = getColorLayers(dataset).filter((layer) => layer.elementClass !== "uint24"); const availableMagnifications = useMemo(() => { - if (selectedAnnotation?.imageDataLayer && selectedAnnotation?.groundTruthLayer) { - return ( - getIntersectingMagList( - annotation, - dataset, - selectedAnnotation.groundTruthLayer, - selectedAnnotation.imageDataLayer, - ) || [] - ); + if (imageDataLayer && groundTruthLayer) { + return getIntersectingMagList(annotation, dataset, groundTruthLayer, imageDataLayer) || []; } return []; - }, [ - selectedAnnotation?.imageDataLayer, - selectedAnnotation?.groundTruthLayer, - annotation, - dataset, - ]); + }, [imageDataLayer, groundTruthLayer, annotation, dataset]); - const boundingBoxCount = useMemo( - () => selectedAnnotation.userBoundingBoxes.length, - [selectedAnnotation.userBoundingBoxes], - ); + const boundingBoxCount = useMemo(() => userBoundingBoxes.length, [userBoundingBoxes]); const boundingBoxVolume = useMemo( () => - selectedAnnotation.userBoundingBoxes.reduce( + userBoundingBoxes.reduce( (sum, box) => sum + computeVolumeFromBoundingBox(box.boundingBox), 0, ), - [selectedAnnotation?.userBoundingBoxes], + [userBoundingBoxes], + ); + + const shouldValidate = useMemo( + () => !!(imageDataLayer && groundTruthLayer && magnification), + [imageDataLayer, groundTruthLayer, magnification], ); + const layerValidationError = useMemo(() => { + if (!shouldValidate) { + return undefined; + } + if (imageDataLayer && groundTruthLayer && imageDataLayer === groundTruthLayer) { + return "Image Data and Ground Truth layers must be different."; + } + return undefined; + }, [shouldValidate, imageDataLayer, groundTruthLayer]); + + const magnificationValidationError = useMemo(() => { + if (!shouldValidate) { + return undefined; + } + if (imageDataLayer && groundTruthLayer && availableMagnifications.length === 0) { + return "No common magnification found for the selected layers."; + } + return undefined; + }, [shouldValidate, imageDataLayer, groundTruthLayer, availableMagnifications]); + + const { bboxErrors, bboxWarnings } = useMemo(() => { + if (!shouldValidate) { + return { bboxErrors: [], bboxWarnings: [] }; + } + const errors: string[] = []; + const warnings: string[] = []; + + if (userBoundingBoxes.length === 0) { + errors.push("At least one bounding box is required for training."); + return { bboxErrors: errors, bboxWarnings: warnings }; + } + + if (boundingBoxVolume === 0) { + errors.push("Total volume of bounding boxes cannot be zero."); + } + + const MIN_BBOX_EXTENT_IN_EACH_DIM = 32; + const tooSmallBoxes: string[] = []; + const notMagAlignedBoundingBoxes: string[] = []; + + userBoundingBoxes.forEach((box) => { + const boundingBox = new BoundingBox(box.boundingBox); + let effectiveBbox = boundingBox; + if (magnification) { + const alignedBoundingBox = boundingBox.alignFromMag1ToMag(magnification, "shrink"); + if (!alignedBoundingBox.equals(boundingBox)) { + notMagAlignedBoundingBoxes.push(box.name); + } + effectiveBbox = alignedBoundingBox; + } + + const [width, height, depth] = effectiveBbox.getSize(); + if ( + width < MIN_BBOX_EXTENT_IN_EACH_DIM || + height < MIN_BBOX_EXTENT_IN_EACH_DIM || + depth < MIN_BBOX_EXTENT_IN_EACH_DIM + ) { + tooSmallBoxes.push(box.name); + } + }); + + if (tooSmallBoxes.length > 0) { + warnings.push( + `The following bounding boxes are too small. They should be at least ${MIN_BBOX_EXTENT_IN_EACH_DIM} Vx in each dimension: ${tooSmallBoxes.join( + ", ", + )}`, + ); + } + + if (notMagAlignedBoundingBoxes.length > 0) { + warnings.push( + `The following bounding boxes are not aligned with the selected magnification and will be automatically shrunk: ${notMagAlignedBoundingBoxes.join( + ", ", + )}`, + ); + } + + return { bboxErrors: errors, bboxWarnings: warnings }; + }, [shouldValidate, userBoundingBoxes, magnification, boundingBoxVolume]); + return ( @@ -123,7 +197,7 @@ const AiTrainingDataSelector = ({ > ({ value: l, label: l }))} - value={selectedAnnotation?.groundTruthLayer} + value={groundTruthLayer} onChange={(value) => handleSelectionChange(annotationId, { groundTruthLayer: value })} /> @@ -149,16 +225,16 @@ const AiTrainingDataSelector = ({ label="Magnification" required rules={[{ required: true, message: "Please select a magnification" }]} + validateStatus={magnificationValidationError ? "error" : undefined} + help={magnificationValidationError} >