从理论到代码:手把手推导MMD公式,并用NumPy/PyTorch复现一个简易的图像质量评估器

张开发
2026/4/18 11:00:33 15 分钟阅读

分享文章

从理论到代码:手把手推导MMD公式,并用NumPy/PyTorch复现一个简易的图像质量评估器
从零推导MMD用NumPy和PyTorch构建图像质量评估器在计算机视觉领域判断两组图像是否来自同一分布是个基础但关键的问题。想象你训练了一个生成模型如何量化它生成的图片与真实照片的差距传统方法如PSNR、SSIM只能捕捉像素级差异而**最大均值差异MMD**提供了一种更本质的分布相似性度量方式。今天我们将抛开框架封装从数学原理出发用纯手工实现揭开MMD的神秘面纱。1. MMD的核心思想与数学原理1.1 分布差异的直观理解假设你面前有两堆沙子如何判断它们是否来自同一个沙坑最直接的方法是抓取样本进行对比传统方法比较沙粒的平均大小、颜色类似PSNR/SSIMMMD方法将沙粒放入特殊显微镜核函数观察微观结构比较所有可能视角下的平均特征数学上MMD通过将数据映射到再生核希尔伯特空间RKHS在该空间中计算两个分布样本均值的距离。关键公式如下MMD^2(P,Q) \mathbb{E}_{x,x}[k(x,x)] \mathbb{E}_{y,y}[k(y,y)] - 2\mathbb{E}_{x,y}[k(x,y)]提示当且仅当PQ时MMD0。高斯核是最常用的核函数其带宽参数σ控制着特征空间的尺度敏感性。1.2 核函数的选择艺术不同的核函数就像不同的显微镜镜头核类型公式适用场景高斯核exp(-拉普拉斯核exp(-线性核xᵀy计算简单表达能力弱在图像领域我们通常采用多尺度高斯核组合来捕捉不同层次的特征差异。2. NumPy基础实现2.1 核矩阵计算我们先实现核心的高斯核函数import numpy as np def gaussian_kernel(X, Y, sigma1.0): 计算样本集X和Y之间的高斯核矩阵 X: (m,d) numpy数组 Y: (n,d) numpy数组 sigma: 高斯核带宽 返回: (m,n)核矩阵 XX np.sum(X**2, axis1)[:, np.newaxis] YY np.sum(Y**2, axis1)[np.newaxis, :] distances XX YY - 2 * np.dot(X, Y.T) return np.exp(-distances / (2 * sigma**2))2.2 完整MMD计算流程结合核函数实现MMDdef mmd_naive(X, Y, kernelgaussian, sigma1.0): 基础MMD实现 if kernel gaussian: K_XX gaussian_kernel(X, X, sigma) K_YY gaussian_kernel(Y, Y, sigma) K_XY gaussian_kernel(X, Y, sigma) else: raise ValueError(Unsupported kernel) m X.shape[0] n Y.shape[0] # 无偏估计版本 term1 (K_XX.sum() - np.trace(K_XX)) / (m*(m-1)) term2 (K_YY.sum() - np.trace(K_YY)) / (n*(n-1)) term3 K_XY.sum() * 2 / (m*n) return np.sqrt(term1 term2 - term3)注意这里使用了无偏估计版本避免对角线元素自相似度对结果的影响。实际应用中可能需要考虑计算效率的优化。3. PyTorch优化实现3.1 批处理与GPU加速用PyTorch重构实现支持自动微分和GPU加速import torch def mmd_rbf(X, Y, kernel_mul2.0, kernel_num5): 多尺度高斯核MMD实现 batch_size X.size(0) total torch.cat([X, Y], dim0) # 计算成对距离矩阵 XX torch.sum(X**2, dim1, keepdimTrue) YY torch.sum(Y**2, dim1, keepdimTrue) distances XX YY.T - 2 * torch.mm(X, Y.T) # 自适应带宽选择 median_distance torch.median(distances.detach()) sigma_list [median_distance * (kernel_mul**i) for i in range(-kernel_num//2, kernel_num//21)] # 多核组合 kernel_val sum(torch.exp(-distances / (2 * sigma**2)) for sigma in sigma_list) # MMD计算 K_XX kernel_val[:batch_size, :batch_size] K_YY kernel_val[batch_size:, batch_size:] K_XY kernel_val[:batch_size, batch_size:] return torch.mean(K_XX) torch.mean(K_YY) - 2 * torch.mean(K_XY)3.2 实际图像评估案例让我们在CIFAR-10数据上测试from torchvision import datasets, transforms # 数据准备 transform transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]) real_data datasets.CIFAR10(root./data, trainTrue, downloadTrue, transformtransform) fake_data datasets.CIFAR10(root./data, trainFalse, downloadTrue, transformtransform) # 提取特征 def get_features(dataloader, sample_size1000): features [] for i, (images, _) in enumerate(dataloader): features.append(images.flatten(start_dim1)) if len(features) * images.size(0) sample_size: break return torch.cat(features)[:sample_size] real_features get_features(torch.utils.data.DataLoader(real_data, batch_size100)) fake_features get_features(torch.utils.data.DataLoader(fake_data, batch_size100)) # 计算MMD distance mmd_rbf(real_features, fake_features) print(fMMD距离: {distance.item():.4f})4. 工业级实现对比4.1 Inception特征空间的重要性原始像素空间的MMD可能无法捕捉语义级差异。工业界常用做法使用Inception-v3的中间层作为特征提取器在特征空间计算MMD即著名的FID指标基础采用多尺度金字塔结构增强鲁棒性from torchvision.models import inception_v3 class InceptionMMD: def __init__(self, devicecuda): self.model inception_v3(pretrainedTrue) self.model.fc torch.nn.Identity() # 移除全连接层 self.model.to(device) self.model.eval() def get_features(self, images): with torch.no_grad(): return self.model(images) def __call__(self, X, Y): feat_X self.get_features(X) feat_Y self.get_features(Y) return mmd_rbf(feat_X, feat_Y)4.2 实际应用建议在真实项目中对于256x256以上图像建议使用多层级特征提取批量大小至少64以上以获得稳定估计可结合自适应带宽选择策略考虑使用线性时间估计版本处理大数据集def linear_mmd(feat_X, feat_Y): 线性复杂度MMD估计 kernel lambda x,y: torch.exp(-torch.norm(x-y, p2)**2 / feat_X.size(1)) phi_X torch.mean(torch.stack([kernel(x, x) for x in feat_X]), dim0) phi_Y torch.mean(torch.stack([kernel(y, y) for y in feat_Y]), dim0) return torch.norm(phi_X - phi_Y, p2)5. 常见问题与调优技巧5.1 超参数选择策略带宽σ通常取数据 pairwise 距离的中位数核数量3-5个不同尺度通常足够样本量每类至少1000个样本可获得稳定估计5.2 数值稳定性处理当数据尺度差异大时def stable_mmd(X, Y): # 数据标准化 X (X - X.mean()) / X.std() Y (Y - Y.mean()) / Y.std() # 添加小常数防止数值溢出 epsilon 1e-6 distances torch.cdist(X, Y) epsilon # 使用log-sum-exp技巧 max_val torch.max(distances) kernel_val torch.exp(-(distances - max_val)**2 / 2) return kernel_val.mean()5.3 与其他指标对比在图像生成评估中指标计算复杂度感知相关性是否需要参考图像MMDO(n²)中是FIDO(n²)高是ISO(n)中否PSNRO(n)低是实际项目中我通常会同时计算MMD和FID作为互补指标。当发现MMD降低但FID升高时往往意味着模型出现了模式崩溃。

更多文章