readme.md
requirementst.txt
源代码:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import talib as ta
import akshare as ak
import baostock as bs
import datetime
from matplotlib.gridspec import GridSpec
# 设置中文显示
plt.rcParams["font.family"] = ["SimHei", "Microsoft YaHei", "SimSun", "KaiTi", "FangSong"]
plt.rcParams["axes.unicode_minus"] = False # 正确显示负号
class StockTechnicalAnalyzer:
def __init__(self, data_source='baostock'):
"""
初始化股票技术分析器
data_source: 数据源,可选 'baostock' 或 'akshare'
"""
self.data_source = data_source
self.data = None
self.stock_code = None
self.start_date = None
self.end_date = None
def get_stock_data(self, stock_code, start_date, end_date):
"""从指定数据源获取股票数据"""
self.stock_code = stock_code
self.start_date = start_date
self.end_date = end_date
print(f"正在从{self.data_source}获取{stock_code}从{start_date}到{end_date}的数据...")
if self.data_source == 'baostock':
return self._get_baostock_data(stock_code, start_date, end_date)
elif self.data_source == 'akshare':
return self._get_akshare_data(stock_code, start_date, end_date)
else:
raise ValueError("数据源必须是 'baostock' 或 'akshare'")
def _get_baostock_data(self, stock_code, start_date, end_date):
"""从baostock获取股票数据"""
# 登录baostock
lg = bs.login()
if lg.error_code != '0':
print(f"登录失败:{lg.error_msg}")
return None
# 获取股票数据
rs = bs.query_history_k_data_plus(
stock_code,
"date,open,high,low,close,volume",
start_date=start_date,
end_date=end_date,
frequency="d",
adjustflag="2" # 前复权
)
# 处理数据
data_list = []
while (rs.error_code == '0') & rs.next():
data_list.append(rs.get_row_data())
# 登出baostock
bs.logout()
# 转换为DataFrame并处理
if not data_list:
print("没有获取到数据")
return None
df = pd.DataFrame(data_list, columns=rs.fields)
# 转换数据类型
df['date'] = pd.to_datetime(df['date'])
df['open'] = df['open'].astype(float)
df['high'] = df['high'].astype(float)
df['low'] = df['low'].astype(float)
df['close'] = df['close'].astype(float)
df['volume'] = df['volume'].astype(float)
df.set_index('date', inplace=True)
self.data = df
print(f"成功获取{len(df)}条数据")
return df
def _get_akshare_data(self, stock_code, start_date, end_date):
"""从akshare获取股票数据"""
try:
# 对于A股,akshare的代码格式为 'sh600000' 或 'sz000000'
# 转换baostock格式到akshare格式
if stock_code.startswith('sh.'):
ak_code = 'sh' + stock_code[3:]
elif stock_code.startswith('sz.'):
ak_code = 'sz' + stock_code[3:]
else:
ak_code = stock_code
# 获取股票数据
df = ak.stock_zh_a_daily(symbol=ak_code, start_date=start_date.replace('-', ''),
end_date=end_date.replace('-', ''), adjust="qfq")
# 重命名列以统一格式
if not df.empty:
df['date'] = pd.to_datetime(df['date'])
df.set_index('date', inplace=True)
self.data = df
print(f"成功获取{len(df)}条数据")
return df
else:
print("没有获取到数据")
return None
except Exception as e:
print(f"获取数据时出错: {str(e)}")
return None
def calculate_indicators(self, rsi_period=14, macd_fast=12, macd_slow=26, macd_signal=9,
bb_period=20, bb_std=2, atr_period=14, cci_period=14,
roc_period=12, wr_period=14, bias_period=6):
"""使用TA-Lib计算各种技术指标"""
if self.data is None:
print("请先获取股票数据")
return None
df = self.data.copy()
# 计算RSI
df['rsi'] = ta.RSI(df['close'], timeperiod=rsi_period)
# 计算MACD
df['macd'], df['macd_signal'], df['macd_hist'] = ta.MACD(
df['close'], fastperiod=macd_fast, slowperiod=macd_slow, signalperiod=macd_signal
)
# 计算布林带
df['bb_upper'], df['bb_middle'], df['bb_lower'] = ta.BBANDS(
df['close'], timeperiod=bb_period, nbdevup=bb_std, nbdevdn=bb_std
)
# 计算ATR
df['atr'] = ta.ATR(df['high'], df['low'], df['close'], timeperiod=atr_period)
df['atr_pct'] = df['atr'] / df['close'] * 100
# 计算MA
df['ma5'] = ta.MA(df['close'], timeperiod=5)
df['ma10'] = ta.MA(df['close'], timeperiod=10)
df['ma20'] = ta.MA(df['close'], timeperiod=20)
df['ma60'] = ta.MA(df['close'], timeperiod=60)
# 计算成交量MA
df['volume_ma5'] = ta.MA(df['volume'], timeperiod=5)
df['volume_ma10'] = ta.MA(df['volume'], timeperiod=10)
# 计算KDJ
df['k'], df['d'] = ta.STOCH(df['high'], df['low'], df['close'])
df['j'] = 3 * df['k'] - 2 * df['d']
# 计算OBV
df['obv'] = ta.OBV(df['close'], df['volume'])
# 计算CCI (顺势指标)
df['cci'] = ta.CCI(df['high'], df['low'], df['close'], timeperiod=cci_period)
# 计算ROC (变化率)
df['roc'] = ta.ROC(df['close'], timeperiod=roc_period)
# 计算W&R (威廉指标)
df['wr'] = -100 * ((df['high'].rolling(window=wr_period).max() - df['close']) /
(df['high'].rolling(window=wr_period).max() - df['low'].rolling(window=wr_period).min()))
# 计算BIAS (乖离率)
df['bias'] = (df['close'] - df['close'].rolling(window=bias_period).mean()) / \
df['close'].rolling(window=bias_period).mean() * 100
# 计算ENE指标 (轨道线)
df['ene_middle'] = df['close'].rolling(window=10).mean()
df['ene_upper'] = df['ene_middle'] * 1.1
df['ene_lower'] = df['ene_middle'] * 0.9
# 计算VI (波动率指标) - 基于真实波幅的移动平均
df['vi_positive'] = ta.SMA(df['high'] - df['open'], timeperiod=14)
df['vi_negative'] = ta.SMA(df['open'] - df['low'], timeperiod=14)
# 计算DMI指标
df['adx'] = ta.ADX(df['high'], df['low'], df['close'], timeperiod=14)
df['plus_di'] = ta.PLUS_DI(df['high'], df['low'], df['close'], timeperiod=14)
df['minus_di'] = ta.MINUS_DI(df['high'], df['low'], df['close'], timeperiod=14)
# 计算PSY (心理线)
df['price_change'] = df['close'].diff(1)
df['positive_days'] = df['price_change'].apply(lambda x: 1 if x > 0 else 0)
df['psy'] = df['positive_days'].rolling(window=12).sum() / 12 * 100
# 清理临时列
if 'price_change' in df.columns:
df.drop('price_change', axis=1, inplace=True)
if 'positive_days' in df.columns:
df.drop('positive_days', axis=1, inplace=True)
self.data = df
return df
def generate_signals(self):
"""基于技术指标生成交易信号"""
if self.data is None:
print("请先获取股票数据并计算指标")
return None
df = self.data.copy()
df['signal'] = 0 # 0表示无信号,1表示买入,-1表示卖出
# RSI超买超卖信号
df.loc[df['rsi'] < 30, 'signal'] += 1 # RSI<30 考虑买入
df.loc[df['rsi'] > 70, 'signal'] += -1 # RSI>70 考虑卖出
# MACD金叉死叉信号
macd_cross_up = (df['macd'] > df['macd_signal']) & (df['macd'].shift(1) <= df['macd_signal'].shift(1))
macd_cross_down = (df['macd'] < df['macd_signal']) & (df['macd'].shift(1) >= df['macd_signal'].shift(1))
df.loc[macd_cross_up, 'signal'] += 1 # MACD金叉 买入信号
df.loc[macd_cross_down, 'signal'] += -1 # MACD死叉 卖出信号
# 布林带突破信号
bb_upper_break = (df['close'] > df['bb_upper']) & (df['close'].shift(1) <= df['bb_upper'].shift(1))
bb_lower_break = (df['close'] < df['bb_lower']) & (df['close'].shift(1) >= df['bb_lower'].shift(1))
df.loc[bb_upper_break, 'signal'] += 1 # 突破上轨 买入信号
df.loc[bb_lower_break, 'signal'] += -1 # 跌破下轨 卖出信号
# 简化信号:>0为买入,<0为卖出
df['final_signal'] = 0
df.loc[df['signal'] > 0, 'final_signal'] = 1
df.loc[df['signal'] < 0, 'final_signal'] = -1
self.data = df
return df
def plot_moving_averages(self):
"""单独绘制移动平均线图表"""
if self.data is None:
print("请先获取股票数据并计算指标")
return None
df = self.data.dropna()
# 创建第二个图形
fig = plt.figure(2, figsize=(16, 8))
fig.clf() # 清除图形内容但保留图形实例
# 绘制价格和移动平均线
plt.plot(df.index, df['close'], label='收盘价', linewidth=2)
plt.plot(df.index, df['ma5'], label='MA5', alpha=0.7)
plt.plot(df.index, df['ma10'], label='MA10', alpha=0.7)
plt.plot(df.index, df['ma20'], label='MA20', alpha=0.7)
plt.plot(df.index, df['ma60'], label='MA60', alpha=0.7)
# 布林带
plt.plot(df.index, df['bb_upper'], 'r--', label='上轨', alpha=0.6)
plt.plot(df.index, df['bb_middle'], 'g--', label='中轨', alpha=0.6)
plt.plot(df.index, df['bb_lower'], 'r--', label='下轨', alpha=0.6)
plt.fill_between(df.index, df['bb_upper'], df['bb_lower'], alpha=0.1, color='gray')
# ENE轨道线
plt.plot(df.index, df['ene_upper'], 'c-.', label='ENE上轨', alpha=0.8)
plt.plot(df.index, df['ene_middle'], 'm-.', label='ENE中轨', alpha=0.8)
plt.plot(df.index, df['ene_lower'], 'c-.', label='ENE下轨', alpha=0.8)
# 交易信号
if 'final_signal' in df.columns:
buy_signals = df[df['final_signal'] == 1]
sell_signals = df[df['final_signal'] == -1]
plt.scatter(buy_signals.index, buy_signals['close'], marker='^', color='g',
label='买入信号', s=100, zorder=3)
plt.scatter(sell_signals.index, sell_signals['close'], marker='v', color='r',
label='卖出信号', s=100, zorder=3)
plt.title(f'{self.stock_code} 价格与移动平均线 ({self.start_date} 至 {self.end_date})')
plt.ylabel('价格')
plt.grid(True, alpha=0.3)
plt.legend()
# 设置x轴日期格式
plt.gca().xaxis.set_major_locator(mdates.MonthLocator(interval=3))
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))
plt.xticks(rotation=45)
plt.tight_layout()
plt.show(block=True) # 阻塞显示,等待用户关闭
return fig
def plot_all_indicators(self):
"""绘制所有技术指标图表(不含均线图)"""
if self.data is None:
print("请先获取股票数据并计算指标")
return None
df = self.data.dropna() # 去除NaN值
# 创建第三个图形
fig = plt.figure(3, figsize=(16, 30))
fig.clf() # 清除图形内容但保留图形实例
gs = GridSpec(9, 1, height_ratios=[1, 2, 2, 2, 2, 2, 2, 2, 2])
# 1. 价格
ax1 = plt.subplot(gs[0])
ax1.plot(df.index, df['close'], label='收盘价', linewidth=2, color='blue')
ax1.set_title(f'{self.stock_code} 价格')
ax1.set_ylabel('价格')
ax1.grid(True, alpha=0.3)
# 2. 成交量
ax2 = plt.subplot(gs[1], sharex=ax1)
ax2.bar(df.index, df['volume'], label='成交量', alpha=0.6, color='gray')
ax2.plot(df.index, df['volume_ma5'], label='成交量MA5', color='blue')
ax2.plot(df.index, df['volume_ma10'], label='成交量MA10', color='red')
ax2.set_ylabel('成交量')
ax2.legend()
ax2.grid(True, alpha=0.3)
# 3. RSI指标
ax3 = plt.subplot(gs[2], sharex=ax1)
ax3.plot(df.index, df['rsi'], label=f'RSI', color='purple', linewidth=2)
ax3.axhline(70, color='r', linestyle='--', alpha=0.7, label='超买线(70)')
ax3.axhline(30, color='g', linestyle='--', alpha=0.7, label='超卖线(30)')
ax3.axhline(50, color='b', linestyle='--', alpha=0.5)
ax3.set_ylabel('RSI值')
ax3.set_ylim(0, 100)
ax3.legend()
ax3.grid(True, alpha=0.3)
# 4. MACD指标
ax4 = plt.subplot(gs[3], sharex=ax1)
ax4.plot(df.index, df['macd'], label='MACD', color='blue')
ax4.plot(df.index, df['macd_signal'], label='信号线', color='red')
ax4.bar(df.index, df['macd_hist'], label='柱状图', alpha=0.5, color='gray')
ax4.axhline(0, color='black', linestyle='-', alpha=0.3)
ax4.set_ylabel('MACD值')
ax4.legend()
ax4.grid(True, alpha=0.3)
# 5. KDJ指标
ax5 = plt.subplot(gs[4], sharex=ax1)
ax5.plot(df.index, df['k'], label='K线', color='blue')
ax5.plot(df.index, df['d'], label='D线', color='red')
ax5.plot(df.index, df['j'], label='J线', color='green')
ax5.axhline(80, color='r', linestyle='--', alpha=0.7, label='超买线(80)')
ax5.axhline(20, color='g', linestyle='--', alpha=0.7, label='超卖线(20)')
ax5.set_ylabel('KDJ值')
ax5.set_ylim(0, 100)
ax5.legend()
ax5.grid(True, alpha=0.3)
# 6. CCI指标
ax6 = plt.subplot(gs[5], sharex=ax1)
ax6.plot(df.index, df['cci'], label='CCI', color='orange', linewidth=2)
ax6.axhline(100, color='r', linestyle='--', alpha=0.7, label='超买线(100)')
ax6.axhline(-100, color='g', linestyle='--', alpha=0.7, label='超卖线(-100)')
ax6.axhline(0, color='b', linestyle='--', alpha=0.5)
ax6.set_ylabel('CCI值')
ax6.legend()
ax6.grid(True, alpha=0.3)
# 7. W&R和BIAS指标
ax7 = plt.subplot(gs[6], sharex=ax1)
# W&R
color1 = 'tab:red'
ax7.set_ylabel('W&R(%)', color=color1)
ax7.plot(df.index, df['wr'], color=color1, label='W&R')
ax7.axhline(-20, color='r', linestyle='--', alpha=0.7, label='超买线(-20)')
ax7.axhline(-80, color='g', linestyle='--', alpha=0.7, label='超卖线(-80)')
ax7.tick_params(axis='y', labelcolor=color1)
ax7.grid(True, alpha=0.3)
# BIAS
ax8 = ax7.twinx()
color2 = 'tab:blue'
ax8.set_ylabel('BIAS(%)', color=color2)
ax8.plot(df.index, df['bias'], color=color2, label='BIAS')
ax8.tick_params(axis='y', labelcolor=color2)
# 合并图例
lines1, labels1 = ax7.get_legend_handles_labels()
lines2, labels2 = ax8.get_legend_handles_labels()
ax7.legend(lines1 + lines2, labels1 + labels2, loc='upper left')
# 8. VI指标
ax9 = plt.subplot(gs[7], sharex=ax1)
ax9.plot(df.index, df['vi_positive'], 'g-', label='VI+', alpha=0.7)
ax9.plot(df.index, df['vi_negative'], 'r-', label='VI-', alpha=0.7)
ax9.set_ylabel('VI值')
ax9.legend()
ax9.grid(True, alpha=0.3)
# 9. OBV指标
ax10 = plt.subplot(gs[8], sharex=ax1)
ax10.set_ylabel('OBV')
ax10.plot(df.index, df['obv'], color='tab:blue', label='OBV')
ax10.tick_params(axis='y')
ax10.legend()
ax10.grid(True, alpha=0.3)
# 设置x轴日期格式
ax10.xaxis.set_major_locator(mdates.MonthLocator(interval=3))
ax10.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))
plt.xticks(rotation=45)
plt.suptitle(f'{self.stock_code} 技术指标分析 ({self.start_date} 至 {self.end_date})', fontsize=16)
plt.tight_layout()
plt.subplots_adjust(top=0.97)
plt.show(block=True) # 阻塞显示,等待用户关闭
return fig
def plot_simplified(self):
"""绘制简化版图表(价格、RSI、MACD)"""
if self.data is None:
print("请先获取股票数据并计算指标")
return None
df = self.data.dropna()
# 直接创建新图形,确保是第一个显示的图表
fig = plt.figure(1, figsize=(16, 15))
fig.clf() # 清除图形内容但保留图形实例
ax1 = plt.subplot(3, 1, 1)
ax2 = plt.subplot(3, 1, 2, sharex=ax1)
ax3 = plt.subplot(3, 1, 3, sharex=ax1)
# 1. 价格和布林带
ax1.plot(df.index, df['close'], label='收盘价', linewidth=2)
ax1.plot(df.index, df['bb_upper'], 'r--', label='上轨')
ax1.plot(df.index, df['bb_middle'], 'g--', label='中轨')
ax1.plot(df.index, df['bb_lower'], 'r--', label='下轨')
ax1.fill_between(df.index, df['bb_upper'], df['bb_lower'], alpha=0.1, color='gray')
if 'final_signal' in df.columns:
buy_signals = df[df['final_signal'] == 1]
sell_signals = df[df['final_signal'] == -1]
ax1.scatter(buy_signals.index, buy_signals['close'], marker='^', color='g',
label='买入信号', s=100, zorder=3)
ax1.scatter(sell_signals.index, sell_signals['close'], marker='v', color='r',
label='卖出信号', s=100, zorder=3)
ax1.set_title(f'{self.stock_code} 价格与布林带')
ax1.set_ylabel('价格')
ax1.legend()
ax1.grid(True, alpha=0.3)
# 2. RSI
ax2.plot(df.index, df['rsi'], label='RSI', color='purple', linewidth=2)
ax2.axhline(70, color='r', linestyle='--', alpha=0.7)
ax2.axhline(30, color='g', linestyle='--', alpha=0.7)
ax2.axhline(50, color='b', linestyle='--', alpha=0.5)
ax2.set_ylabel('RSI值')
ax2.set_ylim(0, 100)
ax2.legend()
ax2.grid(True, alpha=0.3)
# 3. MACD
ax3.plot(df.index, df['macd'], label='MACD', color='blue')
ax3.plot(df.index, df['macd_signal'], label='信号线', color='red')
ax3.bar(df.index, df['macd_hist'], label='柱状图', alpha=0.5)
ax3.axhline(0, color='black', linestyle='-', alpha=0.3)
ax3.set_ylabel('MACD值')
ax3.legend()
ax3.grid(True, alpha=0.3)
# 设置x轴日期格式
ax3.xaxis.set_major_locator(mdates.MonthLocator(interval=3))
ax3.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))
plt.xticks(rotation=45)
plt.tight_layout()
plt.show(block=True) # 阻塞显示,等待用户关闭
return fig
def main():
# 创建分析器实例
analyzer = StockTechnicalAnalyzer(data_source='baostock') # 可选 'baostock' 或 'akshare'
# 设置股票代码和日期范围
#stock_code = 'sh.600938' # 中国海油
stock_code = 'sz.000651' # 格力电器
end_date = datetime.datetime.now().strftime("%Y-%m-%d")
start_date = (datetime.datetime.now() - datetime.timedelta(days=365*2)).strftime("%Y-%m-%d") # 过去2年
# 获取数据
data = analyzer.get_stock_data(stock_code, start_date, end_date)
if data is None or len(data) < 60:
print("数据不足,无法进行分析")
return
# 计算技术指标
analyzer.calculate_indicators()
# 生成交易信号
analyzer.generate_signals()
# 绘制简化图表
print("正在绘制简化图表...")
analyzer.plot_simplified()
# 单独绘制移动平均线图表
print("正在绘制移动平均线图表...")
analyzer.plot_moving_averages()
# 绘制完整图表(包含所有指标,不含详细均线)
print("正在绘制完整指标图表...")
analyzer.plot_all_indicators()
# 显示最新的指标值
latest_data = analyzer.data.iloc[-1]
print(f"\n最新指标值({latest_data.name.date()}):")
print(f"RSI(14): {latest_data['rsi']:.2f}")
print(f"MACD: {latest_data['macd']:.4f}, 信号线: {latest_data['macd_signal']:.4f}, 柱状图: {latest_data['macd_hist']:.4f}")
print(f"KDJ: K={latest_data['k']:.2f}, D={latest_data['d']:.2f}, J={latest_data['j']:.2f}")
print(f"ATR百分比: {latest_data['atr_pct']:.2f}%")
print(f"CCI: {latest_data['cci']:.2f}")
print(f"W&R: {latest_data['wr']:.2f}%")
print(f"BIAS: {latest_data['bias']:.2f}%")
print(f"ROC: {latest_data['roc']:.2f}%")
print(f"PSY: {latest_data['psy']:.2f}%")
print(f"ENE: 上轨={latest_data['ene_upper']:.2f}, 中轨={latest_data['ene_middle']:.2f}, 下轨={latest_data['ene_lower']:.2f}")
print(f"VI: VI+={latest_data['vi_positive']:.2f}, VI-={latest_data['vi_negative']:.2f}")
if __name__ == "__main__":
main()
