227-0391-00: Medical Image Analysis
Section 6
Interpretability, Uncertainty, and Data-Efficient Methods
Swiss Federal Institute of Technology Zurich
Eidgenössische Technische Hochschule Zürich
Last Edit Date: 07/09/2025
Disclaimer and Term of Use:
We do not guarantee the accuracy and completeness of the summary content. Some of the course material may not be included, and some of the content in the summary may not be correct. You should use this file properly and legally. We are not responsible for any results from using this file
This personal note is adapted from Professor Ender Konukoglu, Professor Mauricio Reyes, and Doctor Ertunc Erdil. Please contact us to delete this file if you think your rights have been violated.
This work is licensed under a Creative Commons Attribution 4.0 International License.
Pixel-wise predictions with deep neural networks¶
Segmentation¶
Dataset¶
For a training set, we need good images and a large number of images. Note that we would also need to have validation set and a test set.
How big of a data set do we need?
It depends on the problem
It depends on
- How variable is the structure?
- How variable is the background?
- How variable is the intensity profile?
Example: For brain MRI with T1w MPRAGE image, it takes 3 - 5 labeled volumes to train a good UNet for segmenting large anatomical structures, e.g., white matter, gray matter, hippocampus, etc. It is difficult to generalize from such examples.
Ambiguity and uncertainty in segmentation labels is very common¶
Different people might create different labels.
Missing data in problems requiring multiple sequences is very common¶
We may need to impute missing data during training or testing.
Variations between similar images are very common¶
They are all acquired with similar sequences (T1w MPRAGE). However, small differences in scanners and acquisition protocols (sequence parameters) can change the contrast between different tissue a lot.
Changes can happen during training. They can also happen during inference.
Cost function¶
Basic classification loss (binary cross entropy)¶
Pixel-wise extension of the binary classification loss. The basic binary classification loss for a single sample is
$$BCE(y, f(x; \theta)) = -\mathbb{1}(y=1) \log f(x; \theta) - \mathbb{1}(y = 0) \log(1 - f(x; \theta))$$
assuming $f(x; \theta)$ represents the probability of being of class $y = 1$ predicted by the network.
Extension to all the pixels is to define the loss at each pixel and sum over pixels
$$L(y, f(x; \theta)) = \sum_{r=1}^{D} BCE(y(r), f(x; \theta)(r))$$
where $y(r)$ and $f(x; \theta)(r)$ are the ground truth labels and predictions at pixel $r$, respectively.
Note that the network output $f(x; \theta)$ is assumed to be an image.
Multi-class case (cross entropy)¶
Just like classificatiom problems, segmentation problems often have multiple output labels. Then we extend the multi-class classification loss
$$CE(y, f(x; \theta)) = -\sum_{k} p(y = k) \log f_k(x; \theta)$$
Here we use a different notation;
The network assigns a probability to each output class, indicating the probability of the given sample $x$ belonging to the class $k$.
If ground truth does not have any uncertainty, which is most of the applications, then $p(y = k)$ takes the value 1 for one class.
Extension to all the pixels is to define the loss at each pixel and sum over pixels
$$L(y, f(x; \theta)) = \sum_{r=1}^{D} CE(y(r), f(x; \theta)(r))$$
Loss over a set of samples is defined similarly for both cases: $\mathcal{L}(\mathcal{D}; \theta) = \sum_{n=1}^{N} L(y_n, f(x_n; \theta))$.
Challenge with the (binary) cross entropy¶
When there are small structure, either compared to background or other classes (in the multi-class problem), the losses associated with small structures are dominated by losses associated with larger structures.
$$ \begin{align} L(y, f(x; \theta)) &= \sum_{r = 1}^{D}CE(y(r), f(x; \theta)(r))\\ &= \sum_{r \in \Omega_{small}} CE(y(r), f(x; \theta)(r)) + \sum_{r \in \bar{\Omega_{small}}} CE(y(r), f(x; \theta)(r)) \\ &= L_{\Omega_{small}} + L_{\bar{\Omega_{small}}} \end{align} $$
It is very likely that $L_{\Omega_{small}} << L_{\bar{\Omega_{small}}}$ because of $|\Omega_{small}| << |\bar{\Omega_{small}}|$.
Alternative: Sorensen-Dice Coefficient¶
$$DSC = \frac{2 |A \cap B|}{|A| + |B|}$$
where $A$ and $B$ are masks.
Used for evaluating quality of segmentation predictions
DSC = 1 when the sets perfectly overlap
DSC = 0 when there is no overlap
The issue is that this is not differentiable (cannot be used during training)
For binary segmentation
$$L(y, f(x; \theta)) = \frac{2 \sum_{r = 1}^{D} y(r) f(x; \theta)(r)}{\sum_{r=1}^{D}y(r)^2 + \sum_{r=1}^{D}f(x; \theta)(r)^2}$$
For multi-class segmentation
$$L(y, f(x; \theta)) = \sum_{k = 1} \frac{2 \sum_{r = 1}^{D} p(y(r) = k) f_k(x; \theta)(r)}{\sum_{r=1}^{D}p(y(r) = k)^2 + \sum_{r=1}^{D}f_k(x; \theta)(r)^2}$$
Each class DSC loss is equal in magnitude despite having different sizes contrary to corss entropy.
Architecture¶
Binary segmentation¶
Given the following deep neural network, where the $\rightarrow$ represents convolutional link, $\Rightarrow$ represents the fully connected link, $I$ is the input image, $L\#$ are the intermediate layers, and $O \in [0, 1]$ is the output.
$$I \rightarrow L1 \rightarrow L2 \rightarrow L3 \rightarrow L4 \Rightarrow L5 \Rightarrow O$$
Keep in mind that
We would like an output of image size (input size equals to output size)
Channel dimension reduces from $L1$ to $L4$
$$I \rightarrow L1 \rightarrow L2 \rightarrow L3 \rightarrow L4 \rightarrow L5 \stackrel{\uparrow s}{\rightarrow} O$$
We unsample within the network and convert an intermediate layer $L5$ to required image size. The upsampling factor $s$ depends on the size of the channel at $L5$.
We can use convolution layers that keeps the dimension of the incoming channels the same, e.g., padding for convolution and no-pooling. Moreover, for the last layer, we need to use sigmoid activation to make sure the output is in $[0, 1]$.
$$I \rightarrow 32 \rightarrow 128 \rightarrow 256 \stackrel{\uparrow s}{\rightarrow} 128 {\color{red}~\rightarrow~} 32 {\color{blue}~\rightarrow~} O$$
$\rightarrow$: 1-padding with $3 \times 3$ convolutions, ReLU activations and pooling
${\color{red}\rightarrow}$: 1-padding with $3 \times 3$ convolutions, ReLU activations
${\color{blue}\rightarrow}$: 1-padding with $3 \times 3$ convolutions, Sigmoid activations
Multi-class segmentation¶
$$I \rightarrow L1 \rightarrow L2 \rightarrow L3 \rightarrow L4 \stackrel{\uparrow s}{\rightarrow} L5 {\color{blue}~\rightarrow~} O$$
The output $O$ has multiple channels in this case. Each channel is an image indicating probabilities of each pixel belonging to one class.
Softmax type non-linear activation
$$f_k(x; \theta)(r) = O(r) = \frac{\exp(a_k(r))}{\sum_{j} \exp(a_j(r))}$$
where $a_j(r)$ is the activation in channel $j$ at pixel $r$.
Challenges with this idea¶
This simple idea provide segmentation albeit losing resolution. The problem is that at $L5$, there has been so many pooling operations that details are lost.
The key question is how to use both contextual information and retain high-resolution information.
Alternative 1: Fully convolutional networks (FCN) with hierarchical¶
Combination from different scales to retain high-resolution details in segmentation maps. Note that this method also upsamples at intermediate layers.
Alternative 2: UNet¶
$\rightarrow_{\downarrow 2}$: convolution (same), non-linear activation, pooling
$\rightarrow_{\uparrow 2}$: bilinear upsampling, convolution (same), non-linear activation OR
$\rightarrow_{\uparrow 2}$: transposed convolution, non-linear activation
${\color{red}\rightarrow}$: convolution (same), non-linear activation
${\color{blue}\rightarrow}$: convolution (same), sigmoid or softmax activation
Remark
Skip connections can retain details.
Some other alternatives¶
DeepMedic
3D extensions are available
Transformer technologies within UNet structures are immensely successful
Integrating transformer blocks in the encoding path
Architecture design is an active research area
Restoration¶
Dataset¶
For dataset, it is the similar issue as we met in segmentation.
Missing data in multimodal cases
Images may be coming from different "domains," e.g., centers, scanners, sequences
Groud truth labels may not be available for all questions
- It may not be possible to acquire ground truth images because
- Difficult to acquire such images due to physical limitations on the acquisition system
- Unethical to acquire additional ground truth images
- Using synthetic data is a solution
- It may not be possible to acquire ground truth images because
Cost function¶
Basic cost functions¶
Mean absolute error (MAE)
$$MAE = \sum_{r} |y(r) - f(x; \theta)(x)|$$
Mean squared error (MSE)
$$MSE = \sum_{r} \| y(r) - f(x; \theta)(x) \|_2^2$$
Normalized mean squared error (NMSE)
$$NMSE = \frac{MSE}{\sum_{r}\| y(r) \|_2^2}~~~~\text{or}~~~~NMSE=\frac{MSE}{\| \max(y) - \min(y) \|_2^2}$$
Peak signal to noise ratio (PSNR)
$$PSNR = 10 \log_{10} \frac{\max(y)^2}{MSE}$$
Structural similarity index measure (SSIM)
$$SSIM(a, b) = \frac{(2 \mu_a \mu_b + c_1)(2 \sigma_{ab} + c_2)}{(\mu_a^2 + \mu_b^2 + c_1)(\sigma_a^2 + \sigma_b^2 + c_2)}$$
Computed over many patches $a$ and $b$ coming from $y$ and $f(x; \theta)$. Combination of local luminance, contrast and structure comparisons.
Advanced cost function 1: Perceptual loss¶
While it is unclear what it means to capture "perceptual differences," one can use neural networks to do this.
$L1$, $L2$, $L3$, and $L4$ are layers of a previously trained CNN
Perceptual loss is defined as
$$\sum_{j} MSE(\phi_j^y, \phi_{j}^f)$$
The CNN can be a well-established network trained on natural images, e.g., VGG, or a network trained on CT images for another task
The CNN can be much deeper
Distance between "deep features" measure differences in contextual information going beyond simple pixel-wise difference
Advanced cost function 2: Adversarial loss¶
Another way to capture perceptual differences is adversarial losses.
Distributional distance - not between two samples but over sets of samples
Discriminator identifies whether the input image is real or generated by the "generator" network $f(\cdot ; \theta)$
Optimized for the generatror network
$$\min_{\theta} \max_{\psi} \mathbb{E}_{y\sim p(y)} [\log D(y, \psi)] + \mathbb{E}_{x \sim p(x)} [\log(1 - D(f(x; \theta); \psi))]$$
Residual architectures¶
Restorative information is added as a residual to the original or naively restored image. In this case, residuals are added to linearly upsampled image to restore high-resolution details. Well established and used in other restoration problems.
Synthesis¶
Syntheizing one image from another one. Synthesizing a target image from a source image.
Data imputation
Reducing need for irradiation
Reducing need for additional imaging
Dataset¶
Similar issues
Ambiguity in labeling - source image may not uniquely identify a target
Missing data in multimodal cases
Image may be coming from different "domains", e.g., centers, scanners, sequences
Labels and features may not be paired
Cost function¶
Similar cost functions as restoration
MAE, MSE, NMSE, PSNR, SSIM, Perceptual distance, Adversarial loss
Distributional losses, e.g., Adversarial loss, are particularly useful for unpaired datasets
Architecture¶
Basic architecture¶
The UNet architecture, which is mentioned above, is used quite often for the simple problem.
Architecture for unparied problems¶
Total loss takes into account adversarial terms and "cycle-consistency" loss
$$\underbrace{L_{GAN}(g, D_x, x, y) + L_{GAN}(f, D_f, y, x)}_{\text{Adversarial}} + \underbrace{\| x - g(f(x; \theta); \psi) \|_1 + \| y - f(g(y; \psi); \theta) \|_1}_{\text{cycle consistency}}$$
To measure the consistency
$$L_{GAN}(f, D_y, y, x) = \min_{\theta} \max_{\phi} \mathbb{E}_{y \sim p(y)}[\log D_y(y; \psi)] + \mathbb{E}_{x \sim p(x)}[\log(1 - D_y(f(x; \theta); \phi))]$$
Effect
Training-free models and learning with few labeled examples¶
Problem setup¶
Few labeled examples
- Getting few labeled images is often possible
- For certain problems unpaired datasrts can be useful, e.g., CT synthesis from MRI
- Being more accurate with fewer labeled images is always better
Many unlabeled examples:
- Unlabeled examples are simply images without labels
- For example, for the segmentation problem, unlabeled samples would correspond to images without any ground truth segmentations associated to them
- Unlabeled examples are often available in large numbers
- There are problems where even unlabeled examples are not available in large numbers
Training-free models¶
Principle¶
- Training-free models do not require any training set.
- The model is not trained with examples.
- The model $f(x; \theta)$ is fit to the observed data.
- Fitting means minimizing a loss defined purely using the observed data.
- Examples: segmentation using Expectation-Maximization and K-means segmentation.
- Here we use the same strategy for when $f(x; \theta)$ is a network.
- Such models are mostly available for restoration problems. Denoising, super-resolution, reconstruction, in-painting, etc.
Deep image prior¶
- $x$: corrupted image
- $y$: clean image
- Normal learning procedure with the $L_2$ loss and labeled samples: $$\theta^* = \arg \min_{\theta} \sum_{n=1}^{N} \| y_n - f(x_n; \theta) \|_2^2, y \approx f(x; \theta^*)$$
- DIP fitting procedure with the $L_2$ loss using the "test" sample $x$: $$\theta^* = \arg \min_{\theta} \| x - f(z; \theta) \|_2^2, y \approx f(z; \theta)$$ where $z$ is a random image of sample size as $x$.
- No clean image is used, no training image is used, only fitting to the data.
What happens?
- Model fits to the input perfectly, including the corruption at the end.
- Before doing so, it first reconstructs a clean image.
- For the network it is more difficult to fit to the noise and corruption.
- Inductive bias of the network architecture.
UNet architecture is used in the original article.
- However, skip connections are often omitted based on the application and results may be worse when skip connections are added.
Applications for reconstruction¶
Radon transform $$Ax(s, \phi) = \int x(L_{x, \phi}(t)) dt$$
Noisy observations $$y = Ax + \tau$$
DIP integration follows as $$\min_{\theta} \| Af(z; \theta) - y \|^2$$
Combination with known regularizers, e.g., total variation $\| \nabla f(z; \theta) \|_1$ $$\min_{\theta} \| Af(z; \theta) - y \|^2 + \lambda R(f(z; \theta))$$
Advantages and disadvantages¶
- It is quite impressive how well this algorithm works.
- The fact that it does not require any training sample is great.
- Since it is optimized per image, there is also no domain adaptation problem.
- Results are not as good as those we would get with supervised learning.
- For denoising it works well but results are worse if corruption is of low frequency nature.
- Extension to other problem exist, e.g., segmentatino, but the inductive bias seems to be best leveraged for predicting image intensities.
Learning from unlabeled examples¶
Labled examples are expensive to acquire due to various reasons. On the other hand, unlabeled examples are easier to collect. The questions is whether we can use unlabeled examples for learning parameters of a neural network.
Noise2noise: learning to denoise without clean target data¶
Learn a restoration model without any group truth labels, only using corrupted samples.
The usual way of training deep learning models uses clean images $y_n$
$$\arg \min_{\theta} \sum_{n} L(y_n, f(x_n; \theta))$$
Assume we do not have a clean image but a distribution of noisy images for the same sample $x$, i.e., $p(\hat{y})$, then if we minimize $\arg \min_{f(x;\theta)} \mathbb{E}_{\hat{y}}[\| f(x; \theta) - \hat{y} \|^2]$ we would get $$f(x; \theta) = \mathbb{E}_{\hat{y}}[\hat{y}]$$
If we now also average over the $x$ samples, in the limit computing the expectation withrespect to $p(x)$, and minimize with respect to the parameters of the network
$$\arg \min_{\theta} \mathbb{E}_x \left\{ \mathbb{E}_{\hat{y} | x} \{\| \hat{y} - f(x; \theta) \|^2\} \right\}$$
This is minimizing the distance between the output of the network and a noisy sample from $p(\hat{y} | x)$ and averages over $\hat{y}$ and over $x$. It does not use a clean image.
If $\mathbb{E}\{\hat{y} | x\} = y$, i.e., the average of noisy samples is the clean sample, given infinite data the following minimizations should give the same solution $$\arg \min_{\theta} \mathbb{E}_x \left\{ \mathbb{E}_{\hat{y} | x} \{\| \hat{y} - f(x; \theta) \|^2\} \right\}$$ $$\arg \min_{\theta} \mathbb{E}_x \left\{ \| y - f(x; \theta) \|^2 \right\}$$
Simple cost function¶
Instead of minimizing
$$\arg \min_{\theta} \sum_{n} \| y_n - f(x_n; \theta) \|^2 $$
One can approximate it with
$$\arg \min_{\theta} \sum_{n} \sum_{m} \| y^m_n - f(x_n; \theta) \|^2$$
where $y_n^m$ are noisy versions of an ideal $y_n$.
Self-training¶
Iterative algorithm¶
Train network with labeled examples
$$\theta_0^* = \arg \min_{\theta} \sum_{n} L(y_n, f(x_n; \theta))$$
Predict pseudo-labels for unlabeled images in the i-th iteration
$$\hat{y}_m = f(x_m; \theta_i^*)$$
Train network with labeled and unlabeled images
$$\theta_{i+1}^* = \arg \min_{\theta} \sum_{n} L(y_n, f(x_n; \theta)) + \lambda \sum_{m} L(\hat{y}_m, f(x_m; \theta))$$
where $\lambda$ is a weighting factor.
Iterate over step 2 and 3.
Things to consider¶
If the initial model's estimates are not bad, self-training works well.
Entropy minimization assumes mostly correct class assignments.
If the initial model's estimates are not good, self-training quickly diverges.
Additional regularizations on the final estimates can prevent divergence.
Even with the regularization terms, the model can diverge quite often.
Teacher-student models¶
A new type of regularization: train two interacting networks for generating pseudo-labels. One evolving slowly than the other. The student model is trainied with label to be consistent with teacher model.
Student model
$$f_s(x; \theta)$$
Teacher model
$$f_t(x; \phi)$$
Cost for labeld examples
$$L(y, f_s(x;\theta))$$
Consistency loss
$$L_c(f_s(x;\theta), f_t(x; \phi))$$
This can be applied to both labeled and unlabeled example. Total loss for student network
$$\sum_{n} L(y_n, f_s(x_n; \theta)) + \lambda \left( \sum_{n} L_c(f_s(x_n; \theta), f_t(x_n; \phi)) + \sum_{n} L_c(f_s(x_m; \theta), f_t(x_m; \phi)) \right)$$
loss on both labeled (subscript $n$) and unlaneled examples (subscript $m$).
Mean teacher learning¶
Teacher and student networks are trained in an alternative fashion. In the most well-known mean teacher model, teacher network is learned as an exponential moving average of the student network. Alternative optimization steps
For a given teacher network, train the student network
$$\theta^i = \arg \min_{\theta} \sum_{n} L(y_n, f_s(x_n; \theta)) + \lambda \left( \sum_{n} L_c(f_s(x_n; \theta), f_t(x_n; \phi)) + \sum_{n} L_c(f_s(x_m; \theta), f_t(x_m; \phi)) \right)$$
Update teacher network's parameters
$$\phi^i = \alpha \phi^{i - 1} + (1 - \alpha)\theta^i$$
Iterative between steps 1 and 2.
Auxiliary losses: Semi-supervised¶
In the teacher-student networks, we had a total loss of
$$\underbrace{\sum_{n} L(y_n, f_s(x_n; \theta))}_{\text{supervised loss}} + \underbrace{\lambda \left( \sum_{n} L_c(f_s(x_n; \theta), f_t(x_n; \phi)) + \sum_{n} L_c(f_s(x_m; \theta), f_t(x_m; \phi)) \right)}_{\text{auxiliary loss}}$$
we can generalize this form and construct different auxiliary loss functions that we believe will be useful for the task at hand
Semi-supervised type - defined over unlabeled images during task-specific optimization
Self-supervised type - defined over unlabeled images independent of the task
Semi-supervised 1: Similar representations for similar classes¶
Remeber from self training. This class of loss functions also want a compact representation for all the samples.
$$I \rightarrow L1 \rightarrow L2 \rightarrow L3 \rightarrow O$$
$$x \rightarrow h^1(x) \rightarrow h^2(x) \rightarrow h^3(x) \rightarrow f(x)$$
They have a task-specific loss for labeled examples, $L(y, f(x; \theta))$
They enforce representations of "close" samples to be close to each other as well
For segmentation problem, samples are pixels, since this is pixel-wise class prediction
There are two aspects we need to figure out
- What are "close" samples?
- How do we enforce their representations are close to each other?
The generic way to represent closeness is an affinity matrix
$$ \left[\mathbf{A}\right]_{ij}\begin{cases} 1 & x_i \text{ and } x_j \text{ are close} \\ 0 & x_i \text{ and } x_j \text{ are not close} \\ \end{cases} $$
This matrix is also used for describing graphs. There are different ways to define this matrix:
For labeled examples, if pixels have the same label they can be considered to be close.
For unlabeled examples, if pixels' surrounding image patches are similar, they can be considered close. This can be computed by template matching with normalized cross correlation.
Compact representations can be enforced by minimizing the distance between "close" samples and maximizing those that are "not close."
$$L^I_c(\{x\}, \mathbf{A}) = \sum_{i} \sum_{j} \begin{cases} d(h'(x_i), h'(x_j)) & [\mathbf{A}]_{ij} = 1\\ \max(0, \rho - d(h'(x_i), h'(x_j))) & [\mathbf{A}]_{ij} = 0 \\ \end{cases}$$
where $d(\cdot, \cdot)$ is a distance of choice, e.g., $L_2$ distance.
The goal of this compatness loss can be seen in to ways:
They replicate closeness in one space in the representation space. So that originally close samples, according to $\mathbf{A}$, will stay close and get the same label in the end.
The closeness in the original space acts as a surrogate label.
The final cost is simply the sum of labeled and compatness losses
$$\sum_{n} L(y_n, f(x_n; \theta)) + \lambda \sum_{I}L^I_c(\{x\})$$
there are variations to this auxiliary loss function.
Semi-supervised 2: Consistency under transformations¶
In pixel-wise predictions, spatial transformations to the input should be matched at the output. If $\phi$ is a spatial transformation, e.g., affine transformation or non-linear deformation
$$x \Leftrightarrow y \Rightarrow \phi \circ x \Leftrightarrow \phi \circ y$$
this is the underlying principle in data augmentation.
Leverage this for constructing an auxiliary loss for unlabeled examples for segmentation.
Auxiliary loss to enforce consistency under transformation¶
For every image, two transformations $\phi_{1}$ and $\phi_{2}$ can be sampled. Applying them to the sample yields
$$\phi_1 \circ x, \phi_2 \circ x, \phi_1 \circ y, \phi_2 \circ y, f(\phi_1 \circ x; \theta), f(\phi_2 \circ x; \theta)$$
Consistency can be enforced by
$$L_{cons}(x, \phi_1, \phi_2; \theta) = L(\phi_2 \circ \phi_1^{-1} \circ f(\phi_1 \circ x; \theta), f(\phi_2 \circ x; \theta))$$
where $L$ is a usual loss for pixel-wise predictions.
Notice that the consistency loss does not require any labels. Combingin with the usual loss gives the final semi-supervised cost
$$\sum_{x}L(y_n, f(x_n; \theta)) + \sum_{m} \mathbb{E}_{\phi_1, \phi_2} \left[ L(\phi_2 \circ \phi_1^{-1} \circ f(\phi_1 \circ x_m; \theta), f(\phi_2 \circ x_m; \theta)) \right]$$
with the second sum going over both labeled and unlabeled examples.
Self-supervised learning¶
Main principle: train a network with pre-text task that only requires images. The trained network will be suitable to fine-tuning with very few labeled examples to other tasks.
The crucial problem is to identify a pre-text task that will lead to useful network parameters.
The underlying assumption is that such tasks are good for learning "generalizable" parameters.
Contrastive learning¶
General idea: forcing compact representations with surrogate labels.
Surrogate labeling: random transformations of the same image are "close." We force them to be close in the representation space as well.
One cost function to this end¶
$$L(x, \{\hat{x}\}, \phi_1, \phi_2) = -\log \frac{\exp(\rho(h(\phi_1 \circ x), h(\phi_2 \circ x)) / \tau)}{\exp(\rho(h(\phi_1 \circ x), h(\phi_2 \circ x)) / \tau) + \sum_{n} \exp(\rho(h(\phi_1 \circ x), h(\hat{x}_n)) / \tau)}$$
similarity function is defined via cosine similarity
$$\rho(a, b) = a^T b / \| a \| \| b \|$$
transformations $\phi$ are randomly chosen and not restricted to gometric trnasformations. They also include intensity changes. $\{\hat{x}\}$ is a set of samples that should not have the same representation as $x$. They are called the "negative set."
Minimizing the loss
Numerator maximizes the cosine similarity between samples with the same content, i.e., $\phi_1 \circ x$ and $\phi_2 \circ x$.
Denominator minimizes the cosine similarity between samples with different content, i.e., $x$, $\phi \circ x$, and $\{\hat{x}\}$.
Global contrastive learning¶
Training an encoder with the contrastive loss shown above.
Through minimizing the contrastive loss at $h(x)$, the image-level representation at $L5$ are forced to be compact with respect to the surrogate labels and the correpsonding "closeness" notion defined through random transformations.
Local contrastive loss¶
In addition to the global compact representations, one can force local represnetations to be compact as well. This is particularly useful for pixel-wise predictions.
Training a decoder¶
Train a decoder attached to the fixed encoder from the global model.
$$L(x, \{\hat{x}\}, \phi_1, \phi_2) = -\sum_{i}\log \frac{\exp(\rho(P_i[h(\phi_1 \circ x)], P_i[h(\phi_2 \circ x)]) / \tau)}{\exp(\rho(P_i[h(\phi_1 \circ x)], P_i[h(\phi_2 \circ x)]) / \tau) + \sum_{n} \exp(\rho(P_i[h(\phi_1 \circ x)], P_i[h(\hat{x}_n)]) / \tau)}$$
In this case, $P_i[x]$ would correspond to the i-th patch in an image. However, we need to pay attention to the transformations.
Transformations should not change correspondence between patches, e.g., $P_i[h(\phi_1 \circ x)]$ and $P_i[h(\phi_2 \circ x)]$ should correspond to each other
We need to take into account the transformations induced by $\phi$ to compute correspondences, e.g., $P_i[\phi^{-1}_{1} \circ h(\phi_1 \circ x)]$ and $P_i[\phi^{-1}_{2} \circ h(\phi_2 \circ x)]$.
Fundation models¶
Working with fixed encoders and small task specific additions.
Fine-tuning the fixed encoder along with task specific component.
Efficient fine-tuning strategies.
Already demonstrated for segmentation, classification and image registration.
Uncertainty estimation in deep learning models¶
Problem setup¶
Predictions from machine learning models are often uncertain, and quantifying this uncertainty is crucial, especially in high-stakes fields like medical imaging.
In Classification¶
In medical diagnosis, a model might be uncertain about its prediction, and this uncertainty often correlates with incorrect predictions. For example, when classifying retinal OCT scans, a subtle sign of disease might lead the model to make a wrong prediction, but with high uncertainty, flagging the case for expert review.
Importance: Overly confident but wrong AI predictions can negatively sway the judgment of human experts. Studies have shown that clinicians are more influenced by high-confidence AI suggestions. If an AI is confidently wrong, it can mislead a clinician into making an error. Therefore, it is vital that models are not highly confident when they are incorrect.
In Segmentation¶
The boundaries of anatomical structures or lesions in medical images are often ambiguous due to image resolution or contrast limitations. This leads to variability even among expert human raters.
Importance: Uncertainty in segmentation translates directly to uncertainty in quantifying disease volume or load. This is critical for monitoring disease progression and evaluating treatment effectiveness, making it essential to account for this uncertainty.
In Detection¶
A key part of diagnosis is the detection of abnormalities. Models that can confidently identify both clearly normal and clearly abnormal regions can significantly improve clinical workflows. For subtle abnormalities that are difficult to discern, a model should produce a high uncertainty score, signaling that the case requires careful human inspection.
In Reconstruction¶
Image reconstruction from undersampled data (e.g., in fast MRI) is an ill-posed problem where missing information must be inferred. Multiple valid reconstructions can often be generated from the same undersampled data.
Importance: Uncertainty in the reconstruction process can affect subsequent morphological measurements. A trustworthy model should not generate confident predictions for parts of the image that were inferred and not present in the source data.
In General Model Fitting¶
The choice of model architecture and its parameters is a fundamental source of uncertainty. This uncertainty exists for predictions within the training data's range (interpolation) and is significantly amplified for predictions outside that range (extrapolation).
Mathematical treatment¶
A Bayesian approach provides a formal framework for understanding and modeling uncertainty.
Sources of Uncertainty¶
Uncertainty in a deep learning setup can arise from multiple sources:
- Inherent Ambiguity: The input features $x$ may not be sufficient to uniquely determine the label $y$.
- Parameter Uncertainty: The finite training data $D$ may not be sufficient to uniquely determine the optimal model parameters $\theta^*$.
- Data Uncertainty: The training set $D$ is just one random sample from a larger, unseen population of data.
- Model Uncertainty: The chosen model architecture $M$ is likely not the only or "true" model for the problem.
The Posterior Distribution: The End-Goal¶
The ultimate goal of uncertainty estimation is to characterize the posterior distribution $p(y|x)$. This distribution captures all possible labels $y$ that could correspond to a given input $x$.
To derive this, we must consider all sources of uncertainty by defining a full joint probability distribution $p(y,x,\theta,\mathcal{D},\mathcal{M})$ and then marginalizing the variables we are not interested in ($M$, $D$, and $θ$). This is expressed by the integral:
$$ \rho(y|x)=\int_{\mathcal{M}}\int_{\mathcal{D}}\int_{\theta}\rho(y|x,\theta,\mathcal{M})dp(\theta|\mathcal{D},\mathcal{M})dp(\mathcal{D})dp(\mathcal{M}) $$
However, solving these integrals is highly challenging and generally intractable.
Approximations¶
Since exact calculation is infeasible, we rely on approximations.
Standard Training as an Approximation¶
Typical deep learning training can be seen as a very simple approximation. It involves:
- Fixing the model and data: Choosing a single model $M$ and using a single dataset $D$.
- Point estimate for parameters: Finding a single optimal parameter set $\theta^*$ by minimizing a loss function. This is equivalent to approximating the parameter posterior $p(\theta|\mathcal{D},\mathcal{M})$ with a Dirac delta function, $\delta(\theta-\theta^*)$, which assumes the posterior is a single, infinitely sharp peak.
- Deterministic prediction: Assuming the output is a single, deterministic function of the input, $f(x;\theta^*)$.
This process yields a single prediction with no associated uncertainty. To capture uncertainty, we need more sophisticated approximations.
Approximating the Components¶
We can approximate the different parts of the main integral:
Approximating $p(\mathcal{D})$ (Data Uncertainty): Since we typically have only one dataset $\mathcal{D}'$, we can either assume it's the only one that matters ($p(\mathcal{D})\approx\delta(\mathcal{D}-\mathcal{D}')$) or use resampling techniques like bootstrapping (sampling subsets from $\mathcal{D}'$ with replacement) to create multiple training sets. This is the basis for deep ensembling.
Approximating $p(\mathcal{M})$ (Model Uncertainty): We can either fix a single model architecture ($p(\mathcal{M})\approx\delta(\mathcal{M}-\mathcal{M}')$) or approximate sampling by training a set of different models.
Approximating $p(\theta|\mathcal{D},\mathcal{M})$ (Parameter Uncertainty): The posterior over parameters is complex due to the non-convex nature of deep learning loss functions. Three main approximation strategies are:
- Dirac Approximation: The standard training approach of using a single point estimate.
- Approximate Sampling: Training the same model multiple times with different random initializations can explore different modes in the parameter posterior. This is another way to perform deep ensembling.
- Distributional Approximation: Approximate the posterior around a solution with a known distribution (e.g., a Gaussian). This is the core idea behind methods like Laplace approximation and variational inference (e.g., using dropout as a Bayesian approximation).
Modeling $p(y|x, \theta, \mathcal{M})$ (Inherent/Aleatoric Uncertainty): This component models the inherent ambiguity in the data itself. Instead of a deterministic output, the model can be designed to:
- Predict a distribution: The network outputs the parameters of a probability distribution (e.g., the mean and variance of a Gaussian) from which the label $y$ is drawn.
- Generate samples: The model acts as a generator (e.g., a GAN), taking the input $x$ and a random noise vector $z$ to produce diverse yet realistic outputs.
Evaluation¶
Evaluating the quality of the estimated uncertainty is challenging. The two main methods are:
Comparison with Observed Samples: If a dataset with multiple ground-truth labels per input is available, we can measure the distance between the model's predicted distribution and the empirical distribution of the true labels using metrics like Generalized Energy Distance (GED) or Maximum Mean Discrepancy (MMD).
Calibration Error: In the more common case where only one label is available per input, we can assess if the model's confidence aligns with its accuracy. A well-calibrated model that is 80% confident in its predictions should be correct 80% of the time. The Expected Calibration Error (ECE) is a metric used to quantify this.
Examples¶
- Variational Dropout: Combines an approximation for parameter uncertainty ($p(\theta|D,M)$) with a model for inherent uncertainty ($p(y|x, \theta,M)$) for tasks like dMRI super-resolution.
- Probabilistic U-Net: A generative model designed for segmentation that explicitly models the distribution of possible segmentations ($p(y|x, \theta,M)$) for an ambiguous image, trained on datasets with multiple expert annotations.
- Adversarial Sampler: Uses a generative model with an adversarial loss to produce diverse and realistic image reconstructions from undersampled data.
Interpretability methodologies for machine learning in medical imaging¶
Introduction to interpretability in AI¶
As deep learning (DL) models become increasingly complex and accurate, their inner workings become more opaque, creating a trade-off between prediction accuracy and explainability. This has given rise to the field of eXplainable AI (XAI).
What is Interpretability?¶
Interpretability is the ability for a human to understand the link between the features a machine learning model uses and the predictions it makes. The goal is to produce these explanations without sacrificing the model's performance. While simpler models are easier to understand, they are often less accurate and therefore less interesting from a practical standpoint.
Why is Interpretability Needed?¶
In high-stakes domains like healthcare, understanding why an AI makes a certain decision is crucial for:
- Trust and Accountability: To build trust with patients and providers, and to determine responsibility when an algorithm makes a mistake.
- Quality Control and Assurance: To audit models, ensure they are working as intended, and identify potential failures.
- Ethics and Fairness: To ensure models are not making decisions based on undesirable or discriminatory factors, such as patient race or gender, and to prevent non-harm.
- Data Exploration and Knowledge Discovery: To find new patterns in data and understand if the model has discovered genuinely new insights or is exploiting flaws in the data.
The Problem of "Shortcut Learning"¶
Deep neural networks often follow a principle of "least effort," learning the simplest possible solution to a problem. This can lead them to learn "shortcuts"—spurious correlations in the training data that do not generalize to new, out-of-distribution data.
- Examples of Shortcuts:
- A model classifies a husky as a wolf because it associates the background snow with wolves, not the animal's features.
- A COVID-19 detection model learns to identify the disease based on text markers or patient positioning in the X-ray, which are artifacts of the data collection process at specific hospitals, rather than actual pathological signals in the lungs].
- ImageNet-trained CNNs are often heavily biased towards recognizing texture over shape. Humans can easily identify objects from silhouettes or edge drawings, while these models fail completely, relying instead on textural patterns.
These shortcuts highlight the danger of deploying models without understanding what they have truly learned.
A taxonomy of interpretability methods¶
Interpretability methods can be categorized in several ways:
By Access to Model Internals:
- White-box Methods: Require access to the model's internal structure, such as its layers, weights, and gradients. Gradient-based methods are a primary example.
- Black-box (Model-Agnostic) Methods: Treat the model as an opaque box and work by perturbing inputs and observing the change in outputs. They do not need access to the model's internals.
By Type of Output:
- Visualization / Saliency Maps: The most common output, these are heatmaps that highlight which pixels or regions of an input image were most important for a given prediction.
- Feature Importance: These methods assign a score to each input feature, indicating its influence on the model's decision.
- Concepts: The explanation is given in terms of higher-level, human-understandable concepts (e.g., "the model used the 'striped' concept to identify the zebra").
There is a "choice overload" of available methods, each with its own strengths and weaknesses.
Model-agnostic (black-box) methods¶
These methods can be applied to any classifier, regardless of its architecture.
Partial Dependence Plot (PDP)¶
A PDP shows the marginal effect of one or two features on the predicted outcome of a model. It is a global method, meaning it describes the average effect of a feature across all instances in the dataset. By plotting the feature's value against the average prediction, PDPs can reveal if the relationship is linear, monotonic, or more complex. However, PDPs can be misleading if features are correlated and can hide heterogeneous effects by averaging over all data points.
Individual Conditional Expectation (ICE)¶
ICE plots disaggregate the average effect shown in a PDP and instead display one line per instance, showing how that specific instance's prediction changes as a feature is varied. The PDP is simply the average of all ICE curves. This makes ICE useful for discovering heterogeneous relationships that are masked by PDPs.
Local Interpretable Model-agnostic Explanations (LIME)¶
LIME explains an individual prediction by creating a local, interpretable model (like a linear model) that approximates the behavior of the black-box model in the vicinity of that prediction.
The process works as follows:
- Take the instance to be explained (e.g., an image).
- Generate a neighborhood of perturbed instances around it (e.g., by turning superpixels on or off).
- Get predictions for these perturbed instances from the complex model.
- Fit a simple, interpretable model (e.g., a weighted linear model) to this new dataset of perturbations and their predictions.
- The explanation is the set of features from the simple model that are most important (e.g., the superpixels with the highest weights).
LIME can reveal a model's decision-making process, including when it relies on shortcuts, such as classifying a "husky" as a "wolf" based on snow in the background.
Model-Specific (White-Box) Methods for Deep Learning¶
These methods leverage the internal architecture of neural networks, most often by using gradients.
Gradient-Based Methods & Saliency Maps¶
The core idea is that the magnitude of the gradient of the output score with respect to the input pixels indicates how important those pixels are for the decision.
- Vanilla Saliency: The simplest method, which just visualizes the gradient magnitude.
- Deconvnet & Guided Backpropagation: These methods modify the backpropagation process to produce sharper, more visually interpretable maps. They do this by filtering which gradients are passed backward through the network, typically by zeroing out negative gradients.
- Grad-CAM (Gradient-weighted Class Activation Mapping): A widely used technique that produces a coarse, class-specific localization map highlighting important regions in the image. It uses the gradients flowing into the final convolutional layer to produce a heatmap that shows where the model is "looking" to make its prediction. It is applicable to a wide range of CNN-based models.
Perturbation-Based Methods¶
These methods find an explanation by identifying the smallest perturbation to an input that maximally changes the output. Meaningful Perturbation methods learn an optimal "information deletion" mask that identifies the most salient regions by removing them (e.g., by blurring) and observing the drop in the class prediction score.
Layer-wise Relevance Propagation (LRP)¶
LRP is a decomposition technique designed to address the "shattered gradient" problem, where gradients in deep networks become less reliable. It explains a prediction by propagating the output score backward through the network, layer by layer, conserving the total relevance at each step. This results in a heatmap that shows which pixels contributed positively or negatively to the final decision.
Testing with Concept Activation Vectors (TCAV)¶
TCAV moves beyond pixel-level importance to provide explanations in terms of human-friendly concepts. It quantifies how important a user-defined concept (e.g., "striped" or "female") was for a classification decision, even if that concept was not a training label.
The process involves:
- Gathering examples of the concept (e.g., images with stripes) and random counterexamples.
- Extracting activations for these images from a chosen layer of the network.
- Training a linear classifier to separate the concept activations from the random activations.
- The vector orthogonal to this decision boundary is the Concept Activation Vector (CAV), which points in the direction of the concept in the model's activation space.
- The TCAV score is the fraction of images in a class whose prediction was positively influenced by moving in the direction of the CAV.
TCAV can be used to uncover biases, such as a model associating "red" with "fire engine" or relying on gender for certain classifications.
Evaluating and Applying Interpretability¶
A key challenge in XAI is objectively evaluating the faithfulness of an explanation.
Sanity Checks for Saliency Maps¶
Some interpretability methods might produce visually appealing heatmaps that are not actually faithful to what the model has learned. Sanity checks test this by randomizing the model's parameters, layer by layer, from top to bottom. If the explanation map does not change as the model's learned weights are destroyed, the method is not acting as an explanation but is likely just functioning as an edge detector. Studies have shown that methods like Grad-CAM and vanilla gradients pass these sanity checks, while methods like Guided Backpropagation and Guided Grad-CAM fail.
Applications of "Interpretability-Boosted AI"¶
Interpretability is not just for explaining decisions; it can be used to improve the entire ML pipeline.
- Quality Control & Bias Detection: Saliency maps can reveal when a model is using shortcuts or has learned biases from the data preparation process. This insight allows for data correction and model retraining, leading to improved performance. The INFORMER framework uses the "fingerprint" of an explanation to detect model errors at test time.
- Inductive Bias: Explanations can be used as a form of supervision to guide the model's learning. SIBNet introduces a loss function that encourages saliency maps for different classes to be distinct and spatially coherent, improving both performance and the quality of the explanations.
- Active Learning: Interpretability can guide the selection of the most informative samples to label next. The GESTALT method analyzes the disagreement in saliency maps across different classes to identify samples where the model is most uncertain, prioritizing them for annotation.
- Image Retrieval: Saliency maps, which highlight the pathological information a classifier uses, can serve as better features for content-based image retrieval than the raw image data alone.
Human-AI Alignment and Fairness¶
A key question is whether to force an AI's reasoning to align with a human's. Studies show that enforcing alignment can improve fairness (i.e., reduce performance gaps across demographic groups), but there is a trade-off. A moderate amount of alignment often yields the best balance of accuracy and fairness, suggesting that graduated guidance is better than forcing complete alignment or letting the AI learn entirely on its own.
Interpretability is no longer an optional add-on but a critical component for the safe and effective translation of machine learning into medical applications. The goal is not to understand every single weight in a neural network, but to obtain enough information to audit the model, build trust, ensure fairness, and discover new knowledge. While many visualization techniques exist, benchmarking and objectively evaluating these methods remains an active and important area of research.