Skip to content

Commit ebfdd8e

Browse files
committed
Add WithOps, use for KotlinOps
Signed-off-by: Ryan Nett <[email protected]>
1 parent 111ee68 commit ebfdd8e

File tree

12 files changed

+525
-459
lines changed

12 files changed

+525
-459
lines changed

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.tensorflow.op;
1919

2020
import java.nio.charset.Charset;
21+
import java.util.Arrays;
2122
import java.util.List;
2223
import java.util.Map;
2324
import org.tensorflow.ConcreteFunction;
@@ -330,7 +331,7 @@
330331
* }
331332
* }</pre>
332333
*/
333-
public final class Ops {
334+
public final class Ops implements WithOps {
334335
public final NnOps nn;
335336

336337
public final SummaryOps summary;
@@ -7906,42 +7907,51 @@ public <T extends TType> ZerosLike<T> zerosLike(Operand<T> x) {
79067907
return ZerosLike.create(scope, x);
79077908
}
79087909

7910+
@Override
7911+
public Ops tf() {
7912+
return this;
7913+
}
7914+
79097915
/**
7910-
* Returns an API that builds operations with the provided name prefix.
7911-
*
7912-
* @see {@link Scope#withSubScope(String)}
7916+
* {@inheritDoc}
79137917
*/
7918+
@Override
79147919
public Ops withSubScope(String childScopeName) {
79157920
return new Ops(scope.withSubScope(childScopeName));
79167921
}
79177922

79187923
/**
7919-
* Returns an API that uses the provided name for an op.
7920-
*
7921-
* @see {@link Scope#withName(String)}
7924+
* {@inheritDoc}
79227925
*/
7926+
@Override
79237927
public Ops withName(String opName) {
79247928
return new Ops(scope.withName(opName));
79257929
}
79267930

79277931
/**
7928-
* Returns an API that places the created operations on the device(s) matching the provided spec.
7929-
*
7930-
* @see {@link Scope#withDevice(DeviceSpec)}
7932+
* {@inheritDoc}
79317933
*/
7934+
@Override
79327935
public Ops withDevice(DeviceSpec deviceSpec) {
79337936
return new Ops(scope.withDevice(deviceSpec));
79347937
}
79357938

79367939
/**
7937-
* Returns an API that adds operations to the graph with the provided control dependencies.
7938-
*
7939-
* @see {@link Scope#withControlDependencies(Iterable<Op<?>>)}
7940+
* {@inheritDoc}
79407941
*/
7942+
@Override
79417943
public Ops withControlDependencies(Iterable<Op> controls) {
79427944
return new Ops(scope.withControlDependencies(controls));
79437945
}
79447946

7947+
/**
7948+
* {@inheritDoc}
7949+
*/
7950+
@Override
7951+
public Ops withControlDependencies(Op... controls) {
7952+
return withControlDependencies(Arrays.asList(controls));
7953+
}
7954+
79457955
/**
79467956
* Returns the current {@link Scope scope} of this API
79477957
*/

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616
package org.tensorflow;
1717

1818
import org.tensorflow.op.Op;
19+
import org.tensorflow.op.Ops;
1920
import org.tensorflow.op.Scope;
21+
import org.tensorflow.op.WithOps;
2022

2123
/** Defines an environment for creating and executing TensorFlow {@link Operation}s. */
22-
public interface ExecutionEnvironment {
24+
public interface ExecutionEnvironment extends WithOps {
2325

2426
enum Types {
2527
GRAPH,
@@ -86,4 +88,9 @@ default boolean isGraph() {
8688
* prevent name collisions.
8789
*/
8890
Scope baseScope();
91+
92+
@Override
93+
default Ops tf(){
94+
return Ops.create(this);
95+
}
8996
}
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
Copyright 2021 The TensorFlow Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
=======================================================================
16+
17+
*/
18+
package org.tensorflow.op;
19+
20+
import java.util.Arrays;
21+
import org.tensorflow.DeviceSpec;
22+
23+
/**
24+
* A context that provides a TensorFlow op builder.
25+
*/
26+
public interface WithOps {
27+
28+
/**
29+
* Get the op builder for this context.
30+
*/
31+
Ops tf();
32+
33+
/**
34+
* Returns an API that builds operations with the provided name prefix.
35+
*
36+
* @see Scope#withSubScope(String)
37+
*/
38+
default WithOps withSubScope(String childScopeName) {
39+
return tf().withSubScope(childScopeName);
40+
}
41+
42+
/**
43+
* Returns an API that uses the provided name for an op.
44+
*
45+
* @see Scope#withName(String)
46+
*/
47+
default WithOps withName(String opName) {
48+
return tf().withName(opName);
49+
}
50+
51+
/**
52+
* Returns an API that places the created operations on the device(s) matching the provided spec.
53+
*
54+
* @see Scope#withDevice(DeviceSpec)
55+
*/
56+
default WithOps withDevice(DeviceSpec deviceSpec) {
57+
return tf().withDevice(deviceSpec);
58+
}
59+
60+
/**
61+
* Returns an API that adds operations to the graph with the provided control dependencies.
62+
*
63+
* @see Scope#withControlDependencies(Iterable)
64+
*/
65+
default WithOps withControlDependencies(Iterable<Op> controls){
66+
return tf().withControlDependencies(controls);
67+
}
68+
69+
/**
70+
* Returns an API that adds operations to the graph with the provided control dependencies.
71+
*
72+
* @see Scope#withControlDependencies(Iterable)
73+
*/
74+
default WithOps withControlDependencies(Op... controls){
75+
return withControlDependencies(Arrays.asList(controls));
76+
}
77+
78+
}

tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/Names.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ public class Names {
5353
public static final ClassName OperationBuilder = ClassName.get(TensorflowPackage, "OperationBuilder");
5454
public static final TypeName IterableOp = ParameterizedTypeName.get(ClassName.get(Iterable.class), Op);
5555

56+
public static final ClassName WithOps = ClassName.get(OpPackage, "WithOps");
57+
5658
public static final ClassName Operand = ClassName.get(TensorflowPackage, "Operand");
5759
public static final ClassName Output = ClassName.get(TensorflowPackage, "Output");
5860

tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/processor/operator/OperatorProcessor.java

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,14 @@
1515
*/
1616
package org.tensorflow.processor.operator;
1717

18+
import com.squareup.javapoet.ArrayTypeName;
19+
import com.squareup.javapoet.ClassName;
1820
import com.squareup.javapoet.FieldSpec;
1921
import com.squareup.javapoet.JavaFile;
2022
import com.squareup.javapoet.MethodSpec;
2123
import com.squareup.javapoet.TypeSpec;
2224
import java.io.IOException;
25+
import java.util.Arrays;
2326
import java.util.List;
2427
import javax.lang.model.element.Modifier;
2528
import org.tensorflow.Names;
@@ -108,6 +111,7 @@ protected TypeSpec buildTopClass(OpsSpec spec) {
108111

109112
TypeSpec.Builder opsBuilder =
110113
TypeSpec.classBuilder("Ops")
114+
.addSuperinterface(Names.WithOps)
111115
.addModifiers(Modifier.PUBLIC, Modifier.FINAL)
112116
.addJavadoc(
113117
"An API for building operations as {@link $T Op}s\n<p>\n"
@@ -146,52 +150,65 @@ protected TypeSpec buildTopClass(OpsSpec spec) {
146150

147151
opsBuilder.addMethod(ctorBuilder.build());
148152

153+
opsBuilder.addMethod(MethodSpec
154+
.methodBuilder("tf")
155+
.addModifiers(Modifier.PUBLIC)
156+
.addAnnotation(Override.class)
157+
.returns(Names.Ops)
158+
.addStatement("return this")
159+
.build()
160+
);
161+
149162
opsBuilder.addMethod(
150163
MethodSpec.methodBuilder("withSubScope")
151164
.addModifiers(Modifier.PUBLIC)
165+
.addAnnotation(Override.class)
152166
.addParameter(Names.String, "childScopeName")
153167
.returns(Names.Ops)
154168
.addStatement("return new $T(scope.withSubScope(childScopeName))", Names.Ops)
155-
.addJavadoc(
156-
"Returns an API that builds operations with the provided name prefix.\n"
157-
+ "\n@see {@link $T#withSubScope(String)}\n",
158-
Names.Scope)
169+
.addJavadoc("{@inheritDoc}")
159170
.build());
160171

161172
opsBuilder.addMethod(
162173
MethodSpec.methodBuilder("withName")
163174
.addModifiers(Modifier.PUBLIC)
175+
.addAnnotation(Override.class)
164176
.addParameter(Names.String, "opName")
165177
.returns(Names.Ops)
166178
.addStatement("return new Ops(scope.withName(opName))")
167-
.addJavadoc(
168-
"Returns an API that uses the provided name for an op.\n\n"
169-
+ "@see {@link $T#withName(String)}\n",
170-
Names.Scope)
179+
.addJavadoc("{@inheritDoc}")
171180
.build());
172181

173182
opsBuilder.addMethod(
174183
MethodSpec.methodBuilder("withDevice")
175184
.addModifiers(Modifier.PUBLIC)
185+
.addAnnotation(Override.class)
176186
.addParameter(Names.DeviceSpec, "deviceSpec")
177187
.returns(Names.Ops)
178188
.addStatement("return new Ops(scope.withDevice(deviceSpec))")
179-
.addJavadoc(
180-
"Returns an API that places the created operations on the device(s) matching the provided spec.\n\n"
181-
+ "@see {@link $T#withDevice(DeviceSpec)}\n",
182-
Names.Scope)
189+
.addJavadoc("{@inheritDoc}")
183190
.build());
184191

185192
opsBuilder.addMethod(
186193
MethodSpec.methodBuilder("withControlDependencies")
187194
.addModifiers(Modifier.PUBLIC)
195+
.addAnnotation(Override.class)
188196
.addParameter(Names.IterableOp, "controls")
189197
.returns(Names.Ops)
190198
.addStatement("return new Ops(scope.withControlDependencies(controls))")
191-
.addJavadoc(
192-
"Returns an API that adds operations to the graph with the provided control dependencies.\n\n"
193-
+ "@see {@link $T#withControlDependencies(Iterable<Op<?>>)}\n",
194-
Names.Scope)
199+
.addJavadoc("{@inheritDoc}")
200+
.build());
201+
202+
opsBuilder.addMethod(
203+
MethodSpec.methodBuilder("withControlDependencies")
204+
.addModifiers(Modifier.PUBLIC)
205+
.addAnnotation(Override.class)
206+
.addParameter(ArrayTypeName.of(Names.Op), "controls")
207+
.varargs()
208+
.returns(Names.Ops)
209+
.addStatement("return withControlDependencies($T.asList(controls))", ClassName.get(
210+
Arrays.class))
211+
.addJavadoc("{@inheritDoc}")
195212
.build());
196213

197214
opsBuilder.addField(

0 commit comments

Comments
 (0)