Code Backup : create CSV file

本文详细介绍了一种生成CSV及DDL文件的方法,包括如何从结果集中提取数据并将其写入文件,涵盖了文件路径创建、临时文件生成、字段映射及异常处理等关键步骤。适用于需要批量处理数据并导出为特定格式文件的场景。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

/**
* 方法名称:生成CSV文件

* 概要说明:

*
* @param results 结果集
* @param errors 错误结果集
* @param outPutPath 输出文件路径
* @param fileName 输出文件名称
*
* @return File CSV文件
*
*/
public File createCSVFile(List<Map<String, String>> results, List<Map<String, String>> errors, String outPutPath,
String fileName) {
File csvFile = null;
BufferedWriter csvFileOutputStream = null;
try {
File file = new File(outPutPath);
if (!file.exists()) {
file.mkdirs();
}
csvFile = File.createTempFile(fileName, “.csv”, new File(outPutPath));
csvFileOutputStream = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(csvFile), “UTF-8”),
1024);
csvFileOutputStream.write(“标识,编号,名称,路线编号,桩号,经度,纬度,起点桩号,起点经度,起点纬度,终点桩号,终点经度,终点纬度,行政区划代码,模糊匹配线路编号,备注”);
for (Map<String, String> result : results) {
csvFileOutputStream.newLine();
csvFileOutputStream.write(result.get(“BS”) == null ? “,” : result.get(“BS”).toString() + “,”);
csvFileOutputStream.write(result.get(“BH”) == null ? “,” : result.get(“BH”).toString() + “,”);
csvFileOutputStream.write(result.get(“MC”) == null ? “,” : result.get(“MC”).toString() + “,”);
csvFileOutputStream.write(result.get(“LXBH”) == null ? “,” : result.get(“LXBH”).toString() + “,”);
csvFileOutputStream.write(result.get(“GCZZH”) == null ? “,” : result.get(“GCZZH”).toString() + “,”);
csvFileOutputStream.write(result.get(“GCZJD”) == null ? “,” : result.get(“GCZJD”).toString() + “,”);
csvFileOutputStream.write(result.get(“GCZWD”) == null ? “,” : result.get(“GCZWD”).toString() + “,”);
csvFileOutputStream.write(result.get(“QDZH”) == null ? “,” : result.get(“QDZH”).toString() + “,”);
csvFileOutputStream.write(result.get(“QDJD”) == null ? “,” : result.get(“QDJD”).toString() + “,”);
csvFileOutputStream.write(result.get(“QDWD”) == null ? “,” : result.get(“QDWD”).toString() + “,”);
csvFileOutputStream.write(result.get(“ZDZH”) == null ? “,” : result.get(“ZDZH”).toString() + “,”);
csvFileOutputStream.write(result.get(“ZDJD”) == null ? “,” : result.get(“ZDJD”).toString() + “,”);
csvFileOutputStream.write(result.get(“ZDWD”) == null ? “,” : result.get(“ZDWD”).toString() + “,”);
csvFileOutputStream.write(result.get(“XZQH”) == null ? “,” : result.get(“XZQH”).toString() + “,”);
csvFileOutputStream
.write(result.get(“fuzzyLXBH”) == null ? “,” : result.get(“fuzzyLXBH”).toString() + “,”);
csvFileOutputStream.write(result.get(“REASON”) == null ? “” : result.get(“REASON”).toString() + “”);
}
csvFileOutputStream.newLine();
csvFileOutputStream.write(“errors”);
for (Map<String, String> error : errors) {
csvFileOutputStream.newLine();
csvFileOutputStream.write(error.get(“BS”) == null ? “,” : error.get(“BS”).toString() + “,”);
csvFileOutputStream.write(error.get(“BH”) == null ? “,” : error.get(“BH”).toString() + “,”);
csvFileOutputStream.write(error.get(“MC”) == null ? “,” : error.get(“MC”).toString() + “,”);
csvFileOutputStream.write(error.get(“LXBH”) == null ? “,” : error.get(“LXBH”).toString() + “,”);
csvFileOutputStream.write(error.get(“GCZZH”) == null ? “,” : error.get(“GCZZH”).toString() + “,”);
csvFileOutputStream.write(error.get(“GCZJD”) == null ? “,” : error.get(“GCZJD”).toString() + “,”);
csvFileOutputStream.write(error.get(“GCZWD”) == null ? “,” : error.get(“GCZWD”).toString() + “,”);
csvFileOutputStream.write(error.get(“QDZH”) == null ? “,” : error.get(“QDZH”).toString() + “,”);
csvFileOutputStream.write(error.get(“QDJD”) == null ? “,” : error.get(“QDJD”).toString() + “,”);
csvFileOutputStream.write(error.get(“QDWD”) == null ? “,” : error.get(“QDWD”).toString() + “,”);
csvFileOutputStream.write(error.get(“ZDZH”) == null ? “,” : error.get(“ZDZH”).toString() + “,”);
csvFileOutputStream.write(error.get(“ZDJD”) == null ? “,” : error.get(“ZDJD”).toString() + “,”);
csvFileOutputStream.write(error.get(“ZDWD”) == null ? “,” : error.get(“ZDWD”).toString() + “,”);
csvFileOutputStream.write(error.get(“XZQH”) == null ? “,” : error.get(“XZQH”).toString() + “,”);
csvFileOutputStream
.write(error.get(“fuzzyLXBH”) == null ? “,” : error.get(“fuzzyLXBH”).toString() + “,”);
csvFileOutputStream.write(error.get(“REASON”) == null ? “” : error.get(“REASON”).toString() + “”);
}
csvFileOutputStream.close();
} catch (Exception e) {
return null;
}
return csvFile;
}

