acwing算法基础课


前言

一、第一章 基础算法

1. 排序

快速排序

  1. 确定分界点:q[1], q[(l + r) / 2], q[random]
  2. 调整区间: 使得小于等于x在左边,大于x在右边
  3. 递归处理左右两段
def quick_sort(q, l, r):
    if r <= l: return

    x, i, j = q[l + r >> 1], l, r

    while i < j:
        while(q[i] < x): i += 1
        while(q[j] > x): j -= 1
        if i < j: q[i], q[j] = q[j], q[i]
    
    quick_sort(q, l, j)#写j时不能基准值x不能取到q[r],同理写i - 1时基准值x不能取到q[l]
    quick_sort(q, j + 1, r)

def main():
    n = int(input())
    q = []
    for i in range(n):
        q.append(int(input()))
    quick_sort(q, 0, n - 1)
    for i in range(n):
        print(q[i])
        
if __name__ == "__main__":
    main()

归并排序

  1. 确定分界点
  2. 分好的两部分继续向下递归
  3. 归并
def merge_sort(q, l, r):
    if r <= l: return

    mid = l + r >> 1
    
    merge_sort(q, l, mid)
    merge_sort(q, mid + 1, r)

    i, j = l, mid + 1
    temp = []

    while i <= mid and j <= r:
        if q[i] <= q[j]: 
            temp.append(q[i])
            i += 1
        else:
            temp.append(q[j])
            j += 1
    while i <= mid:
        temp.append(q[i])
        i += 1
    while j <= r:
        temp.append(q[j])
        j += 1
    for i in range(len(temp)):
        q[l + i] = temp[i]

def main():
    n = int(input())
    q = []
    for i in range(n):
        q.append(int(input()))
    print(q)
    merge_sort(q, 0, n - 1)
    print(q)
        
if __name__ == "__main__":
    main()

2. 二分

使用情景:左边满足某一种性质,右边满足另一种性质

用哪一种模板只需要看是需要r = mid还是l = mid,确定后再看mid = l + r >> 1是否需要补上加1

模板一

[l, r]分成[l, mid]和[mid + 1, r]

def binary_search1():
	while l < r:
		mid = l + r >> 1
		if check(mid): r = mid
		else: l = mid + 1

模板二

[l, r]分成[l, mid - 1]和[mid, r]

def binary_search2():
	while l < r:
		mid = l + r + 1 >> 1
		if check(mid): l = mid
		else: r = mid - 1

总结

提示:这里对文章进行总结:
例如:以上就是今天要讲的内容,本文仅仅简单介绍了pandas的使用,而pandas提供了大量能使我们快速便捷地处理数据的函数和方法。

3. 前缀和

原数组: a1,a2,…,ai
前缀和:Si = a1 + a2 + … + ai

如何求

S[0] = 0
for i in range(n):
	S[i] = S[i - 1] + a[i]
def main():
    a, S = [0], [0]
    n = int(input())
    for i in range(1, n + 1):
        a.append(int(input()))
    for i in range(1, n + 1):
        S.append(S[i - 1] + a[i])
    print(S)
        
if __name__ == "__main__":
    main()

作用

快速求出某一段数的和
S[0]是为了同一边界,如求前十个数的和S[10] - S[0]

二维前缀和

for i in range(n):
	for j in range(m):
		S[i][j] = S[i - 1][j] + S[i][j - 1] - S[i - 1][j - 1] + a[i][j]
def main():
    n, m, q = map(int, input().split())

    a = [[0] * m for i in range(n)]
    S = [[0] * (m + 1) for i in range(n + 1)]

    for i in range(n):
        for j in range(m):
            a[i][j] = int(input())
    
    for i in range(1, n + 1):
        for j in range(1, m + 1):
            S[i][j] = S[i - 1][j] + S[i][j - 1] - S[i - 1][j - 1] + a[i - 1][j - 1]
    
    print(S)
    
if __name__ == "__main__":
    main()

4. 差分

原数组:a1,a2,…,an
构造:b1,b2,…,bn
使得ai = b1 + b2 + … + bi(使得b的前缀和是a)

