Skip to content

bominwang/Multi-Fidelity-Deep-Active-Learning

Repository files navigation

MultiFidelityBNN: 多保真度贝叶斯神经网络与主动学习框架

基于JAX和Numpyro的多保真度贝叶斯深度学习框架,实现了自回归神经网络架构、不确定性量化以及基于DPP的多保真度主动学习方法。


目录

  1. 项目结构
  2. 多保真度贝叶斯神经网络
  3. 多保真度主动学习方法
  4. 使用指南
  5. 案例配置
  6. 依赖项

项目结构

MultiFidelityBNN/
├── model/                              # 核心模型实现
│   ├── MLP_Units.py                   # 基础神经网络组件
│   ├── MF_Units.py                    # 保真度配置
│   ├── DeterministicAutoRegMF.py      # 确定性自回归网络
│   ├── StochasticAutoRegMF.py         # MCMC贝叶斯网络
│   ├── VariationalAutoRegMF.py        # 变分推断网络
│   ├── MCMC_Units.py                  # MCMC诊断工具
│   └── SWG_Units.py                   # 随机权重平均调度器
│
├── active_learning/                    # 主动学习模块
│   ├── sequential_sampling.py         # 顺序采样主循环
│   ├── multi_fidelity_dpp.py          # DPP采集函数
│   └── baseline.py                    # 基线采集策略
│
├── multi_fidelity_doe/                 # 试验设计
│   ├── MultiFidelityDOE.py            # 拉丁超立方采样
│   ├── nested.py                      # 嵌套采样
│   └── non_nested.py                  # 非嵌套采样
│
├── Funcs/                              # 测试函数
│   └── mathematic_funcs.py            # 多保真度测试函数
│
├── main.py                             # 入口点
├── cases_config.py                     # 案例配置
└── README.md                           # 本文档

多保真度贝叶斯神经网络

自回归架构

本框架采用自回归深度神经网络架构,低保真度层级的输出作为高保真度层级的输入:

数学表达式:

对于 $M$ 个保真度层级的系统:

$$y_m = f_m(\mathbf{x}, f_{m-1}(\mathbf{x}), \ldots, f_0(\mathbf{x})) + \epsilon_m, \quad m = 1, \ldots, M$$

