Introduction to Domain Adversarial Neural Networks

Author:

Trent Bradberry

Chris Hase

Date Published:
July 2, 2021

Domain Adaptation (DA) is a process for enhancing model training when there is a shift, often referred to as a co-variant shift or data shift, between input distributions for different datasets. Figure 1 illustrates some simple examples: (a) shows a major shift in the mean for one variable, (b) shows a small shift in the mean for one variable, and (c) shows a smaller shift in the mean and a major shift in the variance for one variable.

Figure 1. Data shift example

DA has two competing objectives:

  • Discriminativeness – the ability to discriminate between data coming from different classes within a particular domain
  • Domain invariance – the ability to measure the similarity between data classes across domains

For example, for a classification model we want to discern the differences among classes – and so maintain the discriminativeness component . At the same time, if there is a data shift, we want to increase the domain invariance of our classifier so that it performs well when taking inputs from different domains. (For a rigorous theoretical treatment, we recommend “A theory of learning from different domains” by Ben-David et al.)

An example use case of Domain Adaptation involves the MNIST dataset, which consists of images of handwritten digits. This dataset is ubiquitous in the literature and is frequently used as a benchmark when testing models. There is also a data set called MNIST-M which adds different backgrounds and digit colors, as shown in Figure 2.

Figure 2. MNIST and MNIST-M example

There are obvious similarities, but the variations in the MNIST-M data set lead to different distributions of input features. In this case, DA could be employed to help a model perform well on both MNIST and MNIST-M, even without the labels from MNIST-M. This use case is described in “Unsupervised Domain Adaptation by Backpropagation” by Ganin and Lempitsky.

Another useful DA example is addressing how biological signals, in particular neural signals, change over time or subject. Let’s say we want to develop a brain-computer interface application and need to classify human subjects’ thoughts about moving a specific arm or leg to control an external prosthetic device. Using EEG signals from a Source Subject (Figure 3), we would typically develop a classifier. However, applying that classifier to our Target Subject might result in poor performance. Domain adaptation would allow us to achieve higher performance on our Target Subject without having to collect data from every subject.

Figure 3. Using DA to address how neural signals change over time

One of the most common methods for implementing DA is called Sample Reweighting. When using this method, we develop a domain classifier with the following steps:

  1. Label all source domain samples as “0” and all target domain samples as “1”.
  2. Train a binary classifier that can return predicted probabilities, pi, (such as logistic regression or random forest) to discriminate between the source data and the target data.
  3. Use the resultant probabilities to obtain sample weights for the source domain samples when model fitting on the source domain, using:

This causes the source samples that look the most like target samples get higher weight. While we generally think positively of this method, it has some drawbacks. One of the challenges is determining how accurate to drive the domain classifier to be. If it is too accurate, it won’t be useful as there will not be overlap between the target and domain regions.

Domain Adversarial Neural Networks

What if we had a way to use DA and learn label classification at the same time? One method with this capability is the Domain Adversarial Neural Network (DANN). It employs source data that has class labels and target data that is unlabeled. The goal is to predict the target data by using the source data and the target data in an adversarial training process.

DANN Model Architecture

The DANN Model of Figure 4 includes the following components:

  • Label Predictor (Blue): Predicts class labels
  • Domain Classifier (Pink): Predicts the domain of the inputs
  • Feature Extractor (Green): Produces features that are used as inputs to the Label Predictor and the Domain Classifier in the training process. Ideally, it produces features that can predict class labels for either the source or the target domains. To achieve this, the features need to minimize the error of the Label Predictor (so that the model is good at discriminating) but also maximize the error of the Domain Classifier. If the Domain Classifier cannot distinguish between the features being produced when source and target domain inputs are fed to the Feature Extractor, then the features can be said to be domain invariant.

 

Figure 4. Domain Adversarial Neural Network architecture (Figure reproduced from Ganin and Lempitsky)

DANN Training Process

The DANN training Process is shown in Figure 4:

  • Input features from either source or targets are fed to the Feature Extractor.
  • Features produced are fed to the:
    • Label Predictor or Domain Classifier if the input was from the source domain (because only that data has labels).
    • Domain Classifier if the input was from the target domain (as there are no labels on that data).
  • The Label Predictor and Domain Classifier are optimized to minimize the error associated with their respective classification problems using a loss function like class entropy.
  • ‘Special’ optimization (described below) is performed for the Feature Extractor that is specific to a DANN.

The optimization of the Feature Extractor can be viewed as finding the best tradeoff between producing features that are domain invariant and also useful for the Label Predictor. The parameters of the Feature Extractor are optimized to minimize the loss of the Label Predictor and maximize the loss of the Domain Classifier (which involves the use of a gradient reversal layer).

In production, we feed inputs from the target domain to the Feature Extractor, which creates features fed to the Label Predictor to make our label predictions. The domain classifier is not used, so we can ignore the pink part of Figure 4 when deploying the model.

DANN Examples

Let’s look at a few examples. The first, in Figure 5, is a toy example with synthetic data produced by the make_blobs function from the scikit-learn Python package. The data on the left is the source data and that on the right is the target data. Class 0 is red and class 1 is green.

Observe how the data has been shifted between the source and target domains. Within each domain, the classes can be linearly separated but the shift significantly complicates generalization of that model.

Figure 5. Source and Target data for our toy problem

Note that we only use labels from the source domain for training and only use labels from the target domain to calculate the performance metric. Because target domain labels are not used in the training process, they are grayed out in Figure 6.

 

Figure 6. Label usage for our toy problem

A conventional neural network trained on the source domain and then tested on the target domain achieved an accuracy of 55%. However, if we include the Domain Classifier and use the DANN training process, the resulting target domain accuracy goes up to 95%, demonstrating the value of the DANN process.

A more real-world example comes from a graduate school project for a natural language processing (NLP) course. There, the goal was to identify whether pairs of questions, drawn from an Android forum, were similar.  The training data with labeled pairs though, all came from an AskUbuntu forum. This problem is ideally suited for the DANN architecture and training process.

When the DANN training process was not used (i.e., only the AskUbuntu forum data was used in training and then the model was tested on the Android data) an Area Under the Curve (AUC) of 0.61 was achieved. When the DANN framework and its training process were used (AskUbuntu data inputs and labels as well as Android data inputs were used in training; Android labels were not used), the AUC increased to 0.69. When a small number of labels from the Android forum data were added into the DANN training process, the AUC increased to 0.76 — a large improvement.

In the real world, we might not have labels for our target domain on which to compute metrics. Since we would not want to put a model into production without computing out-of-sample performance metrics, one could hand-label a small amount of the data from the target domain to use for evaluation.

We believe that a method like this could be integrated into training models like ImageNet or ULMFiT, which are frequently used as pre-trained models. Using a DANN training process would likely create more domain-invariant models better tuned for a given application.

Recently, improvements to the DANN architecture have been published, and we recommend interested readers explore the new developments on Generative Adversarial Networks (GANs). Still, even in the latest work, the adversarial training process described here remains a key component of DA.