B->A需要O(n)

作用

当需要将A中[l, r]的所有数都加c
只需要b[l] + c,b[r + 1] - c
这样只需要O(1)的时间就可以完成对某一区间的所有数加1的目的

def insert(b, l, r, c):
    b[l] += c
    b[r + 1] -= c
    return b

def main():
    n = int(input())
    a, b = [0] * (n + 2), [0] * (n + 2)

    '''初始化'''
    for i in range(1, n + 1):
        a[i] = int(input())

    for i in range(1, n + 1):
        insert(b, i, i, a[i])#相当于a的原始全0状态下在区间[i, i]插入a[i]时对应构造的b
    
    '''某一区间加1'''
    insert(b, 1, 2, 1)

    '''由b得到a,输出a'''
    for i in range(1, n + 1):
        b[i] += b[i - 1]
        print(b[i])

if __name__ == "__main__":
    main()

二维差分

b[x1, y1] += c
b[x2 + 1, y1] -= c
b[x1, y2 + 1] -= c
b[x2 + 1, y2 + 1] += c

def insert(b, x1, y1, x2, y2, c):
    b[x1][y1] += c
    b[x2 + 1][y1] -= c
    b[x1][y2 + 1] -= c
    b[x2 + 1][y2 + 1] += c
    return b

def main():
    n, m = map(int, input().split())

    a = [[0] * (m + 2) for i in range(n + 2)]
    b = [[0] * (m + 2) for i in range(n + 2)]

    for i in range(1, n + 1):
        for j in range(1, m + 1):
            a[i][j] = int(input())

    for i in range(1, n + 1):
        for j in range(1, m + 1):
            insert(b, i, j, i, j, a[i][j])

    insert(b, 1, 1, 2, 2, 1)

    for i in range(1, n + 1):
        for j in range(1, m + 1):
            b[i][j] += b[i - 1][j] + b[i][j - 1] - b[i - 1][j - 1]
            
    print(b)

if __name__ == "__main__":
    main()

5. 双指针

两种情况

def main():
    for i in range(n):
        while j < i and check(i, j):
            j += 1
        #具体问题

if __name__ == "__main__":
    main()

常见问题分类:
    (1) 对于一个序列,用两个指针维护一段区间
    (2) 对于两个序列,维护某种次序,比如归并排序中合并两个有序序列的操作

如果使用python需要再循环内改变外循环的循环变量,则需要使用两个while

def main():
    str = input()
    n = len(str)
    i = -1
    while i < n:
        i += 1
        j = i
        while j < n and str[j] != ' ':
            j += 1
        for k in range(i, j):
            print(str[k], end = '')
        print()
        i = j


if __name__ == "__main__":
    main()

核心思想

所有的双指针算法都是O(n),从而优化了O(n^2)的算法

例子

#朴素做法,O(n^2)
def main():
    for i in range(n):
        for j in range(i):
            if check(i, j):
                res = max(res, j - i + 1)
#双指针算法,O(n)
#j的含义是最远能走到哪里,所以实现剪枝
def main():
    for i in range(n):
        while j <= i and check(i, j): j += 1
        res = max(res, j - i + 1)

#关于check,可以用S[N]的哈希表来存储每一个数在窗口中的出现次数(判断数组中是否存在重复元素可以考虑哈希表)
#i++时S[a[i]]++,若S[a[i]] > 1,则说明有重复元素,需要j++,同时再判断S[a[i]]是否大于1
def main():
    for i in range(n):
        S[a[i]] += 1
        while j <= i:
            S[a[j]] += 1
            if S[a[j]] > 1:
                S[a[j]] -= 1
                j += 1
        res = max(res, i - j + 1)
        
if __name__ == "__main__":
    main()

思路

  1. 先考虑暴力做法
  2. 观察i和j之间有什么单调关系
  3. 用单调关系套模板将时间复杂度从O(n^2)降到O(n)

位运算

  1. 求n二进制的第k位数 n >> k & 1
  2. 求最低位为1的位置lowbit(n) = n & -n,因为存储方式为补码所以-n为n的反码加1。这样在最低为1的位置前面后面做与运算都为0,仅最低位为1的位置为1

