Skip to content

Commit d40de6b

Browse files
committed
Fix failing assertions in GetInternalIteratorSequenceStorage
1 parent c1f00fa commit d40de6b

File tree

2 files changed

+34
-19
lines changed

2 files changed

+34
-19
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/common/SequenceStorageNodes.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4059,21 +4059,23 @@ public abstract static class CreateStorageFromIteratorNodeCached extends CreateS
40594059

40604060
@CompilationFinal private ListStorageType expectedElementType = Uninitialized;
40614061

4062+
private static final int MAX_PREALLOCATE_SIZE = 32;
40624063
@CompilationFinal int startSizeProfiled = START_SIZE;
40634064

40644065
public boolean isBuiltinIterator(Object iterator) {
40654066
return iterator instanceof PBuiltinIterator && getClass.execute((PBuiltinIterator) iterator) == PythonBuiltinClassType.PIterator;
40664067
}
40674068

40684069
public static SequenceStorage getSequenceStorage(GetInternalIteratorSequenceStorage node, PBuiltinIterator iterator) {
4069-
return node.execute(iterator);
4070+
return iterator.index != 0 || iterator.isExhausted() ? null : node.execute(iterator);
40704071
}
40714072

4072-
@Specialization(guards = {"isBuiltinIterator(it)", "it.index == 0", "storage != null"})
4073+
@Specialization(guards = {"isBuiltinIterator(it)", "storage != null"})
40734074
public SequenceStorage createBuiltinFastPath(PBuiltinIterator it, int len,
40744075
@Cached GetInternalIteratorSequenceStorage getIterSeqStorageNode,
40754076
@Bind("getSequenceStorage(getIterSeqStorageNode, it)") SequenceStorage storage,
40764077
@Cached CopyNode copyNode) {
4078+
it.setExhausted();
40774079
return copyNode.execute(storage);
40784080
}
40794081

