import os
import csv
import glob
from openai import OpenAI
from typing import List, Tuple
import re
import pandas as pd
from collections import defaultdict
from transformers import BertTokenizer, BertForSequenceClassification
import torch
import torch.nn.functional as F
from tqdm import tqdm
from nrclex import NRCLex
import numpy as np
import pysentiment2 as ps
from nltk.sentiment.vader import SentimentIntensityAnalyzer
def english_sentence_split(text, max_length=60):
"""
用英文标点分句,并合并到每行不超过max_length字符,尽量不截断语义。
"""
sentences = re.split(r'(?<=[.!?])\s+', text)
result = []
buffer = ""
for sent in sentences:
sent = sent.strip()
if not sent:
continue
if len(buffer) + len(sent) + 1 <= max_length:
buffer = buffer + " " + sent if buffer else sent
else:
if buffer:
result.append(buffer)
buffer = sent
if buffer:
result.append(buffer)
return result
class MarketContentExtractor:
def __init__(self, api_key: str):
self.client = OpenAI(
api_key=api_key,
base_url="https://api.deepseek.com"
)
def extract_market_content(self, conversation_text: str) -> List[str]:
"""
只提取市场相关的原始中文句子,返回句子列表
"""
system_prompt = """
你是一个市场专家,只提取输出市场相关的句子,不要解释。
重要要求:
1. 与市场行情相关的内容保留
2. 有关市场涨跌的内容也要保留
3. 只输出市场行情相关的原句、对于现阶段市场涨跌判断的原句、对于某些大宗商品品类涨跌感慨、行业情景近况好坏的原句,以及所有表达盈利或亏损的原句,不要修改任何内容
4. 每个完整的话题或相关联的句子组合占一行
5. 每行长度不宜过短(不少于15字),也不宜过长(不超过60字),如有需要可适当合并或拆分句子
6. 保持内容的原始顺序
7. 如果没有市场行情相关内容、对于现阶段市场涨跌判断的内容、对于某些大宗商品品类涨跌感慨的内容或表达盈利亏损的内容,输出"无相关内容"
8. 不要添加任何解释或标注
9. 每个市场相关的完整句子单独成行,不要把多个话题或问答合并到一行。
10. 若只讨论某个品类的相关信息,而不讨论其市场行情、库存量情况或盈亏情况的,输出"无相关内容"
11. 如果某一句话完全不符合大宗商品业务员交流逻辑,且无法推断其原意,直接删除该句,不要保留无意义内容
"""
user_prompt = f"请仅保留并输出与市场相关的内容(原文中文),不改变原内容,不要翻译:\n\n{conversation_text}"
try:
response = self.client.chat.completions.create(
model="deepseek-chat",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
temperature=1.0,
max_tokens=6000,
stream=False
)
result = response.choices[0].message.content.strip()
if "无相关内容" in result or "无市场相关内容" in result:
return []
chinese_sentences = [line.strip() for line in result.split('\n') if line.strip()]
except Exception as e:
print(f"API调用失败: {e}")
return []
return chinese_sentences
def batch_translate(self, chinese_sentences: list, batch_size=20) -> list:
"""
批量翻译,输入中文句子列表,返回英文句子列表,顺序一一对应。
每句加编号,自动补漏。
"""
if not chinese_sentences:
return []
en_sentences = []
for base_idx in range(0, len(chinese_sentences), batch_size):
batch = chinese_sentences[base_idx:base_idx + batch_size]
# 给每句加编号,编号从1开始
numbered_batch = [f"{idx + 1}. {sent}" for idx, sent in enumerate(batch)]
user_prompt = (
"请将以下每一句分别翻译成英文,每行保留编号和英文,对应原文顺序,不要合并、不要省略、不要解释:\n"
+ "\n".join(numbered_batch)
)
try:
response = self.client.chat.completions.create(
model="deepseek-chat",
messages=[
{"role": "system", "content": (
"你是一个专业的中英翻译专家。"
"请严格逐句翻译并保留编号,每行输出编号和英文翻译,不能合并、不能省略、不能解释。"
"如遇难以翻译的内容请照抄原文。"
)},
{"role": "user", "content": user_prompt}
],
temperature=1.3,
max_tokens=2000,
stream=False
)
result = response.choices[0].message.content.strip()
# 解析编号开头的行
batch_en = [""] * len(batch)
for line in result.split('\n'):
line = line.strip()
m = re.match(r'^(\d+)\.\s*(.+)$', line)
if m:
idx = int(m.group(1)) - 1
if 0 <= idx < len(batch):
batch_en[idx] = m.group(2).strip()
# 检查是否有漏译
missing_indices = [i for i, en in enumerate(batch_en) if not en]
if missing_indices:
print(f"发现 {len(missing_indices)} 条漏译,自动补译...")
for i in missing_indices:
single_prompt = (
f"请将以下中文句子翻译成英文,输出时保留编号,不要省略、不要解释:\n"
f"{i + 1}. {batch[i]}"
)
try:
single_response = self.client.chat.completions.create(
model="deepseek-chat",
messages=[
{"role": "system", "content": (
"你是一个专业的中英翻译专家。"
"请严格逐句翻译并保留编号,不能省略、不能解释。"
)},
{"role": "user", "content": single_prompt}
],
temperature=1.3,
max_tokens=200,
stream=False
)
single_result = single_response.choices[0].message.content.strip()
m2 = re.match(r'^\d+\.\s*(.+)$', single_result)
if m2:
batch_en[i] = m2.group(1).strip()
else:
batch_en[i] = single_result
except Exception as e:
print(f"单句补译失败: {e}")
batch_en[i] = ""
en_sentences.extend(batch_en)
except Exception as e:
print(f"批量翻译失败: {e}")
en_sentences.extend([""] * len(batch))
return en_sentences
def read_txt_file(self, file_path: str) -> str:
try:
with open(file_path, 'r', encoding='utf-8') as file:
return file.read()
except UnicodeDecodeError:
try:
with open(file_path, 'r', encoding='gbk') as file:
return file.read()
except Exception as e:
print(f"读取文件 {file_path} 失败: {e}")
return ""
except Exception as e:
print(f"读取文件 {file_path} 失败: {e}")
return ""
def save_to_csv(self, sentence_pairs: List[Tuple[str, str]], csv_path: str):
"""将英文句和原中文句保存到CSV文件,每行不超过60字"""
try:
with open(csv_path, 'w', newline='', encoding='utf-8-sig') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(['number', 'sentence', 'origin_chinese'])
idx = 1
for en, cn in sentence_pairs:
split_sentences = english_sentence_split(en, 60)
for seg in split_sentences:
writer.writerow([idx, seg.strip(), cn])
idx += 1
print(f"成功保存 {idx-1} 条市场相关内容到 {csv_path}")
except Exception as e:
print(f"保存CSV文件 {csv_path} 失败: {e}")
def process_single_file(self, txt_path: str, output_dir: str):
print(f"正在处理: {txt_path}")
conversation_text = self.read_txt_file(txt_path)
if not conversation_text:
print(f"文件 {txt_path} 内容为空或读取失败")
return
chinese_sentences = self.extract_market_content(conversation_text)
if not chinese_sentences:
print(f"文件 {txt_path} 中未找到市场相关内容,不生成CSV文件。")
return
en_sentences = self.batch_translate(chinese_sentences)
market_sentence_pairs = list(zip(en_sentences, chinese_sentences))
base_name = os.path.splitext(os.path.basename(txt_path))[0]
csv_path = os.path.join(output_dir, f"{base_name}_market_content.csv")
self.save_to_csv(market_sentence_pairs, csv_path)
def batch_process(self, input_dir: str, output_dir: str):
os.makedirs(output_dir, exist_ok=True)
txt_files = glob.glob(os.path.join(input_dir, "*.txt"))
if not txt_files:
print(f"在目录 {input_dir} 中未找到txt文件")
return
print(f"找到 {len(txt_files)} 个txt文件")
for i, txt_file in enumerate(txt_files, 1):
print(f"\n[{i}/{len(txt_files)}] ", end="")
self.process_single_file(txt_file, output_dir)
print(f"\n批量处理完成!结果保存在: {output_dir}")
def finbert_batch_predict(df, text_col, model, tokenizer, batch_size=32, device='cuda'):
labels = model.config.id2label
for label in labels.values():
df[label] = 0.0
all_texts = df[text_col].astype(str).tolist()
all_probs = []
model = model.to(device)
model.eval()
with torch.no_grad():
for i in range(0, len(all_texts), batch_size):
batch_texts = all_texts[i:i+batch_size]
inputs = tokenizer(batch_texts, padding=True, truncation=True, return_tensors='pt', max_length=512)
inputs = {k: v.to(device) for k, v in inputs.items()}
outputs = model(**inputs)
logits = outputs.logits
probs = F.softmax(logits, dim=1).cpu().numpy()
all_probs.append(probs)
all_probs = np.vstack(all_probs) # shape: (N, 3)
for idx, label in enumerate(labels.values()):
df[label] = all_probs[:, idx]
# 计算SSI
avg_pos = np.mean(df['positive'])
avg_neg = np.mean(df['negative'])
ssi = np.log(1 + (avg_pos - avg_neg) / (avg_pos + avg_neg + 1))
df['SSI'] = ssi
return df
def main():
API_KEY = "sk-8f3153cd288249358ee7db91f0d891f5" # 替换为您的DeepSeek API密钥
INPUT_DIR = r"txt/txt5" # txt文件所在目录
OUTPUT_DIR = r"outputcsv5" # CSV输出目录
extractor = MarketContentExtractor(API_KEY)
extractor.batch_process(INPUT_DIR, OUTPUT_DIR)
#================ 删除sentence列仅有No relevant content的CSV文件======================
for filename in os.listdir(OUTPUT_DIR):
if filename.endswith('.csv'):
file_path = os.path.join(OUTPUT_DIR, filename)
try:
df = pd.read_csv(file_path)
if 'sentence' in df.columns:
first_row = str(df['sentence'].iloc[0]).strip()
if first_row == 'No relevant content':
os.remove(file_path)
print(f"已删除: {filename}")
else:
print(f"{filename} 没有'sentence'列")
except Exception as e:
print(f"处理{filename}时出错: {e}")
# ====================用于按日期归类文件===================
merge_path = r"outputcsv5/date_data/"
os.makedirs(merge_path, exist_ok=True)
date_files = defaultdict(list)
pattern = re.compile(r'_(\d{14})_market_content')
for filename in os.listdir(OUTPUT_DIR):
if filename.endswith('.csv'):
match = pattern.search(filename)
if match:
date_str = match.group(1)[:8]
date_files[date_str].append(os.path.join(OUTPUT_DIR, filename))
for date_str, files in date_files.items():
dfs = []
for file in files:
df = pd.read_csv(file)
dfs.append(df)
if dfs:
merged_df = pd.concat(dfs, ignore_index=True)
merged_df['number'] = range(1, len(merged_df) + 1)
merged_df.to_csv(os.path.join(merge_path, f"{date_str}.csv"), index=False)
print(f"已生成合并文件: {os.path.join(merge_path, f'{date_str}.csv')}")
# ====================Finbert批量打分===================
model_path = r'model/Finbert'
finbert = BertForSequenceClassification.from_pretrained(model_path)
tokenizer = BertTokenizer.from_pretrained(model_path)
for filename in os.listdir(merge_path):
if filename.endswith('.csv'):
file_path = os.path.join(merge_path, filename)
df = pd.read_csv(file_path)
if 'sentence' not in df.columns:
print(f"{filename} 缺少'sentence'列,跳过。")
continue
tqdm.pandas(desc=f"FinBERT处理 {filename}")
df = finbert_batch_predict(df, 'sentence', finbert, tokenizer, batch_size=32, device='cuda')
df.to_csv(file_path, index=False)
print(f"已处理并覆盖保存: {file_path}")
# ====================NRC词典打分===================
nrc_labels = ['anger', 'disgust', 'negative', 'sadness', 'fear', 'trust', 'anticipation', 'positive']
nrc_labels_with_suffix = [label + '_nrc' for label in nrc_labels]
for filename in os.listdir(merge_path):
if filename.endswith('.csv'):
file_path = os.path.join(merge_path, filename)
df = pd.read_csv(file_path)
for label in nrc_labels_with_suffix:
df[label] = 0.0
for idx, row in tqdm(df.iterrows(), total=len(df), desc=f"NRC处理 {filename}"):
text = str(row['sentence']).strip()
if not text:
continue
emotion = NRCLex(text)
emotion_scores = emotion.raw_emotion_scores
total = sum(emotion_scores.values())
for ori_label, label in zip(nrc_labels, nrc_labels_with_suffix):
if total > 0 and ori_label in emotion_scores:
df.at[idx, label] = emotion_scores[ori_label] / total
else:
df.at[idx, label] = 0.0
df.to_csv(file_path, index=False)
print(f"已处理并覆盖保存: {file_path}")
print("所有CSV文件处理完成。")
# ====================LM、VADER、HIV4词典打分===================
lm = ps.LM()
hiv4 = ps.HIV4()
vader = SentimentIntensityAnalyzer()
def lm_score(text):
if pd.isnull(text):
return 0
tokens = lm.tokenize(str(text))
score = lm.get_score(tokens)
return score['Polarity']
def hiv4_score(text):
if pd.isnull(text):
return 0
tokens = hiv4.tokenize(str(text))
score = hiv4.get_score(tokens)
return score['Polarity']
def vader_score(text):
if pd.isnull(text):
return 0
score = vader.polarity_scores(str(text))
return score['compound']
for filename in os.listdir(merge_path):
if filename.endswith('.csv'):
file_path = os.path.join(merge_path, filename)
df = pd.read_csv(file_path)
if 'sentence' not in df.columns:
print(f"{filename} 缺少'sentence'列,跳过。")
continue
tqdm.pandas(desc=f"词典打分 {filename}")
df['lm_score'] = df['sentence'].progress_apply(lm_score)
df['hiv4_score'] = df['sentence'].progress_apply(hiv4_score)
df['vader_score'] = df['sentence'].progress_apply(vader_score)
df.to_csv(file_path, index=False)
print(f"{filename} 处理完成并已覆盖保存。")
print("所有CSV文件处理完成。")
if __name__ == "__main__":
main()
修改以上代码,保证可以随时增加或删减字段
最新发布