@@ -144,22 +144,131 @@ protected void raiseDivisionByZero(boolean cond) {
144144 @ TypeSystemReference (PythonArithmeticTypes .class )
145145 abstract static class RoundNode extends PythonBinaryBuiltinNode {
146146 @ SuppressWarnings ("unused" )
147- @ Specialization ( guards = "isPNone(n) || isInteger(n)" )
148- public int roundInt (int arg , Object n ) {
147+ @ Specialization
148+ public int roundIntNone (int arg , PNone n ) {
149149 return arg ;
150150 }
151151
152152 @ SuppressWarnings ("unused" )
153- @ Specialization ( guards = "isPNone(n) || isInteger(n)" )
154- public long roundLong (long arg , Object n ) {
153+ @ Specialization
154+ public long roundLongNone (long arg , PNone n ) {
155155 return arg ;
156156 }
157157
158158 @ SuppressWarnings ("unused" )
159- @ Specialization ( guards = "isPNone(n) || isInteger(n)" )
160- public PInt roundPInt (PInt arg , Object n ) {
159+ @ Specialization
160+ public PInt roundPIntNone (PInt arg , PNone n ) {
161161 return factory ().createInt (arg .getValue ());
162162 }
163+
164+ @ Specialization
165+ public Object roundLongInt (long arg , int n ) {
166+ if (n >= 0 ) {
167+ return arg ;
168+ }
169+ return makeInt (op (arg , n ));
170+ }
171+
172+ @ Specialization
173+ public Object roundPIntInt (PInt arg , int n ) {
174+ if (n >= 0 ) {
175+ return arg ;
176+ }
177+ return makeInt (op (arg .getValue (), n ));
178+ }
179+
180+ @ Specialization
181+ public Object roundLongLong (long arg , long n ) {
182+ if (n >= 0 ) {
183+ return arg ;
184+ }
185+ if (n < Integer .MIN_VALUE ) {
186+ return 0 ;
187+ }
188+ return makeInt (op (arg , (int ) n ));
189+ }
190+
191+ @ Specialization
192+ public Object roundPIntLong (PInt arg , long n ) {
193+ if (n >= 0 ) {
194+ return arg ;
195+ }
196+ if (n < Integer .MIN_VALUE ) {
197+ return 0 ;
198+ }
199+ return makeInt (op (arg .getValue (), (int ) n ));
200+ }
201+
202+ @ Specialization
203+ public Object roundPIntLong (long arg , PInt n ) {
204+ if (n .isZeroOrPositive ()) {
205+ return arg ;
206+ }
207+ try {
208+ return makeInt (op (arg , n .intValueExact ()));
209+ } catch (ArithmeticException e ) {
210+ // n is < -2^31, max. number of base-10 digits in BigInteger is 2^31 * log10(2)
211+ return 0 ;
212+ }
213+ }
214+
215+ @ Specialization
216+ public Object roundPIntPInt (PInt arg , PInt n ) {
217+ if (n .isZeroOrPositive ()) {
218+ return arg ;
219+ }
220+ try {
221+ return makeInt (op (arg .getValue (), n .intValueExact ()));
222+ } catch (ArithmeticException e ) {
223+ // n is < -2^31, max. number of base-10 digits in BigInteger is 2^31 * log10(2)
224+ return 0 ;
225+ }
226+ }
227+
228+ @ Specialization (guards = {"!isInteger(n)" })
229+ @ SuppressWarnings ("unused" )
230+ public Object roundPIntPInt (Object arg , Object n ) {
231+ throw raise (PythonErrorType .TypeError , ErrorMessages .OBJ_CANNOT_BE_INTERPRETED_AS_INTEGER , n );
232+ }
233+
234+ private Object makeInt (BigDecimal d ) {
235+ try {
236+ return d .intValueExact ();
237+ } catch (ArithmeticException e ) {
238+ // does not fit int, so try long
239+ }
240+ try {
241+ return d .longValueExact ();
242+ } catch (ArithmeticException e ) {
243+ // does not fit long, try BigInteger
244+ }
245+ try {
246+ return factory ().createInt (d .toBigIntegerExact ());
247+ } catch (ArithmeticException e ) {
248+ // has non-zero fractional part, which should not happen
249+ throw new IllegalStateException ();
250+ }
251+ }
252+
253+ @ TruffleBoundary
254+ private static BigDecimal op (long arg , int n ) {
255+ try {
256+ return new BigDecimal (arg ).setScale (n , RoundingMode .HALF_EVEN );
257+ } catch (ArithmeticException e ) {
258+ // -n exceeds max. number of base-10 digits in BigInteger
259+ return BigDecimal .ZERO ;
260+ }
261+ }
262+
263+ @ TruffleBoundary
264+ private static BigDecimal op (BigInteger arg , int n ) {
265+ try {
266+ return new BigDecimal (arg ).setScale (n , RoundingMode .HALF_EVEN );
267+ } catch (ArithmeticException e ) {
268+ // -n exceeds max. number of base-10 digits in BigInteger
269+ return BigDecimal .ZERO ;
270+ }
271+ }
163272 }
164273
165274 @ Builtin (name = SpecialMethodNames .__RADD__ , minNumOfPositionalArgs = 2 )
0 commit comments