返回列表

6th place solution

686. PhysioNet - Digitization of ECG Images | physionet-ecg-image-digitization

开始: 2025-10-21 结束: 2026-01-22 医学影像分析 数据算法赛
第 6 名解决方案 - 直接回归导联信号就是你所需的一切

第 6 名解决方案

直接回归导联信号就是你所需的一切

作者: AnnieGo
竞赛排名: 6
发布时间: 2026-01-23

致谢

我想向 Kaggle 和竞赛组织者提供这次宝贵的机会表示诚挚的感谢。特别感谢 @hengck23 分享了他的强基线模型,这为我的工作奠定了 crucial 基础。

总结

我的解决方案专注于优化 @hengck23 提供的基线的第二阶段。关键的见解是直接回归导联信号,绕过传统的“分割后处理”流程。通过将其视为直接回归任务,模型可以更有效地学习信号特征,并减少后处理阶段通常引入的累积误差。

整体流程

整体流程图

重采样 (Resample)

训练数据集包含采样率多样的 ECG 信号,范围从 2.5 kHz10 kHz。为了确保高质量的 ground truth 并保持最佳信噪比 (SNR),需要一致的重采样策略。我实施了一个基准测试框架来评估各种重采样算法的保真度——包括 polyphase, linear, cubic spline, and FFT-based methods。性能是通过以下变换来测量的:

  1. 上/下采样:将原始信号重采样到目标长度(例如 2560、5120 或 10250)。
  2. 恢复:将信号重采样回其原始长度。
  3. 评估:通过将“恢复”的信号与“原始”ground truth 进行比较来计算 SNR。

关键观察结果如下:

  • 对于这个信号处理任务,scipy.signal.resample (基于 FFT) 产生的结果明显优于 torch.nn.functional.interpolate (线性/双线性)。
  • 保真度与中间采样密度呈正相关;较高的中间长度 (10250 > 5120 > 2560) 导致 substantially 较低的信息损失。
重采样性能对比

此外,为了加速图像重采样并将其集成到训练过程中,我使用了 resample_torch

def resample_torch(self, x, num, dim=-1):
        dim = (x.dim() + dim) if dim < 0 else dim
        X = torch.fft.fft(x, dim=dim)
        Nx = X.shape[dim]

        sl = [slice(None)] * X.ndim
        newshape = list(X.shape)
        newshape[dim] = num
        Y = torch.zeros(newshape, dtype=X.dtype, device=X.device)

        N = min(num, Nx)
        sl[dim] = slice(0, (N + 1) // 2)
        Y[sl] = X[sl]
        sl[dim] = slice(-(N - 1) // 2, None)
        Y[sl] = X[sl]

        if N % 2 == 0:
            if N < Nx:
                sl[dim] = slice(N//2, N//2+1)
                Y[sl] += X[sl]
            elif N < num:
                sl[dim] = slice(num-N//2, num-N//2+1)
                Y[sl] /= 2
                temp = Y[sl]
                sl[dim] = slice(N//2, N//2+1)
                Y[sl] = temp

        y = torch.fft.ifft(Y, dim=dim).real * (float(num) / float(Nx))
        return y

信号回归头 (Signal Regression Head)

该模块通过使用 Soft-Argmax 机制估计垂直坐标,将 2D 特征嵌入转换为 ECG 导联的精确物理电压值。

# Signal Regression Head
class MaskEmbeddingToLeadSignalSoftArgmax(nn.Module):
    def __init__(self, n_leads=4, embedding_dim=32, temperature=0.5):
        super().__init__()
        
        self.n_leads = n_leads
        self.temperature = temperature
        self.lead_y_logits = nn.Conv2d(embedding_dim, n_leads, kernel_size=1)
        
        self.register_buffer('zero_mv', torch.tensor([703.5, 987.5, 1271.5, 1531.5]).view(1, 4, 1))
        self.register_buffer('mv_to_pixel', torch.tensor(79.0))

    def forward(self, masked_feat):
        B, C, H, W = masked_feat.shape
        device = masked_feat.device
        dtype = masked_feat.dtype

        y_logits = self.lead_y_logits(masked_feat)
        prob = torch.softmax(y_logits / self.temperature, dim=2)

        y_coord = torch.arange(H, device=device, dtype=dtype).view(1,1,H,1)
        y_pixel = (prob * y_coord).sum(dim=2)  # [B, L, W]
        
        pred_mv = (self.zero_mv - y_pixel) / self.mv_to_pixel
        return pred_mv, prob

实验结果

信号长度 TTA (水平翻转) Epochs LB 分数
2560 False 60 21.36
5120 False 100 22.20
5120 True 100 22.36
5120 True 150 22.43

代码

同比赛其他方案