二、第二章 数据结构

单链表(从头部插入)

N = 100010
e, ne, idx = [0] * N, [0] * N, 0
head = -1


def init():
    global idx, head
    head = -1
    idx = 0

def add(k, x):
    #在k之后插入一个节点,共有四个指针操作
    global idx, head
    e[idx] = x
    ne[idx] = head
    head = idx
    idx += 1

def remove(k):
    head = ne[head]

def main():
    init()
    print(idx)

if __name__ == "__main__":
    main()

双链表

N = 100010
e, l, r, idx = [0] * N, [0] * N, [0] * N, 2


def init():
    global idx
    r[0], l[1], idx = 1, 0, 2

def add(k, x):
    #在k之后插入一个节点,共有四个指针操作
    global idx
    e[idx] = x
    l[idx] = k
    r[idx] = r[k]
    l[r[k]] = idx
    r[k] = idx
    idx += 1

def remove(k):
    r[l[k]] = r[k]
    l[r[k]] = l[k]

def main():
    init()
    print(idx)

if __name__ == "__main__":
    main()

N = 100010
stk = [0] * N#栈
tt = -1#栈顶指针

#插入
def insert(x):
    global tt
    tt += 1
    stk[tt] = x

#弹出
def delete():
    global tt
    tt -= 1

#栈是否为空
def is_empty():
    global tt
    return tt <= 0

#返回栈顶元素
def get_top():
    global stk, tt
    return stk[tt]

def main():
    pass

if __name__ == "__main__":
    main()

队列

N = 100010
q = [0] * N#队列
hh, tt = 0, -1#hh为队头,tt为队尾

#向队尾插入一个数
def insert(x):
    global tt
    tt += 1
    q[tt] = x

def delete():
    global hh
    hh += 1

def get_top():
    global hh
    return q[hh]

def is_empty():
    global hh, tt;
    return hh > tt

def main():
    pass

if __name__ == "__main__":
    main()

单调栈

定义

从栈底到栈顶递增,每次入栈时不断剔除比当前入栈元素小的数,最后将当前元素入栈
从栈底到栈顶递减,同理

题型

找出某个数左边离其最近的最大/最小的数

例子

3 4 2 7 5
每个数离其最近的数:-1 3 -1 2 2
暴力做法:

s = [-1] * n
for i in range(n):
	for j in range(i - 1, 0, -1):
		if a[i] > a[j]:
			s[i] = a[j]
			break

存在关系

如果后面的数小于前面的数,则前面的数不用被考虑。
例如当要寻找a6左边最近的比其小的数,如果a5<a4,则a4必不可能为左边最近的那个比其小的数,所以a4不用考虑。

优化

用一个栈存储左边的所有元素

  1. 如果当前元素比栈顶元素小,则踢出栈顶元素
  2. 直到栈空或者栈顶元素比当前元素大,将该元素入栈
    这样栈中元素为优化后需要和当前元素比较的数

如当寻找7这个位置符合要求的数时,栈中元素仅有2。因为3 4大于2,在处理2后都执行了出栈操作。此时7仅需与2比较

stk = [0] * 5
tt = -1
a = [3, 4, 2, 7, 5]

def main():
    global tt
    n = len(a)
    for i in range(n):
        if tt >= 0:
            for j in range(tt, -1, -1):
                if stk[j] < a[i]:
                    print(str(stk[j]) + ' ')
                    break
                if j == 0: print('-1 ')
            #栈不为空且栈顶元素大于当前准备入栈元素
            while tt + 1 and stk[tt] >= a[i]: tt -= 1
            tt += 1
            stk[tt] = a[i] 
            
        else:
            print('-1 ')
            tt += 1
            stk[tt] = a[i]

if __name__ == "__main__":
    main()

代码优化:实际上第二个循环寻找离当前元素最近且比自己小的数与构建单调栈的过程相同,可以写在一起

stk = [0] * 5
tt = -1
a = [3, 4, 2, 7, 5]

