From 0e8d6e39b33feb3274a8972afe581d5d2bf13e83 Mon Sep 17 00:00:00 2001 From: Amit Parag <aparag@laas.fr> Date: Sat, 30 Jan 2021 22:58:26 +0100 Subject: [PATCH] Added comment set create_graph = True in calculation of grads. This was the problem --- README.md | 8 ++- __pycache__/datagen.cpython-36.pyc | Bin 1809 -> 2199 bytes .../function_definitions.cpython-36.pyc | Bin 3709 -> 3917 bytes __pycache__/neural_network.cpython-36.pyc | Bin 1649 -> 1718 bytes sobolev_training.py | 53 +++++++++--------- 5 files changed, 34 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index 8b33a74..b581cc1 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,10 @@ This is a minimal reproduction of Sobolev learning to highlight that the hessian of the approximated function does not get better. For instance, see the loss curves ( in images folder) while training different functions. -To run experiments, change function_name in sobolev_training.py and run in python3. \ No newline at end of file +To run experiments, change function_name in sobolev_training.py and run in python3. + +The conclusion is : + +1: To use sobolev to guarantee that the derivatives also become better, set create_graph = True in the calculation of hessian and jacobian + +2: To use Sobolev loss as a regularizer to the function, set create_graph = False. This wil just guarantee that the function approximation is better. \ No newline at end of file diff --git a/__pycache__/datagen.cpython-36.pyc b/__pycache__/datagen.cpython-36.pyc index 5e5b6d4824942b1d4869d75a79f78d9a63fa14ea..ec7461e9c21b457f4a50f7eb25423c2c801d153e 100644 GIT binary patch literal 2199 zcmah~OKjXk7@o1cYp?f_O|p45ZA~RaBl2idiBlR?B2^2jDkNH3EGpyd%x=8#tK&(t zyEqrN(Dnc-cZ4{CkT`PS#EBz9h|`=voK`{{IaHwV&vcUpiHEKE|8M4-m%sm?@r7#D z{^HxIw{KSv`UM?(ihwUc$$kJskc$W=T8LdudEM2OH{1evJuF71Yetr9MJ2Zs*{&Ux z-EvfMD^b<0Mm4t<jk#kO5rY)&Bezb+NpVMG-x1TDAeK8x>vV!n4z#xnx8Y8?)9wtN zrH#%Uoq{qAWu`MXFkrmg!+ht2yFeE@CkF^4@TX{lXheGtcGlA7&{3_VYE#EKVt=5+ zEZXa6vD|;fewk69Q-}E6cUlQ^vScd>={<)tKZt|4?R1^x<>d{hb+)^><PhmrNf*%B zLoDmIJp_r+^~I7X`pun?_JmD>ZW5;~Xi{OdSwQ#vct;p8cTbd~pt<dbbT4T3g^^Mg ziDH`U(J&Mx=4YG*?I0GFEQr#OdRx0unl&}K06aA)5GbTW$pl8I_gfC4fRqay?SlyW zI+-DQTYH4!?}0IZd5Pl@GfN7J`3e{*|7%<uF>}OJ%*Vi3z;w7iV&+F<-UX%v%=_FJ zF*YxBifU8~xH52`a&yF;Aev&X0aF3yYi^C01yWPY8ZcwP{K!iq<|L^rW)_%nV1DH` zFca`ar^uvaZlU;$fq0o$TH22N7_&Nt1i*+_Nn?O9dK1OkZIt6j7<N{BfF3|PkSQ{K z-{NDrKERu3fCqTr;PuXUt^q0lnoy|NM*Aid3reZGnH!tv&icNcYq<^m%6UQgN?ugH znwy=;ye!wuEkG58N`R^oJ;rz4TZlJ0Q@NqmlgJ$4&NMkqAj$Vl#aYlYGk`T!tpq(w zxsjW>m6uuuB&gHhvK@GCYzLXM7lfhHe44JlO4y8YCuuoZ(5KEtC*F;=DD#q*CzIJr zlOX2Vl1gg6O`UddkH$`GH*WGEiC5IT$}pK(LzzxP2MkllmdF4T+y<HzWK~`V3=$48 zQyL4*P<9zAwgjmCPUFifKldA#m>fS;6Uzbp`ptMJPWGOy-;~Yrk$vRElTflVVA?20 zcjosoM;+YJ9-@bspqAFcpJ*SVkF>i|7R<(><DxFKI2A?a$0Uh_xf=&?fTKlCXe;Lg z_5c*OlHHiVDTzNvPcE)*ClOusQ=j?m)y>OqT)wtSyKwk|h=Nk9qm#dUyiF?ErtwPJ zdt$GLLB^X&lx|#=dwsg`{&TX+6&h3K^MtLehe^{9vkd@;y@Q_)4i19frTOg75gGzG z82$m~P_l1^P&r_y<cH0wkf=bY7#x#-0(haN!LtxBDZJ8Kly3A=dYQ3=iPCjH8o>hG zK`|G_?To|KBnqHcn}T%2wO;j`$yVUUu6dPa834Gw(v-$T6yRf;+oE`b#u-d!GMa7s zDHUdRcbC#WU95@fVJ>+wtR<%YG5>^h*h>^*5=5C>?5eBPEfY}wp3fmI&mXU>xn@uG zIDZI^S}q9O6}Trf(sK>6=JnW|luKhXQi54&=D~>S-(n&!%Z!6T1r1x+fY;Xk+gJXV zZ{S+N&`;?$%q3?uG#loI7rc$@U?*@55L=TQvOF(>--+j`I0f|=V5cgDvB1s6)57#T zk~BRJG=jUn$7GZlbV6<g8c5J;@gxZ|_PlIr94hJ&+Jb+SJkHs_XUx%AeU>Oko05Tp zO}!|WTaso~nsd^e2P1T{)?@Neih>HiTaSF6h6xXYEjSaQzvUK`6M7nQp~FrfLz+J4 pl&RPpGb%FHdX$jekZ#C(Ba=-ANG;zb-MVR2rIZtRTC<Ub{{r#jMT7tV delta 1087 zcmaiz&ubJ(6vwNotEPX<@0m_A8U!^EjY`B_SWrZ=yIB=M1Vz+lmv&5bkJHIi&(_qS z(He4?96T+CJ}fABUjz^0Um)sV;Hhuo#oYJg$ycp{u)<F9zV)3?y?U>^tNZWfPHo}G zYSnx7vURgwM(92I=EuOVK#U&jEc|vJBThNnMnOGl@X|Vw+uRPC+zDDyJ!(d+DfyiS z?Z6Lag4t*;YA2n@hnRtwO*&Hw`^gow<nDaDjGgQvzOc1MzBH@DY~okb><va}^Mytn zLywL^6T;h^^vJ{#kAPZ0mlX!;b6U`CpbY4$B0y(&sh|R=4Kz>|(Ah7!*MT}frJ|E9 z!d=CZ&*>|$;K9NJWkdQLFBe6Q0IdLhq8y+dUMr{%v<~#Sa)IvQ4MVRXv7aL4sd7x# zJ^2<RGgO5)rx>FfNRaDD;|Yb1Rft+r-_n+*ycLreKOuji&+z|QY)$cmsYcS&1Y8MR zt8jJ?O&kap!n?bw*(&<u^rWnbE<;{Lmx{Tn?P9L!$Vu9|Vy;fR;Hri5z}1X<i2po$ z4JkjF(X41_oHNDAEZ@uFuqRHj<^q|ULet`(2brE`+R?7|VhYDSpFOsc%%jIn?pkoV z(gE3}M$FqKU69^=|Jk{B_xzQ6jjLi^q<6&rcr4abKNY+1E}$%A$FhfP-p`2`=49BH zw%KYbNjH1VmKr756+9i}_E_}eR1UI#+0VG1{Zu;Pm2RbD!ACh3*`RdM{Rhb<*!S^N z_JJ;Lg-Uk&qAz;Kc~^CNkysvX=2hd*MIw@2mC71qk8Xk>MzDh^_Q?0O@*iFyEnH*N z>R=!5wLI8KVLv4v)PX^~?2=tuJqo3BCkzL0C1EJdqml3-qNP2#9fmwz3&Wg=@nE<q z+h%PaNKgU5`*Ez&bYmpvv#0hiqpyWTcMw6L119;Pkwqhij2s5ZE&ju%oHLG%Zg+=K l@j$+jBWB4tJxKX@BRXTQXXJxWxIZ%vGv|U+HFwg&p8#U*_R;_V diff --git a/__pycache__/function_definitions.cpython-36.pyc b/__pycache__/function_definitions.cpython-36.pyc index 269157713eade6507725a3054d7e608f9638fb14..6f3cde3159d0afb4bdb6a803b32a2076e794eb02 100644 GIT binary patch delta 1352 zcmZWp%}*0S6rb5{%dl+uECs@c8$n<jAq6o8MMC6a)Nqi435mgeZ0kb5d<(KACLHw0 zrP-5zg24+R@n}4G(t{Tdds7mVCZ0WDjJ`LE6ze4O+c(oU@8kF0>#w1YBcoTMQSIP; z@1q~>j2*LcOF?-Fx3gaxO^K}jf~|FBqqY6PXe`@{hj^B+GBaL#y|l+sEsT4a)F%<1 zu`C~8TWo^|Qda3<ZFCzNU?lgm=RC&nxw>m=2Cm$sYZ^BVH)$jc%dJdKxyeLg)|fKf z8T<x1KyK<3R-GhM@s&~qRK_}($$jp=V$gy2CwIJ>GOkad84K*9RoZgAcgn!9?%2gj z!CFsm7b}j2iVm4#Ds&+p9^d70XhR8f1XDnS<S2r-tW3`Uvcr6y@_E|l8K0MY#hJl9 zD_sWANh*<4X0u>byS*>U03Y;zD&P0ctgZwZ{KjR5rETaTFR|buq)K@UC^A!({-B?J z;2XZ;{^v)_+HM)sw=>p7)RGA?RpWwpa-D;LcNt;{<Ikzz_9!Wt9%e>KH5|h*9my-S z5QGR?5zI)Cw2}iE2#nyR&mEsvnv)iJu&td`X>$ry9HNS;D*qcyi~}D#`VhXrWf*m= z7?VuMfL7|#_YhtF3&;X71A;)7Lvj;3&&_UgC~CoaTD^`h9D2dR$VUjMCG)twjh4@S z4Juj!`;w1g?|~S*L7_L2ZMNN3v@?%W1#8F7?A}Jrh6@e}?~VA%b?;^Cy;}_r9B3@4 zQ>T+Y-!wxICl8WI)lFA78l7{ZEfj2LD^o637T3H9ZQ`L!J_rsU`D!kec24azD3Auu z+Xbot$X--&0E+kV=ukvd66%h2+>ksv&{p}iOQb8qrCxz?ggQkWOahg|R6r_8S)3B) zUDf9%&$z-dob!3EnXO2rykiv#>t0R2z~{X~{cW1;2VHntT3LG8aa1kzb!N2LS}UZS zEwJ-;=?Q8&)I&@a9H`W0pC2yaR97MDeTptuXrN(`z#teSpcVj=AVH8Mm?t0u=bX_$ lNWh+0D3-I^1#6KAacC3;@^?8Tgy<A)A}VxM5gLBG{{X)r@MZu2 delta 1128 zcmZWo&ubG=5Z<@hjj!8mOss7+ji!*+cA>O3Em*OXQUyV!2YV0-Qa8yaZkueD+0-Tl zDM*iA%0mSI1JCvB$*Wfn@2IDMhoZNFKW27I6XU}B_WhXm&3rTSK41Tp%`cdy@#Ag! z*;G<dz9~JQit|+*!QoNf=9}qqqHG>2N3X|pwxL*;k3Os(|72@yxI=~=hG21s<4>xd zFCNBb$0ch?W=mF-tR$Hy)nFb+;k^P^(CqZSc*Snq+iHUCdahTUtL%EErrYqg%1#A$ zXd2A=Z5g*qJ2j_m;GH0apy(RY7-t5K1WPf<h)=Qab^=epL|6w3rt=3sO9zytz(oSO zi)3JXRR@U8(jZL-!*sy5rP?ZlkwtKoc8#*zYIr`lB_|`&`Vh;BN&O^CZh#!3<bNGN zR?NpU1>%)z1wwAPMk^Lb=J)gSD!9&`?d?p8m+?`SKYkO>u>+#-Xb7Q_(9wE_B8IX5 z0acJP-~^QaNL#E|+kU@at>(OF7tlDM8;Z_ZDh5+HtKtv~{L5%p;)RGm)R`6I0~2@7 zvlc4r=N_!PrRR3d*>g+#t4P;%!N3v+17DdbipHb0t_KD>+F7bpXKzZjW5r+y*HPNE zR<XU^t;ngdTyukFsZsY=9*YxW>S-uG3Jxzxtzs90Grn#LLV{jim>~~gjWyqa`U6}B zR8d8v-b%3~b=KT*b8v&*njqUST<R>CMCee&VP(RaSSo}@R3Ge#C?@ecu{eF6D|{A< zl2!V%755u^PObJtY??W?B=*gZMdFXT>#Ve4>0`7@)k0tAr2T9IU<XZbtN4D<GN2w( zG|u2Q-5Zg*w_EqykQDEdcijx`L5?6#K-Gt70*hddV2Ob4_V}T<8zqh}*Bj;CnzKT= X2u2V%rfw$2Q+$w{JfW%F!2hMcG2Yd4 diff --git a/__pycache__/neural_network.cpython-36.pyc b/__pycache__/neural_network.cpython-36.pyc index 10114fb61198a06874e849f6451145f4a83fccc7..b2623ba1c6f76ae6af823a2c10d74057107820f5 100644 GIT binary patch literal 1718 zcmbtUOK%%D5GJ`#%aWqlvSYhNVe|zEXstH689`8>PJ2n6gM%IxC>Ul*S-W0wWpdZC z1@$BYIpxrM|3Lpn5B&u_@!FIALQWlW{p__|@HNBXeDlq4pL9B{;olptUj_*MgYGmR z*55(bf5E^J#|2878RsM+AR&b}@e?06(jW<cMBL;4CE~syd9?77Sgi2igs4BAeJfwh z+ewFqNmro!f#~vRL6V+`dBo%QIC%)tHNL`Imq_#u&;l=TvJSF7U%e;WxV8p+FWD52 z#CpCZHu9|n!3cCDZ|B=$Lu|tBJD}ZJV8Dov(4f79SkU1Bo7I<!3nlnLmR6}1UN@k@ zHT@M#0wtIu2zEt4V87#cXoPvd!^=+MThNy%3Em+eB_UkiGQm$J7scVy1JfnK2!W&i zIOAMMTCfYDG@UZVCOWm`+Ggh#bZem%40G^z6027g0Iq=OIhlK$yeA{VJ=l8-1S{WZ zyc<n$qX`=gU~@Ld!2<t^E58cnXhHID0ifVSGzV(uh{hZN?)Y%lZ9wiikh|Xwh*`}f z9SK&|N@&wA>*{7V+rEEzm!(y9#;UB8vj_K7kN0Og_XT&(?*HfTll>(Sh6XE&z+{M& zCVY`e!Ibe|F*!DYuBSpN%an;J&16<lYFbC)v=*|;SYbMivn#3xzVWpvMk=%xq0Hp8 zu98SLhafX6l%@sKn$kbehlAm`oQNT_`8ggQJpcLmi=j9NhuH*f*F#;tDGPD-^<9}1 zwPFR8qIz4ZlfCJM>Vm5nx?Y0;c`+uqhg+_1k}jEbU4iUfqfc$9w%MTU`g`cwUNT2H zlma9V3II}sp=4DL<O5GVgvo?18f`qtjOv4=88)mg8kOoTQ~VKZv@0}#FzYm~uXF?N zP+66-<vxfkp^{C6C6v-h$?L+d+mxQxtY|bYFD^p0W!-eH=u1UWP<jgjkgnUCEf$xp z2o?@UO#0oM{7o2C7+u=rEN$puW=TtJ5!hyB7M!ExV^|ymDZligno*7`0^|_;8h#_E z{!tS>Bpm->;+&=Bn~X^lKpdhlk%j^hGXX_S_P&8Ks#qm%;X0fjT2G5iS7|w!KKm4W zY3n^i>PwjZ7b3L_H<)NFv<67Z-mN}^VL;U9cKwAN?&r@rwP&tC3_dbJRjPDsVmIwc zsq8PqtfpmA2zY#@)Xk#_WIdT)z}uOc5Yi1vbabn2dwjdGW_48_E5`RmwRE=daN=vB zY28X2=tQ0yUopuIK6mg<cm@onCnn?<bj&J8C~PDC*zFT?er|%P%A{3|*;JT#q6IBV strHJccfxgnyPA!kHbwW$reA*oLkvaR#<91y8N=*DZ{a?CUHE$VUr4s7SpWb4 delta 962 zcmY*Wy-yTD6rY*>xV>fX(>p#4>Ip<#5>R4n3?WLiQyYxM7L1#lIqueF?=rI|VlJUr z5-JO_E&qW30EGpuiGRXc8Vec|TPoiyC&*0R$GqRX_kQ!{r1Zy`e^e^j-;ZnW5*?wR z=)VPL^b|y%9?fs;e`Sc_&<-X7$18CZl$ps23Gpi2N|0BDNR3&ncz}52IZAMXy*k8h zF#B4p@h=l{HLuAhc|DlojbJJv7y(8CCz$39-h|3CfX*ZsL*GqA+=-*#bNlv~pZow0 zM;>O#BV1#c>w&=t)0qC5bO<xxY$gaiGe@*bq;QE?W2AVIvuz}}gQ1pvj1N%<vkB%L zl)M5fvGM`(tV5)uMX$Jt+)8G=>N6gGDtCK~@u(I0dt9_u&om%)hA={DgN)KGJpGo8 z`6{bWCGs|?7{z#m5)zOE5@IdTM=-Myq6S0I0>5_6%#gzY7eeJ^7NxzY7gL&%C`wPs z>@AY%l7}5p1lYw{k<xbPOG)V$bp9;0w0p}m7Ct0J+Ycf>6n;n}9`E+W`<20-n1Cz` zM5<p#R#{6!ZGE*2zN?7}n9S@&gJCQ)tv`%K4I&q+uCPL<FLr&wn&}U17Tc+zFYP;d z(#u}NeUa&J{HU8Z5ls~@joGz%7*cvw&c&jtN~@~*Lkr+YMZ+QlSFxST{leuCcHl*n zo-uV_QE$Vv^OY_cE4Y?^(;qHgCwp$Mg*_R!`#Xd6J8)I1zc7(lNN*cYw$xo$7wSv0 zLYGU~^CQuMlenu8BIcAouhjLjVgX_;ka0ctW@f~F(e7qeZdRg5IOz$HnoNHgk1Kfx zUHB2p@JG=~m(3R&R~2(FAFZ{v(`Unwud9d2Sx^>C!NHbRZ&+X*Y~lv&GVCh;3+k!M A00000 diff --git a/sobolev_training.py b/sobolev_training.py index e34560b..37dac0a 100644 --- a/sobolev_training.py +++ b/sobolev_training.py @@ -13,14 +13,15 @@ import matplotlib.pyplot as plt -EPOCHS = 50000 # Number of Epochs +EPOCHS = 1000 # Number of Epochs lr = 1e-3 # Learning rate number_of_batches = 1 # Number of batches per epoch -function_name = 'simple_bumps' # See datagen.py or function_definitions.py for other functions to use -number_of_data_points = 5 +#function_name = 'simple_bumps' # See datagen.py or function_definitions.py for other functions to use +function_name = 'perm' +number_of_data_points = 20 @@ -59,13 +60,13 @@ for epoch in range(EPOCHS): y_hat = network(x) - dy_hat = torch.vstack( [ F.jacobian(network, state).squeeze() for state in x ] ) # Gradient of net - #d2y_hat = torch.stack( [ F.hessian(network, state).squeeze() for state in x ] ) # Hessian of net + dy_hat = torch.vstack( [ F.jacobian(network, state).squeeze() for state in x ] ) # Gradient of net, set create_graph = True + d2y_hat = torch.stack( [ F.hessian(network, state).squeeze() for state in x ] ) # Hessian of net, set create_graph = True loss1 = torch.nn.functional.mse_loss(y_hat,y) loss2 = torch.nn.functional.mse_loss(dy_hat, dy) - loss3 = 0#torch.nn.functional.mse_loss(d2y_hat, d2y) + loss3 = torch.nn.functional.mse_loss(d2y_hat, d2y) loss = loss1 + 10*loss2 + loss3 # Can add a sobolev factor to give weight to each loss term. # But it does not really change anything @@ -75,16 +76,16 @@ for epoch in range(EPOCHS): batch_loss_in_value += loss1.item() batch_loss_in_der1 += loss2.item() - #batch_loss_in_der2 += loss3.item() + batch_loss_in_der2 += loss3.item() epoch_loss_in_value.append( batch_loss_in_value / number_of_batches ) epoch_loss_in_der1.append( batch_loss_in_der1 / number_of_batches ) - #epoch_loss_in_der2.append( batch_loss_in_der2 / number_of_batches ) + epoch_loss_in_der2.append( batch_loss_in_der2 / number_of_batches ) if epoch % 10 == 0: print(f"EPOCH : {epoch}") - print(f"Loss Values: {loss1.item()}, Loss Grad : {loss2.item()}") #, Loss Hessian : {loss3.item()}") + print(f"Loss Values: {loss1.item()}, Loss Grad : {loss2.item()} , Loss Hessian : {loss3.item()}") plt.ion() @@ -92,8 +93,8 @@ fig, (ax1, ax2, ax3) = plt.subplots(1,3) fig.suptitle(function_name.upper()) ax1.semilogy(range(len(epoch_loss_in_value)), epoch_loss_in_value, c = "red") -#ax2.semilogy(range(len(epoch_loss_in_der1)), epoch_loss_in_der1, c = "green") -#ax3.semilogy(range(len(epoch_loss_in_der2)), epoch_loss_in_der2, c = "orange") +ax2.semilogy(range(len(epoch_loss_in_der1)), epoch_loss_in_der1, c = "green") +ax3.semilogy(range(len(epoch_loss_in_der2)), epoch_loss_in_der2, c = "orange") ax1.set(title='Loss in Value') ax2.set(title='Loss in Gradient') @@ -113,18 +114,18 @@ fig.tight_layout() #xplt,yplt,dyplt,_ = dataGenerator(function_name, 10000) #np.save('plt2.npy',{ "x": xplt.numpy(),"y": yplt.numpy(),"dy": dyplt.numpy()}) -LOAD = np.load( 'plt2.npy',allow_pickle=True).flat[0] -xplt = torch.tensor(LOAD['x']) -yplt = torch.tensor(LOAD['y']) -dyplt = torch.tensor(LOAD['dy']) - -ypred = network(xplt) - -plt.figure() -plt.subplot(131) -plt.scatter(xplt[:,0],xplt[:,1],c=yplt[:,0]) -plt.subplot(132) -plt.scatter(xplt[:,0],xplt[:,1],c=ypred[:,0].detach()) -plt.subplot(133) -plt.scatter(xplt[:,0],xplt[:,1],c=(ypred-yplt)[:,0].detach()) -plt.colorbar() +#LOAD = np.load( 'plt2.npy',allow_pickle=True).flat[0] +#xplt = torch.tensor(LOAD['x']) +#yplt = torch.tensor(LOAD['y']) +#dyplt = torch.tensor(LOAD['dy']) + +#ypred = network(xplt) + +#plt.figure() +#plt.subplot(131) +#plt.scatter(xplt[:,0],xplt[:,1],c=yplt[:,0]) +#plt.subplot(132) +#plt.scatter(xplt[:,0],xplt[:,1],c=ypred[:,0].detach()) +#plt.subplot(133) +#plt.scatter(xplt[:,0],xplt[:,1],c=(ypred-yplt)[:,0].detach()) +#plt.colorbar() -- GitLab