454. G2Net Gravitational Wave Detection | g2net-gravitational-wave-detection
所有模型均使用 TensorFlow 在 TPU 上实现。以下前端被所有模型使用,维度 dim=128。
input = L.Input(shape=(3, 4096)) # 3个观测值
x = tf.reshape(input,(-1,4096,1)) # 将3个观测值折叠进batch
x = wavenet(x, dim, dilations=12, kernel_size=5)
x = L.Dense(dim//4)(x)
x = wavenet(x, dim, dilations=12, kernel_size=5)
x = L.Dense(dim//4)(x)
x = L.Dense(dim)(x)
x = tf.reshape(x, (-1,3,4096,dim))
x = tf.transpose(x,(0,2,3,1))
x = L.BatchNormalization()(x)
x = L.Activation('gelu')(x)
模型1:
Efficient B3 (尺寸=128*128), 交叉验证分数 87.79-88.04
模型2:
Wavenet + GRU, 交叉验证分数 87.80-88.05
模型3:
Wavenet, 交叉验证分数 87.83-88.07