def main():
    global tt
    n = len(a)
    for i in range(n):
        while tt + 1 and a[i] < stk[tt]: tt -= 1
        if tt + 1:
            print(str(stk[tt]) + ' ')
        else:
            print('-1 ')
        tt += 1
        stk[tt] = a[i]

if __name__ == "__main__":
    main()

单调队列

题型

找出滑动窗口中的最大/最小值

例子

1 3 -1 -3 5 3 6 7
需要输出窗口中的最小值
当窗口中包含3 -1 -3

存在关系

当窗口向后移动时,只要-3存在就一定不会输出3和-1,也就是不用考虑3和-1

优化

维护一个单调队列,当元素入队时,从队头开始遍历,依次剔除比入队元素大的数
如遍历到-1时,队中原本为1 3,变为-1

q = [0] * 5#用来存储单调队列的下标
hh, tt = 0, -1
a = [3, 4, 2, 7, 5]
k = 3

def main():
    global tt, hh, k
    n = len(a)
    for i in range(n):
        #如果单调队列不为空且单调队列中的元素个数大于滑动窗口大小
        #i + 1 - k为滑动窗口最左元素的下标
        while hh <= tt and i + 1 - k > q[hh]: hh += 1 
        #因为要找到滑动窗口的最小值所以肯定是从单调队列较大的那一边开始剔除
        #所以tt -= 1,直到遇到比其小的元素
        while hh <= tt and a[q[tt]] >= a[i]: tt -= 1#这是一个双向队列,一般队列不会从队尾退出元素
        tt += 1
        q[tt] = i
        if i + 1 >= k: print(str(a[q[hh]]) + ' ')
if __name__ == "__main__":
    main()

KMP(还不熟)

字符串匹配

暴力做法

s, p = [0] *n, [0] * m
for i in range(n):
    flag = True
    for j in range(m):
        if s[i] != p[j]:
            flag = False
            break

在这里插入图片描述
如图,红色模板字符串与原字符串匹配,匹配到AB段。要对暴力做法进行优化,如果AB段中的首段和尾端存在相同的子串(蓝色),那么就可以不必再匹配该部分。

用next[i]来表示当原字符串遍历到i时,首段和尾端存在相同的子串(蓝色)的最大长度,即p[1…j] = p[i + 1 - j…1]

思考

KMP就是用next保存匹配过程中产生的信息,再在已有信息的基础上进行匹配,从而避免了一些重复步骤

如图,暴力做法在匹配到AB匹配不下去后,需要j++继续进行匹配。而KMP仅需从j = ne[j]进行匹配,从而减少了匹配次数。

p, s, ne = [0] * 10010, [0] * 10010, [0] * 10010#两个字符串和next数组
n, m = 4, 2#两个字符串的长度
def main():

    '''求next数组,这一步仅操作原数组'''
    j = 0
    for i in range(2, n + 1):
        while j and p[i] != p[j + 1]: j = ne[j]#用已求得的next数组依次匹配,会出现两种情况:1.j = ne[1] = 0; 2.匹配成功
        if p[i] == p[j + 1]: j += 1#这一步是确定匹配成功,而不是j变为0
        ne[i] = j

    ''''匹配'''
    j = 0
    for i in range(1, n + 1):
        #i, j分别表示原字符串和匹配字符串当前扫描到的位置
        while j and s[i] != p[j + 1]: j = ne[j]#当前匹配到的字符串不为空且下一个字符不匹配,则尝试下一个更小子串
        if s[i] == p[j + 1]: j += 1#如果匹配,则继续匹配
        if j == n: print('匹配成功')#匹配到匹配字符串的最后一个字符,则匹配成功
    
if __name__ == "__main__":
    main()

并查集

用法

近乎O(1)完成

  1. 将两个集合合并
  2. 询问两个元素是否在一个集合中

基本原理

每一个集合都用一棵树来表示,树根的编号就是整个集合的编号。每个节点存储其父节点的编号,用p[x]表示x的父节点

问题

  1. 如何判断树根:
    人为规定根节点的在数组p中等于自身
if p[x] == x:
	print(x)
  1. 如何求x的集合编号:
