Source code for dalib.modules.kernels

from typing import Optional, List
import torch
import torch.nn as nn
from numpy import array, dot
import numpy as np
from qpsolvers import solve_qp


__all__ = ['GaussianKernel']


[docs]class GaussianKernel(nn.Module): r"""Gaussian Kernel Matrix Gaussian Kernel k is defined by .. math:: k(x_1, x_2) = \exp \left( - \dfrac{\| x_1 - x_2 \|^2}{2\sigma^2} \right) where :math:`x_1, x_2 \in R^d` are 1-d tensors. Gaussian Kernel Matrix K is defined on input group :math:`X=(x_1, x_2, ..., x_m),` .. math:: K(X)_{i,j} = k(x_i, x_j) Also by default, during training this layer keeps running estimates of the mean of L2 distances, which are then used to set hyperparameter :math:`\sigma`. Mathematically, the estimation is :math:`\sigma^2 = \dfrac{\alpha}{n^2}\sum_{i,j} \| x_i - x_j \|^2`. If :attr:`track_running_stats` is set to ``False``, this layer then does not keep running estimates, and use a fixed :math:`\sigma` instead. Parameters: - sigma (float, optional): bandwidth :math:`\sigma`. Default: None - track_running_stats (bool, optional): If ``True``, this module tracks the running mean of :math:`\sigma^2`. Otherwise, it won't track such statistics and always uses fix :math:`\sigma^2`. Default: ``True`` - alpha (float, optional): :math:`\alpha` which decides the magnitude of :math:`\sigma^2` when track_running_stats is set to ``True`` Inputs: - X (tensor): input group :math:`X` Shape: - Inputs: :math:`(minibatch, F)` where F means the dimension of input features. - Outputs: :math:`(minibatch, minibatch)` """ def __init__(self, sigma: Optional[float] = None, track_running_stats: Optional[bool] = True, alpha: Optional[float] = 1.): super(GaussianKernel, self).__init__() assert track_running_stats or sigma is not None self.sigma_square = torch.tensor(sigma * sigma) if sigma is not None else None self.track_running_stats = track_running_stats self.alpha = alpha def forward(self, X: torch.Tensor) -> torch.Tensor: l2_distance_square = ((X.unsqueeze(0) - X.unsqueeze(1)) ** 2).sum(2) if self.track_running_stats: self.sigma_square = self.alpha * torch.mean(l2_distance_square.detach()) return torch.exp(-l2_distance_square / (2 * self.sigma_square))
def optimal_kernel_combinations(kernel_values: List[torch.Tensor]) -> torch.Tensor: # use quadratic program to get optimal kernel num_kernel = len(kernel_values) kernel_values_numpy = array([float(k.detach().cpu().data.item()) for k in kernel_values]) if np.all(kernel_values_numpy <= 0): beta = solve_qp( P=-np.eye(num_kernel), q=np.zeros(num_kernel), A=kernel_values_numpy, b=np.array([-1.]), G=-np.eye(num_kernel), h=np.zeros(num_kernel), ) else: beta = solve_qp( P=np.eye(num_kernel), q=np.zeros(num_kernel), A=kernel_values_numpy, b=np.array([1.]), G=-np.eye(num_kernel), h=np.zeros(num_kernel), ) beta = beta / beta.sum(axis=0) * num_kernel # normalize return sum([k * b for (k, b) in zip(kernel_values, beta)])