Fast reconstruction of milling temperature field based on CNN-GRU machine learning models

1 Introduction

In the era of Industry 4.0, China’s manufacturing industry is undergoing a profound transformation, and the use of robotics is becoming increasingly important in intelligent manufacturing. Intelligent manufacturing relies on multifunctional sensors to perceive the production environment (Cheng et al., 2016; Javaid et al., 2021). Production equipment autonomously learns through sensor-based and data-driven methods. This enables adaptive machining in changing environments. Ultimately, intelligent control achieves the desired outcomes (Lee et al., 2015). As one of the important machining methods in the manufacturing industry, milling processing has a broad prospect for the use of robots. In the use of robots, the most important issue is the processing quality and processing accuracy. In milling machining, localized high temperatures and strong time-varying temperature gradients are mainly concentrated at the boundary of the tool heat transfer system, i.e., the cutting region. Localized high temperatures impact tool life and can stimulate the chemical activity of the material being removed. This leads to material oxidation, rapid corrosion, and adhesion and diffusion between the material and the tool. Consequently, these effects degrade the machining accuracy and quality of the workpiece (Korkmaz and Gupta, 2024). In addition, the high temperature of the cutting area will also cause localized thermal deformation of the tool tip, which is one of the main reasons for the reduction of machining accuracy. Therefore, the localized high temperature in the cutting area and the strong time-varying temperature gradient will lead to the shortening of tool life, the reduction of workpiece machining quality, and the reduction of machining efficiency.

Due to the interference of cutting fluid and chips, existing sensing technology cannot directly measure the temperature field inside the cutting area (Alammari et al., 2024). Sensors can only be placed near the tool-chip contact area to obtain limited temperature data outside the cutting zone. Metal cutting is a thermodynamic coupling process with significant changes in material elastic–plastic deformation and contact area friction. These changes cause strong non-uniformity in the tool temperature field over time and space. A single or a few measurement points cannot accurately describe the actual processing conditions. Therefore, studying tool temperature field reconstruction during milling is crucial for extending tool life and improving machining accuracy.

Currently, there is still some difficulty in accurately measuring the temperature field online for the cutting region, and physical methods based on infrared thermography, artificial thermocouples and embedded thermocouples (Cichosz et al., 2023; Leonidas et al., 2022; Longbottom and Lanham, 2005) can only measure a limited number of in-situ temperatures near the cutting region or the approximate temperature field close to the location of the cutting region. In recent years, computational reconstruction methods for modeling the temperature field of tools have gained widespread attention. These methods bypass physical limitations to obtain temperature data at any location of interest. Current modeling techniques are mainly categorized into analytical modeling, numerical simulation based on cutting mechanisms, and inverse heat conduction modeling, which combines physical measurements with model-solving methods. The inverse heat conduction problem (IHCP) is part of the “mathematical physics inverse problem” field. A positive problem in physics research can be described by mathematical equations, where given equations and parameters, the output can be determined from a known input. Early studies simplified tool models and the cutting process, often treating the tool as a one- or two-dimensional model and the cutting process as steady-state. In these cases, analytical methods were combined with IHCP to directly compute mathematical expressions for the relationship between unknown quantities and measured values (Murio, 1981).

Nowadays, more and more researchers are focusing on reducing the complex three-dimensional structure of the tool with transient cutting process (Oommen and Srinivasan, 2022), for example, Some scholars (Liang et al., 2013) proposed a three-dimensional inverse heat transfer model based on an improved conjugate gradient method, which can quantitatively calculate the temperature of the tool chip contact area in dry turning. Some other researchers (Carvalho et al., 2006) used the golden section iterative method to solve the inverse heat conduction problem, and used the finite volume method to construct a three-dimensional model of the turning tool, which takes into account the thermal properties of the material as affected by temperature as well as the convective heat transfer losses to realize the temperature field reconstruction calculation.

