#!/usr/bin/python
# -*- coding:utf8 -*-
import torch
import matplotlib.pyplot as plt
def relu(z):
return torch.maximum(z,torch.as_tensor(0.))
#带泄露值relu
def leaky_relu(z,negative_slope=0.1):
a1=(torch.can_cast((z>0).dtype,torch.float32)*z)
a2=(torch.can_cast((z<=0).dtype,torch.float32)*(negative_slope*z))
return a1+a2
z=torch.linspace(-10,10,10000)
plt.figure()
plt.plot(z.tolist(),relu(z).tolist(),color='#e4007f',label='ReLU Function')
plt.plot(z.tolist(),leaky_relu(z).tolist(),color='#f19ec2',linestyle='--',label='Leaky_ReLU')
ax=plt.gca()
ax.spines['top'].set_color('none')
ax.spines['right'].set_color('none')
ax.spines['left'].set_position(('data',0))
ax.spines['bottom'].set_position(('data',0))
plt.legend(loc='upper left',fontsize='large')
plt.savefig('fw-relu-leakyrelu.pdf')
plt.show()
