[docs]defResidualBlock(width):"""Residual block with swish activation."""defapply(x):input_width=ops.shape(x)[3]ifinput_width==width:residual=xelse:residual=layers.Conv2D(width,kernel_size=1)(x)x=layers.BatchNormalization(center=False,scale=False)(x)x=layers.Conv2D(width,kernel_size=3,padding="same",activation="swish")(x)x=layers.Conv2D(width,kernel_size=3,padding="same")(x)x=layers.Add()([x,residual])returnxreturnapply
[docs]defDownBlock(width,block_depth):"""Downsampling block with residual connections."""defapply(x):x,skips=xfor_inrange(block_depth):x=ResidualBlock(width)(x)skips.append(x)x=layers.AveragePooling2D(pool_size=2)(x)returnxreturnapply
[docs]defUpBlock(width,block_depth):"""Upsampling block with residual connections."""defapply(x):x,skips=xx=layers.UpSampling2D(size=2,interpolation="bilinear")(x)for_inrange(block_depth):x=layers.Concatenate()([x,skips.pop()])x=ResidualBlock(width)(x)returnxreturnapply