In recent years, with the advancement of artificial intelligence algorithms and machine learning technology, artificial neural network models based on data relations have been widely used in inverse thermal problem solving. For example, the application of algorithms such as physical information neural network (PINN; Qian et al., 2023), nonlinear autoregressive exogenous input neural network (NARX; Chen and Pan, 2023), convolutional neural network (CNN; Kim and Lee, 2020), and multidomain physical information neural network (M-PINN; Zhang et al., 2022), etc., has made a certain contribution to the solution of the inverse heat conduction problem. Researchers (Zhang and Wang, 2024) have used deep neural networks to characterize and approximate partial differential equations (PDEs) in the forward problem style. They proposed an optimization algorithm that uses sequence-to-sequence (Seq2Seq) stacking with the gated recurrent unit (GRU) model. It improves the solving of these equations by stacking GRU modules to capture their evolution over time. It also has strong generalization ability.

There is still a wide range of prospects for the fusion of artificial neural network models. For example, CNN struggles to capture temporal features, while GRU struggles to capture spatial features. Combining CNN and GRU might allow their strengths to complement each other, enabling the CNN-GRU model to effectively capture spatio-temporal features, thereby improving the model’s accuracy and generalization performance.

This paper proposes a CNN-GRU based milling tool heat transfer model with knowledge distillation compression acceleration. The model reconstructs the milling tool temperature field under three different working conditions. A self-built milling temperature data acquisition system collects real-time temperature data from multiple points on the back face of the milling cutter. This system uses a temperature measurement tool embedded with a thin-film thermocouple array and a multi-channel signal acquisition device. By analyzing the relationship between machining parameters, the temperature at four measurement points on the milling tool, and the temperature in the cutting area, we use machining parameters and multi-point temperatures as input features. The temperature boundary conditions in the cutting area serve as prediction labels. The GRU is introduced to the convolutional neural network (CNN) to extract multi-dimensional feature information, aiming to improve reconstruction accuracy and efficiency. We then apply a knowledge distillation strategy to compress and accelerate the CNN-GRU model. This approach reduces computation time while maintaining high prediction performance and accuracy, ensuring efficient temperature field reconstruction.

The rest of this study is organized as follows: section 2 reviews the work related to this study, section 3 describes the proposed method in detail, section 4 reports the experimental results and analysis, and section 5 concludes this study.

The contributions of this paper are as follows:

1. CNN-GRU-based solution model for inverse heat transfer problem: a solution model for inverse heat transfer problem based on convolutional gated recurrent network (CNN-GRU) to predict the temperature boundary conditions in the cutting region of the tool is proposed. The model can well utilize the machining parameters in milling processing, the characteristics of the multi-point temperature of the milling tool back face, thus significantly improving the accuracy and efficiency of the milling temperature field reconstruction.

2. KD compression-accelerated model for solving inverse heat conduction problem: the constructed CNN-GRU model is compressed and accelerated using the knowledge distillation strategy. Compared with the model without KD acceleration, the model can substantially accelerate the training time with the least loss of goodness-of-fit, and has strong noise immunity.

3. A transient heat conduction model of milling tool is constructed, and the temperature field reconstruction of milling tool is carried out for three different working conditions, and the tests under the three working conditions are carried out in order to check its application ability in the reconstruction of milling temperature field.

2 Related work

Currently, deep learning techniques have been successfully applied to real-world scenarios, solving challenging problems like predicting the lifetime of relays and batteries. Constructing prediction models with artificial neural networks has broad applications in solving inverse heat transfer problems (Cortés et al., 2007; Wang et al., 2023; Kamyab et al., 2022). Additionally, methods for compressing and accelerating deep learning models have enhanced the efficiency and applicability of these techniques in real-time temperature field reconstruction. Integrating these advanced algorithms significantly improves the precision and speed of temperature field predictions during milling. This makes them indispensable for optimizing machining operations.

2.1 Shallow artificial neural network approach

