diff --git a/core/src/ops/fft.rs b/core/src/ops/fft.rs index 4e55aac9ad..11b05af225 100644 --- a/core/src/ops/fft.rs +++ b/core/src/ops/fft.rs @@ -223,5 +223,42 @@ impl TypedOp for Stft { Ok(tvec!(inputs[0].datum_type.fact(shape))) } + fn axes_mapping( + &self, + inputs: &[&TypedFact], + _outputs: &[&TypedFact], + ) -> TractResult { + // Stft is NOT rank-preserving: it inserts a frame axis at + // `axis + 1`. The mapping is: + // - axes 0..self.axis (leading dims): 1-to-1 input <-> output. + // - input axis `self.axis` (the time axis) <-> output axis + // `self.axis` (now the n_frames axis -- same position, the + // dim shrinks from `T` to `(T - frame) / stride + 1`). + // - output axis `self.axis + 1` (the inserted frame axis): + // output-only, no input correspondence. + // - input axes `self.axis + 1..rank` (trailing dims incl. + // the complex pair) <-> output axes `self.axis + 2..rank+1` + // (shifted right by 1 to make room for the frame axis). + // + // Without this mapping the generic `PulseWrappingOp` fallback + // bails with "could not track pulsing axis" the moment a user + // streams a non-time axis through STFT (typical pattern: a + // batched STFT pipeline that streams the batch axis). + let in_rank = inputs[0].rank(); + let mut axes = tvec!(); + let mut alphabet = 'a'..; + for i in 0..in_rank { + let out_axis = if i <= self.axis { i } else { i + 1 }; + axes.push( + crate::axes::Axis::new(alphabet.next().unwrap(), 1, 1) + .input(0, i) + .output(0, out_axis), + ); + } + // Inserted frame axis (output-only). + axes.push(crate::axes::Axis::new(alphabet.next().unwrap(), 1, 1).output(0, self.axis + 1)); + crate::axes::AxesMapping::new(1, 1, axes) + } + as_op!(); } diff --git a/harness/core-proptest-pulse/src/lib.rs b/harness/core-proptest-pulse/src/lib.rs index a8866e9294..252e0ee0e6 100644 --- a/harness/core-proptest-pulse/src/lib.rs +++ b/harness/core-proptest-pulse/src/lib.rs @@ -26,6 +26,7 @@ mod delay_plus_downsample; mod delay_plus_pool; mod einsum; mod pad_plus_conv; +mod stft; #[allow(dead_code)] fn setup_test_logger() { diff --git a/harness/core-proptest-pulse/src/stft.rs b/harness/core-proptest-pulse/src/stft.rs new file mode 100644 index 0000000000..f3ed69f576 --- /dev/null +++ b/harness/core-proptest-pulse/src/stft.rs @@ -0,0 +1,65 @@ +use proptest::test_runner::TestCaseResult; +use tract_core::ops::fft::Stft; + +use super::*; + +/// STFT applied with the streaming axis distinct from the STFT axis +/// must be pulsifiable: every non-STFT axis is a 1-to-1 passthrough +/// once `Stft::axes_mapping` declares the relationship (input axis +/// `op.axis` maps to output `op.axis` as `n_frames`; output `op.axis + +/// 1` is the inserted frame axis; the rest shift naturally). Without +/// the mapping the pulse pass bails with "could not track pulsing +/// axis" the moment a batched STFT pipeline streams its batch axis. +/// +/// Setup: input is rank-3 `(B_stream, T, 2)`. B_stream is the +/// streaming axis (axis 0); STFT runs on the T axis (axis 1); the +/// trailing 2 holds (re, im). One pulse = one batch element; tract +/// runs the full-length STFT inside each pulse. +fn stft_on_non_stft_axis( + batch_len: usize, + pulse: usize, + time_len: usize, + frame: usize, + stride: usize, +) -> TestCaseResult { + let mut model = TypedModel::default(); + let s = model.symbols.sym("S"); + let a = model.add_source("a", f32::fact(dims!(s, time_len, 2))).unwrap(); + model.wire_node("stft", Stft { axis: 1, frame, stride, window: None }, &[a]).unwrap(); + model.auto_outputs().unwrap(); + + let input: ArrayD = ArrayD::from_shape_fn(vec![batch_len, time_len, 2], |idx| { + (idx[0] * time_len * 2 + idx[1] * 2 + idx[2]) as f32 * 0.01 + }); + proptest_regular_against_pulse(model, pulse, input, 0) +} + +#[test] +fn stft_pulse_batch_axis_smoke_4_pulse2_t8_frame4_stride2() { + stft_on_non_stft_axis(4, 2, 8, 4, 2).unwrap(); +} + +#[test] +fn stft_pulse_batch_axis_smoke_3_pulse1_t6_frame3_stride1() { + stft_on_non_stft_axis(3, 1, 6, 3, 1).unwrap(); +} + +#[test] +fn stft_pulse_batch_axis_smoke_2_pulse2_t12_frame4_stride4() { + stft_on_non_stft_axis(2, 2, 12, 4, 4).unwrap(); +} + +proptest! { + #[test] + fn proptest_stft_pulse_batch_axis( + batch_len in 1usize..6, + pulse in 1usize..3, + time_len in 4usize..16, + frame in proptest::sample::select(vec![2usize, 4]), + stride in proptest::sample::select(vec![1usize, 2]), + ) { + // Skip frame > time_len -- the STFT would produce 0 frames. + prop_assume!(time_len >= frame); + stft_on_non_stft_axis(batch_len, pulse, time_len, frame, stride)? + } +}