Source code for dalib.modules.domain_discriminator

from typing import List, Dict
import torch.nn as nn
import torch

__all__ = ['DomainDiscriminator']


[docs]class DomainDiscriminator(nn.Module): r"""Domain discriminator model from `"Domain-Adversarial Training of Neural Networks" <https://arxiv.org/abs/1505.07818>`_ Distinguish whether the input features come from the source domain or the target domain. The source domain label is 1 and the target domain label is 0. Parameters: - **in_feature** (int): dimension of the input feature - **hidden_size** (int): dimension of the hidden features Shape: - Inputs: (minibatch, `in_feature`) - Outputs: :math:`(minibatch, 1)` """ def __init__(self, in_feature: int, hidden_size: int): super(DomainDiscriminator, self).__init__() self.layer1 = nn.Linear(in_feature, hidden_size) self.bn1 = nn.BatchNorm1d(hidden_size) self.relu1 = nn.ReLU() self.layer2 = nn.Linear(hidden_size, hidden_size) self.bn2 = nn.BatchNorm1d(hidden_size) self.relu2 = nn.ReLU() self.layer3 = nn.Linear(hidden_size, 1) self.sigmoid = nn.Sigmoid() def forward(self, x: torch.Tensor) -> torch.Tensor: """""" x = self.relu1(self.bn1(self.layer1(x))) x = self.relu2(self.bn2(self.layer2(x))) y = self.sigmoid(self.layer3(x)) return y def get_parameters(self) -> List[Dict]: return [{"params": self.parameters(), "lr_mult": 1.}]