Shallow artificial neural network methods were among the first techniques applied to the solution of inverse heat transfer problems. These methods utilize a simple hierarchical structure for data processing and prediction by simulating the way neurons in the brain work. Despite the simplicity of their structure, shallow neural networks have demonstrated their effectiveness and feasibility in solving specific problems, such as in the area of predicting lifespan.

Combining Back Propagation Neural Networks (BPNNs) with time-series data analysis methods has been utilized to predict the remaining life of cooling fans (Lixin et al., 2016). Based on the time series data analysis of historical data information to obtain the future trend of the data, the prediction error is adjusted using BPNN to ensure the accuracy of the prediction results. A single BPNN will face the problem of weight local optimization, i.e., overfitting, during training, and in recent years a large number of scholars have combined BPNN with other machine learning methods to improve the model accuracy.

Radial Basis Function Networks (RBFNs) are widely used in various fields due to their advantages of having outputs independent of initial weights and shorter training times. A gray RBFN-based prediction model (Li et al., 2009) for life and reliability of constant stress accelerated life testing has been developed and compared with traditional single Backpropagation Neural Networks (BPNNs). Experimental results demonstrate that the accuracy of the gray RBFN model surpasses that of the BPNN.

The shallow artificial neural network model has a high dependence on large-scale data, and the shallow model is prone to overfitting phenomenon during the training process, especially when the training data is small or the data dimension is high. In practice, the sensitivity of shallow artificial neural networks to data quality and noise may lead to a decrease in the robustness of the model.

2.2 Deep artificial neural network methods

With the improvement of computational power and the development of deep learning technology, deep artificial neural network methods have demonstrated powerful performance in solving complex problems. Deep neural networks are able to better capture complex patterns and higher-order features in the data through multilayer nonlinear transformations, which significantly improves the predictive ability of the model.

A deep learning method combining sparse stacked self-encoders (Stacked Sparse AEs, SSAEs) with Backpropagation Neural Networks (BPNNs) has been proposed (He et al., 2021). This method uses tool temperature measurements from temperature sensors to predict tool wear. When compared to BPNN and SVM models that rely on manually extracted time-frequency domain features, this approach demonstrates high prediction accuracy and stability.

A time window method for obtaining samples and a multivariate equipment life prediction method based on deep Convolutional Neural Networks (CNNs) have been proposed (Li et al., 2018), focusing on feature extraction. To avoid filtering out effective information by the pooling layer, the pooling layer was removed when constructing the network model. Additionally, a deep CNN method for bearing residual life prediction has been introduced (Ren et al., 2018), which combines spectral principal energy vectors into a feature map. This method extracts one-dimensional vectors and inputs them into the deep learning model through a multilayer CNN structure, demonstrating that its prediction accuracy meets the required standards.

Furthermore, the problem of predicting the remaining life of batteries using deep learning has been explored (Zhang Y. et al., 2018). Long Short-Term Memory (LSTM) networks are used to learn the long-term dependencies between the capacity degradation of lithium-ion batteries. LSTM employs backward error propagation for adaptive optimization and uses the dropout regularization technique to address the overfitting problem. This method exhibits better learning and generalization abilities compared to support vector machines and traditional recurrent neural networks.

The deep artificial neural network method has high accuracy for prediction problems such as lifetime prediction, but there is still a lot of room for improvement in efficiency, and there are some limitations in the application of inverse heat conduction solving problems and temperature field reconstruction.

2.3 Deep learning model compression and acceleration method

Neural network pruning is an important method to achieve network model compression and acceleration, and its working principle is mainly to cut off the weights and model branches that are not important when the neural network is working, to get a small model, from achieving the compression and acceleration of the model.

The ThiNet pruning method (Luo et al., 2017) differs from traditional pruning methods by treating network pruning as a reconstruction optimization problem. This approach determines the pruning strategy for the convolutional kernel of the current layer based on statistical information computed from the reconstruction differences between the inputs and outputs of the subsequent layer.

