python - Using zip() on two nn.ModuleList - Stack Overflow

Is using two different nn.ModuleList() zipped lists correct to build the computational graph for traini

Is using two different nn.ModuleList() zipped lists correct to build the computational graph for training a neural net in PyTorch? nn.ModuleList is a wrapper around Python's list with a registration of a module for training.

I'm building a network which consists of 2x interchanging types of blocks in __init__:

    def __init__(self, in_channels):
        super().__init__()

        self.encoder_conv_blocks = nn.ModuleList()
        self.downsample_blocks = nn.ModuleList()

        for out_channels in _FILTERS:
            conv_block = _ConvBlock(in_channels, _CONV_BLOCK_LEN, _CONV_BLOCK_GROWTH_RATE)
            downsample_block = _DownsampleBlock(conv_block.out_channels, out_channels)

            self.encoder_conv_blocks.append(conv_block)
            self.downsample_blocks.append(downsample_block)

            in_channels = out_channels

later in forward, I'm zipping the layers, as I need the outputs of the first type of block later in skip connections:

    def forward(self, x):
        skip_connections = []
        
        for conv_block, downsample_block in zip(self.encoder_conv_blocks,
                                                self.downsample_blocks):
            x = conv_block(x)
            skip_connections.append(x)
            x = downsample_block(x)

However when pritting the summary torchinfo, we can see that summary of the registered methods using 2x zipped nn.ModuleList looks different compared to the summary where one single nn.ModuleList was used. I suspect that this can cause issues for training and inference in the future.

zip(nn.ModuleList(), nn.ModuleList()):

========================================================================================================================
Layer (type:depth-idx)                        Input Shape               Output Shape              Param #
========================================================================================================================
MyNet                                     [16, 4, 128, 256]         [16, 3, 128, 256]         --
├─ModuleList: 1-13                            --                        --                        (recursive)
│    └─_ConvBlock: 2-1                        [16, 4, 128, 256]         [16, 84, 128, 256]        26,360
├─ModuleList: 1-14                            --                        --                        (recursive)
│    └─_DownsampleBlock: 2-2                  [16, 84, 128, 256]        [16, 64, 64, 128]         48,448
├─ModuleList: 1-13                            --                        --                        (recursive)
│    └─_ConvBlock: 2-3                        [16, 64, 64, 128]         [16, 144, 64, 128]        70,160
├─ModuleList: 1-14                            --                        --                        (recursive)
│    └─_DownsampleBlock: 2-4                  [16, 144, 64, 128]        [16, 128, 32, 64]         166,016
├─ModuleList: 1-13                            --                        --                        (recursive)
│    └─_ConvBlock: 2-5                        [16, 128, 32, 64]         [16, 208, 32, 64]         116,880
├─ModuleList: 1-14                            --                        --                        (recursive)
│    └─_DownsampleBlock: 2-6                  [16, 208, 32, 64]         [16, 128, 16, 32]         239,744
├─ModuleList: 1-13                            --                        --                        (recursive)
│    └─_ConvBlock: 2-7                        [16, 128, 16, 32]         [16, 208, 16, 32]         116,880
├─ModuleList: 1-14                            --                        --                        (recursive)
│    └─_DownsampleBlock: 2-8                  [16, 208, 16, 32]         [16, 128, 8, 16]          239,744
├─ModuleList: 1-13                            --                        --                        (recursive)
│    └─_ConvBlock: 2-9                        [16, 128, 8, 16]          [16, 208, 8, 16]          116,880
├─ModuleList: 1-14                            --                        --                        (recursive)
│    └─_DownsampleBlock: 2-10                 [16, 208, 8, 16]          [16, 256, 4, 8]           479,488
├─ModuleList: 1-13                            --                        --                        (recursive)
│    └─_ConvBlock: 2-11                       [16, 256, 4, 8]           [16, 336, 4, 8]           210,320
├─ModuleList: 1-14                            --                        --                        (recursive)
│    └─_DownsampleBlock: 2-12                 [16, 336, 4, 8]           [16, 256, 2, 4]           774,400
├─ModuleList: 1-13                            --                        --                        (recursive)
│    └─_ConvBlock: 2-13                       [16, 256, 2, 4]           [16, 336, 2, 4]           210,320
├─ModuleList: 1-14                            --                        --                        (recursive)
│    └─_DownsampleBlock: 2-14                 [16, 336, 2, 4]           [16, 512, 1, 2]           1,548,800

