@@ -144,22 +144,141 @@ protected void raiseDivisionByZero(boolean cond) {
144144 @ TypeSystemReference (PythonArithmeticTypes .class )
145145 abstract static class RoundNode extends PythonBinaryBuiltinNode {
146146 @ SuppressWarnings ("unused" )
147- @ Specialization ( guards = "isPNone(n) || isInteger(n)" )
148- public int roundInt (int arg , Object n ) {
147+ @ Specialization
148+ public int roundIntNone (int arg , PNone n ) {
149149 return arg ;
150150 }
151151
152152 @ SuppressWarnings ("unused" )
153- @ Specialization ( guards = "isPNone(n) || isInteger(n)" )
154- public long roundLong (long arg , Object n ) {
153+ @ Specialization
154+ public long roundLongNone (long arg , PNone n ) {
155155 return arg ;
156156 }
157157
158158 @ SuppressWarnings ("unused" )
159- @ Specialization ( guards = "isPNone(n) || isInteger(n)" )
160- public PInt roundPInt (PInt arg , Object n ) {
159+ @ Specialization
160+ public PInt roundPIntNone (PInt arg , PNone n ) {
161161 return factory ().createInt (arg .getValue ());
162162 }
163+
164+ @ Specialization
165+ public Object roundLongInt (long arg , int n ) {
166+ if (n >= 0 ) {
167+ return arg ;
168+ }
169+ return makeInt (op (arg , n ));
170+ }
171+
172+ @ Specialization
173+ public Object roundPIntInt (PInt arg , int n ) {
174+ if (n >= 0 ) {
175+ return arg ;
176+ }
177+ return makeInt (op (arg .getValue (), n ));
178+ }
179+
180+ @ Specialization
181+ public Object roundLongLong (long arg , long n ) {
182+ if (n >= 0 ) {
183+ return arg ;
184+ }
185+ if (n < Integer .MIN_VALUE ) {
186+ return 0 ;
187+ }
188+ return makeInt (op (arg , (int ) n ));
189+ }
190+
191+ @ Specialization
192+ public Object roundPIntLong (PInt arg , long n ) {
193+ if (n >= 0 ) {
194+ return arg ;
195+ }
196+ if (n < Integer .MIN_VALUE ) {
197+ return 0 ;
198+ }
199+ return makeInt (op (arg .getValue (), (int ) n ));
200+ }
201+
202+ @ Specialization
203+ public Object roundPIntLong (long arg , PInt n ) {
204+ if (n .isZeroOrPositive ()) {
205+ return arg ;
206+ }
207+ try {
208+ return makeInt (op (arg , n .intValueExact ()));
209+ } catch (ArithmeticException e ) {
210+ // n is < -2^31, max. number of base-10 digits in BigInteger is 2^31 * log10(2)
211+ return 0 ;
212+ }
213+ }
214+
215+ @ Specialization
216+ public Object roundPIntPInt (PInt arg , PInt n ) {
217+ if (n .isZeroOrPositive ()) {
218+ return arg ;
219+ }
220+ try {
221+ return makeInt (op (arg .getValue (), n .intValueExact ()));
222+ } catch (ArithmeticException e ) {
223+ // n is < -2^31, max. number of base-10 digits in BigInteger is 2^31 * log10(2)
224+ return 0 ;
225+ }
226+ }
227+
228+ @ Specialization (guards = {"!isInteger(n)" })
229+ @ SuppressWarnings ("unused" )
230+ public Object roundPIntPInt (Object arg , Object n ) {
231+ throw raise (PythonErrorType .TypeError , ErrorMessages .OBJ_CANNOT_BE_INTERPRETED_AS_INTEGER , n );
232+ }
233+
234+ private Object makeInt (BigDecimal d ) {
235+ try {
236+ return intValueExact (d );
237+ } catch (ArithmeticException e ) {
238+ // does not fit int, so try long
239+ }
240+ try {
241+ return longValueExact (d );
242+ } catch (ArithmeticException e ) {
243+ // does not fit long, try BigInteger
244+ }
245+ try {
246+ return factory ().createInt (d .toBigIntegerExact ());
247+ } catch (ArithmeticException e ) {
248+ // has non-zero fractional part, which should not happen
249+ throw CompilerDirectives .shouldNotReachHere ("non-integer produced after rounding an integer" , e );
250+ }
251+ }
252+
253+ @ TruffleBoundary
254+ private static int intValueExact (BigDecimal d ) {
255+ return d .intValueExact ();
256+ }
257+
258+ @ TruffleBoundary
259+ private static long longValueExact (BigDecimal d ) {
260+ return d .longValueExact ();
261+ }
262+
263+ @ TruffleBoundary
264+ private static BigDecimal op (long arg , int n ) {
265+ try {
266+ return new BigDecimal (arg ).setScale (n , RoundingMode .HALF_EVEN );
267+ } catch (ArithmeticException e ) {
268+ // -n exceeds max. number of base-10 digits in BigInteger
269+ return BigDecimal .ZERO ;
270+ }
271+ }
272+
273+ @ TruffleBoundary
274+ private static BigDecimal op (BigInteger arg , int n ) {
275+ try {
276+ return new BigDecimal (arg ).setScale (n , RoundingMode .HALF_EVEN );
277+ } catch (ArithmeticException e ) {
278+ // -n exceeds max. number of base-10 digits in BigInteger
279+ return BigDecimal .ZERO ;
280+ }
281+ }
163282 }
164283
165284 @ Builtin (name = SpecialMethodNames .__RADD__ , minNumOfPositionalArgs = 2 )
@@ -704,7 +823,9 @@ static long doLLFast(long left, long right, @SuppressWarnings("unused") PNone no
704823 result = Math .multiplyExact (result , base );
705824 }
706825 exponent >>= 1 ;
707- base = Math .multiplyExact (base , base );
826+ if (exponent != 0 ) { // prevent overflow in last iteration
827+ base = Math .multiplyExact (base , base );
828+ }
708829 }
709830 return result ;
710831 }
@@ -1313,15 +1434,13 @@ long doLL(long left, long right) {
13131434 }
13141435
13151436 @ Specialization
1316- PInt doIPi (int left , PInt right ) {
1317- raiseNegativeShiftCount (!right .isZeroOrPositive ());
1318- return factory ().createInt (op (PInt .longToBigInteger (left ), right .intValue ()));
1437+ Object doIPi (int left , PInt right ) {
1438+ return doHugeShift (PInt .longToBigInteger (left ), right );
13191439 }
13201440
13211441 @ Specialization
1322- PInt doLPi (long left , PInt right ) {
1323- raiseNegativeShiftCount (!right .isZeroOrPositive ());
1324- return factory ().createInt (op (PInt .longToBigInteger (left ), right .intValue ()));
1442+ Object doLPi (long left , PInt right ) {
1443+ return doHugeShift (PInt .longToBigInteger (left ), right );
13251444 }
13261445
13271446 @ Specialization
@@ -1331,15 +1450,20 @@ PInt doPiI(PInt left, int right) {
13311450 }
13321451
13331452 @ Specialization
1334- PInt doPiL (PInt left , long right ) {
1453+ Object doPiL (PInt left , long right ) {
13351454 raiseNegativeShiftCount (right < 0 );
1336- return factory ().createInt (op (left .getValue (), (int ) right ));
1455+ int rightI = (int ) right ;
1456+ if (rightI == right ) {
1457+ return factory ().createInt (op (left .getValue (), rightI ));
1458+ }
1459+ // right is >= 2**31, BigInteger's bitLength is at most 2**31-1
1460+ // therefore the result of shifting right is just the sign bit
1461+ return left .isNegative () ? -1 : 0 ;
13371462 }
13381463
13391464 @ Specialization
1340- PInt doPInt (PInt left , PInt right ) {
1341- raiseNegativeShiftCount (!right .isZeroOrPositive ());
1342- return factory ().createInt (op (left .getValue (), right .intValue ()));
1465+ Object doPInt (PInt left , PInt right ) {
1466+ return doHugeShift (left .getValue (), right );
13431467 }
13441468
13451469 private void raiseNegativeShiftCount (boolean cond ) {
@@ -1354,8 +1478,19 @@ PNotImplemented doGeneric(Object a, Object b) {
13541478 return PNotImplemented .NOT_IMPLEMENTED ;
13551479 }
13561480
1481+ private Object doHugeShift (BigInteger left , PInt right ) {
1482+ raiseNegativeShiftCount (!right .isZeroOrPositive ());
1483+ try {
1484+ return factory ().createInt (op (left , right .intValueExact ()));
1485+ } catch (ArithmeticException e ) {
1486+ // right is >= 2**31, BigInteger's bitLength is at most 2**31-1
1487+ // therefore the result of shifting right is just the sign bit
1488+ return left .signum () < 0 ? -1 : 0 ;
1489+ }
1490+ }
1491+
13571492 @ TruffleBoundary
1358- public static BigInteger op (BigInteger left , int right ) {
1493+ private static BigInteger op (BigInteger left , int right ) {
13591494 return left .shiftRight (right );
13601495 }
13611496
0 commit comments