In addition to network pruning, other lightweight network design methods have been developed. Group point-by-point convolution (Zhang X. et al., 2018) performs grouped convolution operations to reduce the computational loss associated with point-by-point convolution operations. To enable grouped convolution to capture features computed by other groups, a mixing operation is introduced to reintegrate features from different groups, allowing the new group to contain features from other groups as well.

Automatic machine learning algorithms (AutoML) have also been widely used in lightweight neural network design. Some researchers (He et al., 2018) proposed AutoML for Model Compression (AMC), which utilizes reinforcement learning to efficiently sample the design space and learn compression strategies with better compression ratios to maintain model performance while reducing human intervention in the model. and maintain model performance while reducing human intervention in the model.

Due to the large model capacity difference between the teacher model and the student model, which leads to a “generation gap” between the student model and the teacher model, Wang Y. et al. (2018) pioneered a teacher-assistant-assisted knowledge distillation method, which utilizes the discriminator of the generative adversarial network as the teacher-assistant. They regarded the student model as a generator, and guided by the discriminator, the student model generated a feature distribution similar to that of the teacher’s model, thus assisting the student model in learning. Some researchers (Cui et al., 2017) proposed a novel mutual distillation method, which allowed two groups of untrained student models to start learning and solve the task together, i.e., the teacher and the student models were trained and updated at the same time.

According to the above findings, knowledge distillation, a deep learning model compression and acceleration strategy, has been widely applied and developed, but little research has been reported on the application of knowledge distillation techniques in the field of heat conduction inverse problem solving.

2.4 Temperature field reconstruction

Temperature field reconstruction is a key step in solving inverse heat transfer problems, through which accurate reconstruction of the temperature field can lead to a better understanding of the heat transfer process and improve the thermal performance of materials and devices. In recent years, temperature field reconstruction techniques combining advanced algorithms and neural network methods have made significant progress.

An enhanced Bayesian backpropagation neural network based on Kalman filtering has been proposed (Deng and Hwang, 2007), applying the Kalman filtering algorithm to improve the weak generalization ability of the backpropagation algorithm in approximating nonlinear functions. This enhancement improves the performance of the Bayesian backpropagation network in solving the inverse heat conduction problem, and it has been compared with backpropagation networks optimized using other mature algorithms, such as GMB and LMB.

In another study, the volumetric heat capacity function of solid materials with temperature has been solved using a backpropagation neural network combined with a radial basis function neural network based on full-history information (Czél et al., 2013). Some researchers (Wang H. et al., 2018) proposed a heat flux estimation algorithm based on a linear artificial neural network for identifying a finite shock response under a linear dynamic system.

In conclusion, temperature field reconstruction plays an important role in the solution of inverse heat conduction problems. By introducing neural networks and other intelligent algorithms, researchers have made many breakthroughs in improving the reconstruction accuracy and computational efficiency. These methods not only enrich the means of solving inverse problems theoretically, but also demonstrate a strong potential in practical applications, providing new ideas for the solution of complex heat conduction problems.

3 Methods 3.1 Acquisition of data sets

In metal cutting, the temperature of the tool is mainly affected by the integrated heat source of the three deformation zones, in which the heat is mainly transferred to the tool through the cutting area, and the cutting area of the tool can be regarded as the boundary of the tool heat conduction system. The cutting region of the tool generally includes: the tool-chip contact region and the tool-worker contact region, when the tool back angle is large, the cutting time is short and the cutting speed is small, and the back face of the tool does not undergo intense wear, the tool-worker contact region can be regarded as a part of the tool-chip contact region (Jaspers et al., 1998), and at this time, the tool-chip contact region is the tool’s cutting region. The front face of the tool can be photographed using an electron microscope, and the wear area of the main cutting edge attachment is the tool-chip contact area. Figure 1 shows the image of the cutting area of the tool with radial depth of cut (ae) of 0.2 mm and axial depth of cut (ap) of 8 mm.

www.frontiersin.org

