DALIB Algorithms¶
The adaptation subpackage contains definitions for the following domain adaptation algorithms:
Besides specific algorithms, this package also provides a recommended image classifier for each algorithms.
We provide benchmarks of different domain adaptation algorithms on Office-31, Office-Home and VisDA-2017 as follows. Note that Origin means the accuracy reported by the original paper, while Avg is the accuracy reported by DALIB.
Office-31 accuracy on ResNet-50
| Methods | Origin | Avg | A → W | D → W | W → D | A → D | D → A | W → A |
| Source Only | 76.1 | 79.5 | 75.8 | 95.5 | 99.0 | 79.3 | 63.6 | 63.8 |
| DANN | 82.2 | 86.4 | 91.7 | 97.9 | 100.0 | 82.9 | 72.8 | 73.3 |
| DAN | 80.4 | 83.7 | 84.2 | 98.4 | 100.0 | 87.3 | 66.9 | 65.2 |
| JAN | 84.3 | 87.3 | 93.7 | 98.4 | 100.0 | 89.4 | 71.2 | 71.0 |
| CDAN | 87.7 | 88.7 | 93.1 | 98.6 | 100.0 | 93.4 | 75.6 | 71.5 |
| MCD | / | 85.9 | 91.8 | 98.6 | 100.0 | 89.0 | 69.0 | 66.9 |
| MDD | 88.9 | 89.2 | 93.6 | 98.6 | 100.0 | 93.6 | 76.7 | 72.9 |
Office-Home accuracy on ResNet-50
| Methods | Origin | Avg | Ar → Cl | Ar → Pr | Ar → Rw | Cl → Ar | Cl → Pr | Cl → Rw | Pr → Ar | Pr → Cl | Pr → Rw | Rw → Ar | Rw → Cl | Rw → Pr |
| Source Only | 46.1 | 58.2 | 41.5 | 65.8 | 73.6 | 52.2 | 59.5 | 63.6 | 51.5 | 36.4 | 71.3 | 65.2 | 42.8 | 75.4 |
| DANN | 57.6 | 65.5 | 52.7 | 61.8 | 73.4 | 57.4 | 67.2 | 69.6 | 57.2 | 55.4 | 79.0 | 71.4 | 60.0 | 81.1 |
| DAN | 56.3 | 61.6 | 45.5 | 67.9 | 73.9 | 57.6 | 63.7 | 66.2 | 55.2 | 39.7 | 74.3 | 66.8 | 49.1 | 78.7 |
| JAN | 58.3 | 65.9 | 50.4 | 71.8 | 76.7 | 60.0 | 67.7 | 68.9 | 60.4 | 49.8 | 77.0 | 71.2 | 55.6 | 81.0 |
| CDAN | 65.8 | 68.8 | 54.4 | 70.9 | 77.9 | 61.6 | 71.1 | 71.9 | 62.3 | 54.9 | 80.7 | 75.1 | 60.8 | 83.7 |
| MCD | / | 67.8 | 51.7 | 72.2 | 78.2 | 63.7 | 69.5 | 70.8 | 61.5 | 52.8 | 78.0 | 74.5 | 58.4 | 81.8 |
| MDD | 68.1 | 69.5 | 56.2 | 74.9 | 78.8 | 63.4 | 72.5 | 72.6 | 63.8 | 54.6 | 80.0 | 73.5 | 60.1 | 83.7 |
VisDA-2017 accuracy on ResNet-50 and ResNet-101
| Methods | Origin | DALIB | Origin | DALIB |
| Backbone | ResNet-50 | ResNet-50 | ResNet-101 | ResNet-101 |
| Source Only | / | 55.1 | 52.4 | 58.3 |
| DANN | / | 72.6 | 57.4 | 72.9 |
| DAN | / | 60.6 | 61.1 | 64.8 |
| JAN | 61.6 | 64.9 | / | 68.0 |
| CDAN | 66.8 | 74.6 | / | 74.5 |
| MCD | 69.2 | 69.1 | 71.9 | 77.3 |
| MDD | 74.6 | 74.9 | / | 78.5 |
DomainNet accuracy on ResNet-101
| Source Only | clp | inf | pnt | real | skt | Avg |
| clp | N/A | 18.0 | 32.7 | 50.6 | 39.4 | 35.2 |
| inf | 35.7 | N/A | 31.1 | 50.0 | 26.5 | 35.8 |
| pnt | 41.1 | 17.8 | N/A | 56.8 | 35.0 | 37.7 |
| real | 48.6 | 22.9 | 48.8 | N/A | 36.1 | 39.1 |
| skt | 49.0 | 15.3 | 34.8 | 46.1 | N/A | 36.3 |
| Avg | 43.6 | 18.5 | 36.9 | 50.9 | 34.3 | 36.8 |
| DANN | clp | inf | pnt | real | skt | Avg |
| clp | N/A | 19.7 | 35.4 | 53.9 | 44.2 | 38.3 |
| inf | 26.7 | N/A | 23.8 | 28.8 | 23.7 | 25.8 |
| pnt | 37.2 | 18.7 | N/A | 51.1 | 36.0 | 35.8 |
| real | 50.6 | 22.1 | 47.9 | N/A | 39.0 | 39.9 |
| skt | 54.0 | 19.7 | 42.7 | 52.8 | N/A | 42.3 |
| Avg | 42.1 | 20.1 | 37.5 | 46.7 | 35.7 | 36.4 |
| DAN | clp | inf | pnt | real | skt | Avg |
| clp | N/A | 17.3 | 37.9 | 54.0 | 42.6 | 38.0 |
| inf | 34.9 | N/A | 33.4 | 46.5 | 29.9 | 36.2 |
| pnt | 43.9 | 17.7 | N/A | 55.9 | 39.3 | 39.2 |
| real | 50.1 | 20.0 | 48.6 | N/A | 38.4 | 39.3 |
| skt | 54.2 | 17.5 | 44.2 | 53.4 | N/A | 42.3 |
| Avg | 45.8 | 18.1 | 41.0 | 52.5 | 37.6 | 39.0 |
| CDAN | clp | inf | pnt | real | skt | Avg |
| clp | N/A | 20.8 | 40.0 | 56.1 | 45.5 | 40.6 |
| inf | 31.2 | N/A | 30.0 | 41.4 | 24.7 | 31.8 |
| pnt | 44.6 | 20.5 | N/A | 57.0 | 39.9 | 40.5 |
| real | 55.3 | 24.1 | 52.6 | N/A | 42.4 | 43.6 |
| skt | 56.7 | 21.3 | 46.2 | 55.0 | N/A | 44.8 |
| Avg | 47.0 | 21.7 | 42.2 | 52.4 | 38.1 | 40.3 |
| MDD | clp | inf | pnt | real | skt | Avg |
| clp | N/A | 21.2 | 42.9 | 59.5 | 47.5 | 42.8 |
| inf | 35.3 | N/A | 34.0 | 49.6 | 29.4 | 37.1 |
| pnt | 48.6 | 19.7 | N/A | 59.4 | 42.6 | 42.6 |
| real | 58.3 | 24.9 | 53.7 | N/A | 46.2 | 45.8 |
| skt | 58.7 | 20.7 | 46.5 | 57.7 | N/A | 45.9 |
| Avg | 50.2 | 21.6 | 44.3 | 56.6 | 41.4 | 42.8 |
DANN¶
-
class
dalib.adaptation.dann.DomainAdversarialLoss(domain_discriminator: torch.nn.modules.module.Module, reduction: Optional[str] = 'mean')[source]¶ Bases:
torch.nn.modules.module.ModuleDomain adversarial loss measures the domain discrepancy through training a domain discriminator. Given domain discriminator \(D\), feature representation \(f\), the definition of DANN loss is
\[\begin{split}loss(\mathcal{D}_s, \mathcal{D}_t) &= \mathbb{E}_{x_i^s \sim \mathcal{D}_s} log[D(f_i^s)] \\ &+ \mathbb{E}_{x_j^t \sim \mathcal{D}_t} log[1-D(f_j^t)].\\\end{split}\]- Parameters:
- domain_discriminator (class:nn.Module object): A domain discriminator object, which predicts the domains of features. Its input shape is (N, F) and output shape is (N, 1)
- reduction (string, optional): Specifies the reduction to apply to the output:
'none'|'mean'|'sum'.'none': no reduction will be applied,'mean': the sum of the output will be divided by the number of elements in the output,'sum': the output will be summed. Default:'mean'
- Inputs: f_s, f_t
- f_s (tensor): feature representations on source domain, \(f^s\)
- f_t (tensor): feature representations on target domain, \(f^t\)
- Shape:
- f_s, f_t: \((N, F)\) where F means the dimension of input features.
- Outputs: scalar by default. If :attr:
reductionis'none', then \((N, )\).
- Examples::
>>> from dalib.modules.domain_discriminator import DomainDiscriminator >>> discriminator = DomainDiscriminator(in_feature=1024, hidden_size=1024) >>> loss = DomainAdversarialLoss(discriminator, reduction='mean') >>> # features from source domain and target domain >>> f_s, f_t = torch.randn(20, 1024), torch.randn(20, 1024) >>> output = loss(f_s, f_t)
DAN¶
-
class
dalib.adaptation.dan.MultipleKernelMaximumMeanDiscrepancy(kernels: Sequence[torch.nn.modules.module.Module], linear: Optional[bool] = False, quadratic_program: Optional[bool] = False)[source]¶ Bases:
torch.nn.modules.module.ModuleThe Multiple Kernel Maximum Mean Discrepancy (MK-MMD) used in Learning Transferable Features with Deep Adaptation Networks
Given source domain \(\mathcal{D}_s\) of \(n_s\) labeled points and target domain \(\mathcal{D}_t\) of \(n_t\) unlabeled points drawn i.i.d. from P and Q respectively, the deep networks will generate activations as \(\{z_i^s\}_{i=1}^{n_s}\) and \(\{z_i^t\}_{i=1}^{n_t}\). The MK-MMD \(D_k (P, Q)\) between probability distributions P and Q is defined as
\[D_k(P, Q) \triangleq \| E_p [\phi(z^s)] - E_q [\phi(z^t)] \|^2_{\mathcal{H}_k},\]\(k\) is a kernel function in the function space
\[\mathcal{K} \triangleq \{ k=\sum_{u=1}^{m}\beta_{u} k_{u} \}\]where \(k_{u}\) is a single kernel.
Using kernel trick, MK-MMD can be computed as
\[\begin{split}\hat{D}_k(P, Q) &= \dfrac{1}{n_s^2} \sum_{i=1}^{n_s}\sum_{j=1}^{n_s} k(z_i^{s}, z_j^{s}) \\ &+ \dfrac{1}{n_t^2} \sum_{i=1}^{n_t}\sum_{j=1}^{n_t} k(z_i^{t}, z_j^{t}) \\ &- \dfrac{2}{n_s n_t} \sum_{i=1}^{n_s}\sum_{j=1}^{n_t} k(z_i^{s}, z_j^{t}). \\\end{split}\]- Parameters:
- kernels (tuple(nn.Module)): kernel functions.
- linear (bool): whether use the linear version of DAN. Default: False
- quadratic_program (bool): whether use quadratic program to solve \(\beta\). Default: False
- Inputs: z_s, z_t
- z_s (tensor): activations from the source domain, \(z^s\)
- z_t (tensor): activations from the target domain, \(z^t\)
- Shape:
- Inputs: \((minibatch, *)\) where * means any dimension
- Outputs: scalar
Note
Activations \(z^{s}\) and \(z^{t}\) must have the same shape.
Note
The kernel values will add up when there are multiple kernels.
- Examples::
>>> from dalib.modules.kernels import GaussianKernel >>> feature_dim = 1024 >>> batch_size = 10 >>> kernels = (GaussianKernel(alpha=0.5), GaussianKernel(alpha=1.), GaussianKernel(alpha=2.)) >>> loss = MultipleKernelMaximumMeanDiscrepancy(kernels) >>> # features from source domain and target domain >>> z_s, z_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim) >>> output = loss(z_s, z_t)
JAN¶
-
class
dalib.adaptation.jan.JointMultipleKernelMaximumMeanDiscrepancy(kernels: Sequence[Sequence[torch.nn.modules.module.Module]], linear: Optional[bool] = True, thetas: Sequence[torch.nn.modules.module.Module] = None)[source]¶ Bases:
torch.nn.modules.module.ModuleThe Joint Multiple Kernel Maximum Mean Discrepancy (JMMD) used in Deep Transfer Learning with Joint Adaptation Networks
Given source domain \(\mathcal{D}_s\) of \(n_s\) labeled points and target domain \(\mathcal{D}_t\) of \(n_t\) unlabeled points drawn i.i.d. from P and Q respectively, the deep networks will generate activations in layers \(\mathcal{L}\) as \(\{(z_i^{s1}, ..., z_i^{s|\mathcal{L}|})\}_{i=1}^{n_s}\) and \(\{(z_i^{t1}, ..., z_i^{t|\mathcal{L}|})\}_{i=1}^{n_t}\). The empirical estimate of \(\hat{D}_{\mathcal{L}}(P, Q)\) is computed as the squared distance between the empirical kernel mean embeddings as
\[\begin{split}\hat{D}_{\mathcal{L}}(P, Q) &= \dfrac{1}{n_s^2} \sum_{i=1}^{n_s}\sum_{j=1}^{n_s} \prod_{l\in\mathcal{L}} k^l(z_i^{sl}, z_j^{sl}) \\ &+ \dfrac{1}{n_t^2} \sum_{i=1}^{n_t}\sum_{j=1}^{n_t} \prod_{l\in\mathcal{L}} k^l(z_i^{tl}, z_j^{tl}) \\ &- \dfrac{2}{n_s n_t} \sum_{i=1}^{n_s}\sum_{j=1}^{n_t} \prod_{l\in\mathcal{L}} k^l(z_i^{sl}, z_j^{tl}). \\\end{split}\]- Parameters:
- kernels (tuple(tuple(nn.Module))): kernel functions, where kernels[r] corresponds to kernel \(k^{\mathcal{L}[r]}\).
- linear (bool): whether use the linear version of JAN. Default: False
- thetas (list(Theta): use adversarial version JAN if not None. Default: None
- Inputs: z_s, z_t
- z_s (tuple(tensor)): multiple layers’ activations from the source domain, \(z^s\)
- z_t (tuple(tensor)): multiple layers’ activations from the target domain, \(z^t\)
- Shape:
- \(z^{sl}\) and \(z^{tl}\): \((minibatch, *)\) where * means any dimension
- Outputs: scalar
Note
Activations \(z^{sl}\) and \(z^{tl}\) must have the same shape.
Note
The kernel values will add up when there are multiple kernels for a certain layer.
- Examples::
>>> feature_dim = 1024 >>> batch_size = 10 >>> layer1_kernels = (GaussianKernel(alpha=0.5), GaussianKernel(1.), GaussianKernel(2.)) >>> layer2_kernels = (GaussianKernel(1.), ) >>> loss = JointMultipleKernelMaximumMeanDiscrepancy((layer1_kernels, layer2_kernels)) >>> # layer1 features from source domain and target domain >>> z1_s, z1_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim) >>> # layer2 features from source domain and target domain >>> z2_s, z2_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim) >>> output = loss((z1_s, z2_s), (z1_t, z2_t))
CDAN¶
-
class
dalib.adaptation.cdan.ConditionalDomainAdversarialLoss(domain_discriminator: torch.nn.modules.module.Module, entropy_conditioning: Optional[bool] = False, randomized: Optional[bool] = False, num_classes: Optional[int] = -1, features_dim: Optional[int] = -1, randomized_dim: Optional[int] = 1024, reduction: Optional[str] = 'mean')[source]¶ Bases:
torch.nn.modules.module.ModuleThe Conditional Domain Adversarial Loss
Conditional Domain adversarial loss measures the domain discrepancy through training a domain discriminator in a conditional manner. Given domain discriminator \(D\), feature representation \(f\) and classifier predictions \(g\), the definition of CDAN loss is
\[\begin{split}loss(\mathcal{D}_s, \mathcal{D}_t) &= \mathbb{E}_{x_i^s \sim \mathcal{D}_s} log[D(T(f_i^s, g_i^s))] \\ &+ \mathbb{E}_{x_j^t \sim \mathcal{D}_t} log[1-D(T(f_j^t, g_j^t))],\\\end{split}\]where \(T\) is a multi linear map or randomized multi linear map which convert two tensors to a single tensor.
- Parameters:
- domain_discriminator (class:nn.Module object): A domain discriminator object, which predicts the domains of features. Its input shape is (N, F) and output shape is (N, 1)
- entropy_conditioning (bool, optional): If True, use entropy-aware weight to reweight each training example. Default: False
- randomized (bool, optional): If True, use randomized multi linear map. Else, use multi linear map. Default: False
- num_classes (int, optional): Number of classes. Default: -1
- features_dim (int, optional): Dimension of input features. Default: -1
- randomized_dim (int, optional): Dimension of features after randomized. Default: 1024
- reduction (string, optional): Specifies the reduction to apply to the output:
'none'|'mean'|'sum'.'none': no reduction will be applied,'mean': the sum of the output will be divided by the number of elements in the output,'sum': the output will be summed. Default:'mean'
Note
You need to provide num_classes, features_dim and randomized_dim only when randomized is set True.
- Inputs: g_s, f_s, g_t, f_t
- g_s (tensor): unnormalized classifier predictions on source domain, \(g^s\)
- f_s (tensor): feature representations on source domain, \(f^s\)
- g_t (tensor): unnormalized classifier predictions on target domain, \(g^t\)
- f_t (tensor): feature representations on target domain, \(f^t\)
- Shape:
- g_s, g_t: \((minibatch, C)\) where C means the number of classes.
- f_s, f_t: \((minibatch, F)\) where F means the dimension of input features.
- Output: scalar by default. If :attr:
reductionis'none', then \((minibatch, )\).
- Examples::
>>> from dalib.modules.domain_discriminator import DomainDiscriminator >>> num_classes = 2 >>> feature_dim = 1024 >>> batch_size = 10 >>> discriminator = DomainDiscriminator(in_feature=feature_dim, hidden_size=1024) >>> loss = ConditionalDomainAdversarialLoss(discriminator, reduction='mean') >>> # features from source domain and target domain >>> f_s, f_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim) >>> # logits output from source domain adn target domain >>> g_s, g_t = torch.randn(batch_size, num_classes), torch.randn(batch_size, num_classes) >>> output = loss(g_s, f_s, g_t, f_t)
-
class
dalib.adaptation.cdan.RandomizedMultiLinearMap(features_dim: int, num_classes: int, output_dim: Optional[int] = 1024)[source]¶ Bases:
torch.nn.modules.module.ModuleRandom multi linear map
Given two inputs \(f\) and \(g\), the definition is
\[T_{\odot}(f,g) = \dfrac{1}{\sqrt{d}} (R_f f) \odot (R_g g),\]where \(\odot\) is element-wise product, \(R_f\) and \(R_g\) are random matrices sampled only once and fixed in training.
- Parameters:
- features_dim (int): dimension of input \(f\)
- num_classes (int): dimension of input \(g\)
- output_dim (int, optional): dimension of output tensor. Default: 1024
- Shape:
- f: (minibatch, features_dim)
- g: (minibatch, num_classes)
- Outputs: (minibatch, output_dim)
MCD¶
-
dalib.adaptation.mcd.classifier_discrepancy(predictions1: torch.Tensor, predictions2: torch.Tensor) → torch.Tensor[source]¶ The Classifier Discrepancy in Maximum Classifier Discrepancy for Unsupervised Domain Adaptation. The classfier discrepancy between predictions \(p_1\) and \(p_2\) can be described as:
\[d(p_1, p_2) = \dfrac{1}{K} \sum_{k=1}^K | p_{1k} - p_{2k} |,\]where K is number of classes.
- Parameters:
- predictions1 (tensor): Classifier predictions \(p_1\). Expected to contain raw, normalized scores for each class
- predictions2 (tensor): Classifier predictions \(p_2\)
-
dalib.adaptation.mcd.entropy(predictions: torch.Tensor) → torch.Tensor[source]¶ Entropy of N predictions \((p_1, p_2, ..., p_N)\). The definition is:
\[d(p_1, p_2, ..., p_N) = -\dfrac{1}{K} \sum_{k=1}^K \log \left( \dfrac{1}{N} \sum_{i=1}^N p_{ik} \right)\]where K is number of classes.
- Parameters:
- predictions (tensor): Classifier predictions. Expected to contain raw, normalized scores for each class
-
class
dalib.adaptation.mcd.ImageClassifierHead(in_features: int, num_classes: int, bottleneck_dim: Optional[int] = 1024)[source]¶ Bases:
torch.nn.modules.module.ModuleClassifier Head for MCD. Parameters:
- in_features (int): Dimension of input features
- num_classes (int): Number of classes
- bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. Default: 1024
- Shape:
- Inputs: \((minibatch, F)\) where F = in_features.
- Output: \((minibatch, C)\) where C = num_classes.
MDD¶
-
class
dalib.adaptation.mdd.MarginDisparityDiscrepancy(margin: Optional[int] = 4, reduction: Optional[str] = 'mean')[source]¶ Bases:
torch.nn.modules.module.ModuleThe margin disparity discrepancy (MDD) is proposed to measure the distribution discrepancy in domain adaptation.
The \(y^s\) and \(y^t\) are logits output by the main classifier on the source and target domain respectively. The \(y_{adv}^s\) and \(y_{adv}^t\) are logits output by the adversarial classifier. They are expected to contain raw, unnormalized scores for each class.
The definition can be described as:
\[\mathcal{D}_{\gamma}(\hat{\mathcal{S}}, \hat{\mathcal{T}}) = \gamma \mathbb{E}_{y^s, y_{adv}^s \sim\hat{\mathcal{S}}} \log\left(\frac{\exp(y_{adv}^s[h_{y^s}])}{\sum_j \exp(y_{adv}^s[j])}\right) + \mathbb{E}_{y^t, y_{adv}^t \sim\hat{\mathcal{T}}} \log\left(1-\frac{\exp(y_{adv}^t[h_{y^t}])}{\sum_j \exp(y_{adv}^t[j])}\right),\]where \(\gamma\) is a margin hyper-parameter and \(h_y\) refers to the predicted label when the logits output is \(y\). You can see more details in Bridging Theory and Algorithm for Domain Adaptation.
- Parameters:
- margin (float): margin \(\gamma\). Default: 4
- reduction (string, optional): Specifies the reduction to apply to the output:
'none'|'mean'|'sum'.'none': no reduction will be applied,'mean': the sum of the output will be divided by the number of elements in the output,'sum': the output will be summed. Default:'mean'
- Inputs: y_s, y_s_adv, y_t, y_t_adv
- y_s: logits output \(y^s\) by the main classifier on the source domain
- y_s_adv: logits output \(y^s\) by the adversarial classifier on the source domain
- y_t: logits output \(y^t\) by the main classifier on the target domain
- y_t_adv: logits output \(y_{adv}^t\) by the adversarial classifier on the target domain
- Shape:
- Inputs: \((minibatch, C)\) where C = number of classes, or \((minibatch, C, d_1, d_2, ..., d_K)\) with \(K \geq 1\) in the case of K-dimensional loss.
- Output: scalar. If
reductionis'none', then the same size as the target: \((minibatch)\), or \((minibatch, d_1, d_2, ..., d_K)\) with \(K \geq 1\) in the case of K-dimensional loss.
- Examples::
>>> num_classes = 2 >>> batch_size = 10 >>> loss = MarginDisparityDiscrepancy(margin=4.) >>> # logits output from source domain and target domain >>> y_s, y_t = torch.randn(batch_size, num_classes), torch.randn(batch_size, num_classes) >>> # adversarial logits output from source domain and target domain >>> y_s_adv, y_t_adv = torch.randn(batch_size, num_classes), torch.randn(batch_size, num_classes) >>> output = loss(y_s, y_s_adv, y_t, y_t_adv)
-
class
dalib.adaptation.mdd.ImageClassifier(backbone: torch.nn.modules.module.Module, num_classes: int, bottleneck_dim: Optional[int] = 1024, width: Optional[int] = 1024)[source]¶ Bases:
torch.nn.modules.module.ModuleClassifier for MDD. Parameters:
- backbone (class:nn.Module object): Any backbone to extract 1-d features from data
- num_classes (int): Number of classes
- bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. Default: 1024
- width (int, optional): Feature dimension of the classifier head. Default: 1024
Note
Classifier for MDD has one backbone, one bottleneck, while two classifier heads. The first classifier head is used for final predictions. The adversarial classifier head is only used when calculating MarginDisparityDiscrepancy.
Note
Remember to call function step() after function forward() during training phase! For instance,
>>> # x is inputs, classifier is an ImageClassifier >>> outputs, outputs_adv = classifier(x) >>> classifier.step()
- Inputs:
- x (Tensor): input data
- Outputs: (outputs, outputs_adv)
- outputs: logits outputs by the main classifier
- outputs_adv: logits outputs by the adversarial classifier
- Shapes:
- x: \((minibatch, *)\), same shape as the input of the backbone.
- outputs, outputs_adv: \((minibatch, C)\), where C means the number of classes.
-
dalib.adaptation.mdd.shift_log(x: torch.Tensor, offset: Optional[float] = 1e-06) → torch.Tensor[source]¶ First shift, then calculate log, which can be described as:
\[y = \max(\log(x+\text{offset}), 0)\]Used to avoid the gradient explosion problem in log(x) function when x=0.
- Parameters:
- x: input tensor
- offset: offset size. Default: 1e-6
Note
Input tensor falls in [0., 1.] and the output tensor falls in [-log(offset), 0]