From 2c0b9cb5045776a974c42016427fffb0a52da595 Mon Sep 17 00:00:00 2001 From: xbfighting Date: Wed, 25 Mar 2026 10:27:44 +0800 Subject: [PATCH 1/7] =?UTF-8?q?docs:=20=E6=B7=BB=E5=8A=A0=20CLAUDE.md=20?= =?UTF-8?q?=E9=A1=B9=E7=9B=AE=E6=8C=87=E5=BC=95=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Opus 4.6 (1M context) --- CLAUDE.md | 66 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 CLAUDE.md diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..23b535a --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,66 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## 项目概述 + +tdx2db:从本地通达信(TDX)行情软件读取 A 股数据,增量同步到数据库。是量化分析工作站的数据入口。 + +## 常用命令 + +```bash +# 安装依赖 +pip install -r requirements.txt + +# 一键增量同步(日线 + 5/15/30/60 分钟线)— 日常使用这一个命令即可 +python main.py sync + +# 单独同步 +python main.py daily --db-only --auto-start --incremental +python main.py minutes --db-only --auto-start --incremental + +# 同步股票列表 +python main.py stock-list --db-only +``` + +无测试套件。验证方式是运行 `sync` 命令后检查数据库数据。 + +## 架构 + +四层管道,单向数据流: + +``` +CLI (cli.py) → Reader (reader.py) → Processor (processor.py) → Storage (storage.py) + ↓ ↓ ↓ ↓ + argparse pytdx 读取本地 清洗 + 计算均线 SQLAlchemy 写库 + 命令分发 .day/.lc5 文件 (MA5~MA250) 支持增量 ON CONFLICT +``` + +- **config.py**: 全局单例 `config`,从 `.env` 加载配置(TDX_PATH、DB_*) +- **logger.py**: 全局单例 `logger` + +### 关键数据流 + +1. **日线**: 读取 `vipdoc/{sz,sh}/lday/*.day` → `process_daily_data()` 添加 date 列和均线 → 写入 `daily_data` 表 +2. **分钟线**: 读取 `vipdoc/{sz,sh}/fzline/*.lc5`(5 分钟原始数据)→ resample 为 15/30/60 分钟 → `process_min_data()` 计算均线 → 分别写入 `minute{5,15,30,60}_data` 表 +3. **增量同步**: `save_incremental()` 使用 `ON CONFLICT DO NOTHING`(PostgreSQL)/ `INSERT IGNORE`(MySQL)跳过重复。分钟线按股票精确查询最新日期(`get_latest_datetime_by_code`),日线按全局最新日期。 + +### 数据库表 + +| 表名 | 唯一约束 | 用途 | +|------|----------|------| +| `daily_data` | (code, date) | 日线数据 | +| `minute{5,15,30,60}_data` | (code, datetime) | 分钟线数据 | +| `stock_info` | code | 股票列表 | +| `block_stock_relation` | — | 板块关系(未完整实现) | + +唯一约束需通过 `scripts/add_constraints.sql` 手动添加。 + +### 股票代码格式 + +代码带市场前缀:`sz000001`、`sh600000`。深圳 market=0,上海 market=1。 +A 股筛选规则:深圳 `000/001/002/300` 开头,上海 `60/688` 开头。 + +## 配置 + +通过 `.env` 文件配置,必填:`TDX_PATH`、`DB_TYPE`、`DB_HOST`、`DB_NAME`、`DB_USER`、`DB_PASSWORD`。 From b7da5cac523fda2803096200479a8c17ae608b01 Mon Sep 17 00:00:00 2001 From: xbfighting Date: Wed, 25 Mar 2026 11:28:55 +0800 Subject: [PATCH 2/7] =?UTF-8?q?perf:=20save=5Fincremental=20=E6=94=B9?= =?UTF-8?q?=E4=B8=BA=E6=89=B9=E9=87=8F=20INSERT=20=E6=9B=BF=E4=BB=A3?= =?UTF-8?q?=E9=80=90=E8=A1=8C=E6=89=A7=E8=A1=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit SQL 预构建移到循环外只构建一次,用 executemany 语义一次传入整批参数, 去除 iterrows 逐行循环。支持 PostgreSQL/MySQL/SQLite。 Co-Authored-By: Claude Opus 4.6 (1M context) --- src/storage.py | 55 +++++++++++++++++++++++++------------------------- 1 file changed, 27 insertions(+), 28 deletions(-) diff --git a/src/storage.py b/src/storage.py index b92d0f0..121e5ca 100644 --- a/src/storage.py +++ b/src/storage.py @@ -286,8 +286,30 @@ def save_incremental( # 获取列名 columns = list(df_to_save.columns) - # 构建 INSERT ... ON CONFLICT DO NOTHING SQL + # 预构建 SQL(只构建一次,循环内复用) db_type = config.db_type + placeholders = ', '.join([f':{col}' for col in columns]) + columns_str = ', '.join(columns) + + if db_type == 'postgresql': + conflict_str = ', '.join(conflict_columns) + sql = text(f""" + INSERT INTO {table_name} ({columns_str}) + VALUES ({placeholders}) + ON CONFLICT ({conflict_str}) DO NOTHING + """) + elif db_type == 'mysql': + sql = text(f""" + INSERT IGNORE INTO {table_name} ({columns_str}) + VALUES ({placeholders}) + """) + elif db_type == 'sqlite': + sql = text(f""" + INSERT OR IGNORE INTO {table_name} ({columns_str}) + VALUES ({placeholders}) + """) + else: + raise ValueError(f"不支持的数据库类型: {db_type}") try: num_batches = (total_rows + batch_size - 1) // batch_size @@ -299,33 +321,10 @@ def save_incremental( end_idx = min((i + 1) * batch_size, total_rows) batch_df = df_to_save.iloc[start_idx:end_idx] - for _, row in batch_df.iterrows(): - values = {col: row[col] for col in columns} - placeholders = ', '.join([f':{col}' for col in columns]) - columns_str = ', '.join(columns) - - if db_type == 'postgresql': - conflict_str = ', '.join(conflict_columns) - sql = f""" - INSERT INTO {table_name} ({columns_str}) - VALUES ({placeholders}) - ON CONFLICT ({conflict_str}) DO NOTHING - """ - elif db_type == 'mysql': - sql = f""" - INSERT IGNORE INTO {table_name} ({columns_str}) - VALUES ({placeholders}) - """ - elif db_type == 'sqlite': - sql = f""" - INSERT OR IGNORE INTO {table_name} ({columns_str}) - VALUES ({placeholders}) - """ - else: - raise ValueError(f"不支持的数据库类型: {db_type}") - - result = conn.execute(text(sql), values) - inserted_count += result.rowcount + # 批量执行:传入参数列表,SQLAlchemy 自动走 executemany + params = batch_df.to_dict('records') + result = conn.execute(sql, params) + inserted_count += result.rowcount conn.commit() From 1f53ae65046efcd65d8a05f855912c60cea30bc4 Mon Sep 17 00:00:00 2001 From: xbfighting Date: Wed, 25 Mar 2026 11:28:59 +0800 Subject: [PATCH 3/7] =?UTF-8?q?refactor:=20Reader=20=E8=81=8C=E8=B4=A3?= =?UTF-8?q?=E5=BD=92=E4=BD=8D=EF=BC=8C=E6=B6=88=E9=99=A4=E9=87=8D=E5=A4=8D?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 从 reader.py 删除 process_and_store_min_data 和 process_single_stock_min_data(~180 行) - 编排逻辑移至 cli.py 的 sync_all_min_data / sync_single_stock_min_data - 提取 RESAMPLE_AGG 常量和 resample_ohlcv() 方法到 processor.py - 提取 _calculate_ma() 消除日线/分钟线 MA 计算的重复代码 - Reader 现在只负责读取文件,不再依赖 Storage 和 Processor Co-Authored-By: Claude Opus 4.6 (1M context) --- src/cli.py | 112 +++++++++++++++++++++- src/processor.py | 100 ++++++++++---------- src/reader.py | 235 ++--------------------------------------------- 3 files changed, 170 insertions(+), 277 deletions(-) diff --git a/src/cli.py b/src/cli.py index 8edab11..65fbf4e 100644 --- a/src/cli.py +++ b/src/cli.py @@ -6,15 +6,123 @@ import argparse import sys from argparse import Namespace +from typing import Optional from datetime import timedelta +import pandas as pd +from tqdm import tqdm + from .reader import TdxDataReader from .processor import DataProcessor from .storage import DataStorage from .config import config from .logger import logger + +def sync_single_stock_min_data( + reader: TdxDataReader, + processor: DataProcessor, + storage: DataStorage, + market: int, + code: str, + start_date: Optional[str] = None, + incremental: bool = True, +) -> bool: + """处理并存储单只股票的分钟数据 + + Args: + reader: 数据读取器 + processor: 数据处理器 + storage: 数据存储器 + market: 市场代码 + code: 股票代码 + start_date: 开始日期 + incremental: 是否启用精确增量 + """ + # 精确增量:查询该股票的最新日期 + if incremental and not start_date: + latest = storage.get_latest_datetime_by_code('minute5_data', code) + if latest: + start_date = (latest + timedelta(days=1)).strftime('%Y-%m-%d') + logger.debug(f"{code} 增量起始日期: {start_date}") + + # 读取5分钟数据 + df_5min = reader.read_5min_data(market, code) + if df_5min.empty: + logger.warning(f"{code} 无5分钟数据") + return False + + # 准备 datetime 索引 + if not pd.api.types.is_datetime64_any_dtype(df_5min['datetime']): + df_5min['datetime'] = pd.to_datetime(df_5min['datetime']) + df_5min['date'] = df_5min['datetime'].dt.date + df_5min = df_5min.set_index('datetime') + + # 重采样为多周期 + df_15min = DataProcessor.resample_ohlcv(df_5min, '15min') + df_30min = DataProcessor.resample_ohlcv(df_5min, '30min') + df_60min = DataProcessor.resample_ohlcv(df_5min, '60min') + df_5min = df_5min.reset_index() + + # 处理、筛选、存储各周期 + freq_data = [ + (df_5min, 5, 'minute5_data'), + (df_15min, 15, 'minute15_data'), + (df_30min, 30, 'minute30_data'), + (df_60min, 60, 'minute60_data'), + ] + + has_data = False + for df, freq, table_name in freq_data: + processed = processor.process_min_data(df) + if start_date: + processed = processor.filter_data_min(processed, start_date=start_date) + if processed.empty: + continue + has_data = True + if incremental: + storage.save_incremental(processed, table_name) + else: + storage.save_minute_data(processed, freq=freq, to_csv=False, to_db=True) + + if has_data: + logger.info(f"{code} 分钟数据已处理并存入数据库") + else: + logger.debug(f"{code} 无新数据需要同步") + + return True + + +def sync_all_min_data( + reader: TdxDataReader, + processor: DataProcessor, + storage: DataStorage, + start_date: Optional[str] = None, +) -> bool: + """编排所有股票的分钟数据同步""" + try: + stocks = reader.get_stock_list() + logger.info(f"处理所有股票的分钟数据,共 {len(stocks)} 只股票") + + iterator = tqdm(stocks.iterrows(), total=len(stocks)) if config.use_tqdm else stocks.iterrows() + + for _, stock in iterator: + code = stock['code'] + market = 1 if code.startswith('sh') else 0 + try: + sync_single_stock_min_data(reader, processor, storage, market, code, start_date) + except FileNotFoundError: + continue + except Exception as e: + logger.error(f"处理 {code} 分钟数据时出错: {e}") + continue + + return True + except Exception as e: + logger.error(f"处理分钟数据时出错: {e}") + return False + def parse_args() -> Namespace: """解析命令行参数 @@ -285,7 +393,7 @@ def main() -> int: # 获取所有股票的分钟线数据 logger.info("开始处理所有股票的分钟线数据...") processor = DataProcessor() - success = reader.process_and_store_min_data(storage, processor, start_date) + success = sync_all_min_data(reader, processor, storage, start_date) if success: logger.info("所有股票的分钟线数据处理完成") else: @@ -355,7 +463,7 @@ def main() -> int: start_date = (latest + timedelta(days=1)).strftime('%Y-%m-%d') logger.info(f"分钟线起始日期: {start_date}") - success = reader.process_and_store_min_data(storage, processor, start_date) + success = sync_all_min_data(reader, processor, storage, start_date) if not success: logger.error("同步分钟线数据时出错") has_error = True diff --git a/src/processor.py b/src/processor.py index c9cb943..a8d0fbd 100644 --- a/src/processor.py +++ b/src/processor.py @@ -5,6 +5,7 @@ - 缺失值处理 - 异常值检测 - 计算技术指标 +- OHLCV 重采样 """ from typing import Optional, List @@ -12,10 +13,59 @@ from .logger import logger +# 重采样聚合规则 +RESAMPLE_AGG = { + 'open': 'first', + 'high': 'max', + 'low': 'min', + 'close': 'last', + 'volume': 'sum', + 'amount': 'sum', + 'code': 'first', + 'market': 'first', +} + +# 均线周期 +MA_WINDOWS = [5, 10, 13, 21, 34, 55, 60, 89, 144, 233, 250] + class DataProcessor: """数据处理类""" + @staticmethod + def resample_ohlcv(df: pd.DataFrame, freq: str) -> pd.DataFrame: + """将 OHLCV 数据重采样到目标频率 + + Args: + df: 带有 DatetimeIndex 的 DataFrame + freq: pandas resample 频率字符串('15min', '30min', '60min') + + Returns: + 重采样后的 DataFrame(已 reset_index) + """ + agg = dict(RESAMPLE_AGG) + if 'date' in df.columns: + agg['date'] = 'first' + result = df.resample(freq).agg(agg).dropna() + result.reset_index(inplace=True) + return result + + @staticmethod + def _calculate_ma(df: pd.DataFrame) -> pd.DataFrame: + """计算均线指标,按股票分组 + + Args: + df: 包含 'close' 和 'code' 列的 DataFrame + + Returns: + 添加了均线列的 DataFrame + """ + for w in MA_WINDOWS: + df[f'ma{w}'] = df.groupby('code')['close'].transform( + lambda x: x.rolling(window=w).mean() + ) + return df + @staticmethod def process_daily_data(df: pd.DataFrame) -> pd.DataFrame: """处理日线数据 @@ -50,30 +100,9 @@ def process_daily_data(df: pd.DataFrame) -> pd.DataFrame: # 用前一个有效值填充缺失值 processed_df[col] = processed_df[col].ffill() - # 计算一些基本的技术指标 + # 计算均线指标 if all(col in processed_df.columns for col in ['close', 'volume']): - # 计算13日均线 - processed_df['ma13'] = processed_df.groupby('code')['close'].transform(lambda x: x.rolling(window=13).mean()) - # 计算21日均线 - processed_df['ma21'] = processed_df.groupby('code')['close'].transform(lambda x: x.rolling(window=21).mean()) - # 计算34日均线 - processed_df['ma34'] = processed_df.groupby('code')['close'].transform(lambda x: x.rolling(window=34).mean()) - # 计算55日均线 - processed_df['ma55'] = processed_df.groupby('code')['close'].transform(lambda x: x.rolling(window=55).mean()) - # 计算89日均线 - processed_df['ma89'] = processed_df.groupby('code')['close'].transform(lambda x: x.rolling(window=89).mean()) - # 计算144日均线 - processed_df['ma144'] = processed_df.groupby('code')['close'].transform(lambda x: x.rolling(window=144).mean()) - # 计算233日均线 - processed_df['ma233'] = processed_df.groupby('code')['close'].transform(lambda x: x.rolling(window=233).mean()) - # 计算5日均线 - processed_df['ma5'] = processed_df.groupby('code')['close'].transform(lambda x: x.rolling(window=5).mean()) - # 计算10日均线 - processed_df['ma10'] = processed_df.groupby('code')['close'].transform(lambda x: x.rolling(window=10).mean()) - # 计算60日均线 - processed_df['ma60'] = processed_df.groupby('code')['close'].transform(lambda x: x.rolling(window=60).mean()) - # 计算250日均线 - processed_df['ma250'] = processed_df.groupby('code')['close'].transform(lambda x: x.rolling(window=250).mean()) + processed_df = DataProcessor._calculate_ma(processed_df) return processed_df @@ -127,30 +156,9 @@ def process_min_data(df: pd.DataFrame) -> pd.DataFrame: # 用前一个有效值填充缺失值 processed_df[col] = processed_df[col].ffill() - # 计算一些基本的技术指标 + # 计算均线指标 if all(col in processed_df.columns for col in ['close', 'volume']): - # 计算13日均线 - processed_df['ma13'] = processed_df.groupby('code')['close'].transform(lambda x: x.rolling(window=13).mean()) - # 计算21日均线 - processed_df['ma21'] = processed_df.groupby('code')['close'].transform(lambda x: x.rolling(window=21).mean()) - # 计算34日均线 - processed_df['ma34'] = processed_df.groupby('code')['close'].transform(lambda x: x.rolling(window=34).mean()) - # 计算55日均线 - processed_df['ma55'] = processed_df.groupby('code')['close'].transform(lambda x: x.rolling(window=55).mean()) - # 计算89日均线 - processed_df['ma89'] = processed_df.groupby('code')['close'].transform(lambda x: x.rolling(window=89).mean()) - # 计算144日均线 - processed_df['ma144'] = processed_df.groupby('code')['close'].transform(lambda x: x.rolling(window=144).mean()) - # 计算233日均线 - processed_df['ma233'] = processed_df.groupby('code')['close'].transform(lambda x: x.rolling(window=233).mean()) - # 计算5日均线 - processed_df['ma5'] = processed_df.groupby('code')['close'].transform(lambda x: x.rolling(window=5).mean()) - # 计算10日均线 - processed_df['ma10'] = processed_df.groupby('code')['close'].transform(lambda x: x.rolling(window=10).mean()) - # 计算60日均线 - processed_df['ma60'] = processed_df.groupby('code')['close'].transform(lambda x: x.rolling(window=60).mean()) - # 计算250日均线 - processed_df['ma250'] = processed_df.groupby('code')['close'].transform(lambda x: x.rolling(window=250).mean()) + processed_df = DataProcessor._calculate_ma(processed_df) return processed_df diff --git a/src/reader.py b/src/reader.py index 8dc2d02..f81cd70 100644 --- a/src/reader.py +++ b/src/reader.py @@ -6,12 +6,10 @@ - 股票列表 """ -from __future__ import annotations - import os import re from pathlib import Path -from typing import TYPE_CHECKING, Optional, List +from typing import Optional, List import pandas as pd from pytdx.reader import TdxDailyBarReader, TdxMinBarReader, TdxLCMinBarReader @@ -20,10 +18,7 @@ from .config import config from .logger import logger - -if TYPE_CHECKING: - from .storage import DataStorage - from .processor import DataProcessor +from .processor import DataProcessor class TdxDataReader: """通达信数据读取类""" @@ -163,47 +158,12 @@ def read_min_data(self, market: int, code: str) -> List[pd.DataFrame]: # 记得定期获取最新的数据,同步进TDX logger.debug(f"数据时间范围: {data.index[0]} ~ {data.index[-1]}") - # 生成15分钟数据 - data_15min = data.resample('15min').agg({ - 'open': 'first', - 'high': 'max', - 'low': 'min', - 'close': 'last', - 'amount': 'sum', - 'volume': 'sum', - 'code': 'first', - 'market': 'first' - }).dropna() - - # 生成30分钟数据 - data_30min = data.resample('30min').agg({ - 'open': 'first', - 'high': 'max', - 'low': 'min', - 'close': 'last', - 'amount': 'sum', - 'volume': 'sum', - 'code': 'first', - 'market': 'first' - }).dropna() - - # 生成60分钟数据 - data_60min = data.resample('60min').agg({ - 'open': 'first', - 'high': 'max', - 'low': 'min', - 'close': 'last', - 'amount': 'sum', - 'volume': 'sum', - 'code': 'first', - 'market': 'first' - }).dropna() + # 重采样生成多周期数据 + data_15min = DataProcessor.resample_ohlcv(data, '15min') + data_30min = DataProcessor.resample_ohlcv(data, '30min') + data_60min = DataProcessor.resample_ohlcv(data, '60min') - # 重置索引,使datetime成为列 data.reset_index(inplace=True) - data_15min.reset_index(inplace=True) - data_30min.reset_index(inplace=True) - data_60min.reset_index(inplace=True) return [data_15min, data_30min, data_60min] @@ -306,189 +266,6 @@ def read_all_daily_data(self) -> pd.DataFrame: return result_df - def process_and_store_min_data( - self, - storage: DataStorage, - processor: DataProcessor, - start_date: Optional[str] = None - ) -> bool: - """处理所有股票的5分钟数据,转换为不同周期,计算技术指标并存入数据库 - - Args: - storage: 数据存储对象 - processor: 数据处理对象 - start_date: 开始日期,格式为'YYYY-MM-DD' - - Returns: - bool: 是否处理成功 - """ - try: - # 获取股票列表 - stocks = self.get_stock_list() - logger.info(f"处理所有股票的分钟数据,共 {len(stocks)} 只股票") - - # 创建进度条 - iterator = tqdm(stocks.iterrows(), total=len(stocks)) if config.use_tqdm else stocks.iterrows() - - for _, stock in iterator: - code = stock['code'] - # 判断市场 - if code.startswith('sh'): - market = 1 # 上海 - else: - market = 0 # 深圳 - - try: - # 处理单只股票的分钟数据 - self.process_single_stock_min_data(market, code, storage, processor, start_date) - except FileNotFoundError: - continue - except Exception as e: - logger.error(f"处理 {code} 分钟数据时出错: {e}") - continue - - return True - except Exception as e: - logger.error(f"处理分钟数据时出错: {e}") - return False - - def process_single_stock_min_data( - self, - market: int, - code: str, - storage: DataStorage, - processor: DataProcessor, - start_date: Optional[str] = None, - incremental: bool = True - ) -> bool: - """处理单只股票的5分钟数据,转换为不同周期,计算技术指标并存入数据库 - - Args: - market: 市场代码 - code: 股票代码 - storage: 数据存储对象 - processor: 数据处理对象 - start_date: 开始日期,格式为'YYYY-MM-DD' - incremental: 是否启用精确增量(按股票查询最新日期) - - Returns: - bool: 是否处理成功 - """ - try: - # 精确增量:查询该股票的最新日期 - if incremental and not start_date: - from datetime import timedelta - latest = storage.get_latest_datetime_by_code('minute5_data', code) - if latest: - start_date = (latest + timedelta(days=1)).strftime('%Y-%m-%d') - logger.debug(f"{code} 增量起始日期: {start_date}") - - # 读取5分钟数据 - df_5min = self.read_5min_data(market, code) - - if df_5min.empty: - logger.warning(f"{code} 无5分钟数据") - return False - - # 确保datetime列存在且为datetime类型 - if 'datetime' not in df_5min.columns: - if isinstance(df_5min.index, pd.DatetimeIndex): - df_5min['datetime'] = df_5min.index - df_5min = df_5min.reset_index() - - # 转换datetime列为日期时间类型 - if not pd.api.types.is_datetime64_any_dtype(df_5min['datetime']): - df_5min['datetime'] = pd.to_datetime(df_5min['datetime']) - - # 添加date列,用于按日期筛选 - df_5min['date'] = df_5min['datetime'].dt.date - - # 设置datetime为索引,用于后续resample操作 - df_5min = df_5min.set_index('datetime') - - # 转换为15分钟数据 - df_15min = df_5min.resample('15min').agg({ - 'open': 'first', - 'high': 'max', - 'low': 'min', - 'close': 'last', - 'volume': 'sum', - 'amount': 'sum', - 'code': 'first', - 'market': 'first', - 'date': 'first' - }).dropna() - - # 转换为30分钟数据 - df_30min = df_5min.resample('30min').agg({ - 'open': 'first', - 'high': 'max', - 'low': 'min', - 'close': 'last', - 'volume': 'sum', - 'amount': 'sum', - 'code': 'first', - 'market': 'first', - 'date': 'first' - }).dropna() - - # 转换为60分钟数据 - df_60min = df_5min.resample('60min').agg({ - 'open': 'first', - 'high': 'max', - 'low': 'min', - 'close': 'last', - 'volume': 'sum', - 'amount': 'sum', - 'code': 'first', - 'market': 'first', - 'date': 'first' - }).dropna() - - # 重置索引 - df_5min = df_5min.reset_index() - df_15min = df_15min.reset_index() - df_30min = df_30min.reset_index() - df_60min = df_60min.reset_index() - - # 使用processor处理数据,计算技术指标 - processed_5min = processor.process_min_data(df_5min) - processed_15min = processor.process_min_data(df_15min) - processed_30min = processor.process_min_data(df_30min) - processed_60min = processor.process_min_data(df_60min) - - # 根据日期筛选 - if start_date: - processed_5min = processor.filter_data_min(processed_5min, start_date=start_date) - processed_15min = processor.filter_data_min(processed_15min, start_date=start_date) - processed_30min = processor.filter_data_min(processed_30min, start_date=start_date) - processed_60min = processor.filter_data_min(processed_60min, start_date=start_date) - - # 检查是否有数据需要保存 - if processed_5min.empty and processed_15min.empty: - logger.debug(f"{code} 无新数据需要同步") - return True - - # 存入数据库(使用增量保存跳过重复) - if incremental: - storage.save_incremental(processed_5min, 'minute5_data') - storage.save_incremental(processed_15min, 'minute15_data') - storage.save_incremental(processed_30min, 'minute30_data') - storage.save_incremental(processed_60min, 'minute60_data') - else: - storage.save_minute_data(processed_5min, freq=5, to_csv=False, to_db=True) - storage.save_minute_data(processed_15min, freq=15, to_csv=False, to_db=True) - storage.save_minute_data(processed_30min, freq=30, to_csv=False, to_db=True) - storage.save_minute_data(processed_60min, freq=60, to_csv=False, to_db=True) - - logger.info(f"{code} 分钟数据已处理并存入数据库") - return True - - except Exception as e: - logger.error(f"处理 {code} 分钟数据时出错: {e}") - return False - - # 板块关系暂时未实现,由于板块文件未找到 def get_block_stock_relation(self) -> pd.DataFrame: """获取通达信板块与股票的对应关系 From ea4671034ea2279d707571b660ca0e88745db174 Mon Sep 17 00:00:00 2001 From: xbfighting Date: Wed, 25 Mar 2026 12:03:30 +0800 Subject: [PATCH 4/7] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=20OHLCV=20?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E8=B4=A8=E9=87=8F=E6=A0=A1=E9=AA=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 在 process_daily_data 和 process_min_data 中增加 _validate_ohlcv 校验: - 价格列必须为正 - OHLC 关系校验(high >= max(open,close), low <= min(open,close)) - 不合格行丢弃并记录警告日志 Co-Authored-By: Claude Opus 4.6 (1M context) --- src/processor.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/src/processor.py b/src/processor.py index a8d0fbd..5fb2e3f 100644 --- a/src/processor.py +++ b/src/processor.py @@ -50,6 +50,44 @@ def resample_ohlcv(df: pd.DataFrame, freq: str) -> pd.DataFrame: result.reset_index(inplace=True) return result + @staticmethod + def _validate_ohlcv(df: pd.DataFrame) -> pd.DataFrame: + """校验 OHLCV 数据质量,丢弃不合格行 + + 校验规则: + 1. 价格列(open/high/low/close)必须 > 0 + 2. OHLC 关系:high >= max(open, close), low <= min(open, close) + + Args: + df: 包含 OHLCV 列的 DataFrame + + Returns: + 校验通过的 DataFrame + """ + required = ['open', 'high', 'low', 'close'] + if not all(col in df.columns for col in required): + return df + + before = len(df) + + # 价格必须为正 + positive_mask = (df[required] > 0).all(axis=1) + + # OHLC 关系校验 + ohlc_mask = ( + (df['high'] >= df[['open', 'close']].max(axis=1)) & + (df['low'] <= df[['open', 'close']].min(axis=1)) + ) + + valid_mask = positive_mask & ohlc_mask + df = df[valid_mask] + + dropped = before - len(df) + if dropped > 0: + logger.warning(f"数据校验丢弃 {dropped} 条不合格记录(价格非正或 OHLC 关系异常)") + + return df + @staticmethod def _calculate_ma(df: pd.DataFrame) -> pd.DataFrame: """计算均线指标,按股票分组 @@ -100,6 +138,9 @@ def process_daily_data(df: pd.DataFrame) -> pd.DataFrame: # 用前一个有效值填充缺失值 processed_df[col] = processed_df[col].ffill() + # 数据质量校验 + processed_df = DataProcessor._validate_ohlcv(processed_df) + # 计算均线指标 if all(col in processed_df.columns for col in ['close', 'volume']): processed_df = DataProcessor._calculate_ma(processed_df) @@ -156,6 +197,9 @@ def process_min_data(df: pd.DataFrame) -> pd.DataFrame: # 用前一个有效值填充缺失值 processed_df[col] = processed_df[col].ffill() + # 数据质量校验 + processed_df = DataProcessor._validate_ohlcv(processed_df) + # 计算均线指标 if all(col in processed_df.columns for col in ['close', 'volume']): processed_df = DataProcessor._calculate_ma(processed_df) From 7339882d2a85c24757651adf283aa66e264513fb Mon Sep 17 00:00:00 2001 From: xbfighting Date: Wed, 25 Mar 2026 12:05:41 +0800 Subject: [PATCH 5/7] =?UTF-8?q?fix:=20P2=20=E4=BF=AE=E5=A4=8D=E4=B8=89?= =?UTF-8?q?=E9=A1=B9=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 补充 Minute5Data ORM 模型,与其他分钟表结构一致 - sync 日线改为逐股票流式处理(sync_all_daily_data),避免全量加载内存 - save_incremental 添加表名白名单校验,防止 SQL 拼接注入 Co-Authored-By: Claude Opus 4.6 (1M context) --- src/cli.py | 69 ++++++++++++++++++++++++++++++++++++++++---------- src/storage.py | 37 +++++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 14 deletions(-) diff --git a/src/cli.py b/src/cli.py index 65fbf4e..70e560a 100644 --- a/src/cli.py +++ b/src/cli.py @@ -94,6 +94,57 @@ def sync_single_stock_min_data( return True +def sync_all_daily_data( + reader: TdxDataReader, + processor: DataProcessor, + storage: DataStorage, + start_date: Optional[str] = None, +) -> bool: + """逐股票流式同步日线数据,避免全量加载到内存""" + try: + stocks = reader.get_stock_list() + logger.info(f"同步日线数据,共 {len(stocks)} 只股票") + + iterator = tqdm(stocks.iterrows(), total=len(stocks)) if config.use_tqdm else stocks.iterrows() + total_inserted = 0 + + for _, stock in iterator: + code = stock['code'] + market = 1 if code.startswith('sh') else 0 + try: + data = reader.read_daily_data(market, code) + if isinstance(data.index, pd.DatetimeIndex) or data.index.name == 'datetime': + data = data.reset_index() + if data.empty: + continue + + processed = processor.process_daily_data(data) + filtered = processor.filter_data(processed, start_date=start_date) + if filtered.empty: + continue + + inserted = storage.save_incremental( + filtered, 'daily_data', + conflict_columns=('code', 'date'), + batch_size=config.db_batch_size + ) + total_inserted += inserted + except FileNotFoundError: + continue + except Exception as e: + logger.error(f"同步 {code} 日线数据时出错: {e}") + continue + + if total_inserted > 0: + logger.info(f"日线数据同步完成,共插入 {total_inserted} 条") + else: + logger.info("日线数据已是最新") + return True + except Exception as e: + logger.error(f"同步日线数据时出错: {e}") + return False + + def sync_all_min_data( reader: TdxDataReader, processor: DataProcessor, @@ -436,20 +487,10 @@ def main() -> int: start_date = (latest + timedelta(days=1)).strftime('%Y-%m-%d') logger.info(f"日线起始日期: {start_date}") - data = reader.read_all_daily_data() - if not data.empty: - processed_data = processor.process_daily_data(data) - filtered_data = processor.filter_data(processed_data, start_date=start_date) - if not filtered_data.empty: - storage.save_incremental( - filtered_data, 'daily_data', - conflict_columns=('code', 'date'), - batch_size=config.db_batch_size - ) - else: - logger.info("日线数据已是最新") - else: - logger.warning("未获取到日线数据") + success = sync_all_daily_data(reader, processor, storage, start_date) + if not success: + logger.error("同步日线数据时出错") + has_error = True except Exception as e: logger.error(f"同步日线数据时出错: {e}") has_error = True diff --git a/src/storage.py b/src/storage.py index 121e5ca..fa0ffcd 100644 --- a/src/storage.py +++ b/src/storage.py @@ -58,6 +58,33 @@ class DailyData(Base): ma60 = Column(Float) ma250 = Column(Float) +class Minute5Data(Base): + """5分钟线数据表模型""" + __tablename__ = 'minute5_data' + + id = Column(Integer, primary_key=True, autoincrement=True) + code = Column(String(10), nullable=False, index=True) + market = Column(Integer, nullable=False) + datetime = Column(DateTime, nullable=False, index=True) + date = Column(DateTime, nullable=False, index=True) + open = Column(Float, nullable=False) + high = Column(Float, nullable=False) + low = Column(Float, nullable=False) + close = Column(Float, nullable=False) + volume = Column(Float, nullable=False) + amount = Column(Float, nullable=False) + ma13 = Column(Float) + ma21 = Column(Float) + ma34 = Column(Float) + ma55 = Column(Float) + ma89 = Column(Float) + ma144 = Column(Float) + ma233 = Column(Float) + ma5 = Column(Float) + ma10 = Column(Float) + ma60 = Column(Float) + ma250 = Column(Float) + class Minute15Data(Base): """15分钟线数据表模型""" __tablename__ = 'minute15_data' @@ -150,6 +177,13 @@ class StockInfo(Base): name = Column(String(50)) market = Column(Integer) +# 允许写入的表名白名单 +_VALID_TABLES = frozenset({ + 'daily_data', 'minute5_data', 'minute15_data', 'minute30_data', 'minute60_data', + 'stock_info', 'block_stock_relation', +}) + + class DataStorage: """数据存储类""" @@ -271,6 +305,9 @@ def save_incremental( Returns: int: 实际插入的行数 """ + if table_name not in _VALID_TABLES: + raise ValueError(f"不允许写入的表名: {table_name}") + if df.empty: logger.warning(f"没有数据可保存到表 {table_name}") return 0 From 2cd7e52cf3e4b54586698994fc53c3972b8cb618 Mon Sep 17 00:00:00 2001 From: xbfighting Date: Wed, 25 Mar 2026 14:00:43 +0800 Subject: [PATCH 6/7] =?UTF-8?q?docs:=20=E5=90=8C=E6=AD=A5=20CLAUDE.md=20?= =?UTF-8?q?=E6=9E=B6=E6=9E=84=E6=8F=8F=E8=BF=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 反映重构后的变化:编排逻辑在 cli.py、逐股票流式处理、 OHLCV 校验、批量 executemany、表名白名单。 Co-Authored-By: Claude Opus 4.6 (1M context) --- CLAUDE.md | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 23b535a..d40708d 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -30,20 +30,24 @@ python main.py stock-list --db-only 四层管道,单向数据流: ``` -CLI (cli.py) → Reader (reader.py) → Processor (processor.py) → Storage (storage.py) - ↓ ↓ ↓ ↓ - argparse pytdx 读取本地 清洗 + 计算均线 SQLAlchemy 写库 - 命令分发 .day/.lc5 文件 (MA5~MA250) 支持增量 ON CONFLICT +CLI (cli.py) → Reader (reader.py) → Processor (processor.py) → Storage (storage.py) + ↓ ↓ ↓ ↓ + argparse pytdx 读取本地 校验 + 重采样 + 均线 SQLAlchemy 批量写库 + 命令分发 + .day/.lc5 文件 (OHLCV 校验, resample, 支持增量 ON CONFLICT + 同步编排 MA5~MA250) 表名白名单保护 ``` +- **cli.py**: 除命令分发外,`sync_all_daily_data` / `sync_all_min_data` / `sync_single_stock_min_data` 编排逐股票流式同步 - **config.py**: 全局单例 `config`,从 `.env` 加载配置(TDX_PATH、DB_*) - **logger.py**: 全局单例 `logger` ### 关键数据流 -1. **日线**: 读取 `vipdoc/{sz,sh}/lday/*.day` → `process_daily_data()` 添加 date 列和均线 → 写入 `daily_data` 表 -2. **分钟线**: 读取 `vipdoc/{sz,sh}/fzline/*.lc5`(5 分钟原始数据)→ resample 为 15/30/60 分钟 → `process_min_data()` 计算均线 → 分别写入 `minute{5,15,30,60}_data` 表 -3. **增量同步**: `save_incremental()` 使用 `ON CONFLICT DO NOTHING`(PostgreSQL)/ `INSERT IGNORE`(MySQL)跳过重复。分钟线按股票精确查询最新日期(`get_latest_datetime_by_code`),日线按全局最新日期。 +日线和分钟线均为**逐股票流式处理**,不全量加载到内存: + +1. **日线**: 逐股票读取 `vipdoc/{sz,sh}/lday/*.day` → `process_daily_data()` 校验 OHLCV + 计算均线 → 增量写入 `daily_data` 表 +2. **分钟线**: 逐股票读取 `.lc5`(5 分钟)→ `resample_ohlcv()` 重采样为 15/30/60 分钟 → `process_min_data()` 校验 + 均线 → 分别写入 `minute{5,15,30,60}_data` 表 +3. **增量同步**: `save_incremental()` 使用批量 executemany + `ON CONFLICT DO NOTHING`(PostgreSQL)/ `INSERT IGNORE`(MySQL)跳过重复。分钟线按股票精确查询最新日期(`get_latest_datetime_by_code`),日线逐股票增量。 ### 数据库表 From c4ec02b4a8a86765ee74097ecf88ccd181e30677 Mon Sep 17 00:00:00 2001 From: xbfighting Date: Wed, 25 Mar 2026 16:55:21 +0800 Subject: [PATCH 7/7] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20code=20review?= =?UTF-8?q?=20=E5=8F=91=E7=8E=B0=E7=9A=84=203=20=E4=B8=AA=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. sync 分钟线不再传全局 start_date,让逐股票精确增量自行查询 2. minutes --code 单股票路径统一走 sync_single_stock_min_data,覆盖 5/15/30/60 全部周期 3. save_incremental 移除不可靠的 rowcount 计数,修正日志措辞 Co-Authored-By: Claude Opus 4.6 (1M context) --- src/cli.py | 70 ++++++++------------------------------------------ src/storage.py | 7 +++-- 2 files changed, 14 insertions(+), 63 deletions(-) diff --git a/src/cli.py b/src/cli.py index 70e560a..195396b 100644 --- a/src/cli.py +++ b/src/cli.py @@ -387,59 +387,17 @@ def main() -> int: # 获取分钟线数据 if args.code and args.market is not None: - # 获取单只股票的分钟线数据 - data_list = reader.read_min_data(args.market, args.code) - - logger.info(f"获取到 {len(data_list)} 种分钟线数据记录") - # 检查数据 - - if data_list[0].empty: - logger.warning("未获取到任何数据") - return 0 - - # [data_15min, data_30min, data_60min] - logger.info(f"生成了 {len(data_list[0])} 条15分钟线数据记录") - logger.info(f"生成了 {len(data_list[1])} 条30分钟线数据记录") - logger.info(f"生成了 {len(data_list[2])} 条60分钟线数据记录") - - # 处理数据 + # 单只股票:统一走 sync_single_stock_min_data,覆盖 5/15/30/60 全部周期 processor = DataProcessor() - processed_data_list = [] - for i, data in enumerate(data_list): - freq = [15, 30, 60][i] # 对应的分钟频率 - processed_data = processor.process_min_data(data) - - # 根据日期筛选 - filtered_data = processor.filter_data( - processed_data, - start_date=start_date, - end_date=args.end_date - ) - - if not filtered_data.empty: - processed_data_list.append((filtered_data, freq)) - logger.info(f"筛选后有 {len(filtered_data)} 条 {freq} 分钟线数据记录") - else: - logger.warning(f"筛选后 {freq} 分钟线没有数据") - - if not processed_data_list: - logger.warning("筛选后所有周期都没有数据") + success = sync_single_stock_min_data( + reader, processor, storage, + args.market, args.code, + start_date=start_date, + incremental=incremental, + ) + if not success: + logger.warning(f"股票 {args.code} 无数据可同步") return 0 - - # 确定保存方式 - to_csv = not args.db_only - to_db = not args.csv_only - - # 保存数据 - for filtered_data, freq in processed_data_list: - table_name = f'minute{freq}_data' - if to_csv: - storage.save_to_csv(filtered_data, table_name) - if to_db: - if incremental: - storage.save_incremental(filtered_data, table_name, batch_size=config.db_batch_size) - else: - storage.save_to_database(filtered_data, table_name, batch_size=config.db_batch_size) else: # 获取所有股票的分钟线数据 logger.info("开始处理所有股票的分钟线数据...") @@ -495,16 +453,10 @@ def main() -> int: logger.error(f"同步日线数据时出错: {e}") has_error = True - # 2. 同步分钟线数据 + # 2. 同步分钟线数据(逐股票精确增量,不传全局 start_date) try: logger.info("=== 同步分钟线数据 ===") - latest = storage.get_latest_datetime('minute15_data') - start_date = None - if latest: - start_date = (latest + timedelta(days=1)).strftime('%Y-%m-%d') - logger.info(f"分钟线起始日期: {start_date}") - - success = sync_all_min_data(reader, processor, storage, start_date) + success = sync_all_min_data(reader, processor, storage) if not success: logger.error("同步分钟线数据时出错") has_error = True diff --git a/src/storage.py b/src/storage.py index fa0ffcd..36dc7d4 100644 --- a/src/storage.py +++ b/src/storage.py @@ -360,16 +360,15 @@ def save_incremental( # 批量执行:传入参数列表,SQLAlchemy 自动走 executemany params = batch_df.to_dict('records') - result = conn.execute(sql, params) - inserted_count += result.rowcount + conn.execute(sql, params) conn.commit() if not config.use_tqdm: logger.info(f"已处理 {end_idx}/{total_rows} 条记录") - logger.info(f"增量保存完成: 共处理 {total_rows} 条,实际插入 {inserted_count} 条到表 {table_name}") - return inserted_count + logger.info(f"增量保存完成: 共处理 {total_rows} 条到表 {table_name}(重复数据已跳过)") + return total_rows except Exception as e: logger.error(f"增量保存数据到表 {table_name} 时出错: {e}")