nn.Sequential函式裡面必須是a Module subclass,不能是一個列表或者是其他的迭代器,雖然這裡麵包含了Module的子類
阿新 • • 發佈:2018-12-18
class RES(nn.Module): def __init__(self): super(RES, self).__init__() self.conv1=nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3,bias=False) self.bn1=nn.BatchNorm2d(64) self.relu=nn.ReLU(inplace=True) self.maxpool=nn.MaxPool2d(kernel_size=3,stride=2,padding=1) self.conv2=nn.Conv2d(64,128,kernel_size=7,stride=2,padding=3,bias=False) self.bn2=nn.BatchNorm2d(128) def forward(self,x): x=self.conv1(x) x=self.bn1(x) x=self.relu(x) x=self.maxpool(x) x=self.conv2(x) x=self.bn2(x) return x model=RES() glb = nn.Sequential(*list(model.children())[:4])
有兩點資料的說明:這個類繼承了Module一定要用super函式
nn.Sequential函式裡面的引數一定是Module的子類,而list:list is not a Module subclass。所以不能當做引數,當然model.children()也是一樣:Module.children is not a Module subclass。這裡的*就起了作用,將list或者children的內容迭代的一個一個的傳進去,效果如下:
當然,我們還可以像最上面的那樣,選取裡面的幾個Module,例如[:4]也就是第0個到第3個.