堆
堆结构实际上是一个完全二叉树,不过它又在完全二叉树的基础上做了升级。
小顶堆:其每个节点的父节点都小于当前节点,那么根节点就是其最小的节点。
大顶堆:其正好与小顶堆相反,每个节点的父节点都大于当前节点,所以根节点就是最大的节点。
结构
在Java中,没有实际意义上的堆数据结构。不过,通常都使用数组来存储。接下来边简单概述为什么要使用数组以及数组存储的好处。
对于完全二叉树结构,它当前所在层数用 n
表示,那么每层可以存储的最大元素的数量就是 2^(n-1)
个元素。
第一层可以存储 1 个,第二层存储 2 个,第三层 4 个,第四层 8 个,不难看出,后边每一层元素个数等于前边所以元素个数的和再 + 1,例如:第三层4个 1+ (1+2) = 4
,第四层8个 1 + (1 + 2 + 4) = 8
。
因此,如果我们把堆数据放到数组中,数组index=0
放置元素空,数组index=1
放置根元素,数组index=2
放置第二层左边的元素(由于完全二叉树存元素是从左边的叶子节点往右存储,并且必须上一层存储满了才能到下一层),这样每一层的元素在索引上是连续的。
并且对于数组中任意一个节点,如果其索引 x
,那么可以计算:
- 父节点索引 =
x / 2
- 其叶子节点分别为 =
x * 2、x * 2 + 1
添加、移除元素
这里以小顶堆为例。
添加
- 在小顶堆中添加元素时,先默认将元素放置在最后的节点
- 然后对最后一个节点进行上浮操作:比较当前节点与其父节点值的大小,如果当前节点小于父节点,则交换两个节点
- 重复上述上浮操作的步骤,直到当前节点大于父节点的值或者当前节点已经是根节点
移除
对于堆结构,移除元素都是直接移除根元素
- 移除根元素后,将最后一个元素放置到根节点的位置
- 然后进行节点下沉操作:先获取当前节点的子节点中最小的节点,然后比较当前节点和子节点中最小的节点的大小,如果当前节点值大,则交换两个节点的值
- 重复上述下沉操作,直到当前节点的值比两个子节点都小或者当前节点已经是叶子节点
代码实现
这里以小顶堆为例。
class MinimumHeap {
private static final int DEFAULT_CAPACITY = 16;
/**
* 数组容量
*/
private int capacity;
/**
* 实际元素个数
*/
private int size = 0;
private final Comparator<? super Node> comparator;
/**
* 存储元素的数组,index = 0 的位置不存储元素
*/
private Node[] queue;
public MinimumHeap() {
this(DEFAULT_CAPACITY, null);
}
public MinimumHeap(int capacity) {
this(capacity, null);
}
public MinimumHeap(Comparator<? super Node> comparator) {
this(DEFAULT_CAPACITY, comparator);
}
public MinimumHeap(int capacity, Comparator<? super Node> comparator) {
this.capacity = capacity;
this.comparator = comparator == null
? (t1, t2) -> Math.toIntExact(t1.getValue() - t2.getValue())
: comparator;
queue = new Node[capacity];
}
public int getSize() {
return size;
}
/**
* 添加元素到队列中
*/
public void offer(Node e) {
if (e == null)
throw new NullPointerException();
final int s = ++size;
if (s >= capacity)
grow();
if (s == 1)
queue[1] = e;
else
siftUp(s, e);
}
/**
* 堆元素上浮
*/
private void siftUp(int index, Node element) {
int p;
while ((p = index >>> 1) > 0) {
if (comparator.compare(queue[p], element) <= 0)
break;
// 如果 父级元素 > 当前元素,那么将父级元素的值放到 index 索引的位置
queue[index] = queue[p];
index = p;
}
queue[index] = element;
}
/**
* 获取并移除队列的第一个元素
* @return 队列的第一个元素,如果队列为空,则返回 null
*/
public Node poll() {
if (size == 0)
return null;
final int s = size;
final Node result = queue[1];
final Node last = queue[s];
queue[s] = null;
size--;
if (size != 0)
siftDown(1, last);
return result;
}
/**
* 元素下沉
*/
private void siftDown(int index, Node element) {
int left;
while ((left = index << 1) <= size) {
int child = left, right = left + 1;
// 子元素取小的进行比较
if (right < size && comparator.compare(queue[left], queue[right]) > 0) {
child = right;
}
if (comparator.compare(queue[child], element) > 0)
break;
queue[index] = queue[child];
index = child;
}
queue[index] = element;
}
/**
* 数组扩容
*/
private void grow() {
final int oldCapacity = capacity;
int newCapacity = oldCapacity > 64
? oldCapacity + oldCapacity >>> 1
: oldCapacity << 1;
capacity = newCapacity;
queue = Arrays.copyOf(queue, newCapacity);
}
@Override
public String toString() {
return Arrays.toString(queue);
}
static class Node {
private int value;
public Node(int value) {
this.value = value;
}
public int getValue() {
return value;
}
@Override
public String toString() {
return "Node{value=" + getValue() + "}";
}
}
public static void main(String[] args) {
MinimumHeap queue = new MinimumHeap();
queue.offer(new Node(80000));
queue.offer(new Node(50000));
queue.offer(new Node(60000));
queue.offer(new Node(40000));
queue.offer(new Node(70000));
queue.offer(new Node(30000));
queue.offer(new Node(100000));
queue.offer(new Node(55000));
queue.offer(new Node(55400));
queue.offer(new Node(55500));
queue.offer(new Node(51000));
queue.offer(new Node(56000));
queue.offer(new Node(75000));
queue.offer(new Node(85000));
queue.offer(new Node(95000));
queue.offer(new Node(57800));
queue.offer(new Node(67800));
queue.offer(new Node(67890));
queue.offer(new Node(78900));
int size = queue.getSize();
System.out.println("size = " + size);
System.out.println(queue);
System.out.println("------------------------------");
queue.poll();
queue.poll();
queue.poll();
queue.poll();
queue.poll();
queue.poll();
queue.poll();
size = queue.getSize();
System.out.println("size = " + size);
System.out.println(queue);
}
}