"""Olist Brazilian E-Commerce dataset loader.
Loads the Olist dataset of ~100K orders from ~96K customers across
Brazilian marketplaces, returning a BehaviorPanel.
Key: Olist anonymizes customer_id per order. The true persistent
identifier is customer_unique_id in olist_customers_dataset.csv.
~3K customers have 2+ orders; ~250 have 3+.
Data must be downloaded separately from Kaggle:
https://www.kaggle.com/datasets/olistbr/brazilian-ecommerce
"""
from __future__ import annotations
import os
from pathlib import Path
import numpy as np
from prefgraph.core.panel import BehaviorPanel
from prefgraph.core.session import BehaviorLog
# --- Constants ---
TOP_CATEGORIES = [
"bed_bath_table", "health_beauty", "sports_leisure",
"furniture_decor", "computers_accessories", "housewares",
"watches_gifts", "telephony", "garden_tools", "auto",
"toys", "cool_stuff", "perfumery", "baby",
"electronics", "stationery", "fashion_bags_accessories",
"pet_shop", "office_furniture", "luggage_accessories",
]
NUM_CATEGORIES = len(TOP_CATEGORIES)
MIN_UNIT_PRICE = 0.01
MAX_UNIT_PRICE = 5000.0
def _find_data_dir(data_dir: str | Path | None) -> Path:
"""Find Olist data directory via cascade."""
candidates = []
if data_dir is not None:
candidates.append(Path(data_dir))
env = os.environ.get("PYREVEALED_DATA_DIR")
if env:
candidates.append(Path(env) / "olist")
candidates.extend([
Path.home() / ".prefgraph" / "data" / "olist",
Path(__file__).resolve().parents[3] / "olist" / "data",
])
for d in candidates:
if d.is_dir() and (d / "olist_orders_dataset.csv").exists():
return d
searched = "\n ".join(str(c) for c in candidates)
raise FileNotFoundError(
f"Olist data not found. Searched:\n {searched}\n\n"
"Download from Kaggle: https://www.kaggle.com/datasets/olistbr/brazilian-ecommerce\n"
"Then pass data_dir= or set PYREVEALED_DATA_DIR environment variable."
)
[docs]
def load_olist(
data_dir: str | Path | None = None,
n_customers: int | None = None,
min_months: int = 2,
min_orders: int = 3,
n_categories: int = NUM_CATEGORIES,
) -> BehaviorPanel:
"""Load Olist Brazilian E-Commerce dataset as a BehaviorPanel.
Joins orders, order items, products, and the customer identity table
to build monthly budget vectors (price x quantity) across product
categories per customer.
Olist anonymizes customer_id per order; the true repeat-buyer key is
customer_unique_id from olist_customers_dataset.csv. ~3K customers
have 2+ orders, ~250 have 3+.
Args:
data_dir: Path to directory containing Olist CSV files.
If None, searches standard locations.
n_customers: Max number of customers to include (None = all).
min_months: Minimum active months per customer (default 2).
min_orders: Minimum number of distinct orders per customer
(default 3). Most Olist customers have only 1 order.
n_categories: Number of top product categories to use (default 20).
Returns:
BehaviorPanel with one BehaviorLog per customer (rows = months,
cols = product categories).
Raises:
FileNotFoundError: If data files cannot be found.
ImportError: If pandas is not installed.
"""
try:
import pandas as pd
except ImportError:
raise ImportError(
"pandas is required for dataset loaders. "
"Install with: pip install 'prefgraph[datasets]'"
) from None
data_path = _find_data_dir(data_dir)
# --- Load CSVs ---
orders = pd.read_csv(
data_path / "olist_orders_dataset.csv",
usecols=["order_id", "customer_id", "order_status",
"order_purchase_timestamp"],
parse_dates=["order_purchase_timestamp"],
)
customers = pd.read_csv(
data_path / "olist_customers_dataset.csv",
usecols=["customer_id", "customer_unique_id"],
)
items = pd.read_csv(
data_path / "olist_order_items_dataset.csv",
usecols=["order_id", "product_id", "price", "freight_value"],
)
products = pd.read_csv(
data_path / "olist_products_dataset.csv",
usecols=["product_id", "product_category_name"],
)
translation = pd.read_csv(
data_path / "product_category_name_translation.csv",
encoding="utf-8-sig",
)
# --- Join and filter ---
# Only delivered orders
orders = orders[orders["order_status"] == "delivered"].copy()
# Resolve persistent customer identity
orders = orders.merge(customers, on="customer_id", how="left")
# Merge items -> products -> translation -> orders
merged = items.merge(products, on="product_id", how="left")
merged = merged.merge(
translation, on="product_category_name", how="left",
)
merged = merged.merge(
orders[["order_id", "customer_unique_id",
"order_purchase_timestamp"]],
on="order_id", how="inner",
)
# Use English category names, drop unmapped
merged["category"] = merged["product_category_name_english"]
merged = merged.dropna(subset=["category"])
# Filter valid prices
merged = merged[
(merged["price"] >= MIN_UNIT_PRICE) &
(merged["price"] <= MAX_UNIT_PRICE)
]
# --- Select top categories ---
categories = TOP_CATEGORIES[:n_categories]
category_counts = merged["category"].value_counts()
# Verify hardcoded list against actual data; fall back to data-driven
available = [c for c in categories if c in category_counts.index]
if len(available) < n_categories:
extras = [
c for c in category_counts.index
if c not in available
]
available.extend(extras[:n_categories - len(available)])
categories = available[:n_categories]
merged = merged[merged["category"].isin(categories)]
# --- Build monthly period key ---
merged["year_month"] = (
merged["order_purchase_timestamp"].dt.to_period("M")
)
# --- Pre-filter: keep only repeat customers (by distinct orders) ---
order_counts = (
merged.groupby("customer_unique_id")["order_id"].nunique()
)
repeat_customers = order_counts[order_counts >= min_orders].index
merged = merged[merged["customer_unique_id"].isin(repeat_customers)]
if merged.empty:
return BehaviorPanel(
_logs={},
metadata={
"dataset": "olist",
"goods": categories,
"n_categories": len(categories),
"min_months": min_months,
"min_orders": min_orders,
"n_customers": 0,
"total_months": 0,
},
)
# --- Build price oracle: median price per category per month ---
# Use ALL delivered rows (before customer filter) for robust medians
price_oracle = merged.pivot_table(
values="price", index="year_month", columns="category",
aggfunc="median",
).reindex(columns=categories)
price_oracle = price_oracle.ffill().bfill()
# Fill remaining NaN with global median per category
global_medians = merged.groupby("category")["price"].median()
for cat in categories:
if cat in global_medians.index:
price_oracle[cat] = price_oracle[cat].fillna(
global_medians[cat]
)
price_oracle = price_oracle.fillna(1.0) # absolute fallback
all_months = sorted(price_oracle.index)
month_to_idx = {m: i for i, m in enumerate(all_months)}
price_grid = price_oracle.values # (n_months, n_categories)
# --- Aggregate quantity per customer-month-category ---
# Each order item counts as quantity 1 (marketplace items)
merged["quantity"] = 1
agg = merged.groupby(
["customer_unique_id", "year_month", "category"], observed=True,
).agg(
total_qty=("quantity", "sum"),
first_timestamp=("order_purchase_timestamp", "first"),
).reset_index()
# --- Build per-customer BehaviorLogs ---
logs: dict[str, BehaviorLog] = {}
grouped = agg.groupby("customer_unique_id")
customer_ids = list(grouped.groups.keys())
if n_customers is not None:
customer_ids = customer_ids[:n_customers]
for cust_id in customer_ids:
cust_data = grouped.get_group(cust_id)
# Pivot to quantity matrix (months x categories)
qty_pivot = cust_data.pivot_table(
values="total_qty", index="year_month", columns="category",
aggfunc="sum",
).reindex(columns=categories).fillna(0)
# Only keep months with at least one purchase
active_months = (
qty_pivot[qty_pivot.sum(axis=1) > 0].index.tolist()
)
if len(active_months) < min_months:
continue
qty_matrix = qty_pivot.loc[active_months].values # (T, K)
# Price matrix from oracle
month_indices = [month_to_idx[m] for m in active_months]
price_matrix = price_grid[month_indices] # (T, K)
# Timestamps for metadata (first purchase in each active month)
timestamps = []
for m in active_months:
month_rows = cust_data[cust_data["year_month"] == m]
ts = month_rows["first_timestamp"].min()
timestamps.append(str(ts))
uid = f"customer_{cust_id}"
logs[uid] = BehaviorLog(
cost_vectors=price_matrix.astype(np.float64),
action_vectors=qty_matrix.astype(np.float64),
user_id=uid,
metadata={
"order_purchase_timestamps": timestamps,
"active_months": [str(m) for m in active_months],
},
)
return BehaviorPanel(
_logs=logs,
metadata={
"dataset": "olist",
"goods": categories,
"n_categories": len(categories),
"min_months": min_months,
"min_orders": min_orders,
"n_customers": len(logs),
"total_months": len(all_months),
},
)