The FluxWrapper provides a unified interface for using Flux diffusion models with both BlackForest Labs and Diffusers backends. This allows you to seamlessly switch between implementations while maintaining the same API.
- Unified Interface: Same API regardless of backend
- Backend Flexibility: Easy switching between BlackForest and Diffusers
- Standard PyTorch Methods: Full support for
.to(),.eval(),.train(), etc. - Consistent Output:
unified_forward()always returns the same format - Backend-Specific Features: Access to unique features of each backend
Make sure you have the required dependencies:
pip install torch diffusers transformersfrom src.flux.flux_wrapper import create_flux_wrapper
# Create wrapper with BlackForest backend
wrapper = create_flux_wrapper(
model_name="flux-dev",
backend="blackforest",
device="cuda"
)
# Get model information
info = wrapper.get_model_info()
print(f"Loaded {info['backend']} backend with {info['num_parameters']:,} parameters")# BlackForest backend
wrapper_bf = create_flux_wrapper(
model_name="flux-dev",
backend="blackforest",
device="cuda"
)
# Diffusers backend
wrapper_diff = create_flux_wrapper(
model_name="flux-dev",
backend="diffusers",
device="cuda"
)
# Both use the same interface!from src.flux.flux_wrapper import FluxWrapper, FluxWrapperConfig
config = FluxWrapperConfig(
model_name="flux-dev",
backend="blackforest",
device="cuda",
torch_dtype=torch.bfloat16,
hf_download=True
)
wrapper = FluxWrapper(config)The unified_forward() method provides consistent output regardless of backend:
# Same code works for both backends
output = wrapper.unified_forward(
img=img_tensor,
img_ids=img_ids,
txt=txt_tensor,
txt_ids=txt_ids,
timesteps=timesteps,
y=pooled_embeddings,
guidance=guidance_tensor
)
# Always returns just the output tensor# Use the info dictionary for feature injection
info = {
'feature_path': 'features/',
'feature': {},
'inject_step': 5
}
output, updated_info = wrapper.forward(
img=img_tensor,
img_ids=img_ids,
txt=txt_tensor,
txt_ids=txt_ids,
timesteps=timesteps,
y=pooled_embeddings,
guidance=guidance_tensor,
info=info
)# Use joint attention kwargs for LoRA scaling
joint_attention_kwargs = {'scale': 1.0}
output = wrapper.forward(
img=img_tensor,
img_ids=img_ids,
txt=txt_tensor,
txt_ids=txt_ids,
timesteps=timesteps,
y=pooled_embeddings,
guidance=guidance_tensor,
joint_attention_kwargs=joint_attention_kwargs,
return_dict=False
)__init__(config: FluxWrapperConfig): Initialize wrapper with configurationforward(...): Full forward pass with backend-specific parametersunified_forward(...): Simplified forward pass with consistent outputget_model_info(): Get model information dictionaryto(device_or_dtype): Move model to device or change dtypeeval(): Set model to evaluation modetrain(mode=True): Set model to training mode
backend: Current backend ("blackforest" or "diffusers")device: Current deviceparams: Model parameters (unified across backends)in_channels: Input channelsout_channels: Output channels
Configuration dataclass for FluxWrapper:
@dataclass
class FluxWrapperConfig:
model_name: str = "flux-dev" # "flux-dev" or "flux-schnell"
backend: str = "blackforest" # "blackforest" or "diffusers"
device: str = "cuda" # Device to load model
hf_download: bool = True # Download from HuggingFace if needed
torch_dtype: torch.dtype = torch.bfloat16 # Model dtypedef create_flux_wrapper(
model_name: str = "flux-dev",
backend: str = "blackforest",
device: str = "cuda",
torch_dtype: torch.dtype = torch.bfloat16,
hf_download: bool = True
) -> FluxWrapperimport torch
from src.flux.flux_wrapper import create_flux_wrapper
# Load both backends
backends = ["blackforest", "diffusers"]
wrappers = {}
for backend in backends:
wrapper = create_flux_wrapper(
model_name="flux-dev",
backend=backend,
device="cuda"
)
wrappers[backend] = wrapper
info = wrapper.get_model_info()
print(f"{backend}: {info['num_parameters']:,} parameters")
# Use same inputs for both
img = torch.randn(1, 1024, 3072, device="cuda", dtype=torch.bfloat16)
# ... other inputs
# Compare outputs
for name, wrapper in wrappers.items():
output = wrapper.unified_forward(
img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids,
timesteps=timesteps, y=y, guidance=guidance
)
print(f"{name} output shape: {output.shape}")# Easy backend switching
def load_model(backend="blackforest"):
return create_flux_wrapper(
model_name="flux-dev",
backend=backend,
device="cuda"
)
# Switch between backends
model = load_model("blackforest") # Use BlackForest
# model = load_model("diffusers") # Use Diffusers
# Same code works for both
output = model.unified_forward(...)The wrapper includes comprehensive error handling:
try:
wrapper = create_flux_wrapper(
model_name="flux-dev",
backend="blackforest",
device="cuda"
)
output = wrapper.unified_forward(...)
except Exception as e:
print(f"Error: {e}")
# Handle error or fallback to different backend- Memory Usage: Both backends have similar memory requirements
- Speed: BlackForest may be slightly faster for custom implementations
- Compatibility: Diffusers backend offers better ecosystem integration
- Features: BlackForest backend provides more advanced feature injection
Run the test script to verify both backends work:
python test_flux_wrapper.pyThis will test:
- Model loading for both backends
- Forward pass compatibility
- Memory usage
- Error handling
- CUDA Out of Memory: Reduce batch size or use CPU
- Model Download Errors: Check internet connection and HuggingFace access
- Import Errors: Ensure all dependencies are installed
- Backend Differences: Use
unified_forward()for consistent behavior
# Get detailed model information
info = wrapper.get_model_info()
print(f"Backend: {info['backend']}")
print(f"Device: {info['device']}")
print(f"Parameters: {info['num_parameters']:,}")
print(f"Guidance: {info['guidance_embed']}")When adding new features:
- Ensure compatibility with both backends
- Add appropriate error handling
- Update documentation
- Test with both backends
- Consider backward compatibility
This wrapper follows the same license as the underlying Flux models:
- BlackForest Labs models: Check individual model licenses
- Diffusers: Apache 2.0 License