Skip to content
Snippets Groups Projects
Commit 0e8d6e39 authored by Amit Parag's avatar Amit Parag
Browse files

Added comment set create_graph = True in calculation of grads. This was the problem

parent bc47b2e7
No related branches found
No related tags found
No related merge requests found
This is a minimal reproduction of Sobolev learning to highlight that the hessian of the approximated function does not get better.
For instance, see the loss curves ( in images folder) while training different functions.
To run experiments, change function_name in sobolev_training.py and run in python3.
\ No newline at end of file
To run experiments, change function_name in sobolev_training.py and run in python3.
The conclusion is :
1: To use sobolev to guarantee that the derivatives also become better, set create_graph = True in the calculation of hessian and jacobian
2: To use Sobolev loss as a regularizer to the function, set create_graph = False. This wil just guarantee that the function approximation is better.
\ No newline at end of file
No preview for this file type
No preview for this file type
No preview for this file type
......@@ -13,14 +13,15 @@ import matplotlib.pyplot as plt
EPOCHS = 50000 # Number of Epochs
EPOCHS = 1000 # Number of Epochs
lr = 1e-3 # Learning rate
number_of_batches = 1 # Number of batches per epoch
function_name = 'simple_bumps' # See datagen.py or function_definitions.py for other functions to use
number_of_data_points = 5
#function_name = 'simple_bumps' # See datagen.py or function_definitions.py for other functions to use
function_name = 'perm'
number_of_data_points = 20
......@@ -59,13 +60,13 @@ for epoch in range(EPOCHS):
y_hat = network(x)
dy_hat = torch.vstack( [ F.jacobian(network, state).squeeze() for state in x ] ) # Gradient of net
#d2y_hat = torch.stack( [ F.hessian(network, state).squeeze() for state in x ] ) # Hessian of net
dy_hat = torch.vstack( [ F.jacobian(network, state).squeeze() for state in x ] ) # Gradient of net, set create_graph = True
d2y_hat = torch.stack( [ F.hessian(network, state).squeeze() for state in x ] ) # Hessian of net, set create_graph = True
loss1 = torch.nn.functional.mse_loss(y_hat,y)
loss2 = torch.nn.functional.mse_loss(dy_hat, dy)
loss3 = 0#torch.nn.functional.mse_loss(d2y_hat, d2y)
loss3 = torch.nn.functional.mse_loss(d2y_hat, d2y)
loss = loss1 + 10*loss2 + loss3 # Can add a sobolev factor to give weight to each loss term.
# But it does not really change anything
......@@ -75,16 +76,16 @@ for epoch in range(EPOCHS):
batch_loss_in_value += loss1.item()
batch_loss_in_der1 += loss2.item()
#batch_loss_in_der2 += loss3.item()
batch_loss_in_der2 += loss3.item()
epoch_loss_in_value.append( batch_loss_in_value / number_of_batches )
epoch_loss_in_der1.append( batch_loss_in_der1 / number_of_batches )
#epoch_loss_in_der2.append( batch_loss_in_der2 / number_of_batches )
epoch_loss_in_der2.append( batch_loss_in_der2 / number_of_batches )
if epoch % 10 == 0:
print(f"EPOCH : {epoch}")
print(f"Loss Values: {loss1.item()}, Loss Grad : {loss2.item()}") #, Loss Hessian : {loss3.item()}")
print(f"Loss Values: {loss1.item()}, Loss Grad : {loss2.item()} , Loss Hessian : {loss3.item()}")
plt.ion()
......@@ -92,8 +93,8 @@ fig, (ax1, ax2, ax3) = plt.subplots(1,3)
fig.suptitle(function_name.upper())
ax1.semilogy(range(len(epoch_loss_in_value)), epoch_loss_in_value, c = "red")
#ax2.semilogy(range(len(epoch_loss_in_der1)), epoch_loss_in_der1, c = "green")
#ax3.semilogy(range(len(epoch_loss_in_der2)), epoch_loss_in_der2, c = "orange")
ax2.semilogy(range(len(epoch_loss_in_der1)), epoch_loss_in_der1, c = "green")
ax3.semilogy(range(len(epoch_loss_in_der2)), epoch_loss_in_der2, c = "orange")
ax1.set(title='Loss in Value')
ax2.set(title='Loss in Gradient')
......@@ -113,18 +114,18 @@ fig.tight_layout()
#xplt,yplt,dyplt,_ = dataGenerator(function_name, 10000)
#np.save('plt2.npy',{ "x": xplt.numpy(),"y": yplt.numpy(),"dy": dyplt.numpy()})
LOAD = np.load( 'plt2.npy',allow_pickle=True).flat[0]
xplt = torch.tensor(LOAD['x'])
yplt = torch.tensor(LOAD['y'])
dyplt = torch.tensor(LOAD['dy'])
ypred = network(xplt)
plt.figure()
plt.subplot(131)
plt.scatter(xplt[:,0],xplt[:,1],c=yplt[:,0])
plt.subplot(132)
plt.scatter(xplt[:,0],xplt[:,1],c=ypred[:,0].detach())
plt.subplot(133)
plt.scatter(xplt[:,0],xplt[:,1],c=(ypred-yplt)[:,0].detach())
plt.colorbar()
#LOAD = np.load( 'plt2.npy',allow_pickle=True).flat[0]
#xplt = torch.tensor(LOAD['x'])
#yplt = torch.tensor(LOAD['y'])
#dyplt = torch.tensor(LOAD['dy'])
#ypred = network(xplt)
#plt.figure()
#plt.subplot(131)
#plt.scatter(xplt[:,0],xplt[:,1],c=yplt[:,0])
#plt.subplot(132)
#plt.scatter(xplt[:,0],xplt[:,1],c=ypred[:,0].detach())
#plt.subplot(133)
#plt.scatter(xplt[:,0],xplt[:,1],c=(ypred-yplt)[:,0].detach())
#plt.colorbar()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment