Source code for dalib.adaptation.jan

from typing import Optional, Sequence
import torch
import torch.nn as nn
from dalib.modules.classifier import Classifier as ClassifierBase
from dalib.modules.grl import GradientReverseLayer
from dalib.modules.kernels import GaussianKernel
from .dan import _update_index_matrix


__all__ = ['JointMultipleKernelMaximumMeanDiscrepancy', 'ImageClassifier']



[docs]class JointMultipleKernelMaximumMeanDiscrepancy(nn.Module): r"""The Joint Multiple Kernel Maximum Mean Discrepancy (JMMD) used in `Deep Transfer Learning with Joint Adaptation Networks <https://arxiv.org/abs/1605.06636>`_ Given source domain :math:`\mathcal{D}_s` of :math:`n_s` labeled points and target domain :math:`\mathcal{D}_t` of :math:`n_t` unlabeled points drawn i.i.d. from P and Q respectively, the deep networks will generate activations in layers :math:`\mathcal{L}` as :math:`\{(z_i^{s1}, ..., z_i^{s|\mathcal{L}|})\}_{i=1}^{n_s}` and :math:`\{(z_i^{t1}, ..., z_i^{t|\mathcal{L}|})\}_{i=1}^{n_t}`. The empirical estimate of :math:`\hat{D}_{\mathcal{L}}(P, Q)` is computed as the squared distance between the empirical kernel mean embeddings as .. math:: \hat{D}_{\mathcal{L}}(P, Q) &= \dfrac{1}{n_s^2} \sum_{i=1}^{n_s}\sum_{j=1}^{n_s} \prod_{l\in\mathcal{L}} k^l(z_i^{sl}, z_j^{sl}) \\ &+ \dfrac{1}{n_t^2} \sum_{i=1}^{n_t}\sum_{j=1}^{n_t} \prod_{l\in\mathcal{L}} k^l(z_i^{tl}, z_j^{tl}) \\ &- \dfrac{2}{n_s n_t} \sum_{i=1}^{n_s}\sum_{j=1}^{n_t} \prod_{l\in\mathcal{L}} k^l(z_i^{sl}, z_j^{tl}). \\ Parameters: - **kernels** (tuple(tuple(`nn.Module`))): kernel functions, where `kernels[r]` corresponds to kernel :math:`k^{\mathcal{L}[r]}`. - **linear** (bool): whether use the linear version of JAN. Default: False - **thetas** (list(`Theta`): use adversarial version JAN if not None. Default: None Inputs: z_s, z_t - **z_s** (tuple(tensor)): multiple layers' activations from the source domain, :math:`z^s` - **z_t** (tuple(tensor)): multiple layers' activations from the target domain, :math:`z^t` Shape: - :math:`z^{sl}` and :math:`z^{tl}`: :math:`(minibatch, *)` where * means any dimension - Outputs: scalar .. note:: Activations :math:`z^{sl}` and :math:`z^{tl}` must have the same shape. .. note:: The kernel values will add up when there are multiple kernels for a certain layer. Examples:: >>> feature_dim = 1024 >>> batch_size = 10 >>> layer1_kernels = (GaussianKernel(alpha=0.5), GaussianKernel(1.), GaussianKernel(2.)) >>> layer2_kernels = (GaussianKernel(1.), ) >>> loss = JointMultipleKernelMaximumMeanDiscrepancy((layer1_kernels, layer2_kernels)) >>> # layer1 features from source domain and target domain >>> z1_s, z1_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim) >>> # layer2 features from source domain and target domain >>> z2_s, z2_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim) >>> output = loss((z1_s, z2_s), (z1_t, z2_t)) """ def __init__(self, kernels: Sequence[Sequence[nn.Module]], linear: Optional[bool] = True, thetas: Sequence[nn.Module] = None): super(JointMultipleKernelMaximumMeanDiscrepancy, self).__init__() self.kernels = kernels self.index_matrix = None self.linear = linear if thetas: self.thetas = thetas else: self.thetas = [nn.Identity() for _ in kernels] def forward(self, z_s: torch.Tensor, z_t: torch.Tensor) -> torch.Tensor: batch_size = int(z_s[0].size(0)) self.index_matrix = _update_index_matrix(batch_size, self.index_matrix, self.linear).to(z_s[0].device) kernel_matrix = torch.ones_like(self.index_matrix) for layer_z_s, layer_z_t, layer_kernels, theta in zip(z_s, z_t, self.kernels, self.thetas): layer_features = torch.cat([layer_z_s, layer_z_t], dim=0) layer_features = theta(layer_features) kernel_matrix *= sum( [kernel(layer_features) for kernel in layer_kernels]) # Add up the matrix of each kernel # Add 2 / (n-1) to make up for the value on the diagonal # to ensure loss is positive in the non-linear version loss = (kernel_matrix * self.index_matrix).sum() + 2. / float(batch_size - 1) return loss
class Theta(nn.Module): """ maximize loss respect to :math:`\theta` minimize loss respect to features """ def __init__(self, dim: int): super(Theta, self).__init__() self.grl1 = GradientReverseLayer() self.grl2 = GradientReverseLayer() self.layer1 = nn.Linear(dim, dim) nn.init.eye_(self.layer1.weight) nn.init.zeros_(self.layer1.bias) def forward(self, features: torch.Tensor) -> torch.Tensor: features = self.grl1(features) return self.grl2(self.layer1(features)) class ImageClassifier(ClassifierBase): def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256): bottleneck = nn.Sequential( nn.Linear(backbone.out_features, bottleneck_dim), nn.BatchNorm1d(bottleneck_dim), nn.ReLU(), nn.Dropout(0.5) ) head = nn.Linear(bottleneck_dim, num_classes) super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, head)