Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,83 @@ Note that the 48 kHz model processes the audio by chunks of 1 seconds, with an o
and renormalizes the audio to have unit scale. For this model, the output of `model.encode(wav)`
would a list (for each frame of 1 second) of a tuple `(codes, scale)` with `scale` a scalar tensor.

## Audio Style Transfer

EnCodec now includes a neural audio style transfer module that enables transferring audio styles while maintaining content structure. This feature allows you to transform the style of one audio while preserving the content of another.

### Features
- Transfer voice styles between different speakers
- Maintain content while changing audio characteristics
- Support for both mono and stereo audio
- GPU acceleration for faster processing
- Compatible with existing EnCodec compression pipeline

### Usage

#### Training
```bash
python -m encodec.train_style_transfer \
--content-dir /path/to/content/audio \
--style-dir /path/to/style/audio \
--checkpoint-dir checkpoints
```

#### Inference
```bash
python -m encodec.apply_style_transfer \
--model-path checkpoints/model_epoch_100.pt \
--content-path input.wav \
--style-path style.wav \
--output-path output.wav
```

### Example
```python
from encodec import NeuralAudioStyleTransfer
from encodec.utils import convert_audio
import torchaudio

# Load model
model = NeuralAudioStyleTransfer()
model.load_state_dict(torch.load("checkpoints/model_epoch_100.pt"))

# Load content and style audio
content_audio, sr = torchaudio.load("content.wav")
style_audio, _ = torchaudio.load("style.wav")

# Convert audio to model's sample rate
content_audio = convert_audio(content_audio, sr, model.sample_rate, model.channels)
style_audio = convert_audio(style_audio, sr, model.sample_rate, model.channels)

# Apply style transfer
with torch.no_grad():
output_audio = model(content_audio, style_audio)

# Save result
torchaudio.save("output.wav", output_audio, model.sample_rate)
```

### Parameters
- `--content-dir`: Directory containing content audio files
- `--style-dir`: Directory containing style reference audio files
- `--sample-rate`: Target sample rate (default: 16000)
- `--batch-size`: Training batch size (default: 32)
- `--epochs`: Number of training epochs (default: 100)
- `--learning-rate`: Learning rate (default: 0.001)

### Integration with Compression
The style transfer module can be used in conjunction with EnCodec's compression pipeline:

```python
# First apply style transfer
output_audio = model(content_audio, style_audio)

# Then compress the result
model = EncodecModel.encodec_model_24khz()
model.set_target_bandwidth(6.0)
compressed = model.encode(output_audio)
```

## Installation for development

