DPLL算法python实现

目录

1. 引言

2. DPLL 算法简介

3. Python 实现 DPLL 算法

3.1 DPLL_SATSolver

3.2 CNFParser

3.3 verify_solution

3.4 test.py


1. 引言

  • 背景介绍
    • SAT(布尔可满足性问题)是计算机科学中的经典问题,广泛应用于人工智能、硬件验证、软件测试等领域。
    • DPLL(Davis-Putnam-Logemann-Loveland)算法是解决 SAT 问题的经典方法,基于回溯搜索和启发式规则。
  • 目标
    • 本文将简要介绍 DPLL 算法的核心思想,并用 Python 实现一个完整的求解器

2. DPLL 算法简介

  • 什么是 DPLL 算法
    • DPLL 是一种递归算法,结合了简化子句单元传播分支赋值等策略。
  • 核心思想
    • 简化子句:根据当前赋值移除已满足的子句或无效的文字。
    • 单元传播:优先处理只有一个文字的子句,快速缩小搜索空间。
    • 分支赋值:对未赋值变量尝试 True 和 False,并通过回溯探索所有可能解。
  • 算法流程
    • 初始化空赋值。
    • 循环执行简化子句和单元传播。
    • 如果存在冲突,回溯并尝试其他赋值;否则返回解。

3. Python 实现 DPLL 算法

3.1 DPLL_SATSolver

class DPLL_SATSolver:
    def __init__(self,clauses,n_vars):
        self.clauses = clauses #[[1, -2, 3],...]
        self.n_vars = n_vars
        self.assignments = [0] * (n_vars + 1) #索引从1开始   self.assignments[2] = 1
        self.decision_stack = []  # 决策栈,用于回溯

    def solve(self):
        """
        使用 DPLL 算法求解 SAT 问题。
        :return: 如果可满足,返回 True;否则返回 False。
        """
        # 如果所有变量都已赋值且没有冲突,则返回 True
        if self.all_variables_assigned():
            return True
        # 选择一个未赋值变量赋值
        var = self.select_unassigned_variable()
        for value in [1, -1]:  # 尝试赋值 True (1) 和 False (-1)
            self.assignments[var] = value
            self.decision_stack.append(var)  # 记录决策变量

            if self.unit_propagation():  # 单位传播
                if self.solve():  # 递归求解
                    return True

            # 回溯:撤销最近的赋值
            self.backtrack()

        return False

    def all_variables_assigned(self):
        """
        检查是否所有变量都已被赋值。
        :return: 如果所有变量都已赋值,返回 True;否则返回 False。
        """
        return all(self.assignments[var] != 0 for var in range(1, self.n_vars + 1))

    def select_unassigned_variable(self):
        """
        选择一个未赋值的变量(简单实现:按顺序选择)。
        :return: 未赋值的变量编号。
        """
        for var in range(1, self.n_vars + 1):
            if self.assignments[var] == 0:
                return var

    def unit_propagation(self):
        """
        执行单位传播。
        :return: 如果没有冲突,返回 True;否则返回 False。
        """
        while True:
            changed = False #changed = True说明需要继续进行单位传播
            for clause in self.clauses:
                # 检查子句是否已经满足
                if any(self.assignments[abs(lit)] == (1 if lit > 0 else -1) for lit in clause):
                    continue

                # 检查子句是否被赋值为 False
                unassigned_literals = [lit for lit in clause if self.assignments[abs(lit)] == 0] #子句中所有未赋值的变量
                if not unassigned_literals: #子句中没有未赋值的变量,但是子句不可满足,则
                    return False  # 子句无法满足,发生冲突

                # 如果子句只有一个未赋值的变量,则进行单位传播
                if len(unassigned_literals) == 1:
                    lit = unassigned_literals[0]
                    self.assignments[abs(lit)] = 1 if lit > 0 else -1
                    changed = True

            if not changed:
                break

        return True

    def backtrack(self):
        """
        回溯:撤销最近的赋值。
        """
        if not self.decision_stack:
            return
        var = self.decision_stack.pop()
        # self.assignments[var] = 0  # 撤销赋值
        # 将 var 及其之后的赋值重置为 0
        self.assignments[var:] = [0] * (len(self.assignments) - var)

    def get_model(self):
        """
        获取当前的模型(变量赋值)。
        :return: 当前的变量赋值字典。
        bool(a)  a!=0 true a==0 false
        """
        return {var: True if val == 1 else False for var, val in enumerate(self.assignments[1:], start=1) if val != 0}

3.2 CNFParser

import os
import urllib.request
import tarfile
import shutil