while p[x] != x:
	x = p[x]
  1. 如何合并两个集合
    假设p[x]为x的集合编号,p[y]为y的集合编号,将集合x合并到y
p[x] = y

优化

路径压缩,每当找到一个节点的根节点时,将该节点直接指向根节点

实现

p = [0] * 10010

def find(x):#加入路径压缩
    global p
    if p[x] != x:
        p[x] = find(p[x])
    return p[x]

def main():
    global p
    n = int(input('输入点的个数:'))
    for i in range(n):
        p[i] = i
    x, y, z = map(int,input().split())
    if x == 1:#插入操作,z所在的集合插入y所在的集合
        p[find(z)] = find(y)
    else:
        if find(y) == find(z): print('y和z在同一个集合中')
        else: print('NO')

if __name__ == "__main__":
    main()

性质

  1. 堆是一棵完全二叉树
  2. 根小于等于两个儿子(小根堆)

基本操作

用一维数组来存储,假设根是x,则左儿子为2x,右儿子为2x + 1

  1. down(x)
    向下调整
在这里插入代码片
  1. up(x)
    向上调整

手写一个堆

  1. 插入一个数
size += 1
heap[size] = x
up(size)
  1. 求集合中的一个最小值
heap[1]
  1. 删除最小值
heap[1] = heap[size]
size -= 1
down(1)
  1. 删除任意一个元素
heap[k] = heap[size]
size -= 1
up(k)
down(k)
#因为不好判断最后一个元素移到k后需要向上还是向下调整
  1. 修改任意一个元素
heap[k] = x
down(k)
up(k)

实现

h = [0] * 10010

def down(u, size):
    t = u
    if u * 2 <= size and h[u * 2] < h[t]: t = u * 2
    if u * 2 + 1 <= size and h[u * 2 + 1] < h[t]: t = u * 2 + 1
    if u != t:#找到根节点和两个儿子节点中最小的那个的编号,如果该编号不是根节点,则交换该节点和根节点
        h[t], h[u] = h[u], h[t]

def up(u):
    #up操作时,只需要和其根节点比较,所以比down操作更为简单
    while u / 2 and h[u / 2] > h[u]:
        h[u / 2], h[u] = h[u], h[u / 2]
        u /= 2

def main():
    global h
    n = int(input('输入点的个数:'))
    for i in range(n):
        h[i] = i
    size = n

    for i in range(n / 2, 0, -1):
    	#第1个节点到第n/2个节点就是除了最底层的所有节点
        down(i, size)#构建堆,复杂度为1

if __name__ == "__main__":
    main()

哈希表

用法

将很大的一个值域映射到0~N(较小的范围)

例子

将-109 ~ 109映射到0 ~ 105

  1. x mod 105
  2. 冲突:可能两个数映射到同一个位置

存储方式(处理冲突)

  1. 开放寻址法

  2. 拉链法
    有点像邻接表,在重复的位置上用一条链表存储

N = 100003
h = [-1] * N
e, ne = [0] * N, [0] * N
idx = 0

def insert(x):
    global idx, e, ne, h
    k = (x % N + N) % N#c++中处理负数的余数还是负数,但实际上需要其为正数
    e[idx] = x
    ne[idx] = h[k]
    h[k] = idx#相当于把最新的值放到哈希表中,其余的移至链表 
    idx += 1

def find(x):
    global idx, e, ne, h
    k = (x % N + N) % N
    i = h[k]
    while e[i] != x and i != -1: i = ne[i]
    if i == -1: return False
    else: return True

def find1(x):
    #如果x在哈希表中,返回x的下标;如果x不在哈希表中,返回x应该插入的位置
    global idx, e, ne, h
    k = (x % N + N) % N
    while h[k] < 0 and h[k] != x:
        k += 1
        if k == N: k = 0
    return k

def main():
    n = int(input())

    while n:
        n -= 1;
        op, x = map(int, input().split())
        if op == 0:
            insert(x)
        else:
            if find(x):
                print('Yes')
            else:
                print('No')


if __name__ == "__main__":
    main()

