From 79dc4c87288369eee34ed83bf96ea10dabeec53a Mon Sep 17 00:00:00 2001
From: Tizian Wenzel <wenzeltn@nbanm02.mathematik.uni-stuttgart.de>
Date: Thu, 8 Jun 2023 12:56:41 +0200
Subject: [PATCH] Modified some more stuff.

---
 section_4.1_compute_visualize.py | 3 +++
 section_4.2_compute.py           | 4 ++++
 section_4.2_visualize.py         | 4 ++++
 section_4.3_compute.py           | 3 +++
 section_4.3_visualize.py         | 3 +++
 utils/main_function.py           | 7 ++++---
 6 files changed, 21 insertions(+), 3 deletions(-)

diff --git a/section_4.1_compute_visualize.py b/section_4.1_compute_visualize.py
index 08af491..3f7e0e7 100644
--- a/section_4.1_compute_visualize.py
+++ b/section_4.1_compute_visualize.py
@@ -14,6 +14,9 @@ from scipy import io
 import os
 
 
+np.random.seed(1)
+
+
 ## Some settings
 # name_dataset = 'example_5d_faster_conv'
 name_dataset = 'example_6d_kink'
diff --git a/section_4.2_compute.py b/section_4.2_compute.py
index c60aa4a..d40c028 100644
--- a/section_4.2_compute.py
+++ b/section_4.2_compute.py
@@ -9,6 +9,10 @@ from utils.hyperparameters import dic_hyperparams
 
 from scipy import io
 import os
+import numpy as np
+
+
+np.random.seed(1)
 
 
 list_datasets = ['fried', 'sarcos', 'protein', 'ct', 'diamonds',
diff --git a/section_4.2_visualize.py b/section_4.2_visualize.py
index 1b63fa0..d78652c 100644
--- a/section_4.2_visualize.py
+++ b/section_4.2_visualize.py
@@ -9,6 +9,10 @@ from matplotlib import pyplot as plt
 from scipy import io
 import os
 import scipy
+import numpy as np
+
+
+np.random.seed(1)
 
 
 ## Some settings
diff --git a/section_4.3_compute.py b/section_4.3_compute.py
index eeed379..0618091 100644
--- a/section_4.3_compute.py
+++ b/section_4.3_compute.py
@@ -14,6 +14,9 @@ from utils.hyperparameters import dic_hyperparams
 from utils.main_function import run_everything
 
 
+np.random.seed(1)
+
+
 ## Some settings
 list_nctrs = [int(np.round(nr)) for nr in np.logspace(np.log(10) / np.log(10), np.log(1000) / np.log(10), 10)]
 
diff --git a/section_4.3_visualize.py b/section_4.3_visualize.py
index 8eb70a6..3743e45 100644
--- a/section_4.3_visualize.py
+++ b/section_4.3_visualize.py
@@ -13,6 +13,9 @@ import pickle
 import scipy
 
 
+np.random.seed(1)
+
+
 ## Some settings
 list_nctrs = [int(np.round(nr)) for nr in np.logspace(np.log(10) / np.log(10), np.log(1000) / np.log(10), 10)]
 
diff --git a/utils/main_function.py b/utils/main_function.py
index 79f5010..c1379c9 100644
--- a/utils/main_function.py
+++ b/utils/main_function.py
@@ -38,10 +38,12 @@ def run_everything(name_dataset, maxIter_vkoga, N_points, noise_level, reg_para_
 
     # Preprocessing
     if 'example' in name_dataset:
+        # No need to shuffle as dataset is randomly generated
         idx = np.arange(X.shape[0])
     else:
-        assert idx_rerun is not None, 'idx_rerun is not set!'
-        idx = np.load(path_for_indices + '_indices_{}/indices_'.format(idx_rerun) + name_dataset + '.npy')
+        # Random but fixed indices were removed, instead use randomly shuffled every time
+        idx = np.arange(X.shape[0])
+        np.random.shuffle(idx)
 
     n_train = int(.8 * X.shape[0])
 
@@ -50,7 +52,6 @@ def run_everything(name_dataset, maxIter_vkoga, N_points, noise_level, reg_para_
     X_test = X[idx[n_train:]]
     y_test = y[idx[n_train:]]
     X_train_torch, y_train_torch = torch.from_numpy(X_train).type(torch.float), torch.from_numpy(y_train).type(torch.float)
-    # ToDo: noise level removed! (was set to 0 so far!)
 
 
     ## Select kernel
-- 
GitLab