from fastapi import FastAPI, HTTPException, Depends, Header, status
from pydantic import BaseModel, Field, field_validator, ValidationInfo
from datetime import datetime, timedelta
from typing import Optional, Dict, List
import pandas as pd
import numpy as np
import akshare as ak
import logging
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="Stock Technical Analysis API",
description="API for analyzing stock technical indicators",
version="1.0.0")
# Security scheme
security = HTTPBearer()
# Configuration
class Config:
MA_PERIODS = {'short': 5, 'medium': 20, 'long': 60}
RSI_PERIOD = 14
BOLLINGER_PERIOD = 20
BOLLINGER_STD = 2
VOLUME_MA_PERIOD = 20
ATR_PERIOD = 14
VALID_TOKENS = ["xue123", "xue1234"]
DEFAULT_DAYS_BACK = 365
class StockAnalysisRequest(BaseModel):
stock_code: str = Field(..., example="000001", description="Stock symbol code")
market_type: str = Field('A', example="A", description="Market type (A, HK, US, ETF, LOF)")
start_date: Optional[str] = Field(None, example="20230101", description="Start date in YYYYMMDD format")
end_date: Optional[str] = Field(None, example="20231231", description="End date in YYYYMMDD format")
@field_validator('market_type')
@classmethod
def validate_market_type(cls, v):
if v.upper() not in ['A', 'HK', 'US', 'ETF', 'LOF']:
raise ValueError("Invalid market type. Must be one of: A, HK, US, ETF, LOF")
return v.upper()
@field_validator('start_date', 'end_date', mode='before')
@classmethod
def validate_date_format(cls, v):
if v is not None:
try:
datetime.strptime(v, '%Y%m%d')
except ValueError:
raise ValueError("Invalid date format. Must be YYYYMMDD")
return v
class TechnicalSummary(BaseModel):
trend: str
volatility: str
volume_trend: str
rsi_level: float
class StockReport(BaseModel):
stock_code: str
market_type: str
analysis_date: str
score: float
price: float
price_change: float
ma_trend: str
rsi: Optional[float]
macd_signal: str
volume_status: str
recommendation: str
class AnalysisResponse(BaseModel):
technical_summary: TechnicalSummary
recent_data: List[Dict]
report: StockReport
# Authentication
async def verify_auth_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
token = credentials.credentials
if token not in Config.VALID_TOKENS:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid or expired token",
)
return token
# Helper functions
def calculate_ema(series: pd.Series, period: int) -> pd.Series:
if series.empty or period <= 0:
return pd.Series(dtype='float64')
return series.ewm(span=period, adjust=False).mean()
def calculate_rsi(series: pd.Series, period: int) -> pd.Series:
if series.empty or period <= 0:
return pd.Series(dtype='float64')
delta = series.diff()
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
rs = gain / loss
return 100 - (100 / (1 + rs))
def calculate_macd(series: pd.Series) -> tuple:
if series.empty:
return pd.Series(dtype='float64'), pd.Series(dtype='float64'), pd.Series(dtype='float64')
exp1 = series.ewm(span=12, adjust=False).mean()
exp2 = series.ewm(span=26, adjust=False).mean()
macd = exp1 - exp2
signal = macd.ewm(span=9, adjust=False).mean()
hist = macd - signal
return macd, signal, hist
def calculate_bollinger_bands(series: pd.Series, period: int, std_dev: int) -> tuple:
if series.empty or period <= 0:
return pd.Series(dtype='float64'), pd.Series(dtype='float64'), pd.Series(dtype='float64')
middle = series.rolling(window=period).mean()
std = series.rolling(window=period).std()
upper = middle + (std * std_dev)
lower = middle - (std * std_dev)
return upper, middle, lower
def calculate_atr(df: pd.DataFrame, period: int) -> pd.Series:
if df.empty or period <= 0:
return pd.Series(dtype='float64')
high = df['high']
low = df['low']
close = df['close'].shift(1)
tr = pd.concat([
high - low,
abs(high - close),
abs(low - close)
], axis=1).max(axis=1)
return tr.rolling(window=period).mean()
def validate_stock_code(stock_code: str, market_type: str):
if market_type == 'A':
valid_prefixes = ['0', '3', '6', '688', '8']
if not any(stock_code.startswith(prefix) for prefix in valid_prefixes):
raise ValueError(
f"Invalid A-share stock code: {stock_code}. "
"A-share codes should start with 0, 3, 6, 688 or 8"
)
elif market_type == 'HK':
if not stock_code.isdigit() or len(stock_code) != 5:
raise ValueError(
f"Invalid HK stock code: {stock_code}. "
"HK stock codes should be 5 digits"
)
elif market_type == 'US':
if not stock_code.replace('.', '', 1).isalnum():
raise ValueError(
f"Invalid US stock code: {stock_code}. "
"US stock codes should be alphanumeric"
)
def get_stock_data(stock_code: str, market_type: str = 'A',
start_date: str = None, end_date: str = None) -> pd.DataFrame:
"""Fetch stock data from AKShare"""
if start_date is None:
start_date = (datetime.now() - timedelta(days=Config.DEFAULT_DAYS_BACK)).strftime('%Y%m%d')
if end_date is None:
end_date = datetime.now().strftime('%Y%m%d')
validate_stock_code(stock_code, market_type)
fetch_functions = {
'A': lambda: ak.stock_zh_a_hist(symbol=stock_code, start_date=start_date, end_date=end_date, adjust="qfq"),
'HK': lambda: ak.stock_hk_daily(symbol=stock_code, adjust="qfq"),
'US': lambda: ak.stock_us_hist(symbol=stock_code, start_date=start_date, end_date=end_date, adjust="qfq"),
'ETF': lambda: ak.fund_etf_hist_em(symbol=stock_code, period="daily", start_date=start_date, end_date=end_date,
adjust="qfq"),
'LOF': lambda: ak.fund_lof_hist_em(symbol=stock_code, period="daily", start_date=start_date, end_date=end_date,
adjust="qfq")
}
try:
df = fetch_functions[market_type]()
except KeyError:
raise ValueError(f"Unsupported market type: {market_type}")
except Exception as e:
logger.error(f"Failed to fetch data: {str(e)}")
raise HTTPException(status_code=400, detail=f"Data fetch failed: {str(e)}")
if df.empty:
raise HTTPException(status_code=404, detail=f"No data found for {stock_code} in {market_type} market")
# Standardize column names and types
column_mappings = {
"日期": "date",
"开盘": "open",
"收盘": "close",
"最高": "high",
"最低": "low",
"成交量": "volume",
"成交额": "amount",
"涨跌幅": "change_pct"
}
# Rename columns if they exist
df = df.rename(columns={col: column_mappings[col] for col in df.columns if col in column_mappings})
# Ensure required columns exist
required_columns = ['date', 'open', 'close', 'high', 'low', 'volume']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
raise HTTPException(status_code=400, detail=f"Missing required columns: {', '.join(missing_columns)}")
df['date'] = pd.to_datetime(df['date'])
numeric_cols = ['open', 'close', 'high', 'low', 'volume']
if 'amount' in df.columns:
numeric_cols.append('amount')
if 'change_pct' in df.columns:
numeric_cols.append('change_pct')
df[numeric_cols] = df[numeric_cols].apply(pd.to_numeric, errors='coerce')
return df.dropna().sort_values('date')
def calculate_indicators(df: pd.DataFrame) -> pd.DataFrame:
"""Calculate all technical indicators"""
if df.empty:
raise ValueError("Cannot calculate indicators on empty DataFrame")
try:
# Moving Averages
df['MA5'] = calculate_ema(df['close'], Config.MA_PERIODS['short'])
df['MA20'] = calculate_ema(df['close'], Config.MA_PERIODS['medium'])
df['MA60'] = calculate_ema(df['close'], Config.MA_PERIODS['long'])
# Oscillators
df['RSI'] = calculate_rsi(df['close'], Config.RSI_PERIOD)
df['MACD'], df['Signal'], _ = calculate_macd(df['close'])
# Volatility
df['BB_upper'], df['BB_middle'], df['BB_lower'] = calculate_bollinger_bands(
df['close'], Config.BOLLINGER_PERIOD, Config.BOLLINGER_STD
)
df['ATR'] = calculate_atr(df, Config.ATR_PERIOD)
df['Volatility'] = df['ATR'] / df['close'] * 100
# Volume
df['Volume_MA'] = df['volume'].rolling(window=Config.VOLUME_MA_PERIOD).mean()
df['Volume_Ratio'] = df['volume'] / df['Volume_MA']
# Momentum
df['ROC'] = df['close'].pct_change(periods=10) * 100
return df.dropna()
except Exception as e:
logger.error(f"Indicator calculation failed: {str(e)}")
raise
def calculate_score(df: pd.DataFrame) -> int:
"""Calculate composite technical score (0-100)"""
if df.empty:
return 0
latest = df.iloc[-1]
score = 0
# Trend (30 points)
score += 15 if latest['MA5'] > latest['MA20'] else 0
score += 15 if latest['MA20'] > latest['MA60'] else 0
# RSI (20 points)
if 30 <= latest['RSI'] <= 70:
score += 20
elif latest['RSI'] < 30: # Oversold
score += 15
# MACD (20 points)
score += 20 if latest['MACD'] > latest['Signal'] else 0
# Volume (30 points)
if latest['Volume_Ratio'] > 1.5:
score += 30
elif latest['Volume_Ratio'] > 1:
score += 15
return min(100, score) # Cap at 100
def get_recommendation(score: int) -> str:
"""Generate recommendation based on score"""
if score >= 80: return 'Strong Buy'
if score >= 60: return 'Buy'
if score >= 40: return 'Neutral'
if score >= 20: return 'Sell'
return 'Strong Sell'
# Main endpoint
@app.post("/analyze-stock/", response_model=AnalysisResponse)
async def analyze_stock(
request: StockAnalysisRequest,
token: str = Depends(verify_auth_token)
):
try:
logger.info(f"Analyzing stock: {request.stock_code} ({request.market_type})")
# Get and process data
stock_data = get_stock_data(
request.stock_code,
request.market_type,
request.start_date,
request.end_date
)
stock_data = calculate_indicators(stock_data)
if stock_data.empty:
raise HTTPException(status_code=400, detail="Insufficient data for analysis")
# Calculate score and get latest data
score = calculate_score(stock_data)
latest, prev = stock_data.iloc[-1], stock_data.iloc[-2]
# Prepare response
technical_summary = TechnicalSummary(
trend='upward' if latest['MA5'] > latest['MA20'] else 'downward',
volatility=f"{latest['Volatility']:.2f}%",
volume_trend='increasing' if latest['Volume_Ratio'] > 1 else 'decreasing',
rsi_level=latest['RSI']
)
report = StockReport(
stock_code=request.stock_code,
market_type=request.market_type,
analysis_date=datetime.now().strftime('%Y-%m-%d'),
score=score,
price=latest['close'],
price_change=(latest['close'] - prev['close']) / prev['close'] * 100,
ma_trend='UP' if latest['MA5'] > latest['MA20'] else 'DOWN',
rsi=latest['RSI'] if not pd.isna(latest['RSI']) else None,
macd_signal='BUY' if latest['MACD'] > latest['Signal'] else 'SELL',
volume_status='HIGH' if latest['Volume_Ratio'] > 1.5 else 'NORMAL',
recommendation=get_recommendation(score)
)
return AnalysisResponse(
technical_summary=technical_summary,
recent_data=stock_data.tail(14).to_dict('records'),
report=report
)
except ValueError as e:
logger.warning(f"Validation error: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))
except HTTPException as e:
raise
except Exception as e:
logger.error(f"Analysis failed: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")
if __name__ == '__main__':
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8020)