精华帖子

滚动训练5日选股系列之KAN

由bq93t66l创建,最终由bq93t66l 被浏览 1 用户

1. 策略概览

本策略基于KAN模型,在2018年至2025年期间对每年进行滚动训练。训练集时段为过去5年,测试集时段为未来一年,如2018年训练集采用2013-01-01至2017-12-29,测试集时段为2018-01-03至2018-12-31。数据使用当期全市场数据,聚焦于量价数据及其衍生,如5日均值比例,量价5日相关性以及量价横截面分位数排名等。标签设置为未来5日收益率的分位数排名。回测时选取预测分数top50只股票,每5日调仓。

KAN原始模型论文:Z. Liu, Y. Wang, S. Vaidya, F. Ruehle, J. Halverson, M. Soljačić, T. Y. Hou, and M. Tegmark, “KAN: Kolmogorov-Arnold Networks,” arXiv preprint arXiv:2404.19756, 2024.

源码github链接:https://github.com/KindXiaoming/pykan

2. 数据处理

数据采用cn_stock_bar1d表内数据。包括原始量价数据和构建的因子共53个模型输入特征。只对于训练集特征进行4倍标准差winsorize,训练集标签1和99分位数winsorize,测试集特征不做处理避免回测时引入未来信息泄露。具体实现代码如下:

def get_data(start_date, end_date, is_train=True): 

    sql1 = """
    WITH feature_table AS (
        /*基础特征*/
        SELECT date, instrument, close close_0, open open_0, high high_0, low low_0, amount amount_0, turn * 100 turn_0, change_ratio + 1 return_0, 

        /*均线*/
        m_AVG(close,5)/close  ma_close_5,
        m_AVG(turn * 100,5)/turn ma_turn_5, 
        m_AVG(amount,5)/amount ma_amount_5, 
        m_AVG(change_ratio + 1, 5)/(change_ratio + 1)  ma_cr_5,

        /*标准差*/
        m_STDDEV(close, 5) std_close_5,
        m_STDDEV(turn * 100,5) std_turn_5,
        m_STDDEV(amount,5) std_amount_5,
        m_STDDEV(change_ratio + 1,5)  std_cr_5,

        /*排名百分比*/
        m_rolling_rank(close, 5)/5 rank_close_5, 
        m_rolling_rank(low, 5)/5 rank_low_5, 
        m_rolling_rank(open, 5)/5 rank_open_5, 
        m_rolling_rank(high, 5)/5 rank_high_5, 
        m_rolling_rank(turn * 100, 5)/5  rank_turn_5, 
        m_rolling_rank(amount, 5)/5 rank_amount_5, 
        m_rolling_rank(change_ratio+1, 5)/5 rank_cr_5,

        /*相关系数*/
        m_CORR(volume, change_ratio+1, 5)  corr_vcr, 
        m_CORR(volume, close, 5)  corr_vc,
        m_CORR(volume, turn * 100, 5)  corr_vt,

        m_CORR(change_ratio+1, close, 5)  corr_crc,
        m_CORR(change_ratio+1, turn, 5)  corr_crt, 
        m_CORR(high, low, 5)  corr_hl,
        m_CORR(high, close, 5)  corr_hc,
        m_CORR(high, open, 5)  corr_ho,
        m_CORR(low, close, 5)   corr_lc,
        m_CORR(low, open, 5)   corr_lo,
        m_CORR(close, open, 5)  corr_co,
        m_CORR(close, turn * 100, 5)  corr_ct,


        /*截面特征*/
        c_pct_rank(turn) cross_turn,
        c_pct_rank(change_ratio + 1) cross_change_ratio,
        c_pct_rank(ma_close_5) cross_ma_close_5,
        c_pct_rank(ma_turn_5) cross_ma_turn_5,
        c_pct_rank(ma_amount_5) cross_ma_amount_5,
        c_pct_rank(ma_cr_5) cross_ma_cr_5,
        c_pct_rank(std_close_5) cross_std_close_5,
        c_pct_rank(std_turn_5) cross_std_turn_5,
        c_pct_rank(std_amount_5) cross_std_amount_5,
        c_pct_rank(std_cr_5) cross_max_cr_r,
        c_pct_rank(rank_close_5) cross_rank_close_5,
        c_pct_rank(rank_turn_5) cross_rank_turn_5,
        c_pct_rank(rank_amount_5) cross_rank_amount_5,
        c_pct_rank(rank_cr_5) cross_rank_cr_5,
        c_pct_rank(corr_vcr) cross_corr_vcr,
        c_pct_rank(corr_vc) cross_corr_vc,
        c_pct_rank(corr_vt) cross_corr_vt,
        c_pct_rank(corr_crc) cross_corr_crc,
        c_pct_rank(corr_crt) cross_corr_crt,

        FROM cn_stock_bar1d
        QUALIFY COLUMNS(*) IS NOT NULL
    )
    """

    if is_train:
        print('抽取训练集数据')
        sql2 = """
        /*标签*/
        ,
        label_table AS (
            SELECT date, instrument, 
            m_lead(close, 5) / m_lead(open, 1) - 1 AS _future_return, 
            all_quantile_cont(_future_return, 0.01) AS _future_return_1pct, 
            all_quantile_cont(_future_return, 0.99) AS _future_return_99pct, 
            clip(_future_return, _future_return_1pct, _future_return_99pct) AS _label, 
            c_pct_rank(_label) as label, 
            FROM cn_stock_bar1d
            QUALIFY COLUMNS(*) IS NOT NULL AND m_lead(high, 1) != m_lead(low, 1)
        )

        -- 移除特征标准化
        SELECT date, instrument, label, COLUMNS(feature_table.* EXCLUDE (date, instrument)) FROM feature_table
        INNER JOIN label_table USING (date, instrument)
        ORDER BY date, instrument;
        """
    
    else:
        print('抽取测试集数据')
        sql2 = """
        /*数据提取*/
        SELECT feature_table.* FROM feature_table
        ORDER BY date, instrument
        """
    sql = sql1+sql2 
    df = dai.query(sql, filters={'date': [start_date, end_date]}).df()

    df = pl.from_pandas(df)
    df = df.fill_nan(None)
    df = df.select(pl.all().forward_fill().over('instrument'))
    df = df.fill_null(0)
    if is_train:
        df = df.with_columns(pl.exclude('date','instrument').clip(
            pl.exclude('date','instrument').mean()-4*pl.exclude('date','instrument').std(),
            pl.exclude('date','instrument').mean()+4*pl.exclude('date','instrument').std()
        ))
    #     df = df.with_columns((pl.col('label')-pl.col('label').mean())/(pl.col('label').std()+1e-6))
    return df