single nn.ModuleList():

MyNet                                     [16, 4, 128, 256]         [16, 3, 128, 256]         --
├─ModuleList: 1-1                             --                        --                        --
│    └─_ConvBlock: 2-1                        [16, 4, 128, 256]         [16, 84, 128, 256]        26,360
│    └─_DownsampleBlock: 2-2                  [16, 84, 128, 256]        [16, 64, 64, 128]         48,448
│    └─_ConvBlock: 2-3                        [16, 64, 64, 128]         [16, 144, 64, 128]        70,160
│    └─_DownsampleBlock: 2-4                  [16, 144, 64, 128]        [16, 128, 32, 64]         166,016
│    └─_ConvBlock: 2-5                        [16, 128, 32, 64]         [16, 208, 32, 64]         116,880
│    └─_DownsampleBlock: 2-6                  [16, 208, 32, 64]         [16, 128, 16, 32]         239,744
│    └─_ConvBlock: 2-7                        [16, 128, 16, 32]         [16, 208, 16, 32]         116,880
│    └─_DownsampleBlock: 2-8                  [16, 208, 16, 32]         [16, 128, 8, 16]          239,744
│    └─_ConvBlock: 2-9                        [16, 128, 8, 16]          [16, 208, 8, 16]          116,880
│    └─_DownsampleBlock: 2-10                 [16, 208, 8, 16]          [16, 256, 4, 8]           479,488
│    └─_ConvBlock: 2-11                       [16, 256, 4, 8]           [16, 336, 4, 8]           210,320
│    └─_DownsampleBlock: 2-12                 [16, 336, 4, 8]           [16, 256, 2, 4]           774,400
│    └─_ConvBlock: 2-13                       [16, 256, 2, 4]           [16, 336, 2, 4]           210,320
│    └─_DownsampleBlock: 2-14                 [16, 336, 2, 4]           [16, 512, 1, 2]           1,548,800

Is using two different nn.ModuleList() zipped lists correct to build the computational graph for training a neural net in PyTorch? nn.ModuleList is a wrapper around Python's list with a registration of a module for training.

I'm building a network which consists of 2x interchanging types of blocks in __init__:

    def __init__(self, in_channels):
        super().__init__()

        self.encoder_conv_blocks = nn.ModuleList()
        self.downsample_blocks = nn.ModuleList()

        for out_channels in _FILTERS:
            conv_block = _ConvBlock(in_channels, _CONV_BLOCK_LEN, _CONV_BLOCK_GROWTH_RATE)
            downsample_block = _DownsampleBlock(conv_block.out_channels, out_channels)

            self.encoder_conv_blocks.append(conv_block)
            self.downsample_blocks.append(downsample_block)

            in_channels = out_channels

later in forward, I'm zipping the layers, as I need the outputs of the first type of block later in skip connections:

    def forward(self, x):
        skip_connections = []
        
        for conv_block, downsample_block in zip(self.encoder_conv_blocks,
                                                self.downsample_blocks):
            x = conv_block(x)
            skip_connections.append(x)
            x = downsample_block(x)

However when pritting the summary torchinfo, we can see that summary of the registered methods using 2x zipped nn.ModuleList looks different compared to the summary where one single nn.ModuleList was used. I suspect that this can cause issues for training and inference in the future.

zip(nn.ModuleList(), nn.ModuleList()):

