augment_regime_detection

augment_regime_detection(
    data,
    date_column,
    close_column,
    window=252,
    n_regimes=2,
    method='hmm',
    step_size=1,
    n_iter=100,
    n_jobs=-1,
    reduce_memory=False,
    hmm_backend='auto',
    engine='auto',
)

Detect regimes in a financial time series using a specified method (e.g., HMM).

Parameters

Name Type Description Default
data DataFrame or GroupBy(pandas or polars) Input time-series data. Grouped inputs are processed per group before the regime labels are appended. required
date_column str or ColumnSelector Column (or selector) containing dates or timestamps. required
close_column str, ColumnSelector, or list Column(s) with closing prices used for regime detection. Must resolve to exactly one column. required
window Union[int, Tuple[int, int], List[int]] Size of the rolling window to fit the regime detection model. Default is 252. 252
n_regimes int Number of regimes to detect (e.g., 2 for bull/bear). Default is 2. 2
method str Method for regime detection. Currently supports β€˜hmm’. Default is β€˜hmm’. 'hmm'
step_size int Step size between HMM fits (e.g., 10 fits every 10 rows). Default is 1. 1
n_iter int Number of iterations for HMM fitting. Default is 100. 100
n_jobs int Number of parallel jobs for group processing (-1 uses all cores). Default is -1. -1
reduce_memory bool If True, reduces memory usage. Default is False. False
hmm_backend (auto, pomegranate, hmmlearn) Backend library used for the HMM implementation. "auto" (default) prefers the faster pomegranate backend when installed, otherwise falls back to hmmlearn. "auto"
engine (auto, pandas, polars) Execution engine. "auto" (default) infers the backend from the input data while allowing explicit overrides. "auto"

Returns

Name Type Description
DataFrame DataFrame with added columns: - {close_column}regime{window}: Integer labels for detected regimes (e.g., 0, 1).

Notes

  • Uses Hidden Markov Model (HMM) to identify latent regimes based on log returns.
  • Regimes reflect distinct statistical states (e.g., high/low volatility, trending).
  • Requires β€˜hmmlearn’ package. Install with pip install hmmlearn or the faster optional pomegranate backend via pip install 'pytimetk[regime_backends]' (equivalent to pip install 'pomegranate<1.0').

Examples

import pandas as pd
import polars as pl
import pytimetk as tk

df = tk.load_dataset("stocks_daily", parse_dates=["date"])

df
symbol date open high low close volume adjusted
0 META 2013-01-02 27.440001 28.180000 27.420000 28.000000 69846400 28.000000
1 META 2013-01-03 27.879999 28.469999 27.590000 27.770000 63140600 27.770000
2 META 2013-01-04 28.010000 28.930000 27.830000 28.760000 72715400 28.760000
3 META 2013-01-07 28.690001 29.790001 28.650000 29.420000 83781800 29.420000
4 META 2013-01-08 29.510000 29.600000 28.860001 29.059999 45871300 29.059999
... ... ... ... ... ... ... ... ...
16189 GOOG 2023-09-15 138.800003 139.360001 137.179993 138.300003 48947600 138.300003
16190 GOOG 2023-09-18 137.630005 139.929993 137.630005 138.960007 16233600 138.960007
16191 GOOG 2023-09-19 138.250000 139.175003 137.500000 138.830002 15479100 138.830002
16192 GOOG 2023-09-20 138.830002 138.839996 134.520004 134.589996 21473500 134.589996
16193 GOOG 2023-09-21 132.389999 133.190002 131.089996 131.360001 22042700 131.360001

16194 rows Γ— 8 columns

# Regime detection - pandas single stock (requires hmm backend)
regime_single = (
    df
    .query("symbol == 'AAPL'")
    .augment_regime_detection(
        date_column="date",
        close_column="close",
        window=252,
        n_regimes=2,
    )
)

