Skip to content

Commit c9eb626

Browse files
author
Kevin Scott
committed
Update documentation around accepted image formats
1 parent 2a7e5c0 commit c9eb626

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ mlClassifier.addData(images, labels, 'train');
131131

132132
#### Parameters
133133

134-
* **images** (`Tensor3D[]`) - an array of 3D tensors. Images can be any sizes, but will be cropped and sized down to match the pretrained model.
134+
* **images** (`Array<tf.Tensor3D | ImageData | HTMLImageElement | string>`) - an array of 3D tensors, ImageData (output from a canvas `toPixels`, a native browser `Image`, or a string representing the image `src`. Images can be any sizes, but will be cropped and sized down to match the pretrained model.
135135
* **labels** (`string[]`) - an array of strings, matching the images passed above.
136136
* **dataType** (`string`) *Optional* - an enum specifying which data type the images match. Data types can be `train` for data used in `model.train()`, and `eval`, for data used in `model.evaluate()`. If no argument is supplied, `dataType` will default to `train`.
137137

src/index.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ import * as tf from '@tensorflow/tfjs';
22
import cropAndResizeImage from './cropAndResizeImage';
33
import getClasses from './getClasses';
44
import train from './train';
5-
import translateImages from './translateImages';
5+
import translateImages, {
6+
IImageData,
7+
} from './translateImages';
68
import loadPretrainedModel, {
79
PRETRAINED_MODELS_KEYS,
810
} from './loadPretrainedModel';
@@ -126,7 +128,7 @@ class MLClassifier {
126128

127129
public getModel = () => this.model;
128130

129-
public addData = async (origImages: Array<tf.Tensor3D | HTMLImageElement>, origLabels: string[], dataType: string = 'train') => {
131+
public addData = async (origImages: Array<tf.Tensor3D | IImageData | HTMLImageElement | string>, origLabels: string[], dataType: string = 'train') => {
130132
this.callbackFn('onAddData', 'start', origImages, origLabels, dataType);
131133
if (!origImages) {
132134
throw new Error('You must supply images');

src/translateImages.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ const imageDataToTensor = async ({
3838
return tf.tensor3d(Array.from(data), [width, height, 4]);
3939
};
4040

41-
interface IImageData {
41+
export interface IImageData {
4242
data: Uint8ClampedArray;
4343
width: number;
4444
height: number;

0 commit comments

Comments
 (0)