目录
1. 引言
- 背景介绍:
- SAT(布尔可满足性问题)是计算机科学中的经典问题,广泛应用于人工智能、硬件验证、软件测试等领域。
- DPLL(Davis-Putnam-Logemann-Loveland)算法是解决 SAT 问题的经典方法,基于回溯搜索和启发式规则。
- 目标:
- 本文将简要介绍 DPLL 算法的核心思想,并用 Python 实现一个完整的求解器
2. DPLL 算法简介
- 什么是 DPLL 算法:
- DPLL 是一种递归算法,结合了简化子句、单元传播和分支赋值等策略。
- 核心思想:
- 简化子句:根据当前赋值移除已满足的子句或无效的文字。
- 单元传播:优先处理只有一个文字的子句,快速缩小搜索空间。
- 分支赋值:对未赋值变量尝试
True
和False
,并通过回溯探索所有可能解。
- 算法流程:
- 初始化空赋值。
- 循环执行简化子句和单元传播。
- 如果存在冲突,回溯并尝试其他赋值;否则返回解。
3. Python 实现 DPLL 算法
3.1 DPLL_SATSolver
class DPLL_SATSolver:
def __init__(self,clauses,n_vars):
self.clauses = clauses #[[1, -2, 3],...]
self.n_vars = n_vars
self.assignments = [0] * (n_vars + 1) #索引从1开始 self.assignments[2] = 1
self.decision_stack = [] # 决策栈,用于回溯
def solve(self):
"""
使用 DPLL 算法求解 SAT 问题。
:return: 如果可满足,返回 True;否则返回 False。
"""
# 如果所有变量都已赋值且没有冲突,则返回 True
if self.all_variables_assigned():
return True
# 选择一个未赋值变量赋值
var = self.select_unassigned_variable()
for value in [1, -1]: # 尝试赋值 True (1) 和 False (-1)
self.assignments[var] = value
self.decision_stack.append(var) # 记录决策变量
if self.unit_propagation(): # 单位传播
if self.solve(): # 递归求解
return True
# 回溯:撤销最近的赋值
self.backtrack()
return False
def all_variables_assigned(self):
"""
检查是否所有变量都已被赋值。
:return: 如果所有变量都已赋值,返回 True;否则返回 False。
"""
return all(self.assignments[var] != 0 for var in range(1, self.n_vars + 1))
def select_unassigned_variable(self):
"""
选择一个未赋值的变量(简单实现:按顺序选择)。
:return: 未赋值的变量编号。
"""
for var in range(1, self.n_vars + 1):
if self.assignments[var] == 0:
return var
def unit_propagation(self):
"""
执行单位传播。
:return: 如果没有冲突,返回 True;否则返回 False。
"""
while True:
changed = False #changed = True说明需要继续进行单位传播
for clause in self.clauses:
# 检查子句是否已经满足
if any(self.assignments[abs(lit)] == (1 if lit > 0 else -1) for lit in clause):
continue
# 检查子句是否被赋值为 False
unassigned_literals = [lit for lit in clause if self.assignments[abs(lit)] == 0] #子句中所有未赋值的变量
if not unassigned_literals: #子句中没有未赋值的变量,但是子句不可满足,则
return False # 子句无法满足,发生冲突
# 如果子句只有一个未赋值的变量,则进行单位传播
if len(unassigned_literals) == 1:
lit = unassigned_literals[0]
self.assignments[abs(lit)] = 1 if lit > 0 else -1
changed = True
if not changed:
break
return True
def backtrack(self):
"""
回溯:撤销最近的赋值。
"""
if not self.decision_stack:
return
var = self.decision_stack.pop()
# self.assignments[var] = 0 # 撤销赋值
# 将 var 及其之后的赋值重置为 0
self.assignments[var:] = [0] * (len(self.assignments) - var)
def get_model(self):
"""
获取当前的模型(变量赋值)。
:return: 当前的变量赋值字典。
bool(a) a!=0 true a==0 false
"""
return {var: True if val == 1 else False for var, val in enumerate(self.assignments[1:], start=1) if val != 0}
3.2 CNFParser
import os
import urllib.request
import tarfile
import shutil
class CNFParser:
@staticmethod
def download_and_extract_cnf_archive(url, dataset_name):
"""
下载并解压 CNF 文件压缩包到指定的数据集目录。
:param url: CNF 文件压缩包的 URL 地址
:param dataset_name: 数据集名称(将作为子目录名)
:return: 解压后的文件路径
"""
# 确保 dataset 目录存在
base_path = "./dataset"
if not os.path.exists(base_path):
os.makedirs(base_path)
print(f"Created dataset directory: {base_path}")
# 创建特定数据集的子目录
save_path = os.path.join(base_path, dataset_name)
if not os.path.exists(save_path):
os.makedirs(save_path)
print(f"Created subdirectory for dataset: {save_path}")
print(f"Downloading and extracting CNF archive from {url}...")
try:
# 下载压缩包
compressed_file_path = os.path.join(save_path, f"{dataset_name}.tar.gz")
urllib.request.urlretrieve(url, compressed_file_path)
# 解压压缩包
with tarfile.open(compressed_file_path, 'r:gz') as tar:
tar.extractall(path=save_path, filter='data')
print(f"Files extracted to {save_path}")
# 删除临时压缩包
os.remove(compressed_file_path)
print(f"Removed temporary file: {compressed_file_path}")
return save_path
except Exception as e:
print(f"Failed to download or extract file: {e}")
return None
@staticmethod
def parse_cnf_file(file_path):
"""
解析 CNF 文件并提取变量数量、子句数量以及子句列表。
:param file_path: CNF 文件路径
:return: (n_vars, clauses)
n_vars: 变量数量
clauses: 子句列表,例如 [[1, -2, 3], [-1, 4], ...]
"""
clauses = []
n_vars = 0
n_clauses = 0
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, 'r') as f:
for line in f:
line = line.strip()
if not line or line.startswith('c') or line.startswith('%'): # 忽略空行和注释行(包括以 % 开头的)
continue
if line.startswith('p cnf'):
# 提取变量数量和子句数量
parts = line.split()
if len(parts) < 4:
raise ValueError("Invalid 'p cnf' line format")
_, _, n_vars, n_clauses = parts
n_vars = int(n_vars)
n_clauses = int(n_clauses)
else:
# 跳过非数字行(如 % 或其他特殊字符)
if line == '0':
continue # 忽略孤立的 0 行
try:
# 提取子句
clause = list(map(int, line.split()))[:-1] # 去掉末尾的 0
clauses.append(clause)
except ValueError:
print(f"Skipping invalid line: {line}")
continue
assert len(clauses) == n_clauses, "子句数量不匹配!"
return n_vars, clauses
3.3 verify_solution
def verify_solution(clauses, model):
"""
验证给定的模型是否满足所有子句。
:param clauses: CNF 子句列表,例如 [[1, -2, 3], [-1, 4], ...]
:param model: 变量赋值字典,例如 {1: True, 2: False, 3: True, ...}
:return: 如果模型满足所有子句,返回 True;否则返回 False。
"""
for clause in clauses:
# 检查当前子句是否至少有一个文字为 True
if not any(model[abs(lit)] == (lit > 0) for lit in clause):
print(f"Clause {clause} is not satisfied.")
return False
print("All clauses are satisfied.")
return True
3.4 test.py
from dpll import DPLL_SATSolver,verify_solution
from cnf_parser import CNFParser
import os
import shutil
if __name__ == "__main__":
k = 100 # 求解的 SAT 实例数量
correct_k = 0
# 下载、解压和解析
test_url = "https://www.cs.ubc.ca/~hoos/SATLIB/Benchmarks/SAT/RND3SAT/uuf50-218.tar.gz" # 示例 CNF 文件压缩包 URL
dataset_name = "uuf50-218" # 数据集名称
try:
# 下载并解压 CNF 文件压缩包
save_path = CNFParser.download_and_extract_cnf_archive(test_url, dataset_name)
if save_path is None:
raise RuntimeError("Failed to download or extract the dataset.")
# 查找解压目录下的所有 CNF 文件
cnf_files = [os.path.join(dp, f) for dp, dn, filenames in os.walk(save_path) for f in filenames if
f.endswith('.cnf')]
if not cnf_files:
print("No CNF files found in the directory.")
else:
# 对前 k 个 CNF 文件进行求解
for i, cnf_file in enumerate(cnf_files[:k], start=1):
filename = os.path.basename(cnf_file)
print(f"\nProcessing CNF file {i}: {cnf_file},filename:{filename}")
n_vars, clauses = CNFParser.parse_cnf_file(cnf_file)
print(f"Parsed CNF file: {n_vars} variables, {len(clauses)} clauses")
# 创建 SATSolver 对象
solver = DPLL_SATSolver(clauses, n_vars)
if solver.solve():
if filename.startswith("uf"):
correct_k += 1
print("SAT")
model = solver.get_model()
print("Model:", model)
# 验证赋值正确性
if verify_solution(clauses, model):
print("The solution is correct!")
else:
print("The solution is incorrect!")
else:
if filename.startswith("uuf"):
correct_k += 1
print("UNSAT")
accuracy = correct_k / k
print("Accuracy:", accuracy)
print("Accuracy Detail:", correct_k ,"/" , k)
except Exception as e:
print(f"Test failed: {e}")
finally:
# 清理测试文件(可选)
base_path = "./dataset"
if os.path.exists(base_path):
shutil.rmtree(base_path)
print(f"\nRemoved dataset directory: {base_path}")