regime_single.glimpse()
Model is not converging.  Current: 673.2557554510333 is not greater than 673.2740563365. Delta is -0.018300885466715044
<class 'pandas.core.frame.DataFrame'>: 2699 rows of 9 columns
symbol:            object            ['AAPL', 'AAPL', 'AAPL', 'AAPL', 'A ...
date:              datetime64[ns]    [Timestamp('2013-01-02 00:00:00'),  ...
open:              float64           [19.779285430908203, 19.56714248657 ...
high:              float64           [19.821428298950195, 19.63107109069 ...
low:               float64           [19.343929290771484, 19.32142829895 ...
close:             float64           [19.608213424682617, 19.36071395874 ...
volume:            int64             [560518000, 352965200, 594333600, 4 ...
adjusted:          float64           [16.791179656982422, 16.57924079895 ...
close_regime_252:  float64           [nan, nan, nan, nan, nan, nan, nan, ...
# Regime detection - pandas grouped (requires hmm backend)
regime_grouped = (
    df
    .groupby("symbol")
    .augment_regime_detection(
        date_column="date",
        close_column="close",
        window=[252, 504],
        n_regimes=3,
    )
)

regime_grouped.groupby("symbol").tail(1)
Model is not converging.  Current: 684.9925091106471 is not greater than 685.1489625220825. Delta is -0.15645341143533642
Model is not converging.  Current: 776.8672485008882 is not greater than 776.911249432062. Delta is -0.04400093117385495
Model is not converging.  Current: 729.2937336409235 is not greater than 729.5471636929799. Delta is -0.25343005205638747
Model is not converging.  Current: 1509.7302235564714 is not greater than 1509.8392116273237. Delta is -0.10898807085231965
Model is not converging.  Current: 1202.7528839438414 is not greater than 1202.754819166896. Delta is -0.0019352230544882332
Model is not converging.  Current: 1430.7922312026346 is not greater than 1430.9316817660072. Delta is -0.13945056337252026
Model is not converging.  Current: 1413.7782205823746 is not greater than 1413.9113377218375. Delta is -0.13311713946291093
symbol date open high low close volume adjusted close_regime_252 close_regime_504
2698 META 2023-09-21 295.700012 300.260010 293.269989 295.730011 21300500 295.730011 NaN NaN
5397 AMZN 2023-09-21 131.940002 132.240005 129.309998 129.330002 70234800 129.330002 NaN NaN
8096 AAPL 2023-09-21 174.550003 176.300003 173.860001 173.929993 63047900 173.929993 NaN NaN
10795 NFLX 2023-09-21 386.500000 395.899994 383.420013 384.149994 5547900 384.149994 NaN NaN
13494 NVDA 2023-09-21 415.829987 421.000000 409.799988 410.170013 44893000 410.170013 NaN NaN
16193 GOOG 2023-09-21 132.389999 133.190002 131.089996 131.360001 22042700 131.360001 NaN NaN
# Regime detection - polars engine (requires hmm backend)
pl_single = pl.from_pandas(df.query("symbol == 'AAPL'"))
regime_polars = (
    pl_single
    .tk.augment_regime_detection(
        date_column="date",
        close_column="close",
        window=252,
        n_regimes=2,
    )
)

regime_polars.glimpse()
Model is not converging.  Current: 672.6488254279976 is not greater than 672.6571367046666. Delta is -0.008311276669019207
Rows: 2699
Columns: 9
$ symbol                    <str> 'AAPL', 'AAPL', 'AAPL', 'AAPL', 'AAPL', 'AAPL', 'AAPL', 'AAPL', 'AAPL', 'AAPL'
$ date             <datetime[ns]> 2013-01-02 00:00:00, 2013-01-03 00:00:00, 2013-01-04 00:00:00, 2013-01-07 00:00:00, 2013-01-08 00:00:00, 2013-01-09 00:00:00, 2013-01-10 00:00:00, 2013-01-11 00:00:00, 2013-01-14 00:00:00, 2013-01-15 00:00:00
$ open                      <f64> 19.779285430908203, 19.567142486572266, 19.177499771118164, 18.64285659790039, 18.90035629272461, 18.66071319580078, 18.876785278320312, 18.60714340209961, 17.952856063842773, 17.796428680419922
$ high                      <f64> 19.821428298950195, 19.63107109069824, 19.236785888671875, 18.9035701751709, 18.996070861816406, 18.750356674194336, 18.882856369018555, 18.761428833007812, 18.125, 17.82107162475586
$ low                       <f64> 19.343929290771484, 19.321428298950195, 18.77964210510254, 18.399999618530273, 18.616071701049805, 18.428213119506836, 18.41142845153809, 18.53642845153809, 17.80392837524414, 17.26357078552246
$ close                     <f64> 19.608213424682617, 19.36071395874023, 18.821428298950195, 18.71071434020996, 18.761070251464844, 18.467857360839844, 18.696786880493164, 18.582143783569336, 17.91964340209961, 17.354286193847656
$ volume                    <i64> 560518000, 352965200, 594333600, 484156400, 458707200, 407604400, 601146000, 350506800, 734207600, 876772400
$ adjusted                  <f64> 16.791179656982422, 16.579240798950195, 16.1174373626709, 16.02262306213379, 16.065746307373047, 15.814659118652344, 16.010698318481445, 15.912524223327637, 15.345203399658203, 14.86106777191162
$ close_regime_252          <f64> None, None, None, None, None, None, None, None, None, None
# Pomegranate backend with column selectors
from pytimetk.utils.selection import contains

selector_demo = (
    df
    .groupby("symbol")
    .augment_regime_detection(
        date_column=contains("dat"),
        close_column=contains("clos"),
        window=252,
        n_regimes=4,
        hmm_backend="pomegranate", # pomegranate<=1.0.0 required
    )
)

selector_demo.groupby("symbol").tail(1)
symbol date open high low close volume adjusted close_regime_252
2698 META 2023-09-21 295.700012 300.260010 293.269989 295.730011 21300500 295.730011 -1.0
5397 AMZN 2023-09-21 131.940002 132.240005 129.309998 129.330002 70234800 129.330002 -1.0
8096 AAPL 2023-09-21 174.550003 176.300003 173.860001 173.929993 63047900 173.929993 -1.0
10795 NFLX 2023-09-21 386.500000 395.899994 383.420013 384.149994 5547900 384.149994 -1.0
13494 NVDA 2023-09-21 415.829987 421.000000 409.799988 410.170013 44893000 410.170013 -1.0
16193 GOOG 2023-09-21 132.389999 133.190002 131.089996 131.360001 22042700 131.360001 -1.0
# Visualizing regimes
SYMBOLS = ['AAPL', 'AMZN', 'MSFT', 'GOOG', 'NVDA']
SYMBOL = 'NVDA'

(
    selector_demo
    .query(f"symbol == '{SYMBOL}'")
    .plot_timeseries(
        date_column="date",
        value_column="close",
        color_column=contains("regime_"),
        smooth=False,
        title=f"{SYMBOL} Close Price with Detected Regimes",
    )
)