别被公式吓到!用Python和PyTorch手把手实现NeRF里的球面谐波(Spherical Harmonics)

张开发
2026/6/11 10:48:11 15 分钟阅读

分享文章

别被公式吓到!用Python和PyTorch手把手实现NeRF里的球面谐波(Spherical Harmonics)
别被公式吓到用Python和PyTorch手把手实现NeRF里的球面谐波Spherical Harmonics在3D重建领域球面谐波Spherical Harmonics, SH正成为NeRF、3D高斯泼溅3DGS等技术的核心组件。许多开发者被其复杂的数学表达式劝退却不知其代码实现远比公式直观。本文将用PyTorch从零构建SH函数带你穿透数学迷雾直击工程实现的本质。1. 环境准备与基础概念首先确保你的Python环境已安装以下库pip install torch matplotlib numpy球面谐波本质是一组定义在球面上的正交基函数类似于傅里叶级数在球坐标系的扩展。在NeRF中SH主要用于编码视角相关的颜色变化。其核心优势在于紧凑性低阶SH即可高精度拟合球面函数旋转不变性基函数在旋转时保持正交性计算高效只需预计算基函数值即可重复使用import torch import math import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D2. SH基函数的PyTorch实现2.1 极坐标转换SH基函数在球坐标系下定义需先将笛卡尔坐标转换为极坐标def cartesian_to_spherical(xyz): Convert Cartesian coordinates to spherical coordinates x, y, z xyz.unbind(-1) r torch.norm(xyz, dim-1) theta torch.acos(z / (r 1e-8)) # polar angle phi torch.atan2(y, x) # azimuthal angle return torch.stack([r, theta, phi], dim-1)2.2 关联勒让德多项式SH的实现依赖于关联勒让德多项式。以下是PyTorch优化版本def associated_legendre(l, m, x): Compute associated Legendre polynomials P_l^m(x) p_mm torch.ones_like(x) if m 0: p_mm (-1)**m * torch.prod(torch.arange(1, 2*m1, 2)) * (1 - x**2)**(m/2) if l m: return p_mm p_mp1m x * (2*m 1) * p_mm if l m 1: return p_mp1m p_lm torch.zeros_like(x) for n in range(m 2, l 1): p_lm ((2*n - 1) * x * p_mp1m - (n m - 1) * p_mm) / (n - m) p_mm, p_mp1m p_mp1m, p_lm return p_mp1m2.3 完整SH基函数组合上述组件实现SH基函数def spherical_harmonics(l, m, theta, phi): Compute real spherical harmonics Y_l^m(theta, phi) if m 0: Y math.sqrt(2) * associated_legendre(l, m, torch.cos(theta)) * torch.cos(m * phi) elif m 0: Y math.sqrt(2) * associated_legendre(l, -m, torch.cos(theta)) * torch.sin(-m * phi) else: Y associated_legendre(l, 0, torch.cos(theta)) return Y * math.sqrt((2*l 1)/(4*math.pi))3. 可视化与验证3.1 SH基函数可视化使用matplotlib绘制前9个SH基函数l0,1,2def visualize_sh(l_max2): fig plt.figure(figsize(15, 10)) theta torch.linspace(0, math.pi, 100) phi torch.linspace(0, 2*math.pi, 100) theta, phi torch.meshgrid(theta, phi) pos 1 for l in range(l_max 1): for m in range(-l, l 1): ax fig.add_subplot(l_max 1, 2*l_max 1, pos, projection3d) Y spherical_harmonics(l, m, theta, phi) # Convert to Cartesian for visualization x torch.sin(theta) * torch.cos(phi) * Y.abs() y torch.sin(theta) * torch.sin(phi) * Y.abs() z torch.cos(theta) * Y.abs() ax.plot_surface(x.numpy(), y.numpy(), z.numpy(), cmapviridis, edgecolornone) ax.set_title(fl{l}, m{m}) pos 1 plt.tight_layout() plt.show()3.2 数值验证验证SH的正交性def verify_orthogonality(l1, m1, l2, m2, n_samples1000): Verify orthogonality of SH functions theta torch.rand(n_samples) * math.pi phi torch.rand(n_samples) * 2 * math.pi Y1 spherical_harmonics(l1, m1, theta, phi) Y2 spherical_harmonics(l2, m2, theta, phi) integral torch.mean(Y1 * Y2 * torch.sin(theta)) * 4 * math.pi print(fY_{l1}^{m1}|Y_{l2}^{m2} {integral.item():.4f})提示实际应用中SH基函数通常预计算并存储为查找表以提升性能4. 集成到NeRF颜色网络4.1 SH系数学习在NeRF中SH系数通常作为网络输出的一部分class SHColorNetwork(torch.nn.Module): def __init__(self, sh_degree2, hidden_dim128): super().__init__() self.sh_degree sh_degree self.n_sh_coeffs (sh_degree 1)**2 # MLP to predict SH coefficients and density self.mlp torch.nn.Sequential( torch.nn.Linear(3, hidden_dim), torch.nn.ReLU(), torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.ReLU(), torch.nn.Linear(hidden_dim, self.n_sh_coeffs * 3 1) # RGB × SH sigma ) def forward(self, x, d): # x: 3D position, d: viewing direction (normalized) output self.mlp(x) sigma torch.sigmoid(output[..., :1]) sh_coeffs output[..., 1:].view(-1, 3, self.n_sh_coeffs) # Compute SH basis for viewing direction spherical cartesian_to_spherical(d) theta, phi spherical[..., 1], spherical[..., 2] basis [] for l in range(self.sh_degree 1): for m in range(-l, l 1): basis.append(spherical_harmonics(l, m, theta, phi)) basis torch.stack(basis, dim-1) # [..., n_coeffs] # Compute RGB color rgb torch.einsum(...c, ...s - ...c, sh_coeffs, basis) return torch.sigmoid(rgb), sigma4.2 训练技巧实际训练时需注意初始化策略def init_weights(m): if isinstance(m, torch.nn.Linear): torch.nn.init.xavier_uniform_(m.weight) torch.nn.init.zeros_(m.bias) model.apply(init_weights)学习率调整optimizer torch.optim.Adam(model.parameters(), lr1e-3) scheduler torch.optim.lr_scheduler.StepLR(optimizer, step_size1000, gamma0.5)正则化方法# Add L2 regularization on SH coefficients def sh_regularization(model): loss 0 for param in model.mlp[-1].parameters(): loss torch.norm(param, p2) return loss * 0.015. 性能优化与调试5.1 内存优化技巧当处理高分辨率图像时# 使用torch.utils.checkpoint减少内存占用 from torch.utils.checkpoint import checkpoint class MemoryEfficientSH(torch.nn.Module): def forward(self, x, d): def create_custom_forward(module): def custom_forward(*inputs): return module(inputs[0]) return custom_forward # Only save intermediate activations for the MLP output checkpoint(create_custom_forward(self.mlp), x) # ... rest of the computation ...5.2 常见问题排查问题现象可能原因解决方案颜色出现带状伪影SH阶数不足增加sh_degree到3或4训练不收敛系数初始化不当使用Xavier初始化并减小初始学习率渲染速度慢重复计算基函数预计算SH基函数查找表5.3 混合精度训练利用AMP加速训练scaler torch.cuda.amp.GradScaler() with torch.cuda.amp.autocast(): rgb_pred, sigma_pred model(x, d) loss compute_loss(rgb_pred, sigma_pred, rgb_gt) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()6. 进阶应用与扩展6.1 动态场景处理对于动态3DGS可扩展SH系数为时变函数class DynamicSH(torch.nn.Module): def __init__(self, n_frames, sh_degree3): super().__init__() self.sh_coeffs torch.nn.Parameter( torch.rand(n_frames, (sh_degree 1)**2, 3) * 0.01) def get_coeffs(self, frame_idx): return self.sh_coeffs[frame_idx]6.2 各向异性反射建模通过组合不同阶数的SH实现复杂材质def anisotropic_sh(d, sh_coeffs_list): Combine multiple SH representations basis compute_sh_basis(d) rgb 0 for coeffs, weight in zip(sh_coeffs_list, [0.3, 0.5, 0.2]): rgb weight * torch.einsum(...c, ...s - ...c, coeffs, basis) return rgb6.3 与其他编码方式结合将SH与位置编码结合提升表现力class HybridEncoder(torch.nn.Module): def __init__(self, pos_enc_dim10, sh_degree2): super().__init__() self.pos_encoder PositionalEncoding(pos_enc_dim) self.sh_encoder SHEncoder(sh_degree) def forward(self, x, d): pos_feat self.pos_encoder(x) sh_feat self.sh_encoder(d) return torch.cat([pos_feat, sh_feat], dim-1)

更多文章