Select Git revision
utilities.py
utilities.py 3.38 KiB
# Define a dataset class
import math
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from utils import kernels
from utils import settings
class TorchDataset(Dataset):
'''
Implementation of a custom dataset class.
'''
def __init__(self, data_input, data_output):
self.data_input = data_input
self.data_output = data_output
def __len__(self):
return len(self.data_input)
def __getitem__(self, idx):
return (self.data_input[idx], self.data_output[idx])
def get_DataLoader(some_unused_imput=0):
'''
Implementation of a function that returns dataloaders, which can be used via pytorch-lightning.
:param some_unused_imput: Any meaningful input can implemented here
:return: train-, validation- and test-loader
'''
# Just use some random points and learn the projection onto the first two dimensions squared
train_dataset = torch.rand(10000, 5)
train_labels = train_dataset[:, :2]**2
test_dataset = torch.rand(2000, 5)
test_labels = test_dataset[:, :2]**2
# Create torchdatasets
dataset_train0 = TorchDataset(data_input=train_dataset, data_output=train_labels)
dataset_test = TorchDataset(data_input=test_dataset, data_output=test_labels)
# Split of dataset
train_size = int((1 - settings.val_split) * len(dataset_train0))
val_size = len(dataset_train0) - train_size
dataset_train, dataset_val = torch.utils.data.random_split(dataset_train0, [train_size, val_size])
# Define dataloaders
train_loader = DataLoader(dataset_train, batch_size=settings.batch_size,
num_workers=settings.num_workers)
val_loader = DataLoader(dataset_val, batch_size=settings.batch_size,
num_workers=settings.num_workers)
test_loader = DataLoader(dataset_test, batch_size=settings.batch_size,
num_workers=settings.num_workers)
return train_loader, val_loader, test_loader
class ActivFunc(torch.nn.Module):
'''
Implementation of the single-dimensional kernel layers of the SDKN, which can be viewed
as optimizable activation function layers.
'''
def __init__(self, in_features, nr_centers, kernel=None):
super().__init__()
# General stuff
self.in_features = in_features
self.nr_centers = nr_centers
self.nr_centers_id = nr_centers # number of centers + maybe additional dimension for identity
# Define kernel if not given
if kernel is None:
self.kernel = kernels.Wendland_order_0(ep=1)
else:
self.kernel = kernel
# Weight parameters
self.weight = torch.nn.Parameter(torch.Tensor(self.in_features, self.nr_centers_id))
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
self.weight.data += .2 * torch.ones(self.weight.data.shape) # provide some positive mean
def forward(self, x, centers):
cx = torch.cat((centers, x), 0)
dist_matrix = torch.abs(torch.unsqueeze(cx, 2) - centers.t().view(1, centers.shape[1], self.nr_centers))
kernel_matrix = self.kernel.rbf(self.kernel.ep, dist_matrix)
cx = torch.sum(kernel_matrix * self.weight, dim=2)
centers = cx[:self.nr_centers, :]
x = cx[self.nr_centers:, :]
return x, centers