diff --git a/smelter-core/src/pipeline/decoder/dynamic_stream.rs b/smelter-core/src/pipeline/decoder/dynamic_stream.rs index 7ec0f5afb..d0d9754be 100644 --- a/smelter-core/src/pipeline/decoder/dynamic_stream.rs +++ b/smelter-core/src/pipeline/decoder/dynamic_stream.rs @@ -1,5 +1,6 @@ use std::{iter, sync::Arc}; +use crossbeam_channel::Sender; use smelter_render::{Frame, error::ErrorStack}; use tracing::error; @@ -20,6 +21,7 @@ where source: Source, eos_sent: bool, decoders_info: VideoDecoderMapping, + keyframe_request_sender: Sender<()>, } impl DynamicVideoDecoderStream @@ -30,6 +32,7 @@ where ctx: Arc, decoders_info: VideoDecoderMapping, source: Source, + keyframe_request_sender: Sender<()>, ) -> Self { Self { ctx, @@ -38,6 +41,7 @@ where source, eos_sent: false, decoders_info, + keyframe_request_sender, } } @@ -99,6 +103,9 @@ where self.ensure_decoder(samples.kind); let decoder = self.decoder.as_mut()?; let chunks = decoder.decode(samples); + if chunks.len() == 0 { + let _ = self.keyframe_request_sender.try_send(()); + } Some(chunks.into_iter().map(PipelineEvent::Data).collect()) } Some(PipelineEvent::EOS) | None => match self.eos_sent { diff --git a/smelter-core/src/pipeline/webrtc/video_input_processing_loop.rs b/smelter-core/src/pipeline/webrtc/video_input_processing_loop.rs index d73abc6a1..96ad6b6d3 100644 --- a/smelter-core/src/pipeline/webrtc/video_input_processing_loop.rs +++ b/smelter-core/src/pipeline/webrtc/video_input_processing_loop.rs @@ -71,13 +71,15 @@ impl InitializableThread for VideoTrackThread { VideoDecoderMapping, VideoPayloadTypeMapping, Sender>, + Sender<()>, ); type SpawnOutput = VideoTrackThreadHandle; type SpawnError = DecoderInitError; fn init(options: Self::InitOptions) -> Result<(Self, Self::SpawnOutput), Self::SpawnError> { - let (ctx, decoder_mapping, payload_type_mapping, frame_sender) = options; + let (ctx, decoder_mapping, payload_type_mapping, frame_sender, keyframe_request_sender) = + options; let (rtp_packet_sender, rtp_packet_receiver) = tokio::sync::mpsc::channel(5000); let packet_stream = AsyncReceiverIter { @@ -87,8 +89,13 @@ impl InitializableThread for VideoTrackThread { let depayloader_stream = DynamicDepayloaderStream::new(payload_type_mapping, packet_stream).flatten(); - let decoder_stream = - DynamicVideoDecoderStream::new(ctx, decoder_mapping, depayloader_stream).flatten(); + let decoder_stream = DynamicVideoDecoderStream::new( + ctx, + decoder_mapping, + depayloader_stream, + keyframe_request_sender, + ) + .flatten(); let result_stream = decoder_stream .filter_map(|event| match event { diff --git a/smelter-core/src/pipeline/webrtc/whip_input/process_tracks.rs b/smelter-core/src/pipeline/webrtc/whip_input/process_tracks.rs index 084df8504..028433e2f 100644 --- a/smelter-core/src/pipeline/webrtc/whip_input/process_tracks.rs +++ b/smelter-core/src/pipeline/webrtc/whip_input/process_tracks.rs @@ -1,7 +1,11 @@ use std::sync::Arc; +use crossbeam_channel::bounded; use tracing::warn; -use webrtc::{rtp_transceiver::RTCRtpTransceiver, track::track_remote::TrackRemote}; +use webrtc::{ + rtcp::payload_feedbacks::picture_loss_indication::PictureLossIndication, + rtp_transceiver::RTCRtpTransceiver, track::track_remote::TrackRemote, +}; use crate::{ codecs::VideoDecoderOptions, @@ -79,6 +83,25 @@ pub async fn process_video_track( let WhipWhepServerState { inputs, ctx, .. } = state; let frame_sender = inputs.get_with(&endpoint_id, |input| Ok(input.frame_sender.clone()))?; + let (keyframe_request_sender, keyframe_request_receiver) = bounded(1); + let ssrc = track.ssrc(); + let rtc_receiver_clone = rtc_receiver.clone(); + tokio::spawn(async move { + let transport = rtc_receiver_clone.transport(); + for _ in keyframe_request_receiver.into_iter() { + warn!("Sending PLI"); + let pli = PictureLossIndication { + // For receive-only endpoints RTP sender SSRC can be set to 0. + sender_ssrc: 0, + media_ssrc: ssrc, + }; + + if let Err(err) = transport.write_rtcp(&[Box::new(pli)]).await { + warn!(?err) + } + } + }); + let handle = VideoTrackThread::spawn( format!("WHIP input video, endpoint_id: {}", endpoint_id), ( @@ -86,6 +109,7 @@ pub async fn process_video_track( decoder_mapping, payload_type_mapping, frame_sender, + keyframe_request_sender, ), )?;