Source code for prefgraph.datasets._uci_retail
"""UCI Online Retail dataset loader.
Loads the UCI Online Retail dataset of ~1,800 UK B2B customers,
returning a BehaviorPanel.
Data must be downloaded separately from UCI ML Repository.
"""
from __future__ import annotations
import os
from pathlib import Path
import numpy as np
from prefgraph.core.panel import BehaviorPanel
from prefgraph.core.session import BehaviorLog
# --- Constants ---
MIN_UNIT_PRICE = 0.01
MAX_UNIT_PRICE = 500.0
TOP_N_PRODUCTS = 50
MIN_TRANSACTIONS = 5
def _find_data_dir(data_dir: str | Path | None) -> Path:
"""Find UCI retail data directory via cascade."""
candidates = []
if data_dir is not None:
candidates.append(Path(data_dir))
env = os.environ.get("PYREVEALED_DATA_DIR")
if env:
candidates.append(Path(env) / "uci_retail")
candidates.extend([
Path.home() / ".prefgraph" / "data" / "uci_retail",
Path(__file__).resolve().parents[3] / "datasets" / "uci_retail" / "data",
])
for d in candidates:
if d.is_dir():
for fname in ["online_retail.xlsx", "Online Retail.xlsx", "online_retail.csv"]:
if (d / fname).exists():
return d
searched = "\n ".join(str(c) for c in candidates)
raise FileNotFoundError(
f"UCI Online Retail data not found. Searched:\n {searched}\n\n"
"Download from: https://archive.ics.uci.edu/ml/datasets/Online+Retail\n"
"Then pass data_dir= or set PYREVEALED_DATA_DIR environment variable."
)
def _load_raw(data_path: Path) -> "pd.DataFrame":
"""Load raw data, trying xlsx then csv."""
import pandas as pd
for fname in ["online_retail.xlsx", "Online Retail.xlsx"]:
fpath = data_path / fname
if fpath.exists():
return pd.read_excel(fpath)
csv_path = data_path / "online_retail.csv"
if csv_path.exists():
return pd.read_csv(csv_path)
raise FileNotFoundError(f"No online_retail file found in {data_path}")
[docs]
def load_uci_retail(
data_dir: str | Path | None = None,
n_customers: int | None = None,
min_transactions: int = MIN_TRANSACTIONS,
top_n_products: int = TOP_N_PRODUCTS,
) -> BehaviorPanel:
"""Load UCI Online Retail dataset as a BehaviorPanel.
Args:
data_dir: Path to directory containing online_retail.xlsx.
n_customers: Max number of customers to include (None = all).
min_transactions: Minimum active months per customer (default 5).
top_n_products: Number of top products to include (default 50).
Returns:
BehaviorPanel with one BehaviorLog per customer.
"""
try:
import pandas as pd
except ImportError:
raise ImportError(
"pandas is required for dataset loaders. "
"Install with: pip install 'prefgraph[datasets]'"
) from None
data_path = _find_data_dir(data_dir)
df = _load_raw(data_path)
# Filter cancelled orders and missing customers
df = df[~df["InvoiceNo"].astype(str).str.startswith("C")]
df = df.dropna(subset=["CustomerID"])
df["CustomerID"] = df["CustomerID"].astype(int)
# Filter prices and quantities
df = df[
(df["UnitPrice"].between(MIN_UNIT_PRICE, MAX_UNIT_PRICE)) &
(df["Quantity"] > 0)
]
# Create monthly period
df["InvoiceDate"] = pd.to_datetime(df["InvoiceDate"], errors="coerce")
df = df.dropna(subset=["InvoiceDate"])
df["period"] = df["InvoiceDate"].dt.to_period("M").astype(str)
# Select top products by transaction count
top_products = df["StockCode"].value_counts().head(top_n_products).index.tolist()
df = df[df["StockCode"].isin(top_products)]
products = sorted(str(p) for p in top_products)
# Ensure StockCode is string for consistent handling
df["StockCode"] = df["StockCode"].astype(str)
# Build price oracle
periods = sorted(df["period"].unique())
price_pivot = df.pivot_table(
values="UnitPrice", index="period", columns="StockCode",
aggfunc="median",
).reindex(index=periods, columns=products)
price_pivot = price_pivot.ffill().bfill().fillna(price_pivot.median())
price_grid = price_pivot.values
period_to_idx = {p: i for i, p in enumerate(periods)}
# Build per-customer sessions
logs: dict[str, BehaviorLog] = {}
grouped = df.groupby("CustomerID")
customer_ids = list(grouped.groups.keys())
if n_customers is not None:
customer_ids = customer_ids[:n_customers]
for cid in customer_ids:
cust_data = grouped.get_group(cid)
qty_pivot = cust_data.pivot_table(
values="Quantity", index="period", columns="StockCode",
aggfunc="sum",
).reindex(columns=products).fillna(0)
active_periods = qty_pivot[qty_pivot.sum(axis=1) > 0].index.tolist()
if len(active_periods) < min_transactions:
continue
qty_matrix = qty_pivot.loc[active_periods].values
price_indices = [period_to_idx[p] for p in active_periods if p in period_to_idx]
if len(price_indices) != len(active_periods):
continue
price_matrix = price_grid[price_indices]
uid = f"customer_{cid}"
logs[uid] = BehaviorLog(
cost_vectors=price_matrix,
action_vectors=qty_matrix,
user_id=uid,
)
return BehaviorPanel(
_logs=logs,
metadata={
"dataset": "uci_retail",
"goods": products,
"min_transactions": min_transactions,
},
)