4444import com .ibm .wala .util .graph .impl .SlowSparseNumberedGraph ;
4545import com .ibm .wala .util .intset .OrdinalSet ;
4646import java .io .File ;
47+ import java .io .IOException ;
4748import java .util .Iterator ;
4849import java .util .List ;
4950import java .util .Map ;
@@ -168,12 +169,12 @@ private static Set<PointsToSetVariable> getDataflowSources(
168169 int objectRef = propertyRead .getObjectRef ();
169170 SSAInstruction def = du .getDef (objectRef );
170171
171- if (def == null ) {
172+ if (def == null )
172173 // definition is unavailable from the local DefUse. Use interprocedural analysis using
173174 // the PA.
174175 processInstructionInterprocedurally (
175176 propertyRead , objectRef , localPointerKeyNode , src , sources , pointerAnalysis );
176- } else if (def instanceof EachElementGetInstruction
177+ else if (def instanceof EachElementGetInstruction
177178 || def instanceof PythonPropertyRead
178179 || def instanceof PythonInvokeInstruction ) {
179180 boolean added = false ;
@@ -408,7 +409,7 @@ private static boolean processInstructionInterprocedurally(
408409 IClass concreteType = asin .getConcreteType ();
409410 TypeReference reference = concreteType .getReference ();
410411
411- if (reference .equals (DATASET ) && isDatasetTensorElement (src , use , node , pointerAnalysis )) {
412+ if (reference .equals (DATASET ) && isDatasetTensorElement (src , use , pointerAnalysis )) {
412413 sources .add (src );
413414 logger .info ("Added dataflow source from tensor dataset: " + src + "." );
414415 return true ;
@@ -421,69 +422,71 @@ private static boolean processInstructionInterprocedurally(
421422
422423 /**
423424 * Returns true iff the given {@link PointsToSetVariable} refers to a tensor dataset element of
424- * the dataset defined by the given value number in the given {@link CGNode}.
425+ * the dataset defined by the given value number in the its associated {@link CGNode}.
425426 *
426427 * @param variable The {@link PointsToSetVariable} to consider.
427428 * @param val The value in the given {@link CGNode} representing the tensor dataset.
428- * @param node The {@link CGNode} containing the given {@link PointsToSetVariable} and value.
429429 * @param pointerAnalysis The {@link PointerAnalysis} that includes points-to information for the
430430 * given {@link CGNode}.
431431 * @return True iff src refers to a tensor dataset element defined by the dataset represented by
432- * val in node.
432+ * val in the node associated with src .
433433 */
434434 private static boolean isDatasetTensorElement (
435- PointsToSetVariable variable ,
436- int val ,
437- CGNode node ,
438- PointerAnalysis <InstanceKey > pointerAnalysis ) {
439- SSAInstruction def = node .getDU ().getDef (val );
440-
441- if (def instanceof PythonInvokeInstruction ) {
442- PythonInvokeInstruction invokeInstruction = (PythonInvokeInstruction ) def ;
443-
444- // Check whether we are calling enumerate(), as that returns a tuple.
445- // Get the invoked function.
446- int invocationUse = invokeInstruction .getUse (0 );
447-
448- PointerKey invocationUsePointerKey =
449- pointerAnalysis .getHeapModel ().getPointerKeyForLocal (node , invocationUse );
450-
451- for (InstanceKey functionInstance : pointerAnalysis .getPointsToSet (invocationUsePointerKey )) {
452- if (functionInstance instanceof ConcreteTypeKey ) {
453- ConcreteTypeKey typeKey = (ConcreteTypeKey ) functionInstance ;
454- IClass type = typeKey .getType ();
455- TypeReference typeReference = type .getReference ();
456-
457- if (typeReference .equals (ENUMERATE .getDeclaringClass ())) {
458- // it's a call to enumerate(), where the returned value is an iterator over
459- // tuples. Each tuple consists of the enumeration number and the dataset
460- // element. Check that we are not looking at the enumeration number.
461-
462- PythonPropertyRead srcDef =
463- (PythonPropertyRead )
464- node .getDU ()
465- .getDef (((LocalPointerKey ) variable .getPointerKey ()).getValueNumber ());
466-
467- // What does the member reference point to?
468- PointerKey memberRefPointerKey =
469- pointerAnalysis .getHeapModel ().getPointerKeyForLocal (node , srcDef .getMemberRef ());
470-
471- for (InstanceKey memberInstance : pointerAnalysis .getPointsToSet (memberRefPointerKey )) {
472- ConstantKey <?> constant = (ConstantKey <?>) memberInstance ;
473- Object value = constant .getValue ();
474-
475- // if it's the first tuple element.
476- if (value .equals (0 )) {
477- // Now that we know it's the first tuple element, we now need to know whether it's
478- // the first tuple, i.e., the one returned by enumerate.
479- // To do that, we examine the object being referenced on the RHS.
480-
481- SSAInstruction objRefDef = node .getDU ().getDef (srcDef .getObjectRef ());
482-
483- // If the object being read is that of the dataset, we know that this is the first
484- // tuple read of the result of enumerate() on the dataset.
485- if (objRefDef instanceof PythonPropertyRead
486- && ((PythonPropertyRead ) objRefDef ).getObjectRef () == val ) return false ;
435+ PointsToSetVariable variable , int val , PointerAnalysis <InstanceKey > pointerAnalysis ) {
436+ if (variable .getPointerKey () instanceof LocalPointerKey ) {
437+ LocalPointerKey localPointerKey = (LocalPointerKey ) variable .getPointerKey ();
438+ CGNode node = localPointerKey .getNode ();
439+ SSAInstruction def = node .getDU ().getDef (val );
440+
441+ if (def instanceof PythonInvokeInstruction ) {
442+ PythonInvokeInstruction invokeInstruction = (PythonInvokeInstruction ) def ;
443+
444+ // Check whether we are calling enumerate(), as that returns a tuple.
445+ // Get the invoked function.
446+ int invocationUse = invokeInstruction .getUse (0 );
447+
448+ PointerKey invocationUsePointerKey =
449+ pointerAnalysis .getHeapModel ().getPointerKeyForLocal (node , invocationUse );
450+
451+ for (InstanceKey functionInstance :
452+ pointerAnalysis .getPointsToSet (invocationUsePointerKey )) {
453+ if (functionInstance instanceof ConcreteTypeKey ) {
454+ ConcreteTypeKey typeKey = (ConcreteTypeKey ) functionInstance ;
455+ IClass type = typeKey .getType ();
456+ TypeReference typeReference = type .getReference ();
457+
458+ if (typeReference .equals (ENUMERATE .getDeclaringClass ())) {
459+ // it's a call to enumerate(), where the returned value is an iterator over
460+ // tuples. Each tuple consists of the enumeration number and the dataset
461+ // element. Check that we are not looking at the enumeration number.
462+
463+ PythonPropertyRead srcDef =
464+ (PythonPropertyRead )
465+ node .getDU ()
466+ .getDef (((LocalPointerKey ) variable .getPointerKey ()).getValueNumber ());
467+
468+ // What does the member reference point to?
469+ PointerKey memberRefPointerKey =
470+ pointerAnalysis .getHeapModel ().getPointerKeyForLocal (node , srcDef .getMemberRef ());
471+
472+ for (InstanceKey memberInstance :
473+ pointerAnalysis .getPointsToSet (memberRefPointerKey )) {
474+ ConstantKey <?> constant = (ConstantKey <?>) memberInstance ;
475+ Object value = constant .getValue ();
476+
477+ // if it's the first tuple element.
478+ if (value .equals (0 )) {
479+ // Now that we know it's the first tuple element, we now need to know whether it's
480+ // the first tuple, i.e., the one returned by enumerate.
481+ // To do that, we examine the object being referenced on the RHS.
482+
483+ SSAInstruction objRefDef = node .getDU ().getDef (srcDef .getObjectRef ());
484+
485+ // If the object being read is that of the dataset, we know that this is the first
486+ // tuple read of the result of enumerate() on the dataset.
487+ if (objRefDef instanceof PythonPropertyRead
488+ && ((PythonPropertyRead ) objRefDef ).getObjectRef () == val ) return false ;
489+ }
487490 }
488491 }
489492 }
@@ -617,16 +620,19 @@ private Map<PointsToSetVariable, TensorType> getShapeSourceCalls(
617620 op ,
618621 builder ,
619622 (CGNode src , SSAAbstractInvokeInstruction call ) -> {
620- if (call .getNumberOfUses () > param ) {
621- targets .put (
622- builder
623- .getPropagationSystem ()
624- .findOrCreatePointsToSet (
625- builder
626- .getPointerAnalysis ()
627- .getHeapModel ()
628- .getPointerKeyForLocal (src , call .getDef ())),
629- TensorType .shapeArg (src , call .getUse (param )));
623+ try {
624+ if (call .getNumberOfUses () > param )
625+ targets .put (
626+ builder
627+ .getPropagationSystem ()
628+ .findOrCreatePointsToSet (
629+ builder
630+ .getPointerAnalysis ()
631+ .getHeapModel ()
632+ .getPointerKeyForLocal (src , call .getDef ())),
633+ TensorType .shapeArg (src , call .getUse (param )));
634+ } catch (IOException e ) {
635+ throw new RuntimeException ("Error while processing shape source call: " + call , e );
630636 }
631637 });
632638 return targets ;
@@ -663,34 +669,46 @@ public TensorTypeAnalysis performAnalysis(PropagationCallGraphBuilder builder)
663669
664670 TensorType mnistData = TensorType .mnistInput ();
665671 Map <PointsToSetVariable , TensorType > init = HashMapFactory .make ();
666- for (PointsToSetVariable v : sources ) {
667- init .put (v , mnistData );
672+
673+ for (PointsToSetVariable v : sources ) init .put (v , mnistData );
674+
675+ Map <PointsToSetVariable , TensorType > placeholders = null ;
676+ try {
677+ placeholders = handleShapeSourceOp (builder , dataflow , placeholder , 2 );
678+ } catch (IOException e ) {
679+ throw new RuntimeException ("Error while processing placeholder calls." , e );
668680 }
681+ logger .fine ("Placeholders: " + placeholders );
669682
670- Map <PointsToSetVariable , TensorType > placeholders =
671- handleShapeSourceOp (builder , dataflow , placeholder , 2 );
672- logger .fine (() -> "Placeholders: " + placeholders );
673- for (Map .Entry <PointsToSetVariable , TensorType > e : placeholders .entrySet ()) {
683+ for (Map .Entry <PointsToSetVariable , TensorType > e : placeholders .entrySet ())
674684 init .put (e .getKey (), e .getValue ());
675- }
676685
677686 Map <PointsToSetVariable , TensorType > setCalls = HashMapFactory .make ();
678687 Map <PointsToSetVariable , TensorType > set_shapes = getShapeSourceCalls (set_shape , builder , 1 );
688+
679689 for (Map .Entry <PointsToSetVariable , TensorType > x : set_shapes .entrySet ()) {
680- CGNode setNode = ((LocalPointerKey ) x .getKey ().getPointerKey ()).getNode ();
681- int defVn = ((LocalPointerKey ) x .getKey ().getPointerKey ()).getValueNumber ();
690+ LocalPointerKey localPointerKey = (LocalPointerKey ) x .getKey ().getPointerKey ();
691+ CGNode setNode = localPointerKey .getNode ();
692+ int defVn = localPointerKey .getValueNumber ();
682693 SSAInstruction read = setNode .getDU ().getDef (defVn );
683694 SSAInstruction call = setNode .getDU ().getDef (read .getUse (0 ));
695+
684696 PointerKey setKey =
685697 builder
686698 .getPointerAnalysis ()
687699 .getHeapModel ()
688700 .getPointerKeyForLocal (setNode , call .getUse (0 ));
701+
689702 setCalls .put (builder .getPropagationSystem ().findOrCreatePointsToSet (setKey ), x .getValue ());
690703 }
691704
692705 Map <PointsToSetVariable , TensorType > shapeOps = HashMapFactory .make ();
693- shapeOps .putAll (handleShapeSourceOp (builder , dataflow , reshape , 2 ));
706+
707+ try {
708+ shapeOps .putAll (handleShapeSourceOp (builder , dataflow , reshape , 2 ));
709+ } catch (IOException e ) {
710+ throw new RuntimeException ("Error while processing reshape calls." , e );
711+ }
694712
695713 Set <PointsToSetVariable > conv2ds = getKeysDefinedByCall (conv2d , builder );
696714
@@ -708,7 +726,8 @@ private Map<PointsToSetVariable, TensorType> handleShapeSourceOp(
708726 PropagationCallGraphBuilder builder ,
709727 Graph <PointsToSetVariable > dataflow ,
710728 MethodReference op ,
711- int shapeSrcOperand ) {
729+ int shapeSrcOperand )
730+ throws IOException {
712731 Map <PointsToSetVariable , TensorType > reshapeTypes =
713732 getShapeSourceCalls (op , builder , shapeSrcOperand );
714733 for (PointsToSetVariable to : reshapeTypes .keySet ()) {
0 commit comments