|
79 | 79 | import com.oracle.graal.python.nodes.expression.UnaryArithmetic; |
80 | 80 | import com.oracle.graal.python.nodes.frame.DeleteGlobalNode; |
81 | 81 | import com.oracle.graal.python.nodes.frame.DestructuringAssignmentNode; |
| 82 | +import com.oracle.graal.python.nodes.frame.FrameSlotIDs; |
82 | 83 | import com.oracle.graal.python.nodes.frame.ReadGlobalOrBuiltinNode; |
83 | 84 | import com.oracle.graal.python.nodes.frame.WriteGlobalNode; |
| 85 | +import com.oracle.graal.python.nodes.frame.WriteLocalVariableNode; |
84 | 86 | import com.oracle.graal.python.nodes.frame.WriteNode; |
85 | 87 | import com.oracle.graal.python.nodes.function.FunctionDefinitionNode; |
86 | 88 | import com.oracle.graal.python.nodes.function.GeneratorExpressionNode; |
@@ -135,18 +137,32 @@ <T> T getChild(Node result, int num, Class<? extends T> klass) { |
135 | 137 | if (++i <= num) { |
136 | 138 | continue; |
137 | 139 | } |
138 | | - if (n instanceof ExpressionNode.ExpressionStatementNode) { |
139 | | - n = n.getChildren().iterator().next(); |
140 | | - } else if (n instanceof ExpressionNode.ExpressionWithSideEffects) { |
141 | | - n = n.getChildren().iterator().next(); |
142 | | - } |
| 140 | + n = unpackModuleBodyWrappers(n); |
143 | 141 | assertTrue("Expected an instance of " + klass + ", got " + n.getClass(), klass.isInstance(n)); |
144 | 142 | return klass.cast(n); |
145 | 143 | } |
146 | 144 | assertFalse("Expected an instance of " + klass + ", got null", true); |
147 | 145 | return null; |
148 | 146 | } |
149 | 147 |
|
| 148 | + private Node unpackModuleBodyWrappers(Node n) { |
| 149 | + Node actual = n; |
| 150 | + if (n instanceof ExpressionNode.ExpressionStatementNode) { |
| 151 | + actual = n.getChildren().iterator().next(); |
| 152 | + } else if (n instanceof ExpressionNode.ExpressionWithSideEffects) { |
| 153 | + actual = n.getChildren().iterator().next(); |
| 154 | + } else if (n instanceof WriteLocalVariableNode) { |
| 155 | + if (((WriteLocalVariableNode) n).getIdentifier().equals(FrameSlotIDs.RETURN_SLOT_ID)) { |
| 156 | + actual = ((WriteLocalVariableNode) n).getRhs(); |
| 157 | + } |
| 158 | + } |
| 159 | + if (actual == n) { |
| 160 | + return n; |
| 161 | + } else { |
| 162 | + return unpackModuleBodyWrappers(actual); |
| 163 | + } |
| 164 | + } |
| 165 | + |
150 | 166 | <T> T getFirstChild(Node result, Class<? extends T> klass) { |
151 | 167 | return getChild(result, 0, klass); |
152 | 168 | } |
|
0 commit comments