Source code for dalib.adaptation.dan

from typing import Optional, Sequence
import torch
import torch.nn as nn
from dalib.modules.classifier import Classifier as ClassifierBase
from dalib.modules.kernels import optimal_kernel_combinations


__all__ = ['MultipleKernelMaximumMeanDiscrepancy', 'ImageClassifier']


[docs]class MultipleKernelMaximumMeanDiscrepancy(nn.Module): r"""The Multiple Kernel Maximum Mean Discrepancy (MK-MMD) used in `Learning Transferable Features with Deep Adaptation Networks <https://arxiv.org/pdf/1502.02791>`_ Given source domain :math:`\mathcal{D}_s` of :math:`n_s` labeled points and target domain :math:`\mathcal{D}_t` of :math:`n_t` unlabeled points drawn i.i.d. from P and Q respectively, the deep networks will generate activations as :math:`\{z_i^s\}_{i=1}^{n_s}` and :math:`\{z_i^t\}_{i=1}^{n_t}`. The MK-MMD :math:`D_k (P, Q)` between probability distributions P and Q is defined as .. math:: D_k(P, Q) \triangleq \| E_p [\phi(z^s)] - E_q [\phi(z^t)] \|^2_{\mathcal{H}_k}, :math:`k` is a kernel function in the function space .. math:: \mathcal{K} \triangleq \{ k=\sum_{u=1}^{m}\beta_{u} k_{u} \} where :math:`k_{u}` is a single kernel. Using kernel trick, MK-MMD can be computed as .. math:: \hat{D}_k(P, Q) &= \dfrac{1}{n_s^2} \sum_{i=1}^{n_s}\sum_{j=1}^{n_s} k(z_i^{s}, z_j^{s}) \\ &+ \dfrac{1}{n_t^2} \sum_{i=1}^{n_t}\sum_{j=1}^{n_t} k(z_i^{t}, z_j^{t}) \\ &- \dfrac{2}{n_s n_t} \sum_{i=1}^{n_s}\sum_{j=1}^{n_t} k(z_i^{s}, z_j^{t}). \\ Parameters: - **kernels** (tuple(`nn.Module`)): kernel functions. - **linear** (bool): whether use the linear version of DAN. Default: False - **quadratic_program** (bool): whether use quadratic program to solve :math:`\beta`. Default: False Inputs: z_s, z_t - **z_s** (tensor): activations from the source domain, :math:`z^s` - **z_t** (tensor): activations from the target domain, :math:`z^t` Shape: - Inputs: :math:`(minibatch, *)` where * means any dimension - Outputs: scalar .. note:: Activations :math:`z^{s}` and :math:`z^{t}` must have the same shape. .. note:: The kernel values will add up when there are multiple kernels. Examples:: >>> from dalib.modules.kernels import GaussianKernel >>> feature_dim = 1024 >>> batch_size = 10 >>> kernels = (GaussianKernel(alpha=0.5), GaussianKernel(alpha=1.), GaussianKernel(alpha=2.)) >>> loss = MultipleKernelMaximumMeanDiscrepancy(kernels) >>> # features from source domain and target domain >>> z_s, z_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim) >>> output = loss(z_s, z_t) """ def __init__(self, kernels: Sequence[nn.Module], linear: Optional[bool] = False, quadratic_program: Optional[bool] = False): super(MultipleKernelMaximumMeanDiscrepancy, self).__init__() self.kernels = kernels self.index_matrix = None self.linear = linear self.quadratic_program = quadratic_program def forward(self, z_s: torch.Tensor, z_t: torch.Tensor) -> torch.Tensor: features = torch.cat([z_s, z_t], dim=0) batch_size = int(z_s.size(0)) self.index_matrix = _update_index_matrix(batch_size, self.index_matrix, self.linear).to(z_s.device) if not self.quadratic_program: kernel_matrix = sum([kernel(features) for kernel in self.kernels]) # Add up the matrix of each kernel # Add 2 / (n-1) to make up for the value on the diagonal # to ensure loss is positive in the non-linear version loss = (kernel_matrix * self.index_matrix).sum() + 2. / float(batch_size - 1) else: kernel_values = [(kernel(features) * self.index_matrix).sum() + 2. / float(batch_size - 1) for kernel in self.kernels] loss = optimal_kernel_combinations(kernel_values) return loss
def _update_index_matrix(batch_size: int, index_matrix: Optional[torch.Tensor] = None, linear: Optional[bool] = True) -> torch.Tensor: r""" Update the `index_matrix` which convert `kernel_matrix` to loss. If `index_matrix` is a tensor with shape (2 x batch_size, 2 x batch_size), then return `index_matrix`. Else return a new tensor with shape (2 x batch_size, 2 x batch_size). """ if index_matrix is None or index_matrix.size(0) != batch_size * 2: index_matrix = torch.zeros(2 * batch_size, 2 * batch_size) if linear: for i in range(batch_size): s1, s2 = i, (i + 1) % batch_size t1, t2 = s1 + batch_size, s2 + batch_size index_matrix[s1, s2] = 1. / float(batch_size) index_matrix[t1, t2] = 1. / float(batch_size) index_matrix[s1, t2] = -1. / float(batch_size) index_matrix[s2, t1] = -1. / float(batch_size) else: for i in range(batch_size): for j in range(batch_size): if i != j: index_matrix[i][j] = 1. / float(batch_size * (batch_size - 1)) index_matrix[i + batch_size][j + batch_size] = 1. / float(batch_size * (batch_size - 1)) for i in range(batch_size): for j in range(batch_size): index_matrix[i][j + batch_size] = -1. / float(batch_size * batch_size) index_matrix[i + batch_size][j] = -1. / float(batch_size * batch_size) return index_matrix class ImageClassifier(ClassifierBase): def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256): bottleneck = nn.Sequential( nn.Linear(backbone.out_features, bottleneck_dim), nn.ReLU(), nn.Dropout(0.5) ) super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, None)