diff --git a/README.md b/README.md
index ac373827aaa41a177087fecf066978f5d8808206..42b4297878950e9756bc971b389adf7a56b8a270 100644
--- a/README.md
+++ b/README.md
@@ -2,9 +2,6 @@
 
 Python implementation of the 2L-VKOGA algorithm, which uses a kernel optimization (two layered kernel) before running VKOGA with the modified kernel.
 
-
-
-
 ## How to cite:
 If you use this code in your work, please cite the paper
 
diff --git a/example_files/01_testfile.py b/example_files/01_testfile.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4e7b0fff44a841ba45a915db7ad77169aba8a75
--- /dev/null
+++ b/example_files/01_testfile.py
@@ -0,0 +1,68 @@
+# Example to show the use of validation set tracking
+
+import numpy as np
+from matplotlib import pyplot as plt
+import kernels, tkernels
+from vkoga import VKOGA_2L
+
+np.random.seed(1)
+
+
+# Create some 2D data, whereby the target values are invariant in one direction
+dim = 3
+X_train = np.random.rand(500, dim)
+y_train = X_train[:, [0]]
+
+X_val = np.random.rand(100, dim)
+y_val = X_val[:, [0]]
+
+
+# Run VKOGA
+kernel = kernels.Matern(k=1)
+kernel_t = tkernels.Matern(k=1)
+
+model_1L = VKOGA_2L(kernel=kernel, greedy_type='f_greedy')
+_ = model_1L.fit(X_train, y_train, X_val=X_val, y_val=y_val, maxIter=50)
+
+model_2L = VKOGA_2L(kernel=[kernel, kernel_t], greedy_type='f_greedy', flag_2L_optimization=True)
+_ = model_2L.fit(X_train, y_train, X_val=X_val, y_val=y_val, maxIter=50)
+
+
+# Get ready for some plot
+fig = plt.figure(2)
+fig.clf()
+plt.plot(model_1L.train_hist['f'])
+plt.plot(model_2L.train_hist['f'])
+plt.legend(['1L, f max', '2L, p max'])
+plt.xlabel('training iteration')
+plt.xscale('log')
+plt.yscale('log')
+plt.draw()
+
+
+
+# Result: Since the training set is quite small, we can clearly observe overfitting via the
+# validation set tracking
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/kernels.py b/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab6d209f7de7425654f81ec4f3d90f09633c2386
--- /dev/null
+++ b/kernels.py
@@ -0,0 +1,252 @@
+#!/usr/bin/env python3
+
+from abc import ABC, abstractmethod
+from scipy.spatial import distance_matrix
+import numpy as np
+import matplotlib.pyplot as plt
+
+# Abstract kernel
+class Kernel(ABC):
+    @abstractmethod    
+    def __init__(self):
+        super().__init__()
+    
+    @abstractmethod
+    def eval(self, x, y):
+        pass
+
+    def eval_prod(self, x, y, v, batch_size=100):
+        N = x.shape[0]
+        num_batches = int(np.ceil(N / batch_size))
+        mat_vec_prod = np.zeros((N, 1)) 
+        for idx in range(num_batches):
+            idx_begin = idx * batch_size
+            idx_end = (idx + 1) * batch_size
+            A = self.eval(x[idx_begin:idx_end, :], y)
+            mat_vec_prod[idx_begin:idx_end] = A @ v
+        return mat_vec_prod
+
+    @abstractmethod
+    def diagonal(self, X):
+        pass
+
+    @abstractmethod
+    def __str__(self):
+        pass
+
+    @abstractmethod
+    def set_params(self, params):
+        pass
+
+# Abstract RBF
+class RBF(Kernel):
+    @abstractmethod    
+    def __init__(self):
+        super(RBF, self).__init__()
+        
+    def eval(self, x, y):
+        return self.rbf(self.ep, distance_matrix(np.atleast_2d(x), np.atleast_2d(y)))
+
+    def diagonal(self, X):
+        return np.ones(X.shape[0]) * self.rbf(self.ep, 0)
+    
+    def __str__(self):
+     return self.name + ' [gamma = %2.2e]' % self.ep   
+
+    def set_params(self, par):
+        self.ep = par
+
+# Implementation of concrete RBFs
+class Gaussian(RBF):
+    def __init__(self, ep=1):
+        self.ep = ep
+        self.name = 'gauss'
+        self.rbf = lambda ep, r: np.exp(-(ep * r) ** 2)
+
+class GaussianTanh(RBF):
+    def __init__(self, ep=1):
+        self.ep = ep
+        self.name = 'gauss_tanh'
+        self.rbf = lambda ep, r: np.exp(-(ep * np.tanh(r)) ** 2)
+
+class IMQ(RBF):
+    def __init__(self, ep=1):
+        self.ep = ep
+        self.name = 'imq'
+        self.rbf = lambda ep, r: 1. / np.sqrt(1 + (ep * r) ** 2)
+
+class Matern(RBF):
+    def __init__(self, ep=1, k=0):
+        self.ep = ep
+        if k == 0:
+            self.name = 'mat0'
+            self.rbf = lambda ep, r : np.exp(-ep * r)
+        elif k == -1:
+            self.name = 'derivative kernel ob quadratic matern'
+            self.rbf = lambda ep, r: np.exp(-r) * (r**2 - (2 * 1 + 3) * r + 1 ** 2 + 2 * 1)
+        elif k == 1:
+            self.name = 'mat1'
+            self.rbf = lambda ep, r: np.exp(-ep * r) * (1 + ep * r)
+        elif k == 2:
+            self.name = 'mat2'
+            self.rbf = lambda ep, r: np.exp(-ep * r) * (3 + 3 * ep * r + 1 * (ep * r) ** 2)
+        elif k == 3:
+            self.name = 'mat3'
+            self.rbf = lambda ep, r: np.exp(-ep * r) * (15 + 15 * ep * r + 6 * (ep * r) ** 2 + 1 * (ep * r) ** 3)
+        elif k == 4:
+            self.name = 'mat4'
+            self.rbf = lambda ep, r: np.exp(-ep * r) * (105 + 105 * ep * r + 45 * (ep * r) ** 2 + 10 * (ep * r) ** 3 + 1 * (ep * r) ** 4)
+        elif k == 5:
+            self.name = 'mat5'
+            self.rbf = lambda ep, r: np.exp(-ep * r) * (945 + 945 * ep * r + 420 * (ep * r) ** 2 + 105 * (ep * r) ** 3 + 15 * (ep * r) ** 4 + 1 * (ep * r) ** 5)
+        elif k == 6:
+            self.name = 'mat6'
+            self.rbf = lambda ep, r: np.exp(-ep * r) * (10395 + 10395 * ep * r + 4725 * (ep * r) ** 2 + 1260 * (ep * r) ** 3 + 210 * (ep * r) ** 4 + 21 * (ep * r) ** 5 + 1 * (ep * r) ** 6)
+        elif k == 7:
+            self.name = 'mat7'
+            self.rbf = lambda ep, r: np.exp(-ep * r) * (135135 + 135135 * ep * r + 62370 * (ep * r) ** 2 + 17325 * (ep * r) ** 3 + 3150 * (ep * r) ** 4 + 378 * (ep * r) ** 5 + 28 * (ep * r) ** 6 + 1 * (ep * r) ** 7)
+        else:
+            self.name = None
+            self.rbf = None
+            raise Exception('This Matern kernel is not implemented')
+
+class Wendland(RBF):
+    def __init__(self, ep=1, k=0, d=1):
+        self.ep = ep
+        self.name = 'wen_' + str(d) + '_' + str(k)
+        l = np.floor(d / 2) + k + 1
+        if k == 0:
+            p = lambda r: 1
+        elif k == 1:
+            p = lambda r: (l + 1) * r + 1
+        elif k == 2:
+            p = lambda r: (l + 3) * (l + 1) * r ** 2 + 3 * (l + 2) * r + 3
+        elif k == 3:
+            p = lambda r: (l + 5) * (l + 3) * (l + 1) * r ** 3 + (45 + 6 * l * (l + 6)) * r ** 2 + (15 * (l + 3)) * r + 15
+        elif k == 4:
+            p = lambda r: (l + 7) * (l + 5) * (l + 3) * (l + 1) * r ** 4 + (5 * (l + 4) * (21 + 2 * l * (8 + l))) * r ** 3 + (45 * (14 + l * (l + 8))) * r ** 2 + (105 * (l + 4)) * r + 105
+        else:
+            raise Exception('This Wendland kernel is not implemented')
+        c = np.math.factorial(l + 2 * k) / np.math.factorial(l)
+        e = l + k
+        self.rbf = lambda ep, r: np.maximum(1 - ep * r, 0) ** e * p(ep * r) / c
+    
+
+ # Polynomial kernels    
+class Polynomial(Kernel):
+    def __init__(self, a=0, p=1):
+        self.a = a
+        self.p = p
+            
+    def eval(self, x, y):
+        return (np.atleast_2d(x) @ np.atleast_2d(y).transpose() + self.a) ** self.p
+    
+    def diagonal(self, X):
+        return ((np.linalg.norm(X, axis=1)**2 + self.a) ** self.p) #[:, None]
+
+    def __str__(self):
+     return 'polynomial' + ' [a = %2.2e, p = %2.2e]' % (self.a, self.p)   
+
+    def set_params(self, par):
+        self.a = par[0]
+        self.p = par[1]
+
+
+# Polynomial kernels
+class BrownianBridge(Kernel):
+    def __init__(self):
+        super().__init__()
+        self.name = 'Brownian Bridge'
+
+    def eval(self, x, y):
+        return np.minimum(np.atleast_2d(x), np.transpose(np.atleast_2d(y))) - np.atleast_2d(x) * np.transpose(np.atleast_2d(y))
+
+    def diagonal(self, X):
+        return X[:, 0] - X[:, 0] ** 2
+
+    def __str__(self):
+        return 'Brownian Bridge kernel'
+
+    def set_params(self, par):
+        pass
+
+class BrownianMotion(Kernel):
+    def __init__(self):
+        super().__init__()
+        self.name = 'Brownian Motion'
+
+    def eval(self, x, y):
+        return np.minimum(np.atleast_2d(x), np.transpose(np.atleast_2d(y)))
+
+    def diagonal(self, X):
+
+
+        return X.reshape(-1)
+
+    def __str__(self):
+        return 'Brownian Motion kernel'
+
+    def set_params(self, par):
+        pass
+
+# Tensor product kernels
+class TensorProductKernel(Kernel):
+    def __init__(self, kernel):
+        super().__init__()
+
+        self.kernel_1D = kernel
+        self.name = self.kernel_1D.name
+
+    def eval(self, x, y):
+
+        x = np.atleast_2d(x)
+        y = np.atleast_2d(y)
+
+        assert x.shape[1] == y.shape[1], 'Dimension do not match'
+
+        array_matrix = np.ones((x.shape[0], y.shape[0]))
+
+        for idx_dim in range(x.shape[1]):
+            array_matrix = array_matrix * self.kernel_1D.eval(x[:, [idx_dim]], y[:, [idx_dim]])
+
+        return array_matrix
+
+    def diagonal(self, X):
+
+        X = np.atleast_2d(X)
+
+        array_diagonal = np.ones(X.shape[0])
+
+        for idx_dim in range(X.shape[1]):
+            array_diagonal *= self.kernel_1D.diagonal(X[:, [idx_dim]])
+
+        return array_diagonal
+
+    def __str__(self):
+        return 'Tensor product kernel for ' + self.name
+
+    def set_params(self, par):
+        pass
+
+# A demo usage
+def main():
+    ker = Gaussian()
+
+    x = np.linspace(-1, 1, 100)[:, None]
+    y = np.matrix([0])
+    A = ker.eval(x, y)
+
+
+    fig = plt.figure(1)
+    fig.clf()
+    ax = fig.gca()
+    ax.plot(x, A)
+    ax.set_title('A kernel: ' + str(ker))
+    fig.show()
+
+
+if __name__ == '__main__':
+    main()
+
+
+        
diff --git a/tkernels.py b/tkernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..26e26602f18e4819b4ba4ff29f0345e83b26bf95
--- /dev/null
+++ b/tkernels.py
@@ -0,0 +1,195 @@
+#!/usr/bin/env python3
+
+# Torch implementation of the kernels!!
+# import sys
+# sys.path.append('/usr/local/lib/python3.7/dist-packages/')
+import torch
+from abc import ABC, abstractmethod
+import matplotlib.pyplot as plt
+import numpy as np
+
+
+# Abstract kernel
+class Kernel(ABC):
+    @abstractmethod
+    def __init__(self):
+        super().__init__()
+
+    @abstractmethod
+    def eval(self):
+        pass
+
+    @abstractmethod
+    def diagonal(self, X):
+        pass
+
+    @abstractmethod
+    def __str__(self):
+        pass
+
+    @abstractmethod
+    def set_params(self, params):
+        pass
+
+
+# Abstract RBF
+class RBF(Kernel):
+    @abstractmethod
+    def __init__(self):
+        super(RBF, self).__init__()
+
+    def eval(self, x, y):
+        return self.rbf(self.ep,
+                        torch.cdist(x, y))
+
+    def diagonal(self, X):
+        return torch.ones(X.shape[0], 1) * self.rbf(self.ep, torch.tensor(0.0))
+
+    def __str__(self):
+        return self.name + ' [gamma = %2.2e]' % self.ep
+
+    def set_params(self, par):
+        self.ep = par
+
+
+# Implementation of concrete RBFs
+class Gaussian(RBF):
+    def __init__(self, ep=1):
+        self.ep = ep
+        self.name = 'gauss'
+        self.rbf = lambda ep, r: torch.exp(-(ep * r) ** 2)
+
+
+class GaussianTanh(RBF):
+    def __init__(self, ep=1):
+        self.ep = ep
+        self.name = 'gauss_tanh'
+        self.rbf = lambda ep, r: torch.exp(-(ep * torch.tanh(r)) ** 2)
+
+
+class IMQ(RBF):
+    def __init__(self, ep=1):
+        self.ep = ep
+        self.name = 'imq'
+        self.rbf = lambda ep, r: 1. / torch.sqrt(1 + (ep * r) ** 2)
+
+
+class Matern(RBF):
+    def __init__(self, ep=1, k=0):
+        self.ep = ep
+        if k == 0:
+            self.name = 'mat0'
+            self.rbf = lambda ep, r: torch.exp(-ep * r)
+        elif k == 1:
+            self.name = 'mat1'
+            self.rbf = lambda ep, r: torch.exp(-ep * r) * (1 + ep * r)
+        elif k == 2:
+            self.name = 'mat2'
+            self.rbf = lambda ep, r: torch.exp(-ep * r) * (3 + 3 * ep * r + (ep * r) ** 2)
+        elif k == 3:
+            self.name = 'mat3'
+            self.rbf = lambda ep, r: torch.exp(-ep * r) * (15 + 15 * ep * r + 6 * (ep * r) ** 2 + 1 * (ep * r) ** 3)
+        elif k == 4:
+            self.name = 'mat4'
+            self.rbf = lambda ep, r: torch.exp(-ep * r) * (
+                        105 + 105 * ep * r + 45 * (ep * r) ** 2 + 10 * (ep * r) ** 3 + 1 * (ep * r) ** 4)
+        elif k == 5:
+            self.name = 'mat5'
+            self.rbf = lambda ep, r: torch.exp(-ep * r) * (
+                        945 + 945 * ep * r + 420 * (ep * r) ** 2 + 105 * (ep * r) ** 3 + 15 * (ep * r) ** 4 + 1 * (
+                            ep * r) ** 5)
+        elif k == 6:
+            self.name = 'mat6'
+            self.rbf = lambda ep, r: torch.exp(-ep * r) * (
+                        10395 + 10395 * ep * r + 4725 * (ep * r) ** 2 + 1260 * (ep * r) ** 3 + 210 * (
+                            ep * r) ** 4 + 21 * (ep * r) ** 5 + 1 * (ep * r) ** 6)
+        elif k == 7:
+            self.name = 'mat7'
+            self.rbf = lambda ep, r: torch.exp(-ep * r) * (
+                        135135 + 135135 * ep * r + 62370 * (ep * r) ** 2 + 17325 * (ep * r) ** 3 + 3150 * (
+                            ep * r) ** 4 + 378 * (ep * r) ** 5 + 28 * (ep * r) ** 6 + 1 * (ep * r) ** 7)
+        else:
+            self.name = None
+            self.rbf = None
+            raise Exception('This Matern kernel is not implemented')
+
+
+class HatFunction(RBF):
+    def __init__(self, ep=1):
+        self.ep = ep
+        self.rbf = lambda ep, r: torch.clamp(1 - ep*r, min=0)
+        self.name = 'Hat function'
+
+
+class PowerKernels(RBF):
+    def __init__(self, ep=1):
+        self.ep = ep
+        self.rbf = lambda ep, r: torch.abs(r)
+        self.name = 'Power kernel'
+
+
+class Wendland(RBF):
+    def __init__(self, ep=1, k=0, d=1):
+        self.ep = ep
+        self.name = 'wen_' + str(d) + '_' + str(k)
+        l = np.floor(d / 2) + k + 1
+        if k == 0:
+            p = lambda r: 1
+        elif k == 1:
+            p = lambda r: (l + 1) * r + 1
+        elif k == 2:
+            p = lambda r: (l + 3) * (l + 1) * r ** 2 + 3 * (l + 2) * r + 3
+        elif k == 3:
+            p = lambda r: (l + 5) * (l + 3) * (l + 1) * r ** 3 + (45 + 6 * l * (l + 6)) * r ** 2 + (
+                        15 * (l + 3)) * r + 15
+        elif k == 4:
+            p = lambda r: (l + 7) * (l + 5) * (l + 3) * (l + 1) * r ** 4 + (
+                        5 * (l + 4) * (21 + 2 * l * (8 + l))) * r ** 3 + (45 * (14 + l * (l + 8))) * r ** 2 + (
+                                      105 * (l + 4)) * r + 105
+        else:
+            raise Exception('This Wendland kernel is not implemented')
+        c = np.math.factorial(l + 2 * k) / np.math.factorial(l)
+        e = l + k
+        # self.rbf = lambda ep, r: np.maximum(1 - ep * r, 0) ** e * p(ep * r) / c
+        self.rbf = lambda ep, r: torch.clamp(1 - ep * r, min=0) ** e * p(ep * r) / c
+
+
+#  # Polynomial kernels
+# class Polynomial(Kernel):
+#     def __init__(self, a=0, p=1):
+#         self.a = a
+#         self.p = p
+#
+#     def eval(self, x, y):
+#         return (np.atleast_2d(x) @ np.atleast_2d(y).transpose() + self.a) ** self.p
+#
+#     def diagonal(self, X):
+#         return ((np.linalg.norm(X, axis=1) + self.a) ** self.p)[:, None]
+#
+#     def __str__(self):
+#      return 'polynomial' + ' [a = %2.2e, p = %2.2e]' % (self.a, self.p)
+#
+#     def set_params(self, par):
+#         self.a = par[0]
+#         self.p = par[1]
+
+# A demo usage
+def main():
+    ker = Gaussian()
+
+    x = torch.linspace(-5, 5, 10001)[:, None]
+    y = torch.zeros(1, 1)
+    A = ker.eval(x, y)
+    B = torch.exp(-x ** 2)
+
+    fig = plt.figure(1)
+    fig.clf()
+    ax = fig.gca()
+    ax.plot(x, A)
+    ax.plot(x, B)
+    ax.set_title('A kernel: ' + str(ker))
+    plt.show()
+
+
+if __name__ == '__main__':
+    main()
diff --git a/utilities.py b/utilities.py
new file mode 100644
index 0000000000000000000000000000000000000000..715f976fb2b56774d1801704ecbbca7b2665425d
--- /dev/null
+++ b/utilities.py
@@ -0,0 +1,145 @@
+from torch import nn
+import torch
+import numpy as np
+
+
+
+def compute_cv_loss_via_rippa_ext_2(kernel_matrix, y, n_folds, reg_for_matrix_inversion):
+    """
+    Implementation without the need to provide a kernel and points: Simply provide the kernel matrix
+    """
+
+    # Some precomputations
+    kernel_matrix_reg = kernel_matrix + reg_for_matrix_inversion * torch.eye(kernel_matrix.shape[0])
+    inv_kernel_matrix = torch.inverse(kernel_matrix_reg)
+    coeffs = torch.linalg.solve(kernel_matrix_reg, y) #[0]
+
+    # Some initializations and preparations: It is required that n_folds divides y.shape[0] without remainder
+    array_error = torch.zeros(y.shape[0], 1)
+    n_per_fold = int(y.shape[0] / n_folds)
+    indices = torch.arange(0, y.shape[0]).view(n_per_fold, n_folds)
+
+    # Standard Rippa's scheme
+    if n_folds == y.shape[0]:
+        array_error = coeffs / torch.diag(inv_kernel_matrix).view(-1,1)
+
+    # Extended Rippa's scheme
+    else:
+        for j in range(n_folds):
+            inv_kernel_matrix_loc1 = inv_kernel_matrix[indices[:, j], :]
+            inv_kernel_matrix_loc = inv_kernel_matrix_loc1[:, indices[:, j]]
+
+            array_error[j * n_per_fold: (j+1) * n_per_fold, 0] = \
+                (torch.linalg.solve(inv_kernel_matrix_loc, coeffs[indices[:, j]])).view(-1)
+
+    cv_error_sq = torch.sum(array_error ** 2) / array_error.numel()
+
+    return cv_error_sq, array_error
+
+
+class OptimizedKernel(torch.nn.Module):
+
+    def __init__(self, kernel, dim,
+                 reg_para=1e-5, learning_rate=1e-3, n_epochs=100, batch_size=32, n_folds=None,
+                 flag_initialize_diagonal=False, flag_symmetric_A=False):
+        super().__init__()
+
+        # Some settings, mostly optimization related
+        self.kernel = kernel
+
+        self.dim = dim
+        self.reg_para = reg_para
+        self.learning_rate = learning_rate
+        self.n_epochs = n_epochs
+        self.batch_size = batch_size
+
+        self.flag_symmetric_A = flag_symmetric_A
+
+        # Define linear maps - hardcoded
+        if torch.is_tensor(flag_initialize_diagonal):
+            self.B = nn.Parameter(flag_initialize_diagonal, requires_grad=True)
+        elif flag_initialize_diagonal:
+            self.B = nn.Parameter(torch.eye(self.dim, self.dim), requires_grad=True)
+        else:
+            self.B = nn.Parameter(torch.rand(self.dim, self.dim), requires_grad=True)
+
+        if self.flag_symmetric_A:
+            self.A = (self.B + self.B.t()) / 2
+        else:
+            self.A = self.B
+
+
+        if n_folds is None:
+            self.n_folds = self.batch_size
+        else:
+            self.n_folds = n_folds
+
+
+        # Set optimizer and scheduler
+        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
+        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=50, gamma=.7)
+
+        # Initliaze lists from tracking
+        self.list_obj = []
+        self.list_parameters = []
+
+
+    def optimize(self, X, y, flag_optim_verbose=True):
+
+        assert X.shape[0] == y.shape[0], 'Data sizes do not match'
+        n_batches = X.shape[0] // self.batch_size
+
+        # Append initial parameters
+        if self.flag_symmetric_A:
+            self.list_parameters.append(torch.clone((self.B + self.B.t()) / 2).detach().numpy())
+        else:
+            self.list_parameters.append(torch.clone(self.A).detach().numpy())
+
+
+        for idx_epoch in range(self.n_epochs):
+            shuffle = np.random.permutation(X.shape[0])  # reshuffle the data set every epoch
+
+            list_obj_loc = []
+
+            for idx_batch in range(n_batches):
+
+                # Select minibatch from the data
+                ind = shuffle[idx_batch * self.batch_size : (idx_batch + 1) * self.batch_size]
+                Xb, yb = X[ind, :], y[ind, :]
+
+                # Compute kernel matrix for minibatch
+                kernel_matrix = self.kernel.eval(Xb @ self.A, Xb @ self.A)
+
+                # use cross validation loss via rippa to assess the error
+                optimization_objective, _ = compute_cv_loss_via_rippa_ext_2(
+                    kernel_matrix, yb, self.n_folds, self.reg_para)
+
+                # Keep track of optimization quantity within epoch
+                list_obj_loc.append(optimization_objective.detach().item())
+                if idx_epoch == 0 and flag_optim_verbose:
+                    print('First epoch: Iteration {:5d}: Training objective: {:.3e}'.format(
+                        idx_batch, optimization_objective.detach().item()))
+
+                # Do optimization stuff
+                optimization_objective.backward()
+                self.optimizer.step()  # do one optimization step
+                self.optimizer.zero_grad()  # set gradients to zero
+
+                if self.flag_symmetric_A:
+                    self.A = (self.B + self.B.t()) / 2
+                else:
+                    self.A = self.B
+
+            # Keep track of some quantities and print something
+            mean_obj = np.mean(list_obj_loc)
+
+            if flag_optim_verbose:
+                print('Epoch {:5d} finished, mean training objective: {:.3e}.'.format(
+                    idx_epoch + 1, mean_obj))
+
+            self.list_obj.append(mean_obj)
+
+            self.list_parameters.append(torch.clone(self.A).detach().numpy())
+
+
+
diff --git a/vkoga.py b/vkoga.py
new file mode 100644
index 0000000000000000000000000000000000000000..6625bb22334b2b71e0fe6b56e2bb4cdd4c4e5461
--- /dev/null
+++ b/vkoga.py
@@ -0,0 +1,394 @@
+#!/usr/bin/env python3
+
+from kernels import Matern
+import numpy as np
+from sklearn.base import BaseEstimator
+from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
+from scipy.spatial import distance_matrix
+import torch
+from utilities import OptimizedKernel
+    
+# VKOGA implementation
+class VKOGA_2L(BaseEstimator):
+                                          
+    def __init__(self, kernel=Matern(k=2), flag_2L_optimization=False,
+                 verbose=True, n_report=10,
+                 greedy_type='f_greedy', reg_par=0, restr_par=0,
+                 tol_f=1e-10, tol_p=1e-10,
+                 reg_para_optim=1e-3, learning_rate=5e-3, n_epochs_optim=10, batch_size=32):
+        
+        # Set the verbosity on/off
+        self.verbose = verbose
+        
+        # Set the frequency of report
+        self.n_report = n_report
+        
+        # Set the params defining the method 
+        self.flag_2L_optimization = flag_2L_optimization
+        if self.flag_2L_optimization:
+            assert type(kernel) is type([]), 'If flag_2L_optimization, then two kernels need to be provided!'
+            self.kernel_t = kernel[1]
+            self.kernel = kernel[0]
+        else:
+            self.kernel_t = None
+            self.kernel = kernel
+
+        self.greedy_type = greedy_type
+        self.reg_par = reg_par
+        self.restr_par = restr_par
+
+        self.flag_val = None
+
+        self.reg_para_optim = reg_para_optim
+        self.learning_rate = learning_rate
+        self.n_epochs_optim = n_epochs_optim
+        self.batch_size = batch_size
+
+        # Set the stopping values
+        self.tol_f = tol_f
+        self.tol_p = tol_p
+
+        # Some further settings
+        self.ctrs_ = None
+        self.Cut_ = None
+        self.c = None
+
+
+        # Initialize the convergence history
+        self.train_hist = {}
+        self.train_hist['n'] = []
+        self.train_hist['f'] = []
+        self.train_hist['p'] = []
+        self.train_hist['p selected'] = []              # list of selected power vals
+        self.train_hist['f val'] = []
+        self.train_hist['p val'] = []
+        
+    def selection_rule(self, f, p):
+        if self.restr_par > 0:
+            p_ = np.max(p)
+            restr_idx = np.nonzero(p >= self.restr_par * p_)[0]
+        else:
+            restr_idx = np.arange(len(p))
+
+        f = np.sum(f ** 2, axis=1)
+        if self.greedy_type == 'f_greedy':
+            idx = np.argmax(f[restr_idx])
+            idx = restr_idx[idx]
+            f_max = np.max(f)
+            p_max = np.max(p)
+        elif self.greedy_type == 'fp_greedy':
+            idx = np.argmax(f[restr_idx] / p[restr_idx])
+            idx = restr_idx[idx]
+            f_max = np.max(f)
+            p_max = np.max(p)
+        elif self.greedy_type == 'p_greedy':
+            f_max = np.max(f)
+            idx = np.argmax(p)
+            p_max = p[idx]
+        elif self.greedy_type == 'rand':    # pick some random point - David Holzmüller asked me about its performance
+            f_max = np.max(f)
+            p_max = np.max(p)
+            idx = np.random.randint(len(p))
+        return idx, f_max, p_max
+
+    def fit(self, X, y, X_val=None, y_val=None, maxIter=None):
+
+        # Check the datasets
+        X, y = check_X_y(X, y, multi_output=True)
+        if len(y.shape) == 1:
+            y = y[:, None]
+
+        if X_val is None or y_val is None:
+            X_val = None
+            y_val = None
+            self.flag_val = False
+        else:
+            self.flag_val = True
+            X_val, y_val = check_X_y(X_val, y_val, multi_output=True)
+            # We will assume in the following that X_val and X are disjoint
+
+            if len(y_val.shape) == 1:
+                y_val = y_val[:, None]
+
+
+        # Check whether already fitted - restart in case we used kernel optimization
+        if self.ctrs_ is None or self.flag_2L_optimization:
+            self.ctrs_ = np.zeros((0, X.shape[1]))
+            self.Cut_ = np.zeros((0, 0))
+            self.c = np.zeros((0, y.shape[1]))
+
+
+        # Check whether "new X" and previously chosen centers overlap
+        list_truly_new_X = []
+        if self.ctrs_.shape[0] > 0:
+            for idx_x in range(X.shape[0]):
+                if min(np.linalg.norm(self.ctrs_ - X[idx_x, :], axis=1)) < 1e-10:
+                    continue
+                else:
+                    list_truly_new_X.append(idx_x)
+        else:
+            list_truly_new_X = list(range(X.shape[0]))
+        X = X[list_truly_new_X, :]
+        y = y[list_truly_new_X, :]
+
+
+        # Optimize the kernel
+        if self.flag_2L_optimization:
+            model_OptimKernel = OptimizedKernel(kernel=self.kernel_t, dim=X.shape[1],
+                                                reg_para=self.reg_para_optim, learning_rate=self.learning_rate,
+                                                n_epochs=self.n_epochs_optim, batch_size=self.batch_size,
+                                                n_folds=None, flag_initialize_diagonal=True,
+                                                flag_symmetric_A=False)
+            model_OptimKernel.optimize(torch.from_numpy(X).float(), torch.from_numpy(y).float(),
+                                       flag_optim_verbose=True)
+
+            self.A = model_OptimKernel.A.detach().numpy()
+
+            X = X @ self.A      # we can incorporate the kernel modification directly into the data
+
+        else:
+            self.A = np.eye(X.shape[1])
+
+
+        # Initialize the residual and update the given y values by substracting the current model
+        y = np.array(y)
+        if len(y.shape) == 1:
+            y = y[:, None]
+        y = y - self.predict(X)
+        if self.flag_val:
+            y_val = y_val - self.predict(X_val)
+
+
+        # Get the data dimension
+        N, q = y.shape
+        if self.flag_val:
+            N_val = y_val.shape[0]
+
+
+        # Set maxIter_continue
+        if maxIter is None or maxIter > N:
+            self.maxIter = 100
+        else:
+            self.maxIter = maxIter
+
+
+        # Check compatibility of restriction
+        if self.greedy_type == 'p_greedy':
+            self.restr_par = 0
+        if not self.reg_par == 0:
+            self.restr_par = 0
+
+
+        # Initialize list for the chosen and non-chosen indices
+        indI_ = []
+        notIndI = list(range(N))
+        c = np.zeros((self.maxIter, q))
+
+
+        # Compute the Newton basis values (related to the old centers) on the new X
+        Vx_new_X_old_ctrs = self.kernel.eval(X, self.ctrs_) @ self.Cut_.transpose()
+        if self.flag_val:
+            Vx_val_new_X_old_ctrs = self.kernel.eval(X_val, self.ctrs_) @ self.Cut_.transpose()
+
+
+        # Initialize arrays for the Newton basis values (related to the new centers) on the new X
+        Vx = np.zeros((N, self.maxIter))
+        if self.flag_val:
+            Vx_val = np.zeros((N_val, self.maxIter))
+
+
+        # Compute the powervals on X and X_val
+        p = self.kernel.diagonal(X) + self.reg_par
+        p = p - np.sum(Vx_new_X_old_ctrs ** 2, axis=1)
+        if self.flag_val:
+            p_val = self.kernel.diagonal(X_val) + self.reg_par
+            p_val = p_val - np.sum(Vx_val_new_X_old_ctrs ** 2, axis=1)
+
+
+        # Extend Cut_ matrix, i.e. continue to build on old self.Cut_ matrix
+        N_ctrs_so_far = self.Cut_.shape[0]
+        Cut_ = np.zeros((N_ctrs_so_far + self.maxIter, N_ctrs_so_far + self.maxIter))
+        Cut_[:N_ctrs_so_far, :N_ctrs_so_far] = self.Cut_
+
+
+        # Iterative selection of new points
+        self.print_message('begin')
+        for n in range(self.maxIter):
+            # prepare
+            self.train_hist['n'].append(self.ctrs_.shape[0] + n + 1)
+            self.train_hist['f'].append([])
+            self.train_hist['p'].append([])
+            self.train_hist['p selected'].append([])
+            if self.flag_val:
+                self.train_hist['p val'].append([])
+                self.train_hist['f val'].append([])
+
+            # select the current index
+            idx, self.train_hist['f'][-1], self.train_hist['p'][-1] = self.selection_rule(y[notIndI], p[notIndI])
+            self.train_hist['p selected'][-1] = p[notIndI[idx]]
+            if self.flag_val:
+                self.train_hist['p val'][-1] = np.max(p_val)
+                self.train_hist['f val'][-1] = np.max(np.sum(y_val ** 2, axis=1))
+
+            # add the current index
+            indI_.append(notIndI[idx])
+
+            # check if the tolerances are reacheded
+            if self.train_hist['f'][n] <= self.tol_f:
+                n = n - 1
+                self.print_message('end')
+                break
+            if self.train_hist['p'][n] <= self.tol_p:
+                n = n - 1
+                self.print_message('end')
+                break
+
+            # compute the nth basis (including normalization)# ToDo: Also old Vx need to be substracted here!
+            Vx[notIndI, n] = self.kernel.eval(X[notIndI, :], X[indI_[n], :])[:, 0]\
+                 - Vx_new_X_old_ctrs[notIndI, :] @ Vx_new_X_old_ctrs[indI_[n], :].transpose()\
+                 - Vx[notIndI, :n+1] @ Vx[indI_[n], 0:n+1].transpose()
+            Vx[indI_[n], n] += self.reg_par
+            Vx[notIndI, n] = Vx[notIndI, n] / np.sqrt(p[indI_[n]])
+
+            if self.flag_val:
+                Vx_val[:, n] = self.kernel.eval(X_val, X[indI_[n], :])[:, 0]\
+                    - Vx_val_new_X_old_ctrs[:, :] @ Vx_new_X_old_ctrs[indI_[n], :].transpose()\
+                    - Vx_val[:, :n+1] @ Vx[indI_[n], 0:n+1].transpose()
+                Vx_val[:, n] = Vx_val[:, n] / np.sqrt(p[indI_[n]])
+
+
+            # update the change of basis
+            Cut_new_row = np.ones(N_ctrs_so_far + n + 1)
+            Cut_new_row[:N_ctrs_so_far + n] = \
+                -np.concatenate((Vx_new_X_old_ctrs[indI_[n], :], Vx[indI_[n], :n])) \
+                @ Cut_[:N_ctrs_so_far + n, :N_ctrs_so_far + n]
+            Cut_[N_ctrs_so_far + n, :N_ctrs_so_far + n + 1] = Cut_new_row / Vx[indI_[n], n]
+
+            # compute the nth coefficient
+            c[n] = y[indI_[n]] / np.sqrt(p[indI_[n]])
+
+            # update the power function
+            p[notIndI] = p[notIndI] - Vx[notIndI, n] ** 2
+            if self.flag_val:
+                p_val = p_val - Vx_val[:, n] ** 2
+
+            # update the residual
+            y[notIndI] = y[notIndI] - Vx[notIndI, n][:, None] * c[n]
+            if self.flag_val:
+                y_val = y_val - Vx_val[:, n][:, None] * c[n]
+
+            # remove the nth index from the dictionary
+            notIndI.pop(idx)
+
+            # Report some data every now and then
+            if n % self.n_report == 0:
+                self.print_message('track')
+
+        else:
+            self.print_message('end')
+
+        # Define coefficients and centers
+        self.c =  np.concatenate((self.c, c[:n + 1]))
+        self.Cut_ = Cut_[:N_ctrs_so_far + n + 1, :N_ctrs_so_far + n + 1]
+        self.indI_ = indI_[:n + 1]     # Mind: These are only the indices of the latest points
+        self.coef_ = self.Cut_.transpose() @ self.c
+        self.ctrs_ = np.concatenate((self.ctrs_, X[self.indI_, :]), axis=0)
+
+
+        return self
+
+
+    def predict(self, X):
+        # Check is fit has been called
+        # check_is_fitted(self, 'coef_')     # ToDo: Remove this one!
+
+        # Validate the input
+        X = check_array(X)
+
+        if self.flag_2L_optimization:
+            X = X @ self.A
+
+        # Compute prediction
+        if self.ctrs_.shape[0] > 0:
+            prediction = self.kernel.eval(X, self.ctrs_) @ self.coef_
+        else:
+            prediction = np.zeros((X.shape[0], 1))
+
+        return prediction
+        ### TODO: replace with eval prod
+
+    def predict_P(self, X, n=None):
+        # Prediction of the power function, copied from pgreedy.py
+
+        # Try to do nothing
+        if self.ctrs_ is None or n == 0:
+            return self.kernel.diagonal(X)
+
+        # Otherwise check if everything is ok
+        # Check is fit has been called
+        check_is_fitted(self, 'ctrs_')
+        # Validate the input
+        X = check_array(X)
+
+        # Decide how many centers to use
+        if n is None:
+            n = np.atleast_2d(self.ctrs_).shape[0]
+
+        # Evaluate the power function on the input
+        if self.flag_2L_optimization:
+            X = X @ self.A
+
+        p = self.kernel.diagonal(X) - np.sum(
+            (self.kernel.eval(X, np.atleast_2d(self.ctrs_)[:n]) @ self.Cut_[:n, :n].transpose()) ** 2, axis=1)
+
+        return p
+
+    def print_message(self, when):
+        if self.verbose and when == 'begin':
+            print('')
+            print('*' * 30 + ' [VKOGA] ' + '*' * 30)
+            print('Training model with')
+            print('       |_ kernel              : %s' % self.kernel)
+            print('       |_ regularization par. : %2.2e' % self.reg_par)
+            print('       |_ restriction par.    : %2.2e' % self.restr_par)
+            print('')
+            
+        if self.verbose and when == 'end':
+            print('Training completed with')
+            print('       |_ selected points     : %8d / %8d' % (self.train_hist['n'][-1], self.ctrs_.shape[0] + self.maxIter))
+            if self.flag_val:
+                print('       |_ train, val residual : %2.2e / %2.2e,    %2.2e' %
+                      (self.train_hist['f'][-1], self.tol_f, self.train_hist['f val'][-1]))
+                print('       |_ train, val power fun: %2.2e / %2.2e,    %2.2e' %
+                      (self.train_hist['p'][-1], self.tol_p, self.train_hist['p val'][-1]))
+            else:
+                print('       |_ train residual      : %2.2e / %2.2e' % (self.train_hist['f'][-1], self.tol_f))
+                print('       |_ train power fun     : %2.2e / %2.2e' % (self.train_hist['p'][-1], self.tol_p))
+                        
+        if self.verbose and when == 'track':
+            print('Training ongoing with')
+            print('       |_ selected points     : %8d / %8d' % (self.train_hist['n'][-1], self.ctrs_.shape[0] + self.maxIter))
+            if self.flag_val:
+                print('       |_ train, val residual : %2.2e / %2.2e,    %2.2e' %
+                      (self.train_hist['f'][-1], self.tol_f, self.train_hist['f val'][-1]))
+                print('       |_ train, val power fun: %2.2e / %2.2e,    %2.2e' %
+                      (self.train_hist['p'][-1], self.tol_p, self.train_hist['p val'][-1]))
+            else:
+                print('       |_ train residual      : %2.2e / %2.2e' % (self.train_hist['f'][-1], self.tol_f))
+                print('       |_ train power fun     : %2.2e / %2.2e' % (self.train_hist['p'][-1], self.tol_p))
+
+
+
+#%% Utilities to 
+import pickle
+def save_object(obj, filename):
+    with open(filename, 'wb') as output:
+        pickle.dump(obj, output, pickle.HIGHEST_PROTOCOL)
+
+def load_object(filename):        
+    with open(filename, 'rb') as input:
+        obj = pickle.load(input)    
+    return obj
+
+
+