diff --git a/README.md b/README.md index 59af2bff86896a74f65ed1f9f0f328497a0d9430..8b33a74ed8ede6d59c3ca6522b222a9924d92087 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,4 @@ This is a minimal reproduction of Sobolev learning to highlight that the hessian of the approximated function does not get better. For instance, see the loss curves ( in images folder) while training different functions. + +To run experiments, change function_name in sobolev_training.py and run in python3. \ No newline at end of file diff --git a/sobolev_training.py b/sobolev_training.py index 49aa2dd49e9e3c36ae5ac6bd1170fc040c045449..10973cab9d17e96bb840ac5039f51a57e919050b 100644 --- a/sobolev_training.py +++ b/sobolev_training.py @@ -7,6 +7,12 @@ from datagen import dataGenerator import torch.autograd.functional as F import matplotlib.pyplot as plt + + +# .............................................................................. + + + EPOCHS = 500 # Number of Epochs lr = 1e-3 # Learning rate number_of_batches = 10 # Number of batches per epoch @@ -15,13 +21,21 @@ number_of_batches = 10 # Number of batches per epoch function_name = 'ackley' # See datagen.py or function_definitions.py for other functions to use number_of_data_points = 200 + + + +#............................................................................. + + + + + X,Y,DY,D2Y = dataGenerator(function_name, number_of_data_points) dataset = torch.utils.data.TensorDataset(X,Y,DY,D2Y) dataloader = torch.utils.data.DataLoader(dataset, batch_size = number_of_data_points // number_of_batches, shuffle=True, num_workers=4) - network = Model() optimizer = torch.optim.Adam(params = network.parameters(), lr = lr) @@ -44,15 +58,17 @@ for epoch in range(EPOCHS): x,y,dy,d2y = data y_hat = network(x) - dy_hat = torch.vstack( [ F.jacobian(network, state).squeeze() for state in x ] ) - d2y_hat = torch.stack( [ F.hessian(network, state).squeeze() for state in x ] ) - + + dy_hat = torch.vstack( [ F.jacobian(network, state).squeeze() for state in x ] ) # Gradient of net + d2y_hat = torch.stack( [ F.hessian(network, state).squeeze() for state in x ] ) # Hessian of net + + loss1 = torch.nn.functional.mse_loss(y_hat,y) loss2 = torch.nn.functional.mse_loss(dy_hat, dy) loss3 = torch.nn.functional.mse_loss(d2y_hat, d2y) - loss = loss1 + loss2 + loss3 - + loss = loss1 + loss2 + loss3 # Can add a sobolev factor to give weight to each loss term. + # But it does not really change anything optimizer.zero_grad() loss.backward() optimizer.step()