========================================================================================================================
Layer (type:depth-idx)                        Input Shape               Output Shape              Param #
========================================================================================================================
MyNet                                     [16, 4, 128, 256]         [16, 3, 128, 256]         --
├─ModuleList: 1-13                            --                        --                        (recursive)
│    └─_ConvBlock: 2-1                        [16, 4, 128, 256]         [16, 84, 128, 256]        26,360
├─ModuleList: 1-14                            --                        --                        (recursive)
│    └─_DownsampleBlock: 2-2                  [16, 84, 128, 256]        [16, 64, 64, 128]         48,448
├─ModuleList: 1-13                            --                        --                        (recursive)
│    └─_ConvBlock: 2-3                        [16, 64, 64, 128]         [16, 144, 64, 128]        70,160
├─ModuleList: 1-14                            --                        --                        (recursive)
│    └─_DownsampleBlock: 2-4                  [16, 144, 64, 128]        [16, 128, 32, 64]         166,016
├─ModuleList: 1-13                            --                        --                        (recursive)
│    └─_ConvBlock: 2-5                        [16, 128, 32, 64]         [16, 208, 32, 64]         116,880
├─ModuleList: 1-14                            --                        --                        (recursive)
│    └─_DownsampleBlock: 2-6                  [16, 208, 32, 64]         [16, 128, 16, 32]         239,744
├─ModuleList: 1-13                            --                        --                        (recursive)
│    └─_ConvBlock: 2-7                        [16, 128, 16, 32]         [16, 208, 16, 32]         116,880
├─ModuleList: 1-14                            --                        --                        (recursive)
│    └─_DownsampleBlock: 2-8                  [16, 208, 16, 32]         [16, 128, 8, 16]          239,744
├─ModuleList: 1-13                            --                        --                        (recursive)
│    └─_ConvBlock: 2-9                        [16, 128, 8, 16]          [16, 208, 8, 16]          116,880
├─ModuleList: 1-14                            --                        --                        (recursive)
│    └─_DownsampleBlock: 2-10                 [16, 208, 8, 16]          [16, 256, 4, 8]           479,488
├─ModuleList: 1-13                            --                        --                        (recursive)
│    └─_ConvBlock: 2-11                       [16, 256, 4, 8]           [16, 336, 4, 8]           210,320
├─ModuleList: 1-14                            --                        --                        (recursive)
│    └─_DownsampleBlock: 2-12                 [16, 336, 4, 8]           [16, 256, 2, 4]           774,400
├─ModuleList: 1-13                            --                        --                        (recursive)
│    └─_ConvBlock: 2-13                       [16, 256, 2, 4]           [16, 336, 2, 4]           210,320
├─ModuleList: 1-14                            --                        --                        (recursive)
│    └─_DownsampleBlock: 2-14                 [16, 336, 2, 4]           [16, 512, 1, 2]           1,548,800

single nn.ModuleList():

MyNet                                     [16, 4, 128, 256]         [16, 3, 128, 256]         --
├─ModuleList: 1-1                             --                        --                        --
│    └─_ConvBlock: 2-1                        [16, 4, 128, 256]         [16, 84, 128, 256]        26,360
│    └─_DownsampleBlock: 2-2                  [16, 84, 128, 256]        [16, 64, 64, 128]         48,448
│    └─_ConvBlock: 2-3                        [16, 64, 64, 128]         [16, 144, 64, 128]        70,160
│    └─_DownsampleBlock: 2-4                  [16, 144, 64, 128]        [16, 128, 32, 64]         166,016
│    └─_ConvBlock: 2-5                        [16, 128, 32, 64]         [16, 208, 32, 64]         116,880
│    └─_DownsampleBlock: 2-6                  [16, 208, 32, 64]         [16, 128, 16, 32]         239,744
│    └─_ConvBlock: 2-7                        [16, 128, 16, 32]         [16, 208, 16, 32]         116,880
│    └─_DownsampleBlock: 2-8                  [16, 208, 16, 32]         [16, 128, 8, 16]          239,744
│    └─_ConvBlock: 2-9                        [16, 128, 8, 16]          [16, 208, 8, 16]          116,880
│    └─_DownsampleBlock: 2-10                 [16, 208, 8, 16]          [16, 256, 4, 8]           479,488
│    └─_ConvBlock: 2-11                       [16, 256, 4, 8]           [16, 336, 4, 8]           210,320
│    └─_DownsampleBlock: 2-12                 [16, 336, 4, 8]           [16, 256, 2, 4]           774,400
│    └─_ConvBlock: 2-13                       [16, 256, 2, 4]           [16, 336, 2, 4]           210,320
│    └─_DownsampleBlock: 2-14                 [16, 336, 2, 4]           [16, 512, 1, 2]           1,548,800
Share Improve this question asked Mar 26 at 18:02 Ivan TishchenkoIvan Tishchenko 10810 bronze badges 2
  • zip is just simultaneously iterating over both module lists and the summary looks correct to me. Is there something particular about the summary that makes you think there's an issue? – jodag Commented Mar 27 at 2:24
  • @jodag I'm concerned about the indexes of ModuleList in the first table and about 'recursive' in the parameter column in the first table (need to scroll the table horizontally to see the last column). In the second table the indexes are sequential and there is no mention of 'recursive'. The overall number of parameters and modules is the same in both cases though. – Ivan Tishchenko Commented Mar 27 at 9:17
