{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FtHjlXDmDJAF"
      },
      "source": [
        "# EzMath: Time-Aware Bayesian Knowledge Tracing\n",
        "## Cross-Validation Pipeline for Neurodivergent Skill Acquisition\n",
        "\n",
        "**Research objective:** Compare learning trajectories of the EzMath dyscalculia intervention cohort against an ASSISTments neurotypical baseline using Standard and Time-Aware Bayesian Knowledge Tracing (BKT).\n",
        "\n",
        "**Knowledge component:** Addition and Subtraction of Fractions\n",
        "\n",
        "| Section | Description |\n",
        "|---|---|\n",
        "| 1. Setup | Environment, dependencies, reproducibility |\n",
        "| 2. Configuration | Paths, hyperparameters, constants |\n",
        "| 3. Data Loading | Load and inspect raw datasets |\n",
        "| 4. Preprocessing | Filter, normalize, align both datasets |\n",
        "| 5. Model Definition | StandardBKT and TimeAwareBKT classes |\n",
        "| 6. Training | 5-fold student-level cross-validation |\n",
        "| 7. Evaluation | RMSE, AUC, learning curve, statistical tests |\n",
        "| 8. Visualization | Publication-quality figures (BW) |\n",
        "| 9. Export | CSV results and auto-generated report |\n",
        "\n",
        "> **Limitation note:** ASSISTments baseline retains only students with ≥10 interactions on the target skill, introducing survivorship bias. This subgroup represents persistent-difficulty learners rather than the full neurotypical population."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hZRqtb-dDJAI"
      },
      "source": [
        "---\n",
        "## 1. Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "haAUZ-dCDJAI"
      },
      "outputs": [],
      "source": [
        "# Mount Google Drive\n",
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n",
        "print('Drive mounted successfully.')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LjWp4yLRDJAJ"
      },
      "outputs": [],
      "source": [
        "# ── Dependencies ──────────────────────────────────────────────────\n",
        "import os\n",
        "import random\n",
        "import warnings\n",
        "from pathlib import Path\n",
        "from typing import Dict, List, Tuple, Optional\n",
        "\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import matplotlib.pyplot as plt\n",
        "import matplotlib.ticker as mticker\n",
        "from matplotlib.lines import Line2D\n",
        "from scipy import stats, optimize\n",
        "from sklearn.metrics import roc_auc_score, roc_curve, mean_squared_error\n",
        "from IPython.display import display, Markdown\n",
        "\n",
        "warnings.filterwarnings('ignore')\n",
        "pd.set_option('display.float_format', '{:.4f}'.format)\n",
        "pd.set_option('display.max_columns', 20)\n",
        "\n",
        "print('All dependencies loaded.')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DJnHDe7TDJAJ"
      },
      "outputs": [],
      "source": [
        "def set_global_seed(seed: int = 42) -> None:\n",
        "    \"\"\"Set random seed across all libraries for full reproducibility.\n",
        "\n",
        "    Args:\n",
        "        seed: Integer seed value. Default 42 (convention).\n",
        "    \"\"\"\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    os.environ['PYTHONHASHSEED'] = str(seed)\n",
        "\n",
        "RANDOM_SEED: int = 42\n",
        "set_global_seed(RANDOM_SEED)\n",
        "print(f'Global seed set: {RANDOM_SEED}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HYW5IvVhDJAJ"
      },
      "source": [
        "---\n",
        "## 2. Configuration"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3j9rufOvDJAJ"
      },
      "outputs": [],
      "source": [
        "# ── Path Configuration (edit these 3 paths for your Drive) ───────\n",
        "BASE_DIR        = Path('/content/drive/MyDrive/DU_LIEU')\n",
        "ASSISTMENTS_PATH: Path = BASE_DIR / 'skill_builder_data.csv'\n",
        "EZMATH_PATH:      Path = BASE_DIR / 'ezmath_intervention.csv'\n",
        "OUTPUT_DIR:       Path = BASE_DIR / 'outputs'\n",
        "OUTPUT_DIR.mkdir(parents=True, exist_ok=True)\n",
        "\n",
        "# ── Cross-Validation Configuration ───────────────────────────────\n",
        "N_FOLDS:              int = 5    # Number of folds\n",
        "MIN_OPPORTUNITIES:    int = 10   # Minimum interactions per student (ASSISTments)\n",
        "MAX_OPPORTUNITIES:    int = 10   # Sequence truncation point\n",
        "N_COMPARISONS:        int = 10   # Time-points for Bonferroni correction\n",
        "ALPHA:              float = 0.05 # Base significance level\n",
        "\n",
        "# ── Target Knowledge Component ───────────────────────────────────\n",
        "TARGET_SKILL:  str = 'Addition and Subtraction Fractions'\n",
        "FALLBACK_KEYWORDS: List[str] = [\n",
        "    'addition and subtraction fractions',\n",
        "    'fraction addition',\n",
        "    'fraction subtraction',\n",
        "    'addition whole number and fraction',\n",
        "    'adding fractions',\n",
        "    'subtracting fractions',\n",
        "]\n",
        "\n",
        "# ── Standard BKT Initial Parameters (re-fitted via EM) ───────────\n",
        "BKT_INIT: Dict[str, float] = {\n",
        "    'prior': 0.3,   # P(L_0): initial knowledge probability\n",
        "    'learn': 0.1,   # P(T):   learning transition probability\n",
        "    'guess': 0.25,  # P(G):   guess probability (static)\n",
        "    'slip':  0.1,   # P(S):   slip probability (static)\n",
        "}\n",
        "\n",
        "# ── Time-Aware BKT Initial Parameters (tuned via Nelder-Mead) ────\n",
        "# P(G|t) = G_max / (1 + exp(k_g * (t - tau_g)))  -- decreasing sigmoid\n",
        "# P(S|t) = S_max / (1 + exp(k_s * (t - tau_s)))  -- increasing sigmoid (k_s < 0)\n",
        "TIME_AWARE_INIT: Dict[str, float] = {\n",
        "    'G_max':  0.4,     # Upper bound for guess probability\n",
        "    'S_max':  0.3,     # Upper bound for slip probability\n",
        "    'k_g':    0.001,   # Decay rate for P(G|t), must be > 0\n",
        "    'k_s':   -0.0005,  # Growth rate for P(S|t), must be < 0\n",
        "    'tau_g':  3000,    # Inflection point of P(G|t) in milliseconds\n",
        "    'tau_s':  5000,    # Inflection point of P(S|t) in milliseconds\n",
        "}\n",
        "\n",
        "# ── Visualization Style ───────────────────────────────────────────\n",
        "plt.rcParams.update({\n",
        "    'font.family':      'serif',\n",
        "    'font.size':        11,\n",
        "    'axes.titlesize':   11,\n",
        "    'axes.labelsize':   11,\n",
        "    'legend.fontsize':  9,\n",
        "    'xtick.labelsize':  10,\n",
        "    'ytick.labelsize':  10,\n",
        "    'figure.dpi':       150,\n",
        "    'savefig.dpi':      300,\n",
        "    'savefig.bbox':     'tight',\n",
        "})\n",
        "\n",
        "print('Configuration complete.')\n",
        "print(f'  Paths validated: ASSISTments={ASSISTMENTS_PATH.name}, EzMath={EZMATH_PATH.name}')\n",
        "print(f'  Output directory: {OUTPUT_DIR}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Yb51rMPMDJAK"
      },
      "source": [
        "---\n",
        "## 3. Data Loading"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1I1w07MCDJAK"
      },
      "outputs": [],
      "source": [
        "def load_assistments(path: Path) -> pd.DataFrame:\n",
        "    \"\"\"Load the ASSISTments skill-builder dataset from CSV.\n",
        "\n",
        "    Args:\n",
        "        path: Path to the ASSISTments CSV file.\n",
        "\n",
        "    Returns:\n",
        "        Raw DataFrame with all original columns preserved.\n",
        "\n",
        "    Raises:\n",
        "        FileNotFoundError: If the file does not exist at the given path.\n",
        "    \"\"\"\n",
        "    if not path.exists():\n",
        "        raise FileNotFoundError(f'ASSISTments file not found: {path}')\n",
        "    df = pd.read_csv(path, encoding='ISO-8859-1', low_memory=False)\n",
        "    print(f'ASSISTments loaded: {df.shape[0]:,} rows x {df.shape[1]} columns')\n",
        "    print(f'  Students: {df[\"user_id\"].nunique():,} | Skills: {df[\"skill_name\"].nunique()}')\n",
        "    return df\n",
        "\n",
        "\n",
        "def load_ezmath(path: Path) -> pd.DataFrame:\n",
        "    \"\"\"Load the EzMath intervention dataset from CSV.\n",
        "\n",
        "    Args:\n",
        "        path: Path to the EzMath CSV file.\n",
        "\n",
        "    Returns:\n",
        "        Raw DataFrame with columns: student_id, knowledge_component,\n",
        "        is_correct, ms_first_response, opportunity_count.\n",
        "\n",
        "    Raises:\n",
        "        FileNotFoundError: If the file does not exist at the given path.\n",
        "    \"\"\"\n",
        "    if not path.exists():\n",
        "        raise FileNotFoundError(f'EzMath file not found: {path}')\n",
        "    df = pd.read_csv(path)\n",
        "    print(f'EzMath loaded: {df.shape[0]:,} rows x {df.shape[1]} columns')\n",
        "    print(f'  Students: {df[\"student_id\"].nunique():,} | '\n",
        "          f'Opportunities: {df[\"opportunity_count\"].min()}–{df[\"opportunity_count\"].max()}')\n",
        "    return df\n",
        "\n",
        "\n",
        "assistments_raw: pd.DataFrame = load_assistments(ASSISTMENTS_PATH)\n",
        "ezmath_raw:      pd.DataFrame = load_ezmath(EZMATH_PATH)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sTzRyQh0DJAL"
      },
      "outputs": [],
      "source": [
        "def verify_target_skill(df: pd.DataFrame, target: str) -> bool:\n",
        "    \"\"\"Verify that the target skill name exists in the ASSISTments dataset.\n",
        "\n",
        "    Prints all fraction-related skill names found, marking the target.\n",
        "\n",
        "    Args:\n",
        "        df:     ASSISTments DataFrame containing a 'skill_name' column.\n",
        "        target: Exact skill name string to search for.\n",
        "\n",
        "    Returns:\n",
        "        True if target skill exists, False otherwise.\n",
        "    \"\"\"\n",
        "    all_skills = df['skill_name'].dropna().unique()\n",
        "    fraction_skills = sorted([s for s in all_skills if 'frac' in str(s).lower()])\n",
        "    print('Fraction-related skills found in ASSISTments:')\n",
        "    for skill in fraction_skills:\n",
        "        count = (df['skill_name'] == skill).sum()\n",
        "        marker = '  <-- TARGET' if skill == target else ''\n",
        "        print(f'  [{count:>7,}]  {skill}{marker}')\n",
        "    found = target in df['skill_name'].values\n",
        "    print(f'\\nTarget skill \"{target}\" found: {found}')\n",
        "    if not found:\n",
        "        print('WARNING: Update TARGET_SKILL in Section 2 to match available names.')\n",
        "    return found\n",
        "\n",
        "\n",
        "_ = verify_target_skill(assistments_raw, TARGET_SKILL)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OBQdcjxQDJAL"
      },
      "outputs": [],
      "source": [
        "def validate_ezmath_structure(df: pd.DataFrame) -> None:\n",
        "    \"\"\"Run structural integrity checks on the EzMath dataset.\n",
        "\n",
        "    Checks for: null values, opportunity range, completeness of\n",
        "    10-opportunity sequences, and chronological ordering.\n",
        "\n",
        "    Args:\n",
        "        df: EzMath DataFrame to validate.\n",
        "    \"\"\"\n",
        "    print('=== EzMath structural validation ===')\n",
        "    null_counts = df.isnull().sum()\n",
        "    if null_counts.any():\n",
        "        print(f'Null values detected:\\n{null_counts[null_counts > 0]}')\n",
        "    else:\n",
        "        print('No null values.')\n",
        "\n",
        "    opp_range = (df['opportunity_count'].min(), df['opportunity_count'].max())\n",
        "    print(f'Opportunity range: {opp_range[0]} to {opp_range[1]}')\n",
        "\n",
        "    n_complete = (df.groupby('student_id')['opportunity_count'].count() == 10).sum()\n",
        "    print(f'Students with exactly 10 opportunities: {n_complete}/{df[\"student_id\"].nunique():,}')\n",
        "\n",
        "    sample_id = df['student_id'].iloc[0]\n",
        "    sample_seq = df[df['student_id'] == sample_id]['opportunity_count'].tolist()\n",
        "    is_ordered = (sample_seq == sorted(sample_seq))\n",
        "    print(f'Temporal ordering check ({sample_id}): {\"PASS\" if is_ordered else \"FAIL — sort required\"}')\n",
        "\n",
        "    print('\\nResponse time (ms_first_response) descriptive statistics:')\n",
        "    display(df['ms_first_response'].describe().round(0).to_frame().T)\n",
        "\n",
        "\n",
        "validate_ezmath_structure(ezmath_raw)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "udNbyACiDJAL"
      },
      "source": [
        "---\n",
        "## 4. Preprocessing"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZuZjcHFUDJAL"
      },
      "outputs": [],
      "source": [
        "# ── 4.1  ASSISTments: skill extraction ───────────────────────────\n",
        "\n",
        "def extract_fraction_skill(\n",
        "    df: pd.DataFrame,\n",
        "    exact_name: str,\n",
        "    fallback_keywords: List[str]\n",
        ") -> pd.DataFrame:\n",
        "    \"\"\"Extract records matching the target fraction skill.\n",
        "\n",
        "    Attempts exact match first; falls back to keyword search if no\n",
        "    exact match is found. Validates that error rates across retained\n",
        "    skills are comparable (diff <= 0.15).\n",
        "\n",
        "    Args:\n",
        "        df:                ASSISTments raw DataFrame.\n",
        "        exact_name:        Exact skill_name string to match.\n",
        "        fallback_keywords: Lowercase substrings for fallback matching.\n",
        "\n",
        "    Returns:\n",
        "        Filtered DataFrame containing only fraction-skill records.\n",
        "    \"\"\"\n",
        "    exact_mask = df['skill_name'] == exact_name\n",
        "    if exact_mask.sum() > 0:\n",
        "        result = df[exact_mask].copy()\n",
        "        print(f'Exact match \"{exact_name}\": {len(result):,} records')\n",
        "    else:\n",
        "        def _is_fraction(name: str) -> bool:\n",
        "            if pd.isna(name):\n",
        "                return False\n",
        "            return any(kw in str(name).lower() for kw in fallback_keywords)\n",
        "        result = df[df['skill_name'].apply(_is_fraction)].copy()\n",
        "        print(f'WARNING: Exact match not found. Keyword fallback: {len(result):,} records')\n",
        "\n",
        "    # Validate difficulty homogeneity across retained skills\n",
        "    difficulty_by_skill = (\n",
        "        result.groupby('skill_name')['correct']\n",
        "        .agg(n='count', error_rate=lambda x: round(1 - x.mean(), 4))\n",
        "        .reset_index()\n",
        "        .sort_values('error_rate')\n",
        "    )\n",
        "    diff_range = difficulty_by_skill['error_rate'].max() - difficulty_by_skill['error_rate'].min()\n",
        "    status = 'PASS' if diff_range <= 0.15 else 'WARN (diff > 0.15)'\n",
        "    print(f'Skill difficulty range: {diff_range:.3f} [{status}]')\n",
        "    display(difficulty_by_skill)\n",
        "    return result\n",
        "\n",
        "\n",
        "assistments_fraction: pd.DataFrame = extract_fraction_skill(\n",
        "    assistments_raw, TARGET_SKILL, FALLBACK_KEYWORDS\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "J6tbM2F-DJAM"
      },
      "outputs": [],
      "source": [
        "# ── 4.2  ASSISTments: filter original attempts only ──────────────\n",
        "\n",
        "def filter_original_attempts(df: pd.DataFrame) -> pd.DataFrame:\n",
        "    \"\"\"Retain only original problem attempts, excluding scaffolding hints.\n",
        "\n",
        "    Rows with original == 0 represent hint-assisted sub-problems that\n",
        "    distort response time distributions and inflate accuracy.\n",
        "\n",
        "    Args:\n",
        "        df: ASSISTments DataFrame with an 'original' column.\n",
        "\n",
        "    Returns:\n",
        "        Filtered DataFrame with original == 1 only.\n",
        "    \"\"\"\n",
        "    filtered = df[df['original'] == 1].copy()\n",
        "    removed = len(df) - len(filtered)\n",
        "    print(f'Original-attempt filter: {len(df):,} → {len(filtered):,} '\n",
        "          f'(removed {removed:,} scaffolding rows, {removed/len(df)*100:.1f}%)')\n",
        "    return filtered\n",
        "\n",
        "\n",
        "assistments_original: pd.DataFrame = filter_original_attempts(assistments_fraction)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "poCbRuCiDJAM"
      },
      "outputs": [],
      "source": [
        "# ── 4.3  ASSISTments: IQR-based response time normalization ──────\n",
        "\n",
        "def normalize_response_time(\n",
        "    df: pd.DataFrame,\n",
        "    rt_column: str = 'ms_first_response',\n",
        "    iqr_multiplier: float = 3.0,\n",
        "    hard_min_ms: float = 500.0,\n",
        "    hard_max_ms: float = 120_000.0\n",
        ") -> pd.DataFrame:\n",
        "    \"\"\"Remove response time outliers using IQR-based bounds.\n",
        "\n",
        "    Bounds are computed as [Q1 - k*IQR, Q3 + k*IQR], clipped to\n",
        "    [hard_min_ms, hard_max_ms] to handle extreme values.\n",
        "\n",
        "    Args:\n",
        "        df:              Input DataFrame.\n",
        "        rt_column:       Name of the response time column.\n",
        "        iqr_multiplier:  Multiplier k for IQR fence. Default 3.0.\n",
        "        hard_min_ms:     Absolute lower bound in milliseconds.\n",
        "        hard_max_ms:     Absolute upper bound in milliseconds.\n",
        "\n",
        "    Returns:\n",
        "        DataFrame with outlier response times removed.\n",
        "    \"\"\"\n",
        "    rt_series = df[rt_column].dropna()\n",
        "    q1, q3 = rt_series.quantile(0.25), rt_series.quantile(0.75)\n",
        "    iqr = q3 - q1\n",
        "    rt_low  = max(hard_min_ms, q1 - iqr_multiplier * iqr)\n",
        "    rt_high = min(hard_max_ms, q3 + iqr_multiplier * iqr)\n",
        "\n",
        "    mask = df[rt_column].between(rt_low, rt_high)\n",
        "    filtered = df[mask].copy()\n",
        "    print(f'RT normalization: bounds=[{rt_low:.0f}ms, {rt_high:.0f}ms] '\n",
        "          f'(Q1={q1:.0f}, Q3={q3:.0f}, IQR={iqr:.0f})')\n",
        "    print(f'  Records: {len(df):,} → {len(filtered):,} '\n",
        "          f'(removed {len(df)-len(filtered):,} outliers)')\n",
        "    return filtered\n",
        "\n",
        "\n",
        "assistments_rt: pd.DataFrame = normalize_response_time(assistments_original)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wnbIRuBxDJAM"
      },
      "outputs": [],
      "source": [
        "# ── 4.4  ASSISTments: sequence length filter ─────────────────────\n",
        "\n",
        "def filter_by_sequence_length(\n",
        "    df: pd.DataFrame,\n",
        "    student_col: str,\n",
        "    opportunity_col: str,\n",
        "    min_opportunities: int,\n",
        "    max_opportunities: int\n",
        ") -> pd.DataFrame:\n",
        "    \"\"\"Retain only students with sufficient practice history.\n",
        "\n",
        "    Students with fewer than min_opportunities interactions are excluded\n",
        "    (survivorship bias noted). Sequences are truncated at max_opportunities\n",
        "    to match the fixed-length EzMath structure.\n",
        "\n",
        "    Args:\n",
        "        df:                Input DataFrame.\n",
        "        student_col:       Column name for student identifiers.\n",
        "        opportunity_col:   Column name for opportunity counts.\n",
        "        min_opportunities: Minimum interactions required (inclusive).\n",
        "        max_opportunities: Maximum opportunity index to retain.\n",
        "\n",
        "    Returns:\n",
        "        Filtered and truncated DataFrame.\n",
        "\n",
        "    Note:\n",
        "        Survivorship bias: students who achieved mastery before\n",
        "        min_opportunities are excluded, making this baseline\n",
        "        representative of persistent-difficulty learners only.\n",
        "    \"\"\"\n",
        "    max_opp_per_student = df.groupby(student_col)[opportunity_col].max()\n",
        "    eligible_students   = max_opp_per_student[max_opp_per_student >= min_opportunities].index\n",
        "\n",
        "    n_total   = max_opp_per_student.shape[0]\n",
        "    n_eligible = len(eligible_students)\n",
        "    n_excluded = n_total - n_eligible\n",
        "\n",
        "    filtered = df[\n",
        "        df[student_col].isin(eligible_students) &\n",
        "        (df[opportunity_col] <= max_opportunities)\n",
        "    ].copy()\n",
        "\n",
        "    print(f'Sequence filter (≥{min_opportunities} opportunities, truncated at {max_opportunities}):')\n",
        "    print(f'  Eligible students: {n_eligible:,} / {n_total:,} '\n",
        "          f'(excluded {n_excluded:,} early-mastery students)')\n",
        "    print(f'  Records retained: {len(filtered):,}')\n",
        "    print(f'  SURVIVORSHIP BIAS: {n_excluded:,} students excluded — '\n",
        "          f'acknowledge in paper Limitations.')\n",
        "    return filtered\n",
        "\n",
        "\n",
        "assistments_seq: pd.DataFrame = filter_by_sequence_length(\n",
        "    assistments_rt,\n",
        "    student_col='user_id',\n",
        "    opportunity_col='opportunity',\n",
        "    min_opportunities=MIN_OPPORTUNITIES,\n",
        "    max_opportunities=MAX_OPPORTUNITIES\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MsOPei3jDJAM"
      },
      "outputs": [],
      "source": [
        "# ── 4.5  Standardize schemas and assign group labels ─────────────\n",
        "\n",
        "def standardize_assistments(df: pd.DataFrame, opp_col: str = 'opportunity') -> pd.DataFrame:\n",
        "    \"\"\"Rename ASSISTments columns to the shared canonical schema.\n",
        "\n",
        "    Canonical schema: student_id, knowledge_component, is_correct,\n",
        "    ms_first_response, opportunity_count, group, dataset.\n",
        "\n",
        "    Args:\n",
        "        df:      ASSISTments DataFrame after all filtering steps.\n",
        "        opp_col: Name of the opportunity column ('opportunity' or\n",
        "                 'opportunity_original').\n",
        "\n",
        "    Returns:\n",
        "        Standardized DataFrame sorted by student and opportunity.\n",
        "    \"\"\"\n",
        "    result = df[['user_id', 'skill_name', 'correct', 'ms_first_response', opp_col]].copy()\n",
        "    result.columns = ['student_id', 'knowledge_component', 'is_correct',\n",
        "                      'ms_first_response', 'opportunity_count']\n",
        "    result['student_id']          = 'ASST_' + result['student_id'].astype(str)\n",
        "    result['knowledge_component'] = 'Addition_Subtraction_Fractions'\n",
        "    result['group']               = 'control'\n",
        "    result['dataset']             = 'ASSISTments'\n",
        "    result = result.sort_values(['student_id', 'opportunity_count']).reset_index(drop=True)\n",
        "    return result\n",
        "\n",
        "\n",
        "def standardize_ezmath(df: pd.DataFrame) -> pd.DataFrame:\n",
        "    \"\"\"Standardize EzMath columns to the shared canonical schema.\n",
        "\n",
        "    Preserves p_error_modelled if present for reference comparison.\n",
        "\n",
        "    Args:\n",
        "        df: EzMath raw DataFrame.\n",
        "\n",
        "    Returns:\n",
        "        Standardized DataFrame sorted by student and opportunity.\n",
        "    \"\"\"\n",
        "    cols = ['student_id', 'knowledge_component', 'is_correct',\n",
        "            'ms_first_response', 'opportunity_count']\n",
        "    result = df[cols].copy()\n",
        "    if 'p_error_modelled' in df.columns:\n",
        "        result['p_error_modelled'] = df['p_error_modelled'].values\n",
        "    result['group']   = 'intervention'\n",
        "    result['dataset'] = 'EzMath'\n",
        "    result = result.sort_values(['student_id', 'opportunity_count']).reset_index(drop=True)\n",
        "    return result\n",
        "\n",
        "\n",
        "assistments_clean: pd.DataFrame = standardize_assistments(assistments_seq)\n",
        "ezmath_clean:      pd.DataFrame = standardize_ezmath(ezmath_raw)\n",
        "\n",
        "print('=== Preprocessing summary ===')\n",
        "print(f'ASSISTments (control):    {assistments_clean[\"student_id\"].nunique():>6,} students | '\n",
        "      f'{len(assistments_clean):>7,} records')\n",
        "print(f'EzMath (intervention):    {ezmath_clean[\"student_id\"].nunique():>6,} students | '\n",
        "      f'{len(ezmath_clean):>7,} records')\n",
        "display(assistments_clean.describe())\n",
        "display(ezmath_clean.describe())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Kjxl8St4DJAM"
      },
      "source": [
        "---\n",
        "## 5. Model Definition"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hFlCOZTxDJAM"
      },
      "outputs": [],
      "source": [
        "class StandardBKT:\n",
        "    \"\"\"Standard Bayesian Knowledge Tracing with fixed emission probabilities.\n",
        "\n",
        "    Models student knowledge as a two-state Hidden Markov Model (HMM).\n",
        "    Parameters are estimated from data via Expectation-Maximization.\n",
        "\n",
        "    Attributes:\n",
        "        prior: P(L_0) — probability of knowing skill before first attempt.\n",
        "        learn: P(T)   — probability of transitioning from unknown to known.\n",
        "        guess: P(G)   — probability of correct response when skill unknown.\n",
        "        slip:  P(S)   — probability of incorrect response when skill known.\n",
        "\n",
        "    Reference:\n",
        "        Corbett & Anderson (1994). Knowledge tracing: Modeling the\n",
        "        acquisition of procedural knowledge. User Modeling, 4(4), 253–278.\n",
        "    \"\"\"\n",
        "\n",
        "    def __init__(\n",
        "        self,\n",
        "        prior: float = 0.3,\n",
        "        learn: float = 0.1,\n",
        "        guess: float = 0.25,\n",
        "        slip:  float = 0.1\n",
        "    ) -> None:\n",
        "        assert 0 < guess + slip < 1, (\n",
        "            f'Constraint violated: P(G) + P(S) = {guess + slip:.3f} must be in (0, 1)'\n",
        "        )\n",
        "        self.prior = prior\n",
        "        self.learn = learn\n",
        "        self.guess = guess\n",
        "        self.slip  = slip\n",
        "\n",
        "    def _bayesian_update(\n",
        "        self,\n",
        "        p_known: float,\n",
        "        is_correct: int\n",
        "    ) -> float:\n",
        "        \"\"\"Apply Bayes' theorem to update P(L_n) after one observation.\n",
        "\n",
        "        Args:\n",
        "            p_known:    Current P(L_{n-1}) before observation.\n",
        "            is_correct: Observed outcome (1 = correct, 0 = incorrect).\n",
        "\n",
        "        Returns:\n",
        "            Updated P(L_n) incorporating the transition probability.\n",
        "        \"\"\"\n",
        "        g, s = self.guess, self.slip\n",
        "        if is_correct:\n",
        "            numerator   = p_known * (1.0 - s)\n",
        "            denominator = numerator + (1.0 - p_known) * g\n",
        "        else:\n",
        "            numerator   = p_known * s\n",
        "            denominator = numerator + (1.0 - p_known) * (1.0 - g)\n",
        "        p_posterior = numerator / (denominator + 1e-10)\n",
        "        return p_posterior + (1.0 - p_posterior) * self.learn\n",
        "\n",
        "    def predict_sequence(\n",
        "        self,\n",
        "        correctness_sequence: List[int]\n",
        "    ) -> List[float]:\n",
        "        \"\"\"Generate predicted P(correct) for each opportunity in a sequence.\n",
        "\n",
        "        Prediction is made before updating on each observation (predict\n",
        "        then update, matching standard BKT evaluation protocol).\n",
        "\n",
        "        Args:\n",
        "            correctness_sequence: Ordered list of binary outcomes (0/1).\n",
        "\n",
        "        Returns:\n",
        "            List of predicted P(correct) values, same length as input.\n",
        "        \"\"\"\n",
        "        p_known = self.prior\n",
        "        predictions: List[float] = []\n",
        "        for outcome in correctness_sequence:\n",
        "            p_correct = p_known * (1.0 - self.slip) + (1.0 - p_known) * self.guess\n",
        "            predictions.append(p_correct)\n",
        "            p_known = self._bayesian_update(p_known, outcome)\n",
        "        return predictions\n",
        "\n",
        "    def fit_em(\n",
        "        self,\n",
        "        training_sequences: List[List[int]],\n",
        "        n_iterations: int = 100,\n",
        "        convergence_tol: float = 1e-5\n",
        "    ) -> 'StandardBKT':\n",
        "        \"\"\"Estimate BKT parameters using Expectation-Maximization.\n",
        "\n",
        "        Iterates E-step (soft-count accumulation) and M-step (parameter\n",
        "        update) until convergence or maximum iterations are reached.\n",
        "\n",
        "        Args:\n",
        "            training_sequences: List of correctness sequences for training.\n",
        "            n_iterations:       Maximum EM iterations.\n",
        "            convergence_tol:    Max parameter delta for convergence.\n",
        "\n",
        "        Returns:\n",
        "            Self (fitted model) for method chaining.\n",
        "        \"\"\"\n",
        "        prior, learn, guess, slip = self.prior, self.learn, self.guess, self.slip\n",
        "\n",
        "        for _ in range(n_iterations):\n",
        "            # E-step: accumulate expected counts\n",
        "            n_correct_known = n_wrong_known = 0.0\n",
        "            n_correct_unkn  = n_wrong_unkn  = 0.0\n",
        "            n_learn = n_no_learn = prior_sum = n_students = 0.0\n",
        "\n",
        "            for seq in training_sequences:\n",
        "                p_known = prior\n",
        "                prior_sum  += p_known\n",
        "                n_students += 1\n",
        "                for outcome in seq:\n",
        "                    if outcome:\n",
        "                        n_correct_known += p_known * (1.0 - slip)\n",
        "                        n_correct_unkn  += (1.0 - p_known) * guess\n",
        "                    else:\n",
        "                        n_wrong_known += p_known * slip\n",
        "                        n_wrong_unkn  += (1.0 - p_known) * (1.0 - guess)\n",
        "                    n_learn    += (1.0 - p_known) * learn\n",
        "                    n_no_learn += (1.0 - p_known) * (1.0 - learn)\n",
        "                    p_known = self._bayesian_update(p_known, outcome)\n",
        "\n",
        "            # M-step: update parameters\n",
        "            new_prior = np.clip(prior_sum / (n_students + 1e-10), 0.01, 0.99)\n",
        "            new_learn = np.clip(\n",
        "                n_learn / (n_learn + n_no_learn + 1e-10), 0.01, 0.99)\n",
        "            new_guess = np.clip(\n",
        "                n_correct_unkn / (n_correct_unkn + n_wrong_unkn + 1e-10), 0.01, 0.49)\n",
        "            new_slip  = np.clip(\n",
        "                n_wrong_known / (n_wrong_known + n_correct_known + 1e-10),\n",
        "                0.01, min(0.49, 0.98 - new_guess))\n",
        "\n",
        "            delta = max(\n",
        "                abs(new_prior - prior), abs(new_learn - learn),\n",
        "                abs(new_guess - guess), abs(new_slip  - slip)\n",
        "            )\n",
        "            prior, learn, guess, slip = new_prior, new_learn, new_guess, new_slip\n",
        "            self.prior, self.learn = prior, learn\n",
        "            self.guess, self.slip  = guess, slip\n",
        "\n",
        "            if delta < convergence_tol:\n",
        "                break\n",
        "        return self\n",
        "\n",
        "\n",
        "print('StandardBKT class defined.')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FUJNfUTMDJAN"
      },
      "outputs": [],
      "source": [
        "class TimeAwareBKT(StandardBKT):\n",
        "    \"\"\"Time-Aware BKT extending emission probabilities with logistic decay.\n",
        "\n",
        "    Replaces static P(G) and P(S) with response-time-conditioned functions:\n",
        "\n",
        "        P(G|t) = G_max / (1 + exp( k_g * (t - tau_g)))   [decreasing]\n",
        "        P(S|t) = S_max / (1 + exp( k_s * (t - tau_s)))   [increasing, k_s < 0]\n",
        "\n",
        "    Cognitive rationale:\n",
        "        - Short response time (impulsive): P(G) high (guessing likely),\n",
        "          P(S) low.\n",
        "        - Long response time (effortful): P(G) decreases (guessing less\n",
        "          likely), P(S) increases (fatigue-induced slip more likely).\n",
        "\n",
        "    Constraint enforced at runtime: P(G|t) + P(S|t) < 1 for all t.\n",
        "\n",
        "    Attributes:\n",
        "        prior:  P(L_0), inherited from StandardBKT.\n",
        "        learn:  P(T), inherited from StandardBKT.\n",
        "        G_max:  Maximum asymptote for P(G|t).\n",
        "        S_max:  Maximum asymptote for P(S|t).\n",
        "        k_g:    Decay rate for P(G|t); must be > 0.\n",
        "        k_s:    Growth rate for P(S|t); must be < 0.\n",
        "        tau_g:  Inflection point of P(G|t) in milliseconds.\n",
        "        tau_s:  Inflection point of P(S|t) in milliseconds.\n",
        "    \"\"\"\n",
        "\n",
        "    def __init__(\n",
        "        self,\n",
        "        prior: float = 0.3,\n",
        "        learn: float = 0.1,\n",
        "        G_max:  float = 0.4,\n",
        "        S_max:  float = 0.3,\n",
        "        k_g:    float = 0.001,\n",
        "        k_s:    float = -0.0005,\n",
        "        tau_g:  float = 3000.0,\n",
        "        tau_s:  float = 5000.0\n",
        "    ) -> None:\n",
        "        self.prior = prior\n",
        "        self.learn = learn\n",
        "        self.G_max, self.S_max = G_max, S_max\n",
        "        self.k_g,   self.k_s   = k_g,   k_s\n",
        "        self.tau_g, self.tau_s = tau_g, tau_s\n",
        "\n",
        "    def _guess_probability(self, response_time_ms: float) -> float:\n",
        "        \"\"\"Compute time-conditioned guess probability P(G|t).\n",
        "\n",
        "        Args:\n",
        "            response_time_ms: Response time in milliseconds.\n",
        "\n",
        "        Returns:\n",
        "            P(G|t) in (0, G_max).\n",
        "        \"\"\"\n",
        "        return self.G_max / (1.0 + np.exp(self.k_g * (response_time_ms - self.tau_g)))\n",
        "\n",
        "    def _slip_probability(self, response_time_ms: float) -> float:\n",
        "        \"\"\"Compute time-conditioned slip probability P(S|t).\n",
        "\n",
        "        Clips output to enforce P(G|t) + P(S|t) < 1.\n",
        "\n",
        "        Args:\n",
        "            response_time_ms: Response time in milliseconds.\n",
        "\n",
        "        Returns:\n",
        "            P(S|t) clipped to (0.01, 0.99 - P(G|t)).\n",
        "        \"\"\"\n",
        "        p_slip = self.S_max / (1.0 + np.exp(self.k_s * (response_time_ms - self.tau_s)))\n",
        "        p_guess = self._guess_probability(response_time_ms)\n",
        "        return np.clip(p_slip, 0.01, max(0.01, 0.99 - p_guess))\n",
        "\n",
        "    def _bayesian_update(\n",
        "        self,\n",
        "        p_known: float,\n",
        "        is_correct: int,\n",
        "        response_time_ms: float = 0.0\n",
        "    ) -> float:\n",
        "        \"\"\"Time-conditioned Bayesian update of P(L_n).\n",
        "\n",
        "        Overrides parent method to use dynamic P(G|t) and P(S|t).\n",
        "\n",
        "        Args:\n",
        "            p_known:          Current P(L_{n-1}).\n",
        "            is_correct:       Observed outcome (1/0).\n",
        "            response_time_ms: Response time for this interaction.\n",
        "\n",
        "        Returns:\n",
        "            Updated P(L_n).\n",
        "        \"\"\"\n",
        "        g = self._guess_probability(response_time_ms)\n",
        "        s = self._slip_probability(response_time_ms)\n",
        "        if is_correct:\n",
        "            numerator   = p_known * (1.0 - s)\n",
        "            denominator = numerator + (1.0 - p_known) * g\n",
        "        else:\n",
        "            numerator   = p_known * s\n",
        "            denominator = numerator + (1.0 - p_known) * (1.0 - g)\n",
        "        p_posterior = numerator / (denominator + 1e-10)\n",
        "        return p_posterior + (1.0 - p_posterior) * self.learn\n",
        "\n",
        "    def predict_sequence(\n",
        "        self,\n",
        "        correctness_sequence: List[int],\n",
        "        time_sequence:        List[float]\n",
        "    ) -> List[float]:\n",
        "        \"\"\"Generate predicted P(correct) using time-conditioned emission probs.\n",
        "\n",
        "        Args:\n",
        "            correctness_sequence: Ordered binary outcomes (0/1).\n",
        "            time_sequence:        Response times in ms, same length.\n",
        "\n",
        "        Returns:\n",
        "            Predicted P(correct) for each opportunity.\n",
        "        \"\"\"\n",
        "        p_known = self.prior\n",
        "        predictions: List[float] = []\n",
        "        for outcome, rt in zip(correctness_sequence, time_sequence):\n",
        "            g = self._guess_probability(rt)\n",
        "            s = self._slip_probability(rt)\n",
        "            predictions.append(p_known * (1.0 - s) + (1.0 - p_known) * g)\n",
        "            p_known = self._bayesian_update(p_known, outcome, rt)\n",
        "        return predictions\n",
        "\n",
        "    def tune_time_parameters(\n",
        "        self,\n",
        "        validation_sequences: List[Tuple[List[int], List[float]]]\n",
        "    ) -> 'TimeAwareBKT':\n",
        "        \"\"\"Optimize time-aware parameters to minimize RMSE on validation set.\n",
        "\n",
        "        Uses Nelder-Mead simplex optimization. Constraints k_g > 0 and\n",
        "        k_s < 0 are enforced via early-return penalty.\n",
        "\n",
        "        Args:\n",
        "            validation_sequences: List of (correctness_seq, time_seq) tuples.\n",
        "\n",
        "        Returns:\n",
        "            Self (fitted model) for method chaining.\n",
        "        \"\"\"\n",
        "        def _objective(params: np.ndarray) -> float:\n",
        "            G_max, S_max, k_g, k_s, tau_g, tau_s = params\n",
        "            if G_max <= 0 or S_max <= 0 or G_max + S_max >= 1:\n",
        "                return 1e6\n",
        "            if k_g <= 0 or k_s >= 0:\n",
        "                return 1e6\n",
        "            self.G_max, self.S_max = G_max, S_max\n",
        "            self.k_g,   self.k_s   = k_g,   k_s\n",
        "            self.tau_g, self.tau_s = tau_g, tau_s\n",
        "            squared_errors = [\n",
        "                (y - yhat) ** 2\n",
        "                for cs, ts in validation_sequences\n",
        "                for y, yhat in zip(cs, self.predict_sequence(cs, ts))\n",
        "            ]\n",
        "            return float(np.sqrt(np.mean(squared_errors)))\n",
        "\n",
        "        x0 = [self.G_max, self.S_max, self.k_g, self.k_s, self.tau_g, self.tau_s]\n",
        "        result = optimize.minimize(\n",
        "            _objective, x0, method='Nelder-Mead',\n",
        "            options={'maxiter': 800, 'xatol': 1e-5, 'fatol': 1e-5}\n",
        "        )\n",
        "        G_max, S_max, k_g, k_s, tau_g, tau_s = result.x\n",
        "        self.G_max, self.S_max = G_max, S_max\n",
        "        self.k_g,   self.k_s   = k_g,   k_s\n",
        "        self.tau_g, self.tau_s = tau_g, tau_s\n",
        "        return self\n",
        "\n",
        "\n",
        "print('TimeAwareBKT class defined.')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xEuy_lzpDJAN"
      },
      "source": [
        "---\n",
        "## 6. Training: 5-Fold Cross-Validation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GKmvEnzkDJAN"
      },
      "outputs": [],
      "source": [
        "# ── Utility functions for cross-validation ────────────────────────\n",
        "\n",
        "def build_student_sequences(\n",
        "    df: pd.DataFrame,\n",
        "    include_response_times: bool = False\n",
        ") -> Dict[str, Tuple[List[int], List[float]]]:\n",
        "    \"\"\"Convert a DataFrame into per-student (correctness, time) sequences.\n",
        "\n",
        "    Sequences are sorted chronologically by opportunity_count to\n",
        "    preserve temporal ordering (no shuffling within student).\n",
        "\n",
        "    Args:\n",
        "        df:                     DataFrame with columns: student_id,\n",
        "                                is_correct, ms_first_response,\n",
        "                                opportunity_count.\n",
        "        include_response_times: If True, include ms_first_response.\n",
        "                                If False, fill time list with zeros.\n",
        "\n",
        "    Returns:\n",
        "        Dict mapping student_id → (correctness_list, time_list).\n",
        "    \"\"\"\n",
        "    sequences: Dict[str, Tuple[List[int], List[float]]] = {}\n",
        "    for student_id, group in df.groupby('student_id'):\n",
        "        group_sorted  = group.sort_values('opportunity_count')\n",
        "        correctness   = group_sorted['is_correct'].tolist()\n",
        "        response_times = (\n",
        "            group_sorted['ms_first_response'].tolist()\n",
        "            if include_response_times\n",
        "            else [0.0] * len(correctness)\n",
        "        )\n",
        "        sequences[student_id] = (correctness, response_times)\n",
        "    return sequences\n",
        "\n",
        "\n",
        "def compute_fold_metrics(\n",
        "    model: StandardBKT,\n",
        "    test_sequences: Dict[str, Tuple[List[int], List[float]]],\n",
        "    is_time_aware: bool\n",
        ") -> Tuple[float, float]:\n",
        "    \"\"\"Compute RMSE and AUC-ROC on a test fold.\n",
        "\n",
        "    Args:\n",
        "        model:          Fitted BKT model instance.\n",
        "        test_sequences: Dict of student sequences for evaluation.\n",
        "        is_time_aware:  Whether to pass time sequences to the model.\n",
        "\n",
        "    Returns:\n",
        "        Tuple (rmse, auc). AUC is NaN if only one class present.\n",
        "    \"\"\"\n",
        "    y_true: List[int]   = []\n",
        "    y_pred: List[float] = []\n",
        "    for _, (correctness, times) in test_sequences.items():\n",
        "        if is_time_aware:\n",
        "            predictions = model.predict_sequence(correctness, times)\n",
        "        else:\n",
        "            predictions = model.predict_sequence(correctness)\n",
        "        y_true.extend(correctness)\n",
        "        y_pred.extend(predictions)\n",
        "\n",
        "    y_true_arr = np.array(y_true)\n",
        "    y_pred_arr = np.array(y_pred)\n",
        "    rmse = float(np.sqrt(mean_squared_error(y_true_arr, y_pred_arr)))\n",
        "    try:\n",
        "        auc = float(roc_auc_score(y_true_arr, y_pred_arr))\n",
        "    except ValueError:\n",
        "        auc = float('nan')\n",
        "    return rmse, auc\n",
        "\n",
        "\n",
        "def compute_learning_curve(\n",
        "    df_test: pd.DataFrame\n",
        ") -> Dict[int, float]:\n",
        "    \"\"\"Compute actual error rate per opportunity from a test fold.\n",
        "\n",
        "    Uses actual student outcomes (not model predictions) to compute\n",
        "    empirical learning curves for publication figures.\n",
        "\n",
        "    Args:\n",
        "        df_test: Test-fold DataFrame with is_correct and opportunity_count.\n",
        "\n",
        "    Returns:\n",
        "        Dict mapping opportunity index → mean error rate.\n",
        "    \"\"\"\n",
        "    lc = (\n",
        "        df_test\n",
        "        .groupby('opportunity_count')['is_correct']\n",
        "        .mean()\n",
        "        .reset_index()\n",
        "    )\n",
        "    return {int(row['opportunity_count']): 1.0 - row['is_correct']\n",
        "            for _, row in lc.iterrows()}\n",
        "\n",
        "\n",
        "print('Cross-validation utilities defined.')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "g9csuH0iDJAN"
      },
      "outputs": [],
      "source": [
        "def run_student_level_kfold(\n",
        "    df: pd.DataFrame,\n",
        "    model_type: str = 'standard',\n",
        "    n_folds: int = 5,\n",
        "    seed: int = 42\n",
        ") -> Tuple[pd.DataFrame, Dict[int, float], Dict[int, float], np.ndarray]:\n",
        "    \"\"\"Run student-level k-fold cross-validation.\n",
        "\n",
        "    Partitions students (not interactions) into k folds to prevent\n",
        "    data leakage. Within each fold: 60% train / 20% validation /\n",
        "    20% test. Validation fold is (i+1) % k (cyclic adjacency).\n",
        "    Temporal ordering within each student's sequence is preserved.\n",
        "\n",
        "    Args:\n",
        "        df:         Input DataFrame (standardized schema).\n",
        "        model_type: 'standard' for StandardBKT or 'time_aware' for\n",
        "                    TimeAwareBKT.\n",
        "        n_folds:    Number of folds k. Default 5.\n",
        "        seed:       Random seed for reproducibility.\n",
        "\n",
        "    Returns:\n",
        "        fold_results_df:     DataFrame with RMSE and AUC per fold.\n",
        "        learning_curve_mean: Dict{opportunity → mean error rate} across folds.\n",
        "        learning_curve_std:  Dict{opportunity → std error rate} across folds.\n",
        "        opp10_student_errors: Binary error array at opportunity 10,\n",
        "                              one entry per student (for t-test).\n",
        "    \"\"\"\n",
        "    set_global_seed(seed)\n",
        "    is_time_aware = (model_type == 'time_aware')\n",
        "\n",
        "    # Partition students into k folds\n",
        "    all_student_ids = sorted(df['student_id'].unique())\n",
        "    n_students = len(all_student_ids)\n",
        "    fold_size  = n_students // n_folds\n",
        "    folds: List[List[str]] = [\n",
        "        all_student_ids[i * fold_size : (i + 1) * fold_size if i < n_folds - 1 else n_students]\n",
        "        for i in range(n_folds)\n",
        "    ]\n",
        "\n",
        "    fold_rows:           List[dict] = []\n",
        "    learning_curve_folds: List[Dict[int, float]] = []\n",
        "    opp10_student_errors: List[float] = []\n",
        "\n",
        "    for fold_idx in range(n_folds):\n",
        "        # Assign splits\n",
        "        test_students  = folds[fold_idx]\n",
        "        val_students   = folds[(fold_idx + 1) % n_folds]\n",
        "        train_students = [\n",
        "            s for i, fold in enumerate(folds)\n",
        "            if i != fold_idx and i != (fold_idx + 1) % n_folds\n",
        "            for s in fold\n",
        "        ]\n",
        "\n",
        "        df_train = df[df['student_id'].isin(train_students)]\n",
        "        df_val   = df[df['student_id'].isin(val_students)]\n",
        "        df_test  = df[df['student_id'].isin(test_students)]\n",
        "\n",
        "        train_seqs = build_student_sequences(df_train, is_time_aware)\n",
        "        val_seqs   = build_student_sequences(df_val,   is_time_aware)\n",
        "        test_seqs  = build_student_sequences(df_test,  is_time_aware)\n",
        "\n",
        "        # Fit model\n",
        "        if is_time_aware:\n",
        "            # Step 1: Estimate P(L_0) and P(T) via EM on training data\n",
        "            base_model = StandardBKT(**BKT_INIT)\n",
        "            base_model.fit_em([seq[0] for seq in train_seqs.values()])\n",
        "            # Step 2: Initialize time-aware model and tune on validation data\n",
        "            model = TimeAwareBKT(\n",
        "                prior=base_model.prior,\n",
        "                learn=base_model.learn,\n",
        "                **TIME_AWARE_INIT\n",
        "            )\n",
        "            model.tune_time_parameters(list(val_seqs.values()))\n",
        "        else:\n",
        "            model = StandardBKT(**BKT_INIT)\n",
        "            model.fit_em([seq[0] for seq in train_seqs.values()])\n",
        "\n",
        "        # Evaluate on test fold\n",
        "        fold_rmse, fold_auc = compute_fold_metrics(model, test_seqs, is_time_aware)\n",
        "        lc_fold = compute_learning_curve(df_test)\n",
        "\n",
        "        # Collect student-level errors at opportunity 10 for statistical testing\n",
        "        df_test_opp10 = df_test[df_test['opportunity_count'] == 10]\n",
        "        opp10_student_errors.extend(\n",
        "            (1 - row['is_correct']) for _, row in df_test_opp10.iterrows()\n",
        "        )\n",
        "\n",
        "        fold_rows.append({\n",
        "            'fold':    fold_idx + 1,\n",
        "            'rmse':    fold_rmse,\n",
        "            'auc':     fold_auc,\n",
        "            'n_train': len(train_students),\n",
        "            'n_val':   len(val_students),\n",
        "            'n_test':  len(test_students),\n",
        "        })\n",
        "        learning_curve_folds.append(lc_fold)\n",
        "        print(f'  Fold {fold_idx + 1}/{n_folds} — '\n",
        "              f'RMSE: {fold_rmse:.4f} | AUC: {fold_auc:.4f} | '\n",
        "              f'n_test: {len(test_students):,}')\n",
        "\n",
        "    # Aggregate learning curve across folds\n",
        "    lc_mean = {\n",
        "        opp: float(np.nanmean([lc.get(opp, np.nan) for lc in learning_curve_folds]))\n",
        "        for opp in range(1, MAX_OPPORTUNITIES + 1)\n",
        "    }\n",
        "    lc_std = {\n",
        "        opp: float(np.nanstd([lc.get(opp, np.nan) for lc in learning_curve_folds]))\n",
        "        for opp in range(1, MAX_OPPORTUNITIES + 1)\n",
        "    }\n",
        "\n",
        "    return (\n",
        "        pd.DataFrame(fold_rows),\n",
        "        lc_mean,\n",
        "        lc_std,\n",
        "        np.array(opp10_student_errors)\n",
        "    )\n",
        "\n",
        "\n",
        "print('run_student_level_kfold defined.')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "A73o1kZzDJAN"
      },
      "outputs": [],
      "source": [
        "# ── Run 3 cross-validation experiments ───────────────────────────\n",
        "# Run 1: Standard BKT on EzMath  (baseline for RMSE comparison)\n",
        "# Run 2: Time-Aware BKT on EzMath (main model)\n",
        "# Run 3: Standard BKT on ASSISTments (learning curve baseline)\n",
        "\n",
        "print('Run 1/3 — Standard BKT on EzMath dataset')\n",
        "(\n",
        "    std_bkt_ezmath_folds,\n",
        "    std_bkt_ezmath_lc_mean,\n",
        "    std_bkt_ezmath_lc_std,\n",
        "    std_bkt_ezmath_opp10_errors\n",
        ") = run_student_level_kfold(ezmath_clean, model_type='standard', n_folds=N_FOLDS)\n",
        "display(std_bkt_ezmath_folds)\n",
        "\n",
        "print('\\nRun 2/3 — Time-Aware BKT on EzMath dataset')\n",
        "(\n",
        "    ta_bkt_ezmath_folds,\n",
        "    ta_bkt_ezmath_lc_mean,\n",
        "    ta_bkt_ezmath_lc_std,\n",
        "    ta_bkt_ezmath_opp10_errors\n",
        ") = run_student_level_kfold(ezmath_clean, model_type='time_aware', n_folds=N_FOLDS)\n",
        "display(ta_bkt_ezmath_folds)\n",
        "\n",
        "print('\\nRun 3/3 — Standard BKT on ASSISTments (control baseline)')\n",
        "(\n",
        "    std_bkt_assistments_folds,\n",
        "    std_bkt_assistments_lc_mean,\n",
        "    std_bkt_assistments_lc_std,\n",
        "    assistments_opp10_errors\n",
        ") = run_student_level_kfold(assistments_clean, model_type='standard', n_folds=N_FOLDS)\n",
        "display(std_bkt_assistments_folds)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RnzctwYhDJAO"
      },
      "source": [
        "---\n",
        "## 7. Evaluation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RBhh56BJDJAO"
      },
      "outputs": [],
      "source": [
        "# ── 7.1  RMSE comparison: Standard vs Time-Aware on EzMath ───────\n",
        "\n",
        "rmse_std = std_bkt_ezmath_folds['rmse'].mean()\n",
        "rmse_ta  = ta_bkt_ezmath_folds['rmse'].mean()\n",
        "rmse_improvement_pct = (rmse_std - rmse_ta) / rmse_std * 100\n",
        "\n",
        "rmse_comparison_df = pd.DataFrame({\n",
        "    'Model': ['Standard BKT (EzMath)', 'Time-Aware BKT (EzMath)'],\n",
        "    'RMSE (mean ± SD)': [\n",
        "        f\"{rmse_std:.4f} ± {std_bkt_ezmath_folds['rmse'].std():.4f}\",\n",
        "        f\"{rmse_ta:.4f} ± {ta_bkt_ezmath_folds['rmse'].std():.4f}\",\n",
        "    ],\n",
        "    'AUC (mean ± SD)': [\n",
        "        f\"{std_bkt_ezmath_folds['auc'].mean():.4f} ± {std_bkt_ezmath_folds['auc'].std():.4f}\",\n",
        "        f\"{ta_bkt_ezmath_folds['auc'].mean():.4f} ± {ta_bkt_ezmath_folds['auc'].std():.4f}\",\n",
        "    ],\n",
        "})\n",
        "print('=== Table: RMSE and AUC comparison on EzMath dataset ===')\n",
        "display(rmse_comparison_df)\n",
        "print(f'\\nRMSE improvement (Time-Aware over Standard): {rmse_improvement_pct:.1f}%')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sH4a7VHdDJAO"
      },
      "outputs": [],
      "source": [
        "# ── 7.2  Statistical testing at Opportunity 10 ───────────────────\n",
        "\n",
        "control_errors:      np.ndarray = assistments_opp10_errors    # ASSISTments\n",
        "intervention_errors: np.ndarray = ta_bkt_ezmath_opp10_errors  # EzMath\n",
        "\n",
        "n_control      = len(control_errors)\n",
        "n_intervention = len(intervention_errors)\n",
        "\n",
        "print(f'ASSISTments (control):  n={n_control:,}, '\n",
        "      f'mean={control_errors.mean():.4f}, '\n",
        "      f'SD={control_errors.std(ddof=1):.4f}')\n",
        "print(f'EzMath (intervention):  n={n_intervention:,}, '\n",
        "      f'mean={intervention_errors.mean():.4f}, '\n",
        "      f'SD={intervention_errors.std(ddof=1):.4f}')\n",
        "\n",
        "# Normality check (Shapiro-Wilk, skipped if n > 5000)\n",
        "for label, arr in [('ASSISTments', control_errors), ('EzMath', intervention_errors)]:\n",
        "    if len(arr) <= 5000:\n",
        "        sw = stats.shapiro(arr)\n",
        "        print(f'Shapiro-Wilk ({label}): W={sw.statistic:.4f}, p={sw.pvalue:.4f}')\n",
        "    else:\n",
        "        print(f'Shapiro-Wilk ({label}): n > 5000 — skipped, using Welch/Mann-Whitney')\n",
        "\n",
        "# Homogeneity of variance (Levene's test)\n",
        "levene_result = stats.levene(control_errors, intervention_errors)\n",
        "equal_variances = levene_result.pvalue >= 0.05\n",
        "print(f'\\nLevene\\'s test: stat={levene_result.statistic:.4f}, p={levene_result.pvalue:.6f}')\n",
        "print(f'  Equal variances: {equal_variances} → '\n",
        "      f'{\"standard t-test\" if equal_variances else \"Welch\\'s t-test\"}')\n",
        "\n",
        "# Welch's independent samples t-test\n",
        "ttest_result = stats.ttest_ind(control_errors, intervention_errors, equal_var=False)\n",
        "t_statistic = ttest_result.statistic\n",
        "p_value     = ttest_result.pvalue\n",
        "\n",
        "# Welch-Satterthwaite degrees of freedom\n",
        "var_c = control_errors.var(ddof=1)\n",
        "var_i = intervention_errors.var(ddof=1)\n",
        "df_welch = (\n",
        "    (var_c / n_control + var_i / n_intervention) ** 2 /\n",
        "    ((var_c / n_control) ** 2 / (n_control - 1) +\n",
        "     (var_i / n_intervention) ** 2 / (n_intervention - 1))\n",
        ")\n",
        "\n",
        "# Cohen's d (pooled standard deviation)\n",
        "pooled_sd = np.sqrt(\n",
        "    ((n_control - 1) * var_c + (n_intervention - 1) * var_i) /\n",
        "    (n_control + n_intervention - 2)\n",
        ")\n",
        "cohens_d = (control_errors.mean() - intervention_errors.mean()) / (pooled_sd + 1e-10)\n",
        "\n",
        "# Bonferroni correction\n",
        "alpha_bonferroni = ALPHA / N_COMPARISONS\n",
        "is_significant   = bool(p_value < alpha_bonferroni)\n",
        "\n",
        "print(f'\\n=== Independent samples Welch\\'s t-test at Opportunity 10 ===')\n",
        "print(f'  t = {t_statistic:.4f}')\n",
        "print(f'  p = {p_value:.6f}  ({\"< 0.001\" if p_value < 0.001 else f\"{p_value:.4f}\"})')\n",
        "print(f'  df (Welch) = {df_welch:.2f}')\n",
        "print(f\"  Cohen's d = {cohens_d:.4f}\")\n",
        "print(f'  Bonferroni α = {ALPHA}/{N_COMPARISONS} = {alpha_bonferroni}')\n",
        "print(f'  Significant after Bonferroni correction: {is_significant}')\n",
        "\n",
        "# Mann-Whitney U (non-parametric backup)\n",
        "mwu_result = stats.mannwhitneyu(control_errors, intervention_errors, alternative='two-sided')\n",
        "print(f'\\nMann-Whitney U (non-parametric): U={mwu_result.statistic:.0f}, p={mwu_result.pvalue:.6f}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "m0YKxlLcDJAO"
      },
      "outputs": [],
      "source": [
        "# ── 7.3  Summary table ────────────────────────────────────────────\n",
        "\n",
        "summary_df = pd.DataFrame([\n",
        "    {\n",
        "        'Dataset / Model': 'ASSISTments — Standard BKT (control baseline)',\n",
        "        'RMSE': f\"{std_bkt_assistments_folds['rmse'].mean():.4f} ± {std_bkt_assistments_folds['rmse'].std():.4f}\",\n",
        "        'AUC':  f\"{std_bkt_assistments_folds['auc'].mean():.4f} ± {std_bkt_assistments_folds['auc'].std():.4f}\",\n",
        "    },\n",
        "    {\n",
        "        'Dataset / Model': 'EzMath — Standard BKT',\n",
        "        'RMSE': f\"{std_bkt_ezmath_folds['rmse'].mean():.4f} ± {std_bkt_ezmath_folds['rmse'].std():.4f}\",\n",
        "        'AUC':  f\"{std_bkt_ezmath_folds['auc'].mean():.4f} ± {std_bkt_ezmath_folds['auc'].std():.4f}\",\n",
        "    },\n",
        "    {\n",
        "        'Dataset / Model': 'EzMath — Time-Aware BKT',\n",
        "        'RMSE': f\"{ta_bkt_ezmath_folds['rmse'].mean():.4f} ± {ta_bkt_ezmath_folds['rmse'].std():.4f}\",\n",
        "        'AUC':  f\"{ta_bkt_ezmath_folds['auc'].mean():.4f} ± {ta_bkt_ezmath_folds['auc'].std():.4f}\",\n",
        "    },\n",
        "]).set_index('Dataset / Model')\n",
        "\n",
        "print('=== Model performance summary (5-fold cross-validation) ===')\n",
        "print('Note: RMSE comparison is within EzMath dataset only.')\n",
        "print('      Learning curve comparison is ASSISTments vs EzMath Time-Aware.')\n",
        "display(summary_df)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MVIAIMKwDJAO"
      },
      "source": [
        "---\n",
        "## 8. Visualization"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AGvB4HiWDJAO"
      },
      "outputs": [],
      "source": [
        "# ── Figure 1: Learning curve comparison (publication BW) ─────────\n",
        "\n",
        "opportunity_axis = np.arange(1, MAX_OPPORTUNITIES + 1)\n",
        "\n",
        "control_error_mean = np.array(\n",
        "    [std_bkt_assistments_lc_mean.get(i, np.nan) for i in opportunity_axis])\n",
        "control_error_std  = np.array(\n",
        "    [std_bkt_assistments_lc_std.get(i, 0.0)    for i in opportunity_axis])\n",
        "intervention_error_mean = np.array(\n",
        "    [ta_bkt_ezmath_lc_mean.get(i, np.nan) for i in opportunity_axis])\n",
        "intervention_error_std  = np.array(\n",
        "    [ta_bkt_ezmath_lc_std.get(i, 0.0)    for i in opportunity_axis])\n",
        "\n",
        "fig, ax = plt.subplots(figsize=(7, 4.5))\n",
        "fig.patch.set_facecolor('white')\n",
        "ax.set_facecolor('white')\n",
        "\n",
        "# ±1 SD bands\n",
        "ax.fill_between(opportunity_axis,\n",
        "                control_error_mean - control_error_std,\n",
        "                control_error_mean + control_error_std,\n",
        "                color='black', alpha=0.07)\n",
        "ax.fill_between(opportunity_axis,\n",
        "                intervention_error_mean - intervention_error_std,\n",
        "                intervention_error_mean + intervention_error_std,\n",
        "                color='black', alpha=0.07)\n",
        "\n",
        "# Control: solid line, filled circle\n",
        "ax.plot(opportunity_axis, control_error_mean,\n",
        "        color='black', linewidth=1.8, linestyle='-',\n",
        "        marker='o', markersize=6,\n",
        "        markerfacecolor='black', markeredgecolor='black',\n",
        "        label='ASSISTments baseline (control)')\n",
        "\n",
        "# Intervention: dashed line, open square\n",
        "ax.plot(opportunity_axis, intervention_error_mean,\n",
        "        color='black', linewidth=1.8, linestyle='--',\n",
        "        marker='s', markersize=6,\n",
        "        markerfacecolor='white', markeredgecolor='black', markeredgewidth=1.5,\n",
        "        label='EzMath intervention (time-aware BKT)')\n",
        "\n",
        "# Intersection annotation\n",
        "for i in range(len(opportunity_axis) - 1):\n",
        "    vals = [control_error_mean[i], intervention_error_mean[i],\n",
        "            control_error_mean[i+1], intervention_error_mean[i+1]]\n",
        "    if not any(np.isnan(vals)):\n",
        "        if (vals[0] - vals[1]) * (vals[2] - vals[3]) < 0:\n",
        "            d_control      = control_error_mean[i+1]      - control_error_mean[i]\n",
        "            d_intervention = intervention_error_mean[i+1] - intervention_error_mean[i]\n",
        "            x_cross = opportunity_axis[i] + (\n",
        "                (intervention_error_mean[i] - control_error_mean[i]) /\n",
        "                (d_control - d_intervention)\n",
        "            )\n",
        "            y_cross = control_error_mean[i] + d_control * (x_cross - opportunity_axis[i])\n",
        "            ax.axvline(x_cross, color='black', linestyle=':', linewidth=1.0, alpha=0.45)\n",
        "            ax.annotate(\n",
        "                f'Intersection\\n(opp. {x_cross:.1f})',\n",
        "                xy=(x_cross, y_cross),\n",
        "                xytext=(x_cross + 0.5, y_cross + 0.04),\n",
        "                fontsize=9, color='black',\n",
        "                arrowprops=dict(arrowstyle='->', color='black', lw=0.8)\n",
        "            )\n",
        "            break\n",
        "\n",
        "ax.set_xlabel('Opportunity count (number of attempts)')\n",
        "ax.set_ylabel('Probability of error')\n",
        "ax.set_title(\n",
        "    'Learning curve comparison: Addition and Subtraction Fractions'\n",
        ")\n",
        "ax.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1, decimals=0))\n",
        "ax.set_xticks(opportunity_axis)\n",
        "ax.set_xlim(0.5, MAX_OPPORTUNITIES + 0.5)\n",
        "ax.set_ylim(0, min(1.0,\n",
        "    max(np.nanmax(control_error_mean), np.nanmax(intervention_error_mean)) + 0.12))\n",
        "ax.legend(frameon=True, edgecolor='black', facecolor='white')\n",
        "ax.grid(True, color='#cccccc', linewidth=0.5, linestyle='-')\n",
        "for spine in ax.spines.values():\n",
        "    spine.set_visible(True)\n",
        "    spine.set_color('black')\n",
        "    spine.set_linewidth(0.7)\n",
        "\n",
        "plt.tight_layout()\n",
        "\n",
        "fig_lc_path = OUTPUT_DIR / 'fig1_learning_curve.png'\n",
        "fig_lc_pdf  = OUTPUT_DIR / 'fig1_learning_curve.pdf'\n",
        "fig.savefig(fig_lc_path, dpi=300, facecolor='white')\n",
        "fig.savefig(fig_lc_pdf,  facecolor='white')\n",
        "plt.show()\n",
        "print(f'Figure 1 saved: {fig_lc_path.name} + {fig_lc_pdf.name}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IE4LtQqGDJAO"
      },
      "outputs": [],
      "source": [
        "# ── Figure 2: ROC curve comparison (publication BW) ──────────────\n",
        "# Uses predictions already collected during Run 1 & 2 in Section 6.\n",
        "# No re-training required.\n",
        "\n",
        "def collect_all_predictions(\n",
        "    df: pd.DataFrame,\n",
        "    model_type: str,\n",
        "    n_folds: int = 5,\n",
        "    seed: int = 42\n",
        ") -> Tuple[np.ndarray, np.ndarray]:\n",
        "    \"\"\"Collect pooled y_true and y_pred across all k-fold test sets.\n",
        "\n",
        "    Replicates the exact same fold splits and model training as\n",
        "    run_student_level_kfold to ensure predictions are consistent.\n",
        "\n",
        "    Args:\n",
        "        df:         Input DataFrame (standardized schema).\n",
        "        model_type: 'standard' or 'time_aware'.\n",
        "        n_folds:    Number of folds.\n",
        "        seed:       Must match the seed used in run_student_level_kfold.\n",
        "\n",
        "    Returns:\n",
        "        Tuple (y_true, y_pred) as flat numpy arrays.\n",
        "    \"\"\"\n",
        "    set_global_seed(seed)\n",
        "    is_time_aware = (model_type == 'time_aware')\n",
        "\n",
        "    all_student_ids = sorted(df['student_id'].unique())\n",
        "    n_students = len(all_student_ids)\n",
        "    fold_size  = n_students // n_folds\n",
        "    folds = [\n",
        "        all_student_ids[i * fold_size : (i + 1) * fold_size if i < n_folds - 1 else n_students]\n",
        "        for i in range(n_folds)\n",
        "    ]\n",
        "\n",
        "    all_true: List[int]   = []\n",
        "    all_pred: List[float] = []\n",
        "\n",
        "    for fold_idx in range(n_folds):\n",
        "        test_students  = folds[fold_idx]\n",
        "        val_students   = folds[(fold_idx + 1) % n_folds]\n",
        "        train_students = [\n",
        "            s for i, fold in enumerate(folds)\n",
        "            if i != fold_idx and i != (fold_idx + 1) % n_folds\n",
        "            for s in fold\n",
        "        ]\n",
        "        df_train = df[df['student_id'].isin(train_students)]\n",
        "        df_val   = df[df['student_id'].isin(val_students)]\n",
        "        df_test  = df[df['student_id'].isin(test_students)]\n",
        "\n",
        "        train_seqs = build_student_sequences(df_train, is_time_aware)\n",
        "        val_seqs   = build_student_sequences(df_val,   is_time_aware)\n",
        "        test_seqs  = build_student_sequences(df_test,  is_time_aware)\n",
        "\n",
        "        if is_time_aware:\n",
        "            base_model = StandardBKT(**BKT_INIT)\n",
        "            base_model.fit_em([seq[0] for seq in train_seqs.values()])\n",
        "            model = TimeAwareBKT(prior=base_model.prior, learn=base_model.learn,\n",
        "                                 **TIME_AWARE_INIT)\n",
        "            model.tune_time_parameters(list(val_seqs.values()))\n",
        "        else:\n",
        "            model = StandardBKT(**BKT_INIT)\n",
        "            model.fit_em([seq[0] for seq in train_seqs.values()])\n",
        "\n",
        "        for _, (cs, ts) in test_seqs.items():\n",
        "            preds = model.predict_sequence(cs, ts) if is_time_aware else model.predict_sequence(cs)\n",
        "            all_true.extend(cs)\n",
        "            all_pred.extend(preds)\n",
        "\n",
        "    return np.array(all_true), np.array(all_pred)\n",
        "\n",
        "\n",
        "print('Collecting predictions for ROC (Standard BKT)...')\n",
        "y_true_std, y_pred_std = collect_all_predictions(ezmath_clean, 'standard')\n",
        "print('Collecting predictions for ROC (Time-Aware BKT)...')\n",
        "y_true_ta,  y_pred_ta  = collect_all_predictions(ezmath_clean, 'time_aware')\n",
        "\n",
        "fpr_std, tpr_std, _ = roc_curve(y_true_std, y_pred_std)\n",
        "fpr_ta,  tpr_ta,  _ = roc_curve(y_true_ta,  y_pred_ta)\n",
        "auc_std = float(roc_auc_score(y_true_std, y_pred_std))\n",
        "auc_ta  = float(roc_auc_score(y_true_ta,  y_pred_ta))\n",
        "\n",
        "fig, ax = plt.subplots(figsize=(5.5, 5.5))\n",
        "fig.patch.set_facecolor('white')\n",
        "ax.set_facecolor('white')\n",
        "\n",
        "ax.plot(fpr_std, tpr_std,\n",
        "        color='black', linewidth=1.6, linestyle='-',\n",
        "        label=f'Standard BKT (AUC = {auc_std:.4f})')\n",
        "ax.plot(fpr_ta, tpr_ta,\n",
        "        color='black', linewidth=1.6, linestyle='--',\n",
        "        label=f'Time-Aware BKT (AUC = {auc_ta:.4f})')\n",
        "ax.plot([0, 1], [0, 1],\n",
        "        color='black', linewidth=0.8, linestyle=':',\n",
        "        label='Chance level (AUC = 0.50)')\n",
        "\n",
        "ax.set_xlabel('False Positive Rate')\n",
        "ax.set_ylabel('True Positive Rate')\n",
        "ax.set_title(\n",
        "    'Receiver Operating Characteristic (ROC) curves\\n'\n",
        "    'Standard BKT vs. Time-Aware BKT (EzMath, 5-fold cross-validation)'\n",
        ")\n",
        "ax.set_xlim(0.0, 1.0)\n",
        "ax.set_ylim(0.0, 1.02)\n",
        "ax.legend(frameon=True, edgecolor='black', facecolor='white', loc='lower right')\n",
        "ax.grid(True, color='#cccccc', linewidth=0.5, linestyle='-')\n",
        "for spine in ax.spines.values():\n",
        "    spine.set_visible(True)\n",
        "    spine.set_color('black')\n",
        "    spine.set_linewidth(0.7)\n",
        "\n",
        "plt.tight_layout()\n",
        "\n",
        "fig_roc_path = OUTPUT_DIR / 'fig2_roc_curve.png'\n",
        "fig_roc_pdf  = OUTPUT_DIR / 'fig2_roc_curve.pdf'\n",
        "fig.savefig(fig_roc_path, dpi=300, facecolor='white')\n",
        "fig.savefig(fig_roc_pdf,  facecolor='white')\n",
        "plt.show()\n",
        "print(f'Figure 2 saved: {fig_roc_path.name} + {fig_roc_pdf.name}')\n",
        "print(f'AUC Standard BKT   = {auc_std:.4f}')\n",
        "print(f'AUC Time-Aware BKT = {auc_ta:.4f}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lsDHCQkNDJAO"
      },
      "source": [
        "---\n",
        "## 9. Export"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mkn8jttADJAP"
      },
      "outputs": [],
      "source": [
        "# ── Export 1: Per-fold results (all 3 runs) ───────────────────────\n",
        "std_bkt_ezmath_folds['dataset']    = 'EzMath'\n",
        "std_bkt_ezmath_folds['model']      = 'Standard BKT'\n",
        "ta_bkt_ezmath_folds['dataset']     = 'EzMath'\n",
        "ta_bkt_ezmath_folds['model']       = 'Time-Aware BKT'\n",
        "std_bkt_assistments_folds['dataset'] = 'ASSISTments'\n",
        "std_bkt_assistments_folds['model']   = 'Standard BKT'\n",
        "\n",
        "fold_results_df = pd.concat([\n",
        "    std_bkt_assistments_folds,\n",
        "    std_bkt_ezmath_folds,\n",
        "    ta_bkt_ezmath_folds,\n",
        "], ignore_index=True)\n",
        "\n",
        "fold_results_path = OUTPUT_DIR / 'fold_results.csv'\n",
        "fold_results_df.to_csv(fold_results_path, index=False)\n",
        "print(f'Exported: {fold_results_path.name}')\n",
        "display(fold_results_df)\n",
        "\n",
        "# ── Export 2: Learning curve data ────────────────────────────────\n",
        "lc_records = []\n",
        "for opp in range(1, MAX_OPPORTUNITIES + 1):\n",
        "    lc_records += [\n",
        "        {'opportunity': opp, 'dataset': 'ASSISTments',\n",
        "         'model': 'Standard BKT',\n",
        "         'error_rate': std_bkt_assistments_lc_mean.get(opp),\n",
        "         'error_rate_sd': std_bkt_assistments_lc_std.get(opp)},\n",
        "        {'opportunity': opp, 'dataset': 'EzMath',\n",
        "         'model': 'Standard BKT',\n",
        "         'error_rate': std_bkt_ezmath_lc_mean.get(opp),\n",
        "         'error_rate_sd': std_bkt_ezmath_lc_std.get(opp)},\n",
        "        {'opportunity': opp, 'dataset': 'EzMath',\n",
        "         'model': 'Time-Aware BKT',\n",
        "         'error_rate': ta_bkt_ezmath_lc_mean.get(opp),\n",
        "         'error_rate_sd': ta_bkt_ezmath_lc_std.get(opp)},\n",
        "    ]\n",
        "lc_df = pd.DataFrame(lc_records)\n",
        "lc_path = OUTPUT_DIR / 'learning_curve.csv'\n",
        "lc_df.to_csv(lc_path, index=False)\n",
        "print(f'Exported: {lc_path.name}')\n",
        "display(lc_df)\n",
        "\n",
        "# ── Export 3: Statistical summary ────────────────────────────────\n",
        "stat_summary_records = [\n",
        "    {'metric': 'RMSE — Standard BKT (EzMath)',\n",
        "     'value': f\"{std_bkt_ezmath_folds['rmse'].mean():.4f} ± {std_bkt_ezmath_folds['rmse'].std():.4f}\"},\n",
        "    {'metric': 'RMSE — Time-Aware BKT (EzMath)',\n",
        "     'value': f\"{ta_bkt_ezmath_folds['rmse'].mean():.4f} ± {ta_bkt_ezmath_folds['rmse'].std():.4f}\"},\n",
        "    {'metric': 'RMSE reduction (%)',\n",
        "     'value': f'{rmse_improvement_pct:.1f}%'},\n",
        "    {'metric': 'AUC — Standard BKT (EzMath)',\n",
        "     'value': f\"{std_bkt_ezmath_folds['auc'].mean():.4f} ± {std_bkt_ezmath_folds['auc'].std():.4f}\"},\n",
        "    {'metric': 'AUC — Time-Aware BKT (EzMath)',\n",
        "     'value': f\"{ta_bkt_ezmath_folds['auc'].mean():.4f} ± {ta_bkt_ezmath_folds['auc'].std():.4f}\"},\n",
        "    {'metric': 'Error rate opp.10 — ASSISTments',\n",
        "     'value': f\"{control_errors.mean():.4f} (n={n_control:,})\"},\n",
        "    {'metric': 'Error rate opp.10 — EzMath',\n",
        "     'value': f\"{intervention_errors.mean():.4f} (n={n_intervention:,})\"},\n",
        "    {'metric': 'Welch t-statistic',\n",
        "     'value': f't = {t_statistic:.4f}, p = {p_value:.6f}, df = {df_welch:.1f}'},\n",
        "    {'metric': \"Cohen's d\",\n",
        "     'value': f'{cohens_d:.4f}'},\n",
        "    {'metric': 'Bonferroni correction (α = 0.005)',\n",
        "     'value': 'Significant' if is_significant else 'Not significant'},\n",
        "    {'metric': 'Mann-Whitney U',\n",
        "     'value': f'U = {mwu_result.statistic:.0f}, p = {mwu_result.pvalue:.6f}'},\n",
        "]\n",
        "stat_df = pd.DataFrame(stat_summary_records)\n",
        "stat_path = OUTPUT_DIR / 'statistical_summary.csv'\n",
        "stat_df.to_csv(stat_path, index=False)\n",
        "print(f'Exported: {stat_path.name}')\n",
        "display(stat_df)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ttedtGmUDJAP"
      },
      "outputs": [],
      "source": [
        "# ── Auto-generated research report ───────────────────────────────\n",
        "\n",
        "report_content = f\"\"\"\n",
        "# EzMath BKT Cross-Validation — Research Report\n",
        "\n",
        "Generated automatically. Random seed: {RANDOM_SEED}.\n",
        "\n",
        "## 1. Dataset summary after preprocessing\n",
        "\n",
        "| | ASSISTments (control) | EzMath (intervention) |\n",
        "|---|---|---|\n",
        "| Students | {assistments_clean['student_id'].nunique():,} | {ezmath_clean['student_id'].nunique():,} |\n",
        "| Records | {len(assistments_clean):,} | {len(ezmath_clean):,} |\n",
        "| Model | Standard BKT | Time-Aware BKT |\n",
        "\n",
        "## 2. Predictive performance (5-fold cross-validation, EzMath dataset)\n",
        "\n",
        "| Model | RMSE | AUC |\n",
        "|---|---|---|\n",
        "| Standard BKT | {std_bkt_ezmath_folds['rmse'].mean():.4f} ± {std_bkt_ezmath_folds['rmse'].std():.4f} | {std_bkt_ezmath_folds['auc'].mean():.4f} ± {std_bkt_ezmath_folds['auc'].std():.4f} |\n",
        "| Time-Aware BKT | {ta_bkt_ezmath_folds['rmse'].mean():.4f} ± {ta_bkt_ezmath_folds['rmse'].std():.4f} | {ta_bkt_ezmath_folds['auc'].mean():.4f} ± {ta_bkt_ezmath_folds['auc'].std():.4f} |\n",
        "| RMSE reduction | {rmse_improvement_pct:.1f}% | — |\n",
        "\n",
        "## 3. Statistical analysis at Opportunity 10 (student-level)\n",
        "\n",
        "| Group | N | Mean error | SD |\n",
        "|---|---|---|---|\n",
        "| ASSISTments (control) | {n_control:,} | {control_errors.mean():.4f} | {control_errors.std(ddof=1):.4f} |\n",
        "| EzMath (intervention) | {n_intervention:,} | {intervention_errors.mean():.4f} | {intervention_errors.std(ddof=1):.4f} |\n",
        "\n",
        "- Levene's test: p = {levene_result.pvalue:.4f} → {'unequal variances' if not equal_variances else 'equal variances'}\n",
        "- Welch's t-test: t = {t_statistic:.4f}, p = {p_value:.6f}, df = {df_welch:.1f}\n",
        "- Cohen's d = {cohens_d:.4f}\n",
        "- Bonferroni α = {alpha_bonferroni}: {'Significant' if is_significant else 'Not significant'}\n",
        "- Mann-Whitney U = {mwu_result.statistic:.0f}, p = {mwu_result.pvalue:.6f}\n",
        "\n",
        "## 4. Limitations\n",
        "\n",
        "1. **Survivorship bias:** ASSISTments baseline retains only students with ≥{MIN_OPPORTUNITIES}\n",
        "   practice opportunities, excluding early-mastery learners.\n",
        "2. **P(S|t) directionality:** The monotone-increasing sigmoid for slip probability\n",
        "   requires further validation across extended response time ranges.\n",
        "3. **AUC interpretation:** Values below 1.0 are expected and appropriate;\n",
        "   prior reports of AUC = 1.0 warrant methodological re-examination.\n",
        "\"\"\"\n",
        "\n",
        "display(Markdown(report_content))\n",
        "report_path = OUTPUT_DIR / 'auto_report.md'\n",
        "with open(report_path, 'w', encoding='utf-8') as f:\n",
        "    f.write(report_content)\n",
        "print(f'Report exported: {report_path.name}')"
      ]
    }
  ]
}