Skip to content

Commit accc01f

Browse files
author
Kevin Scott
committed
Lots of updates to prepare for 0.4.0
1 parent c9eb626 commit accc01f

File tree

8 files changed

+223
-83
lines changed

8 files changed

+223
-83
lines changed

README.md

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,63 @@ When you have a trained model you're happy with, save it with:
6464
mlClassifier.save();
6565
```
6666

67+
## Using the saved model
68+
69+
When you hit save, Tensorflow.js will download a weights file and a model topology file.
70+
71+
You'll need to combine both into a single `json` file. Open up your model topology file and at the top level of the JSON file, make sure to add a `weightsManifest` key pointing to your weights, like:
72+
73+
```
74+
{
75+
"weightsManifest": "ml-classifier-class1-class2.weights.bin",
76+
"modelTopology": {
77+
...
78+
}
79+
}
80+
```
81+
82+
When using the model in your app, there's a few things to keep in mind:
83+
84+
1. You need to make sure you transform images into the correct dimensions, depending on the pretrained model it was trained with. (For MOBILENET, this would be 1x224x224x3).
85+
2. You must create a pretrained model matching the dimensions used to train. An example is below for MOBILENET.
86+
3. You must first run your images through the pretrained model to activate them.
87+
4. After getting the final prediction, you must take the arg max.
88+
5. You'll get back a number indicating your class.
89+
90+
Full example for MOBILENET:
91+
92+
```
93+
const loadImage = (src) => new Promise((resolve, reject) => {
94+
const image = new Image();
95+
image.src = src;
96+
image.crossOrigin = 'Anonymous';
97+
image.onload = () => resolve(image);
98+
image.onerror = (err) => reject(err);
99+
});
100+
101+
const pretrainedModelURL = 'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json';
102+
103+
tf.loadModel(pretrainedModelURL).then(model => {
104+
const layer = model.getLayer('conv_pw_13_relu');
105+
return tf.model({
106+
inputs: [model.inputs[0]],
107+
outputs: layer.output,
108+
});
109+
}).then(pretrainedModel => {
110+
return tf.loadModel('/model.json').then(model => {
111+
return loadImage('/trees/tree1.png').then(loadedImage => {
112+
const image = tf.reshape(tf.fromPixels(loadedImage), [1,224,224,3]);
113+
const pretrainedModelPrediction = pretrainedModel.predict(image);
114+
const modelPrediction = model.predict(pretrainedModelPrediction);
115+
const prediction = modelPrediction.as1D().argMax().dataSync()[0];
116+
console.log(prediction);
117+
});
118+
});
119+
}).catch(err => {
120+
console.error('Error', err);
121+
});
122+
```
123+
67124
## API Documentation
68125

69126
Start by instantiating a new instance of `MLClassifier` with:
@@ -76,10 +133,16 @@ This will begin loading the pretrained model and provide you with an object onto
76133

77134
### `constructor`
78135

79-
`MLClassifier` accepts a number of callbacks when initialized:
136+
`MLClassifier` accepts a number of callbacks for beginning and end of various methods.
137+
138+
You can provide a custom pretrained model as a `pretrainedModel`.
139+
140+
You can provide a custom training model as a `trainingModel`.
80141

81142
#### Parameters
82143

144+
* **pretrainedModel** (`string | tf.Model`) *Optional* - A string denoting which pretrained model to load from an internal config. Valid strings can be found on the exported object `PRETRAINED_MODELS`. You can also specify a preloaded pretrained model directly.
145+
* **trainingModel** (`tf.Model | Function`) *Optional* - A custom model to use during training. Can be provided as a `tf.Model` or as a function that accepts `{xs: [...], ys: [...]`, number of `classes`, and `params` provided to train.
83146
* **onLoadStart** (`Function`) *Optional* - A callback for when `load` (loading the pre-trained model) is first called.
84147
* **onLoadComplete** (`Function`) *Optional* - A callback for when `load` (loading the pre-trained model) is complete.
85148
* **onAddDataStart** (`Function`) *Optional* - A callback for when `addData` is first called.
@@ -98,8 +161,13 @@ This will begin loading the pretrained model and provide you with an object onto
98161

99162
#### Example
100163
```
101-
import MLClassifier from 'ml-classifier';
164+
import MLClassifier, {
165+
PRETRAINED_MODELS,
166+
} from 'ml-classifier';
167+
102168
const mlClassifier = new MLClassifier({
169+
pretrainedModel: PRETRAINED_MODELS.MOBILENET,
170+
103171
onLoadStart: () => console.log('onLoadStart'),
104172
onLoadComplete: () => console.log('onLoadComplete'),
105173
onAddDataStart: () => console.log('onAddDataStart'),
@@ -117,6 +185,18 @@ const mlClassifier = new MLClassifier({
117185
});
118186
```
119187

188+
Example of specifying a preloaded pretrained model:
189+
190+
```
191+
import MLClassifier from 'ml-classifier';
192+
193+
const mlClassifier = tf.loadModel('... some pretrained model ...').then(model => {
194+
return new MLClassifier({
195+
pretrainedModel: model,
196+
});
197+
});
198+
```
199+
120200
### `addData`
121201

122202
This method takes an array of incoming images, an optional array of labels, and an optional dataType.
@@ -302,3 +382,5 @@ yarn test
302382
## License
303383

304384
This project is licensed under the MIT License - see the LICENSE file for details
385+
386+
![](https://ga-beacon.appspot.com/UA-112845439-4/ml-classifier/readme)

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "ml-classifier",
3-
"version": "0.3.11",
3+
"version": "0.4.0",
44
"description": "A machine learning engine for quickly training image classification models in your browser",
55
"main": "dist/index.js",
66
"typings": "dist/index.d.ts",

src/index.ts

Lines changed: 71 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -5,45 +5,29 @@ import train from './train';
55
import translateImages, {
66
IImageData,
77
} from './translateImages';
8-
import loadPretrainedModel, {
9-
PRETRAINED_MODELS_KEYS,
10-
} from './loadPretrainedModel';
8+
import loadPretrainedModel from './loadPretrainedModel';
9+
// import log, { resetLog } from './log';
1110
import {
1211
addData,
1312
addLabels,
14-
} from './prepareTrainingData';
13+
} from './prepareData';
1514
import getDefaultDownloadHandler from './getDefaultDownloadHandler';
1615

1716
import {
1817
IParams,
1918
// IConfigurationParams,
2019
IData,
2120
ICollectedData,
21+
IArgs,
2222
// DataType,
2323
} from './types';
2424

25-
interface IArgs {
26-
onLoadStart?: Function;
27-
onLoadComplete?: Function;
28-
onAddDataStart?: Function;
29-
onAddDataComplete?: Function;
30-
onClearDataStart?: Function;
31-
onClearDataComplete?: Function;
32-
onTrainStart?: Function;
33-
onTrainComplete?: Function;
34-
onPredictComplete?: Function;
35-
onPredictStart?: Function;
36-
onEvaluateStart?: Function;
37-
onEvaluateComplete?: Function;
38-
onSaveStart?: Function;
39-
onSaveComplete?: Function;
40-
}
41-
4225
// export { DataType } from './types';
4326

4427
class MLClassifier {
4528
// private pretrainedModel: typeof tf.model;
46-
private pretrainedModel: any;
29+
// private pretrainedModel: any;
30+
private pretrainedModel: tf.Model;
4731
// private model: tf.Sequential;
4832
private model: any;
4933
private callbacks: Function[] = [];
@@ -66,11 +50,17 @@ class MLClassifier {
6650

6751
private init = async () => {
6852
this.callbackFn('onLoad', 'start');
69-
this.pretrainedModel = await loadPretrainedModel(PRETRAINED_MODELS_KEYS.MOBILENET);
53+
this.pretrainedModel = await loadPretrainedModel(this.args.pretrainedModel);
7054

7155
this.callbacks.map(callback => callback());
7256

7357
this.callbackFn('onLoad', 'complete');
58+
59+
// Warmup the model
60+
const dims = await this.getInputDims();
61+
tf.tidy(() => {
62+
this.pretrainedModel.predict(tf.zeros([1, ...dims, 3]));
63+
});
7464
}
7565

7666
private loaded = async () => new Promise(resolve => {
@@ -93,8 +83,11 @@ class MLClassifier {
9383
batchInputShape,
9484
} = inputLayers[0];
9585
const dims = await this.getInputDims();
86+
await tf.nextFrame();
9687
const processedImage = await cropAndResizeImage(image, dims);
97-
return this.pretrainedModel.predict(processedImage);
88+
await tf.nextFrame();
89+
const pred = this.pretrainedModel.predict(processedImage);
90+
return pred;
9891
}
9992

10093
private getInputDims = async (): Promise<[number, number]> => {
@@ -128,7 +121,7 @@ class MLClassifier {
128121

129122
public getModel = () => this.model;
130123

131-
public addData = async (origImages: Array<tf.Tensor3D | IImageData | HTMLImageElement | string>, origLabels: string[], dataType: string = 'train') => {
124+
public addData = async (origImages: Array<tf.Tensor | IImageData | HTMLImageElement | string>, origLabels: string[], dataType: string = 'train') => {
132125
this.callbackFn('onAddData', 'start', origImages, origLabels, dataType);
133126
if (!origImages) {
134127
throw new Error('You must supply images');
@@ -148,20 +141,27 @@ class MLClassifier {
148141
}
149142

150143
if (dataType === 'train' || dataType === 'eval') {
151-
const activatedImages = await Promise.all(images.map(async (image: tf.Tensor3D, idx: number) => {
144+
const activatedImages: tf.Tensor[] = [];
145+
for (let i = 0; i < images.length; i++) {
146+
const image = images[i];
147+
// TODO: Debug this any type
148+
const activatedImage: any = await this.cropAndActivateImage(image);
149+
activatedImages.push(activatedImage);
152150
await tf.nextFrame();
153-
return await this.cropAndActivateImage(image);
154-
}));
151+
}
155152

156153
this.data.classes = getClasses(labels);
157154
const xs = addData(activatedImages);
155+
await tf.nextFrame();
158156
const ys = addLabels(labels, this.data.classes);
157+
await tf.nextFrame();
159158
this.data[dataType] = {
160159
xs,
161160
ys,
162161
};
163162
}
164163

164+
await tf.nextFrame();
165165
this.callbackFn('onAddData', 'complete', origImages, labels, dataType, errors);
166166
}
167167

@@ -196,47 +196,54 @@ class MLClassifier {
196196
const {
197197
model,
198198
history,
199-
} = await train(data, classes, params);
199+
} = await train(this.pretrainedModel, data, classes, params, this.args);
200200

201201
this.model = model;
202202
this.callbackFn('onTrain', 'complete', params, history);
203203
return history;
204204
}
205205

206-
public predict = async (origImage: tf.Tensor3D | HTMLImageElement | string, label?: string) => {
207-
this.callbackFn('onPredict', 'start', origImage);
208-
await this.loaded();
209-
if (!this.model) {
210-
throw new Error('You must call train prior to calling predict');
206+
public predict = async (origImage: tf.Tensor | HTMLImageElement | string, label?: string) => {
207+
try {
208+
this.callbackFn('onPredict', 'start', origImage);
209+
await this.loaded();
210+
if (!this.model) {
211+
throw new Error('You must call train prior to calling predict');
212+
}
213+
const dims = await this.getInputDims();
214+
const {
215+
images,
216+
errors,
217+
} = await translateImages([origImage], dims);
218+
if (errors && errors.length) {
219+
throw errors[0].error;
220+
}
221+
const data = images[0];
222+
const img = await this.cropAndActivateImage(data);
223+
// TODO: Do these images need to be activated?
224+
const predictedClass = tf.tidy(() => {
225+
const predictions = this.model.predict(img);
226+
// TODO: address this
227+
return (predictions as tf.Tensor).as1D().argMax();
228+
});
229+
230+
console.log(predictedClass.dataSync());
231+
232+
const classId = (await predictedClass.data())[0];
233+
predictedClass.dispose();
234+
const prediction = Object.entries(this.data.classes).reduce((obj, [
235+
key,
236+
val,
237+
]) => ({
238+
...obj,
239+
[val]: key,
240+
}), {})[classId];
241+
this.callbackFn('onPredict', 'complete', origImage, label, prediction);
242+
return prediction;
243+
} catch(err) {
244+
console.error(err, origImage, label);
245+
throw new Error(err);
211246
}
212-
const dims = await this.getInputDims();
213-
const {
214-
images,
215-
errors,
216-
} = await translateImages([origImage], dims);
217-
if (errors && errors.length) {
218-
throw errors[0].error;
219-
}
220-
const data = images[0];
221-
const img = await this.cropAndActivateImage(data);
222-
// TODO: Do these images need to be activated?
223-
const predictedClass = tf.tidy(() => {
224-
const predictions = this.model.predict(img);
225-
// TODO: address this
226-
return (predictions as tf.Tensor).as1D().argMax();
227-
});
228-
229-
const classId = (await predictedClass.data())[0];
230-
predictedClass.dispose();
231-
const prediction = Object.entries(this.data.classes).reduce((obj, [
232-
key,
233-
val,
234-
]) => ({
235-
...obj,
236-
[val]: key,
237-
}), {})[classId];
238-
this.callbackFn('onPredict', 'complete', origImage, label, prediction);
239-
return prediction;
240247
}
241248

242249
public evaluate = async (params: IParams = {}) => {
@@ -271,3 +278,5 @@ class MLClassifier {
271278
}
272279

273280
export default MLClassifier;
281+
282+
export { PRETRAINED_MODELS_KEYS as PRETRAINED_MODELS } from './loadPretrainedModel';

src/loadPretrainedModel.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@ export const PRETRAINED_MODELS = {
1111
},
1212
};
1313

14-
const loadPretrainedModel = async (pretrainedModel: string) => {
14+
const loadPretrainedModel = async (pretrainedModel: string | tf.Model = PRETRAINED_MODELS_KEYS.MOBILENET) => {
15+
if (pretrainedModel instanceof tf.Model) {
16+
return pretrainedModel;
17+
}
18+
1519
if (!PRETRAINED_MODELS[pretrainedModel]) {
1620
throw new Error('You have supplied an invalid key for a pretrained model');
1721
}
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ const oneHot = (labelIndex: number, classLength: number) => tf.tidy(() => tf.one
1616
// return newData;
1717
// }), undefined);
1818

19-
export const addData = (tensors: tf.Tensor3D[]): tf.Tensor3D => {
19+
export const addData = (tensors: tf.Tensor[]): tf.Tensor => {
2020
const data = tf.keep(tensors[0]);
21-
return tensors.slice(1).reduce((data: tf.Tensor3D, tensor: tf.Tensor3D) => tf.tidy(() => {
21+
return tensors.slice(1).reduce((data: tf.Tensor, tensor: tf.Tensor) => tf.tidy(() => {
2222
const newData = tf.keep(data.concat(tensor, 0));
2323
data.dispose();
2424
return newData;

0 commit comments

Comments
 (0)