Skip to content
40 changes: 23 additions & 17 deletions src/main/java/qupath/ext/bioimageio/BioimageIoCommand.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import javax.swing.SwingUtilities;

import javafx.stage.FileChooser;
import org.bytedeco.javacpp.PointerScope;
import org.bytedeco.opencv.global.opencv_core;
import org.bytedeco.opencv.opencv_core.Mat;
Expand Down Expand Up @@ -60,10 +61,11 @@
import javafx.util.StringConverter;
import qupath.bioimageio.spec.BioimageIoSpec;
import qupath.bioimageio.spec.BioimageIoSpec.BioimageIoModel;
import qupath.fx.dialogs.FileChoosers;
import qupath.imagej.tools.IJTools;
import qupath.lib.common.GeneralTools;
import qupath.lib.gui.QuPathGUI;
import qupath.lib.gui.dialogs.Dialogs;
import qupath.fx.dialogs.Dialogs;
import qupath.lib.gui.tools.PaneTools;
import qupath.lib.images.ImageData;
import qupath.lib.images.servers.ColorTransforms;
Expand All @@ -80,6 +82,9 @@
import qupath.opencv.tools.NumpyTools;
import qupath.opencv.tools.OpenCVTools;

import static qupath.bioimageio.spec.BioimageIoSpec.getAxesString;


/**
* Very early exploration of BioImage Model Zoo support within QuPath.
*
Expand All @@ -102,7 +107,8 @@ public class BioimageIoCommand {
public void promptForModel() {

// TODO: In the future consider handling .zip files
var file = Dialogs.promptForFile(title, null, "BioImage Model Zoo spec", "", ".yml", ".yaml");
var file = FileChoosers.promptForFile(title,
new FileChooser.ExtensionFilter("BioImage Model Zoo YAML file", "*.yml", "*.yaml"));
if (file == null)
return;

Expand Down Expand Up @@ -134,14 +140,15 @@ public void promptForModel() {
showLoadPixelClassifier = true;
}
} else {
var fileSaved = Dialogs.promptToSaveFile(title, null, null, "Pixel classifier", ".json");
var fileSaved = FileChoosers.promptToSaveFile(title,
FileChoosers.promptToSaveFile(new FileChooser.ExtensionFilter("Pixel classifier", "*.json")));
if (fileSaved != null) {
PixelClassifiers.writeClassifier(classifier, fileSaved.toPath());
Dialogs.showInfoNotification(title, "Pixel classifier saved to \n" + fileSaved.getAbsolutePath());
}
}


// Offer to show the prediction in the current image, if it's small enough
var imageData = qupath.getImageData();
if (imageData != null) {
Expand All @@ -164,7 +171,8 @@ public void promptForModel() {


} catch (Exception e) {
Dialogs.showErrorMessage(title, e);
Dialogs.showErrorMessage(title, "Error loading or running model. See the log for more details.");
logger.error("Error loading model", e);
}
}

Expand All @@ -189,18 +197,16 @@ static void showDialog(ImageData<?> imageData, String path) throws IOException {

static class DnnBuilderPane {

private QuPathGUI qupath;
private String title;
private final QuPathGUI qupath;
private final String title;

private static Font font = Font.font("Arial");
private static final Font font = Font.font("Arial");

private DnnBuilderPane(QuPathGUI qupath, String title) {
this.qupath = qupath;
this.title = title;
}

private GridPane pane;


private PatchClassifierParams promptForParameters(BioimageIoModel model, ImageData<?> imageData) {

Objects.requireNonNull(imageData, "ImageData must not be null!");
Expand All @@ -213,8 +219,8 @@ private PatchClassifierParams promptForParameters(BioimageIoModel model, ImageDa

int nChannels = params.getInputChannels().size();
int nOutputClasses = params.getOutputClasses().size();
pane = new GridPane();

GridPane pane = new GridPane();
pane.setHgap(5);
pane.setVgap(5);

Expand All @@ -234,8 +240,7 @@ private PatchClassifierParams promptForParameters(BioimageIoModel model, ImageDa
, row++);

addSeparatorRow(pane, row++);



// Handle input channels & their order
addTitleRow(pane, "Input channels", row++);
addDescriptionRow(pane, "The image channels provided as input to the model", row++);
Expand Down Expand Up @@ -307,8 +312,9 @@ private PatchClassifierParams promptForParameters(BioimageIoModel model, ImageDa
int[] shape = output.getShape().getShape();
int[] steps = output.getShape().getShapeStep();
int[] minSize = output.getShape().getShapeMin();
int indX = output.getAxes().toLowerCase().indexOf("x");
int indY = output.getAxes().toLowerCase().indexOf("y");
String outputAxes = getAxesString(output.getAxes()).toLowerCase();
int indX = outputAxes.indexOf("x");
int indY = outputAxes.indexOf("y");
if (indX >= 0 && indY >= 0) {
if (minSize.length > 0) {
width = minSize[indX];
Expand Down
Loading