This repository is a Maven project that wraps the Segment Anything Model. You can read more about the converting process under the /pytorch_convert/README.md directory.
To interface with the TorchScript model, we used the DJL framework. DJL is a deep learning framework for Java that supports PyTorch, TensorFlow, and MXNet. It also provides a Java API to load and run TorchScript models.
The project is structured as follows:
/pytorch_convert: Python code to patch and save the Segment Anything Model (SAM) as TorchScript to a new file./src/main/java/djlsam/Sam.java: Java code to load the TorchScript model and run inference./src/main/java/djlsam/translators: Java classes to convert the input/output tensors to/from the TorchScript model./src/main/test/djlsam/SamTest.java: Java code to test the model./src/resources/images: Test images./src/resources/pytorch_models: TorchScript models.
It is recommended to use an IDE such as IntelliJ IDEA to run the project.
To install the dependencies, run the following command:
mvn clean installTo run the tests, run the following command:
mvn testBefore implementing a model with the DJL framework, you should first convert your model to TorchScript.
You can also find example from the DJL documentation here.
Add the following dependencies to your pom.xml file:
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>0.21.0</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>0.22.0</version>
<scope>runtime</scope>
</dependency>Note: You can find the latest version of the dependencies here.
Create a new class for your model. Within the class you can load the TorchScript model and run inference. You can find an example here.
The main idea is to create the following objects:
Translator<Image, SamRawOutput> translator;
Criteria<Image, SamRawOutput> criteria;
ZooModel<Image, SamRawOutput> model;
Predictor<Image, SamRawOutput> predictor;Each object has an input and output type which should match the input and output types of the translator object.
DJL has many input/output types as well as translators already implemented. You can find them here.
The translator object is used to convert the input/output tensors to/from the TorchScript model. You can find an example here.
It overrides the following methods:
processInput(TranslatorContext ctx, Image input)to convert the input image to aNDListobject.processOutput(TranslatorContext ctx, NDList list)to convert the outputNDListobject to aSamRawOutputobject.
The SamRawOutput object is a custom class wrapper that contains the output tensors of the model. You can find an example here.
The criteria object is used to specify the input and output types of the model. You can find an example here.
Note: The path of the TorchScript model must be a directory that contains the
.ptfile and it must have the same name as the directory.
By calling the method criteria.loadModel();, the model object is created. You can find an example here.
Finally, the predictor object is created by calling the method model.newPredictor();. You can find an example here.
- You can use the
NDManagerobject to createNDArrayobjects.