Figure 1. Image of the cutting area of the tool with radial depth of cut (ae) of 0.2 mm and axial depth of cut (ap) of 8 mm.

In this test, a temperature measuring tool embedded with a thin-film thermocouple (TFTC) developed by this group was used for end milling Inconel 718 nickel-based high-temperature alloy workpiece, and according to the requirements of the test, the size of the workpiece was designed to be 50 mm × 20 mm × 10 mm. In the design of the test for end milling Inconel 718, the comprehensive consideration of the theory of heat transfer of metal cutting was taken into account, and the spindle speed (r/min), feed rate (mm/min), and radial milling depth (mm), which have an important influence on milling temperature, were taken as test variable factors. The spindle speed (r/min), feed rate (mm/min), and radial milling depth (mm), which have an important influence on the milling temperature, are taken as the test variable factors. After determining the test variables, a full factorial design of experiments (DOE) was used to ensure that all levels of each test variable were tested at least once. Figure 2 shows a physical diagram of a transient milling multi-point temperature measurement toolholder.

www.frontiersin.org

Figure 2. Milling multi-point temperature measurement toolholder.

The end milling test was conducted using the constructed test platform and test program, and the temperature data corresponding to the four temperature measurement points of the milling tool were recorded and saved. According to the actual processing requirements, when reconstructing the temperature field of the milling process tool, only the temperature field reconstruction of the tool during the cutting process needs to be considered, without the need to reconstruct the temperature field of the retracting process after the completion of machining, so this paper in the subsequent processing of data in the process of selecting the cut in to the cut out of the retracting tool before the start of the temperature reduction moment to be recorded. In the end milling process, the tool is often accompanied by violent vibration, so the collected temperature data will have a certain noise level, and the data need to be filtered.

The inverse heat conduction problem of the tool heat conduction model refers to the fact that one of the parameters in the control equations, initial conditions, thermophysical parameters, and all boundary conditions of the tool heat conduction is in a missing state, and the unknown parameters need to be solved in reverse by measuring the physical signals by other methods. The inverse heat conduction problem in this paper belongs to the first type of margin estimation inverse problem, where the temperature on the boundary of the tool heat conduction system is estimated from the temperature sensor measurement results. The temperature on the boundary of the tool heat transfer system cannot be measured by physical methods or the measurement accuracy is poor, and is generally obtained using simulation methods. By constructing a local numerical simulation model of the milling process, the Inconel 718 end milling simulation model is operated and set up in complete control of the full factorial test parameters and machining time, and the simulation model is adjusted and corrected by the results of the actual sensor measurements and comparison of the chip morphology. The test simulation was completed using a cutting model with the required accuracy to obtain temperature data on the boundary of the tool heat transfer system.

The simulation model is adjusted and calibrated according to the test results, and after the accuracy meets the requirements, the average temperature of the cutting area of the tool is derived from the cloud diagram of the simulation results, and the other 26 sets of temperature curves can also be obtained by the simulation model, which provides the data sample set of the training model for the subsequent inverse heat conduction problem solving.

3.2 Construction of gated convolutional recurrent network model

The traditional one-dimensional CNN may ignore the time series features in the input data, resulting in the loss of some important time series information, in order to solve this problem, GRU can be introduced on the basis of one-dimensional CNN to simultaneously extract the multi-dimensional feature information as well as the temporal characteristics of the time series. Gated Convolutional Recurrent Neural Network (CNN-GRU) is a kind of neural network that combines the features of both CNN and GRU models, and is usually used to process time-series data, text, speech, and video, etc. The workflow of CNN-GRU is firstly, the input data undergoes a series of Convolution and pooling operations to extract the spatial dimension information in the data, and then the local features after the convolution operation are input into the GRU for sequence modeling, the GRU will dynamically update the hidden state according to the feature sequences of the input data, obtaining the long-term dependencies in the input data, and perform the task of output prediction according to the hidden state of the GRU network, and the final output is performed through a fully connected layer. Using CNN-GRU as a model for solving the inverse heat transfer problem can directly establish the nonlinear data relationship between the machining parameters, the temperature of the back face of the milling cutter and the temperature of the cutting area. The CNN effectively captures the spatial information and combines with the GRU network to model the long-term dependence in the sequence to realize the rapid solution of the nonlinear inverse heat transfer problem.

