Start by importing pytorch & matplotlib. I could not make the latex interpreter work in binder, so the corresponding lines are commented out.
Code
import torch import torch.nn as nnimport matplotlib.pyplot as pltplt.rcParams['svg.fonttype'] ='none'from IPython.display import set_matplotlib_formatsimport matplotlibimport numpy as np#%% matplotlib setupmatplotlib.rcParams["text.usetex"] =Truematplotlib.rcParams["font.family"] ="serif"matplotlib.rcParams["font.size"] ="14"mps_device = torch.device("mps")
Create the functions
The separability proprieties of several functions are investigated. This code first shows the separability proprieties of different ways of clustering the 1D space into two regions. Two types of functions are used to do so:
A sharp step function (Heaviside),
A smooth one (Tanh).
In both cases the position of the jump is parametrised by a scalar parameter \(\alpha\).
Furhter investigations are conducted on moving front(s) with a
A gaussian function that is moving
Two gaussian functions moving at different rates
Code
L =10# Space domainAlpha_vect = torch.linspace(0,1,100) # vector of alphasx_vect = torch.linspace(0,L,2000) # vector of xFunction ='Heaviside'# Alpha-parameterised step functionFunction ='Tanh'# smooth alpha-parameterised step function# Function = 'Gauss' # Alpha-parameterised front function# Function = 'Gauss_sum' # Double alpha-parameterised front functionsif Function =='Heaviside': F = torch.heaviside((x_vect[:,None] - (1-Alpha_vect[None,:])*L), x_vect[-1]/x_vect[-1])elif Function =='Tanh': F = torch.tanh((x_vect[:,None] - (1-Alpha_vect[None,:])*L))elif Function =='Gauss': F = torch.exp(-(x_vect[:,None] - (1-Alpha_vect[None,:])*L)**2)elif Function =='Gauss_sum': F = torch.exp(-(x_vect[:,None] - (1-Alpha_vect[None,:])*L)**2) + torch.exp(-(x_vect[:,None] - (1-2*Alpha_vect[None,:])*L)**2)
Define the autoencoder
We know that the parametrised field is (non-linearly) parametrised with a single parameter. The SVD required approximatively \(15\) modes to reprensent the field. We propose to create an autoencoder with a single parameter in the latent space. i.e.\(N_3=1\), to see how the non-linear interpolation can retrive the single parameter dependency of the function.
Code
class AutoEncoder(nn.Module):def__init__(self):super().__init__()self.N_0 =2000self.N_1 =1000self.N_2 =200self.N_3 =1self.encoder = nn.Sequential( nn.Linear(self.N_0,self.N_1), nn.Tanh(), nn.Linear(self.N_1, self.N_2), nn.Tanh(), nn.Linear(self.N_2, self.N_3) )self.decoder = nn.Sequential( nn.Linear(self.N_3,self.N_2), nn.ReLU(), nn.Linear(self.N_2, self.N_1), nn.ReLU(), nn.Linear(self.N_1, self.N_0) )def forward(self, x, role ="decode"):ifself.training: e =self.encoder(x) d =self.decoder(e)return delse:match role:case"encode": e =self.encoder(x)return ecase"decode": d =self.decoder(x)return dROM = AutoEncoder()MSE = nn.MSELoss()optimizer = torch.optim.Adam(ROM.parameters(), lr =1e-3)
/var/folders/vt/tk0pvn6x7dj443pvlvkxzcj00000gn/T/ipykernel_95795/3940643703.py:11: UserWarning:
The use of `x.T` on tensors of dimension other than 2 to reverse their shape is deprecated and it will throw an error in a future release. Consider `x.mT` to transpose batches of matrices or `x.permute(*torch.arange(x.ndim - 1, -1, -1))` to reverse the dimensions of a tensor. (Triggered internally at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/aten/src/ATen/native/TensorShape.cpp:3618.)
Train the model
epoch 1 loss = 8.2155e-01
epoch 2 loss = 7.4085e-01
epoch 3 loss = 6.4534e-01
epoch 4 loss = 4.8344e-01
epoch 5 loss = 4.8184e-01
epoch 6 loss = 4.5387e-01
epoch 7 loss = 3.9579e-01
epoch 8 loss = 3.7095e-01
epoch 9 loss = 3.3925e-01
epoch 10 loss = 3.0197e-01
epoch 11 loss = 2.8023e-01
epoch 12 loss = 2.5817e-01
epoch 13 loss = 2.4296e-01
epoch 14 loss = 2.3631e-01
epoch 15 loss = 2.271e-01
epoch 16 loss = 2.1174e-01
epoch 17 loss = 2.0173e-01
epoch 18 loss = 1.9554e-01
epoch 19 loss = 1.8771e-01
epoch 20 loss = 1.7975e-01
epoch 21 loss = 1.6976e-01
epoch 22 loss = 1.6103e-01
epoch 23 loss = 1.5026e-01
epoch 24 loss = 1.3971e-01
epoch 25 loss = 1.2882e-01
epoch 26 loss = 1.1729e-01
epoch 27 loss = 1.065e-01
epoch 28 loss = 9.6526e-02
epoch 29 loss = 8.6951e-02
epoch 30 loss = 7.8627e-02
epoch 31 loss = 7.0401e-02
epoch 32 loss = 6.3924e-02
epoch 33 loss = 5.8911e-02
epoch 34 loss = 5.5006e-02
epoch 35 loss = 5.1995e-02
epoch 36 loss = 5.0566e-02
epoch 37 loss = 4.8634e-02
epoch 38 loss = 4.7474e-02
epoch 39 loss = 4.7193e-02
epoch 40 loss = 4.7940e-02
epoch 41 loss = 5.2896e-02
epoch 42 loss = 5.9132e-02
epoch 43 loss = 4.8067e-02
epoch 44 loss = 4.2177e-02
epoch 45 loss = 4.8879e-02
epoch 46 loss = 3.9875e-02
epoch 47 loss = 4.2604e-02
epoch 48 loss = 4.1113e-02
epoch 49 loss = 3.8979e-02
epoch 50 loss = 3.9686e-02
epoch 51 loss = 3.7978e-02
epoch 52 loss = 3.7426e-02
epoch 53 loss = 3.6782e-02
epoch 54 loss = 3.5568e-02
epoch 55 loss = 3.4666e-02
epoch 56 loss = 3.3975e-02
epoch 57 loss = 3.2523e-02
epoch 58 loss = 3.2201e-02
epoch 59 loss = 3.0572e-02
epoch 60 loss = 3.0382e-02
epoch 61 loss = 2.8892e-02
epoch 62 loss = 2.8115e-02
epoch 63 loss = 2.7071e-02
epoch 64 loss = 2.5838e-02
epoch 65 loss = 2.4633e-02
epoch 66 loss = 2.369e-02
epoch 67 loss = 2.225e-02
epoch 68 loss = 2.1503e-02
epoch 69 loss = 2.0067e-02
epoch 70 loss = 1.9367e-02
epoch 71 loss = 1.8188e-02
epoch 72 loss = 1.7298e-02
epoch 73 loss = 1.6423e-02
epoch 74 loss = 1.5630e-02
epoch 75 loss = 1.4770e-02
epoch 76 loss = 1.4277e-02
epoch 77 loss = 1.3544e-02
epoch 78 loss = 1.3142e-02
epoch 79 loss = 1.266e-02
epoch 80 loss = 1.2276e-02
epoch 81 loss = 1.1928e-02
epoch 82 loss = 1.1651e-02
epoch 83 loss = 1.132e-02
epoch 84 loss = 1.1146e-02
epoch 85 loss = 1.0812e-02
epoch 86 loss = 1.0640e-02
epoch 87 loss = 1.0346e-02
epoch 88 loss = 1.0115e-02
epoch 89 loss = 9.8678e-03
epoch 90 loss = 9.5919e-03
epoch 91 loss = 9.3518e-03
epoch 92 loss = 9.0759e-03
epoch 93 loss = 8.8265e-03
epoch 94 loss = 8.5736e-03
epoch 95 loss = 8.3326e-03
epoch 96 loss = 8.0972e-03
epoch 97 loss = 7.8782e-03
epoch 98 loss = 7.6577e-03
epoch 99 loss = 7.4597e-03
epoch 100 loss = 7.2663e-03
epoch 101 loss = 7.0822e-03
epoch 102 loss = 6.9131e-03
epoch 103 loss = 6.7434e-03
epoch 104 loss = 6.5875e-03
epoch 105 loss = 6.4347e-03
epoch 106 loss = 6.2886e-03
epoch 107 loss = 6.1507e-03
epoch 108 loss = 6.0154e-03
epoch 109 loss = 5.8867e-03
epoch 110 loss = 5.7649e-03
epoch 111 loss = 5.6437e-03
epoch 112 loss = 5.5332e-03
epoch 113 loss = 5.4245e-03
epoch 114 loss = 5.3208e-03
epoch 115 loss = 5.2267e-03
epoch 116 loss = 5.1309e-03
epoch 117 loss = 5.0453e-03
epoch 118 loss = 4.9602e-03
epoch 119 loss = 4.8775e-03
epoch 120 loss = 4.8005e-03
epoch 121 loss = 4.7221e-03
epoch 122 loss = 4.6473e-03
epoch 123 loss = 4.5739e-03
epoch 124 loss = 4.5008e-03
epoch 125 loss = 4.4290e-03
epoch 126 loss = 4.3590e-03
epoch 127 loss = 4.2876e-03
epoch 128 loss = 4.2188e-03
epoch 129 loss = 4.1503e-03
epoch 130 loss = 4.0826e-03
epoch 131 loss = 4.0159e-03
epoch 132 loss = 3.9502e-03
epoch 133 loss = 3.8851e-03
epoch 134 loss = 3.8207e-03
epoch 135 loss = 3.7579e-03
epoch 136 loss = 3.6948e-03
epoch 137 loss = 3.6329e-03
epoch 138 loss = 3.5718e-03
epoch 139 loss = 3.5115e-03
epoch 140 loss = 3.4516e-03
epoch 141 loss = 3.3928e-03
epoch 142 loss = 3.3345e-03
epoch 143 loss = 3.2768e-03
epoch 144 loss = 3.2201e-03
epoch 145 loss = 3.1643e-03
epoch 146 loss = 3.1091e-03
epoch 147 loss = 3.0549e-03
epoch 148 loss = 3.0017e-03
epoch 149 loss = 2.9495e-03
epoch 150 loss = 2.8982e-03
epoch 151 loss = 2.848e-03
epoch 152 loss = 2.7988e-03
epoch 153 loss = 2.7507e-03
epoch 154 loss = 2.7036e-03
epoch 155 loss = 2.6574e-03
epoch 156 loss = 2.6121e-03
epoch 157 loss = 2.5678e-03
epoch 158 loss = 2.5244e-03
epoch 159 loss = 2.4818e-03
epoch 160 loss = 2.4402e-03
epoch 161 loss = 2.3995e-03
epoch 162 loss = 2.3597e-03
epoch 163 loss = 2.321e-03
epoch 164 loss = 2.2831e-03
epoch 165 loss = 2.2461e-03
epoch 166 loss = 2.2100e-03
epoch 167 loss = 2.1748e-03
epoch 168 loss = 2.1405e-03
epoch 169 loss = 2.1068e-03
epoch 170 loss = 2.074e-03
epoch 171 loss = 2.042e-03
epoch 172 loss = 2.0107e-03
epoch 173 loss = 1.9802e-03
epoch 174 loss = 1.9505e-03
epoch 175 loss = 1.9215e-03
epoch 176 loss = 1.8931e-03
epoch 177 loss = 1.8654e-03
epoch 178 loss = 1.8384e-03
epoch 179 loss = 1.8120e-03
epoch 180 loss = 1.7862e-03
epoch 181 loss = 1.7610e-03
epoch 182 loss = 1.7363e-03
epoch 183 loss = 1.7122e-03
epoch 184 loss = 1.6885e-03
epoch 185 loss = 1.6653e-03
epoch 186 loss = 1.6427e-03
epoch 187 loss = 1.6205e-03
epoch 188 loss = 1.5989e-03
epoch 189 loss = 1.5778e-03
epoch 190 loss = 1.5573e-03
epoch 191 loss = 1.5379e-03
epoch 192 loss = 1.5206e-03
epoch 193 loss = 1.5071e-03
epoch 194 loss = 1.5021e-03
epoch 195 loss = 1.5185e-03
epoch 196 loss = 1.5856e-03
epoch 197 loss = 1.7829e-03
epoch 198 loss = 2.2682e-03
epoch 199 loss = 3.3796e-03
epoch 200 loss = 5.2142e-03
epoch 201 loss = 6.7241e-03
epoch 202 loss = 4.8273e-03
epoch 203 loss = 1.7260e-03
epoch 204 loss = 2.0891e-03
epoch 205 loss = 3.6923e-03
epoch 206 loss = 2.2138e-03
epoch 207 loss = 1.4608e-03
epoch 208 loss = 2.8068e-03
epoch 209 loss = 1.9051e-03
epoch 210 loss = 1.4191e-03
epoch 211 loss = 2.3544e-03
epoch 212 loss = 1.4942e-03
epoch 213 loss = 1.4848e-03
epoch 214 loss = 1.9594e-03
epoch 215 loss = 1.2159e-03
epoch 216 loss = 1.5578e-03
epoch 217 loss = 1.5672e-03
epoch 218 loss = 1.137e-03
epoch 219 loss = 1.5317e-03
epoch 220 loss = 1.2545e-03
epoch 221 loss = 1.1778e-03
epoch 222 loss = 1.3977e-03
epoch 223 loss = 1.0851e-03
epoch 224 loss = 1.2199e-03
epoch 225 loss = 1.2228e-03
epoch 226 loss = 1.0328e-03
epoch 227 loss = 1.2000e-03
epoch 228 loss = 1.0747e-03
epoch 229 loss = 1.0323e-03
epoch 230 loss = 1.1263e-03
epoch 231 loss = 9.8404e-04
epoch 232 loss = 1.0296e-03
epoch 233 loss = 1.0345e-03
epoch 234 loss = 9.4081e-04
epoch 235 loss = 1.0055e-03
epoch 236 loss = 9.5542e-04
epoch 237 loss = 9.2134e-04
epoch 238 loss = 9.6229e-04
epoch 239 loss = 8.9963e-04
epoch 240 loss = 9.0467e-04
epoch 241 loss = 9.1333e-04
epoch 242 loss = 8.6303e-04
epoch 243 loss = 8.8175e-04
epoch 244 loss = 8.6798e-04
epoch 245 loss = 8.3659e-04
epoch 246 loss = 8.5344e-04
epoch 247 loss = 8.2997e-04
epoch 248 loss = 8.1350e-04
epoch 249 loss = 8.2295e-04
epoch 250 loss = 7.9848e-04
epoch 251 loss = 7.9071e-04
epoch 252 loss = 7.9310e-04
epoch 253 loss = 7.7155e-04
epoch 254 loss = 7.6778e-04
epoch 255 loss = 7.6541e-04
epoch 256 loss = 7.4756e-04
epoch 257 loss = 7.4496e-04
epoch 258 loss = 7.3986e-04
epoch 259 loss = 7.2529e-04
epoch 260 loss = 7.2263e-04
epoch 261 loss = 7.1634e-04
epoch 262 loss = 7.0429e-04
epoch 263 loss = 7.0100e-04
epoch 264 loss = 6.9452e-04
epoch 265 loss = 6.8423e-04
epoch 266 loss = 6.8027e-04
epoch 267 loss = 6.7406e-04
epoch 268 loss = 6.6498e-04
epoch 269 loss = 6.605e-04
epoch 270 loss = 6.5475e-04
epoch 271 loss = 6.4654e-04
epoch 272 loss = 6.4166e-04
epoch 273 loss = 6.3633e-04
epoch 274 loss = 6.2891e-04
epoch 275 loss = 6.2364e-04
epoch 276 loss = 6.1873e-04
epoch 277 loss = 6.1203e-04
epoch 278 loss = 6.0647e-04
epoch 279 loss = 6.0177e-04
epoch 280 loss = 5.9577e-04
epoch 281 loss = 5.9010e-04
epoch 282 loss = 5.8548e-04
epoch 283 loss = 5.8009e-04
epoch 284 loss = 5.7451e-04
epoch 285 loss = 5.6982e-04
epoch 286 loss = 5.6495e-04
epoch 287 loss = 5.596e-04
epoch 288 loss = 5.5484e-04
epoch 289 loss = 5.503e-04
epoch 290 loss = 5.4535e-04
epoch 291 loss = 5.4053e-04
epoch 292 loss = 5.3614e-04
epoch 293 loss = 5.3158e-04
epoch 294 loss = 5.2687e-04
epoch 295 loss = 5.2250e-04
epoch 296 loss = 5.1821e-04
epoch 297 loss = 5.1374e-04
epoch 298 loss = 5.0938e-04
epoch 299 loss = 5.0527e-04
epoch 300 loss = 5.0108e-04
epoch 301 loss = 4.9686e-04
epoch 302 loss = 4.9278e-04
epoch 303 loss = 4.8881e-04
epoch 304 loss = 4.8479e-04
epoch 305 loss = 4.8081e-04
epoch 306 loss = 4.7694e-04
epoch 307 loss = 4.7312e-04
epoch 308 loss = 4.6928e-04
epoch 309 loss = 4.6549e-04
epoch 310 loss = 4.6178e-04
epoch 311 loss = 4.5811e-04
epoch 312 loss = 4.5443e-04
epoch 313 loss = 4.5081e-04
epoch 314 loss = 4.4727e-04
epoch 315 loss = 4.4373e-04
epoch 316 loss = 4.4021e-04
epoch 317 loss = 4.3674e-04
epoch 318 loss = 4.3334e-04
epoch 319 loss = 4.2996e-04
epoch 320 loss = 4.2660e-04
epoch 321 loss = 4.2328e-04
epoch 322 loss = 4.2001e-04
epoch 323 loss = 4.1677e-04
epoch 324 loss = 4.1357e-04
epoch 325 loss = 4.1038e-04
epoch 326 loss = 4.0725e-04
epoch 327 loss = 4.0414e-04
epoch 328 loss = 4.0106e-04
epoch 329 loss = 3.9801e-04
epoch 330 loss = 3.9499e-04
epoch 331 loss = 3.9202e-04
epoch 332 loss = 3.8906e-04
epoch 333 loss = 3.8614e-04
epoch 334 loss = 3.8324e-04
epoch 335 loss = 3.8037e-04
epoch 336 loss = 3.7753e-04
epoch 337 loss = 3.7472e-04
epoch 338 loss = 3.7196e-04
epoch 339 loss = 3.6919e-04
epoch 340 loss = 3.6647e-04
epoch 341 loss = 3.6377e-04
epoch 342 loss = 3.611e-04
epoch 343 loss = 3.5846e-04
epoch 344 loss = 3.5584e-04
epoch 345 loss = 3.5326e-04
epoch 346 loss = 3.5068e-04
epoch 347 loss = 3.4814e-04
epoch 348 loss = 3.4562e-04
epoch 349 loss = 3.4313e-04
epoch 350 loss = 3.4066e-04
epoch 351 loss = 3.3822e-04
epoch 352 loss = 3.3582e-04
epoch 353 loss = 3.3341e-04
epoch 354 loss = 3.3104e-04
epoch 355 loss = 3.2868e-04
epoch 356 loss = 3.2635e-04
epoch 357 loss = 3.2404e-04
epoch 358 loss = 3.2174e-04
epoch 359 loss = 3.1948e-04
epoch 360 loss = 3.1722e-04
epoch 361 loss = 3.1499e-04
epoch 362 loss = 3.1276e-04
epoch 363 loss = 3.1055e-04
epoch 364 loss = 3.0836e-04
epoch 365 loss = 3.0618e-04
epoch 366 loss = 3.0402e-04
epoch 367 loss = 3.0186e-04
epoch 368 loss = 2.9972e-04
epoch 369 loss = 2.9759e-04
epoch 370 loss = 2.9548e-04
epoch 371 loss = 2.9336e-04
epoch 372 loss = 2.9124e-04
epoch 373 loss = 2.8913e-04
epoch 374 loss = 2.8703e-04
epoch 375 loss = 2.8494e-04
epoch 376 loss = 2.8285e-04
epoch 377 loss = 2.8079e-04
epoch 378 loss = 2.7875e-04
epoch 379 loss = 2.7671e-04
epoch 380 loss = 2.7469e-04
epoch 381 loss = 2.7273e-04
epoch 382 loss = 2.7078e-04
epoch 383 loss = 2.6885e-04
epoch 384 loss = 2.6693e-04
epoch 385 loss = 2.6504e-04
epoch 386 loss = 2.6316e-04
epoch 387 loss = 2.6129e-04
epoch 388 loss = 2.5944e-04
epoch 389 loss = 2.576e-04
epoch 390 loss = 2.5578e-04
epoch 391 loss = 2.5398e-04
epoch 392 loss = 2.522e-04
epoch 393 loss = 2.5043e-04
epoch 394 loss = 2.4868e-04
epoch 395 loss = 2.4696e-04
epoch 396 loss = 2.4524e-04
epoch 397 loss = 2.4356e-04
epoch 398 loss = 2.419e-04
epoch 399 loss = 2.4025e-04
epoch 400 loss = 2.3863e-04
Evaluate the model
Code
ROM.eval()F_train = F_train.cpu()F_val = F_val.cpu()ROM.cpu()torch.save(ROM, 'FullModel.pt') # to save a full coarse model
Plots
Code
loss_t_vect = [loss_t.cpu() for loss_t in loss_t_vect]loss_v_vect = [loss_v.cpu() for loss_v in loss_v_vect]plt.plot(loss_t_vect,label ='training set')plt.plot(loss_v_vect,label ='validation set')plt.legend(loc="upper right")plt.show()plt.semilogy(loss_t_vect,label ='training set')plt.semilogy(loss_v_vect,label ='validation set')plt.legend(loc="upper right")plt.xlabel('Epochs')plt.xlabel('Loss')# plt.savefig(f'Results/loss_training_'+Function+'.pdf', transparent=True) plt.show()
Plot the comparison of the latent space and the natural parameter used to generate the data
With only one latent space-parameter the error are reasonably low. Using non-linear interpolation is more appropriate when the separability is low unsing linear subspaces.