Skip to content

Commit 932183a

Browse files
committed
Properly track coroutine-ness of comprehensions
1 parent 323b134 commit 932183a

File tree

2 files changed

+18
-9
lines changed
  • graalpython
    • com.oracle.graal.python.pegparser/src/com/oracle/graal/python/pegparser/scope
    • com.oracle.graal.python/src/com/oracle/graal/python/compiler

2 files changed

+18
-9
lines changed

graalpython/com.oracle.graal.python.pegparser/src/com/oracle/graal/python/pegparser/scope/ScopeEnvironment.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,7 @@ private void handleComprehension(ExprTy e, String scopeName, ComprehensionTy[] g
393393
outermost.iter.accept(this);
394394
currentScope.comprehensionIterExpression--;
395395
enterBlock(scopeName, Scope.ScopeType.Function, e);
396+
boolean isAsync;
396397
try {
397398
currentScope.comprehensionType = comprehensionType;
398399
if (outermost.isAsync) {
@@ -414,9 +415,13 @@ private void handleComprehension(ExprTy e, String scopeName, ComprehensionTy[] g
414415
if (isGenerator) {
415416
currentScope.flags.add(ScopeFlags.IsGenerator);
416417
}
418+
isAsync = currentScope.isCoroutine() && !isGenerator;
417419
} finally {
418420
exitBlock();
419421
}
422+
if (isAsync) {
423+
currentScope.flags.add(ScopeFlags.IsCoroutine);
424+
}
420425
}
421426

422427
private void raiseIfComprehensionBlock(ExprTy node) {
@@ -545,6 +550,7 @@ public Void visit(ExprTy.Attribute node) {
545550
public Void visit(ExprTy.Await node) {
546551
raiseIfAnnotationBlock("await expression", node);
547552
node.value.accept(this);
553+
currentScope.flags.add(ScopeFlags.IsCoroutine);
548554
return null;
549555
}
550556

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/compiler/Compiler.java

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,7 @@ private Void addOp(OpCodes code, int arg, byte[] followingArgs) {
526526
}
527527

528528
private void addOp(OpCodes code, int arg, byte[] followingArgs, SourceRange location) {
529+
assert code != YIELD_VALUE || unit.scope.isGenerator() || unit.scope.isCoroutine();
529530
Block b = unit.currentBlock;
530531
Instruction insn = new Instruction(code, arg, followingArgs, null, location);
531532
b.instr.add(insn);
@@ -1570,10 +1571,15 @@ private Void visitComprehension(ExprTy node, String name, ComprehensionTy[] gene
15701571
SourceRange savedLocation = setLocation(node);
15711572
try {
15721573
enterScope(name, CompilationScope.Comprehension, node, 1, 0, 0, false, false);
1574+
boolean isAsyncGenerator = unit.scope.isCoroutine();
15731575
if (type != ComprehensionType.GENEXPR) {
15741576
// The result accumulator, empty at the beginning
15751577
addOp(COLLECTION_FROM_STACK, type.typeBits);
15761578
}
1579+
// TODO allow top-level await
1580+
if (isAsyncGenerator && type != ComprehensionType.GENEXPR && unit.scopeType != CompilationScope.AsyncFunction && unit.scopeType != CompilationScope.Comprehension) {
1581+
errorCallback.onError(ErrorType.Syntax, unit.currentLocation, "asynchronous comprehension outside of an asynchronous function");
1582+
}
15771583
visitComprehensionGenerator(generators, 0, element, value, type);
15781584
if (type != ComprehensionType.GENEXPR) {
15791585
addOp(RETURN_VALUE);
@@ -1590,15 +1596,12 @@ private Void visitComprehension(ExprTy node, String name, ComprehensionTy[] gene
15901596
addOp(CALL_COMPREHENSION);
15911597
// a genexpr will create an asyncgen, which we cannot await
15921598
if (type != ComprehensionType.GENEXPR) {
1593-
for (ComprehensionTy gen : generators) {
1594-
// if we have a non-genexpr async comprehension, the call will produce a
1595-
// coroutine which we need to await
1596-
if (gen.isAsync) {
1597-
addOp(GET_AWAITABLE);
1598-
addOp(LOAD_NONE);
1599-
addYieldFrom();
1600-
break;
1601-
}
1599+
// if we have a non-genexpr async comprehension, the call will produce a
1600+
// coroutine which we need to await
1601+
if (isAsyncGenerator) {
1602+
addOp(GET_AWAITABLE);
1603+
addOp(LOAD_NONE);
1604+
addYieldFrom();
16021605
}
16031606
}
16041607
return null;

0 commit comments

Comments
 (0)