Skip to content

Commit cba0ba9

Browse files
authored
Upstream sync (#295)
Miscellaneous changes to sync upstream. This pull request introduces several improvements and bug fixes to the Python tensor analysis engine, focusing on error handling, code clarity, and compatibility updates. The most significant changes include improved exception handling for shape source operations, refactoring for better code readability, and updates to project settings for enhanced compatibility with Java and Eclipse. ## Error handling and robustness improvements * Exception handling was added to shape source operations in `PythonTensorAnalysisEngine.java` and `TensorType.java`, converting previously ignored IO errors into runtime exceptions to ensure failures are surfaced and handled appropriately. * The `TensorType.shapeArg` method now throws `IOException`, propagating errors up the call stack instead of silently ignoring them. ## Code clarity and refactoring * Several code blocks were refactored for readability and maintainability, such as removing unnecessary braces, improving variable naming, and restructuring loops and conditional statements. * Documentation was updated to clarify the behavior of dataset tensor element checks and associated nodes. ## Project configuration and compatibility * Project settings were updated to enable Java compiler release compatibility and preview features, ensuring the codebase remains up-to-date with the latest Java standards. * Eclipse resource encoding settings were added for test directories to ensure consistent file encoding across environments. * A new Maven test launch configuration was added for easier test execution within Eclipse.
1 parent 593cd68 commit cba0ba9

File tree

5 files changed

+133
-93
lines changed

5 files changed

+133
-93
lines changed

Maven test.launch

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
2+
<launchConfiguration type="org.eclipse.m2e.Maven2LaunchConfigurationType">
3+
<intAttribute key="M2_COLORS" value="0"/>
4+
<booleanAttribute key="M2_DEBUG_OUTPUT" value="false"/>
5+
<stringAttribute key="M2_GOALS" value="test"/>
6+
<booleanAttribute key="M2_NON_RECURSIVE" value="false"/>
7+
<booleanAttribute key="M2_OFFLINE" value="true"/>
8+
<stringAttribute key="M2_PROFILES" value=""/>
9+
<listAttribute key="M2_PROPERTIES"/>
10+
<stringAttribute key="M2_RUNTIME" value="EMBEDDED"/>
11+
<booleanAttribute key="M2_SKIP_TESTS" value="false"/>
12+
<intAttribute key="M2_THREADS" value="1"/>
13+
<booleanAttribute key="M2_UPDATE_SNAPSHOTS" value="true"/>
14+
<stringAttribute key="M2_USER_SETTINGS" value=""/>
15+
<booleanAttribute key="M2_WORKSPACE_RESOLUTION" value="true"/>
16+
<booleanAttribute key="org.eclipse.debug.core.ATTR_FORCE_SYSTEM_CONSOLE_ENCODING" value="false"/>
17+
<listAttribute key="org.eclipse.debug.ui.favoriteGroups">
18+
<listEntry value="org.eclipse.debug.ui.launchGroup.run"/>
19+
</listAttribute>
20+
<booleanAttribute key="org.eclipse.jdt.launching.ATTR_ATTR_USE_ARGFILE" value="false"/>
21+
<booleanAttribute key="org.eclipse.jdt.launching.ATTR_SHOW_CODEDETAILS_IN_EXCEPTION_MESSAGES" value="true"/>
22+
<booleanAttribute key="org.eclipse.jdt.launching.ATTR_USE_CLASSPATH_ONLY_JAR" value="false"/>
23+
<stringAttribute key="org.eclipse.jdt.launching.WORKING_DIRECTORY" value="${workspace_loc:/ml}"/>
24+
</launchConfiguration>

com.ibm.wala.cast.python.jython.test/.settings/org.eclipse.jdt.core.prefs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ org.eclipse.jdt.core.compiler.problem.enablePreviewFeatures=disabled
1818
org.eclipse.jdt.core.compiler.problem.enumIdentifier=error
1919
org.eclipse.jdt.core.compiler.problem.forbiddenReference=warning
2020
org.eclipse.jdt.core.compiler.problem.reportPreviewFeatures=ignore
21-
org.eclipse.jdt.core.compiler.release=disabled
21+
org.eclipse.jdt.core.compiler.release=enabled
2222
org.eclipse.jdt.core.compiler.source=21
2323
org.eclipse.jdt.core.incompatibleJDKLevel=ignore
2424
org.eclipse.jdt.core.incompleteClasspath=error

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java

Lines changed: 98 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import com.ibm.wala.util.graph.impl.SlowSparseNumberedGraph;
4545
import com.ibm.wala.util.intset.OrdinalSet;
4646
import java.io.File;
47+
import java.io.IOException;
4748
import java.util.Iterator;
4849
import java.util.List;
4950
import java.util.Map;
@@ -168,12 +169,12 @@ private static Set<PointsToSetVariable> getDataflowSources(
168169
int objectRef = propertyRead.getObjectRef();
169170
SSAInstruction def = du.getDef(objectRef);
170171

171-
if (def == null) {
172+
if (def == null)
172173
// definition is unavailable from the local DefUse. Use interprocedural analysis using
173174
// the PA.
174175
processInstructionInterprocedurally(
175176
propertyRead, objectRef, localPointerKeyNode, src, sources, pointerAnalysis);
176-
} else if (def instanceof EachElementGetInstruction
177+
else if (def instanceof EachElementGetInstruction
177178
|| def instanceof PythonPropertyRead
178179
|| def instanceof PythonInvokeInstruction) {
179180
boolean added = false;
@@ -408,7 +409,7 @@ private static boolean processInstructionInterprocedurally(
408409
IClass concreteType = asin.getConcreteType();
409410
TypeReference reference = concreteType.getReference();
410411

411-
if (reference.equals(DATASET) && isDatasetTensorElement(src, use, node, pointerAnalysis)) {
412+
if (reference.equals(DATASET) && isDatasetTensorElement(src, use, pointerAnalysis)) {
412413
sources.add(src);
413414
logger.info("Added dataflow source from tensor dataset: " + src + ".");
414415
return true;
@@ -421,69 +422,71 @@ private static boolean processInstructionInterprocedurally(
421422

422423
/**
423424
* Returns true iff the given {@link PointsToSetVariable} refers to a tensor dataset element of
424-
* the dataset defined by the given value number in the given {@link CGNode}.
425+
* the dataset defined by the given value number in the its associated {@link CGNode}.
425426
*
426427
* @param variable The {@link PointsToSetVariable} to consider.
427428
* @param val The value in the given {@link CGNode} representing the tensor dataset.
428-
* @param node The {@link CGNode} containing the given {@link PointsToSetVariable} and value.
429429
* @param pointerAnalysis The {@link PointerAnalysis} that includes points-to information for the
430430
* given {@link CGNode}.
431431
* @return True iff src refers to a tensor dataset element defined by the dataset represented by
432-
* val in node.
432+
* val in the node associated with src.
433433
*/
434434
private static boolean isDatasetTensorElement(
435-
PointsToSetVariable variable,
436-
int val,
437-
CGNode node,
438-
PointerAnalysis<InstanceKey> pointerAnalysis) {
439-
SSAInstruction def = node.getDU().getDef(val);
440-
441-
if (def instanceof PythonInvokeInstruction) {
442-
PythonInvokeInstruction invokeInstruction = (PythonInvokeInstruction) def;
443-
444-
// Check whether we are calling enumerate(), as that returns a tuple.
445-
// Get the invoked function.
446-
int invocationUse = invokeInstruction.getUse(0);
447-
448-
PointerKey invocationUsePointerKey =
449-
pointerAnalysis.getHeapModel().getPointerKeyForLocal(node, invocationUse);
450-
451-
for (InstanceKey functionInstance : pointerAnalysis.getPointsToSet(invocationUsePointerKey)) {
452-
if (functionInstance instanceof ConcreteTypeKey) {
453-
ConcreteTypeKey typeKey = (ConcreteTypeKey) functionInstance;
454-
IClass type = typeKey.getType();
455-
TypeReference typeReference = type.getReference();
456-
457-
if (typeReference.equals(ENUMERATE.getDeclaringClass())) {
458-
// it's a call to enumerate(), where the returned value is an iterator over
459-
// tuples. Each tuple consists of the enumeration number and the dataset
460-
// element. Check that we are not looking at the enumeration number.
461-
462-
PythonPropertyRead srcDef =
463-
(PythonPropertyRead)
464-
node.getDU()
465-
.getDef(((LocalPointerKey) variable.getPointerKey()).getValueNumber());
466-
467-
// What does the member reference point to?
468-
PointerKey memberRefPointerKey =
469-
pointerAnalysis.getHeapModel().getPointerKeyForLocal(node, srcDef.getMemberRef());
470-
471-
for (InstanceKey memberInstance : pointerAnalysis.getPointsToSet(memberRefPointerKey)) {
472-
ConstantKey<?> constant = (ConstantKey<?>) memberInstance;
473-
Object value = constant.getValue();
474-
475-
// if it's the first tuple element.
476-
if (value.equals(0)) {
477-
// Now that we know it's the first tuple element, we now need to know whether it's
478-
// the first tuple, i.e., the one returned by enumerate.
479-
// To do that, we examine the object being referenced on the RHS.
480-
481-
SSAInstruction objRefDef = node.getDU().getDef(srcDef.getObjectRef());
482-
483-
// If the object being read is that of the dataset, we know that this is the first
484-
// tuple read of the result of enumerate() on the dataset.
485-
if (objRefDef instanceof PythonPropertyRead
486-
&& ((PythonPropertyRead) objRefDef).getObjectRef() == val) return false;
435+
PointsToSetVariable variable, int val, PointerAnalysis<InstanceKey> pointerAnalysis) {
436+
if (variable.getPointerKey() instanceof LocalPointerKey) {
437+
LocalPointerKey localPointerKey = (LocalPointerKey) variable.getPointerKey();
438+
CGNode node = localPointerKey.getNode();
439+
SSAInstruction def = node.getDU().getDef(val);
440+
441+
if (def instanceof PythonInvokeInstruction) {
442+
PythonInvokeInstruction invokeInstruction = (PythonInvokeInstruction) def;
443+
444+
// Check whether we are calling enumerate(), as that returns a tuple.
445+
// Get the invoked function.
446+
int invocationUse = invokeInstruction.getUse(0);
447+
448+
PointerKey invocationUsePointerKey =
449+
pointerAnalysis.getHeapModel().getPointerKeyForLocal(node, invocationUse);
450+
451+
for (InstanceKey functionInstance :
452+
pointerAnalysis.getPointsToSet(invocationUsePointerKey)) {
453+
if (functionInstance instanceof ConcreteTypeKey) {
454+
ConcreteTypeKey typeKey = (ConcreteTypeKey) functionInstance;
455+
IClass type = typeKey.getType();
456+
TypeReference typeReference = type.getReference();
457+
458+
if (typeReference.equals(ENUMERATE.getDeclaringClass())) {
459+
// it's a call to enumerate(), where the returned value is an iterator over
460+
// tuples. Each tuple consists of the enumeration number and the dataset
461+
// element. Check that we are not looking at the enumeration number.
462+
463+
PythonPropertyRead srcDef =
464+
(PythonPropertyRead)
465+
node.getDU()
466+
.getDef(((LocalPointerKey) variable.getPointerKey()).getValueNumber());
467+
468+
// What does the member reference point to?
469+
PointerKey memberRefPointerKey =
470+
pointerAnalysis.getHeapModel().getPointerKeyForLocal(node, srcDef.getMemberRef());
471+
472+
for (InstanceKey memberInstance :
473+
pointerAnalysis.getPointsToSet(memberRefPointerKey)) {
474+
ConstantKey<?> constant = (ConstantKey<?>) memberInstance;
475+
Object value = constant.getValue();
476+
477+
// if it's the first tuple element.
478+
if (value.equals(0)) {
479+
// Now that we know it's the first tuple element, we now need to know whether it's
480+
// the first tuple, i.e., the one returned by enumerate.
481+
// To do that, we examine the object being referenced on the RHS.
482+
483+
SSAInstruction objRefDef = node.getDU().getDef(srcDef.getObjectRef());
484+
485+
// If the object being read is that of the dataset, we know that this is the first
486+
// tuple read of the result of enumerate() on the dataset.
487+
if (objRefDef instanceof PythonPropertyRead
488+
&& ((PythonPropertyRead) objRefDef).getObjectRef() == val) return false;
489+
}
487490
}
488491
}
489492
}
@@ -617,16 +620,19 @@ private Map<PointsToSetVariable, TensorType> getShapeSourceCalls(
617620
op,
618621
builder,
619622
(CGNode src, SSAAbstractInvokeInstruction call) -> {
620-
if (call.getNumberOfUses() > param) {
621-
targets.put(
622-
builder
623-
.getPropagationSystem()
624-
.findOrCreatePointsToSet(
625-
builder
626-
.getPointerAnalysis()
627-
.getHeapModel()
628-
.getPointerKeyForLocal(src, call.getDef())),
629-
TensorType.shapeArg(src, call.getUse(param)));
623+
try {
624+
if (call.getNumberOfUses() > param)
625+
targets.put(
626+
builder
627+
.getPropagationSystem()
628+
.findOrCreatePointsToSet(
629+
builder
630+
.getPointerAnalysis()
631+
.getHeapModel()
632+
.getPointerKeyForLocal(src, call.getDef())),
633+
TensorType.shapeArg(src, call.getUse(param)));
634+
} catch (IOException e) {
635+
throw new RuntimeException("Error while processing shape source call: " + call, e);
630636
}
631637
});
632638
return targets;
@@ -663,34 +669,46 @@ public TensorTypeAnalysis performAnalysis(PropagationCallGraphBuilder builder)
663669

664670
TensorType mnistData = TensorType.mnistInput();
665671
Map<PointsToSetVariable, TensorType> init = HashMapFactory.make();
666-
for (PointsToSetVariable v : sources) {
667-
init.put(v, mnistData);
672+
673+
for (PointsToSetVariable v : sources) init.put(v, mnistData);
674+
675+
Map<PointsToSetVariable, TensorType> placeholders = null;
676+
try {
677+
placeholders = handleShapeSourceOp(builder, dataflow, placeholder, 2);
678+
} catch (IOException e) {
679+
throw new RuntimeException("Error while processing placeholder calls.", e);
668680
}
681+
logger.fine("Placeholders: " + placeholders);
669682

670-
Map<PointsToSetVariable, TensorType> placeholders =
671-
handleShapeSourceOp(builder, dataflow, placeholder, 2);
672-
logger.fine(() -> "Placeholders: " + placeholders);
673-
for (Map.Entry<PointsToSetVariable, TensorType> e : placeholders.entrySet()) {
683+
for (Map.Entry<PointsToSetVariable, TensorType> e : placeholders.entrySet())
674684
init.put(e.getKey(), e.getValue());
675-
}
676685

677686
Map<PointsToSetVariable, TensorType> setCalls = HashMapFactory.make();
678687
Map<PointsToSetVariable, TensorType> set_shapes = getShapeSourceCalls(set_shape, builder, 1);
688+
679689
for (Map.Entry<PointsToSetVariable, TensorType> x : set_shapes.entrySet()) {
680-
CGNode setNode = ((LocalPointerKey) x.getKey().getPointerKey()).getNode();
681-
int defVn = ((LocalPointerKey) x.getKey().getPointerKey()).getValueNumber();
690+
LocalPointerKey localPointerKey = (LocalPointerKey) x.getKey().getPointerKey();
691+
CGNode setNode = localPointerKey.getNode();
692+
int defVn = localPointerKey.getValueNumber();
682693
SSAInstruction read = setNode.getDU().getDef(defVn);
683694
SSAInstruction call = setNode.getDU().getDef(read.getUse(0));
695+
684696
PointerKey setKey =
685697
builder
686698
.getPointerAnalysis()
687699
.getHeapModel()
688700
.getPointerKeyForLocal(setNode, call.getUse(0));
701+
689702
setCalls.put(builder.getPropagationSystem().findOrCreatePointsToSet(setKey), x.getValue());
690703
}
691704

692705
Map<PointsToSetVariable, TensorType> shapeOps = HashMapFactory.make();
693-
shapeOps.putAll(handleShapeSourceOp(builder, dataflow, reshape, 2));
706+
707+
try {
708+
shapeOps.putAll(handleShapeSourceOp(builder, dataflow, reshape, 2));
709+
} catch (IOException e) {
710+
throw new RuntimeException("Error while processing reshape calls.", e);
711+
}
694712

695713
Set<PointsToSetVariable> conv2ds = getKeysDefinedByCall(conv2d, builder);
696714

@@ -708,7 +726,8 @@ private Map<PointsToSetVariable, TensorType> handleShapeSourceOp(
708726
PropagationCallGraphBuilder builder,
709727
Graph<PointsToSetVariable> dataflow,
710728
MethodReference op,
711-
int shapeSrcOperand) {
729+
int shapeSrcOperand)
730+
throws IOException {
712731
Map<PointsToSetVariable, TensorType> reshapeTypes =
713732
getShapeSourceCalls(op, builder, shapeSrcOperand);
714733
for (PointsToSetVariable to : reshapeTypes.keySet()) {

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/types/TensorType.java

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ public static TensorType mnistInput() {
329329
return new TensorType("pixel", Arrays.asList(batch, vec));
330330
}
331331

332-
public static TensorType shapeArg(CGNode node, int literalVn) {
332+
public static TensorType shapeArg(CGNode node, int literalVn) throws IOException {
333333
logger.fine(() -> node.getIR().toString());
334334
ArrayList<Dimension<?>> r = new ArrayList<>();
335335
DefUse du = node.getDU();
@@ -360,18 +360,13 @@ public static TensorType shapeArg(CGNode node, int literalVn) {
360360
.debugInfo()
361361
.getInstructionPosition(du.getDef(val).iIndex());
362362
System.err.println(p);
363-
try {
364-
SourceBuffer b = new SourceBuffer(p);
365-
String expr = b.toString();
366-
System.err.println(expr);
367-
Integer ival = PythonInterpreter.interpretAsInt(expr);
368-
if (ival != null) {
369-
r.add(new NumericDim(ival));
370-
continue;
371-
}
372-
} catch (IOException e) {
373-
// TODO Auto-generated catch block
374-
e.printStackTrace();
363+
SourceBuffer b = new SourceBuffer(p);
364+
String expr = b.toString();
365+
System.err.println(expr);
366+
Integer ival = PythonInterpreter.interpretAsInt(expr);
367+
if (ival != null) {
368+
r.add(new NumericDim(ival));
369+
continue;
375370
}
376371
}
377372
r.add(new SymbolicDim("?"));
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
eclipse.preferences.version=1
2+
encoding//src/test/java=UTF-8
3+
encoding//src/test/resources=UTF-8
24
encoding/<project>=UTF-8
35
encoding/data=UTF-8
46
encoding/source=UTF-8

0 commit comments

Comments
 (0)