Source code for prefgraph.datasets._dunnhumby

"""Dunnhumby grocery dataset loader.

Loads the Dunnhumby "The Complete Journey" dataset of ~2,500 household
grocery transactions over 104 weeks, returning a BehaviorPanel.

Data must be downloaded separately from Kaggle.
"""

from __future__ import annotations

import os
from pathlib import Path

import numpy as np

from prefgraph.core.panel import BehaviorPanel
from prefgraph.core.session import BehaviorLog

# --- Constants ---

TOP_COMMODITIES = [
    "SOFT DRINKS", "FLUID MILK PRODUCTS", "BAKED BREAD/BUNS/ROLLS",
    "CHEESE", "BAG SNACKS", "SOUP", "YOGURT", "BEEF",
    "FROZEN PIZZA", "LUNCHMEAT",
]

NUM_WEEKS = 104
NUM_PRODUCTS = len(TOP_COMMODITIES)
MIN_UNIT_PRICE = 0.01
MAX_UNIT_PRICE = 50.0


def _find_data_dir(data_dir: str | Path | None) -> Path:
    """Find dunnhumby data directory via cascade."""
    candidates = []
    if data_dir is not None:
        candidates.append(Path(data_dir))

    env = os.environ.get("PYREVEALED_DATA_DIR")
    if env:
        candidates.append(Path(env) / "dunnhumby")

    candidates.extend([
        Path.home() / ".prefgraph" / "data" / "dunnhumby",
        Path(__file__).resolve().parents[3] / "dunnhumby" / "data",
    ])

    for d in candidates:
        if d.is_dir() and (d / "transaction_data.csv").exists():
            return d

    searched = "\n  ".join(str(c) for c in candidates)
    raise FileNotFoundError(
        f"Dunnhumby data not found. Searched:\n  {searched}\n\n"
        "Download from Kaggle: https://www.kaggle.com/datasets/frtgnn/dunnhumby-the-complete-journey\n"
        "Then pass data_dir= or set PYREVEALED_DATA_DIR environment variable."
    )


[docs] def load_dunnhumby( data_dir: str | Path | None = None, n_households: int | None = None, min_weeks: int = 10, period: str | None = None, ) -> BehaviorPanel: """Load Dunnhumby grocery dataset as a BehaviorPanel. Args: data_dir: Path to directory containing transaction_data.csv and product.csv. If None, searches standard locations. n_households: Max number of households to include (None = all). min_weeks: Minimum active shopping weeks per household (default 10). period: Time aggregation level. None = one BehaviorLog per household across all weeks. "month" = split into monthly sub-sessions. Returns: BehaviorPanel with one BehaviorLog per household (or household-month). Raises: FileNotFoundError: If data files cannot be found. ImportError: If pandas is not installed. """ try: import pandas as pd except ImportError: raise ImportError( "pandas is required for dataset loaders. " "Install with: pip install 'prefgraph[datasets]'" ) from None data_path = _find_data_dir(data_dir) # Load and join transactions = pd.read_csv(data_path / "transaction_data.csv") products = pd.read_csv(data_path / "product.csv") merged = transactions.merge(products[["PRODUCT_ID", "COMMODITY_DESC"]], on="PRODUCT_ID") # Filter to top commodities merged = merged[merged["COMMODITY_DESC"].isin(TOP_COMMODITIES)] # Calculate week and unit price merged["week"] = ((merged["DAY"] - 1) // 7) + 1 merged["unit_price"] = ( (merged["SALES_VALUE"] - merged["RETAIL_DISC"] - merged["COUPON_DISC"]) / merged["QUANTITY"] ) merged = merged[ (merged["unit_price"] >= MIN_UNIT_PRICE) & (merged["unit_price"] <= MAX_UNIT_PRICE) & (merged["QUANTITY"] > 0) ] # Build price oracle: median price per week per commodity price_pivot = merged.pivot_table( values="unit_price", index="week", columns="COMMODITY_DESC", aggfunc="median", ).reindex(index=range(1, NUM_WEEKS + 1), columns=TOP_COMMODITIES) price_pivot = price_pivot.ffill().bfill() price_grid = price_pivot.values # (104, 10) # Optional month mapping if period == "month": merged["period"] = ((merged["week"] - 1) // 4) + 1 group_col = "period" else: group_col = None # Build per-household sessions logs: dict[str, BehaviorLog] = {} period_map: dict[str, tuple[str, str]] | None = None if group_col is not None: period_map = {} grouped = merged.groupby("household_key") hh_keys = list(grouped.groups.keys()) if n_households is not None: hh_keys = hh_keys[:n_households] for hh_key in hh_keys: hh_data = grouped.get_group(hh_key) if group_col is not None: # Split by period for period_val, period_data in hh_data.groupby(group_col): qty_pivot = period_data.pivot_table( values="QUANTITY", index="week", columns="COMMODITY_DESC", aggfunc="sum", ).reindex(columns=TOP_COMMODITIES).fillna(0) active_weeks = qty_pivot.index.tolist() if len(active_weeks) < 2: continue qty_matrix = qty_pivot.values price_matrix = price_grid[np.array(active_weeks) - 1] # 0-indexed uid = f"household_{hh_key}__period_{period_val}" logs[uid] = BehaviorLog( cost_vectors=price_matrix, action_vectors=qty_matrix, user_id=uid, ) period_map[uid] = (f"household_{hh_key}", str(int(period_val))) else: # All weeks together qty_pivot = hh_data.pivot_table( values="QUANTITY", index="week", columns="COMMODITY_DESC", aggfunc="sum", ).reindex(columns=TOP_COMMODITIES).fillna(0) active_weeks = qty_pivot[qty_pivot.sum(axis=1) > 0].index.tolist() if len(active_weeks) < min_weeks: continue qty_matrix = qty_pivot.loc[active_weeks].values price_matrix = price_grid[np.array(active_weeks) - 1] uid = f"household_{hh_key}" logs[uid] = BehaviorLog( cost_vectors=price_matrix, action_vectors=qty_matrix, user_id=uid, ) return BehaviorPanel( _logs=logs, metadata={ "dataset": "dunnhumby", "goods": TOP_COMMODITIES, "min_weeks": min_weeks, "period": period, }, _period_map=period_map, )