Double-mix pseudo-label framework: enhancing semi-supervised segmentation on category-imbalanced CT volumes

Our method, depicted in Fig. 3, integrates two DMP modules within a CPS-like training framework. The input for our method contain labeled \(\textbf^l\), unlabeled \(\textbf^u\), and the ground-truth \(\textbf^l\). The category-wise weights (in this work, CDifW and DisW) are used for loss calculation and category mask \(}\) generation in the DMP module (see Fig. 4a). For segmentation models \(f_A\) and \(f_B\) in the CPS framework, we adopt Exponential Moving Average (EMA) [11] models \(\hat_A\) and \(\hat_B\) that have the same structure as the \(f_A\) and \(f_B\). The EMA models’ parameters \(\textbf^_A}_t\) and \(\textbf^_B}_t\) of \(\hat_A\) and \(\hat_B\) at iteration t are updated as \(\textbf^_A}=\mu \textbf_t^ + (1-\mu )\textbf^_A}_\) and \(\textbf^_B}=\mu \textbf_t^ + (1-\mu )\textbf^_B}_\) during training process, respectively. The DMP modules output mixed labels \(\hat}^m_A\) and \(\hat}^m_B\), defining new sample pairs for training \(f_A\) and \(f_B\). Framework details are described in Sect. “Double-mix Pseudo-label Framework”.

Distribution based weight (DisW)

We utilized two weights, DisW and CDifW, to estimate the category-wise distribution and difficulty. \(\textbf^l\) and \(\textbf^u\) represent the labeled and unlabeled input volumes, respectively.

Following [2], we compute the category-wise distribution weight \(w^}_\) for each category k at iteration t during training from the pseudo-labels \(\hat}^u_t\) of \(\textbf^u\), by first calculating the voxel-count ratio \(r_\), and then normalizing these ratios using the category-wise voxel-counts \(\psi _^L\) based on the pseudo-labels \(\hat}_t\) of the input \([\textbf^l,\textbf^u]\) by

$$\begin w^}_ & = \frac)}} \log (r_)},\nonumber \\ r_ & = \frac} \psi _^L}}}. \end$$

(1)

Weights are updated using an EMA approach

$$\begin \textbf^}_t = \beta \textbf^}_ + (1 - \beta )\hat}^}_t, \quad \hat}^}_t=(w^}_, \ldots , w^}_), \end$$

(2)

where \(\beta \) is the parameter for weight smoothing, and K is the number of categories.

Confidence-difficulty based weight (CDifW)

We assess the difficulty of each category through two dimensions: Dice score and confidence.

The learning speed is calculated by population stability index [12] based on the Dice score, which is defined as

$$\begin d^u_&= \sum _^ \mathbb _\ln \left( \frac}}\right) , \nonumber \\ d^l_&= \sum _^ \mathbb _\ln \left( \frac}}\right) , \nonumber \\ \partial&=\max (t-\tau ,0), \end$$

(3)

where \(\zeta _\) is the Dice score of the category k at iteration t, \(\Delta = \zeta _-\zeta _\) and \(\mathbb _\) and \(\mathbb _\) are defined as the indicator functions,

$$\begin \mathbb _ = 1 & \text \Delta \le 0 \\ 0 & \text \Delta> 0 \end\right. }, \quad \mathbb _ = 1 & \text \Delta > 0 \\ 0 & \text \Delta \le 0 \end\right. }. \end$$

(4)

The symbol \(\tau \) represents the cumulative number of iterations, which is empirically set as 50. Following [2], \(d^u_\) and \(d^l_\) are used to evaluate whether the category k has been unlearned or well learned, and the difficulty \(d_\) should be defined as \(d_ = (\frac + \epsilon } + \epsilon })^\alpha \), where \(\epsilon \) is a smoothing element and the \(\alpha \) is a hyperparameter to alleviate outliers. The well-learned category should perform low \(d_\). The difficulty weight for category k at iteration t is known as

$$\begin w^}_ = (1-\zeta _) d_. \end$$

(5)

