Skip to content

Commit 00d949a

Browse files
maciej3031lina128
andauthored
Add support for converting models with ImageProjectiveTransformV3 op (#6206)
* Add support for converting models with ImageProjectiveTransformV3 op * Rename image to images Co-authored-by: Na Li <[email protected]>
1 parent 092c976 commit 00d949a

File tree

4 files changed

+98
-1
lines changed

4 files changed

+98
-1
lines changed

tfjs-converter/docs/supported_ops.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@
230230
|ResizeNearestNeighbor|resizeNearestNeighbor|
231231
|Not mapped|flipLeftRight|
232232
|Not mapped|rotateWithOffset|
233+
|ImageProjectiveTransformV3|transform|
233234

234235
## Operations - Matrices
235236

tfjs-converter/python/tensorflowjs/op_list/image.json

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,5 +104,43 @@
104104
"type": "number"
105105
}
106106
]
107+
},
108+
{
109+
"tfOpName": "ImageProjectiveTransformV3",
110+
"category": "image",
111+
"inputs": [
112+
{
113+
"start": 0,
114+
"name": "images",
115+
"type": "tensor"
116+
},
117+
{
118+
"start": 1,
119+
"name": "transforms",
120+
"type": "tensor"
121+
},
122+
{
123+
"start": 2,
124+
"name": "outputShape",
125+
"type": "number[]"
126+
},
127+
{
128+
"start": 3,
129+
"name": "fillValue",
130+
"type": "number"
131+
}
132+
],
133+
"attrs": [
134+
{
135+
"tfName": "interpolation",
136+
"name": "interpolation",
137+
"type": "string"
138+
},
139+
{
140+
"tfName": "fill_mode",
141+
"name": "fillMode",
142+
"type": "string"
143+
}
144+
]
107145
}
108146
]

tfjs-converter/src/operations/executors/image_executor.ts

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,29 @@ export const executeOp: InternalOpExecutor =
7878
cropSize as [number, number], method as 'bilinear' | 'nearest',
7979
extrapolationValue)];
8080
}
81+
case 'ImageProjectiveTransformV3': {
82+
const images =
83+
getParamValue('images', node, tensorMap, context) as Tensor;
84+
const transforms =
85+
getParamValue('transforms', node, tensorMap, context) as Tensor;
86+
const outputShape =
87+
getParamValue('outputShape', node, tensorMap, context) as
88+
number[];
89+
const fillValue =
90+
getParamValue('fillValue', node, tensorMap, context) as number;
91+
const interpolation =
92+
getParamValue('interpolation', node, tensorMap, context) as
93+
string;
94+
const fillMode =
95+
getParamValue('fillMode', node, tensorMap, context) as string;
96+
return [tfOps.image.transform(
97+
images as Tensor4D,
98+
transforms as Tensor2D,
99+
interpolation.toLowerCase() as 'bilinear' | 'nearest',
100+
fillMode.toLowerCase() as 'constant' | 'reflect' | 'wrap' | 'nearest',
101+
fillValue,
102+
outputShape as [number, number])];
103+
}
81104
default:
82105
throw TypeError(`Node type ${node.op} is not implemented`);
83106
}

tfjs-converter/src/operations/executors/image_executor_test.ts

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import * as image from '../op_list/image';
2222
import {Node} from '../types';
2323

2424
import {executeOp} from './image_executor';
25-
import {createBoolAttr, createNumberAttr, createNumericArrayAttrFromIndex, createStrAttr, createTensorAttr, validateParam} from './test_helper';
25+
import {createBoolAttr, createNumberAttr, createNumberAttrFromIndex, createNumericArrayAttrFromIndex, createStrAttr, createTensorAttr, validateParam} from './test_helper';
2626

2727
describe('image', () => {
2828
let node: Node;
@@ -123,6 +123,41 @@ describe('image', () => {
123123
node.attrParams['extrapolationValue'] = createNumberAttr(0.5);
124124
node.inputNames = ['input1', 'input2', 'input3', 'input4'];
125125

126+
expect(validateParam(node, image.json)).toBeTruthy();
127+
});
128+
});
129+
describe('ImageProjectiveTransformV3', () => {
130+
it('should return input', () => {
131+
node.op = 'ImageProjectiveTransformV3';
132+
node.inputParams['images'] = createTensorAttr(0);
133+
node.inputParams['transforms'] = createTensorAttr(1);
134+
node.inputParams['outputShape'] = createNumericArrayAttrFromIndex(2);
135+
node.inputParams['fillValue'] = createNumberAttrFromIndex(3);
136+
node.attrParams['interpolation'] = createStrAttr('bilinear');
137+
node.attrParams['fillMode'] = createStrAttr('constant');
138+
node.inputNames = ['input1', 'input2', 'input3', 'input4'];
139+
140+
spyOn(tfOps.image, 'transform');
141+
const input2 = [tfOps.tensor1d([2])];
142+
const input3 = [tfOps.tensor1d([4, 5])];
143+
const input4 = [tfOps.scalar(3)];
144+
145+
executeOp(node, {input1, input2, input3, input4}, context);
146+
expect(tfOps.image.transform)
147+
.toHaveBeenCalledWith(
148+
input1[0], input2[0], 'bilinear', 'constant', 3, [4, 5]);
149+
});
150+
151+
it('should match json def', () => {
152+
node.op = 'ImageProjectiveTransformV3';
153+
node.inputParams['images'] = createTensorAttr(0);
154+
node.inputParams['transforms'] = createTensorAttr(1);
155+
node.inputParams['outputShape'] = createNumericArrayAttrFromIndex(2);
156+
node.inputParams['fillValue'] = createNumberAttrFromIndex(3);
157+
node.attrParams['interpolation'] = createStrAttr('bilinear');
158+
node.attrParams['fillMode'] = createStrAttr('constant');
159+
node.inputNames = ['input1', 'input2', 'input3', 'input4'];
160+
126161
expect(validateParam(node, image.json)).toBeTruthy();
127162
});
128163
});

0 commit comments

Comments
 (0)