@@ -124,6 +124,9 @@ abstract static class AcceptNode extends PythonUnaryBuiltinNode {
124124 @ Specialization
125125 @ TruffleBoundary
126126 Object accept (PSocket socket ) {
127+ if (socket .getServerSocket () == null ) {
128+ throw raiseOSError (null , OSErrorEnum .EINVAL );
129+ }
127130 try {
128131 SocketChannel acceptSocket = SocketUtils .accept (this , socket );
129132 if (acceptSocket == null ) {
@@ -338,6 +341,9 @@ Object listen(PSocket socket, PNone backlog) {
338341 abstract static class RecvNode extends PythonTernaryClinicBuiltinNode {
339342 @ Specialization
340343 Object recv (VirtualFrame frame , PSocket socket , int bufsize , int flags ) {
344+ if (socket .getSocket () == null ) {
345+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
346+ }
341347 ByteBuffer readBytes = PythonUtils .allocateByteBuffer (bufsize );
342348 try {
343349 int length = SocketUtils .recv (this , socket , readBytes );
@@ -384,6 +390,9 @@ Object recvInto(VirtualFrame frame, PSocket socket, PMemoryView buffer, Object f
384390 @ CachedLibrary (limit = "getCallSiteInlineCacheMaxDepth()" ) PythonObjectLibrary lib ,
385391 @ Cached ("create(__LEN__)" ) LookupAndCallUnaryNode callLen ,
386392 @ Cached ("create(__SETITEM__)" ) LookupAndCallTernaryNode setItem ) {
393+ if (socket .getSocket () == null ) {
394+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
395+ }
387396 int bufferLen = lib .asSizeWithState (callLen .executeObject (frame , buffer ), PArguments .getThreadState (frame ));
388397 byte [] targetBuffer = new byte [bufferLen ];
389398 ByteBuffer byteBuffer = PythonUtils .wrapByteBuffer (targetBuffer );
@@ -410,6 +419,9 @@ Object recvInto(VirtualFrame frame, PSocket socket, PByteArray buffer, Object fl
410419 @ Cached ("createBinaryProfile()" ) ConditionProfile byteStorage ,
411420 @ Cached SequenceStorageNodes .LenNode lenNode ,
412421 @ Cached ("createSetItem()" ) SequenceStorageNodes .SetItemNode setItem ) {
422+ if (socket .getSocket () == null ) {
423+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
424+ }
413425 SequenceStorage storage = buffer .getSequenceStorage ();
414426 int bufferLen = lenNode .execute (storage );
415427 if (byteStorage .profile (storage instanceof ByteSequenceStorage )) {
@@ -470,17 +482,14 @@ Object send(VirtualFrame frame, PSocket socket, PBytes bytes, Object flags,
470482 @ Cached SequenceStorageNodes .ToByteArrayNode toBytes ) {
471483 // TODO: do not ignore flags
472484 if (socket .getSocket () == null ) {
473- throw raise (OSError );
474- }
475-
476- if (!socket .isOpen ()) {
477- throw raise (OSError );
485+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
478486 }
479-
480487 int written ;
481488 ByteBuffer buffer = PythonUtils .wrapByteBuffer (toBytes .execute (bytes .getSequenceStorage ()));
482489 try {
483490 written = SocketUtils .send (this , socket , buffer );
491+ } catch (NotYetConnectedException e ) {
492+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
484493 } catch (IOException e ) {
485494 throw raise (OSError );
486495 }
@@ -500,6 +509,9 @@ Object sendAll(VirtualFrame frame, PSocket socket, PBytesLike bytes, Object flag
500509 @ Cached SequenceStorageNodes .ToByteArrayNode toBytes ,
501510 @ Cached ConditionProfile hasTimeoutProfile ) {
502511 // TODO: do not ignore flags
512+ if (socket .getSocket () == null ) {
513+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
514+ }
503515 ByteBuffer buffer = PythonUtils .wrapByteBuffer (toBytes .execute (bytes .getSequenceStorage ()));
504516 long timeoutMillis = socket .getTimeoutInMilliseconds ();
505517 TimeoutHelper timeoutHelper = null ;
@@ -513,6 +525,8 @@ Object sendAll(VirtualFrame frame, PSocket socket, PBytesLike bytes, Object flag
513525 int written ;
514526 try {
515527 written = SocketUtils .send (this , socket , buffer , timeoutMillis );
528+ } catch (NotYetConnectedException e ) {
529+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
516530 } catch (IOException e ) {
517531 throw raise (OSError );
518532 }
@@ -606,24 +620,29 @@ Object setTimeout(PSocket socket, Object secondsObj,
606620 @ GenerateNodeFactory
607621 abstract static class shutdownNode extends PythonBinaryBuiltinNode {
608622 @ Specialization
609- @ TruffleBoundary
610- Object family (PSocket socket , int how ) {
611- if (socket .getSocket () != null ) {
612- try {
613- if (how == 0 || how == 2 ) {
614- socket .getSocket ().shutdownInput ();
615- }
616- if (how == 1 || how == 2 ) {
617- socket .getSocket ().shutdownOutput ();
618- }
619- } catch (IOException e ) {
620- throw raise (OSError );
621- }
622- } else {
623+ Object family (VirtualFrame frame , PSocket socket , int how ) {
624+ if (socket .getSocket () == null ) {
625+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
626+ }
627+ try {
628+ shutdown (socket , how );
629+ } catch (NotYetConnectedException e ) {
630+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
631+ } catch (IOException e ) {
623632 throw raise (OSError );
624633 }
625634 return PNone .NO_VALUE ;
626635 }
636+
637+ @ TruffleBoundary
638+ private static void shutdown (PSocket socket , int how ) throws IOException {
639+ if (how == 0 || how == 2 ) {
640+ socket .getSocket ().shutdownInput ();
641+ }
642+ if (how == 1 || how == 2 ) {
643+ socket .getSocket ().shutdownOutput ();
644+ }
645+ }
627646 }
628647
629648 // family
0 commit comments