#!/usr/bin/env python3
# -*-coding:utf8-*-
import numpy as np
def max_pooling(feature_map, size=2, stride=2):
"""max_pooling"""
channel = feature_map.shape[0]
height = feature_map.shape[1]
width = feature_map.shape[2]
# 确定最后的输出形状
out_height = np.uint16((height - size) // stride + 1)
out_width = np.uint16((width - size) // stride + 1)
# print("out_shape", (out_height, out_width))
out_pooling = np.zeros((channel, out_height, out_width), dtype=np.uint8)
for map_num in range(channel):
after_height = 0
for m in np.arange(0, height, stride):
after_width = 0
for n in np.arange(0, width, stride):
try:
out_pooling[map_num, after_height, after_width] = np.max(feature_map[map_num, m:m + size, n:n + size])
# try执行不成功, 说明已经超出,终止循环
except:
break
after_width = after_width + 1
after_height = after_height + 1
return out_pooling
if __name__ == "__main__":
input = np.arange(1, 28)
input = input.reshape((3, 3, 3))
print("池化前:", input)
output = max_pooling(input, 2, 1)
print("最大池化后:", output)
print(output.dtype)
print("out_shape", (output.shape))
Python3:
池化前: [[[ 1 2 3]
[ 4 5 6]
[ 7 8 9]]
[[10 11 12]
[13 14 15]
[16 17 18]]
[[19 20 21]
[22 23 24]
[25 26 27]]]
最大池化后: [[[ 5 6]
[ 8 9]]
[[14 15]
[17 18]]
[[23 24]
[26 27]]]
uint8
"out_shape" (3, 2, 2)