The prediction process based on the CNN-GRU model is shown in Figure 3 as follows:

1. Preprocess the original dataset with data normalization and dataset division;

2. Construct the CNN-GRU model;

3. Use the validation set to verify the model accuracy and save the model with the required accuracy;

4. Test the CNN-GRU model with the test set to obtain the final temperature prediction results on the cutting area of the tool.

www.frontiersin.org

Figure 3. Prediction process based on the CNN-GRU model.

As the original data set milling temperature and machining parameters and other types of data have different scales and value ranges, which makes some features weight update process will be affected by the larger and ignore some other features, normalization can eliminate this effect so that all features have the same scale. In addition, in this case, the difference in the scale of the features will affect the training and convergence of the model, if there is a large difference in the scale of the features, then the step size of the update in the gradient descent process may be affected by the difference in the size of the gradient, which will lead to a slower convergence speed. By normalization, the direction of gradient descent can be made consistent, accelerating the convergence speed of the model. There are many methods of normalization such as Min-Max Scaling, Z-score Standardization, Softmax Normalization etc. According to the data type choose the Min-Max Scaling method for data normalization, which is a common normalization method to scale the data to between [−1,1], the formula is shown as Equation (1):

x∗=xi−xminxmax−xmin    (1)

where x* represents the normalized data, xi represents the observed value at moment i, and xmin and xmax are the minimum and maximum values in the data, respectively.

The predictions of the model on the test set are restored after the model training using inverse normalization, which is formulated as Equation (2):

x=x∗xmax−xmin+xmin    (2)

where x inverse normalized value, x* normalized value of the prediction result, and, xmin, xmax are the minimum and maximum values in the data, respectively.

To determine the model structure, this study employs Mean Squared Error (MSE) and R-squared (R2) as evaluation metrics. The formulas for these metrics are as Equations (3, 4):

MSE=1n∑i=1ny^i−yi2    (3)

where n represents the total number of samples, i denotes the current sample, ŷi is the predicted value for the ith sample, and yi is the true value for the ith sample. A smaller MSE indicates that the model’s predictions are closer to the true values, signifying better model performance.

R2=1−∑i=1nyi−y^i2∑i=1nyi−y¯i2    (4)

Where n represents the total number of samples, i denotes the current sample, ŷi is the predicted value for the ith sample, yi is the true value for the ith sample, ȳi represents the mean of the true values yi. The range of R2 is [0,1], with a higher R2 indicating better model performance.

In the inverse heat transfer problem solving model based on deep learning, effective data samples are the key to develop the model to accurately predict the boundary temperature conditions in the cutting region, among the 27 sets of full factorial test samples, the 24th set of test data is extracted as the test set data, and 80% of the remaining data is treated as the training set, and 20% is treated as the validation set. The machining parameters and the temperature at multiple points on the back face of the milling cutter are chosen as input features, and the temperature boundary conditions on the milling cutter cutting region are used as prediction labels. The compiled language for the neural network is Python 3.7, the model is built using the PyTorch deep learning framework, the operating system is 64-bit Windows 10, and the GPU is an NVIDIA GTX 1050Ti graphics card.

Among them, the parameters of the model are shown in Table 1, and the overall structure of the CNN-GRU model built in this paper is shown in Figure 4.

www.frontiersin.org

Table 1. CNN-GRU model parameters.

www.frontiersin.org

Figure 4. General structure of CNN-GRU model.

In order to confirm the validity and accuracy of the models, this paper compares the constructed CNN-GRU models with CNN, GRU, and LSTM networks, using MSE as the Loss function and R2 as the evaluation index of model error. All models use the dataset delineated in the previous section, and the parameter details of each model are shown in Table 2, and the training Epoch and batch size of all models are kept the same in order to ensure the scientific nature of model comparison.

