diff --git a/NeuralAudio/WaveNet.h b/NeuralAudio/WaveNet.h index 5f0814b..5e52f46 100644 --- a/NeuralAudio/WaveNet.h +++ b/NeuralAudio/WaveNet.h @@ -122,6 +122,7 @@ namespace NeuralAudio DenseLayerT inputMixin; DenseLayerT oneByOne; Eigen::Matrix state; + bool isLast; public: static constexpr auto ReceptiveFieldSize = (KernelSize - 1) * Dilation; @@ -136,11 +137,17 @@ namespace NeuralAudio //Eigen::Matrix layerBuffer; size_t bufferStart; - WaveNetLayerT() + WaveNetLayerT() : + isLast(false) { state.setZero(); } + void SetLast() + { + isLast = true; + } + void AllocBuffer(int allocNum) { long size = BufferSize; @@ -225,9 +232,12 @@ namespace NeuralAudio const_cast&>(headInput).noalias() += block.topRows(Channels); - oneByOne.Process(block.topRows(Channels), const_cast&>(output).middleCols(outputStart, numFrames)); + if (!isLast) + { + oneByOne.Process(block.topRows(Channels), const_cast&>(output).middleCols(outputStart, numFrames)); - const_cast&>(output).middleCols(outputStart, numFrames).noalias() += layerBuffer.middleCols(bufferStart, numFrames); + const_cast&>(output).middleCols(outputStart, numFrames).noalias() += layerBuffer.middleCols(bufferStart, numFrames); + } } }; @@ -280,6 +290,11 @@ namespace NeuralAudio return allocNum; } + void SetLast() + { + std::get(layers).SetLast(); + } + void SetWeights(std::vector::iterator& weights) { rechannel.SetWeights(weights); @@ -350,6 +365,7 @@ namespace NeuralAudio allocNum = std::get(layerArrays).AllocBuffers(allocNum); }); + std::get(layerArrays).SetLast(); } void SetWeights(std::vector weights) @@ -431,4 +447,4 @@ namespace NeuralAudio Eigen::Matrix headArray; float headScale; }; -} \ No newline at end of file +}