这个精简了,小一点
另一篇介绍:
https://blog.youkuaiyun.com/jacke121/article/details/97677477
# !/usr/bin/env python
# -*- coding: utf-8 -*-
import time
import torch
import torch.nn as nn
import math
class Swish(nn.Module):
def __init__(self):
super(Swish, self).__init__()
self.sigmoid = nn.Sigmoid()
def forward(self, x):
return x * self.sigmoid(x)
NON_LINEARITY = {
'ReLU': nn.ReLU(inplace=True),
'Swish': Swish(),
}
def _RoundChannels(c, divisor=8, min_value=None):
if min_value is None:
min_value = divisor
new_c = max(min_value, int(c + divisor /