@@ -124,6 +124,9 @@ abstract static class AcceptNode extends PythonUnaryBuiltinNode {
124124 @ Specialization
125125 @ TruffleBoundary
126126 Object accept (PSocket socket ) {
127+ if (socket .getServerSocket () == null ) {
128+ throw raiseOSError (null , OSErrorEnum .EINVAL );
129+ }
127130 try {
128131 SocketChannel acceptSocket = SocketUtils .accept (this , socket );
129132 if (acceptSocket == null ) {
@@ -214,6 +217,7 @@ private static void doConnect(PSocket socket, Object[] hostAndPort) throws IOExc
214217 InetSocketAddress socketAddress = new InetSocketAddress ((String ) hostAndPort [0 ], (Integer ) hostAndPort [1 ]);
215218 SocketChannel channel = SocketChannel .open ();
216219 channel .connect (socketAddress );
220+ channel .configureBlocking (socket .isBlocking ());
217221 socket .setSocket (channel );
218222 }
219223 }
@@ -338,6 +342,9 @@ Object listen(PSocket socket, PNone backlog) {
338342 abstract static class RecvNode extends PythonTernaryClinicBuiltinNode {
339343 @ Specialization
340344 Object recv (VirtualFrame frame , PSocket socket , int bufsize , int flags ) {
345+ if (socket .getSocket () == null ) {
346+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
347+ }
341348 ByteBuffer readBytes = PythonUtils .allocateByteBuffer (bufsize );
342349 try {
343350 int length = SocketUtils .recv (this , socket , readBytes );
@@ -384,6 +391,9 @@ Object recvInto(VirtualFrame frame, PSocket socket, PMemoryView buffer, Object f
384391 @ CachedLibrary (limit = "getCallSiteInlineCacheMaxDepth()" ) PythonObjectLibrary lib ,
385392 @ Cached ("create(__LEN__)" ) LookupAndCallUnaryNode callLen ,
386393 @ Cached ("create(__SETITEM__)" ) LookupAndCallTernaryNode setItem ) {
394+ if (socket .getSocket () == null ) {
395+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
396+ }
387397 int bufferLen = lib .asSizeWithState (callLen .executeObject (frame , buffer ), PArguments .getThreadState (frame ));
388398 byte [] targetBuffer = new byte [bufferLen ];
389399 ByteBuffer byteBuffer = PythonUtils .wrapByteBuffer (targetBuffer );
@@ -410,6 +420,9 @@ Object recvInto(VirtualFrame frame, PSocket socket, PByteArray buffer, Object fl
410420 @ Cached ("createBinaryProfile()" ) ConditionProfile byteStorage ,
411421 @ Cached SequenceStorageNodes .LenNode lenNode ,
412422 @ Cached ("createSetItem()" ) SequenceStorageNodes .SetItemNode setItem ) {
423+ if (socket .getSocket () == null ) {
424+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
425+ }
413426 SequenceStorage storage = buffer .getSequenceStorage ();
414427 int bufferLen = lenNode .execute (storage );
415428 if (byteStorage .profile (storage instanceof ByteSequenceStorage )) {
@@ -470,17 +483,14 @@ Object send(VirtualFrame frame, PSocket socket, PBytes bytes, Object flags,
470483 @ Cached SequenceStorageNodes .ToByteArrayNode toBytes ) {
471484 // TODO: do not ignore flags
472485 if (socket .getSocket () == null ) {
473- throw raise (OSError );
474- }
475-
476- if (!socket .isOpen ()) {
477- throw raise (OSError );
486+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
478487 }
479-
480488 int written ;
481489 ByteBuffer buffer = PythonUtils .wrapByteBuffer (toBytes .execute (bytes .getSequenceStorage ()));
482490 try {
483491 written = SocketUtils .send (this , socket , buffer );
492+ } catch (NotYetConnectedException e ) {
493+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
484494 } catch (IOException e ) {
485495 throw raise (OSError );
486496 }
@@ -500,6 +510,9 @@ Object sendAll(VirtualFrame frame, PSocket socket, PBytesLike bytes, Object flag
500510 @ Cached SequenceStorageNodes .ToByteArrayNode toBytes ,
501511 @ Cached ConditionProfile hasTimeoutProfile ) {
502512 // TODO: do not ignore flags
513+ if (socket .getSocket () == null ) {
514+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
515+ }
503516 ByteBuffer buffer = PythonUtils .wrapByteBuffer (toBytes .execute (bytes .getSequenceStorage ()));
504517 long timeoutMillis = socket .getTimeoutInMilliseconds ();
505518 TimeoutHelper timeoutHelper = null ;
@@ -513,6 +526,8 @@ Object sendAll(VirtualFrame frame, PSocket socket, PBytesLike bytes, Object flag
513526 int written ;
514527 try {
515528 written = SocketUtils .send (this , socket , buffer , timeoutMillis );
529+ } catch (NotYetConnectedException e ) {
530+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
516531 } catch (IOException e ) {
517532 throw raise (OSError );
518533 }
@@ -606,24 +621,29 @@ Object setTimeout(PSocket socket, Object secondsObj,
606621 @ GenerateNodeFactory
607622 abstract static class shutdownNode extends PythonBinaryBuiltinNode {
608623 @ Specialization
609- @ TruffleBoundary
610- Object family (PSocket socket , int how ) {
611- if (socket .getSocket () != null ) {
612- try {
613- if (how == 0 || how == 2 ) {
614- socket .getSocket ().shutdownInput ();
615- }
616- if (how == 1 || how == 2 ) {
617- socket .getSocket ().shutdownOutput ();
618- }
619- } catch (IOException e ) {
620- throw raise (OSError );
621- }
622- } else {
624+ Object family (VirtualFrame frame , PSocket socket , int how ) {
625+ if (socket .getSocket () == null ) {
626+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
627+ }
628+ try {
629+ shutdown (socket , how );
630+ } catch (NotYetConnectedException e ) {
631+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
632+ } catch (IOException e ) {
623633 throw raise (OSError );
624634 }
625635 return PNone .NO_VALUE ;
626636 }
637+
638+ @ TruffleBoundary
639+ private static void shutdown (PSocket socket , int how ) throws IOException {
640+ if (how == 0 || how == 2 ) {
641+ socket .getSocket ().shutdownInput ();
642+ }
643+ if (how == 1 || how == 2 ) {
644+ socket .getSocket ().shutdownOutput ();
645+ }
646+ }
627647 }
628648
629649 // family
0 commit comments