class CNFParser:
    @staticmethod
    def download_and_extract_cnf_archive(url, dataset_name):
        """
        下载并解压 CNF 文件压缩包到指定的数据集目录。

        :param url: CNF 文件压缩包的 URL 地址
        :param dataset_name: 数据集名称(将作为子目录名)
        :return: 解压后的文件路径
        """
        # 确保 dataset 目录存在
        base_path = "./dataset"
        if not os.path.exists(base_path):
            os.makedirs(base_path)
            print(f"Created dataset directory: {base_path}")

        # 创建特定数据集的子目录
        save_path = os.path.join(base_path, dataset_name)
        if not os.path.exists(save_path):
            os.makedirs(save_path)
            print(f"Created subdirectory for dataset: {save_path}")

        print(f"Downloading and extracting CNF archive from {url}...")
        try:
            # 下载压缩包
            compressed_file_path = os.path.join(save_path, f"{dataset_name}.tar.gz")
            urllib.request.urlretrieve(url, compressed_file_path)

            # 解压压缩包
            with tarfile.open(compressed_file_path, 'r:gz') as tar:
                tar.extractall(path=save_path, filter='data')
                print(f"Files extracted to {save_path}")

            # 删除临时压缩包
            os.remove(compressed_file_path)
            print(f"Removed temporary file: {compressed_file_path}")

            return save_path
        except Exception as e:
            print(f"Failed to download or extract file: {e}")
            return None

    @staticmethod
    def parse_cnf_file(file_path):
        """
        解析 CNF 文件并提取变量数量、子句数量以及子句列表。

        :param file_path: CNF 文件路径
        :return: (n_vars, clauses)
            n_vars: 变量数量
            clauses: 子句列表,例如 [[1, -2, 3], [-1, 4], ...]
        """
        clauses = []
        n_vars = 0
        n_clauses = 0

        if not os.path.exists(file_path):
            raise FileNotFoundError(f"File not found: {file_path}")

        with open(file_path, 'r') as f:
            for line in f:
                line = line.strip()
                if not line or line.startswith('c') or line.startswith('%'):  # 忽略空行和注释行(包括以 % 开头的)
                    continue
                if line.startswith('p cnf'):
                    # 提取变量数量和子句数量
                    parts = line.split()
                    if len(parts) < 4:
                        raise ValueError("Invalid 'p cnf' line format")
                    _, _, n_vars, n_clauses = parts
                    n_vars = int(n_vars)
                    n_clauses = int(n_clauses)
                else:
                    # 跳过非数字行(如 % 或其他特殊字符)
                    if line == '0':
                        continue  # 忽略孤立的 0 行

                    try:
                        # 提取子句
                        clause = list(map(int, line.split()))[:-1]  # 去掉末尾的 0
                        clauses.append(clause)
                    except ValueError:
                        print(f"Skipping invalid line: {line}")
                        continue

        assert len(clauses) == n_clauses, "子句数量不匹配!"
        return n_vars, clauses

3.3 verify_solution

def verify_solution(clauses, model):
    """
    验证给定的模型是否满足所有子句。

    :param clauses: CNF 子句列表,例如 [[1, -2, 3], [-1, 4], ...]
    :param model: 变量赋值字典,例如 {1: True, 2: False, 3: True, ...}
    :return: 如果模型满足所有子句,返回 True;否则返回 False。
    """
    for clause in clauses:
        # 检查当前子句是否至少有一个文字为 True
        if not any(model[abs(lit)] == (lit > 0) for lit in clause):
            print(f"Clause {clause} is not satisfied.")
            return False

    print("All clauses are satisfied.")
    return True

3.4 test.py

from dpll import DPLL_SATSolver,verify_solution
from cnf_parser import CNFParser
import os
import shutil

if __name__ == "__main__":
    k = 100 # 求解的 SAT 实例数量
    correct_k = 0
    # 下载、解压和解析
    test_url = "https://www.cs.ubc.ca/~hoos/SATLIB/Benchmarks/SAT/RND3SAT/uuf50-218.tar.gz"  # 示例 CNF 文件压缩包 URL
    dataset_name = "uuf50-218"  # 数据集名称

    try:
        # 下载并解压 CNF 文件压缩包
        save_path = CNFParser.download_and_extract_cnf_archive(test_url, dataset_name)
        if save_path is None:
            raise RuntimeError("Failed to download or extract the dataset.")

        # 查找解压目录下的所有 CNF 文件
        cnf_files = [os.path.join(dp, f) for dp, dn, filenames in os.walk(save_path) for f in filenames if
                     f.endswith('.cnf')]
        if not cnf_files:
            print("No CNF files found in the directory.")
        else:
            # 对前 k 个 CNF 文件进行求解
            for i, cnf_file in enumerate(cnf_files[:k], start=1):
                filename = os.path.basename(cnf_file)
                print(f"\nProcessing CNF file {i}: {cnf_file},filename:{filename}")
                n_vars, clauses = CNFParser.parse_cnf_file(cnf_file)
                print(f"Parsed CNF file: {n_vars} variables, {len(clauses)} clauses")

                # 创建 SATSolver 对象
                solver = DPLL_SATSolver(clauses, n_vars)

                if solver.solve():
                    if filename.startswith("uf"):
                        correct_k += 1
                    print("SAT")
                    model = solver.get_model()
                    print("Model:", model)

                    # 验证赋值正确性
                    if verify_solution(clauses, model):
                        print("The solution is correct!")
                    else:
                        print("The solution is incorrect!")

                else:
                    if filename.startswith("uuf"):
                        correct_k += 1
                    print("UNSAT")
        accuracy = correct_k / k
        print("Accuracy:", accuracy)
        print("Accuracy Detail:", correct_k ,"/" , k)

    except Exception as e:
        print(f"Test failed: {e}")
    finally:
        # 清理测试文件(可选)
        base_path = "./dataset"
        if os.path.exists(base_path):
            shutil.rmtree(base_path)
            print(f"\nRemoved dataset directory: {base_path}")

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值