pytorch skip connection in a sequential model
Your observations are correct, but you may have missed the definition of UnetSkipConnectionBlock.forward()
(UnetSkipConnectionBlock
being the Module
defining the U-Net block you shared), which may clarify this implementation:
(from pytorch-CycleGAN-and-pix2pix/models/networks.py#L259
)
# Defines the submodule with skip connection.# X -------------------identity---------------------- X# |-- downsampling -- |submodule| -- upsampling --|class UnetSkipConnectionBlock(nn.Module): # ... def forward(self, x): if self.outermost: return self.model(x) else: return torch.cat([x, self.model(x)], 1)
The last line is the key (applied for all inner blocks). The skip layer is simply done by concatenating the input x
and the (recursive) block output self.model(x)
, with self.model
the list of operations you mentioned -- so not so differently from the Functional
code you wrote.