@@ -24,7 +24,7 @@ class Stream(Generic[_T]):
2424
2525 response : httpx .Response
2626
27- _decoder : SSEDecoder | SSEBytesDecoder
27+ _decoder : SSEBytesDecoder
2828
2929 def __init__ (
3030 self ,
@@ -47,10 +47,7 @@ def __iter__(self) -> Iterator[_T]:
4747 yield item
4848
4949 def _iter_events (self ) -> Iterator [ServerSentEvent ]:
50- if isinstance (self ._decoder , SSEBytesDecoder ):
51- yield from self ._decoder .iter_bytes (self .response .iter_bytes ())
52- else :
53- yield from self ._decoder .iter (self .response .iter_lines ())
50+ yield from self ._decoder .iter_bytes (self .response .iter_bytes ())
5451
5552 def __stream__ (self ) -> Iterator [_T ]:
5653 cast_to = cast (Any , self ._cast_to )
@@ -151,12 +148,8 @@ async def __aiter__(self) -> AsyncIterator[_T]:
151148 yield item
152149
153150 async def _iter_events (self ) -> AsyncIterator [ServerSentEvent ]:
154- if isinstance (self ._decoder , SSEBytesDecoder ):
155- async for sse in self ._decoder .aiter_bytes (self .response .aiter_bytes ()):
156- yield sse
157- else :
158- async for sse in self ._decoder .aiter (self .response .aiter_lines ()):
159- yield sse
151+ async for sse in self ._decoder .aiter_bytes (self .response .aiter_bytes ()):
152+ yield sse
160153
161154 async def __stream__ (self ) -> AsyncIterator [_T ]:
162155 cast_to = cast (Any , self ._cast_to )
@@ -282,21 +275,49 @@ def __init__(self) -> None:
282275 self ._last_event_id = None
283276 self ._retry = None
284277
285- def iter (self , iterator : Iterator [str ]) -> Iterator [ServerSentEvent ]:
286- """Given an iterator that yields lines, iterate over it & yield every event encountered"""
287- for line in iterator :
288- line = line .rstrip ("\n " )
289- sse = self .decode (line )
290- if sse is not None :
291- yield sse
292-
293- async def aiter (self , iterator : AsyncIterator [str ]) -> AsyncIterator [ServerSentEvent ]:
294- """Given an async iterator that yields lines, iterate over it & yield every event encountered"""
295- async for line in iterator :
296- line = line .rstrip ("\n " )
297- sse = self .decode (line )
298- if sse is not None :
299- yield sse
278+ def iter_bytes (self , iterator : Iterator [bytes ]) -> Iterator [ServerSentEvent ]:
279+ """Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
280+ for chunk in self ._iter_chunks (iterator ):
281+ # Split before decoding so splitlines() only uses \r and \n
282+ for raw_line in chunk .splitlines ():
283+ line = raw_line .decode ("utf-8" )
284+ sse = self .decode (line )
285+ if sse :
286+ yield sse
287+
288+ def _iter_chunks (self , iterator : Iterator [bytes ]) -> Iterator [bytes ]:
289+ """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks"""
290+ data = b""
291+ for chunk in iterator :
292+ for line in chunk .splitlines (keepends = True ):
293+ data += line
294+ if data .endswith ((b"\r \r " , b"\n \n " , b"\r \n \r \n " )):
295+ yield data
296+ data = b""
297+ if data :
298+ yield data
299+
300+ async def aiter_bytes (self , iterator : AsyncIterator [bytes ]) -> AsyncIterator [ServerSentEvent ]:
301+ """Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
302+ async for chunk in self ._aiter_chunks (iterator ):
303+ # Split before decoding so splitlines() only uses \r and \n
304+ for raw_line in chunk .splitlines ():
305+ line = raw_line .decode ("utf-8" )
306+ sse = self .decode (line )
307+ if sse :
308+ yield sse
309+
310+ async def _aiter_chunks (self , iterator : AsyncIterator [bytes ]) -> AsyncIterator [bytes ]:
311+ """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks"""
312+ data = b""
313+ async for chunk in iterator :
314+ for line in chunk .splitlines (keepends = True ):
315+ data += line
316+ if data .endswith ((b"\r \r " , b"\n \n " , b"\r \n \r \n " )):
317+ yield data
318+ data = b""
319+ if data :
320+ yield data
300321
301322 def decode (self , line : str ) -> ServerSentEvent | None :
302323 # See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501
0 commit comments