from typing import Optional, List
import torch
import torch.nn as nn
from numpy import array, dot
import numpy as np
from qpsolvers import solve_qp
__all__ = ['GaussianKernel']
[docs]class GaussianKernel(nn.Module):
r"""Gaussian Kernel Matrix
Gaussian Kernel k is defined by
.. math::
k(x_1, x_2) = \exp \left( - \dfrac{\| x_1 - x_2 \|^2}{2\sigma^2} \right)
where :math:`x_1, x_2 \in R^d` are 1-d tensors.
Gaussian Kernel Matrix K is defined on input group :math:`X=(x_1, x_2, ..., x_m),`
.. math::
K(X)_{i,j} = k(x_i, x_j)
Also by default, during training this layer keeps running estimates of the
mean of L2 distances, which are then used to set hyperparameter :math:`\sigma`.
Mathematically, the estimation is :math:`\sigma^2 = \dfrac{\alpha}{n^2}\sum_{i,j} \| x_i - x_j \|^2`.
If :attr:`track_running_stats` is set to ``False``, this layer then does not
keep running estimates, and use a fixed :math:`\sigma` instead.
Parameters:
- sigma (float, optional): bandwidth :math:`\sigma`. Default: None
- track_running_stats (bool, optional): If ``True``, this module tracks the running mean of :math:`\sigma^2`.
Otherwise, it won't track such statistics and always uses fix :math:`\sigma^2`. Default: ``True``
- alpha (float, optional): :math:`\alpha` which decides the magnitude of :math:`\sigma^2` when track_running_stats is set to ``True``
Inputs:
- X (tensor): input group :math:`X`
Shape:
- Inputs: :math:`(minibatch, F)` where F means the dimension of input features.
- Outputs: :math:`(minibatch, minibatch)`
"""
def __init__(self, sigma: Optional[float] = None, track_running_stats: Optional[bool] = True,
alpha: Optional[float] = 1.):
super(GaussianKernel, self).__init__()
assert track_running_stats or sigma is not None
self.sigma_square = torch.tensor(sigma * sigma) if sigma is not None else None
self.track_running_stats = track_running_stats
self.alpha = alpha
def forward(self, X: torch.Tensor) -> torch.Tensor:
l2_distance_square = ((X.unsqueeze(0) - X.unsqueeze(1)) ** 2).sum(2)
if self.track_running_stats:
self.sigma_square = self.alpha * torch.mean(l2_distance_square.detach())
return torch.exp(-l2_distance_square / (2 * self.sigma_square))
def optimal_kernel_combinations(kernel_values: List[torch.Tensor]) -> torch.Tensor:
# use quadratic program to get optimal kernel
num_kernel = len(kernel_values)
kernel_values_numpy = array([float(k.detach().cpu().data.item()) for k in kernel_values])
if np.all(kernel_values_numpy <= 0):
beta = solve_qp(
P=-np.eye(num_kernel),
q=np.zeros(num_kernel),
A=kernel_values_numpy,
b=np.array([-1.]),
G=-np.eye(num_kernel),
h=np.zeros(num_kernel),
)
else:
beta = solve_qp(
P=np.eye(num_kernel),
q=np.zeros(num_kernel),
A=kernel_values_numpy,
b=np.array([1.]),
G=-np.eye(num_kernel),
h=np.zeros(num_kernel),
)
beta = beta / beta.sum(axis=0) * num_kernel # normalize
return sum([k * b for (k, b) in zip(kernel_values, beta)])