As mentioned in the Introduction, it is necessary to utilize confidence in category-wise difficulty. For labeled data \(\textbf^L\), the confidence \(c_k\) for category k is computed from the logits \(\textbf^L\), where \(\textbf^L = \text \^L)\}\) and f is the segmentation model applying CDifW, \(f\in \\). At iteration t, the category confidence \(\hat_\) in a mini-batch is defined as

$$\begin \hat_ = \frac\sum _^ \frac\sum _ p^L_, \end$$

(6)

where B is the mini-batch size and \(p^L_\) is the probability of category k at location j in sample b, with j indicating positions marked as category k in the ground-truth. The term \(z_k\) represents the number of pixels for category k in the ground-truth \(\textbf\). The EMA method updates the confidence score \(c_\) as

$$\begin c_ = \beta c_ + (1 - \beta ) \hat_. \end$$

(7)

Similarly to [13], we define the information score \(\textbf_t\) for category k at iteration t as

$$\begin \textbf_= & \ \mid k = 1, 2, \dots , K \,\},\nonumber \\ s_= & \frac}}} (1-c_)}. \end$$

(8)

Then, our proposed CDifW \(\textbf^}\) in iteration t can be defined as

$$\begin \textbf^}_t=(w^}_ ,w^}_ \dots , w^}_ ), \quad w^}_ = s_^ w^}_, \end$$

(9)

where the parameter \(\gamma \) is a hyperparameter.

This method yields two sets of weights, \(\textbf^}\) and \(\textbf^}\), representing training difficulty and category distribution, respectively. We omit subscript t for brevity.

Double-mix pseudo-label module

The scarcity of labeled data necessitates using unlabeled data for augmentation. ClassMix [5] blends regions using pseudo-labels but does not rectify category imbalances, especially in high-difficulty categories, as described in the Introduction. Our DMP module counters this by applying weights \(\textbf^}\) and \(\textbf^}\) from Sects. “Distribution based Weight (DisW)” and “Confidence-Difficulty based Weight (CDifW)” to selectively blending categories with ClassMix for more balanced and effective augmentation.

The process of a single DMP module is shown in Fig. 4. Initially, for an input unlabeled volume \( \textbf^u \), its pseudo-label \( \hat}^u \) is computed using the EMA model. For data mixing, a binary mask should be created by the selected categories \( \textbf \). A probability distribution is generated using category-wise weights \(\textbf\) (in this work, \(\textbf \in \^}, \textbf^}\}\)), where the weight \(w_k\) for category \(k\) represents the probability of this category being sampled. We sample \(k\) times from this probability distribution, resulting in a set of selected categories denoted as \(\textbf\). Using this categories set \(\textbf\), we generate a binary mask \( \textbf \) corresponding to an unlabeled volume \( \textbf^u \) as follow: for any given pixel \( j\) in the volume, if \(\hat}^u_ \in \textbf \), then \( \textbf_ = 1 \); otherwise, \( \textbf_ = 0 \). Therefore, the mixed sample pair \([\textbf^m,\hat}^m]\) using unlabeled data \(\textbf^u\), labeled data \(\textbf^l\), and the ground-truth of \(\textbf^l\) can be obtained as

$$\begin \textbf^m = \textbf^u \odot \textbf + \textbf^l \odot (\textbf-\textbf),\quad \hat}^m = \hat}^u \odot \textbf + \textbf^l \odot (\textbf-\textbf), \end$$

(10)

where \(\odot \) is an element-wise product. As shown in Fig. 3, during the training process, we employ distinct weight distributions to perform two DMP operations to obtain two different mixed sample pairs \([\textbf^m_A, \textbf^m_A]\) and \([\textbf^m_B, \textbf^m_B]\). This approach considers both the difficulty and distribution of each category, focusing on the imbalanced category augmentation.

Double-mix pseudo-label framework

The process of DMPF is shown in Fig. 3. The updating process of \(\textbf^}\) and \(\textbf^}\), as well as the generation process of our proposed DMP, can be summarized in Algorithm 1. To simultaneously consider the distribution and difficulty of categories, we created two models, \(f_A\) and \(f_B\), with different random initializations for the model weights. As we defined in Sect. “Distribution based Weight (DisW)”, \(N^L\) and \(N^U\) show the sample number of the labeled dataset and the unlabeled dataset. We obtained logits \(\textbf_A\) and \(\textbf_B\) of the input data \([\textbf^l,\textbf^u]\) through \(f_A\) and \(f_B\). In the CPS framework, the supervised loss is