def get_train_test(start_year:str='2023'):
    '''
    默认5年训练,一年测试
    start_year: 测试集开始年份,训练集自动后选5年
    默认从1月1到12月31
    '''
    train_start_date = str(int(start_year)-5)+'-01-01'
    train_end_date = str(int(start_year)-1)+'-12-29'

    test_start_date = start_year+'-01-03'
    test_end_date = start_year+'-12-31'
    train_df = get_data(train_start_date, train_end_date, is_train=True)
    test_df = get_data(test_start_date, test_end_date, is_train=False)
    return train_df, test_df

模型

模型使用KANLayer,暂未结构调优,模型结构代码如下:

from kan.KANLayer import KANLayer

class KAN(nn.Module):
    """
    仿 DNN 结构的浅层 KAN 网络:
    """
    def __init__(self, input_dim, dropout=0.2, num=5, k=3, device=None, **kan_kwargs):
        super().__init__()
        self.bn = nn.BatchNorm1d(input_dim)
        self.lr = nn.Linear(input_dim, 128)
        self.dropout = nn.Dropout(dropout)

        self.kans = nn.ModuleList([
            KANLayer(128, 64, num=num, k=k, device=device, **kan_kwargs),
            KANLayer(64, 1, num=num, k=k, device=device, **kan_kwargs)
        ])

        if device is not None:
            self.to(device)

    def forward(self, x):
        x = self.bn(x)
        x = self.lr(x)
        x = self.dropout(x)
        x = self.kans[0](x)[0]
        x = self.dropout(x)
        x = self.kans[1](x)[0]
        return x

    @torch.no_grad()
    def update_grid(self, X, mode='sample'):
        """
        基于传入的样本集 X 更新所有 KANLayer 的网格。
        X 应该是代表该层输入分布的较大样本集,通常用整个训练集。
        建议在训练若干 epoch 后周期性调用。
        """
        # 保存当前训练/评估状态,结束后恢复
        was_training = self.training
        self.eval()                     # 避免 BN 等层在无梯度模式下行为异常

        inputs = self.dropout(self.lr(self.bn(X)))
        for layer in self.kans:
            # 更新当前层的网格 —— 传入的是这一层真实接收到的输入
            layer.update_grid_from_samples(inputs, mode=mode)
            # 计算下一层的输入,并 detach 切断梯度图
            y, *_ = layer(inputs)
            inputs = y.detach()

        if was_training:
            self.train()                # 恢复原来的训练模式

训练

训练时在训练集内部按4:1再次划分训练集和验证集。损失函数为mse,优化器为Adam。参数设置如下:

learning rate 0.0001
max_epochs 30
网格更新时间 20th epoch

训练曲线显示存在一定的过拟合,可能建议未来的工作调整模型结构或尝试对kan模型剪枝,以提高泛化能力。

结果

回测将8年滚动预测分数合并,进行st和停牌过滤,然后每次选择分数最高的50只股票每5日调仓。回测结果如下:

SHAP 分析

以下内容存在局限性,仅供演示参考。

全局特征摘要图

该图展示了各输入因子对模型输出预测收益的全局贡献程度。纵轴因子排序代表因子重要性,越上方因子影响力越强;横轴 SHAP 值表示因子对预测值的增减效应,正值提升收益,负值抑制收益;散点颜色反映因子自身大小,可直观识别因子的单调正向 / 负向驱动规律。

  1. 每个小点:一个样本
  2. 越靠上方的因子,对模型预测收益的整体影响力越大(全局特征重要性)。
  3. SHAP > 0:该因子正向拉升模型预测收益;SHAP < 0:该因子负向压低模型预测收益
  4. 红点(特征数值高)→ 落在右边:**因子越高,预测收益越高;**红点(特征数值高)→ 落在左边:因子越高,预测收益越低

\

特征重要性排序图

条形图以各因子平均绝对 SHAP 值排序,直观量化因子全局重要性。条形长度越长,代表该因子对模型预测结果的整体贡献越大,可快速筛选出策略核心有效因子与冗余弱因子。

waterfall 瀑布图(单个样本解释)

瀑布图实现单样本局部可解释性,展示从模型全局基线预测值开始,各个因子依次叠加正负贡献,最终得到该样本模型输出值的完整分解过程,清晰定位单条样本预测结果的关键驱动因子与拖累因子。

\

特征依赖图

该特征依赖图以 close_0 为分析因子,横轴为 close_0 原始取值,左侧纵轴为其对应的 SHAP 贡献值,反映因子自身对股票收益预测的影响规律;右侧色阶代表算法自动识别出的交互最强因子 corr_ct,通过散点颜色展示两因子间的交互效应。

\

{link}