diff --git a/av/container/output.py b/av/container/output.py index 145437752..e57f269e8 100644 --- a/av/container/output.py +++ b/av/container/output.py @@ -278,15 +278,23 @@ def add_data_stream(self, codec_name=None, options: dict | None = None): :rtype: The new :class:`~av.data.stream.DataStream`. """ codec: cython.pointer[cython.const[lib.AVCodec]] = cython.NULL + codec_descriptor: cython.pointer[lib.AVCodecDescriptor] = cython.NULL if codec_name is not None: codec = lib.avcodec_find_encoder_by_name(codec_name.encode()) if codec == cython.NULL: - raise ValueError(f"Unknown data codec: {codec_name}") + codec = lib.avcodec_find_decoder_by_name(codec_name.encode()) + if codec == cython.NULL: + codec_descriptor = lib.avcodec_descriptor_get_by_name( + codec_name.encode() + ) + if codec_descriptor == cython.NULL: + raise ValueError(f"Unknown data codec: {codec_name}") - # Assert that this format supports the requested codec + # Verify format supports this codec + codec_id = codec.id if codec != cython.NULL else codec_descriptor.id if not lib.avformat_query_codec( - self.ptr.oformat, codec.id, lib.FF_COMPLIANCE_NORMAL + self.ptr.oformat, codec_id, lib.FF_COMPLIANCE_NORMAL ): raise ValueError( f"{self.format.name!r} format does not support {codec_name!r} codec" @@ -297,7 +305,7 @@ def add_data_stream(self, codec_name=None, options: dict | None = None): if stream == cython.NULL: raise MemoryError("Could not allocate stream") - # Set up codec context if we have a codec + # Set up codec context and parameters ctx: cython.pointer[lib.AVCodecContext] = cython.NULL if codec != cython.NULL: ctx = lib.avcodec_alloc_context3(codec) @@ -311,8 +319,10 @@ def add_data_stream(self, codec_name=None, options: dict | None = None): # Initialize stream codec parameters err_check(lib.avcodec_parameters_from_context(stream.codecpar, ctx)) else: - # For raw data streams, just set the codec type + # No codec available - set basic parameters for data stream stream.codecpar.codec_type = lib.AVMEDIA_TYPE_DATA + if codec_descriptor != cython.NULL: + stream.codecpar.codec_id = codec_descriptor.id # Construct the user-land stream py_codec_context: CodecContext | None = None diff --git a/tests/test_streams.py b/tests/test_streams.py index f82ce384b..9387d68cc 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -17,6 +17,8 @@ def cleanup(self): "data.ts", "data_source.ts", "data_copy.ts", + "data_with_codec.ts", + "data_invalid.ts", "out.mkv", "video_with_attachment.mkv", "remuxed_attachment.mkv", @@ -201,6 +203,50 @@ def test_data_stream_from_template(self) -> None: assert remuxed_payloads == copied_payloads + def test_data_stream_with_codec(self) -> None: + """Test adding a data stream with a specific codec name.""" + # Test that invalid codec names raise appropriate errors + with pytest.raises(ValueError, match="Unknown data codec"): + container = av.open("data_invalid.ts", "w") + try: + container.add_data_stream("not_a_real_codec_name_12345") + finally: + container.close() + + # Test that add_data_stream with codec parameter works correctly + # We use "bin_data" which is a data codec that's always available + output_path = "data_with_codec.ts" + with av.open(output_path, "w") as container: + # Try to create a data stream with a codec + # bin_data is a simple passthrough codec for binary data + data_stream = container.add_data_stream("bin_data") + klv_stream = container.add_data_stream("klv") + + assert data_stream.type == "data" + assert klv_stream.type == "data" + # Note: codec_context may be None for descriptor-only data codecs + + test_data = [b"test1", b"test2", b"test3"] + for i, data in enumerate(test_data): + packet = av.Packet(data) + packet.pts = i + packet.stream = data_stream + container.mux(packet) + + with av.open(output_path) as newcontainer: + data_stream = newcontainer.streams.data[0] + klv_stream = newcontainer.streams.data[1] + assert data_stream.type == "data" + assert klv_stream.type == "data" + assert "bin_data" in str(data_stream) + assert "klv" in str(klv_stream) + assert data_stream.name == "bin_data" + assert klv_stream.name == "klv" + try: + os.remove(output_path) + except Exception: + pass + def test_attachment_stream(self) -> None: input_path = av.datasets.curated( "pexels/time-lapse-video-of-night-sky-857195.mp4"