From 4e0ee0530ec2bea0d4979173778ff45a479f070d Mon Sep 17 00:00:00 2001 From: Tizian Wenzel <tizian.wenzel@uni-hamburg.de> Date: Wed, 9 Oct 2024 15:53:12 +0200 Subject: [PATCH] Added missing class OptimizedKernel. --- notebooks/utils/utilities.py | 141 +++++++++++++++++++++++++++++++++++ 1 file changed, 141 insertions(+) diff --git a/notebooks/utils/utilities.py b/notebooks/utils/utilities.py index a213ad2..bc807a1 100644 --- a/notebooks/utils/utilities.py +++ b/notebooks/utils/utilities.py @@ -1,7 +1,9 @@ # + import math import torch +import numpy as np +from torch import nn from utils import kernels @@ -46,3 +48,142 @@ class ActivFunc(torch.nn.Module): return x, centers + + +def compute_cv_loss_via_rippa_ext_2(kernel_matrix, y, n_folds, reg_for_matrix_inversion): + """ + Implementation without the need to provide a kernel and points: Simply provide the kernel matrix + """ + + # Some precomputations + kernel_matrix_reg = kernel_matrix + reg_for_matrix_inversion * torch.eye(kernel_matrix.shape[0]) + inv_kernel_matrix = torch.inverse(kernel_matrix_reg) + coeffs = torch.linalg.solve(kernel_matrix_reg, y) #[0] + + # Some initializations and preparations: It is required that n_folds divides y.shape[0] without remainder + array_error = torch.zeros(y.shape[0], 1) + n_per_fold = int(y.shape[0] / n_folds) + indices = torch.arange(0, y.shape[0]).view(n_per_fold, n_folds) + + # Standard Rippa's scheme + if n_folds == y.shape[0]: + array_error = coeffs / torch.diag(inv_kernel_matrix).view(-1,1) + + # Extended Rippa's scheme + else: + for j in range(n_folds): + inv_kernel_matrix_loc1 = inv_kernel_matrix[indices[:, j], :] + inv_kernel_matrix_loc = inv_kernel_matrix_loc1[:, indices[:, j]] + + array_error[j * n_per_fold: (j+1) * n_per_fold, 0] = \ + (torch.linalg.solve(inv_kernel_matrix_loc, coeffs[indices[:, j]])).view(-1) + + cv_error_sq = torch.sum(array_error ** 2) / array_error.numel() + + return cv_error_sq, array_error + + +class OptimizedKernel(torch.nn.Module): + + def __init__(self, kernel, dim, + reg_para=1e-5, learning_rate=1e-3, n_epochs=100, batch_size=32, n_folds=None, + flag_initialize_diagonal=False, flag_symmetric_A=False): + super().__init__() + + # Some settings, mostly optimization related + self.kernel = kernel + + self.dim = dim + self.reg_para = reg_para + self.learning_rate = learning_rate + self.n_epochs = n_epochs + self.batch_size = batch_size + + self.flag_symmetric_A = flag_symmetric_A + + # Define linear maps - hardcoded + if torch.is_tensor(flag_initialize_diagonal): + self.B = nn.Parameter(flag_initialize_diagonal, requires_grad=True) + elif flag_initialize_diagonal: + self.B = nn.Parameter(torch.eye(self.dim, self.dim), requires_grad=True) + else: + self.B = nn.Parameter(torch.rand(self.dim, self.dim), requires_grad=True) + + if self.flag_symmetric_A: + self.A = (self.B + self.B.t()) / 2 + else: + self.A = self.B + + + if n_folds is None: + self.n_folds = self.batch_size + else: + self.n_folds = n_folds + + + # Set optimizer and scheduler + self.optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) + self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=50, gamma=.7) + + # Initliaze lists from tracking + self.list_obj = [] + self.list_parameters = [] + + + def optimize(self, X, y, flag_optim_verbose=True): + + assert X.shape[0] == y.shape[0], 'Data sizes do not match' + n_batches = X.shape[0] // self.batch_size + + # Append initial parameters + if self.flag_symmetric_A: + self.list_parameters.append(torch.clone((self.B + self.B.t()) / 2).detach().numpy()) + else: + self.list_parameters.append(torch.clone(self.A).detach().numpy()) + + + for idx_epoch in range(self.n_epochs): + shuffle = np.random.permutation(X.shape[0]) # reshuffle the data set every epoch + + list_obj_loc = [] + + for idx_batch in range(n_batches): + + # Select minibatch from the data + ind = shuffle[idx_batch * self.batch_size : (idx_batch + 1) * self.batch_size] + Xb, yb = X[ind, :], y[ind, :] + + # Compute kernel matrix for minibatch + kernel_matrix = self.kernel.eval(Xb @ self.A, Xb @ self.A) + + # use cross validation loss via rippa to assess the error + optimization_objective, _ = compute_cv_loss_via_rippa_ext_2( + kernel_matrix, yb, self.n_folds, self.reg_para) + + # Keep track of optimization quantity within epoch + list_obj_loc.append(optimization_objective.detach().item()) + if idx_epoch == 0 and flag_optim_verbose: + print('First epoch: Iteration {:5d}: Training objective: {:.3e}'.format( + idx_batch, optimization_objective.detach().item())) + + # Do optimization stuff + optimization_objective.backward() + self.optimizer.step() # do one optimization step + self.optimizer.zero_grad() # set gradients to zero + + if self.flag_symmetric_A: + self.A = (self.B + self.B.t()) / 2 + else: + self.A = self.B + + # Keep track of some quantities and print something + mean_obj = np.mean(list_obj_loc) + + if flag_optim_verbose: + print('Epoch {:5d} finished, mean training objective: {:.3e}.'.format( + idx_epoch + 1, mean_obj)) + + self.list_obj.append(mean_obj) + + self.list_parameters.append(torch.clone(self.A).detach().numpy()) + -- GitLab