$$\begin L^_}(\textbf_A,\textbf_B,\textbf)&= \frac \frac \sum _^ [ L_( \textbf^}, \textbf_, \textbf_i) \nonumber \\&\quad + L_(\textbf^}, \textbf_, \textbf_i)], \end$$

(11)

where \(\textbf\) is the ground-truth of labeled data \(\textbf^l\), and for the unsupervised loss component

$$\begin L^_u(\textbf_A,\textbf_B, \hat}_A, \hat}_B)&= \frac \frac \sum _^ [ L_u(\textbf^}, \textbf_, \hat}_) \nonumber \\&\quad + L_u(\textbf^}, \textbf_, \hat}_)], \end$$

(12)

where \(\hat}_A\), \(\hat}_B\) are the pseudo-labels calculated from \(\textbf_A\) and \(\textbf_B\). In our experiments, \(L_(\textbf, \textbf, \textbf) = L_}( \textbf, \textbf, \textbf) + \frac L_}(\textbf, \textbf, \textbf)\) and \(L_u(\textbf, \textbf,\textbf) = L_}( \textbf, \textbf, \textbf)\), where \(L_}\) was set as the weighted cross-entropy loss and \(L_}\) was set as the weighted Dice loss [14].

Algorithm 1figure a

Double-Mix Pseudo-Label Framework

In the DMP module, \(\hat}^u_A\) and \(\hat}^u_B\) are the pseudo-labels generated by the unlabeled volumes \(\textbf^u\) from the EMA model \( \hat_A\) and \(\hat_B\). \([\textbf^u,\hat}^u_A,\textbf^l,\textbf]\) and \([\textbf^u,\hat}^u_B,\textbf^l,\textbf]\) are, respectively, fed into two DMP modules which selecting categories by \(\textbf^}\) and \(\textbf^}\). This process is employed to generate new training data pairs at each iteration, denoted as \([\textbf^m_A,\hat}^m_A]\) and \([\textbf^m_B,\hat}^m_B]\). The loss for the data pairs created by the DMP modules is

$$\begin L^_}(\textbf_A^m,\textbf_B^m,\hat}^m_A,\hat}^m_B) = L_(\textbf^},\textbf_A^m,\hat}^m_A) +L_ (\textbf^}, \textbf_B^m,\hat}^m_B), \end$$

(13)

where \(\textbf_A^m\) and \(\textbf_B^m\) are the output of \(f_A\) and model \(f_B\) with input \(\textbf_A^m\) and \(\hat}_B^m\). Therefore, the loss function can be defined as

$$\begin L & = L^_}(\textbf_A,\textbf_B,\textbf) + L^m_}(\textbf_A^m,\textbf_B^m,\hat}^m_A,\hat}^m_B) \nonumber \\ & \quad \ + \theta L^_u(\textbf_A, \textbf_B, \hat}_A, \hat}_B), \end$$

(14)

where \(\theta \) is a hyperparameters, and the epoch-dependent Gaussian ramp-up strategy [3] is used to enlarge the ratio of unsupervised loss.

In inference stage, for the input volume \(\textbf^p\), we calculate \(\textbf^p_A = f_A(\textbf^p)\) and \(\textbf^p_B = f_B(\textbf^p)\). The predicted logits are given by \(\textbf^p = \frac^p_A + \textbf^p_B)}\). The predicted result \(\textbf^p\) is derived from \(\textbf^p\) by assigning each voxel to the category with the highest predicted probability.

Table 1 Segmentation outcomes between our method and other SSL segmentation methods on 40% labeled BTCV datasetTable 2 Segmentation outcomes between our method and other SSL segmentation methods on 5% labeled CHD datasetFig. 5figure 5

Comparative experiments with other methods using 40% BTCV dataset. Some high-difficulty categories are highlighted by red frames

Fig. 6figure 6

Comparative experiments with other methods using 5% of the CHD dataset

Comments (0)

No login
gif