From 79dc4c87288369eee34ed83bf96ea10dabeec53a Mon Sep 17 00:00:00 2001 From: Tizian Wenzel <wenzeltn@nbanm02.mathematik.uni-stuttgart.de> Date: Thu, 8 Jun 2023 12:56:41 +0200 Subject: [PATCH] Modified some more stuff. --- section_4.1_compute_visualize.py | 3 +++ section_4.2_compute.py | 4 ++++ section_4.2_visualize.py | 4 ++++ section_4.3_compute.py | 3 +++ section_4.3_visualize.py | 3 +++ utils/main_function.py | 7 ++++--- 6 files changed, 21 insertions(+), 3 deletions(-) diff --git a/section_4.1_compute_visualize.py b/section_4.1_compute_visualize.py index 08af491..3f7e0e7 100644 --- a/section_4.1_compute_visualize.py +++ b/section_4.1_compute_visualize.py @@ -14,6 +14,9 @@ from scipy import io import os +np.random.seed(1) + + ## Some settings # name_dataset = 'example_5d_faster_conv' name_dataset = 'example_6d_kink' diff --git a/section_4.2_compute.py b/section_4.2_compute.py index c60aa4a..d40c028 100644 --- a/section_4.2_compute.py +++ b/section_4.2_compute.py @@ -9,6 +9,10 @@ from utils.hyperparameters import dic_hyperparams from scipy import io import os +import numpy as np + + +np.random.seed(1) list_datasets = ['fried', 'sarcos', 'protein', 'ct', 'diamonds', diff --git a/section_4.2_visualize.py b/section_4.2_visualize.py index 1b63fa0..d78652c 100644 --- a/section_4.2_visualize.py +++ b/section_4.2_visualize.py @@ -9,6 +9,10 @@ from matplotlib import pyplot as plt from scipy import io import os import scipy +import numpy as np + + +np.random.seed(1) ## Some settings diff --git a/section_4.3_compute.py b/section_4.3_compute.py index eeed379..0618091 100644 --- a/section_4.3_compute.py +++ b/section_4.3_compute.py @@ -14,6 +14,9 @@ from utils.hyperparameters import dic_hyperparams from utils.main_function import run_everything +np.random.seed(1) + + ## Some settings list_nctrs = [int(np.round(nr)) for nr in np.logspace(np.log(10) / np.log(10), np.log(1000) / np.log(10), 10)] diff --git a/section_4.3_visualize.py b/section_4.3_visualize.py index 8eb70a6..3743e45 100644 --- a/section_4.3_visualize.py +++ b/section_4.3_visualize.py @@ -13,6 +13,9 @@ import pickle import scipy +np.random.seed(1) + + ## Some settings list_nctrs = [int(np.round(nr)) for nr in np.logspace(np.log(10) / np.log(10), np.log(1000) / np.log(10), 10)] diff --git a/utils/main_function.py b/utils/main_function.py index 79f5010..c1379c9 100644 --- a/utils/main_function.py +++ b/utils/main_function.py @@ -38,10 +38,12 @@ def run_everything(name_dataset, maxIter_vkoga, N_points, noise_level, reg_para_ # Preprocessing if 'example' in name_dataset: + # No need to shuffle as dataset is randomly generated idx = np.arange(X.shape[0]) else: - assert idx_rerun is not None, 'idx_rerun is not set!' - idx = np.load(path_for_indices + '_indices_{}/indices_'.format(idx_rerun) + name_dataset + '.npy') + # Random but fixed indices were removed, instead use randomly shuffled every time + idx = np.arange(X.shape[0]) + np.random.shuffle(idx) n_train = int(.8 * X.shape[0]) @@ -50,7 +52,6 @@ def run_everything(name_dataset, maxIter_vkoga, N_points, noise_level, reg_para_ X_test = X[idx[n_train:]] y_test = y[idx[n_train:]] X_train_torch, y_train_torch = torch.from_numpy(X_train).type(torch.float), torch.from_numpy(y_train).type(torch.float) - # ToDo: noise level removed! (was set to 0 so far!) ## Select kernel -- GitLab