diff --git a/Cargo.lock b/Cargo.lock index ec44bbc0b..1af7e9860 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -371,6 +371,15 @@ dependencies = [ "serde", ] +[[package]] +name = "bytestring" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "113b4343b5f6617e7ad401ced8de3cc8b012e73a594347c307b90db3e9271289" +dependencies = [ + "bytes", +] + [[package]] name = "cbindgen" version = "0.29.2" @@ -1132,6 +1141,7 @@ dependencies = [ "num_enum", "regex", "reqwest", + "scuffle-h265", "serde", "serde_json", "serde_with", @@ -2024,6 +2034,15 @@ dependencies = [ "syn", ] +[[package]] +name = "nutype-enum" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1e13adea6de269faa0724df58f43f6fe2a81af7094f1dcb8b5b968eb2103cb3" +dependencies = [ + "scuffle-workspace-hack", +] + [[package]] name = "object" version = "0.37.3" @@ -2767,6 +2786,49 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "scuffle-bytes-util" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0417748c2a42f4a08d4e634b68b1d64f22a8c24bef2e7ac93df33aa61202a45b" +dependencies = [ + "byteorder", + "bytes", + "bytestring", + "scuffle-workspace-hack", +] + +[[package]] +name = "scuffle-expgolomb" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48d21330974c941e4c0aedc1e7255ea809e8cbac51e135209f6d67843ad1b94d" +dependencies = [ + "scuffle-bytes-util", + "scuffle-workspace-hack", +] + +[[package]] +name = "scuffle-h265" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b04b276c2f79846b7968abe6f87cedf951e06fd2a2b72d99c457e85d7e40f3fb" +dependencies = [ + "bitflags", + "byteorder", + "bytes", + "nutype-enum", + "scuffle-bytes-util", + "scuffle-expgolomb", + "scuffle-workspace-hack", +] + +[[package]] +name = "scuffle-workspace-hack" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8028ded836a0d9fabdfa4d713389b76a2098b5153f50a135c8faed7e3a3d5ae2" + [[package]] name = "sd-notify" version = "0.4.5" diff --git a/rs/hang/Cargo.toml b/rs/hang/Cargo.toml index fd9705eb2..75928bc1f 100644 --- a/rs/hang/Cargo.toml +++ b/rs/hang/Cargo.toml @@ -29,6 +29,7 @@ reqwest = { version = "0.12", default-features = false, features = [ "rustls-tls", "gzip", ] } +scuffle-h265 = "0.2.2" serde = { workspace = true } serde_json = "1" serde_with = { version = "3", features = ["hex"] } diff --git a/rs/hang/src/import/annexb.rs b/rs/hang/src/import/annexb.rs new file mode 100644 index 000000000..03e0c88e0 --- /dev/null +++ b/rs/hang/src/import/annexb.rs @@ -0,0 +1,535 @@ +use anyhow::{self}; +use bytes::{Buf, Bytes}; + +pub const START_CODE: Bytes = Bytes::from_static(&[0, 0, 0, 1]); + +pub struct NalIterator<'a, T: Buf + AsRef<[u8]> + 'a> { + buf: &'a mut T, + start: Option, +} + +impl<'a, T: Buf + AsRef<[u8]> + 'a> NalIterator<'a, T> { + pub fn new(buf: &'a mut T) -> Self { + Self { buf, start: None } + } + + /// Assume the buffer ends with a NAL unit and flush it. + /// This is more efficient because we cache the last "start" code position. + pub fn flush(self) -> anyhow::Result> { + let start = match self.start { + Some(start) => start, + None => match after_start_code(self.buf.as_ref())? { + Some(start) => start, + None => return Ok(None), + }, + }; + + self.buf.advance(start); + + let nal = self.buf.copy_to_bytes(self.buf.remaining()); + Ok(Some(nal)) + } +} + +impl<'a, T: Buf + AsRef<[u8]> + 'a> Iterator for NalIterator<'a, T> { + type Item = anyhow::Result; + + fn next(&mut self) -> Option { + let start = match self.start { + Some(start) => start, + None => match after_start_code(self.buf.as_ref()).transpose()? { + Ok(start) => start, + Err(err) => return Some(Err(err)), + }, + }; + + let (size, new_start) = find_start_code(&self.buf.as_ref()[start..])?; + self.buf.advance(start); + + let nal = self.buf.copy_to_bytes(size); + self.start = Some(new_start); + Some(Ok(nal)) + } +} + +// Return the size of the start code at the start of the buffer. +pub fn after_start_code(b: &[u8]) -> anyhow::Result> { + if b.len() < 3 { + return Ok(None); + } + + // NOTE: We have to check every byte, so the `find_start_code` optimization doesn't matter. + anyhow::ensure!(b[0] == 0, "missing Annex B start code"); + anyhow::ensure!(b[1] == 0, "missing Annex B start code"); + + match b[2] { + 0 if b.len() < 4 => Ok(None), + 0 if b[3] != 1 => anyhow::bail!("missing Annex B start code"), + 0 => Ok(Some(4)), + 1 => Ok(Some(3)), + _ => anyhow::bail!("invalid Annex B start code"), + } +} + +// Return the number of bytes until the next start code, and the size of that start code. +pub fn find_start_code(mut b: &[u8]) -> Option<(usize, usize)> { + // Okay this is over-engineered because this was my interview question. + // We need to find either a 3 byte or 4 byte start code. + // 3-byte: 0 0 1 + // 4-byte: 0 0 0 1 + // + // You fail the interview if you call string.split twice or something. + // You get a pass if you do index += 1 and check the next 3-4 bytes. + // You get my eternal respect if you check the 3rd byte first. + // What? + // + // If we check the 3rd byte and it's not a 0 or 1, then we immediately index += 3 + // Sometimes we might only skip 1 or 2 bytes, but it's still better than checking every byte. + // + // TODO Is this the type of thing that SIMD could further improve? + // If somebody can figure that out, I'll buy you a beer. + let size = b.len(); + + while b.len() >= 3 { + // ? ? ? + match b[2] { + // ? ? 0 + 0 if b.len() >= 4 => match b[3] { + // ? ? 0 1 + 1 => match b[1] { + // ? 0 0 1 + 0 => match b[0] { + // 0 0 0 1 + 0 => return Some((size - b.len(), 4)), + // ? 0 0 1 + _ => return Some((size - b.len() + 1, 3)), + }, + // ? x 0 1 + _ => b = &b[4..], + }, + // ? ? 0 0 - skip only 1 byte to check for potential 0 0 0 1 + 0 => b = &b[1..], + // ? ? 0 x + _ => b = &b[4..], + }, + // ? ? 0 FIN + 0 => return None, + // ? ? 1 + 1 => match b[1] { + // ? 0 1 + 0 => match b[0] { + // 0 0 1 + 0 => return Some((size - b.len(), 3)), + // ? 0 1 + _ => b = &b[3..], + }, + // ? x 1 + _ => b = &b[3..], + }, + // ? ? x + _ => b = &b[3..], + } + } + + None +} + +#[cfg(test)] +mod tests { + use super::*; + + // Tests for after_start_code - validates and measures start code at buffer beginning + + #[test] + fn test_after_start_code_3_byte() { + let buf = &[0, 0, 1, 0x67]; + assert_eq!(after_start_code(buf).unwrap(), Some(3)); + } + + #[test] + fn test_after_start_code_4_byte() { + let buf = &[0, 0, 0, 1, 0x67]; + assert_eq!(after_start_code(buf).unwrap(), Some(4)); + } + + #[test] + fn test_after_start_code_too_short() { + let buf = &[0, 0]; + assert_eq!(after_start_code(buf).unwrap(), None); + } + + #[test] + fn test_after_start_code_incomplete_4_byte() { + let buf = &[0, 0, 0]; + assert_eq!(after_start_code(buf).unwrap(), None); + } + + #[test] + fn test_after_start_code_invalid_first_byte() { + let buf = &[1, 0, 1]; + assert!(after_start_code(buf).is_err()); + } + + #[test] + fn test_after_start_code_invalid_second_byte() { + let buf = &[0, 1, 1]; + assert!(after_start_code(buf).is_err()); + } + + #[test] + fn test_after_start_code_invalid_third_byte() { + let buf = &[0, 0, 2]; + assert!(after_start_code(buf).is_err()); + } + + #[test] + fn test_after_start_code_invalid_4_byte_pattern() { + let buf = &[0, 0, 0, 2]; + assert!(after_start_code(buf).is_err()); + } + + // Tests for find_start_code - finds next start code in NAL data + + #[test] + fn test_find_start_code_3_byte() { + let buf = &[0x67, 0x42, 0x00, 0x1f, 0, 0, 1]; + assert_eq!(find_start_code(buf), Some((4, 3))); + } + + #[test] + fn test_find_start_code_4_byte() { + // Should detect 4-byte start code at beginning + let buf = &[0, 0, 0, 1, 0x67]; + assert_eq!(find_start_code(buf), Some((0, 4))); + } + + #[test] + fn test_find_start_code_4_byte_after_data() { + // Should detect 4-byte start code after NAL data + let buf = &[0x67, 0x42, 0xff, 0x1f, 0, 0, 0, 1]; + assert_eq!(find_start_code(buf), Some((4, 4))); + } + + #[test] + fn test_find_start_code_at_start_3_byte() { + let buf = &[0, 0, 1, 0x67]; + assert_eq!(find_start_code(buf), Some((0, 3))); + } + + #[test] + fn test_find_start_code_none() { + let buf = &[0x67, 0x42, 0x00, 0x1f, 0xff]; + assert_eq!(find_start_code(buf), None); + } + + #[test] + fn test_find_start_code_trailing_zeros() { + let buf = &[0x67, 0x42, 0x00, 0x1f, 0, 0]; + assert_eq!(find_start_code(buf), None); + } + + #[test] + fn test_find_start_code_edge_case_3_byte() { + let buf = &[0xff, 0, 0, 1]; + assert_eq!(find_start_code(buf), Some((1, 3))); + } + + #[test] + fn test_find_start_code_false_positive_avoidance() { + // Pattern like: x 0 0 y (where y != 1) - should skip ahead + let buf = &[0xff, 0, 0, 0xff, 0, 0, 1]; + assert_eq!(find_start_code(buf), Some((4, 3))); + } + + #[test] + fn test_find_start_code_4_byte_after_nonzero() { + // Critical edge case: x 0 0 0 1 should find 4-byte start code at position 1 + // This tests that we only skip 1 byte when seeing ? ? 0 0 + let buf = &[0xff, 0, 0, 0, 1]; + assert_eq!(find_start_code(buf), Some((1, 4))); + } + + #[test] + fn test_find_start_code_consecutive_zeros() { + // Multiple consecutive zeros before the 1 + let buf = &[0xff, 0, 0, 0, 0, 0, 1]; + // Should skip past leading zeros and find the start code + let result = find_start_code(buf); + assert!(result.is_some()); + let (pos, size) = result.unwrap(); + // The exact position depends on the algorithm, but it should find a valid start code + assert!(size == 3 || size == 4); + assert!(pos < buf.len()); + } + + // Tests for NalIterator - iterates over NAL units in Annex B format + + #[test] + fn test_nal_iterator_simple_3_byte() { + let mut data = Bytes::from(vec![0, 0, 1, 0x67, 0x42, 0, 0, 1]); + let mut iter = NalIterator::new(&mut data); + + let nal = iter.next().unwrap().unwrap(); + assert_eq!(nal.as_ref(), &[0x67, 0x42]); + assert!(iter.next().is_none()); + + // Make sure the trailing 001 is still in the buffer. + assert_eq!(data.as_ref(), &[0, 0, 1]); + } + + #[test] + fn test_nal_iterator_simple_4_byte() { + let mut data = Bytes::from(vec![0, 0, 0, 1, 0x67, 0x42, 0, 0, 0, 1]); + let mut iter = NalIterator::new(&mut data); + + let nal = iter.next().unwrap().unwrap(); + assert_eq!(nal.as_ref(), &[0x67, 0x42]); + assert!(iter.next().is_none()); + + // Make sure the trailing 0001 is still in the buffer. + assert_eq!(data.as_ref(), &[0, 0, 0, 1]); + } + + #[test] + fn test_nal_iterator_multiple_nals() { + let mut data = Bytes::from(vec![0, 0, 0, 1, 0x67, 0x42, 0, 0, 0, 1, 0x68, 0xce, 0, 0, 0, 1]); + let mut iter = NalIterator::new(&mut data); + + let nal1 = iter.next().unwrap().unwrap(); + assert_eq!(nal1.as_ref(), &[0x67, 0x42]); + + let nal2 = iter.next().unwrap().unwrap(); + assert_eq!(nal2.as_ref(), &[0x68, 0xce]); + + assert!(iter.next().is_none()); + + // Make sure the trailing 0001 is still in the buffer. + assert_eq!(data.as_ref(), &[0, 0, 0, 1]); + } + + #[test] + fn test_nal_iterator_realistic_h264() { + // A realistic H.264 stream with SPS, PPS, and IDR + let mut data = Bytes::from(vec![ + 0, 0, 0, 1, 0x67, 0x42, 0x00, 0x1f, // SPS NAL + 0, 0, 0, 1, 0x68, 0xce, 0x3c, 0x80, // PPS NAL + 0, 0, 0, 1, 0x65, 0x88, 0x84, 0x00, // IDR slice + // Trailing start code (needed to detect the end of the last NAL) + 0, 0, 0, 1, + ]); + let mut iter = NalIterator::new(&mut data); + + let sps = iter.next().unwrap().unwrap(); + assert_eq!(sps[0] & 0x1f, 7); // SPS type + assert_eq!(sps.as_ref(), &[0x67, 0x42, 0x00, 0x1f]); + + let pps = iter.next().unwrap().unwrap(); + assert_eq!(pps[0] & 0x1f, 8); // PPS type + assert_eq!(pps.as_ref(), &[0x68, 0xce, 0x3c, 0x80]); + + let idr = iter.next().unwrap().unwrap(); + assert_eq!(idr[0] & 0x1f, 5); // IDR type + assert_eq!(idr.as_ref(), &[0x65, 0x88, 0x84, 0x00]); + + assert!(iter.next().is_none()); + + // Make sure the trailing 0001 is still in the buffer. + assert_eq!(data.as_ref(), &[0, 0, 0, 1]); + } + + #[test] + fn test_nal_iterator_realistic_h265() { + // A realistic H.265 stream with VPS, SPS, PPS, and IDR + let mut data = Bytes::from(vec![ + 0, 0, 0, 1, 0x40, 0x01, 0x0c, 0x01, // VPS NAL + 0, 0, 0, 1, 0x42, 0x01, 0x01, 0x60, // SPS NAL + 0, 0, 0, 1, 0x44, 0x01, 0xc0, 0xf1, // PPS NAL + 0, 0, 0, 1, 0x26, 0x01, 0x9a, 0x20, // IDR_W_RADL slice + // Trailing start code (needed to detect the end of the last NAL) + 0, 0, 0, 1, + ]); + let mut iter = NalIterator::new(&mut data); + + let vps = iter.next().unwrap().unwrap(); + assert_eq!((vps[0] >> 1) & 0x3f, 32); // VPS type + assert_eq!(vps.as_ref(), &[0x40, 0x01, 0x0c, 0x01]); + + let sps = iter.next().unwrap().unwrap(); + assert_eq!((sps[0] >> 1) & 0x3f, 33); // SPS type + assert_eq!(sps.as_ref(), &[0x42, 0x01, 0x01, 0x60]); + + let pps = iter.next().unwrap().unwrap(); + assert_eq!((pps[0] >> 1) & 0x3f, 34); // PPS type + assert_eq!(pps.as_ref(), &[0x44, 0x01, 0xc0, 0xf1]); + + let idr = iter.next().unwrap().unwrap(); + assert_eq!((idr[0] >> 1) & 0x3f, 19); // IDR slice type (IDR_W_RADL) + assert_eq!(idr.as_ref(), &[0x26, 0x01, 0x9a, 0x20]); + + assert!(iter.next().is_none()); + + // Make sure the trailing 0001 is still in the buffer. + assert_eq!(data.as_ref(), &[0, 0, 0, 1]); + } + + #[test] + fn test_nal_iterator_invalid_start() { + let mut data = Bytes::from(vec![1, 0, 1, 0x67]); + let mut iter = NalIterator::new(&mut data); + + assert!(iter.next().unwrap().is_err()); + + // Make sure the data is still in the buffer. + assert_eq!(data.as_ref(), &[1, 0, 1, 0x67]); + } + + #[test] + fn test_nal_iterator_empty_nal() { + // Two consecutive start codes create an empty NAL + let mut data = Bytes::from(vec![0, 0, 1, 0, 0, 1, 0x67, 0, 0, 1]); + let mut iter = NalIterator::new(&mut data); + + let nal1 = iter.next().unwrap().unwrap(); + assert_eq!(nal1.len(), 0); + + let nal2 = iter.next().unwrap().unwrap(); + assert_eq!(nal2.as_ref(), &[0x67]); + + assert!(iter.next().is_none()); + + // Make sure the data is still in the buffer. + assert_eq!(data.as_ref(), &[0, 0, 1]); + } + + #[test] + fn test_nal_iterator_nal_with_embedded_zeros() { + // NAL data that contains zeros (but not a start code pattern) + let mut data = Bytes::from(vec![ + 0, 0, 1, 0x67, 0x00, 0x00, 0x00, 0xff, // NAL with embedded zeros + 0, 0, 1, 0x68, // Next NAL + 0, 0, 1, + ]); + let mut iter = NalIterator::new(&mut data); + + let nal1 = iter.next().unwrap().unwrap(); + assert_eq!(nal1.as_ref(), &[0x67, 0x00, 0x00, 0x00, 0xff]); + + let nal2 = iter.next().unwrap().unwrap(); + assert_eq!(nal2.as_ref(), &[0x68]); + + assert!(iter.next().is_none()); + + // Make sure the data is still in the buffer. + assert_eq!(data.as_ref(), &[0, 0, 1]); + } + + // Tests for flush - extracts final NAL without trailing start code + + #[test] + fn test_flush_after_iteration() { + // Normal case: iterate over NALs, then flush the final one + let mut data = Bytes::from(vec![ + 0, 0, 0, 1, 0x67, 0x42, // First NAL + 0, 0, 0, 1, 0x68, 0xce, 0x3c, 0x80, // Second NAL (final, no trailing start code) + ]); + let mut iter = NalIterator::new(&mut data); + + let nal1 = iter.next().unwrap().unwrap(); + assert_eq!(nal1.as_ref(), &[0x67, 0x42]); + + assert!(iter.next().is_none()); + + let final_nal = iter.flush().unwrap().unwrap(); + assert_eq!(final_nal.as_ref(), &[0x68, 0xce, 0x3c, 0x80]); + } + + #[test] + fn test_flush_single_nal() { + // Buffer contains only a single NAL with no trailing start code + let mut data = Bytes::from(vec![0, 0, 1, 0x67, 0x42, 0x00, 0x1f]); + let iter = NalIterator::new(&mut data); + + let final_nal = iter.flush().unwrap().unwrap(); + assert_eq!(final_nal.as_ref(), &[0x67, 0x42, 0x00, 0x1f]); + } + + #[test] + fn test_flush_4_byte_start_code() { + // Test flush with 4-byte start code + let mut data = Bytes::from(vec![0, 0, 0, 1, 0x65, 0x88, 0x84, 0x00, 0xff]); + let iter = NalIterator::new(&mut data); + + let final_nal = iter.flush().unwrap().unwrap(); + assert_eq!(final_nal.as_ref(), &[0x65, 0x88, 0x84, 0x00, 0xff]); + } + + #[test] + fn test_flush_no_start_code() { + // Buffer doesn't start with a start code and has no cached start position + let mut data = Bytes::from(vec![0x67, 0x42, 0x00, 0x1f]); + let iter = NalIterator::new(&mut data); + + let result = iter.flush(); + assert!(result.is_err()); + } + + #[test] + fn test_flush_empty_buffer() { + // Empty buffer should return None + let mut data = Bytes::from(vec![]); + let iter = NalIterator::new(&mut data); + + let result = iter.flush().unwrap(); + assert!(result.is_none()); + } + + #[test] + fn test_flush_incomplete_start_code() { + // Buffer has incomplete start code (not enough bytes) + let mut data = Bytes::from(vec![0, 0]); + let iter = NalIterator::new(&mut data); + + let result = iter.flush().unwrap(); + assert!(result.is_none()); + } + + #[test] + fn test_flush_multiple_nals_then_flush() { + // Iterate over multiple NALs, then flush the final one + let mut data = Bytes::from(vec![ + 0, 0, 0, 1, 0x67, 0x42, // SPS + 0, 0, 0, 1, 0x68, 0xce, // PPS + 0, 0, 0, 1, 0x65, 0x88, 0x84, // IDR (final NAL) + ]); + let mut iter = NalIterator::new(&mut data); + + let sps = iter.next().unwrap().unwrap(); + assert_eq!(sps.as_ref(), &[0x67, 0x42]); + + let pps = iter.next().unwrap().unwrap(); + assert_eq!(pps.as_ref(), &[0x68, 0xce]); + + assert!(iter.next().is_none()); + + let idr = iter.flush().unwrap().unwrap(); + assert_eq!(idr.as_ref(), &[0x65, 0x88, 0x84]); + } + + #[test] + fn test_flush_empty_final_nal() { + // Edge case: final NAL is empty (just a start code with no data) + let mut data = Bytes::from(vec![ + 0, 0, 0, 1, 0x67, 0x42, // First NAL + 0, 0, 0, 1, // Second NAL (empty) + ]); + let mut iter = NalIterator::new(&mut data); + + let nal1 = iter.next().unwrap().unwrap(); + assert_eq!(nal1.as_ref(), &[0x67, 0x42]); + + assert!(iter.next().is_none()); + + let final_nal = iter.flush().unwrap().unwrap(); + assert_eq!(final_nal.len(), 0); + } +} diff --git a/rs/hang/src/import/avc3.rs b/rs/hang/src/import/avc3.rs index 9027473ec..ec8d472ba 100644 --- a/rs/hang/src/import/avc3.rs +++ b/rs/hang/src/import/avc3.rs @@ -1,13 +1,11 @@ use crate as hang; +use crate::import::annexb::{NalIterator, START_CODE}; + use anyhow::Context; use buf_list::BufList; use bytes::{Buf, Bytes}; use moq_lite as moq; -// Prepend each NAL with a 4 byte start code. -// Yes, it's one byte longer than the 3 byte start code, but it's easier to convert to MP4. -const START_CODE: Bytes = Bytes::from_static(&[0, 0, 0, 1]); - /// A decoder for H.264 with inline SPS/PPS. pub struct Avc3 { // The broadcast being produced. @@ -288,506 +286,9 @@ pub enum NalType { DepthParameterSet = 16, } -struct NalIterator<'a, T: Buf + AsRef<[u8]> + 'a> { - buf: &'a mut T, - start: Option, -} - -impl<'a, T: Buf + AsRef<[u8]> + 'a> NalIterator<'a, T> { - pub fn new(buf: &'a mut T) -> Self { - Self { buf, start: None } - } - - /// Assume the buffer ends with a NAL unit and flush it. - /// This is more efficient because we cache the last "start" code position. - pub fn flush(self) -> anyhow::Result> { - let start = match self.start { - Some(start) => start, - None => match after_start_code(self.buf.as_ref())? { - Some(start) => start, - None => return Ok(None), - }, - }; - - self.buf.advance(start); - - let nal = self.buf.copy_to_bytes(self.buf.remaining()); - Ok(Some(nal)) - } -} - -impl<'a, T: Buf + AsRef<[u8]> + 'a> Iterator for NalIterator<'a, T> { - type Item = anyhow::Result; - - fn next(&mut self) -> Option { - let start = match self.start { - Some(start) => start, - None => match after_start_code(self.buf.as_ref()).transpose()? { - Ok(start) => start, - Err(err) => return Some(Err(err)), - }, - }; - - let (size, new_start) = find_start_code(&self.buf.as_ref()[start..])?; - self.buf.advance(start); - - let nal = self.buf.copy_to_bytes(size); - self.start = Some(new_start); - Some(Ok(nal)) - } -} - -// Return the size of the start code at the start of the buffer. -fn after_start_code(b: &[u8]) -> anyhow::Result> { - if b.len() < 3 { - return Ok(None); - } - - // NOTE: We have to check every byte, so the `find_start_code` optimization doesn't matter. - anyhow::ensure!(b[0] == 0, "missing Annex B start code"); - anyhow::ensure!(b[1] == 0, "missing Annex B start code"); - - match b[2] { - 0 if b.len() < 4 => Ok(None), - 0 if b[3] != 1 => anyhow::bail!("missing Annex B start code"), - 0 => Ok(Some(4)), - 1 => Ok(Some(3)), - _ => anyhow::bail!("invalid Annex B start code"), - } -} - -// Return the number of bytes until the next start code, and the size of that start code. -fn find_start_code(mut b: &[u8]) -> Option<(usize, usize)> { - // Okay this is over-engineered because this was my interview question. - // We need to find either a 3 byte or 4 byte start code. - // 3-byte: 0 0 1 - // 4-byte: 0 0 0 1 - // - // You fail the interview if you call string.split twice or something. - // You get a pass if you do index += 1 and check the next 3-4 bytes. - // You get my eternal respect if you check the 3rd byte first. - // What? - // - // If we check the 3rd byte and it's not a 0 or 1, then we immediately index += 3 - // Sometimes we might only skip 1 or 2 bytes, but it's still better than checking every byte. - // - // TODO Is this the type of thing that SIMD could further improve? - // If somebody can figure that out, I'll buy you a beer. - let size = b.len(); - - while b.len() >= 3 { - // ? ? ? - match b[2] { - // ? ? 0 - 0 if b.len() >= 4 => match b[3] { - // ? ? 0 1 - 1 => match b[1] { - // ? 0 0 1 - 0 => match b[0] { - // 0 0 0 1 - 0 => return Some((size - b.len(), 4)), - // ? 0 0 1 - _ => return Some((size - b.len() + 1, 3)), - }, - // ? x 0 1 - _ => b = &b[4..], - }, - // ? ? 0 0 - skip only 1 byte to check for potential 0 0 0 1 - 0 => b = &b[1..], - // ? ? 0 x - _ => b = &b[4..], - }, - // ? ? 0 FIN - 0 => return None, - // ? ? 1 - 1 => match b[1] { - // ? 0 1 - 0 => match b[0] { - // 0 0 1 - 0 => return Some((size - b.len(), 3)), - // ? 0 1 - _ => b = &b[3..], - }, - // ? x 1 - _ => b = &b[3..], - }, - // ? ? x - _ => b = &b[3..], - } - } - - None -} - #[derive(Default)] struct Frame { chunks: BufList, contains_idr: bool, contains_slice: bool, } - -#[cfg(test)] -mod tests { - use super::*; - - // Tests for after_start_code - validates and measures start code at buffer beginning - - #[test] - fn test_after_start_code_3_byte() { - let buf = &[0, 0, 1, 0x67]; - assert_eq!(after_start_code(buf).unwrap(), Some(3)); - } - - #[test] - fn test_after_start_code_4_byte() { - let buf = &[0, 0, 0, 1, 0x67]; - assert_eq!(after_start_code(buf).unwrap(), Some(4)); - } - - #[test] - fn test_after_start_code_too_short() { - let buf = &[0, 0]; - assert_eq!(after_start_code(buf).unwrap(), None); - } - - #[test] - fn test_after_start_code_incomplete_4_byte() { - let buf = &[0, 0, 0]; - assert_eq!(after_start_code(buf).unwrap(), None); - } - - #[test] - fn test_after_start_code_invalid_first_byte() { - let buf = &[1, 0, 1]; - assert!(after_start_code(buf).is_err()); - } - - #[test] - fn test_after_start_code_invalid_second_byte() { - let buf = &[0, 1, 1]; - assert!(after_start_code(buf).is_err()); - } - - #[test] - fn test_after_start_code_invalid_third_byte() { - let buf = &[0, 0, 2]; - assert!(after_start_code(buf).is_err()); - } - - #[test] - fn test_after_start_code_invalid_4_byte_pattern() { - let buf = &[0, 0, 0, 2]; - assert!(after_start_code(buf).is_err()); - } - - // Tests for find_start_code - finds next start code in NAL data - - #[test] - fn test_find_start_code_3_byte() { - let buf = &[0x67, 0x42, 0x00, 0x1f, 0, 0, 1]; - assert_eq!(find_start_code(buf), Some((4, 3))); - } - - #[test] - fn test_find_start_code_4_byte() { - // Should detect 4-byte start code at beginning - let buf = &[0, 0, 0, 1, 0x67]; - assert_eq!(find_start_code(buf), Some((0, 4))); - } - - #[test] - fn test_find_start_code_4_byte_after_data() { - // Should detect 4-byte start code after NAL data - let buf = &[0x67, 0x42, 0xff, 0x1f, 0, 0, 0, 1]; - assert_eq!(find_start_code(buf), Some((4, 4))); - } - - #[test] - fn test_find_start_code_at_start_3_byte() { - let buf = &[0, 0, 1, 0x67]; - assert_eq!(find_start_code(buf), Some((0, 3))); - } - - #[test] - fn test_find_start_code_none() { - let buf = &[0x67, 0x42, 0x00, 0x1f, 0xff]; - assert_eq!(find_start_code(buf), None); - } - - #[test] - fn test_find_start_code_trailing_zeros() { - let buf = &[0x67, 0x42, 0x00, 0x1f, 0, 0]; - assert_eq!(find_start_code(buf), None); - } - - #[test] - fn test_find_start_code_edge_case_3_byte() { - let buf = &[0xff, 0, 0, 1]; - assert_eq!(find_start_code(buf), Some((1, 3))); - } - - #[test] - fn test_find_start_code_false_positive_avoidance() { - // Pattern like: x 0 0 y (where y != 1) - should skip ahead - let buf = &[0xff, 0, 0, 0xff, 0, 0, 1]; - assert_eq!(find_start_code(buf), Some((4, 3))); - } - - #[test] - fn test_find_start_code_4_byte_after_nonzero() { - // Critical edge case: x 0 0 0 1 should find 4-byte start code at position 1 - // This tests that we only skip 1 byte when seeing ? ? 0 0 - let buf = &[0xff, 0, 0, 0, 1]; - assert_eq!(find_start_code(buf), Some((1, 4))); - } - - #[test] - fn test_find_start_code_consecutive_zeros() { - // Multiple consecutive zeros before the 1 - let buf = &[0xff, 0, 0, 0, 0, 0, 1]; - // Should skip past leading zeros and find the start code - let result = find_start_code(buf); - assert!(result.is_some()); - let (pos, size) = result.unwrap(); - // The exact position depends on the algorithm, but it should find a valid start code - assert!(size == 3 || size == 4); - assert!(pos < buf.len()); - } - - // Tests for NalIterator - iterates over NAL units in Annex B format - - #[test] - fn test_nal_iterator_simple_3_byte() { - let mut data = Bytes::from(vec![0, 0, 1, 0x67, 0x42, 0, 0, 1]); - let mut iter = NalIterator::new(&mut data); - - let nal = iter.next().unwrap().unwrap(); - assert_eq!(nal.as_ref(), &[0x67, 0x42]); - assert!(iter.next().is_none()); - - // Make sure the trailing 001 is still in the buffer. - assert_eq!(data.as_ref(), &[0, 0, 1]); - } - - #[test] - fn test_nal_iterator_simple_4_byte() { - let mut data = Bytes::from(vec![0, 0, 0, 1, 0x67, 0x42, 0, 0, 0, 1]); - let mut iter = NalIterator::new(&mut data); - - let nal = iter.next().unwrap().unwrap(); - assert_eq!(nal.as_ref(), &[0x67, 0x42]); - assert!(iter.next().is_none()); - - // Make sure the trailing 0001 is still in the buffer. - assert_eq!(data.as_ref(), &[0, 0, 0, 1]); - } - - #[test] - fn test_nal_iterator_multiple_nals() { - let mut data = Bytes::from(vec![0, 0, 0, 1, 0x67, 0x42, 0, 0, 0, 1, 0x68, 0xce, 0, 0, 0, 1]); - let mut iter = NalIterator::new(&mut data); - - let nal1 = iter.next().unwrap().unwrap(); - assert_eq!(nal1.as_ref(), &[0x67, 0x42]); - - let nal2 = iter.next().unwrap().unwrap(); - assert_eq!(nal2.as_ref(), &[0x68, 0xce]); - - assert!(iter.next().is_none()); - - // Make sure the trailing 0001 is still in the buffer. - assert_eq!(data.as_ref(), &[0, 0, 0, 1]); - } - - #[test] - fn test_nal_iterator_realistic_h264() { - // A realistic H.264 stream with SPS, PPS, and IDR - let mut data = Bytes::from(vec![ - // SPS NAL - 0, 0, 0, 1, 0x67, 0x42, 0x00, 0x1f, // PPS NAL - 0, 0, 0, 1, 0x68, 0xce, 0x3c, 0x80, // IDR slice - 0, 0, 0, 1, 0x65, 0x88, 0x84, 0x00, - // Trailing start code (needed to detect the end of the last NAL) - 0, 0, 0, 1, - ]); - let mut iter = NalIterator::new(&mut data); - - let sps = iter.next().unwrap().unwrap(); - assert_eq!(sps[0] & 0x1f, 7); // SPS type - assert_eq!(sps.as_ref(), &[0x67, 0x42, 0x00, 0x1f]); - - let pps = iter.next().unwrap().unwrap(); - assert_eq!(pps[0] & 0x1f, 8); // PPS type - assert_eq!(pps.as_ref(), &[0x68, 0xce, 0x3c, 0x80]); - - let idr = iter.next().unwrap().unwrap(); - assert_eq!(idr[0] & 0x1f, 5); // IDR type - assert_eq!(idr.as_ref(), &[0x65, 0x88, 0x84, 0x00]); - - assert!(iter.next().is_none()); - - // Make sure the trailing 0001 is still in the buffer. - assert_eq!(data.as_ref(), &[0, 0, 0, 1]); - } - - #[test] - fn test_nal_iterator_invalid_start() { - let mut data = Bytes::from(vec![1, 0, 1, 0x67]); - let mut iter = NalIterator::new(&mut data); - - assert!(iter.next().unwrap().is_err()); - - // Make sure the data is still in the buffer. - assert_eq!(data.as_ref(), &[1, 0, 1, 0x67]); - } - - #[test] - fn test_nal_iterator_empty_nal() { - // Two consecutive start codes create an empty NAL - let mut data = Bytes::from(vec![0, 0, 1, 0, 0, 1, 0x67, 0, 0, 1]); - let mut iter = NalIterator::new(&mut data); - - let nal1 = iter.next().unwrap().unwrap(); - assert_eq!(nal1.len(), 0); - - let nal2 = iter.next().unwrap().unwrap(); - assert_eq!(nal2.as_ref(), &[0x67]); - - assert!(iter.next().is_none()); - - // Make sure the data is still in the buffer. - assert_eq!(data.as_ref(), &[0, 0, 1]); - } - - #[test] - fn test_nal_iterator_nal_with_embedded_zeros() { - // NAL data that contains zeros (but not a start code pattern) - let mut data = Bytes::from(vec![ - 0, 0, 1, 0x67, 0x00, 0x00, 0x00, 0xff, // NAL with embedded zeros - 0, 0, 1, 0x68, // Next NAL - 0, 0, 1, - ]); - let mut iter = NalIterator::new(&mut data); - - let nal1 = iter.next().unwrap().unwrap(); - assert_eq!(nal1.as_ref(), &[0x67, 0x00, 0x00, 0x00, 0xff]); - - let nal2 = iter.next().unwrap().unwrap(); - assert_eq!(nal2.as_ref(), &[0x68]); - - assert!(iter.next().is_none()); - - // Make sure the data is still in the buffer. - assert_eq!(data.as_ref(), &[0, 0, 1]); - } - - // Tests for flush - extracts final NAL without trailing start code - - #[test] - fn test_flush_after_iteration() { - // Normal case: iterate over NALs, then flush the final one - let mut data = Bytes::from(vec![ - 0, 0, 0, 1, 0x67, 0x42, // First NAL - 0, 0, 0, 1, 0x68, 0xce, 0x3c, 0x80, // Second NAL (final, no trailing start code) - ]); - let mut iter = NalIterator::new(&mut data); - - let nal1 = iter.next().unwrap().unwrap(); - assert_eq!(nal1.as_ref(), &[0x67, 0x42]); - - assert!(iter.next().is_none()); - - let final_nal = iter.flush().unwrap().unwrap(); - assert_eq!(final_nal.as_ref(), &[0x68, 0xce, 0x3c, 0x80]); - } - - #[test] - fn test_flush_single_nal() { - // Buffer contains only a single NAL with no trailing start code - let mut data = Bytes::from(vec![0, 0, 1, 0x67, 0x42, 0x00, 0x1f]); - let iter = NalIterator::new(&mut data); - - let final_nal = iter.flush().unwrap().unwrap(); - assert_eq!(final_nal.as_ref(), &[0x67, 0x42, 0x00, 0x1f]); - } - - #[test] - fn test_flush_4_byte_start_code() { - // Test flush with 4-byte start code - let mut data = Bytes::from(vec![0, 0, 0, 1, 0x65, 0x88, 0x84, 0x00, 0xff]); - let iter = NalIterator::new(&mut data); - - let final_nal = iter.flush().unwrap().unwrap(); - assert_eq!(final_nal.as_ref(), &[0x65, 0x88, 0x84, 0x00, 0xff]); - } - - #[test] - fn test_flush_no_start_code() { - // Buffer doesn't start with a start code and has no cached start position - let mut data = Bytes::from(vec![0x67, 0x42, 0x00, 0x1f]); - let iter = NalIterator::new(&mut data); - - let result = iter.flush(); - assert!(result.is_err()); - } - - #[test] - fn test_flush_empty_buffer() { - // Empty buffer should return None - let mut data = Bytes::from(vec![]); - let iter = NalIterator::new(&mut data); - - let result = iter.flush().unwrap(); - assert!(result.is_none()); - } - - #[test] - fn test_flush_incomplete_start_code() { - // Buffer has incomplete start code (not enough bytes) - let mut data = Bytes::from(vec![0, 0]); - let iter = NalIterator::new(&mut data); - - let result = iter.flush().unwrap(); - assert!(result.is_none()); - } - - #[test] - fn test_flush_multiple_nals_then_flush() { - // Iterate over multiple NALs, then flush the final one - let mut data = Bytes::from(vec![ - 0, 0, 0, 1, 0x67, 0x42, // SPS - 0, 0, 0, 1, 0x68, 0xce, // PPS - 0, 0, 0, 1, 0x65, 0x88, 0x84, // IDR (final NAL) - ]); - let mut iter = NalIterator::new(&mut data); - - let sps = iter.next().unwrap().unwrap(); - assert_eq!(sps.as_ref(), &[0x67, 0x42]); - - let pps = iter.next().unwrap().unwrap(); - assert_eq!(pps.as_ref(), &[0x68, 0xce]); - - assert!(iter.next().is_none()); - - let idr = iter.flush().unwrap().unwrap(); - assert_eq!(idr.as_ref(), &[0x65, 0x88, 0x84]); - } - - #[test] - fn test_flush_empty_final_nal() { - // Edge case: final NAL is empty (just a start code with no data) - let mut data = Bytes::from(vec![ - 0, 0, 0, 1, 0x67, 0x42, // First NAL - 0, 0, 0, 1, // Second NAL (empty) - ]); - let mut iter = NalIterator::new(&mut data); - - let nal1 = iter.next().unwrap().unwrap(); - assert_eq!(nal1.as_ref(), &[0x67, 0x42]); - - assert!(iter.next().is_none()); - - let final_nal = iter.flush().unwrap().unwrap(); - assert_eq!(final_nal.len(), 0); - } -} diff --git a/rs/hang/src/import/decoder.rs b/rs/hang/src/import/decoder.rs index 7bc8eb670..7e6c25d12 100644 --- a/rs/hang/src/import/decoder.rs +++ b/rs/hang/src/import/decoder.rs @@ -2,7 +2,7 @@ use std::{fmt, str::FromStr}; use bytes::Buf; -use crate::{self as hang, import::Aac, import::Opus, Error}; +use crate::{self as hang, import::Aac, import::Hev1, import::Opus, Error}; use super::{Avc3, Fmp4}; @@ -12,6 +12,8 @@ pub enum DecoderFormat { Avc3, /// fMP4/CMAF container. Fmp4, + /// aka H265 with inline SPS/PPS + Hev1, /// Raw AAC frames (not ADTS). Aac, /// Raw Opus frames (not Ogg). @@ -28,6 +30,7 @@ impl FromStr for DecoderFormat { tracing::warn!("format '{s}' is deprecated, use 'avc3' instead"); Ok(DecoderFormat::Avc3) } + "hev1" => Ok(DecoderFormat::Hev1), "fmp4" | "cmaf" => Ok(DecoderFormat::Fmp4), "aac" => Ok(DecoderFormat::Aac), "opus" => Ok(DecoderFormat::Opus), @@ -41,6 +44,7 @@ impl fmt::Display for DecoderFormat { match self { DecoderFormat::Avc3 => write!(f, "avc3"), DecoderFormat::Fmp4 => write!(f, "fmp4"), + DecoderFormat::Hev1 => write!(f, "hev1"), DecoderFormat::Aac => write!(f, "aac"), DecoderFormat::Opus => write!(f, "opus"), } @@ -53,6 +57,8 @@ enum DecoderKind { Avc3(Avc3), // Boxed because it's a large struct and clippy complains about the size. Fmp4(Box), + /// aka H265 with inline SPS/PPS + Hev1(Hev1), Aac(Aac), Opus(Opus), } @@ -71,6 +77,7 @@ impl Decoder { let decoder = match format { DecoderFormat::Avc3 => Avc3::new(broadcast).into(), DecoderFormat::Fmp4 => Box::new(Fmp4::new(broadcast)).into(), + DecoderFormat::Hev1 => Hev1::new(broadcast).into(), DecoderFormat::Aac => Aac::new(broadcast).into(), DecoderFormat::Opus => Opus::new(broadcast).into(), }; @@ -88,6 +95,7 @@ impl Decoder { match &mut self.decoder { DecoderKind::Avc3(decoder) => decoder.initialize(buf)?, DecoderKind::Fmp4(decoder) => decoder.decode(buf)?, + DecoderKind::Hev1(decoder) => decoder.initialize(buf)?, DecoderKind::Aac(decoder) => decoder.initialize(buf)?, DecoderKind::Opus(decoder) => decoder.initialize(buf)?, } @@ -114,6 +122,7 @@ impl Decoder { match &mut self.decoder { DecoderKind::Avc3(decoder) => decoder.decode_stream(buf, None)?, DecoderKind::Fmp4(decoder) => decoder.decode(buf)?, + DecoderKind::Hev1(decoder) => decoder.decode_stream(buf, None)?, // TODO Fix or make these more type safe. DecoderKind::Aac(_) => anyhow::bail!("AAC does not support stream decoding"), DecoderKind::Opus(_) => anyhow::bail!("Opus does not support stream decoding"), @@ -140,6 +149,7 @@ impl Decoder { match &mut self.decoder { DecoderKind::Avc3(decoder) => decoder.decode_frame(buf, pts)?, DecoderKind::Fmp4(decoder) => decoder.decode(buf)?, + DecoderKind::Hev1(decoder) => decoder.decode_frame(buf, pts)?, DecoderKind::Aac(decoder) => decoder.decode(buf, pts)?, DecoderKind::Opus(decoder) => decoder.decode(buf, pts)?, } @@ -152,6 +162,7 @@ impl Decoder { match &self.decoder { DecoderKind::Avc3(decoder) => decoder.is_initialized(), DecoderKind::Fmp4(decoder) => decoder.is_initialized(), + DecoderKind::Hev1(decoder) => decoder.is_initialized(), DecoderKind::Aac(decoder) => decoder.is_initialized(), DecoderKind::Opus(decoder) => decoder.is_initialized(), } diff --git a/rs/hang/src/import/hev1.rs b/rs/hang/src/import/hev1.rs new file mode 100644 index 000000000..57d3ff111 --- /dev/null +++ b/rs/hang/src/import/hev1.rs @@ -0,0 +1,357 @@ +use crate as hang; +use crate::import::annexb::{NalIterator, START_CODE}; + +use anyhow::Context; +use buf_list::BufList; +use bytes::{Buf, Bytes}; +use moq_lite as moq; +use scuffle_h265::{NALUnitType, SpsNALUnit}; + +/// A decoder for H.265 with inline SPS/PPS. +/// Only supports single layer streams, ignores VPS. +pub struct Hev1 { + // The broadcast being produced. + // This `hang` variant includes a catalog. + broadcast: hang::BroadcastProducer, + + // The track being produced. + track: Option, + + // Whether the track has been initialized. + // If it changes, then we'll reinitialize with a new track. + config: Option, + + // The current frame being built. + current: Frame, + + // Used to compute wall clock timestamps if needed. + zero: Option, +} + +impl Hev1 { + pub fn new(broadcast: hang::BroadcastProducer) -> Self { + Self { + broadcast, + track: None, + config: None, + current: Default::default(), + zero: None, + } + } + + fn init(&mut self, sps: &SpsNALUnit) -> anyhow::Result<()> { + let profile = &sps.rbsp.profile_tier_level.general_profile; + let vui_data = sps.rbsp.vui_parameters.as_ref().map(VuiData::new).unwrap_or_default(); + + let config = hang::catalog::VideoConfig { + coded_width: Some(sps.rbsp.cropped_width() as u32), + coded_height: Some(sps.rbsp.cropped_height() as u32), + codec: hang::catalog::H265 { + in_band: true, // We only support `hev1` with inline SPS/PPS for now + profile_space: profile.profile_space, + profile_idc: profile.profile_idc, + profile_compatibility_flags: profile.profile_compatibility_flag.bits().to_be_bytes(), + tier_flag: profile.tier_flag, + level_idc: profile.level_idc.context("missing level_idc in SPS")?, + constraint_flags: pack_constraint_flags(profile), + } + .into(), + description: None, + framerate: vui_data.framerate, + bitrate: None, + display_ratio_width: vui_data.display_ratio_width, + display_ratio_height: vui_data.display_ratio_height, + optimize_for_latency: None, + }; + + if let Some(old) = &self.config { + if old == &config { + return Ok(()); + } + } + + if let Some(track) = &self.track.take() { + tracing::debug!(name = ?track.info.name, "reinitializing track"); + self.broadcast.catalog.lock().remove_video(&track.info.name); + } + + let track = moq::Track { + name: self.broadcast.track_name("video"), + priority: 2, + }; + + tracing::debug!(name = ?track.name, ?config, "starting track"); + + { + let mut catalog = self.broadcast.catalog.lock(); + let video = catalog.insert_video(track.name.clone(), config.clone()); + video.priority = 2; + } + + let track = track.produce(); + self.broadcast.insert_track(track.consumer); + + self.config = Some(config); + self.track = Some(track.producer.into()); + + Ok(()) + } + + /// Initialize the decoder with SPS/PPS and other non-slice NALs. + pub fn initialize>(&mut self, buf: &mut T) -> anyhow::Result<()> { + let mut nals = NalIterator::new(buf); + + while let Some(nal) = nals.next().transpose()? { + self.decode_nal(nal, None)?; + } + + if let Some(nal) = nals.flush()? { + self.decode_nal(nal, None)?; + } + + Ok(()) + } + + /// Decode as much data as possible from the given buffer. + /// + /// Unlike [Self::decode_frame], this method needs the start code for the next frame. + /// This means it works for streaming media (ex. stdin) but adds a frame of latency. + /// + /// TODO: This currently associates PTS with the *previous* frame, as part of `maybe_start_frame`. + pub fn decode_stream>( + &mut self, + buf: &mut T, + pts: Option, + ) -> anyhow::Result<()> { + let pts = self.pts(pts)?; + + // Iterate over the NAL units in the buffer based on start codes. + let nals = NalIterator::new(buf); + + for nal in nals { + self.decode_nal(nal?, Some(pts))?; + } + + Ok(()) + } + + /// Decode all data in the buffer, assuming the buffer contains (the rest of) a frame. + /// + /// Unlike [Self::decode_stream], this is called when we know NAL boundaries. + /// This can avoid a frame of latency just waiting for the next frame's start code. + /// This can also be used when EOF is detected to flush the final frame. + /// + /// NOTE: The next decode will fail if it doesn't begin with a start code. + pub fn decode_frame>( + &mut self, + buf: &mut T, + pts: Option, + ) -> anyhow::Result<()> { + let pts = self.pts(pts)?; + // Iterate over the NAL units in the buffer based on start codes. + let mut nals = NalIterator::new(buf); + + // Iterate over each NAL that is followed by a start code. + while let Some(nal) = nals.next().transpose()? { + self.decode_nal(nal, Some(pts))?; + } + + // Assume the rest of the buffer is a single NAL. + if let Some(nal) = nals.flush()? { + self.decode_nal(nal, Some(pts))?; + } + + // Flush the frame if we read a slice. + self.maybe_start_frame(Some(pts))?; + + Ok(()) + } + + /// Decode a single NAL unit. Only reads the first header byte to extract nal_unit_type, + /// Ignores nuh_layer_id and nuh_temporal_id_plus1. + fn decode_nal(&mut self, nal: Bytes, pts: Option) -> anyhow::Result<()> { + anyhow::ensure!(nal.len() >= 2, "NAL unit is too short"); + // u16 header: [forbidden_zero_bit(1) | nal_unit_type(6) | nuh_layer_id(6) | nuh_temporal_id_plus1(3)] + let header = nal.first().context("NAL unit is too short")?; + + let forbidden_zero_bit = (header >> 7) & 1; + anyhow::ensure!(forbidden_zero_bit == 0, "forbidden zero bit is not zero"); + + // Bits 1-6: nal_unit_type + let nal_unit_type = (header >> 1) & 0b111111; + let nal_type = NALUnitType::from(nal_unit_type); + + match nal_type { + NALUnitType::SpsNut => { + self.maybe_start_frame(pts)?; + + // Try to reinitialize the track if the SPS has changed. + let sps = SpsNALUnit::parse(&mut &nal[..]).context("failed to parse SPS NAL unit")?; + self.init(&sps)?; + } + // TODO parse the SPS again and reinitialize the track if needed + NALUnitType::AudNut | NALUnitType::PpsNut | NALUnitType::PrefixSeiNut | NALUnitType::SuffixSeiNut => { + self.maybe_start_frame(pts)?; + } + // Keyframe containing slices + NALUnitType::IdrWRadl + | NALUnitType::IdrNLp + | NALUnitType::BlaNLp + | NALUnitType::BlaWRadl + | NALUnitType::BlaWLp + | NALUnitType::CraNut => { + self.current.contains_idr = true; + self.current.contains_slice = true; + } + // All other slice types (both N and R variants) + NALUnitType::TrailN + | NALUnitType::TrailR + | NALUnitType::TsaN + | NALUnitType::TsaR + | NALUnitType::StsaN + | NALUnitType::StsaR + | NALUnitType::RadlN + | NALUnitType::RadlR + | NALUnitType::RaslN + | NALUnitType::RaslR => { + // Check first_slice_segment_in_pic_flag (bit 7 of third byte, after 2-byte header) + if nal.get(2).context("NAL unit is too short")? & 0x80 != 0 { + self.maybe_start_frame(pts)?; + } + self.current.contains_slice = true; + } + _ => {} + } + + // Rather than keeping the original size of the start code, we replace it with a 4 byte start code. + // It's just marginally easier and potentially more efficient down the line (JS player with MSE). + // NOTE: This is ref-counted and static, so it's extremely cheap to clone. + self.current.chunks.push_chunk(START_CODE.clone()); + self.current.chunks.push_chunk(nal); + + Ok(()) + } + + fn maybe_start_frame(&mut self, pts: Option) -> anyhow::Result<()> { + // If we haven't seen any slices, we shouldn't flush yet. + if !self.current.contains_slice { + return Ok(()); + } + + let track = self.track.as_mut().context("expected SPS before any frames")?; + let pts = pts.context("missing timestamp")?; + + let payload = std::mem::take(&mut self.current.chunks); + let frame = hang::Frame { + timestamp: pts, + keyframe: self.current.contains_idr, + payload, + }; + + track.write(frame)?; + + self.current.contains_idr = false; + self.current.contains_slice = false; + + Ok(()) + } + + pub fn is_initialized(&self) -> bool { + self.track.is_some() + } + + fn pts(&mut self, hint: Option) -> anyhow::Result { + if let Some(pts) = hint { + return Ok(pts); + } + + let zero = self.zero.get_or_insert_with(tokio::time::Instant::now); + Ok(hang::Timestamp::from_micros(zero.elapsed().as_micros() as u64)?) + } +} + +impl Drop for Hev1 { + fn drop(&mut self) { + if let Some(track) = &self.track { + tracing::debug!(name = ?track.info.name, "ending track"); + self.broadcast.catalog.lock().remove_video(&track.info.name); + } + } +} + +// Packs the constraint flags from ITU H.265 V10 Section 7.3.3 Profile, tier and level syntax +fn pack_constraint_flags(profile: &scuffle_h265::Profile) -> [u8; 6] { + let mut flags = [0u8; 6]; + flags[0] = ((profile.progressive_source_flag as u8) << 7) + | ((profile.interlaced_source_flag as u8) << 6) + | ((profile.non_packed_constraint_flag as u8) << 5) + | ((profile.frame_only_constraint_flag as u8) << 4); + + // @todo: pack the rest of the optional flags in profile.additional_flags + flags +} + +#[derive(Default)] +struct Frame { + chunks: BufList, + contains_idr: bool, + contains_slice: bool, +} + +#[derive(Default)] +struct VuiData { + framerate: Option, + display_ratio_width: Option, + display_ratio_height: Option, +} + +impl VuiData { + fn new(vui: &scuffle_h265::VuiParameters) -> Self { + // FPS = time_scale / num_units_in_tick + let framerate = vui + .vui_timing_info + .as_ref() + .map(|t| t.time_scale.get() as f64 / t.num_units_in_tick.get() as f64); + + let (display_ratio_width, display_ratio_height) = match &vui.aspect_ratio_info { + // Extended SAR has explicit arbitrary values for width and height. + scuffle_h265::AspectRatioInfo::ExtendedSar { sar_width, sar_height } => { + (Some(*sar_width as u32), Some(*sar_height as u32)) + } + // Predefined map to known values. + scuffle_h265::AspectRatioInfo::Predefined(idc) => aspect_ratio_from_idc(*idc) + .map(|(w, h)| (Some(w), Some(h))) + .unwrap_or((None, None)), + }; + + VuiData { + framerate, + display_ratio_width, + display_ratio_height, + } + } +} + +fn aspect_ratio_from_idc(idc: scuffle_h265::AspectRatioIdc) -> Option<(u32, u32)> { + match idc { + scuffle_h265::AspectRatioIdc::Unspecified => None, + scuffle_h265::AspectRatioIdc::Square => Some((1, 1)), + scuffle_h265::AspectRatioIdc::Aspect12_11 => Some((12, 11)), + scuffle_h265::AspectRatioIdc::Aspect10_11 => Some((10, 11)), + scuffle_h265::AspectRatioIdc::Aspect16_11 => Some((16, 11)), + scuffle_h265::AspectRatioIdc::Aspect40_33 => Some((40, 33)), + scuffle_h265::AspectRatioIdc::Aspect24_11 => Some((24, 11)), + scuffle_h265::AspectRatioIdc::Aspect20_11 => Some((20, 11)), + scuffle_h265::AspectRatioIdc::Aspect32_11 => Some((32, 11)), + scuffle_h265::AspectRatioIdc::Aspect80_33 => Some((80, 33)), + scuffle_h265::AspectRatioIdc::Aspect18_11 => Some((18, 11)), + scuffle_h265::AspectRatioIdc::Aspect15_11 => Some((15, 11)), + scuffle_h265::AspectRatioIdc::Aspect64_33 => Some((64, 33)), + scuffle_h265::AspectRatioIdc::Aspect160_99 => Some((160, 99)), + scuffle_h265::AspectRatioIdc::Aspect4_3 => Some((4, 3)), + scuffle_h265::AspectRatioIdc::Aspect3_2 => Some((3, 2)), + scuffle_h265::AspectRatioIdc::Aspect2_1 => Some((2, 1)), + scuffle_h265::AspectRatioIdc::ExtendedSar => None, + _ => None, // Reserved + } +} diff --git a/rs/hang/src/import/mod.rs b/rs/hang/src/import/mod.rs index abc45c9b6..26e44ca55 100644 --- a/rs/hang/src/import/mod.rs +++ b/rs/hang/src/import/mod.rs @@ -1,7 +1,9 @@ mod aac; +mod annexb; mod avc3; mod decoder; mod fmp4; +mod hev1; mod hls; mod opus; @@ -9,5 +11,6 @@ pub use aac::*; pub use avc3::*; pub use decoder::*; pub use fmp4::*; +pub use hev1::*; pub use hls::*; pub use opus::*;