/**
 * <B>方法名称:</B>生成CSV文件<BR>
 * <B>概要说明:</B><BR>
 * 
 * @param results 结果集
 * @param errors 错误结果集
 * @param outPutPath 输出文件路径
 * @param fileName 输出文件名称
 * 
 * @return File CSV文件
 * 
 */
public File createDDLFile(List<Map<String, String>> results, List<Map<String, String>> errors, String outPutPath,
        String fileName) {
    File file = new File(outPutPath + fileName + ".txt");
    try {
        if (!file.getParentFile().exists()) {
            file.getParentFile().mkdirs();
        }
        FileOutputStream fileOutputStream = new FileOutputStream(file);
        PrintWriter printWriter = new PrintWriter(fileOutputStream);
        StringBuffer sql = new StringBuffer();
        for (int i = 0; i < results.size(); i++) {
            try {
                if (!results.get(i).containsKey("QDJD") || !results.get(i).containsKey("QDWD")
                        || !results.get(i).containsKey("ZDJD") || !results.get(i).containsKey("ZDWD")) {
                    continue;
                }
                sql.setLength(0);
                sql.append(" UPDATE STAKE_INFO SET ");
                sql.append(" SSTARTLNG = "
                        + (!results.get(i).containsKey("QDJD") ? "null" : results.get(i).get("QDJD").toString())
                        + ",");
                sql.append(" SSTARTLAT = "
                        + (!results.get(i).containsKey("QDWD") ? "null" : results.get(i).get("QDWD").toString())
                        + ",");
                sql.append(" SENDLNG = "
                        + (!results.get(i).containsKey("ZDJD") ? "null" : results.get(i).get("ZDJD").toString())
                        + ",");
                sql.append(" SENDLAT = "
                        + (!results.get(i).containsKey("ZDWD") ? "null" : results.get(i).get("ZDWD").toString())
                        + " ");
                sql.append(" WHERE GCZBS = '" + results.get(i).get("BS").toString() + "' ");
                sql.append(" AND SID = '" + results.get(i).get("BH").toString() + "' ");
                sql.append(" AND SNAME = '" + results.get(i).get("MC").toString() + "' ");
                sql.append(" AND SROADNUMBER = '" + results.get(i).get("LXBH").toString() + "';\n");
                printWriter.write(sql.toString());
                printWriter.flush();
            } catch (Exception e) {
                printWriter.write("//第" + i + "条数据异常,请检查!\n");
                continue;
            }
        }
        printWriter.close();
    } catch (Exception e) {
        return null;
    }
    return file;
}
import os import re import csv import logging import chardet from datetime import datetime from sqlalchemy import create_engine, exc, DateTime from sqlalchemy.orm import sessionmaker from sqlalchemy.ext.declarative import declarative_base from sqlalchemy import Column, String, Date, Float, BigInteger from sqlalchemy.sql import func from sqlalchemy import text from concurrent.futures import ThreadPoolExecutor from typing import List, Tuple, Dict, Optional, Any # ========= 配置参数 ========= # 使用环境变量,避免硬编码敏感信息 DB_USER = os.getenv("DB_USER", "william") DB_PASSWORD = os.getenv("DB_PASSWORD", "123456") DB_NAME = os.getenv("DB_NAME", "stock_db") DB_HOST = os.getenv("DB_HOST", "localhost") DB_PORT = int(os.getenv("DB_PORT", 5432)) DATA_DIR = os.getenv("DATA_DIR", "/mnt/e/stock_proc_data/raw_data") OUTPUT_DIR = os.getenv("OUTPUT_DIR", "/mnt/e/stock_proc_data/clean_data") LOG_FILE = os.getenv("LOG_FILE", "/mnt/e/stock_proc_data/data_process.log") # 优化配置:并行处理参数和批量大小 MAX_WORKERS = int(os.getenv("MAX_WORKERS", 4)) # 默认使用4个线程 BATCH_SIZE = int(os.getenv("BATCH_SIZE", 10000)) # 数据库批量插入大小 ENCODING_FALLBACK = os.getenv("ENCODING_FALLBACK", "gbk") # 编码检测失败时的回退编码 # ========= 初始化日志 ========= logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(LOG_FILE, encoding='utf-8'), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) # ========= 数据库模型 ========= Base = declarative_base() class StockInfo(Base): __tablename__ = 'stock_info' id = Column(BigInteger, primary_key=True, autoincrement=True) stock_code = Column(String(20), unique=True, comment='股票代码') stock_name = Column(String(50), comment='股票名称') create_time = Column(DateTime, default=datetime.now, comment='创建时间') class StockData(Base): __tablename__ = 'stock_data' id = Column(BigInteger, primary_key=True, autoincrement=True) stock_code = Column(String(20), comment='股票代码') stock_name = Column(String(50), comment='股票名称') adj_type = Column(String(20), comment='复权类型') date = Column(Date, comment='交易日期') open = Column(Float, comment='开盘价') high = Column(Float, comment='最高价') low = Column(Float, comment='最低价') close = Column(Float, comment='收盘价') volume = Column(BigInteger, comment='成交量') amount = Column(Float, comment='成交额') # ========= 数据库连接 ========= # 全局数据库引擎,避免重复创建连接 engine = None Session = None def init_db(): """初始化数据库连接和会话""" global engine, Session try: connection_string = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}" engine = create_engine(connection_string, pool_size=5, max_overflow=10) # 测试连接 with engine.connect() as conn: result = conn.execute(text("SELECT 1")) if result.fetchone(): logger.info("数据库连接成功") else: logger.error("数据库连接测试失败") Session = sessionmaker(bind=engine) return engine except exc.OperationalError as e: logger.error(f"数据库连接失败: {e}") raise # ========= 工具函数 ========= def create_directories(): """创建必要的目录""" for dir_path in [DATA_DIR, OUTPUT_DIR]: if not os.path.exists(dir_path): os.makedirs(dir_path) logger.info(f"创建目录: {dir_path}") def rename_files(): """重命名文件,规范化文件名格式""" logger.info("开始重命名文件...") renamed_count = 0 for root, dirs, files in os.walk(DATA_DIR): for filename in files: if not filename.endswith('.txt'): continue file_path = os.path.join(root, filename) match_code = re.search(r'^(\d+)', filename) if not match_code: logger.warning(f"无法提取股票代码: {filename}") continue stock_code = match_code.group(1) new_filename = f"{stock_code}_daily_data.txt" new_path = os.path.join(root, new_filename) os.rename(file_path, new_path) renamed_count += 1 logger.info(f"重命名完成,处理 {renamed_count} 个文件") return renamed_count def detect_encoding(file_path: str) -> str: """检测文件编码""" try: with open(file_path, 'rb') as f: raw_data = f.read(10240) result = chardet.detect(raw_data) return result.get('encoding', ENCODING_FALLBACK) except Exception as e: logger.warning(f"编码检测失败,使用回退编码 {ENCODING_FALLBACK}: {e}") return ENCODING_FALLBACK def clean_and_save_file(file_path: str) -> Tuple[str, int, str, str]: """清洗单个文件并保存为CSV""" try: filename = os.path.basename(file_path) relative_path = os.path.relpath(file_path, DATA_DIR) output_dir = os.path.join(OUTPUT_DIR, os.path.dirname(relative_path)) os.makedirs(output_dir, exist_ok=True) output_path = os.path.join(output_dir, filename.replace('.txt', '.csv')) # 检测文件编码 encoding = detect_encoding(file_path) with open(file_path, 'r', encoding=encoding) as infile: # 读取第一行,提取股票名称 first_line = infile.readline().strip() match_name = re.search(r'\d+ (.*?) 日线', first_line) stock_name = match_name.group(1).strip() if match_name else "未知" # 从文件名提取代码 match_code = re.search(r'^(\d+)', filename) stock_code = match_code.group(1) if match_code else "未知" # 过滤注释行 valid_lines = [] for line in infile: line = line.strip() if not line.startswith('#'): valid_lines.append(line) # 获取股票信息映射 session = Session() try: if not session.query(StockInfo).filter_by(stock_code=stock_code).first(): session.add(StockInfo( stock_code=stock_code, stock_name=stock_name )) session.commit() logger.info(f"新增映射: {stock_code} → {stock_name}") except Exception as e: session.rollback() logger.error(f"插入映射失败: {e}") finally: session.close() # 处理数据行,写入CSV with open(file_path, 'r', encoding=encoding) as infile, \ open(output_path, 'w', newline='', encoding='utf-8') as outfile: reader = csv.reader(infile, delimiter='\t') writer = csv.writer(outfile, quoting=csv.QUOTE_NONE, escapechar='\\') try: next(reader) # 跳过名称行 next(reader) # 跳过表头行 except StopIteration: logger.warning(f"文件 {file_path} 格式异常,行数不足") return file_path, 0, stock_code, stock_name records_count = 0 for row in reader: # 清洗数据,去除引号和转义字符 cleaned_row = [field.replace('"', '').replace('\\', '') for field in row] writer.writerow(cleaned_row) records_count += 1 logger.info(f"处理文件: {file_path} → {output_path},记录数: {records_count}") return output_path, records_count, stock_code, stock_name except UnicodeDecodeError as e: logger.error(f"文件 {file_path} 编码错误: {e}") return file_path, 0, "", "" except Exception as e: logger.error(f"处理文件 {file_path} 时出错: {e}") return file_path, 0, "", "" def clean_data(): """并行清洗数据""" logger.info("开始清洗数据...") processed_files = 0 total_records = 0 # 收集所有需要处理的文件 files_to_process = [] for root, dirs, files in os.walk(DATA_DIR): for filename in files: if filename.endswith('.txt'): files_to_process.append(os.path.join(root, filename)) logger.info(f"发现 {len(files_to_process)} 个文件需要处理") # 使用线程池并行处理文件 with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: results = list(executor.map(clean_and_save_file, files_to_process)) # 统计结果 for output_path, records, code, name in results: if records > 0: processed_files += 1 total_records += records logger.info(f"数据清洗完成,共处理: {processed_files} 个文件,总记录数: {total_records}") return processed_files, total_records def create_db_table(): """创建数据库表""" logger.info("开始创建数据库表...") try: # 先删除旧表(谨慎!正式环境需备份数据) Base.metadata.drop_all(engine) # 创建所有表 Base.metadata.create_all(engine) logger.info("数据库表创建成功") except Exception as e: logger.error(f"创建数据库表失败: {e}") raise def import_file_to_db(file_path: str, stock_code_map: Dict[str, str]) -> Tuple[str, int, bool, str]: """将单个文件导入数据库""" try: filename = os.path.basename(file_path) parent_dir = os.path.basename(os.path.dirname(file_path)) # 解析adj_type if '_adj_data' in parent_dir: adj_type = parent_dir.split('_adj_data')[0] else: adj_type = parent_dir logger.warning(f"文件夹命名不规范: {parent_dir},直接使用文件夹名作为 adj_type") # 获取股票名称 stock_code = re.search(r'^(\d+)', filename).group(1) if re.search(r'^(\d+)', filename) else "未知" stock_name = stock_code_map.get(stock_code, "未知") # 批量处理数据 with Session() as session, open(file_path, 'r', encoding='utf-8') as infile: reader = csv.reader(infile) records = [] parse_errors = 0 imported_records = 0 for row in reader: if len(row) < 7: parse_errors += 1 continue try: # 安全解析日期 date_str = row[0] date_obj = None # 增加对 %Y%m%d 格式的支持 for fmt in ['%Y/%m/%d', '%Y-%m-%d', '%Y%m%d']: try: date_obj = datetime.strptime(date_str, fmt).date() break except ValueError: continue if not date_obj: logger.warning(f"日期格式不支持: {date_str}") parse_errors += 1 continue # 解析价格数据 open_price = float(row[1]) high_price = float(row[2]) low_price = float(row[3]) close_price = float(row[4]) # 数据验证 if open_price > high_price: logger.warning(f"异常数据:开盘价({open_price}) > 最高价({high_price}),跳过该行") continue if not (low_price <= close_price <= high_price): logger.warning(f"异常数据:收盘价({close_price})不在[最低价,最高价]区间,跳过该行") continue # 创建数据对象 record = StockData( stock_code=stock_code, stock_name=stock_name, adj_type=adj_type, date=date_obj, open=open_price, high=high_price, low=low_price, close=close_price, volume=int(row[5]), amount=float(row[6]) if len(row) > 6 else 0.0 ) records.append(record) # 分批提交 if len(records) >= BATCH_SIZE: with session.begin_nested(): session.bulk_save_objects(records) imported_records += len(records) records = [] except (ValueError, IndexError) as e: parse_errors += 1 logger.debug(f"解析行时出错 (文件: {filename}): {e}") continue # 提交剩余记录 if records: with session.begin_nested(): session.bulk_save_objects(records) imported_records += len(records) if parse_errors > 0: logger.warning(f"文件 {filename} 解析错误: {parse_errors} 行被跳过") if imported_records > 0: logger.info(f"导入文件: {filename},股票代码: {stock_code},记录数: {imported_records}") return filename, imported_records, True, "" else: logger.warning(f"文件 {filename} 没有有效记录") return filename, 0, False, "无有效记录" except FileNotFoundError as e: logger.error(f"文件不存在: {file_path}") return filename, 0, False, f"文件不存在: {str(e)}" except UnicodeDecodeError as e: logger.error(f"文件编码错误: {file_path},不是 UTF-8 格式") return filename, 0, False, f"编码错误: {str(e)}" except (ValueError, IndexError) as ve: logger.error(f"数据解析错误: {file_path}, 错误: {ve}") return filename, 0, False, f"数据解析错误: {str(ve)}" except exc.SQLAlchemyError as sae: logger.error(f"数据库操作失败: {sae}") return filename, 0, False, f"数据库错误: {str(sae)}" except Exception as e: logger.error(f"未知错误: {e}") return filename, 0, False, f"未知错误: {str(e)}" def import_to_db(): """并行导入数据到数据库""" logger.info("开始导入数据到数据库...") files_imported = 0 total_records = 0 failed_files = [] # 预加载所有股票代码映射(解决N+1查询问题) with Session() as session: stock_code_map = { code: name for code, name in session.query(StockInfo.stock_code, StockInfo.stock_name).all() } # 收集所有需要导入的文件 files_to_import = [] for root, dirs, files in os.walk(OUTPUT_DIR): for filename in files: if filename.endswith('.csv'): files_to_import.append(os.path.join(root, filename)) logger.info(f"发现 {len(files_to_import)} 个文件需要导入") # 使用线程池并行导入 with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: results = list(executor.map(lambda f: import_file_to_db(f, stock_code_map), files_to_import)) # 统计结果 for filename, records, success, reason in results: if success: files_imported += 1 total_records += records else: failed_files.append((filename, reason)) # 输出失败文件汇总 if failed_files: logger.error(f"共 {len(failed_files)} 个文件导入失败:") for file, reason in failed_files: logger.error(f" - {file}: {reason}") logger.info(f"数据导入完成,共导入: {files_imported} 个文件,总记录数: {total_records}") return files_imported, total_records, failed_files def verify_data(): """验证数据库中的数据""" logger.info("开始验证数据...") with Session() as session: try: # 统计总记录数 total_records = session.query(StockData).count() # 统计股票数量 stock_count = session.query(StockData.stock_code).distinct().count() # 检查日期范围 min_date = session.query(func.min(StockData.date)).scalar() max_date = session.query(func.max(StockData.date)).scalar() # 检查价格异常 price_anomalies = session.query(StockData).filter( (StockData.open > StockData.high) | (StockData.close < StockData.low) | (StockData.close > StockData.high) ).count() logger.info(f"数据验证结果:") logger.info(f" 总记录数: {total_records}") logger.info(f" 股票数量: {stock_count}") logger.info(f" 日期范围: {min_date} 至 {max_date}") logger.info(f" 价格异常记录数: {price_anomalies}") return { 'total_records': total_records, 'stock_count': stock_count, 'date_range': f"{min_date} 至 {max_date}", 'price_anomalies': price_anomalies } except Exception as e: logger.error(f"数据验证失败: {e}") raise # ========= 主函数 ========= def main(): logger.info("=== 开始执行股票数据处理脚本 ===") try: # 1. 创建目录 create_directories() # 2. 初始化数据库连接 init_db() # 3. 创建表 create_db_table() # 4. 重命名文件 renamed_files = rename_files() # 5. 清洗数据(并行处理) processed_files, cleaned_records = clean_data() # 6. 导入数据(并行处理) imported_files, imported_records, failed_files = import_to_db() # 7. 验证数据 verification = verify_data() logger.info("=== 数据处理完成 ===") logger.info(f"处理统计:") logger.info(f" 重命名文件: {renamed_files}") logger.info(f" 清洗文件: {processed_files}") logger.info(f" 清洗记录: {cleaned_records}") logger.info(f" 导入文件: {imported_files}") logger.info(f" 导入记录: {imported_records}") logger.info(f" 验证结果: {verification}") except Exception as e: logger.error(f"脚本执行失败: {e}") exit(1) if __name__ == "__main__": main() 这些代码我改正了,你再看看还有什么不足嘛?
最新发布
07-30
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值