1. 程式人生 > >nn.Sequential函式裡面必須是a Module subclass,不能是一個列表或者是其他的迭代器,雖然這裡麵包含了Module的子類

nn.Sequential函式裡面必須是a Module subclass,不能是一個列表或者是其他的迭代器,雖然這裡麵包含了Module的子類

class RES(nn.Module):
    def __init__(self):
        super(RES, self).__init__()
        self.conv1=nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3,bias=False)
        self.bn1=nn.BatchNorm2d(64)
        self.relu=nn.ReLU(inplace=True)
        self.maxpool=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
        self.conv2=nn.Conv2d(64,128,kernel_size=7,stride=2,padding=3,bias=False)
        self.bn2=nn.BatchNorm2d(128)
    def forward(self,x):
        x=self.conv1(x)
        x=self.bn1(x)
        x=self.relu(x)
        x=self.maxpool(x)
        x=self.conv2(x)
        x=self.bn2(x)
        return x

model=RES()
glb = nn.Sequential(*list(model.children())[:4])

有兩點資料的說明:這個類繼承了Module一定要用super函式

nn.Sequential函式裡面的引數一定是Module的子類,而list:list is not a Module subclass。所以不能當做引數,當然model.children()也是一樣:Module.children is not a Module subclass。這裡的*就起了作用,將list或者children的內容迭代的一個一個的傳進去,效果如下:

 當然,我們還可以像最上面的那樣,選取裡面的幾個Module,例如[:4]也就是第0個到第3個.