其中:

  • $\mathbf{x} \in \mathbb{R}^D$ 为输入向量
  • $f_m$ 为第 $m$ 层的神经网络
  • $\epsilon_m \sim \mathcal{N}(0, \sigma_m^2)$ 为观测噪声
  • $c_m$ 为第 $m$ 层的计算成本(通常 $c_1 < c_2 < \ldots < c_M$

连接模式:

模式 输入维度 描述
稀疏连接 $D + 1$ (当 $m > 0$) $f_m(\mathbf{x}, f_{m-1}(\mathbf{x}))$
全连接 $D + m$ $f_m(\mathbf{x}, f_0(\mathbf{x}), \ldots, f_{m-1}(\mathbf{x}))$

概率模型

生成模型:

$$p(y_1, \ldots, y_M | \mathbf{x}, \boldsymbol{\theta}) = \prod_{m=1}^{M} \mathcal{N}(y_m | f_m(\mathbf{x}, \boldsymbol{\theta}), \sigma_m^2)$$

先验分布(信息性先验):

利用确定性预训练网络的参数作为先验均值:

$$p(\boldsymbol{\theta}) = \mathcal{N}(\boldsymbol{\theta}_{\text{det}}, \lambda^2 \mathbf{I})$$

其中:

  • $\boldsymbol{\theta}_{\text{det}}$ 为确定性网络的预训练参数
  • $\lambda$ 为先验标准差(默认0.6),控制先验的信息量

后验分布:

$$p(\boldsymbol{\theta} | \mathcal{D}) \propto p(\mathcal{D} | \boldsymbol{\theta}) p(\boldsymbol{\theta})$$


贝叶斯推断

MCMC推断 (StochasticAutoRegMF)

使用NUTS (No-U-Turn Sampler) 进行后验采样:

$${\boldsymbol{\theta}^{(t)}}_{t=1}^{T} \sim p(\boldsymbol{\theta} | \mathcal{D})$$

预测分布:

均值预测: $$\mu_m(\mathbf{x}) = \mathbb{E}{\boldsymbol{\theta}}[f_m(\mathbf{x}, \boldsymbol{\theta}) | \mathcal{D}] \approx \frac{1}{T}\sum{t=1}^{T} f_m(\mathbf{x}, \boldsymbol{\theta}^{(t)})$$

总不确定性: $$\sigma_m^2(\mathbf{x}) = \underbrace{\text{Var}{\boldsymbol{\theta}}[f_m(\mathbf{x}, \boldsymbol{\theta}) | \mathcal{D}]}{\text{认知不确定性}} + \underbrace{\sigma_{\text{obs}}^2}_{\text{偶然不确定性}}$$

变分推断 (VariationalAutoRegMF)

最大化证据下界(ELBO):

$$\mathcal{L}_{\text{ELBO}} = \mathbb{E}_{q(\boldsymbol{\theta})}[\log p(\mathcal{D}|\boldsymbol{\theta})] - \text{KL}(q(\boldsymbol{\theta}) | p(\boldsymbol{\theta}))$$

支持的变分族:

  • AutoNormal: 单变量正态近似
  • AutoDiagonalNormal: 对角协方差
  • AutoMultivariateNormal: 全协方差
  • AutoLowRankMultivariateNormal: 低秩近似

多保真度主动学习方法

最大加权互信息采集函数

核心思想: 在给定计算预算下,选择能够最大化高保真度预测信息增益的样本点和保真度层级。

互信息计算

基于Wasserstein距离的互信息度量:

$$\rho_m(\mathbf{x}) = \exp\left(-W_2(f_m(\mathbf{x}), f_M(\mathbf{x}))\right)$$

其中 $W_2$ 为2-Wasserstein距离:

$$W_2(f_m, f_M) = \sqrt{(\mu_m - \mu_M)^2 + (\sigma_m - \sigma_M)^2}$$

基于相关系数的互信息(可选):

$$I(f_m; f_M) \approx -\frac{1}{2}\log(1 - \rho^2 + \epsilon)$$

其中 $\rho = \text{Corr}(f_m^{(1:T)}, f_M^{(1:T)})$,裁剪到 $[-0.99, 0.99]$

覆盖权重函数

基于各向同性高斯核密度估计,降低已采样区域的权重:

$$w(\mathbf{x}) = 1 - \min\left(1, \sum_{j \in \mathcal{S}_m} \exp\left(-\frac{|\mathbf{x} - \mathbf{x}_j|^2}{0.3\sigma^2}\right)\right)$$

其中:

  • $\mathcal{S}_m$ 为第 $m$ 层已采样点集合
  • $\sigma$ 为自适应聚类阈值(基于平均最近邻距离)

聚类阈值计算:

$$\sigma = \alpha \cdot \frac{1}{N}\sum_{i=1}^{N} \min_{j \neq i} |\mathbf{x}_i - \mathbf{x}_j|$$

加权互信息增益

对于候选点 $\mathbf{x}$ 和保真度层级 $m$

$$G_m(\mathbf{x}) = w(\mathbf{x}) \cdot \frac{\rho_m(\mathbf{x}) \cdot H(f_M(\mathbf{x}))}{c_m}$$

其中 $H(f_M) = \sigma_M(\mathbf{x})$ 为最高保真度的不确定性(熵代理)。

最优保真度选择:

$$m^*(\mathbf{x}) = \arg\max_{m \in {1,\ldots,M}} G_m(\mathbf{x})$$

$$q(\mathbf{x}) = \max_{m} G_m(\mathbf{x})$$


DPP批量采集

使用行列式点过程(Determinantal Point Process) 实现质量-多样性权衡的批量采集。

DPP核矩阵构造

$$\mathbf{L} = \text{diag}(\mathbf{q}) \cdot \mathbf{S} \cdot \text{diag}(\mathbf{q})$$

其中:

  • $\mathbf{q} = [q(\mathbf{x}_1), \ldots, q(\mathbf{x}_N)]^T$ 为质量分数向量
  • $\mathbf{S}$ 为相似度/多样性核矩阵

多样性核函数

覆盖自适应核(推荐):

$$S_{ij} = \exp\left(-\frac{d_{ij}^2}{2r_{\text{exp}}^2}\right)$$

其中期望覆盖半径:

$$r_{\text{exp}} = \left(\frac{1}{n_U}\right)^{1/D} \cdot \sqrt{D}$$

  • $n_U$ 为批量大小
  • $D$ 为输入维度

距离基核:

$$S_{ij} = \max\left(0, 1 - \frac{d_{ij}}{d_{\max}}\right)^2$$

其中 $d_{\max} = \sqrt{D}$(归一化空间中的最大距离)。

贪心DPP-MAP选择

选择子集 $\mathcal{Y}^*$ 最大化:

$$\Pr(\mathcal{Y}^* \subseteq \mathbf{L}) = \det(\mathbf{L}_{\mathcal{Y}^*})$$

算法步骤(基于Cholesky分解):

  1. 初始化:选择 $i_1 = \arg\max_i L_{ii}$
  2. 对于 $k = 2, \ldots, n_U$
    • 更新Cholesky因子 $\mathbf{C}$
    • 计算条件增益:$g_j = L_{jj} - |\mathbf{v}_j|^2$
    • 选择 $i_k = \arg\max_j g_j$
  3. 返回选中索引集合

使用指南

基本训练流程

import jax.numpy as jnp
from MultiFidelityBNN.model.DeterministicAutoRegMF import DeterministicAutoRegressiveNetwork
from MultiFidelityBNN.model.StochasticAutoRegMF import StochasticAutoRegMF
from MultiFidelityBNN.model.MF_Units import FidelityConfig

# 1. 准备多保真度数据
train_data = {
    'x_1': low_fidelity_inputs,    # shape: (N1, D)
    'y_1': low_fidelity_outputs,   # shape: (N1, 1)
    'x_2': mid_fidelity_inputs,    # shape: (N2, D)
    'y_2': mid_fidelity_outputs,   # shape: (N2, 1)
    'x_3': high_fidelity_inputs,   # shape: (N3, D)
    'y_3': high_fidelity_outputs,  # shape: (N3, 1)
}

# 2. 定义网络配置
configs = [
    FidelityConfig(features=[32, 16], activation='elu'),  # 低保真度
    FidelityConfig(features=[32, 16], activation='elu'),  # 中保真度
    FidelityConfig(features=[32, 16], activation='elu'),  # 高保真度
]

# 3. 训练确定性网络(提供信息性先验)
det_model = DeterministicAutoRegressiveNetwork(
    input_dim=D,
    output_dim=1,
    fidelity_configs=configs,
    learning_rate=0.01,
    sgd_epochs=1000,
    standardize=True
)
det_model.train(train_data)
det_params = det_model.get_params()

# 4. 训练贝叶斯网络
bayes_model = StochasticAutoRegMF(
    input_dim=D,
    output_dim=1,
    fidelity_configs=configs,
    deterministic_params=det_params,  # 信息性先验
    prior_std=0.6,
    rng_seed=42,
    standardize=True
)
bayes_model.fit(
    train_data,
    num_warmup=500,      # MCMC预热步数
    num_samples=1000,    # MCMC采样步数
    num_chains=2,        # 并行链数
    progress_bar=True
)

# 5. 预测
x_test = jnp.linspace(low_bound, high_bound, 100).reshape(-1, 1)
predictions = bayes_model.predict(x_test, return_observation=True)

# 访问预测结果
for m in range(3):
    mean = predictions[f'f_{m+1}']['mean']  # 预测均值
    std = predictions[f'f_{m+1}']['std']    # 预测标准差

主动学习流程

from MultiFidelityBNN.active_learning.sequential_sampling import SequentialSampling
from MultiFidelityBNN.cases_config import Case_2

# 使用预定义案例配置
config = Case_2

# 运行主动学习
results = SequentialSampling(
    funcs=config.Funcs,           # 多保真度函数列表
    D=config.Dim,                 # 输入维度
    fidelity_num=len(config.Funcs),
    cost=config.COST,             # 各层级计算成本 [c1, c2, c3]
    sample_sizes=config.SampleSize,  # 初始样本数 [n1, n2, n3]
    pool_size=config.PoolSize,    # 候选池大小
    configs=config.NetConfigs,    # 网络配置
    method='DDP',                 # 'DDP' 或 'baseline'
    low_bounds=config.LOW,
    high_bounds=config.HIGH,
    warm_up_num=config.WarmUpNum,
    sample_num=config.SampleNum,
    seed=config.Seed,
    batch_size=config.BatchSize,  # 每轮采集样本数
    iteration_num=config.IterationNum,
    test_set=config.TestSet,
    error_threshold=config.ErrorThreshold
)

model, train_data, error_history, cost_history, sample_history = results

互信息计算

# 计算各保真度层级关于最高保真度的互信息
MI = bayes_model.mutual_information(x_test, wasserstein=True)

# MI['f_1']: 低保真度与高保真度的互信息
# MI['f_2']: 中保真度与高保真度的互信息
# MI['f_3']: 高保真度自身(恒为1.0)

# 计算经验预测核(用于DPP)
S = bayes_model.empirical_prediction_kernel(x_test)
# S: N x N 相关系数矩阵

案例配置

本框架提供6个测试案例,维度从1D到15D:

案例 维度 保真度数 成本 初始样本 网络结构
Case_1 1D 2 [2, 8] [8, 4] [16, 8]
Case_2 1D 3 [2, 4, 8] [8, 6, 4] [16, 4]
Case_3 2D 3 [2, 4, 12] [20, 10, 5] [32, 16, 8]
Case_4 6D 3 [2, 4, 12] [60, 30, 10] [64, 64]
Case_5 10D 3 [2, 4, 12] [100, 30, 20] [64, 64]
Case_6 15D 3 [2, 4, 12] [100, 30, 20] [64, 64]

运行案例:

cd G:/code/MultiFidelityBayes/MultiFidelityBNN
python main.py  # 默认运行Case_2
from main import main
results = main(case=3)  # 运行2D案例

依赖项

jax
jaxlib
flax
numpyro
optax
scipy
numpy
matplotlib
tqdm

数学符号汇总

符号 含义
$M$ 保真度层级数
$D$ 输入维度
$\mathbf{x}$ 输入向量
$f_m$ $m$ 层神经网络
$c_m$ $m$ 层计算成本
$\boldsymbol{\theta}$ 网络参数
$\sigma_m^2$ $m$ 层观测噪声方差
$\rho_m$ $m$ 层与最高层的相似度
$w(\mathbf{x})$ 覆盖权重函数
$G_m(\mathbf{x})$ 加权互信息增益
$\mathbf{L}$ DPP核矩阵
$\mathbf{S}$ 多样性核矩阵
$n_U$ 批量采集大小

算法流程图

┌─────────────────────────────────────────────────────────────────┐
│                      初始化试验设计                               │
│                 (拉丁超立方采样, 各层级)                          │
└─────────────────────────────────────────────────────────────────┘
                              │
                              ▼
┌─────────────────────────────────────────────────────────────────┐
│                     主动学习循环                                  │
│  ┌───────────────────────────────────────────────────────────┐  │
│  │ 1. 训练确定性网络 (SGD + SWA)                              │  │
│  │    └─> 获取预训练参数 θ_det                                │  │
│  └───────────────────────────────────────────────────────────┘  │
│                              │                                   │
│                              ▼                                   │
│  ┌───────────────────────────────────────────────────────────┐  │
│  │ 2. 训练贝叶斯网络 (MCMC-NUTS)                              │  │
│  │    └─> 后验采样 {θ^(t)}                                    │  │
│  └───────────────────────────────────────────────────────────┘  │
│                              │                                   │
│                              ▼                                   │
│  ┌───────────────────────────────────────────────────────────┐  │
│  │ 3. 评估测试误差                                            │  │
│  │    └─> 若 AAE < threshold, 提前停止                        │  │
│  └───────────────────────────────────────────────────────────┘  │
│                              │                                   │
│                              ▼                                   │
│  ┌───────────────────────────────────────────────────────────┐  │
│  │ 4. DPP批量采集                                             │  │
│  │    ├─> 计算候选池质量分数 q(x)                             │  │
│  │    ├─> 构建多样性核 S                                      │  │
│  │    ├─> 构建DPP核 L = diag(q) S diag(q)                    │  │
│  │    └─> 贪心选择 n_U 个样本                                 │  │
│  └───────────────────────────────────────────────────────────┘  │
│                              │                                   │
│                              ▼                                   │
│  ┌───────────────────────────────────────────────────────────┐  │
│  │ 5. 评估新样本并更新数据集                                   │  │
│  └───────────────────────────────────────────────────────────┘  │
│                              │                                   │
│                              ▼                                   │
│                        返回步骤1                                 │
└─────────────────────────────────────────────────────────────────┘

参考文献

  1. Kennedy, M. C., & O'Hagan, A. (2000). Predicting the output from a complex computer code when fast approximations are available. Biometrika.
  2. Kulesza, A., & Taskar, B. (2012). Determinantal point processes for machine learning. Foundations and Trends in Machine Learning.
  3. Hoffman, M. D., & Gelman, A. (2014). The No-U-Turn sampler. Journal of Machine Learning Research.

About

基于JAX和Numpyro的多保真度贝叶斯深度学习框架,实现了自回归神经网络架构、不确定性量化以及基于DPP的多保真度主动学习方法。

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages