项目练习_利用tushare下载股票行情【多线程】
tushare简介
代码
import pandas as pd
import numpy as np
import sqlalchemy
from sqlalchemy import create_engine
import threading
import tushare as ts
import itertools
import threadpool
import datetime
today=datetime.date.today().strftime('%Y%m%d')
class downmarket:
def __new__(cls,save='csv',tdc=1,start_date=today,end_date=today,codes=None,adj='qfq',method='append'):
new=super(cls,cls).__new__(cls)
cls.__init__(new,adj,method)
return new.action(save,tdc,start_date,end_date,codes)
def __init__(self,adj,method):
ts.set_token('1223456')
self.ts=ts
self.pro=ts.pro_api()
self.adj=adj
self.method=method
self.ct=itertools.count(1)
self.errorcode=[]
self.mk_engine=self.get_engine('stock_market')
self.da_engine=self.get_engine('stock_data')
def get_engine(self,db):
engine=create_engine("mysql+pymysql://root:123456@127.0.0.1:3306/%s?charset=utf8mb4"%db)
return engine
def get_basic(self):
'''
股票基本信息
'''
try:
basic=pd.read_sql('select * from basic',con=self.da_engine)
except sqlalchemy.exc.ProgrammingError as e:
print(e)
basic=self.pro.stock_basic(list_status='L')
basic.to_sql('basic',con=self.da_engine,if_exists='replace',index=False)
return basic
def get_market(self,ts_code,start_date,end_date):
'''
下载数据及格式调整
'''
market=self.ts.pro_bar(ts_code=ts_code,pro_api=self.pro,start_date=start_date,end_date=end_date,adj=self.adj)
market.trade_date=market.trade_date.astype(np.str).astype(np.datetime64)
market=market.set_index('trade_date')
market.drop(columns='change',inplace=True)
market.pct_chg=(market.close/market.pre_close-1)*100
market.pct_chg=np.around(market.pct_chg,2)
market.vol=np.around(market.vol,2)
market.amount=np.around(market.amount,2)
return market
def save_sql(self,ts_code,start_date,end_date):
'''
数据库存储
'''
tbname='%s_%s'%(self.adj,ts_code.lower())
market=self.get_market(ts_code,start_date,end_date)
try:
exists_date=pd.read_sql('select trade_date from `%s`'%tbname,con=self.mk_engine).ix[:,0].tolist()
except sqlalchemy.exc.ProgrammingError as e:
print(e)
exists_date=None
if exists_date:
for ind in market.index:
if ind in exists_date:
market.drop(index=ind,inplace=True)
market=market.reset_index()
market.to_sql(tbname,con=self.mk_engine,if_exists=self.method,index=False)
print('num(%s) code(%s) updat_row(%s) is ok...'%(next(self.ct),ts_code,market.shape[0]))
def save_csv(self,ts_code,start_date,end_date):
fname='%s_%s.csv'%(self.adj,ts_code)
market=self.get_market(ts_code,start_date,end_date)
market.to_csv(fname)
print('num(%s) code(%s) is ok...'%(next(self.ct),ts_code))
def action(self,save,tdc,start_date,end_date,codes):
'''
多线程执行
'''
if not codes:
codes=self.get_basic().ts_code.tolist()
if save=='sql':
callable_=self.save_sql
elif save=='csv':
callable_=self.save_csv
tuple_list=[((code,start_date,end_date),None) for code in codes]
pool=threadpool.ThreadPool(tdc)
tasks=threadpool.makeRequests(callable_=callable_,args_list=tuple_list)
for task in tasks:
pool.putRequest(task)
pool.wait()
if len(self.errorcode)==0:
print('全部成功...')
else:
print('部分失败:%s'%self.errorcode)
return self.errorcode
if __name__=='__main__':
downmarket(tdc=2,codes=['000001.SZ','000002.SZ'])