第三章 搜索与图论

DFS与DFS

  1. 数据结构:stack与queue
  2. 所需空间:O(h)与O(2h)
  3. 特性:“不具有最短性”与“最短路,先搜到的点到根节点的距离更近”
DFS
回溯
N = 10
path = [None] * N#用于枚举的数组
st = [False] * N#记录枚举数组对应位置是否被填过
n = None#枚举的长度

def dfs(u):
    global n
    if u == n:
        for i in range(n):
            print(path[i], end='')
        print("")
        return
    
    for i in range(1, n + 1):
        if not st[i]:
            path[u] = i
            st[i] = True
            dfs(u + 1)
            #path[u] = 0
            st[i] = False
    
def main():
    global n
    n = int(input())

    dfs(0)

if __name__ == "__main__":
    main()

思路
写dfs先写出口,当u == n时输出结果
再写递归
1. 用for遍历枚举数组
2. 用st判断枚举数组当前位置是否被枚举
3. 如果没有,枚举该位置,将st置为True
4. 向下递归,dfs(u + 1)
5. 回溯,还原现场
1. 修改节点值
2. 修改st为False

剪枝

n皇后问题
该问题n为2,3时无解

N = 20
col, dg, udg = [False] * N, [False] * N, [False] * N#行,正反对角线,因为是从上向下遍历,所以不用列对角线
n = None#枚举的长度
g = [['*'] * N for i in range(N)]


def dfs(u):
    global n
    if u == n:
        for i in range(n):
            print(g[i], end='')
            print("")
        print("")
        return
    
    for i in range(n):
        if (not col[i]) and (not dg[i - u + n]) and (not udg[i + u]):
            #以正对角线为例,y = x - b -> b = x - y,y = x + b -> b = y - x
            #这里x是i,y是u,b可以唯一标识一条对角线,加n是为了避免出现负数
            g[u][i] = 'Q'
            col[i] = dg[i - u + n] = udg[i + u] = True
            dfs(u + 1)
            col[i] = dg[i - u + n] = udg[i + u] = False
            g[u][i] = '*'
    
def main():
    global n
    n = int(input())

    dfs(0)

if __name__ == "__main__":
    main()

BFS

BFS没有回溯,所以不需要额外的数组来记录是否遍历过
在这里插入图片描述

N = 110
g, d = [[0] * N for i in range(N)], [[-1] * N for i in range(N)]
q = [[-1, -1] for i in range(N * N)]
n, m = None, None

def bfs():
    global n, m
    hh, tt = 0, 0#因为q中已经有一个元素(0,0)所以tt(队尾)为0
    q[0] = [0, 0]

    d[0][0] = 0

    dx, dy = [-1, 0, 1, 0], [0, 1, 0, -1]

    while hh <= tt:
        t = q[hh]
        hh += 1
        #遍历q[hh]的四个方向
        for i in range(4):
            x, y = t[0] + dx[i], t[1] + dy[i]
            if x >=0 and x < n and y >= 0 and y < n and g[x][y] == 0 and d[x][y] == -1:
                d[x][y] = d[t[0]][t[1]] + 1
                #pre[x][y] = t
                tt += 1
                q[tt] = [x, y]
    return d[n - 1][m - 1]
    
def main():
    global n, m
    n, m = map(int, input().split())
    for i in range(n):
            g[i] = list(map(int, input().split()))
    print(bfs())
if __name__ == "__main__":
    main()
'''
5 5
0 1 0 0 0
0 1 0 1 0
0 0 0 0 0
0 1 1 1 0
0 0 0 1 0
'''

如果要记录路径,创一个二维数组pre[N][N]记录,当前走到节点的上一个位置。
最后由后往前推即可得到完整路径

树和图的遍历

有向图:

  1. 邻接矩阵
  2. 邻接表
    每个节点存储的值表示边的末端端点
DFS
N = 1000010
M = N * 2
h, e, ne = [-1] * N, [-1] * M, [-1] * M
idx = 0
st = [False] * N

def add(a, b):
    global idx
    e[idx] = b
    ne[idx] = h[a]
    h[a] = idx
    idx += 1

