From 771bd3e9e3973a0dd17db4e1a2acc3b49afe9a44 Mon Sep 17 00:00:00 2001
From: Mansard <nmansard@laas.fr>
Date: Fri, 29 Jan 2021 10:57:28 +0100
Subject: [PATCH] Add input shape argument in the network.

---
 neural_network.py | 27 +++++++++++++++------------
 1 file changed, 15 insertions(+), 12 deletions(-)

diff --git a/neural_network.py b/neural_network.py
index 0ae0148..27573ae 100644
--- a/neural_network.py
+++ b/neural_network.py
@@ -6,15 +6,17 @@ import numpy as np
 
 class Model(nn.Module):
     def __init__(self,
-        hidden_layers_params:OrderedDict = OrderedDict([
-        ('hidden layer 1', nn.Linear(in_features= 3 ,out_features=256)),
-        ('hidden layer 1 activation', nn.Tanh()),
-        ('hidden layer 2:', nn.Linear(in_features=256,out_features=256)),
-        ('hidden layer 2 activation:', nn.Tanh()),
-        ('hidden layer 3:', nn.Linear(in_features=256,out_features=1)),
-        ])):
-        super(Model, self).__init__()
+                 hidden_layers_params:OrderedDict = None,ninput = 3,nhidden = 256):
 
+        if hidden_layers_params is None:
+            hidden_layers_params = OrderedDict([
+                ('hidden layer 1', nn.Linear(in_features= ninput ,out_features=nhidden)),
+                ('hidden layer 1 activation', nn.Tanh()),
+                ('hidden layer 2:', nn.Linear(in_features=nhidden,out_features=nhidden)),
+                ('hidden layer 2 activation:', nn.Tanh()),
+                ('hidden layer 3:', nn.Linear(in_features=nhidden,out_features=1)),
+            ])
+        super(Model, self).__init__()
         
         self.hidden_layers = nn.Sequential(hidden_layers_params)
 
@@ -26,11 +28,12 @@ class Model(nn.Module):
 
 if __name__=='__main__':
     import torch.autograd.functional as F
-    x = torch.rand(100,3)
-    model = Model()
+    nx = 6 # 100
+    x = torch.rand(nx,2)
+    model = Model(ninput=2)
     dy_hat = torch.vstack( [ F.jacobian(model, state).squeeze() for state in x ] )
     d2y_hat = torch.stack( [ F.hessian(model, state).squeeze() for state in x ] )
-    xx = torch.rand(100,3,3)
+    xx = torch.rand(nx,2,2)
     print(dy_hat.shape,d2y_hat.shape)
     mse = torch.nn.functional.mse_loss(d2y_hat,xx)
-    print(mse)
\ No newline at end of file
+    print(mse)
-- 
GitLab