3838import com .oracle .graal .python .builtins .Builtin ;
3939import com .oracle .graal .python .builtins .CoreFunctions ;
4040import com .oracle .graal .python .builtins .PythonBuiltins ;
41+ import com .oracle .graal .python .builtins .objects .PNone ;
4142import com .oracle .graal .python .builtins .objects .PNotImplemented ;
43+ import com .oracle .graal .python .builtins .objects .common .EconomicMapStorage ;
4244import com .oracle .graal .python .builtins .objects .common .HashingStorage ;
45+ import com .oracle .graal .python .builtins .objects .common .HashingStorage .Equivalence ;
4346import com .oracle .graal .python .builtins .objects .common .HashingStorageNodes ;
47+ import com .oracle .graal .python .builtins .objects .common .HashingStorageNodes .PythonEquivalence ;
48+ import com .oracle .graal .python .builtins .objects .common .PHashingCollection ;
49+ import com .oracle .graal .python .builtins .objects .set .FrozenSetBuiltinsFactory .BinaryUnionNodeGen ;
50+ import com .oracle .graal .python .nodes .PBaseNode ;
51+ import com .oracle .graal .python .nodes .control .GetIteratorNode ;
52+ import com .oracle .graal .python .nodes .control .GetNextNode ;
4453import com .oracle .graal .python .nodes .function .PythonBuiltinBaseNode ;
54+ import com .oracle .graal .python .nodes .function .PythonBuiltinNode ;
4555import com .oracle .graal .python .nodes .function .builtins .PythonBinaryBuiltinNode ;
4656import com .oracle .graal .python .nodes .function .builtins .PythonUnaryBuiltinNode ;
57+ import com .oracle .graal .python .runtime .exception .PException ;
58+ import com .oracle .truffle .api .CompilerDirectives ;
59+ import com .oracle .truffle .api .CompilerDirectives .CompilationFinal ;
4760import com .oracle .truffle .api .dsl .Cached ;
4861import com .oracle .truffle .api .dsl .Fallback ;
4962import com .oracle .truffle .api .dsl .GenerateNodeFactory ;
5063import com .oracle .truffle .api .dsl .NodeFactory ;
5164import com .oracle .truffle .api .dsl .Specialization ;
65+ import com .oracle .truffle .api .profiles .ConditionProfile ;
66+ import com .oracle .truffle .api .profiles .ValueProfile ;
5267
5368@ CoreFunctions (extendClasses = {PFrozenSet .class , PSet .class })
5469public final class FrozenSetBuiltins extends PythonBuiltins {
@@ -116,35 +131,51 @@ Object run(PBaseSet self, PBaseSet other) {
116131 @ Builtin (name = __AND__ , fixedNumOfArguments = 2 )
117132 @ GenerateNodeFactory
118133 abstract static class AndNode extends PythonBinaryBuiltinNode {
134+ @ Child private HashingStorageNodes .IntersectNode intersectNode ;
135+
119136 @ Specialization
120- PBaseSet doPBaseSet (PSet left , PBaseSet right ,
121- @ Cached ("create()" ) HashingStorageNodes .IntersectNode intersectNode ) {
122- HashingStorage intersectedStorage = intersectNode .execute (left .getDictStorage (), right .getDictStorage ());
137+ PBaseSet doPBaseSet (PSet left , PBaseSet right ) {
138+ HashingStorage intersectedStorage = getIntersectNode ().execute (left .getDictStorage (), right .getDictStorage ());
123139 return factory ().createSet (intersectedStorage );
124140 }
125141
126142 @ Specialization
127- PBaseSet doPBaseSet (PFrozenSet left , PBaseSet right ,
128- @ Cached ("create()" ) HashingStorageNodes .IntersectNode intersectNode ) {
129- HashingStorage intersectedStorage = intersectNode .execute (left .getDictStorage (), right .getDictStorage ());
143+ PBaseSet doPBaseSet (PFrozenSet left , PBaseSet right ) {
144+ HashingStorage intersectedStorage = getIntersectNode ().execute (left .getDictStorage (), right .getDictStorage ());
130145 return factory ().createFrozenSet (intersectedStorage );
131146 }
147+
148+ private HashingStorageNodes .IntersectNode getIntersectNode () {
149+ if (intersectNode == null ) {
150+ CompilerDirectives .transferToInterpreterAndInvalidate ();
151+ intersectNode = insert (HashingStorageNodes .IntersectNode .create ());
152+ }
153+ return intersectNode ;
154+ }
132155 }
133156
134157 @ Builtin (name = __SUB__ , fixedNumOfArguments = 2 )
135158 @ GenerateNodeFactory
136159 abstract static class SubNode extends PythonBinaryBuiltinNode {
160+ @ Child private HashingStorageNodes .DiffNode diffNode ;
161+
162+ private HashingStorageNodes .DiffNode getDiffNode () {
163+ if (diffNode == null ) {
164+ CompilerDirectives .transferToInterpreterAndInvalidate ();
165+ diffNode = HashingStorageNodes .DiffNode .create ();
166+ }
167+ return diffNode ;
168+ }
169+
137170 @ Specialization
138- PBaseSet doPBaseSet (PSet left , PBaseSet right ,
139- @ Cached ("create()" ) HashingStorageNodes .DiffNode diffNode ) {
140- HashingStorage storage = diffNode .execute (left .getDictStorage (), right .getDictStorage ());
171+ PBaseSet doPBaseSet (PSet left , PBaseSet right ) {
172+ HashingStorage storage = getDiffNode ().execute (left .getDictStorage (), right .getDictStorage ());
141173 return factory ().createSet (storage );
142174 }
143175
144176 @ Specialization
145- PBaseSet doPBaseSet (PFrozenSet left , PBaseSet right ,
146- @ Cached ("create()" ) HashingStorageNodes .DiffNode diffNode ) {
147- HashingStorage storage = diffNode .execute (left .getDictStorage (), right .getDictStorage ());
177+ PBaseSet doPBaseSet (PFrozenSet left , PBaseSet right ) {
178+ HashingStorage storage = getDiffNode ().execute (left .getDictStorage (), right .getDictStorage ());
148179 return factory ().createSet (storage );
149180 }
150181 }
@@ -158,4 +189,105 @@ boolean contains(PBaseSet self, Object key,
158189 return containsKeyNode .execute (self .getDictStorage (), key );
159190 }
160191 }
192+
193+ @ Builtin (name = "union" , minNumOfArguments = 1 , takesVariableArguments = true )
194+ @ GenerateNodeFactory
195+ abstract static class UnionNode extends PythonBuiltinNode {
196+
197+ @ Child private BinaryUnionNode binaryUnionNode ;
198+
199+ @ CompilationFinal private ValueProfile setTypeProfile ;
200+
201+ @ Specialization (guards = {"args.length == len" , "args.length < 32" }, limit = "3" )
202+ PBaseSet doCached (PBaseSet self , Object [] args ,
203+ @ Cached ("args.length" ) int len ,
204+ @ Cached ("create()" ) HashingStorageNodes .CopyNode copyNode ) {
205+ PBaseSet result = create (self , copyNode .execute (self .getDictStorage ()));
206+ for (int i = 0 ; i < len ; i ++) {
207+ getBinaryUnionNode ().execute (result , result .getDictStorage (), args [i ]);
208+ }
209+ return result ;
210+ }
211+
212+ @ Specialization (replaces = "doCached" )
213+ PBaseSet doGeneric (PBaseSet self , Object [] args ,
214+ @ Cached ("create()" ) HashingStorageNodes .CopyNode copyNode ) {
215+ PBaseSet result = create (self , copyNode .execute (self .getDictStorage ()));
216+ for (int i = 0 ; i < args .length ; i ++) {
217+ getBinaryUnionNode ().execute (result , result .getDictStorage (), args [i ]);
218+ }
219+ return result ;
220+ }
221+
222+ private PBaseSet create (PBaseSet left , HashingStorage storage ) {
223+ if (getSetTypeProfile ().profile (left ) instanceof PFrozenSet ) {
224+ return factory ().createFrozenSet (storage );
225+ }
226+ return factory ().createSet (storage );
227+ }
228+
229+ private BinaryUnionNode getBinaryUnionNode () {
230+ if (binaryUnionNode == null ) {
231+ CompilerDirectives .transferToInterpreterAndInvalidate ();
232+ binaryUnionNode = insert (BinaryUnionNode .create ());
233+ }
234+ return binaryUnionNode ;
235+ }
236+
237+ private ValueProfile getSetTypeProfile () {
238+ if (setTypeProfile == null ) {
239+ CompilerDirectives .transferToInterpreterAndInvalidate ();
240+ setTypeProfile = ValueProfile .createClassProfile ();
241+ }
242+ return setTypeProfile ;
243+ }
244+
245+ }
246+
247+ abstract static class BinaryUnionNode extends PBaseNode {
248+ @ Child private Equivalence equivalenceNode ;
249+
250+ public abstract PBaseSet execute (PBaseSet container , HashingStorage left , Object right );
251+
252+ @ Specialization
253+ PBaseSet doHashingCollection (PBaseSet container , EconomicMapStorage selfStorage , PHashingCollection other ) {
254+ for (Object key : other .getDictStorage ().keys ()) {
255+ selfStorage .setItem (key , PNone .NO_VALUE , getEquivalence ());
256+ }
257+ return container ;
258+ }
259+
260+ @ Specialization
261+ PBaseSet doIterable (PBaseSet container , HashingStorage dictStorage , Object iterable ,
262+ @ Cached ("create()" ) GetIteratorNode getIteratorNode ,
263+ @ Cached ("create()" ) GetNextNode next ,
264+ @ Cached ("createBinaryProfile()" ) ConditionProfile errorProfile ,
265+ @ Cached ("create()" ) HashingStorageNodes .SetItemNode setItemNode ) {
266+
267+ Object iterator = getIteratorNode .executeWith (iterable );
268+ while (true ) {
269+ Object value ;
270+ try {
271+ value = next .execute (iterator );
272+ } catch (PException e ) {
273+ e .expectStopIteration (getCore (), errorProfile );
274+ return container ;
275+ }
276+ setItemNode .execute (container , dictStorage , value , PNone .NO_VALUE );
277+ }
278+ }
279+
280+ protected Equivalence getEquivalence () {
281+ if (equivalenceNode == null ) {
282+ CompilerDirectives .transferToInterpreterAndInvalidate ();
283+ equivalenceNode = insert (new PythonEquivalence ());
284+ }
285+ return equivalenceNode ;
286+ }
287+
288+ public static BinaryUnionNode create () {
289+ return BinaryUnionNodeGen .create ();
290+ }
291+
292+ }
161293}
0 commit comments