Add a comment  | 

1 Answer 1

Reset to default 1

Both methods are equivalent - change in print-out is just an artifact of how torchinfo crawls the model.

torchinfo tracks the model's forward pass, looking at every module involved. If the same module appears more than once, it is labeled recursive. For nn.ModuleList objects, using an item in the same ModuleList at different points of the forward gets flagged as recursive simply because the ModuleList container is showing up more than once in different places. Here's a simple example:

Example 1:

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.ModuleList([nn.Linear(8, 8) for i in range(2)])
        self.l2 = nn.Linear(8,8)
        
    def forward(self, x):
        x = self.l1[0](x)
        x = self.l1[1](x)
        x = self.l2(x)
        return x

m = MyModel()
summary(m, (1, 8), depth=5)

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
MyModel                                  [1, 8]                    --
├─ModuleList: 1-1                        --                        --
│    └─Linear: 2-1                       [1, 8]                    72
│    └─Linear: 2-2                       [1, 8]                    72
├─Linear: 1-2                            [1, 8]                    72
==========================================================================================
Total params: 216
Trainable params: 216
Non-trainable params: 0
Total mult-adds (M): 0.00
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
==========================================================================================

Example 2:

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.ModuleList([nn.Linear(8, 8) for i in range(2)])
        self.l2 = nn.Linear(8,8)
        
    def forward(self, x):
        x = self.l1[0](x)
        x = self.l2(x)
        x = self.l1[1](x)
        return x

m = MyModel()
summary(m, (1, 8), depth=5)

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
MyModel                                  [1, 8]                    --
├─ModuleList: 1-3                        --                        (recursive)
│    └─Linear: 2-1                       [1, 8]                    72
├─Linear: 1-2                            [1, 8]                    72
├─ModuleList: 1-3                        --                        (recursive)
│    └─Linear: 2-2                       [1, 8]                    72
==========================================================================================
Total params: 216
Trainable params: 216
Non-trainable params: 0
Total mult-adds (M): 0.00
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
==========================================================================================

In the first example, we use all layers in the ModuleList in order, and get no recursive flag. In the second, we use the layers in the ModuleList at different times, and get the recursive flag on the ModuleList object itself. This is just an artifact of how torchinfo crawls the model.

As a purely style-based note, there's nothing wrong with zipping modulelists, but if you know each _ConvBlock will be paired 1-1 with a _DownsampleBlock, you might consider putting them into a combined module

class CombinedBlock(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.conv_block = _ConvBlock(...)
        self.down_block = _DownsampleBlock(...)
    def forward(self, x):
        x = self.conv_block(x)
        skip = x
        x = self.down_block(x)
        return x, skip

发布者:admin,转转请注明出处:http://www.yc00.com/questions/1744133956a4559968.html

相关推荐

  • python - Using zip() on two nn.ModuleList - Stack Overflow

    Is using two different nn.ModuleList() zipped lists correct to build the computational graph for traini

    8天前
    10

发表回复

评论列表(0条)

  • 暂无评论

联系我们

400-800-8888

在线咨询: QQ交谈

邮件:admin@example.com

工作时间:周一至周五,9:30-18:30,节假日休息

关注微信