通俗易懂讲透随机梯度下降法(SGD)

张开发
2026/4/16 14:03:43 15 分钟阅读

分享文章

通俗易懂讲透随机梯度下降法(SGD)
通俗易懂讲透随机梯度下降法SGD本科生/研究生都能看懂本文用大白话下山比喻公式拆解完整代码可视化把随机梯度下降SGD从原理、流程、优缺点到实战讲得明明白白适合机器学习入门、面试复习、课程笔记。一、先搞懂什么是随机梯度下降SGD一句话定义SGD 每次只随机抽一个样本算梯度然后更新参数的梯度下降。超级形象比喻你在下山找谷底批量梯度下降BGD每走一步把整座山地形看一遍 → 准但超级慢随机梯度下降SGD每步只看脚下一小块 → 快但会晃悠二、为什么要用 SGD在大数据时代数据动不动几百万、几千万批量梯度下降根本跑不动SGD 每步只算一个样本速度起飞三、SGD 核心思想超简单随机抽一个样本用它算梯度沿梯度反方向更新参数重复几万次 → 自然收敛因为是随机抽样本梯度带点“噪声”反而能跳出局部最优。四、数学公式超级易懂1. 损失函数L(θ)1N∑i1Nℓ(θ;xi,yi) L(\theta) \frac{1}{N}\sum_{i1}^N \ell(\theta;x_i,y_i)L(θ)N1​i1∑N​ℓ(θ;xi​,yi​)2. SGD 更新公式θt1θt−η⋅∇ℓ(θt;xi,yi) \theta_{t1} \theta_t - \eta \cdot \nabla \ell(\theta_t;x_i,y_i)θt1​θt​−η⋅∇ℓ(θt​;xi​,yi​)η学习率∇ℓ随机一个样本的梯度五、SGD 完整算法流程4步背会初始化参数 θ随机抽 1 个样本 (xix_ixi​,yiy_iyi​)计算梯度更新参数重复直到损失收敛六、代码实战SGD 训练线性回归直接复制可运行包含大数据集生成SGD 实现损失曲线 预测对比图importnumpyasnpimportmatplotlib.pyplotaspltfromsklearn.datasetsimportmake_regressionfromsklearn.model_selectionimporttrain_test_splitfromsklearn.preprocessingimportStandardScalerfromsklearn.metricsimportmean_squared_error# 1. 生成大数据集 X,ymake_regression(n_samples100000,n_features10,noise0.1,random_state42)X_train,X_test,y_train,y_testtrain_test_split(X,y,test_size0.2,random_state42)# 标准化scalerStandardScaler()X_trainscaler.fit_transform(X_train)X_testscaler.transform(X_test)# 2. SGD 回归实现 classSGDRegressor:def__init__(self,learning_rate0.01,n_iterations100,batch_size1):self.lrlearning_rate self.n_itern_iterations self.batch_sizebatch_size# 1SGD1小批量self.losses[]deffit(self,X,y):n_samples,n_featuresX.shape self.wnp.zeros(n_features)self.b0for_inrange(self.n_iter):# 随机打乱idxnp.random.permutation(n_samples)X_shufX[idx]y_shufy[idx]# 按批次遍历foriinrange(0,n_samples,self.batch_size):XbX_shuf[i:iself.batch_size]yby_shuf[i:iself.batch_size]y_predXb self.wself.b errory_pred-yb# 梯度grad_w(1/len(Xb))*Xb.T error grad_b(1/len(Xb))*np.sum(error)# 更新self.w-self.lr*grad_w self.b-self.lr*grad_b# 记录损失y_train_predX self.wself.b self.losses.append(mean_squared_error(y,y_train_pred))returnselfdefpredict(self,X):returnX self.wself.b# 3. 训练 SGD modelSGDRegressor(learning_rate0.01,n_iterations100,batch_size1)model.fit(X_train,y_train)# 4. 评估 y_predmodel.predict(X_test)msemean_squared_error(y_test,y_pred)print(f测试集 MSE {mse:.4f})# 5. 损失曲线 plt.figure(figsize(12,5))plt.plot(model.losses,b-,linewidth2)plt.title(SGD 训练损失曲线)plt.xlabel(迭代轮次)plt.ylabel(MSE)plt.grid()plt.show()# 6. 预测对比 plt.figure(figsize(12,5))plt.scatter(y_test,y_pred,alpha0.2)plt.plot([y_test.min(),y_test.max()],[y_test.min(),y_test.max()],r-,linewidth2)plt.title(真实值 vs 预测值)plt.xlabel(真实)plt.ylabel(预测)plt.grid()plt.show()七、SGD 优点速度极快每步只算一个样本内存占用小不用加载全部数据能跳出局部最优随机噪声帮助脱困适合大规模数据深度学习标配八、SGD 缺点梯度噪声大更新路径震荡收敛不稳定后期抖动学习率难调太大发散太小太慢后期收敛慢震荡着接近最低点九、BGD vs SGD vs Mini-batch GD速记算法每次用多少数据速度稳定性适用场景BGD全部最慢最稳小数据集SGD1个最快震荡大数据、深度学习Mini-batch一小批中较稳工业界通用十、SGD 适用场景✅适合大规模数据集深度学习CNN、RNN、Transformer在线学习、流式数据非凸优化❌不适合极小数据集追求绝对稳定收敛十一、一句话总结随机梯度下降SGD是大数据与深度学习的基石优化器用“随机采样快速更新”实现高效训练虽然会震荡但速度无人能敌。

更多文章