The Loss function curves for the training process of each model are shown in Figure 5A. The figure illustrates that, for the same number of training iterations (150), the CNN model stabilizes its Loss value at approximately the 120th iteration, making it the slowest to converge among the models. In contrast, the LSTM and GRU models stabilize around the 90th iteration. Notably, the CNN-GRU model exhibits a smooth trend and stabilizes as early as the 45th iteration, demonstrating the fastest convergence speed among all the models.

www.frontiersin.org

Table 2. Parameter details for each model.

www.frontiersin.org

Figure 5. Performance comparison of different models. (A) The Loss function curves for each model training process. (B) Fit curves for each model on the test set.

During 150 training sessions, the final LOSS value of the CNN-GRU model is 2.57 × 10−3, the final LOSS value of the CNN model is 6.37 × 10−3, the final LOSS value of the LSTM model is 5.3 × 10−3, and the final LOSS value of the GRU model is 6.82 × 10−3. In comparison, the CNN-GRU model exhibits better learning ability and fitting effect, and the evaluation indexes of each model are shown in Table 3.

www.frontiersin.org

Table 3. Evaluation indicators for each model.

The fitting curves of each model on the test set are shown in Figure 5B, from which it can be seen that the prediction curves of the CNN-GRU model are the closest to the real value, and compared with other models, it can predict the temperature trend on the cutting area of the tool more efficiently, especially in the position of the peaks and valleys of the best fitting, which further verifies that the prediction results of the CNN-GRU model are more in line with the practical requirements.

3.3 Temperature boundary condition estimation model based on knowledge distillation with gated convolutional recurrent networks

Knowledge distillation is an instructor-student training structure that typically utilizes a student model with a simpler network structure to learn the knowledge provided by an instructor model that has been trained with a more complex network structure; this approach trades a slight performance loss for faster computation and smaller model parameters. Knowledge distillation works by training the student model with both the predictions of the teacher model (soft labeling) and the real data (hard labeling), and calculating the weighted total loss of the student model on both the soft and hard labels, essentially “migrating” the knowledge learned by the teacher model to the student model. The structure of the knowledge distillation strategy used in this paper is shown in Figure 6.

www.frontiersin.org

Figure 6. Structure of knowledge distillation strategy.

The specific knowledge distillation strategy process is as follows:

1. The raw data that has been preprocessed is input to both the teacher model and the student model, the teacher model is the CNN-GRU model constructed in the previous section, and the student model is a small model with a single CNN layer and a single GRU layer.

2. The output of the teacher model is softened using the Softmax function with temperature coefficient T. The processed labels are used as soft labels.

3. Use the same Softmax function with temperature coefficient T to soften the results of the student model output, and process the labels of the student output and the soft labels of the teacher model output in the previous step through the distillation loss function LOSSsoft to obtain the distillation loss function between the student model and the teacher model.

4. Process the unsoftened student model output labels with the real hard labels through the student loss function LOSShard to get the student loss.

5. The distillation loss and the student loss are weighted to obtain the total loss, and the gradient of each parameter is updated in the backpropagation process.

The following are the calculation formulas involved in the knowledge distillation operation process:

Knowledge distillation soft labeling calculation formula as Equation (5):

qi=expziT∑j=0kexpzjT    (5)

where T is the distillation temperature coefficient, used to control the “hardness” of the soft label. When T is larger, the soft label distribution area is uniform, more softened, when T is smaller, the soft label distribution closer to the hard label.

Distillation loss of the loss function LOSSsoft formula is as Equation (6):

LOSSsoft=∑i=0k−piuiTlogpiziT    (6)

where k is the total number of samples, pi(ui,T) is the ith output of the teacher model at temperature coefficient T, and pi(zi,T) is the ith output of the student model at temperature coefficient T.

