@梯度下降 gradient descent
1. 定义
适用范围: 监督学习涉及梯度下降
目的: loss=(f(x)-y)^2, 为使loss函数最小,其中f(x)为预测值,y为真实值
函数范围: 对于机器学习,loss函数是凸函数(这意味着连接曲线上任意两点的线段永远不会跟曲线相交)
在x1处, ∂ y ∂ x > 0 \frac{\partial{y}}{\partial{x}} > 0 ∂x∂y>0 y随x增大而增大,为求最小值, x n e w = x 1 + ( − 1 ) ∗ ∂ y ∂ x x_{new} = x1 + (-1) * \frac{\partial{y}}{\partial{x}} xnew=x1+(−1)∗∂x∂y
在x2处, ∂ y ∂ x < 0 \frac{\partial{y}}{\partial{x}} < 0 ∂x∂y<0 y随x增大而减小,为求最小值, x n e w = x 2 + ( − 1 ) ∗ ∂ y ∂ x x_{new} = x2 + (-1) * \frac{\partial{y}}{\partial{x}} xnew=x2+(−1)∗∂x∂y
在x3处,函数偏导数比较大, x n e w = x 3 + ( − 1 ) ∗ ∂ y ∂ x x_{new} = x3 + (-1) * \frac{\partial{y}}{\partial{x}} xnew=x3+(−1)∗∂x∂y x n e w x_{new} xnew退到到离x3左边很远的地方,之后迭代过程中, x n e w x_{new} xnew会在最小值左右两边反复横跳,为避免这种情况的发生,会加入learning rate ,假设令 α = 1 0 − 4 \alpha=10^{-4} α=10−4, x n e w = x 3 + ( − 1 ) ∗ ∂ y ∂ x α x_{new} = x3 + (-1) * \frac{\partial{y}}{\partial{x}}\alpha xnew=x3+(−1)∗∂x∂yα 此时可以避免上述情况
因此,梯度下降函数如下: x = x + ( − 1 ) ∗ ∂ y ∂ x α x = x + (-1)*\frac{\partial{y}}{\partial{x}}\alpha x=x+(−1)∗∂x∂yα
2. 代码举例
import numpy as np
# 定义函数
def func(x):
return 10*x**2 + 37*x + 9
# 定义导数
def gradient(x):
return 20*x + 37
# 画出函数在(-10,10)内图像
import matplotlib.pyplot as plt
x = np.linspace(-10,10)
plt.plot(x,func(x))
import random
# 定义steps用于存储迭代过程中的x
steps = []
# 随机初始化x
x_optimal = random.choice(x)
# 设置学习率
alpha = 1e-3
# 梯度下降200次
for i in range(200):
x_optimal += (-1)*gradient(x_optimal)*alpha
steps.append(x_optimal)
# 输出中间步骤
for s in steps:
print(s,func(s))
输出结果:
0.9630000000000006 53.90469000000003
0.9067400000000007 50.77115427600004
0.8516052000000006 47.76170656667043
0.7975730960000006 44.87143298663028
0.7446216340800006 42.09561424035972
0.6927292013984006 39.429717916441476
0.6418746173704325 36.869391086950394
0.5920371250230239 34.41045319990716
0.5431963825225634 32.048889253190836
0.49533245487211214 29.78084323876448
0.4484258057746699 27.602611846509404
0.4024572896591765 25.510638417387632
0.35740814386599296 23.501507136059082
0.3132599809886731 21.571937453471143
0.2699947813688996 19.718778730313684
0.2275948857415216 17.939005092593263
0.18604298802669117 16.229710490926568
0.14532212826615734 14.588103955485877
0.10541568570083419 13.011505038848636
0.0663073719868175 11.49733943931023
0.02798122454708115 10.043134797513545
-0.00957839994386047 8.646516659532008
-0.04638683194498326 7.305204599814541
-0.0824590953060836 6.017008497661885
-0.11780991339996194 4.779824961154473
-0.1524537151319627 3.591633892692756
-0.18640464082932345 2.4504951905421235
-0.21967654801273698 1.3545455809966551
-0.25228301705248224 0.3019955759891868
-0.2842373567114326 -0.7088734488199826
-0.31555260957720394 -1.6797120602467128
-0.34624155738565987 -2.612105462660944
-0.37631672623794665 -3.5075760863395686
-0.4057903917131877 -4.367586073320522
-0.43467458387892394 -5.193539664817031
-0.4629810922013455 -5.986785494090274
-0.49072147035731856 -6.7486187885243005
-0.5179070409501721 -7.4802834844987345
-0.5445489001311687 -8.182974258512587
-0.5706579221285454 -8.85783847787549
-0.5962447636859745 -9.50597807415162
-0.621319868412255 -10.128451342415218
-0.6458934710440098 -10.726274669255574
-0.6699756016231296 -11.300424192353052
-0.6935760895906671 -11.85183739433587
-0.7167045677988537 -12.381414633520173
-0.7393704764428767 -12.890020614032771
-0.7615830669140191 -13.378485797717076
-0.7833514055757387 -13.847607760127474
-0.8046843774642239 -14.298152492826432
-0.8255906899149394 -14.730855654110503
-0.8460788761166407 -15.146423770207726
-0.8661572985943079 -15.5455353889075
-0.8858341526224217 -15.928842187506763
-0.9051174695699733 -16.296970036881497
-0.9240151201785738 -16.650520023420988
-0.9425348177750024 -16.99006943049352
-0.9606841214195023 -17.316172681045977
-0.9784704389911123 -17.629362242876557
-0.9959010302112901 -17.930149498058643
-1.0129830096070642 -18.219025577935522
-1.029723349414923 -18.496462165049273
-1.0461288824266244 -18.76291226331332
-1.0622063047780919 -19.01881093768611
-1.07796217868253 -19.264576024553744
-1.0934029351088794 -19.500608813981415
-1.1085348764067018 -19.727294704947752
-1.1233641788785678 -19.94500383463182
-1.1378968953009965 -20.154091682780397
-1.1521389573949765 -20.354899652142294
-1.166096178247077 -20.547755625917457
-1.1797742546821353 -20.73297450313113
-1.1931787695884926 -20.910858712807133
-1.2063151941967227 -21.08169870777997
-1.2191888903127883 -21.245773438951886
-1.2318051125065326 -21.403350810769393
-1.244169010256402 -21.554688118662927
-1.256285630051274 -21.70003246916387
-1.2681599174502485 -21.83962118338498
-1.2797967191012436 -21.973682184522936
-1.2912007847192186 -22.10243437001583
-1.3023767690248342 -22.226087968963206
-1.3133292336443376 -22.344844885392263
-1.3240626489714509 -22.45889902793073
-1.3345813959920219 -22.56843662642467
-1.3448897680721814 -22.673636536018254
-1.3549919727107378 -22.77467052919193
-1.3648921332565231 -22.871703576235934
-1.3745942905913926 -22.964894114616985
-1.3841024047795647 -23.05439430767815
-1.3934203566839733 -23.140350293094095
-1.4025519495502938 -23.222902421487568
-1.411500910559288 -23.302185485596667
-1.4202708923481022 -23.37832894036704
-1.4288654745011402 -23.451457114328505
-1.4372881650111173 -23.52168941260109
-1.445542401710895 -23.58914051186209
-1.4536315536766773 -23.653920547592357
-1.4615589226031438 -23.716135293907698
-1.469327744151081 -23.775886336268954
-1.4769411892680595 -23.833271237352704
-1.4844023654826983 -23.888383696353536
-1.4917143181730443 -23.941313701977933
-1.4988800318095834 -23.992147679379613
-1.5059024311733917 -24.040968631276172
-1.5127843825499239 -24.08785627347764
-1.5195286948989253 -24.13288716504792
-1.5261381210009468 -24.176134833312034
-1.532615358580928 -24.21766989391287
-1.5389630514093093 -24.25756016611392
-1.5451837903811232 -24.29587078353581
-1.5512801145735007 -24.332664300507794
-1.5572545122820307 -24.368000794207674
-1.5631094220363901 -24.40193796275706
-1.5688472335956622 -24.434531219431875
-1.574470288923749 -24.465833783142372
-1.579980883145274 -24.495896765329938
-1.5853812654823685 -24.524769253422875
-1.590673640172721 -24.552498390987324
-1.5958601673692667 -24.57912945470423
-1.6009429640218813 -24.604705928297946
-1.6059241047414436 -24.629269573537343
-1.6108056226466148 -24.652860498425262
-1.6155895101936824 -24.675517222687617
-1.6202777199898086 -24.69727674066919
-1.6248721655900125 -24.718174581738694
-1.6293747222782122 -24.738244868301834
-1.633787227832648 -24.75752037151709
-1.638111483275995 -24.776032564805007
-1.6423492536104751 -24.79381167523873
-1.6465022685382655 -24.810886732899277
-1.6505722231675002 -24.82728561827647
-1.65456077870415 -24.843035107792716
-1.658469563130067 -24.85816091752413
-1.6623001718674657 -24.87268774519017
-1.6660541684301164 -24.886639310480646
-1.6697330850615142 -24.900038393785607
-1.673338423360284 -24.91290687339169
-1.6768716548930782 -24.92526576120538
-1.6803342217952166 -24.937135237061653
-1.6837275373593124 -24.948534681674012
-1.6870529866121262 -24.959482708279722
-1.6903119268798836 -24.969997193031844
-1.6935056883422859 -24.980095304187785
-1.6966355745754402 -24.989793530141945
-1.6997028630839315 -24.999107706348326
-1.7027088058222528 -25.008053041176936
-1.7056546297058077 -25.016644140746322
-1.7085415371116917 -25.024895032772775
-1.7113707063694579 -25.03281918947497
-1.7141432922420687 -25.040429549571762
-1.7168604263972274 -25.04773853940872
-1.7195232178692827 -25.05475809324814
-1.722132753511897 -25.06149967275551
-1.724690098441659 -25.067974285714385
-1.727196296472826 -25.074192504000102
-1.7296523705433695 -25.080164480841702
-1.732059323132502 -25.085899967400373
-1.7344181366698521 -25.091408328691315
-1.7367297739364551 -25.09669855887514
-1.738995178457726 -25.101779295943672
-1.7412152748885714 -25.106658835824312
-1.7433909693908 -25.111345145925668
-1.745523150002984 -25.115845878147013
-1.7476126870029243 -25.120168381372388
-1.7496604332628658 -25.124319713470037
-1.7516672245976086 -25.128306652816633
-1.7536338801056564 -25.132135709365095
-1.7555612025035432 -25.135813135274233
-1.7574499784534723 -25.13934493511738
-1.7593009788844027 -25.142736875686722
-1.7611149593067146 -25.145994495409532
-1.7628926601205803 -25.149123113391312
-1.7646348069181688 -25.152127838101023
-1.7663421107798054 -25.155013575712218
-1.7680152685642092 -25.157785038114014
-1.769654963192925 -25.1604467506047
-1.7712618639290665 -25.163003059280754
-1.7728366266504851 -25.165458138133232
-1.7743798941174755 -25.16781599586315
-1.775892296235126 -25.170080482426968
-1.7773744503104234 -25.172255295322863
-1.778826961304215 -25.174343985628084
-1.7802504220781306 -25.176349963797207
-1.781645413636568 -25.178276505230848
-1.7830125053638366 -25.18012675562369
-1.7843522552565598 -25.181903736101006
-1.7856652101514285 -25.18361034815139
-1.7869519059484 -25.185249378364603
-1.788212867829432 -25.18682350298137
-1.7894486104728433 -25.18833529226331
-1.7906596382633864 -25.18978721468968
-1.7918464454981187 -25.19118164098797
-1.7930095165881563 -25.192520848004847
-1.7941493262563932 -25.193807022423847
-1.7952663397312654 -25.195042264335875
-1.79636101293664 -25.196228590668163
-1.7974337926779074 -25.19736793847771
-1.7984851168243492 -25.19846216811399
-1.7995154144878622 -25.199513066256678
数学解释:
对于求解函数
y
=
10
x
2
+
37
x
+
9
y = 10x^2+37x+9
y=10x2+37x+9最小值问题,首先对函数求导
∂
y
∂
x
=
20
x
+
37
\frac{\partial{y}}{\partial{x}} = 20x+37
∂x∂y=20x+37,y的最小值点在
∂
y
∂
x
\frac{\partial{y}}{\partial{x}}
∂x∂y等于0处,令
∂
y
∂
x
=
0
\frac{\partial{y}}{\partial{x}}=0
∂x∂y=0,得到 x = -1.85 ,此时 y = -25.225
上边代码只要迭代次数足够多,会无限接近数学上的最小值
3. 小结:
上边只是梯度下降的一个简单举例,实际工作中遇到的函数会比这个复杂的多,pytorch中也有自动求导工具,根据链式法则更新权重,这里只是帮助理解梯度下降具体是怎么实现的.