Here is my BlogPost for the paper "Global-Reasoned Multi-Task Learning Model for Surgical Scene Understanding" written by Lalithkumar Seenivasan1+, Sai Mitheran2+, Mobarakol Islam3, Hongliang Ren1,4∗ Senior Member, IEEE.
INTRODUCTION
Fig. 1: Enhancing surgical scene understanding (tool interaction detection and instrument segmentation) through global-local relational reasoning.
Global and local relational reasoning enable scene understanding models to perform human-like scene analysis and understanding. Inspired by the work that a graph-based global reasoning network was proposed that performs global reasoning in the latent space to efficiently capture global relations, we propose a global-reasoned multi-task surgical scene understanding model that performs instrument segmentation and detects tool-tissue interaction. Global-reasoned surgical scene understanding is critical in developing surgical skill assessment, real-time and post-surgical analysis, augmented tactile feedback and automated surgical report generation. By combining the GloRe unit that reasons in the latent space and multi-scale-feature decoder aggregation that captures local relations at multiple scales, the semantic segmentation model is aimed to perform better scene reasoning. To detect the tool action, we improve upon the visual-semantic graph attention network (VS-GAT) [20] and introduce Globally-Reasoned VS-GAT. By embedding global-reasoned latent features to VS-GAT, we hypothesize the model to detect globally-reasoned node-to-node interaction. By sharing the feature encoder and GloRe unit, we also reduce the computational cost compared to running two independent single-task models.
Key Contributions
Propose a globally-reasoned multi-task learning (MTL) surgical scene understanding model that performs instrument segmentation and tool-tissue interaction detection.
Improve the MTL model’s segmentation performance by incorporating latent global interaction reasoning and introducing multi-scale local reasoning.
Utilize the MTL model setup to enhance interaction detection performance by sharing a generalized feature extractor for visual feature extraction and incorporating globally-reasoned features from the segmentation module into the scene graph (tool interaction detection) model.
Study the performance of sequential and knowledge distillation (KD) based optimization techniques in optimizing MTL models for optimal model convergence.
RELATED WORK
Surgical Instrument Segmentation
To address the spatial inconsistency problem, instance-based segmentation has been proposed [17, 15]. Current state-of-the-art (SOTA) models in instrument segmentation include MF-TAPNet [15] and ISI-Net [8]. MF-TAPNet [15] employs an attention mechanism and utilizes temporal optical flow. Built on top of Mask-RCNN [10], ISI-Net [8] employs a temporal consistency strategy to take advantage of the temporal frame sequence.
A refined attention-based network called RASNet [23] was also proposed that utilizes the attention mechanism for semantic segmentation to leverage on the global context of high-level features to focus on key regions of the image.
As an alternative to these prior works, we propose a simple and efficient global and local reasoned model that achieves competitive performances against existing SOTA models.
Surgical Tool Interaction Detection
Initially, human-to-object interaction detection was achieved by employing Fast-RCNN [21] and Faster-RCNN [7].
This issue of rubustness was addressed by theorizing the interaction detection task in the non-euclidean space and employing graph networks to detect interaction [24, 20].
Graph passing neural network (GPNN) [24] theorizes each scene as a sparse graph, with its nodes being the objects and edges denoting the presence of interaction. While GPNN relies mainly on visual features to detect object-to-object interaction, the visual semantic graph attention network (VS-GAT) [20] was introduced that utilizes spatial and semantic features on top of visual features to detect interactions.
Here, we further improve the VS-GAT in detecting interaction by including globally-reasoned features.
Multi-Task Learning
A single MTL model offers a computational advantage over multiple single-task learning (STL) models.
GradNorm [5] helps balance the learning of independent task sub-modules, thereby balancing independent task influence on the shared module and improving synchronization in model convergence.
Attention prone MTL optimization technique [12] has also been proposed that enables sequential convergence of model’s independent tasks.
In this work, we also implement and study the performance of (i) sequential, (ii) vanilla, and (iii) KD-based MTL optimization in training our proposed MTL model.
METHODOLOGY
Fig. 2: The proposed network architecture. The proposed globally-reasoned multi-task scene understanding model consists of a shared feature extractor. The segmentation module performs latent global reasoning (GloRe [4] unit) and local reasoning (multi-scale local reasoning) to segment instruments. To detect tool interaction, the scene graph (tool interaction detection) model incorporates the global interaction space features to further improve the performance of the visual-semantic graph attention network [20].
Global And Local Reasoning For Instrument Segmentation
Fig. 3: Multi-scale global reasoning for instrument segmentation
A simple encoder-decoder pair incorporating global reasoning is employed to achieve competitive performance with the SOTA models in instrument segmentation.
The GloRe unit [4] is employed to reason global interaction in the latent space. While this reasoning is limited to the latent interaction space, we include the Multi-Scale Local Reasoning module (MSLR).
Here, multi-scale decoder aggregation is performed to capture multi-scale local (neighborhood) relations in coordinate space. For the decoder block, we design a lightweight decoder with (a) a conv block (conv-BatchNorm-ReLU), (b) a dropout and (c) finally a conv layer.
To improve instrument segmentation, three variants of global reasoning, (i) vanilla GR, (ii) Multi-scale global reasoning (MSGR), and (iii) multi-scale local reasoning and GR (MSLRGR), have been studied.
(i) vanilla GR: the GloRe unit [4] is naively implemented to reason on the encoder’s latent features. In MSGR, the GloRe unit is employed to reason on multi-scale interactions as shown in Fig. 3.
(ii) Multi-scale global reasoning (MSGR): the GloRe unit is employed to reason on multi-scale interactions as shown in Fig. 3.
(iii) multi-scale local reasoning and GR (MSLRGR): global reasoning is achieved by combining vanilla GR and multi-scale local (neighborhood) reasoning (MSLR).
Global Reasoning For Interaction Detection
The VS-GAT [20] network employs two sub-graphs: (a) Visual graph (Gv) and (b) Semantic graph (Gs) embedded with visual (Fvf) and semantic features (Fsemf), respectively. The two graphs are then propagated and fused to form a combined graph (Gc). The edges of this graph are embedded with spatial features (Fsf = features of bounding box location). We append the GISF (FGISF) from the segmentation module’s GloRe unit to the combined graph’s edges. This allows the model to predict the interactions based on both node-to-node and global latent interaction reasoning [ Y = G ( F v f , F s e m f , F s f , F G I S F ) ].
Multi-Task Optimization
To address the asynchronous convergence problem, we explore three different optimization techniques.
The first one, Vanilla-MTL (V-MTL) optimization, naively combines the loss of both tasks during the training.
L V − M T L = ( α ∗ L s g ) + ( ( 1 − α ) ∗ L s e g )
In the second variant, KD-based MTL (KD-MTL) optimization [19] is explored.
The KD-MTL favors the segmentation task in training the feature encoder. Here, the task losses are combined with Kullback-Leibler divergence (KLD) [18] loss between the feature encoder outputs of the STL segmentation model and MTL model. By reducing the KLD loss between the outputs of the feature encoder, we aim for the MTL to improve model convergence of the segmentation model.
L K D − M T L = ( α ∗ L s g ) + L s e g + L K L D − s e g
The final optimization technique involves optimizing the MTL model sequentially (S-MTL).
As shown in Algorithm 1, the MTL model’s feature encoder and segmentation model is first trained based on the segmentation loss. Upon convergence, the weights of the feature encoder and segmentation blocks are frozen. The training of scene graph in detecting interactions is then performed until convergence.
1: [Initialize model weights]shared feature extractor (Wsh), scene segmentation (Wseg), scene graph (Wsg)
2: [Set gradient accumulators to zero]shared feature extractor ( d W s h ), scene segmentation ( d W s e g ), scene graph ( d W s g ) d W s h ← 0 , d W s e g ← 0 , d W s g ← 0
3: [Optimize feature extractor and segmentation network]while tasknotconvergeddo: [segmentor and feature extractor gradients w.r.t segmentation loss Lseg] dWsh←dWsh+∑iδi∇WshL(Wsh,Wseg) dWseg←dWseg+∑iδi∇WsegL(Wsh,Wseg)endwhile
4: [Optimize Scene graph]whilescenegraphtasknotconvergeddo: [Scene graph block gradients w.r.t scene graph loss Lsg] dWsg←dWsg+∑iδi∇WsgL(Wsh,Wseg,Wsg)endwhile
EXPERIMENTS
Dataset
The model’s performance in interaction detection and instrument segmentation is trained and evaluated on MICCAI Endoscopic Vision Challenge 2018 [1] dataset.
Implementation Details
We employ cross-entropy loss to calculate the segmentation loss and multi-label loss to calculate the interaction detection loss. The models are trained using the Adam optimizer [16]. The feature extractor is initially loaded with ImageNet pre-trained weights. The learning rate at epoch = 0 is set to xxx and is decayed by 0.98 every 10 epochs. Our models are trained for 130 epochs with a batch size of 4.
Multi-Task Model Improving Single-Task Performance
Fig. 4: Variants of feature sharing between the segmentation and scene graph modules in multi-task setting to improve single-task performance
Model | Tool interaction detection | Segmentation | |||||||||||
Acc | mAP | Recall | mIoU | P-Acc | Class-wise IoU | ||||||||
T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 | ||||||
SOTA (Surgical scene graph) | |||||||||||||
GPNN [24] | 0.5500 | 0.1934 | - | - | - | - | - | - | - | - | - | - | - |
Islam et al. [13] | 0.4802 | 0.2157 | - | - | - | - | - | - | - | - | - | - | - |
G-Hpooling [28] | 0.3321 | 0.1523 | - | - | - | - | - | - | - | - | - | - | - |
VS-GAT[20] | 0.6537 | 0.2560 | 0.2666 | - | - | - | - | - | - | - | - | - | - |
SOTA (Surgical scene segmentation) | |||||||||||||
LinkNet34 [2] | - | - | - | 0.2610 | 0.93 | 0.9193 | 0.3581 | 0.1481 | 0.0062 | 0.6488 | 0.0004 | 0.0071 | 0.0000 |
AlbUNet [26] | - | - | - | 0.2471 | 0.91 | 0.9090 | 0.3610 | 0.0923 | 0.0064 | 0.6082 | 0.0000 | 0.0000 | 0.0000 |
Ternaus-UNet11 [26] | - | - | - | 0.2406 | 0.917 | 0.8904 | 0.3267 | 0.0741 | 0.0055 | 0.6283 | 0.0000 | 0.0000 | 0.0000 |
Ternaus-UNet16 [26] | - | - | - | 0.2329 | 0.918 | 0.8811 | 0.3069 | 0.0923 | 0.0062 | 0.5763 | 0.0000 | 0.0003 | 0.0000 |
MF-TAPNet [15] | - | - | - | 0.2489 | 0.931 | 0.9310 | 0.2961 | 0.0225 | 0.0000 | 0.7420 | 0.0000 | 0.0000 | 0.0000 |
MF-TAPNet11 [15] | - | - | - | 0.3568 | 0.955 | 0.9729 | 0.6142 | 0.2338 | 0.0100 | 0.8420 | 0.0030 | 0.1634 | 0.0153 |
MF-TAPNet34 [15] | - | - | - | 0.3543 | 0.952 | 0.9767 | 0.6636 | 0.3435 | 0.0284 | 0.8222 | 0.0000 | 0.0000 | 0.0000 |
ResNet18 [11] | - | - | - | 0.3858 | 0.9487 | 0.9533 | 0.5764 | 0.3810 | 0.0008 | 0.8353 | 0.0073 | 0.2763 | 0.0557 |
ResNet18 + GloRe [4] | - | - | - | 0.3926 | 0.9483 | 0.9524 | 0.5764 | 0.3842 | 0.0009 | 0.8256 | 0.0720 | 0.2847 | 0.0448 |
S-MTL (Ours) | |||||||||||||
MSLRGR | 0.7003 | 0.2885 | 0.3096 | 0.4354 | 0.9638 | 0.9714 | 0.6966 | 0.4356 | 0.0015 | 0.8716 | 0.1203 | 0.3471 | 0.0387 |
MSLRGR-GISFSG | 0.6994 | 0.3131 | 0.3157 | 0.4354 | 0.9638 | 0.9714 | 0.6966 | 0.4356 | 0.0015 | 0.8716 | 0.1203 | 0.3471 | 0.0387 |
KD-MTL (Ours) | |||||||||||||
MSLRGR | 0.6434 | 0.2992 | 0.2818 | 0.4105 | 0.9617 | 0.9704 | 0.6775 | 0.3650 | 0.0002 | 0.8598 | 0.0709 | 0.3330 | 0.0069 |
MSLRGR-GISFSG | 0.6710 | 0.2808 | 0.3072 | 0.4105 | 0.9611 | 0.9713 | 0.6670 | 0.3785 | 0.0028 | 0.8603 | 0.0458 | 0.3184 | 0.0401 |
TABLE I: Comparison of our proposed globally-reasoned multi-task scene understanding model (S-MTL-MSLRGR-GISFSG) and its variant’s performances against the state-of-the-art models in segmentation and tool-tissue interaction detection. T0-T7 are tool classes as stated in section IV-A.
Fig. 5: Qualitative analysis: Top - Comparison of our proposed model and its variant’s performance in instrument segmentation against select benchmark models and the Ground Truth (GT). Bottom - Comparison of our proposed model’s performance in interaction detection against using vanilla VS-GAT [20] and the Ground Truth (GT). Here, our proposed model refers to S-MTL-MSLRGR-GISFSG (sequentially trained multi-task learning model with multi-scale local reasoning and global reasoning and its scene graph enhanced with global interaction space features).
Model | Feature encoder | SF | Tool interaction detection | Segmentation | ||||||
GR [4] | MSGR | MSLR | PF | GISFSG | Acc | mAP | Recall | mIoU | P-Acc | |
STL | ||||||||||
VS-GAT [20] | 0.6537 | 0.2560 | 0.2666 | - | - | |||||
SEG | - | - | - | 0.3858 | 0.9487 | |||||
SEG-GR | ✓ | - | - | - | 0.3926 | 0.9483 | ||||
SEG-MSGR | ✓ | ✓ | - | - | - | 0.4350 | 0.9628 | |||
SEG-MSLRGR | ✓ | ✓ | - | - | - | 0.4354 | 0.9638 | |||
S-MTL | ||||||||||
GR | ✓ | 0.6787 | 0.2578 | 0.3042 | 0.3926 | 0.9483 | ||||
MSGR | ✓ | ✓ | 0.6813 | 0.2906 | 0.3040 | 0.4350 | 0.9628 | |||
MSLRGR | ✓ | ✓ | 0.7003 | 0.2885 | 0.3096 | 0.4354 | 0.9638 | |||
MSLRGR-PF | ✓ | ✓ | ✓ | 0.6848 | 0.2960 | 0.3157 | 0.4354 | 0.9638 | ||
MSLRGR-GISFSG | ✓ | ✓ | ✓ | 0.6994 | 0.3131 | 0.3157 | 0.4354 | 0.9638 | ||
TABLE II: Ablation study highlighting the importance of multi-scale local and global reasoning (MSLRGR) and use of global interaction space feature in the scene graph (GISFSG) in improving sequentially optimized multi-task learning (S-MTL) model.
Model | Best in tool interaction detection | Best in instrument segmentation | Balanced performance | |||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
Tool interaction | Segmentation | Tool interaction | Segmentation | Tool interaction | Segmentation | |||||||
detection | detection | detection | ||||||||||
Acc | mAP | mIoU | P-Acc | Acc | mAP | mIoU | P-Acc | Acc | mAP | mIoU | P-Acc | |
V-MTL-GR | 0.6193 | 0.2303 | 0.3521 | 0.9420 | 0.5073 | 0.2580 | 0.3731 | 0.9455 | 0.6064 | 0.2327 | 0.3621 | 0.9447 |
KD-MTL-GR | 0.6615 | 0.2531 | 0.3609 | 0.9449 | 0.6391 | 0.2472 | 0.3730 | 0.9453 | 0.6555 | 0.2522 | 0.3713 | 0.9458 |
KD-MTL-MSLRGR | 0.6649 | 0.2644 | 0.4022 | 0.9610 | 0.6322 | 0.2724 | 0.4165 | 0.9622 | 0.6434 | 0.2992 | 0.4105 | 0.9617 |
KD-MTL-MSLRGR-SGFSEG | 0.6589 | 0.2636 | 0.3974 | 0.9593 | 0.6184 | 0.2829 | 0.4188 | 0.9607 | 0.6503 | 0.2600 | 0.4111 | 0.9608 |
KD-MTL-MSLRGR-GISFSG | 0.6830 | 0.2818 | 0.4034 | 0.9613 | 0.6339 | 0.2819 | 0.4169 | 0.9617 | 0.6710 | 0.2808 | 0.4105 | 0.9611 |
TABLE III: Ablation Study on multi-task learning (MTL) model optimized using Vanilla-MTL (V-MTL) and Knowledge Distillation-based MTL (KD-MTL) optimization techniques.
We experiment to improve MTL models' performance over the STL models through the multi-task model setup as three variants of MTL models:
i) Vanilla-MTL model: we aim to improve the interaction detection model.
ii) Fig. 4 (i): the interaction features from the VS-GAT’s combined graph (GC) edges are appended to the latent interaction space features in the segmentation module’s GloRe unit.
iii) The final variant used in our final proposed model: we aim to use global interaction space features to improve scene graph (GISFSG) interaction detection (Fig. 4 (ii)).
RESULTS AND EVALUATION
Quantitatively, the model’s performance in segmenting instruments and detecting interaction is benchmarked against its respective single task SOTA models. The performance in instrument segmentation is quantified using (a) the mean intersection over onion (mIoU), class-wise IoU, and pixel accuracy (P-Acc) metrics. The performance in interaction detection is quantified using accuracy (Acc), mean average precision (mAP), and Recall. It is observed that our globally-reasoned multi-task model (S-MTL-MSLRGR-GISFSG) performance is on par and, in most cases, outperforms STL models in both instrument segmentation (mIoU and P-Acc) and interaction detection (Acc, mAP, and Recall).
Qualitatively, it is also observed that the model’s performance with global reasoning in latent space is further enhanced by incorporating multi-scale local reasoning.
The segmentation performance of the SOTA models are significantly different from their original works due to three main changes: (i) train and test set, (ii) number and type of classes and (iii) resolution of the input image.
DISCUSSION AND CONCLUSION
In the paper, a globally-reasoned multi-task surgical scene understanding model to perform instrument segmentation and tool-tissue interaction detection is proposed.
The model’s performance is improved by (i) introducing multi-scale local (neighborhood) reasoning and incorporating latent global reasoning and (ii) introducing global interaction space features into the scene graph.
The detailed study also proves that the proposed model performs on-par and, in most cases, outperforms existing SOTA single-task models in MICCAI endoscopic vision challenge 2018.
STUDENT'S REVIEW
In this paper, we improved the performance of a globally-reasoned multi-task surgical scene understanding model for instrument segmentation and interaction detection by incorporating global relational reasoning in the latent interaction space and introducing multi-scale local (neighborhood) reasoning in the coordinate space to improve segmentation.
The S-MTL Optimization Algorithm: Based on the segmentation loss ---→ the MTL model’s feature encoder and segmentation model is trained ---→ convergence ---→ the weights of the feature encoder and segmentation blocks are frozen ---→ the training of scene graph in detecting interactions is then performed until convergence.
With the Ablation Study on multi-task learning (MTL) model, V-MTL and KD-MTL optimized as well and sequential training results in optimal convergence but further improvement in segmentation task from incorporating it with scene graph still needs to be made in the future because of the asynchronous converge of the model.
Remaining Question:
REFERENCES
- [1] (2020) 2018 robotic scene segmentation challenge. arXiv preprint arXiv:2001.11190. Cited by: §IV-A.
- [2] (2017) LinkNet: exploiting encoder representations for efficient semantic segmentation. arXiv preprint arXiv:1707.03718. Cited by: §II-1, TABLE I.
- [3] (2004-12) Image categorization by learning and reasoning with regions. J. Mach. Learn. Res. 5, pp. 913–939. External Links: ISSN 1532-4435 Cited by: §I.
- [4] (2019) Graph-based global reasoning networks. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 433–442. Cited by: §I, §I, Fig. 2, §III-A1, §III-A, §IV-E, TABLE I, TABLE II, TABLE IV, TABLE V.
- [5] (2017) GradNorm: gradient normalization for adaptive loss balancing in deep multitask networks. arXiv preprint arXiv:1711.02257. Cited by: §II-3.
- [6] (2017) ToolNet: holistically-nested real-time segmentation of robotic surgical tools. arXiv preprint arXiv:1706.08126. Cited by: §II-1.
- [7] (2018) Detecting and recognizing human-object interactions. 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 8359–8367. Cited by: §II-2.
- [8] (2020) ISINet: an instance-based approach for surgical instrument segmentation. arXiv preprint arXiv:2007.05533. Cited by: §II-1.
- [9] (2017) Inductive representation learning on large graphs. arXiv preprint arXiv:1706.02216. Cited by: §II-2.
- [10] (2017) Mask r-cnn. arXiv preprint arXiv:1703.06870. Cited by: §II-1.
- [11] (2016) Deep residual learning for image recognition. In 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Vol. , pp. 770–778. External Links: Document Cited by: §IV-E, TABLE I, TABLE IV, TABLE V.
- [12] (2019) Learning where to look while tracking instruments in robot-assisted surgery. arXiv preprint arXiv:1907.00214. Cited by: §II-1, §II-2, §II-3, §III-C.
- [13] (2020) Learning and reasoning with the graph structure representation in robotic surgery. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pp. 627–636. Cited by: §IV-A, TABLE I.
- [14] (2020) AP-mtl: attention pruned multi-task learning model for real-time instrument detection and segmentation in robot-assisted surgery. In 2020 IEEE International Conference on Robotics and Automation (ICRA), Vol. , pp. 8433–8439. External Links: Document Cited by: §II-1.
- [15] (2019) Incorporating temporal prior from motion flow for instrument segmentation in minimally invasive surgery video. arXiv preprint arXiv:1907.07899. Cited by: §II-1, TABLE I, TABLE IV.
- [16] (2015) Adam: A method for stochastic optimization. In 3rd International Conference on Learning Representations, ICLR 2015, San Diego, CA, USA, May 7-9, 2015, Conference Track Proceedings, Y. Bengio and Y. LeCun (Eds.), External Links: Link Cited by: §IV-B.
- [17] (2019) Identifying surgical instruments in laparoscopy using deep learning instance segmentation. In 2019 International Conference on Content-Based Multimedia Indexing (CBMI), Vol. , pp. 1–6. External Links: Document Cited by: §II-1.
- [18] (1951) On Information and Sufficiency. The Annals of Mathematical Statistics 22 (1), pp. 79 – 86. External Links: Document, Link Cited by: §III-C.
- [19] (2020) Knowledge distillation for multi-task learning. In European Conference on Computer Vision, pp. 163–176. Cited by: §II-3, §III-C.
- [20] (2020) Visual-semantic graph attention networks for human-object interaction detection. arXiv preprint arXiv:2001.02302. Cited by: §I, Fig. 2, §II-2, §III-B, §III, Fig. 5, §IV-E, §IV-E, TABLE I, TABLE II, TABLE IV, TABLE V.
- [21] (2016) Learning models for actions and person-object interactions with transfer to question answering. ArXiv abs/1604.04808. Cited by: §II-2.
- [22] (2019) When does label smoothing help?. arXiv preprint arXiv:1906.02629. Cited by: §II-2.
- [23] (2019) RASNet: segmentation for tracking surgical instruments in surgical videos using refined attention segmentation network. 2019 41st Annual International Conference of the IEEE Engineering in Medicine and Biology Society (EMBC), pp. 5735–5738. Cited by: §II-1.
- [24] (2018) Learning human-object interactions by graph parsing neural networks. ArXiv abs/1808.07962. Cited by: §II-2, TABLE I.
- [25] (2015) U-net: convolutional networks for biomedical image segmentation. arXiv preprint arXiv:1505.04597. Cited by: §II-1.
- [26] Automatic instrument segmentation in robot-assisted surgery using deep learning. Cited by: §II-1, TABLE I.
- [27] (2017) Non-local neural networks. arXiv preprint arXiv:1711.07971. Cited by: §I.
- [28] (2019) Hierarchical graph pooling with structure learning. arXiv preprint arXiv:1911.05954. Cited by: TABLE I.
- [29] (2016) Pyramid scene parsing network. arXiv preprint arXiv:1612.01105. Cited by: §I.