@@ -4085,7 +4087,7 @@ public SequenceStorage createBuiltinUnknownLen(VirtualFrame frame, PBuiltinItera
40854087
@Shared("arrayGrowProfile") @Cached("createCountingProfile()") ConditionProfile arrayGrowProfile,
40864088
@Cached NextNode nextNode) {
40874089
int expectedLen = lengthHint.execute(iterator);
4088-
if (expectedLen == -1) {
4090+
if (expectedLen < 0) {
40894091
expectedLen = startSizeProfiled;
40904092
}
40914093
SequenceStorage s = createStorageFromBuiltin(frame, iterator, expectedLen, expectedElementType, nextNode, errorProfile, arrayGrowProfile, loopProfile);
@@ -4123,7 +4125,7 @@ public SequenceStorage createGenericKnownLen(VirtualFrame frame, Object iterator
41234125
private SequenceStorage profileResult(SequenceStorage storage, boolean profileLength) {
41244126
if (CompilerDirectives.inInterpreter() && profileLength) {
41254127
int actualLen = storage.length();
4126-
if (startSizeProfiled < actualLen && actualLen <= 32) {
4128+
if (startSizeProfiled < actualLen && actualLen <= MAX_PREALLOCATE_SIZE) {
41274129
startSizeProfiled = actualLen;
41284130
}
41294131
}
@@ -4152,7 +4154,7 @@ public SequenceStorage execute(VirtualFrame frame, Object iterator, int len) {
41524154
private SequenceStorage executeImpl(Object iterator, int len) {
41534155
if (iterator instanceof PBuiltinIterator) {
41544156
PBuiltinIterator pbi = (PBuiltinIterator) iterator;
4155-
if (GetClassNode.getUncached().execute(pbi) == PythonBuiltinClassType.PIterator && pbi.index == 0) {
4157+
if (GetClassNode.getUncached().execute(pbi) == PythonBuiltinClassType.PIterator && pbi.index == 0 && !pbi.isExhausted()) {
41564158
SequenceStorage s = GetInternalIteratorSequenceStorage.getUncached().execute(pbi);
41574159
if (s != null) {
41584160
return s.copy();

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/iterator/IteratorNodes.java

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
import com.oracle.graal.python.runtime.sequence.storage.SequenceStorage;
8181
import com.oracle.truffle.api.CompilerDirectives;
8282
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
83+
import com.oracle.truffle.api.dsl.Bind;
8384
import com.oracle.truffle.api.dsl.Cached;
8485
import com.oracle.truffle.api.dsl.Fallback;
8586
import com.oracle.truffle.api.dsl.GenerateNodeFactory;
@@ -188,6 +189,7 @@ int length(VirtualFrame frame, Object iterable,
188189
}
189190

190191
@ImportStatic(PGuards.class)
192+
@GenerateUncached
191193
public abstract static class GetInternalIteratorSequenceStorage extends Node {
192194
public static GetInternalIteratorSequenceStorage getUncached() {
193195
return GetInternalIteratorSequenceStorageNodeGen.getUncached();
@@ -200,7 +202,7 @@ public static GetInternalIteratorSequenceStorage getUncached() {
200202
*/
201203
public final SequenceStorage execute(PBuiltinIterator iterator) {
202204
assert GetClassNode.getUncached().execute(iterator) == PIterator;
203-
assert iterator.index == 0;
205+
assert iterator.index == 0 && !iterator.isExhausted();
204206
return executeInternal(iterator);
205207
}
206208

@@ -244,43 +246,54 @@ static SequenceStorage doOthers(PBuiltinIterator it) {
244246

245247
@ImportStatic(PGuards.class)
246248
public abstract static class BuiltinIteratorLengthHint extends Node {
247-
@Child GetInternalIteratorSequenceStorage getSeqStorage = GetInternalIteratorSequenceStorageNodeGen.create();
248-
private final ConditionProfile noStorageProfile = ConditionProfile.createBinaryProfile();
249-
250249
/**
251250
* The argument must be a builtin iterator. Returns {@code -1} if the length hint is not
252-
* available.
251+
* available and rewrites itself to generic fallback that always returns {@code -1}.
253252
*/
254253
public final int execute(PBuiltinIterator iterator) {
255254
assert GetClassNode.getUncached().execute(iterator) == PIterator;
256-
SequenceStorage result = getSeqStorage.execute(iterator);
257-
if (noStorageProfile.profile(result != null)) {
258-
return result.length();
259-
}
260255
return executeInternal(iterator);
261256
}
262257

263258
protected abstract int executeInternal(PBuiltinIterator iterator);
264259

260+
protected static SequenceStorage getStorage(GetInternalIteratorSequenceStorage getSeqStorage, PBuiltinIterator it) {
261+
return it.index != 0 || it.isExhausted() ? null : getSeqStorage.execute(it);
262+
}
263+
264+
@Specialization(guards = "storage != null")
265+
static int doSeqStorage(@SuppressWarnings("unused") PBuiltinIterator it,
266+
@SuppressWarnings("unused") @Cached GetInternalIteratorSequenceStorage getSeqStorage,
267+
@Bind("getStorage(getSeqStorage, it)") SequenceStorage storage) {
268+
return ensurePositive(storage.length());
269+
}
270+
265271
@Specialization
266272
static int doString(PStringIterator it) {
267-
return it.value.length();
273+
return ensurePositive(it.value.length());
268274
}
269275

270276
@Specialization
271277
static int doSequenceArr(PArrayIterator it) {
272-
return it.array.getLength();
278+
return ensurePositive(it.array.getLength());
273279
}
274280

275281
@Specialization
276282
static int doSequenceIntRange(PIntRangeIterator it) {
277-
return it.getLength();
283+
return ensurePositive(it.getLength());
278284
}
279285

280-
@Fallback
281-
static int doOthers(PBuiltinIterator it) {
286+
@Specialization(replaces = {"doSeqStorage", "doString", "doSequenceArr", "doSequenceIntRange"})
287+
static int doGeneric(@SuppressWarnings("unused") PBuiltinIterator it) {
282288
return -1;
283289
}
290+
291+
static int ensurePositive(int len) {
292+
if (len < 0) {
293+
throw CompilerDirectives.shouldNotReachHere();
294+
}
295+
return len;
296+
}
284297
}
285298

286299
@GenerateUncached

0 commit comments

Comments
 (0)