The loss function LOSShard for student loss is formulated as Equation (7):

LOSShard=∑i=0k−yilogpizi1    (7)

where yi7is a vector of hard labels representing the class i output of the unsoftened student model.

The total loss of knowledge distillation can be expressed as Equation (8):

LOSStotal=λLOSS+soft1−λLOSShard    (8)

where λ are hyperparameters, which are fixed constants that can be empirically tuned to the reference or dynamically adjusted.

Based on the above knowledge distillation strategy for model optimization design of the constructed CNN-GRU teacher model, the first step is to construct a simple CNN-GRU student model, and with reference to the structure of the teacher model with 2 layers of CNN layers plus 2 layers of GRU layers, the student model structure is designed as a 1-layer CNN layer plus 1 layer of GRU layer structure. In order to determine the optimal student model total neuron number, the student models with total neuron number of 10, 20, 30, 40, 50, 60, 70, 80, and 90% of the teacher’s model were designed, and the gradient descent training was performed on each model using the same training, validation, and test sets, and the training Epoch and batch sizes were consistent with those of the teacher’s model. The learning ability and single-step training time of each student model not trained by the knowledge distillation strategy are first compared to the true values, and the comparison of the prediction results of each percentage of student models is shown in Figure 7A.

www.frontiersin.org

Figure 7. Comparison of model performance before and after adding knowledge distillation. (A) Comparison of student model training results for each scale. (B) Comparison of student model predictions for each scale after knowledge distillation. (C) Comparison of model performance improvement before and after acceleration of knowledge distillation strategy.

As can be seen from the figure, the goodness of fit of the student model gradually increases with the increase of the total number of neurons before the knowledge guidance of the teacher’s model, and the goodness of fit tends to stabilize when the ratio of the student model to the teacher’s model is 60%, which indicates that the closer the student model is to the teacher’s model, the better the ability to learn the data, however, due to the structural limitation of the student model, the simple model structure is not enough to accurately reflect the complex nonlinear relationship between the input data and the output data. However, due to the structural limitations of the student model, the simple model structure is not enough to accurately reflect the complex nonlinear relationship between the input data and the output data, although the goodness of fit of the student model to the teacher’s model still fluctuates slightly after the ratio of the student model to the teacher’s model is more than 60%, but the overall learning ability does not improve much. The single-step training time consumed by the student model also becomes more with the increase of the total number of neurons, and the rate of change of the single-step training time consumed increases when the ratio of the student model to the teacher’s model is 70%, which demonstrates that the closer the number of neurons of the student model is to that of the teacher’s model, the slower the model’s inference is, and the more hardware resources it occupies.

Then, using the knowledge distillation strategy, the teacher model trained in the previous section is used to “guide the training” of the above student models of different sizes, so as to transfer the knowledge learned from the teacher model to the student model. In order to avoid random errors, the distillation temperature coefficient T is set to [1,10], T takes an integer, the total loss weighting factor λ is set to [0.1,0.9], λ retains one decimal place, and the distillation effect of the model under the parameter combinations of T and λ is compared one by one, and it is finally determined that T = 7, λ = 0.8. The comparison of the prediction results of various proportions of the students’ models after the distillation is shown in Figure 7B.

As can be seen from the figure, the student models (CNN-GRU + KD) guided by the teacher’s model all have a better improvement in the goodness-of-fit, and the R2 shows a smooth trend and stabilizes around 0.96 when the percentage of the student model to the teacher’s model is 60%, which results in a longer single-step training time than that of the original model after the distillation before the original model is longer, when the percentage is more than 80%, the single-step training time of the model is close to that of the teacher model. Therefore, considering the goodness of fit of the student model and the single-step training time, the total number of neurons of the student model is determined to be 60% of the teacher model, and the network parameters of the student model are shown in Table 4. A comparison of the performance improvement of the model before and after acceleration by the knowledge distillation strategy is shown in Figure 7C.

Comments (0)

No login
gif