686. PhysioNet - Digitization of ECG Images | physionet-ecg-image-digitization
我想向 Kaggle 和竞赛组织者提供这次宝贵的机会表示诚挚的感谢。特别感谢 @hengck23 分享了他的强基线模型,这为我的工作奠定了 crucial 基础。
我的解决方案专注于优化 @hengck23 提供的基线的第二阶段。关键的见解是直接回归导联信号,绕过传统的“分割后处理”流程。通过将其视为直接回归任务,模型可以更有效地学习信号特征,并减少后处理阶段通常引入的累积误差。
训练数据集包含采样率多样的 ECG 信号,范围从 2.5 kHz 到 10 kHz。为了确保高质量的 ground truth 并保持最佳信噪比 (SNR),需要一致的重采样策略。我实施了一个基准测试框架来评估各种重采样算法的保真度——包括 polyphase, linear, cubic spline, and FFT-based methods。性能是通过以下变换来测量的:
关键观察结果如下:
此外,为了加速图像重采样并将其集成到训练过程中,我使用了 resample_torch:
def resample_torch(self, x, num, dim=-1):
dim = (x.dim() + dim) if dim < 0 else dim
X = torch.fft.fft(x, dim=dim)
Nx = X.shape[dim]
sl = [slice(None)] * X.ndim
newshape = list(X.shape)
newshape[dim] = num
Y = torch.zeros(newshape, dtype=X.dtype, device=X.device)
N = min(num, Nx)
sl[dim] = slice(0, (N + 1) // 2)
Y[sl] = X[sl]
sl[dim] = slice(-(N - 1) // 2, None)
Y[sl] = X[sl]
if N % 2 == 0:
if N < Nx:
sl[dim] = slice(N//2, N//2+1)
Y[sl] += X[sl]
elif N < num:
sl[dim] = slice(num-N//2, num-N//2+1)
Y[sl] /= 2
temp = Y[sl]
sl[dim] = slice(N//2, N//2+1)
Y[sl] = temp
y = torch.fft.ifft(Y, dim=dim).real * (float(num) / float(Nx))
return y
该模块通过使用 Soft-Argmax 机制估计垂直坐标,将 2D 特征嵌入转换为 ECG 导联的精确物理电压值。
# Signal Regression Head
class MaskEmbeddingToLeadSignalSoftArgmax(nn.Module):
def __init__(self, n_leads=4, embedding_dim=32, temperature=0.5):
super().__init__()
self.n_leads = n_leads
self.temperature = temperature
self.lead_y_logits = nn.Conv2d(embedding_dim, n_leads, kernel_size=1)
self.register_buffer('zero_mv', torch.tensor([703.5, 987.5, 1271.5, 1531.5]).view(1, 4, 1))
self.register_buffer('mv_to_pixel', torch.tensor(79.0))
def forward(self, masked_feat):
B, C, H, W = masked_feat.shape
device = masked_feat.device
dtype = masked_feat.dtype
y_logits = self.lead_y_logits(masked_feat)
prob = torch.softmax(y_logits / self.temperature, dim=2)
y_coord = torch.arange(H, device=device, dtype=dtype).view(1,1,H,1)
y_pixel = (prob * y_coord).sum(dim=2) # [B, L, W]
pred_mv = (self.zero_mv - y_pixel) / self.mv_to_pixel
return pred_mv, prob
| 信号长度 | TTA (水平翻转) | Epochs | LB 分数 |
|---|---|---|---|
| 2560 | False | 60 | 21.36 |
| 5120 | False | 100 | 22.20 |
| 5120 | True | 100 | 22.36 |
| 5120 | True | 150 | 22.43 |