From 0e8d6e39b33feb3274a8972afe581d5d2bf13e83 Mon Sep 17 00:00:00 2001
From: Amit Parag <aparag@laas.fr>
Date: Sat, 30 Jan 2021 22:58:26 +0100
Subject: [PATCH] Added comment set create_graph = True in calculation of
 grads. This was the problem

---
 README.md                                     |   8 ++-
 __pycache__/datagen.cpython-36.pyc            | Bin 1809 -> 2199 bytes
 .../function_definitions.cpython-36.pyc       | Bin 3709 -> 3917 bytes
 __pycache__/neural_network.cpython-36.pyc     | Bin 1649 -> 1718 bytes
 sobolev_training.py                           |  53 +++++++++---------
 5 files changed, 34 insertions(+), 27 deletions(-)

diff --git a/README.md b/README.md
index 8b33a74..b581cc1 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,10 @@
 This is a minimal reproduction of Sobolev learning to highlight that the hessian of the approximated function does not get better.
 For instance, see the loss curves ( in images folder) while training different functions.
 
-To run experiments, change function_name in sobolev_training.py and run in python3.
\ No newline at end of file
+To run experiments, change function_name in sobolev_training.py and run in python3.
+
+The conclusion is :
+
+1: To use sobolev to guarantee that the derivatives also become better, set create_graph = True in the calculation of hessian and jacobian 
+
+2: To use Sobolev loss as a regularizer to the function, set create_graph = False. This wil just guarantee that the function approximation is better.
\ No newline at end of file
diff --git a/__pycache__/datagen.cpython-36.pyc b/__pycache__/datagen.cpython-36.pyc
index 5e5b6d4824942b1d4869d75a79f78d9a63fa14ea..ec7461e9c21b457f4a50f7eb25423c2c801d153e 100644
GIT binary patch
literal 2199
zcmah~OKjXk7@o1cYp?f_O|p45ZA~RaBl2idiBlR?B2^2jDkNH3EGpyd%x=8#tK&(t
zyEqrN(Dnc-cZ4{CkT`PS#EBz9h|`=voK`{{IaHwV&vcUpiHEKE|8M4-m%sm?@r7#D
z{^HxIw{KSv`UM?(ihwUc$$kJskc$W=T8LdudEM2OH{1evJuF71Yetr9MJ2Zs*{&Ux
z-EvfMD^b<0Mm4t<jk#kO5rY)&Bezb+NpVMG-x1TDAeK8x>vV!n4z#xnx8Y8?)9wtN
zrH#%Uoq{qAWu`MXFkrmg!+ht2yFeE@CkF^4@TX{lXheGtcGlA7&{3_VYE#EKVt=5+
zEZXa6vD|;fewk69Q-}E6cUlQ^vScd>={<)tKZt|4?R1^x<>d{hb+)^><PhmrNf*%B
zLoDmIJp_r+^~I7X`pun?_JmD>ZW5;~Xi{OdSwQ#vct;p8cTbd~pt<dbbT4T3g^^Mg
ziDH`U(J&Mx=4YG*?I0GFEQr#OdRx0unl&}K06aA)5GbTW$pl8I_gfC4fRqay?SlyW
zI+-DQTYH4!?}0IZd5Pl@GfN7J`3e{*|7%<uF>}OJ%*Vi3z;w7iV&+F<-UX%v%=_FJ
zF*YxBifU8~xH52`a&yF;Aev&X0aF3yYi^C01yWPY8ZcwP{K!iq<|L^rW)_%nV1DH`
zFca`ar^uvaZlU;$fq0o$TH22N7_&Nt1i*+_Nn?O9dK1OkZIt6j7<N{BfF3|PkSQ{K
z-{NDrKERu3fCqTr;PuXUt^q0lnoy|NM*Aid3reZGnH!tv&icNcYq<^m%6UQgN?ugH
znwy=;ye!wuEkG58N`R^oJ;rz4TZlJ0Q@NqmlgJ$4&NMkqAj$Vl#aYlYGk`T!tpq(w
zxsjW>m6uuuB&gHhvK@GCYzLXM7lfhHe44JlO4y8YCuuoZ(5KEtC*F;=DD#q*CzIJr
zlOX2Vl1gg6O`UddkH$`GH*WGEiC5IT$}pK(LzzxP2MkllmdF4T+y<HzWK~`V3=$48
zQyL4*P<9zAwgjmCPUFifKldA#m>fS;6Uzbp`ptMJPWGOy-;~Yrk$vRElTflVVA?20
zcjosoM;+YJ9-@bspqAFcpJ*SVkF>i|7R<(><DxFKI2A?a$0Uh_xf=&?fTKlCXe;Lg
z_5c*OlHHiVDTzNvPcE)*ClOusQ=j?m)y>OqT)wtSyKwk|h=Nk9qm#dUyiF?ErtwPJ
zdt$GLLB^X&lx|#=dwsg`{&TX+6&h3K^MtLehe^{9vkd@;y@Q_)4i19frTOg75gGzG
z82$m~P_l1^P&r_y<cH0wkf=bY7#x#-0(haN!LtxBDZJ8Kly3A=dYQ3=iPCjH8o>hG
zK`|G_?To|KBnqHcn}T%2wO;j`$yVUUu6dPa834Gw(v-$T6yRf;+oE`b#u-d!GMa7s
zDHUdRcbC#WU95@fVJ>+wtR<%YG5>^h*h>^*5=5C>?5eBPEfY}wp3fmI&mXU>xn@uG
zIDZI^S}q9O6}Trf(sK>6=JnW|luKhXQi54&=D~>S-(n&!%Z!6T1r1x+fY;Xk+gJXV
zZ{S+N&`;?$%q3?uG#loI7rc$@U?*@55L=TQvOF(>--+j`I0f|=V5cgDvB1s6)57#T
zk~BRJG=jUn$7GZlbV6<g8c5J;@gxZ|_PlIr94hJ&+Jb+SJkHs_XUx%AeU>Oko05Tp
zO}!|WTaso~nsd^e2P1T{)?@Neih>HiTaSF6h6xXYEjSaQzvUK`6M7nQp~FrfLz+J4
pl&RPpGb%FHdX$jekZ#C(Ba=-ANG;zb-MVR2rIZtRTC<Ub{{r#jMT7tV

delta 1087
zcmaiz&ubJ(6vwNotEPX<@0m_A8U!^EjY`B_SWrZ=yIB=M1Vz+lmv&5bkJHIi&(_qS
z(He4?96T+CJ}fABUjz^0Um)sV;Hhuo#oYJg$ycp{u)<F9zV)3?y?U>^tNZWfPHo}G
zYSnx7vURgwM(92I=EuOVK#U&jEc|vJBThNnMnOGl@X|Vw+uRPC+zDDyJ!(d+DfyiS
z?Z6Lag4t*;YA2n@hnRtwO*&Hw`^gow<nDaDjGgQvzOc1MzBH@DY~okb><va}^Mytn
zLywL^6T;h^^vJ{#kAPZ0mlX!;b6U`CpbY4$B0y(&sh|R=4Kz>|(Ah7!*MT}frJ|E9
z!d=CZ&*>|$;K9NJWkdQLFBe6Q0IdLhq8y+dUMr{%v<~#Sa)IvQ4MVRXv7aL4sd7x#
zJ^2<RGgO5)rx>FfNRaDD;|Yb1Rft+r-_n+*ycLreKOuji&+z|QY)$cmsYcS&1Y8MR
zt8jJ?O&kap!n?bw*(&<u^rWnbE<;{Lmx{Tn?P9L!$Vu9|Vy;fR;Hri5z}1X<i2po$
z4JkjF(X41_oHNDAEZ@uFuqRHj<^q|ULet`(2brE`+R?7|VhYDSpFOsc%%jIn?pkoV
z(gE3}M$FqKU69^=|Jk{B_xzQ6jjLi^q<6&rcr4abKNY+1E}$%A$FhfP-p`2`=49BH
zw%KYbNjH1VmKr756+9i}_E_}eR1UI#+0VG1{Zu;Pm2RbD!ACh3*`RdM{Rhb<*!S^N
z_JJ;Lg-Uk&qAz;Kc~^CNkysvX=2hd*MIw@2mC71qk8Xk>MzDh^_Q?0O@*iFyEnH*N
z>R=!5wLI8KVLv4v)PX^~?2=tuJqo3BCkzL0C1EJdqml3-qNP2#9fmwz3&Wg=@nE<q
z+h%PaNKgU5`*Ez&bYmpvv#0hiqpyWTcMw6L119;Pkwqhij2s5ZE&ju%oHLG%Zg+=K
l@j$+jBWB4tJxKX@BRXTQXXJxWxIZ%vGv|U+HFwg&p8#U*_R;_V

diff --git a/__pycache__/function_definitions.cpython-36.pyc b/__pycache__/function_definitions.cpython-36.pyc
index 269157713eade6507725a3054d7e608f9638fb14..6f3cde3159d0afb4bdb6a803b32a2076e794eb02 100644
GIT binary patch
delta 1352
zcmZWp%}*0S6rb5{%dl+uECs@c8$n<jAq6o8MMC6a)Nqi435mgeZ0kb5d<(KACLHw0
zrP-5zg24+R@n}4G(t{Tdds7mVCZ0WDjJ`LE6ze4O+c(oU@8kF0>#w1YBcoTMQSIP;
z@1q~>j2*LcOF?-Fx3gaxO^K}jf~|FBqqY6PXe`@{hj^B+GBaL#y|l+sEsT4a)F%<1
zu`C~8TWo^|Qda3<ZFCzNU?lgm=RC&nxw>m=2Cm$sYZ^BVH)$jc%dJdKxyeLg)|fKf
z8T<x1KyK<3R-GhM@s&~qRK_}($$jp=V$gy2CwIJ>GOkad84K*9RoZgAcgn!9?%2gj
z!CFsm7b}j2iVm4#Ds&+p9^d70XhR8f1XDnS<S2r-tW3`Uvcr6y@_E|l8K0MY#hJl9
zD_sWANh*<4X0u>byS*>U03Y;zD&P0ctgZwZ{KjR5rETaTFR|buq)K@UC^A!({-B?J
z;2XZ;{^v)_+HM)sw=>p7)RGA?RpWwpa-D;LcNt;{<Ikzz_9!Wt9%e>KH5|h*9my-S
z5QGR?5zI)Cw2}iE2#nyR&mEsvnv)iJu&td`X>$ry9HNS;D*qcyi~}D#`VhXrWf*m=
z7?VuMfL7|#_YhtF3&;X71A;)7Lvj;3&&_UgC~CoaTD^`h9D2dR$VUjMCG)twjh4@S
z4Juj!`;w1g?|~S*L7_L2ZMNN3v@?%W1#8F7?A}Jrh6@e}?~VA%b?;^Cy;}_r9B3@4
zQ>T+Y-!wxICl8WI)lFA78l7{ZEfj2LD^o637T3H9ZQ`L!J_rsU`D!kec24azD3Auu
z+Xbot$X--&0E+kV=ukvd66%h2+>ksv&{p}iOQb8qrCxz?ggQkWOahg|R6r_8S)3B)
zUDf9%&$z-dob!3EnXO2rykiv#>t0R2z~{X~{cW1;2VHntT3LG8aa1kzb!N2LS}UZS
zEwJ-;=?Q8&)I&@a9H`W0pC2yaR97MDeTptuXrN(`z#teSpcVj=AVH8Mm?t0u=bX_$
lNWh+0D3-I^1#6KAacC3;@^?8Tgy<A)A}VxM5gLBG{{X)r@MZu2

delta 1128
zcmZWo&ubG=5Z<@hjj!8mOss7+ji!*+cA>O3Em*OXQUyV!2YV0-Qa8yaZkueD+0-Tl
zDM*iA%0mSI1JCvB$*Wfn@2IDMhoZNFKW27I6XU}B_WhXm&3rTSK41Tp%`cdy@#Ag!
z*;G<dz9~JQit|+*!QoNf=9}qqqHG>2N3X|pwxL*;k3Os(|72@yxI=~=hG21s<4>xd
zFCNBb$0ch?W=mF-tR$Hy)nFb+;k^P^(CqZSc*Snq+iHUCdahTUtL%EErrYqg%1#A$
zXd2A=Z5g*qJ2j_m;GH0apy(RY7-t5K1WPf<h)=Qab^=epL|6w3rt=3sO9zytz(oSO
zi)3JXRR@U8(jZL-!*sy5rP?ZlkwtKoc8#*zYIr`lB_|`&`Vh;BN&O^CZh#!3<bNGN
zR?NpU1>%)z1wwAPMk^Lb=J)gSD!9&`?d?p8m+?`SKYkO>u>+#-Xb7Q_(9wE_B8IX5
z0acJP-~^QaNL#E|+kU@at>(OF7tlDM8;Z_ZDh5+HtKtv~{L5%p;)RGm)R`6I0~2@7
zvlc4r=N_!PrRR3d*>g+#t4P;%!N3v+17DdbipHb0t_KD>+F7bpXKzZjW5r+y*HPNE
zR<XU^t;ngdTyukFsZsY=9*YxW>S-uG3Jxzxtzs90Grn#LLV{jim>~~gjWyqa`U6}B
zR8d8v-b%3~b=KT*b8v&*njqUST<R>CMCee&VP(RaSSo}@R3Ge#C?@ecu{eF6D|{A<
zl2!V%755u^PObJtY??W?B=*gZMdFXT>#Ve4>0`7@)k0tAr2T9IU<XZbtN4D<GN2w(
zG|u2Q-5Zg*w_EqykQDEdcijx`L5?6#K-Gt70*hddV2Ob4_V}T<8zqh}*Bj;CnzKT=
X2u2V%rfw$2Q+$w{JfW%F!2hMcG2Yd4

diff --git a/__pycache__/neural_network.cpython-36.pyc b/__pycache__/neural_network.cpython-36.pyc
index 10114fb61198a06874e849f6451145f4a83fccc7..b2623ba1c6f76ae6af823a2c10d74057107820f5 100644
GIT binary patch
literal 1718
zcmbtUOK%%D5GJ`#%aWqlvSYhNVe|zEXstH689`8>PJ2n6gM%IxC>Ul*S-W0wWpdZC
z1@$BYIpxrM|3Lpn5B&u_@!FIALQWlW{p__|@HNBXeDlq4pL9B{;olptUj_*MgYGmR
z*55(bf5E^J#|2878RsM+AR&b}@e?06(jW<cMBL;4CE~syd9?77Sgi2igs4BAeJfwh
z+ewFqNmro!f#~vRL6V+`dBo%QIC%)tHNL`Imq_#u&;l=TvJSF7U%e;WxV8p+FWD52
z#CpCZHu9|n!3cCDZ|B=$Lu|tBJD}ZJV8Dov(4f79SkU1Bo7I<!3nlnLmR6}1UN@k@
zHT@M#0wtIu2zEt4V87#cXoPvd!^=+MThNy%3Em+eB_UkiGQm$J7scVy1JfnK2!W&i
zIOAMMTCfYDG@UZVCOWm`+Ggh#bZem%40G^z6027g0Iq=OIhlK$yeA{VJ=l8-1S{WZ
zyc<n$qX`=gU~@Ld!2<t^E58cnXhHID0ifVSGzV(uh{hZN?)Y%lZ9wiikh|Xwh*`}f
z9SK&|N@&wA>*{7V+rEEzm!(y9#;UB8vj_K7kN0Og_XT&(?*HfTll>(Sh6XE&z+{M&
zCVY`e!Ibe|F*!DYuBSpN%an;J&16<lYFbC)v=*|;SYbMivn#3xzVWpvMk=%xq0Hp8
zu98SLhafX6l%@sKn$kbehlAm`oQNT_`8ggQJpcLmi=j9NhuH*f*F#;tDGPD-^<9}1
zwPFR8qIz4ZlfCJM>Vm5nx?Y0;c`+uqhg+_1k}jEbU4iUfqfc$9w%MTU`g`cwUNT2H
zlma9V3II}sp=4DL<O5GVgvo?18f`qtjOv4=88)mg8kOoTQ~VKZv@0}#FzYm~uXF?N
zP+66-<vxfkp^{C6C6v-h$?L+d+mxQxtY|bYFD^p0W!-eH=u1UWP<jgjkgnUCEf$xp
z2o?@UO#0oM{7o2C7+u=rEN$puW=TtJ5!hyB7M!ExV^|ymDZligno*7`0^|_;8h#_E
z{!tS>Bpm->;+&=Bn~X^lKpdhlk%j^hGXX_S_P&8Ks#qm%;X0fjT2G5iS7|w!KKm4W
zY3n^i>PwjZ7b3L_H<)NFv<67Z-mN}^VL;U9cKwAN?&r@rwP&tC3_dbJRjPDsVmIwc
zsq8PqtfpmA2zY#@)Xk#_WIdT)z}uOc5Yi1vbabn2dwjdGW_48_E5`RmwRE=daN=vB
zY28X2=tQ0yUopuIK6mg<cm@onCnn?<bj&J8C~PDC*zFT?er|%P%A{3|*;JT#q6IBV
strHJccfxgnyPA!kHbwW$reA*oLkvaR#<91y8N=*DZ{a?CUHE$VUr4s7SpWb4

delta 962
zcmY*Wy-yTD6rY*>xV>fX(>p#4>Ip<#5>R4n3?WLiQyYxM7L1#lIqueF?=rI|VlJUr
z5-JO_E&qW30EGpuiGRXc8Vec|TPoiyC&*0R$GqRX_kQ!{r1Zy`e^e^j-;ZnW5*?wR
z=)VPL^b|y%9?fs;e`Sc_&<-X7$18CZl$ps23Gpi2N|0BDNR3&ncz}52IZAMXy*k8h
zF#B4p@h=l{HLuAhc|DlojbJJv7y(8CCz$39-h|3CfX*ZsL*GqA+=-*#bNlv~pZow0
zM;>O#BV1#c>w&=t)0qC5bO<xxY$gaiGe@*bq;QE?W2AVIvuz}}gQ1pvj1N%<vkB%L
zl)M5fvGM`(tV5)uMX$Jt+)8G=>N6gGDtCK~@u(I0dt9_u&om%)hA={DgN)KGJpGo8
z`6{bWCGs|?7{z#m5)zOE5@IdTM=-Myq6S0I0>5_6%#gzY7eeJ^7NxzY7gL&%C`wPs
z>@AY%l7}5p1lYw{k<xbPOG)V$bp9;0w0p}m7Ct0J+Ycf>6n;n}9`E+W`<20-n1Cz`
zM5<p#R#{6!ZGE*2zN?7}n9S@&gJCQ)tv`%K4I&q+uCPL<FLr&wn&}U17Tc+zFYP;d
z(#u}NeUa&J{HU8Z5ls~@joGz%7*cvw&c&jtN~@~*Lkr+YMZ+QlSFxST{leuCcHl*n
zo-uV_QE$Vv^OY_cE4Y?^(;qHgCwp$Mg*_R!`#Xd6J8)I1zc7(lNN*cYw$xo$7wSv0
zLYGU~^CQuMlenu8BIcAouhjLjVgX_;ka0ctW@f~F(e7qeZdRg5IOz$HnoNHgk1Kfx
zUHB2p@JG=~m(3R&R~2(FAFZ{v(`Unwud9d2Sx^>C!NHbRZ&+X*Y~lv&GVCh;3+k!M
A00000

diff --git a/sobolev_training.py b/sobolev_training.py
index e34560b..37dac0a 100644
--- a/sobolev_training.py
+++ b/sobolev_training.py
@@ -13,14 +13,15 @@ import matplotlib.pyplot as plt
 
 
 
-EPOCHS                = 50000                        # Number of Epochs
+EPOCHS                = 1000                        # Number of Epochs
 lr                    = 1e-3                       # Learning rate
 number_of_batches     = 1                         # Number of batches per epoch
 
 
 
-function_name         = 'simple_bumps'                   # See datagen.py or function_definitions.py for other functions to use
-number_of_data_points = 5
+#function_name         = 'simple_bumps'                   # See datagen.py or function_definitions.py for other functions to use
+function_name          = 'perm'
+number_of_data_points  = 20
 
 
 
@@ -59,13 +60,13 @@ for epoch in range(EPOCHS):
         
         y_hat  = network(x)
         
-        dy_hat  = torch.vstack( [ F.jacobian(network, state).squeeze() for state in x ] )   # Gradient of net
-        #d2y_hat = torch.stack( [ F.hessian(network, state).squeeze() for state in x ] )     # Hessian of net
+        dy_hat  = torch.vstack( [ F.jacobian(network, state).squeeze() for state in x ] )   # Gradient of net, set create_graph = True
+        d2y_hat = torch.stack( [ F.hessian(network, state).squeeze() for state in x ] )     # Hessian of net, set create_graph = True
         
         
         loss1   = torch.nn.functional.mse_loss(y_hat,y)
         loss2   = torch.nn.functional.mse_loss(dy_hat, dy)
-        loss3   = 0#torch.nn.functional.mse_loss(d2y_hat, d2y) 
+        loss3   = torch.nn.functional.mse_loss(d2y_hat, d2y) 
 
         loss    = loss1 + 10*loss2 + loss3                            # Can add a sobolev factor to give weight to each loss term.
                                                                    # But it does not really change anything     
@@ -75,16 +76,16 @@ for epoch in range(EPOCHS):
 
         batch_loss_in_value += loss1.item()
         batch_loss_in_der1  += loss2.item()
-        #batch_loss_in_der2  += loss3.item()
+        batch_loss_in_der2  += loss3.item()
 
     epoch_loss_in_value.append( batch_loss_in_value / number_of_batches )
     epoch_loss_in_der1.append( batch_loss_in_der1 / number_of_batches )
-    #epoch_loss_in_der2.append( batch_loss_in_der2 / number_of_batches )
+    epoch_loss_in_der2.append( batch_loss_in_der2 / number_of_batches )
 
     
     if epoch % 10 == 0:
             print(f"EPOCH : {epoch}")
-            print(f"Loss Values:  {loss1.item()}, Loss Grad : {loss2.item()}") #, Loss Hessian : {loss3.item()}")
+            print(f"Loss Values:  {loss1.item()}, Loss Grad : {loss2.item()} , Loss Hessian : {loss3.item()}")
 
 plt.ion()
 
@@ -92,8 +93,8 @@ fig, (ax1, ax2, ax3) = plt.subplots(1,3)
 fig.suptitle(function_name.upper())
 
 ax1.semilogy(range(len(epoch_loss_in_value)), epoch_loss_in_value, c = "red")
-#ax2.semilogy(range(len(epoch_loss_in_der1)), epoch_loss_in_der1, c = "green")
-#ax3.semilogy(range(len(epoch_loss_in_der2)), epoch_loss_in_der2, c = "orange")
+ax2.semilogy(range(len(epoch_loss_in_der1)), epoch_loss_in_der1, c = "green")
+ax3.semilogy(range(len(epoch_loss_in_der2)), epoch_loss_in_der2, c = "orange")
 
 ax1.set(title='Loss in Value')
 ax2.set(title='Loss in Gradient')
@@ -113,18 +114,18 @@ fig.tight_layout()
 
 #xplt,yplt,dyplt,_            = dataGenerator(function_name, 10000)
 #np.save('plt2.npy',{ "x": xplt.numpy(),"y": yplt.numpy(),"dy": dyplt.numpy()})
-LOAD = np.load( 'plt2.npy',allow_pickle=True).flat[0]
-xplt = torch.tensor(LOAD['x'])
-yplt = torch.tensor(LOAD['y'])
-dyplt = torch.tensor(LOAD['dy'])
-
-ypred = network(xplt)
-
-plt.figure()
-plt.subplot(131)
-plt.scatter(xplt[:,0],xplt[:,1],c=yplt[:,0])
-plt.subplot(132)
-plt.scatter(xplt[:,0],xplt[:,1],c=ypred[:,0].detach())
-plt.subplot(133)
-plt.scatter(xplt[:,0],xplt[:,1],c=(ypred-yplt)[:,0].detach())
-plt.colorbar()
+#LOAD = np.load( 'plt2.npy',allow_pickle=True).flat[0]
+#xplt = torch.tensor(LOAD['x'])
+#yplt = torch.tensor(LOAD['y'])
+#dyplt = torch.tensor(LOAD['dy'])
+
+#ypred = network(xplt)
+
+#plt.figure()
+#plt.subplot(131)
+#plt.scatter(xplt[:,0],xplt[:,1],c=yplt[:,0])
+#plt.subplot(132)
+#plt.scatter(xplt[:,0],xplt[:,1],c=ypred[:,0].detach())
+#plt.subplot(133)
+#plt.scatter(xplt[:,0],xplt[:,1],c=(ypred-yplt)[:,0].detach())
+#plt.colorbar()
-- 
GitLab