This will install the dependencies and a `encodec` in developer mode (changes to the files
Expand Down
93 changes: 93 additions & 0 deletions encodec/apply_style_transfer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torchaudio
import argparse
from pathlib import Path
from typing import Optional

from .modules.style_transfer import NeuralAudioStyleTransfer

def apply_style_transfer(
model: NeuralAudioStyleTransfer,
content_path: Path,
style_path: Path,
output_path: Path,
sample_rate: int = 16000,
device: Optional[str] = None
) -> None:
"""Apply style transfer to an audio file.

Args:
model: Trained style transfer model
content_path: Path to content audio file
style_path: Path to style reference audio file
output_path: Path to save the output audio
sample_rate: Target sample rate
device: Device to run inference on
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load content audio
content_audio, sr = torchaudio.load(content_path)
if sr != sample_rate:
content_audio = torchaudio.transforms.Resample(sr, sample_rate)(content_audio)

# Load style audio
style_audio, sr = torchaudio.load(style_path)
if sr != sample_rate:
style_audio = torchaudio.transforms.Resample(sr, sample_rate)(style_audio)

# Move to device
content_audio = content_audio.to(device)
style_audio = style_audio.to(device)

# Apply style transfer
model.eval()
with torch.no_grad():
generated_audio = model(content_audio, style_audio)

# Save output
torchaudio.save(
output_path,
generated_audio.cpu(),
sample_rate,
encoding='PCM_S',
bits_per_sample=16
)

def main():
parser = argparse.ArgumentParser(description="Apply Neural Audio Style Transfer")

parser.add_argument("--model-path", type=str, required=True, help="Path to trained model checkpoint")
parser.add_argument("--content-path", type=str, required=True, help="Path to content audio file")
parser.add_argument("--style-path", type=str, required=True, help="Path to style reference audio file")
parser.add_argument("--output-path", type=str, required=True, help="Path to save output audio")
parser.add_argument("--sample-rate", type=int, default=16000, help="Target sample rate")
parser.add_argument("--device", type=str, default=None, help="Device to run inference on")

args = parser.parse_args()

# Load model
checkpoint = torch.load(args.model_path, map_location=args.device)
model = NeuralAudioStyleTransfer()
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(args.device)

# Apply style transfer
apply_style_transfer(
model=model,
content_path=Path(args.content_path),
style_path=Path(args.style_path),
output_path=Path(args.output_path),
sample_rate=args.sample_rate,
device=args.device
)

if __name__ == "__main__":
main()
149 changes: 149 additions & 0 deletions encodec/modules/style_transfer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple

class StyleEncoder(nn.Module):
"""Encoder network for extracting style features from audio.

This module uses a combination of convolutional layers and attention
to extract style-specific features from audio input.
"""
def __init__(self, input_channels: int = 1, style_dim: int = 256):
super().__init__()
self.conv_layers = nn.Sequential(
nn.Conv1d(input_channels, 64, kernel_size=7, stride=2, padding=3),
nn.ReLU(),
nn.Conv1d(64, 128, kernel_size=4, stride=2, padding=1),
nn.ReLU(),
nn.Conv1d(128, 256, kernel_size=4, stride=2, padding=1),
nn.ReLU(),
)

self.attention = nn.MultiheadAttention(256, num_heads=8)
self.style_projection = nn.Linear(256, style_dim)

def forward(self, x: torch.Tensor) -> torch.Tensor:
# x shape: [batch, channels, time]
features = self.conv_layers(x)

# Reshape for attention
features = features.permute(2, 0, 1) # [time, batch, channels]
features, _ = self.attention(features, features, features)

# Global average pooling
style = features.mean(dim=0) # [batch, channels]
style = self.style_projection(style)
return style

class ContentEncoder(nn.Module):
"""Encoder network for extracting content features from audio.

This module focuses on capturing the structural and content-related
features of the audio while being style-invariant.
"""
def __init__(self, input_channels: int = 1, content_dim: int = 256):
super().__init__()
self.conv_layers = nn.Sequential(
nn.Conv1d(input_channels, 64, kernel_size=7, stride=2, padding=3),
nn.InstanceNorm1d(64),
nn.ReLU(),
nn.Conv1d(64, 128, kernel_size=4, stride=2, padding=1),
nn.InstanceNorm1d(128),
nn.ReLU(),
nn.Conv1d(128, 256, kernel_size=4, stride=2, padding=1),
nn.InstanceNorm1d(256),
nn.ReLU(),
)

self.content_projection = nn.Linear(256, content_dim)

def forward(self, x: torch.Tensor) -> torch.Tensor:
features = self.conv_layers(x)
# Global average pooling
content = features.mean(dim=2) # [batch, channels]
content = self.content_projection(content)
return content

class AudioDecoder(nn.Module):
"""Decoder network for generating audio from content and style features.

This module combines content and style features to generate
stylized audio output.
"""
def __init__(self, content_dim: int = 256, style_dim: int = 256, output_channels: int = 1):
super().__init__()
self.fusion = nn.Sequential(
nn.Linear(content_dim + style_dim, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
)

self.deconv_layers = nn.Sequential(
nn.ConvTranspose1d(1024, 512, kernel_size=4, stride=2, padding=1),
nn.InstanceNorm1d(512),
nn.ReLU(),
nn.ConvTranspose1d(512, 256, kernel_size=4, stride=2, padding=1),
nn.InstanceNorm1d(256),
nn.ReLU(),
nn.ConvTranspose1d(256, 128, kernel_size=4, stride=2, padding=1),
nn.InstanceNorm1d(128),
nn.ReLU(),
nn.ConvTranspose1d(128, output_channels, kernel_size=7, stride=2, padding=3),
nn.Tanh(),
)

def forward(self, content: torch.Tensor, style: torch.Tensor) -> torch.Tensor:
# Combine content and style
combined = torch.cat([content, style], dim=1)
features = self.fusion(combined)

# Reshape for deconvolution
features = features.unsqueeze(2) # Add time dimension
audio = self.deconv_layers(features)
return audio

class NeuralAudioStyleTransfer(nn.Module):
"""Complete neural audio style transfer model.

This model can transfer the style of one audio to another while
maintaining the content structure.
"""
def __init__(self, input_channels: int = 1, content_dim: int = 256, style_dim: int = 256):
super().__init__()
self.content_encoder = ContentEncoder(input_channels, content_dim)
self.style_encoder = StyleEncoder(input_channels, style_dim)
self.decoder = AudioDecoder(content_dim, style_dim, input_channels)

def encode_content(self, x: torch.Tensor) -> torch.Tensor:
"""Extract content features from input audio."""
return self.content_encoder(x)

def encode_style(self, x: torch.Tensor) -> torch.Tensor:
"""Extract style features from reference audio."""
return self.style_encoder(x)

def forward(self, content_audio: torch.Tensor, style_audio: torch.Tensor) -> torch.Tensor:
"""Transfer style from style_audio to content_audio."""
content_features = self.encode_content(content_audio)
style_features = self.encode_style(style_audio)
return self.decoder(content_features, style_features)

def compute_style_loss(self, generated: torch.Tensor, style: torch.Tensor) -> torch.Tensor:
"""Compute style loss between generated and style audio."""
gen_features = self.style_encoder(generated)
style_features = self.style_encoder(style)
return F.mse_loss(gen_features, style_features)

def compute_content_loss(self, generated: torch.Tensor, content: torch.Tensor) -> torch.Tensor:
"""Compute content loss between generated and content audio."""
gen_features = self.content_encoder(generated)
content_features = self.content_encoder(content)
return F.mse_loss(gen_features, content_features)
Loading