基于JAX和Numpyro的多保真度贝叶斯深度学习框架,实现了自回归神经网络架构、不确定性量化以及基于DPP的多保真度主动学习方法。
MultiFidelityBNN/
├── model/ # 核心模型实现
│ ├── MLP_Units.py # 基础神经网络组件
│ ├── MF_Units.py # 保真度配置
│ ├── DeterministicAutoRegMF.py # 确定性自回归网络
│ ├── StochasticAutoRegMF.py # MCMC贝叶斯网络
│ ├── VariationalAutoRegMF.py # 变分推断网络
│ ├── MCMC_Units.py # MCMC诊断工具
│ └── SWG_Units.py # 随机权重平均调度器
│
├── active_learning/ # 主动学习模块
│ ├── sequential_sampling.py # 顺序采样主循环
│ ├── multi_fidelity_dpp.py # DPP采集函数
│ └── baseline.py # 基线采集策略
│
├── multi_fidelity_doe/ # 试验设计
│ ├── MultiFidelityDOE.py # 拉丁超立方采样
│ ├── nested.py # 嵌套采样
│ └── non_nested.py # 非嵌套采样
│
├── Funcs/ # 测试函数
│ └── mathematic_funcs.py # 多保真度测试函数
│
├── main.py # 入口点
├── cases_config.py # 案例配置
└── README.md # 本文档
本框架采用自回归深度神经网络架构,低保真度层级的输出作为高保真度层级的输入:
数学表达式:
对于
其中:
-
$\mathbf{x} \in \mathbb{R}^D$ 为输入向量 -
$f_m$ 为第$m$ 层的神经网络 -
$\epsilon_m \sim \mathcal{N}(0, \sigma_m^2)$ 为观测噪声 -
$c_m$ 为第$m$ 层的计算成本(通常$c_1 < c_2 < \ldots < c_M$ )
连接模式:
| 模式 | 输入维度 | 描述 |
|---|---|---|
| 稀疏连接 |
|
|
| 全连接 |
生成模型:
先验分布(信息性先验):
利用确定性预训练网络的参数作为先验均值:
其中:
-
$\boldsymbol{\theta}_{\text{det}}$ 为确定性网络的预训练参数 -
$\lambda$ 为先验标准差(默认0.6),控制先验的信息量
后验分布:
使用NUTS (No-U-Turn Sampler) 进行后验采样:
预测分布:
均值预测: $$\mu_m(\mathbf{x}) = \mathbb{E}{\boldsymbol{\theta}}[f_m(\mathbf{x}, \boldsymbol{\theta}) | \mathcal{D}] \approx \frac{1}{T}\sum{t=1}^{T} f_m(\mathbf{x}, \boldsymbol{\theta}^{(t)})$$
总不确定性: $$\sigma_m^2(\mathbf{x}) = \underbrace{\text{Var}{\boldsymbol{\theta}}[f_m(\mathbf{x}, \boldsymbol{\theta}) | \mathcal{D}]}{\text{认知不确定性}} + \underbrace{\sigma_{\text{obs}}^2}_{\text{偶然不确定性}}$$
最大化证据下界(ELBO):
支持的变分族:
AutoNormal: 单变量正态近似AutoDiagonalNormal: 对角协方差AutoMultivariateNormal: 全协方差AutoLowRankMultivariateNormal: 低秩近似
核心思想: 在给定计算预算下,选择能够最大化高保真度预测信息增益的样本点和保真度层级。
基于Wasserstein距离的互信息度量:
其中
基于相关系数的互信息(可选):
其中
基于各向同性高斯核密度估计,降低已采样区域的权重:
其中:
-
$\mathcal{S}_m$ 为第$m$ 层已采样点集合 -
$\sigma$ 为自适应聚类阈值(基于平均最近邻距离)
聚类阈值计算:
对于候选点
其中
最优保真度选择:
使用行列式点过程(Determinantal Point Process) 实现质量-多样性权衡的批量采集。
其中:
-
$\mathbf{q} = [q(\mathbf{x}_1), \ldots, q(\mathbf{x}_N)]^T$ 为质量分数向量 -
$\mathbf{S}$ 为相似度/多样性核矩阵
覆盖自适应核(推荐):
其中期望覆盖半径:
-
$n_U$ 为批量大小 -
$D$ 为输入维度
距离基核:
其中
选择子集
算法步骤(基于Cholesky分解):
- 初始化:选择
$i_1 = \arg\max_i L_{ii}$ - 对于
$k = 2, \ldots, n_U$ :- 更新Cholesky因子
$\mathbf{C}$ - 计算条件增益:$g_j = L_{jj} - |\mathbf{v}_j|^2$
- 选择
$i_k = \arg\max_j g_j$
- 更新Cholesky因子
- 返回选中索引集合
import jax.numpy as jnp
from MultiFidelityBNN.model.DeterministicAutoRegMF import DeterministicAutoRegressiveNetwork
from MultiFidelityBNN.model.StochasticAutoRegMF import StochasticAutoRegMF
from MultiFidelityBNN.model.MF_Units import FidelityConfig
# 1. 准备多保真度数据
train_data = {
'x_1': low_fidelity_inputs, # shape: (N1, D)
'y_1': low_fidelity_outputs, # shape: (N1, 1)
'x_2': mid_fidelity_inputs, # shape: (N2, D)
'y_2': mid_fidelity_outputs, # shape: (N2, 1)
'x_3': high_fidelity_inputs, # shape: (N3, D)
'y_3': high_fidelity_outputs, # shape: (N3, 1)
}
# 2. 定义网络配置
configs = [
FidelityConfig(features=[32, 16], activation='elu'), # 低保真度
FidelityConfig(features=[32, 16], activation='elu'), # 中保真度
FidelityConfig(features=[32, 16], activation='elu'), # 高保真度
]
# 3. 训练确定性网络(提供信息性先验)
det_model = DeterministicAutoRegressiveNetwork(
input_dim=D,
output_dim=1,
fidelity_configs=configs,
learning_rate=0.01,
sgd_epochs=1000,
standardize=True
)
det_model.train(train_data)
det_params = det_model.get_params()
# 4. 训练贝叶斯网络
bayes_model = StochasticAutoRegMF(
input_dim=D,
output_dim=1,
fidelity_configs=configs,
deterministic_params=det_params, # 信息性先验
prior_std=0.6,
rng_seed=42,
standardize=True
)
bayes_model.fit(
train_data,
num_warmup=500, # MCMC预热步数
num_samples=1000, # MCMC采样步数
num_chains=2, # 并行链数
progress_bar=True
)
# 5. 预测
x_test = jnp.linspace(low_bound, high_bound, 100).reshape(-1, 1)
predictions = bayes_model.predict(x_test, return_observation=True)
# 访问预测结果
for m in range(3):
mean = predictions[f'f_{m+1}']['mean'] # 预测均值
std = predictions[f'f_{m+1}']['std'] # 预测标准差from MultiFidelityBNN.active_learning.sequential_sampling import SequentialSampling
from MultiFidelityBNN.cases_config import Case_2
# 使用预定义案例配置
config = Case_2
# 运行主动学习
results = SequentialSampling(
funcs=config.Funcs, # 多保真度函数列表
D=config.Dim, # 输入维度
fidelity_num=len(config.Funcs),
cost=config.COST, # 各层级计算成本 [c1, c2, c3]
sample_sizes=config.SampleSize, # 初始样本数 [n1, n2, n3]
pool_size=config.PoolSize, # 候选池大小
configs=config.NetConfigs, # 网络配置
method='DDP', # 'DDP' 或 'baseline'
low_bounds=config.LOW,
high_bounds=config.HIGH,
warm_up_num=config.WarmUpNum,
sample_num=config.SampleNum,
seed=config.Seed,
batch_size=config.BatchSize, # 每轮采集样本数
iteration_num=config.IterationNum,
test_set=config.TestSet,
error_threshold=config.ErrorThreshold
)
model, train_data, error_history, cost_history, sample_history = results# 计算各保真度层级关于最高保真度的互信息
MI = bayes_model.mutual_information(x_test, wasserstein=True)
# MI['f_1']: 低保真度与高保真度的互信息
# MI['f_2']: 中保真度与高保真度的互信息
# MI['f_3']: 高保真度自身(恒为1.0)
# 计算经验预测核(用于DPP)
S = bayes_model.empirical_prediction_kernel(x_test)
# S: N x N 相关系数矩阵本框架提供6个测试案例,维度从1D到15D:
| 案例 | 维度 | 保真度数 | 成本 | 初始样本 | 网络结构 |
|---|---|---|---|---|---|
| Case_1 | 1D | 2 | [2, 8] | [8, 4] | [16, 8] |
| Case_2 | 1D | 3 | [2, 4, 8] | [8, 6, 4] | [16, 4] |
| Case_3 | 2D | 3 | [2, 4, 12] | [20, 10, 5] | [32, 16, 8] |
| Case_4 | 6D | 3 | [2, 4, 12] | [60, 30, 10] | [64, 64] |
| Case_5 | 10D | 3 | [2, 4, 12] | [100, 30, 20] | [64, 64] |
| Case_6 | 15D | 3 | [2, 4, 12] | [100, 30, 20] | [64, 64] |
运行案例:
cd G:/code/MultiFidelityBayes/MultiFidelityBNN
python main.py # 默认运行Case_2from main import main
results = main(case=3) # 运行2D案例jax
jaxlib
flax
numpyro
optax
scipy
numpy
matplotlib
tqdm
| 符号 | 含义 |
|---|---|
| 保真度层级数 | |
| 输入维度 | |
| 输入向量 | |
| 第 |
|
| 第 |
|
| 网络参数 | |
| 第 |
|
| 第 |
|
| 覆盖权重函数 | |
| 加权互信息增益 | |
| DPP核矩阵 | |
| 多样性核矩阵 | |
| 批量采集大小 |
┌─────────────────────────────────────────────────────────────────┐
│ 初始化试验设计 │
│ (拉丁超立方采样, 各层级) │
└─────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ 主动学习循环 │
│ ┌───────────────────────────────────────────────────────────┐ │
│ │ 1. 训练确定性网络 (SGD + SWA) │ │
│ │ └─> 获取预训练参数 θ_det │ │
│ └───────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌───────────────────────────────────────────────────────────┐ │
│ │ 2. 训练贝叶斯网络 (MCMC-NUTS) │ │
│ │ └─> 后验采样 {θ^(t)} │ │
│ └───────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌───────────────────────────────────────────────────────────┐ │
│ │ 3. 评估测试误差 │ │
│ │ └─> 若 AAE < threshold, 提前停止 │ │
│ └───────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌───────────────────────────────────────────────────────────┐ │
│ │ 4. DPP批量采集 │ │
│ │ ├─> 计算候选池质量分数 q(x) │ │
│ │ ├─> 构建多样性核 S │ │
│ │ ├─> 构建DPP核 L = diag(q) S diag(q) │ │
│ │ └─> 贪心选择 n_U 个样本 │ │
│ └───────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌───────────────────────────────────────────────────────────┐ │
│ │ 5. 评估新样本并更新数据集 │ │
│ └───────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ 返回步骤1 │
└─────────────────────────────────────────────────────────────────┘
- Kennedy, M. C., & O'Hagan, A. (2000). Predicting the output from a complex computer code when fast approximations are available. Biometrika.
- Kulesza, A., & Taskar, B. (2012). Determinantal point processes for machine learning. Foundations and Trends in Machine Learning.
- Hoffman, M. D., & Gelman, A. (2014). The No-U-Turn sampler. Journal of Machine Learning Research.