def dfs(u):
    st[u] = True

    i = h[u]
    while i != -1:
        j = e[i]
        if not st[j]:
            dfs(j)
        i = ne[i]

def main():
    dfs(1)
    
if __name__ == "__main__":
    main()

在这里插入图片描述

N = 1000010
M = N * 2
h, e, ne = [-1] * N, [-1] * M, [-1] * M
idx = 0
st = [False] * N
ans, n, m = float('inf'), None, None

def add(a, b):
    global idx
    e[idx] = b
    ne[idx] = h[a]
    h[a] = idx
    idx += 1

def dfs(u):
    global ans, n
    st[u] = True
    sum, res = 1, 0#每次遍历一个新的块
    i = h[u]
    while i != -1:
        j = e[i]
        if not st[j]:
            s = dfs(j)
            res = max(res, s)#先与该点的子树比较
            sum += s
        i = ne[i]
    res = max(res, n - sum)#再与父节点往上连接的部分比较

    ans = min(ans, res)

    return sum

def main():
    global n
    n = int(input())

    for i in range(n - 1):
        a, b = map(int, input().split())
        add(a, b)
        add(b, a)

    dfs(1)
    
    print(ans)

if __name__ == "__main__":
    main()

BFS

在这里插入图片描述
因为所有点的权重都为1,所以可以通过宽搜来求最短距离
BFS框架

  1. 将第一个元素插入queue
  2. while queue不空
    1. t = 队头元素
    2. 拓展x的临点。如果x没有被遍历,
      1. 将x入队
      2. d[x] = d[t] + 1
N = 1000010
h, e, ne = [-1] * N, [-1] * N, [-1] * N
idx = 0
n, m = None, None
d, q = [-1] * N, [-1] * N

def add(a, b):
    global idx
    e[idx] = b
    ne[idx] = h[a]
    h[a] = idx
    idx += 1

def bfs():
    hh, tt = 0, 0
    q[0] = 1
    d[1] = 0#这两个初始化别忘了
    while hh <= tt:
        t = q[hh]
        hh += 1
        i = h[t]
        while i != -1:
            j = e[i]
            if d[j] == -1:
                d[j] = d[t] + 1
                tt += 1
                q[tt] = j              
            i = ne[i]
            
    return d[n]

def main():
    global n, m
    n, m = map(int, input().split())

    for i in range(m):
        a, b = map(int, input().split())
        add(a, b)
    print(bfs())

if __name__ == "__main__":
    main()

最短路问题

在这里插入图片描述
朴素Dijkstra算法适用于稠密图,堆优化版适合稀疏图

朴素Dijkstra

基于贪心

步骤
  1. 初始化dist,
    dist[1] = 0(第一个点到第一个点的距离为0);
    dist[2…n] = float(‘inf’)(到其他点的距离为正无穷)
  2. S为当前已确定最短距离的点
    for i in range(1, n + 1):
    t <- 不在s中的距离最近的点
    s <- t
    用t更新其他点的距离
    (每循环一次可以确定一个点)
n, m = map(int, input().split())
st = [False] * n
dist = [float('inf')] * n 
g = [[float('inf')] * n for i in range(n)]
for i in range(m):
    x, y, z = map(int, input().split())   
    g[x - 1][y - 1] = min(g[x - 1][y - 1], z)

def dijkstra():
    dist[0] = 0
    for i in range(n - 1): 
        t = -1
        for j in range(n):
            if not st[j] and (t == -1 or dist[t] > dist[j]):
                t = j#找到离当前确定集合中点最近的
        for j in range(n):
            dist[j] = min(dist[j], dist[t] + g[t][j])#更新每个点到源点的最小距离,min(原来到j的距离, 先到t再到j的距离)
        st[t] = True
    if dist[n - 1] == float('inf'):
        return -1
    print(g)
    return dist[n - 1]
print(dijkstra())

'''
5 10
1 2 2
5 3 3
4 1 8
2 4 3
4 5 7
5 2 3
3 4 1
1 2 9
3 2 3
1 2 8
'''
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值