Source code for prefgraph.datasets._online_retail_ii

"""Online Retail II dataset loader.

Loads the Online Retail II dataset of ~5,942 UK e-commerce customers
over Dec 2009 to Dec 2011, returning a BehaviorPanel.

Budget-based with real prices. Monthly aggregation, top-N products
by transaction frequency.

Data must be downloaded separately from UCI ML Repository.
"""

from __future__ import annotations

import os
from pathlib import Path

import numpy as np

from prefgraph.core.panel import BehaviorPanel
from prefgraph.core.session import BehaviorLog

# --- Constants ---

MIN_UNIT_PRICE = 0.01
MAX_UNIT_PRICE = 500.0
TOP_N_CATEGORIES = 30
MIN_MONTHS = 4
CUTOFF_DATE = "2011-06-01"


def _find_data_dir(data_dir: str | Path | None) -> Path:
    """Find Online Retail II data directory via cascade."""
    candidates = []
    if data_dir is not None:
        candidates.append(Path(data_dir))

    env = os.environ.get("PYREVEALED_DATA_DIR")
    if env:
        candidates.append(Path(env) / "online_retail_ii")

    candidates.extend([
        Path.home() / ".prefgraph" / "data" / "online_retail_ii",
        Path(__file__).resolve().parents[3] / "datasets" / "online_retail_ii" / "data",
    ])

    for d in candidates:
        if d.is_dir():
            for fname in [
                "online_retail_II.csv",
                "online_retail_ii.csv",
                "Online Retail II.csv",
                "online_retail_II.xlsx",
            ]:
                if (d / fname).exists():
                    return d

    searched = "\n  ".join(str(c) for c in candidates)
    raise FileNotFoundError(
        f"Online Retail II data not found. Searched:\n  {searched}\n\n"
        "Download from: https://archive.ics.uci.edu/dataset/502/online+retail+ii\n"
        "Place online_retail_II.csv in ~/.prefgraph/data/online_retail_ii/\n"
        "Or pass data_dir= or set PYREVEALED_DATA_DIR environment variable."
    )


def _load_raw(data_path: Path) -> "pd.DataFrame":
    """Load raw data, trying csv then xlsx."""
    import pandas as pd

    for fname in [
        "online_retail_II.csv",
        "online_retail_ii.csv",
        "Online Retail II.csv",
    ]:
        fpath = data_path / fname
        if fpath.exists():
            return pd.read_csv(fpath)

    xlsx_path = data_path / "online_retail_II.xlsx"
    if xlsx_path.exists():
        return pd.read_excel(xlsx_path)

    raise FileNotFoundError(f"No Online Retail II file found in {data_path}")


[docs] def load_online_retail_ii( data_dir: str | Path | None = None, n_customers: int | None = None, min_months: int = MIN_MONTHS, top_n_categories: int = TOP_N_CATEGORIES, ) -> BehaviorPanel: """Load Online Retail II dataset as a BehaviorPanel. Args: data_dir: Path to directory containing online_retail_II.csv. If None, searches standard locations. n_customers: Max number of customers to include (None = all). min_months: Minimum active months per customer (default 4). top_n_categories: Number of top products by frequency (default 30). Returns: BehaviorPanel with one BehaviorLog per customer. Metadata includes 'cutoff_date' for train/test splitting. Raises: FileNotFoundError: If data files cannot be found. ImportError: If pandas is not installed. """ try: import pandas as pd except ImportError: raise ImportError( "pandas is required for dataset loaders. " "Install with: pip install 'prefgraph[datasets]'" ) from None data_path = _find_data_dir(data_dir) df = _load_raw(data_path) # --- Filtering --- # Remove cancellations (Invoice starts with 'C') df["Invoice"] = df["Invoice"].astype(str) df = df[~df["Invoice"].str.startswith("C")] # Drop null Customer IDs df = df.dropna(subset=["Customer ID"]) df["Customer ID"] = df["Customer ID"].astype(int) # Filter negative/zero quantities and price outliers df = df[ (df["Quantity"] > 0) & (df["Price"].between(MIN_UNIT_PRICE, MAX_UNIT_PRICE)) ] # Parse dates df["InvoiceDate"] = pd.to_datetime(df["InvoiceDate"], errors="coerce") df = df.dropna(subset=["InvoiceDate"]) # --- Monthly periods --- df["period"] = df["InvoiceDate"].dt.to_period("M").astype(str) # --- Select top products by transaction count --- df["StockCode"] = df["StockCode"].astype(str) top_products = df["StockCode"].value_counts().head(top_n_categories).index.tolist() df = df[df["StockCode"].isin(top_products)] products = sorted(str(p) for p in top_products) # --- Build price oracle: median price per month per product --- periods = sorted(df["period"].unique()) price_pivot = df.pivot_table( values="Price", index="period", columns="StockCode", aggfunc="median", ).reindex(index=periods, columns=products) price_pivot = price_pivot.ffill().bfill().fillna(price_pivot.median()) price_grid = price_pivot.values period_to_idx = {p: i for i, p in enumerate(periods)} # --- Build per-customer BehaviorLogs --- logs: dict[str, BehaviorLog] = {} grouped = df.groupby("Customer ID") customer_ids = list(grouped.groups.keys()) if n_customers is not None: customer_ids = customer_ids[:n_customers] for cid in customer_ids: cust_data = grouped.get_group(cid) # Aggregate quantity per month per product qty_pivot = cust_data.pivot_table( values="Quantity", index="period", columns="StockCode", aggfunc="sum", ).reindex(columns=products).fillna(0) active_periods = qty_pivot[qty_pivot.sum(axis=1) > 0].index.tolist() if len(active_periods) < min_months: continue qty_matrix = qty_pivot.loc[active_periods].values price_indices = [period_to_idx[p] for p in active_periods if p in period_to_idx] if len(price_indices) != len(active_periods): continue price_matrix = price_grid[price_indices] uid = f"customer_{cid}" logs[uid] = BehaviorLog( cost_vectors=price_matrix, action_vectors=qty_matrix, user_id=uid, ) return BehaviorPanel( _logs=logs, metadata={ "dataset": "online_retail_ii", "goods": products, "min_months": min_months, "top_n_categories": top_n_categories, "cutoff_date": CUTOFF_DATE, "date_range": "2009-12 to 2011-12", }, )