Source code for prefgraph.datasets._instacart

"""Instacart Market Basket dataset loader.

Loads the Instacart "Market Basket Analysis" dataset and aggregates
orders at the aisle level (134 aisles). Since individual product prices
are not available, heuristic per-aisle prices are assigned based on
aisle names (keyword matching to price tiers). This gives meaningful
price variation for revealed preference analysis.

Data must be downloaded separately from Kaggle:
  kaggle datasets download -d instacart/market-basket-analysis
  unzip market-basket-analysis.zip -d ~/.prefgraph/data/instacart/

Source: https://www.kaggle.com/c/instacart-market-basket-analysis
License: Competition-specific (research use)
"""

from __future__ import annotations

import os
from pathlib import Path

import numpy as np

from prefgraph.core.panel import BehaviorPanel
from prefgraph.core.session import BehaviorLog


# ---------------------------------------------------------------------------
# Heuristic price tiers: keyword -> $/unit
# ---------------------------------------------------------------------------
# Ordered so that more specific keywords match first. The lookup function
# iterates top-to-bottom and returns the price for the first keyword hit.
_PRICE_TIERS: list[tuple[list[str], float]] = [
    # Alcohol (most expensive tier)
    (["spirit"], 14.00),
    (["champagne", "specialty wine"], 12.00),
    (["red wine", "white wine"], 10.00),
    (["beer", "cooler"], 8.00),

    # Protein (meat, seafood, poultry)
    (["seafood counter"], 8.00),
    (["packaged seafood", "canned meat seafood"], 5.50),
    (["meat counter", "packaged meat"], 6.00),
    (["poultry counter", "packaged poultry"], 5.50),
    (["hot dog", "bacon", "sausage"], 5.00),
    (["lunch meat"], 4.50),
    (["frozen meat"], 6.00),

    # Health / personal care / baby
    (["vitamin", "supplement"], 8.00),
    (["baby food", "formula"], 5.00),
    (["baby accessor"], 6.00),
    (["baby bath"], 5.50),
    (["diaper", "wipe"], 7.00),
    (["first aid"], 6.00),
    (["cold flu", "allergy", "muscles joints", "pain relief"], 6.50),
    (["oral hygiene"], 4.50),
    (["hair care"], 6.00),
    (["skin care", "facial care"], 6.50),
    (["body lotion", "soap"], 5.50),
    (["deodorant"], 4.50),
    (["shave"], 5.50),
    (["feminine care"], 5.00),
    (["eye ear care"], 5.50),
    (["beauty"], 6.00),
    (["digestion"], 6.00),

    # Cheese
    (["specialty cheese"], 6.00),
    (["packaged cheese", "other cream", "cheese"], 5.00),

    # Dairy / eggs
    (["ice cream"], 4.50),
    (["cream"], 3.50),
    (["butter"], 3.50),
    (["yogurt"], 3.00),
    (["milk"], 3.00),
    (["soy lactosefree"], 3.50),
    (["egg"], 3.00),
    (["pudding"], 3.00),

    # Frozen meals / pizza / appetizers
    (["frozen pizza"], 5.00),
    (["frozen meal"], 4.50),
    (["frozen appetizer", "frozen side"], 4.00),
    (["frozen breakfast"], 3.50),
    (["frozen bread", "frozen dough"], 3.50),
    (["frozen vegan", "frozen vegetarian"], 4.50),
    (["frozen juice"], 3.00),
    (["frozen produce"], 3.00),
    (["frozen dessert"], 4.00),

    # Bakery / bread
    (["bakery dessert"], 4.00),
    (["bread"], 3.00),
    (["bun", "roll", "tortilla", "flat bread"], 3.00),
    (["breakfast bar", "pastri"], 3.50),
    (["breakfast bakery"], 3.00),

    # Bulk bins (before produce/snacks so "bulk" catches these first)
    (["bulk"], 3.00),

    # Fresh produce
    (["fresh fruit"], 2.50),
    (["fresh vegetable"], 2.00),
    (["fresh herb"], 2.50),
    (["fresh pasta"], 3.00),
    (["fresh dip", "tapenade"], 3.50),
    (["packaged produce", "packaged vegetable", "packaged fruit"], 3.00),

    # Pantry condiments (before beverages so "honey/syrup" beats "nectar")
    (["honey", "syrup"], 3.50),

    # Beverages -- specific multi-word matches before generic "energy"/"sport"
    (["energy granola"], 3.00),
    (["energy sport"], 3.00),
    (["protein", "meal replacement"], 5.00),
    (["juice", "nectar"], 3.00),
    (["water", "seltzer", "sparkling"], 2.00),
    (["soft drink"], 2.50),
    (["coffee"], 4.00),
    (["tea"], 3.00),
    (["cocoa", "drink mix"], 3.00),

    # Snacks
    (["chip", "pretzel"], 3.50),
    (["popcorn", "jerky"], 3.50),
    (["cookie", " cake"], 3.50),
    (["candy", "chocolate"], 3.00),
    (["cracker"], 3.00),
    (["mint", "gum"], 2.00),
    (["trail mix", "snack mix"], 3.50),
    (["fruit vegetable snack"], 3.00),
    (["nut", "seed", "dried fruit"], 3.50),
    (["granola"], 3.50),

    # Pantry staples
    (["canned meal", "bean"], 2.00),
    (["canned jarred vegetable"], 2.00),
    (["canned fruit", "applesauce"], 2.00),
    (["prepared soup", "prepared salad"], 3.50),
    (["soup", "broth", "bouillon"], 2.50),
    (["prepared meal"], 4.50),
    (["pasta sauce"], 2.50),
    (["dry pasta"], 1.50),
    (["grain", "rice", "dried good"], 2.50),
    (["baking ingredient", "baking supplie", "baking decor"], 2.50),
    (["dough", "gelatin", "bake mix"], 2.50),
    (["spice", "season"], 3.00),
    (["condiment"], 2.50),
    (["salad dressing", "topping"], 3.00),
    (["oil", "vinegar"], 3.50),
    (["spread"], 3.00),
    (["preserved dip"], 3.00),
    (["pickle", "olive"], 3.00),
    (["marinade", "meat preparation"], 3.00),
    (["hot cereal", "pancake mix"], 3.00),
    (["cereal"], 3.50),
    (["instant food"], 2.50),
    (["tofu", "meat alternative"], 3.50),

    # International foods
    (["latino"], 2.50),
    (["asian"], 2.50),
    (["indian"], 2.50),
    (["kosher"], 3.00),

    # Household
    (["cleaning product"], 4.00),
    (["dish detergent"], 3.50),
    (["laundry"], 5.00),
    (["trash bag", "liner"], 4.00),
    (["paper good"], 4.50),
    (["air freshener", "candle"], 3.50),
    (["food storage"], 3.50),
    (["plate", "bowl", "cup", "flatware"], 3.00),
    (["kitchen supplie"], 3.00),
    (["more household"], 3.50),

    # Pets
    (["dog food", "dog care"], 5.00),
    (["cat food", "cat care"], 5.00),
    (["pet"], 5.00),

    # Catch-all for "refrigerated", "other", "missing", etc.
    (["refrigerated"], 3.00),
    (["other"], 3.00),
    (["missing"], 3.00),
]

_DEFAULT_PRICE = 3.00


def _aisle_price(aisle_name: str) -> float:
    """Return heuristic $/unit price for an aisle based on keyword matching."""
    name = aisle_name.lower()
    for keywords, price in _PRICE_TIERS:
        if any(kw in name for kw in keywords):
            return price
    return _DEFAULT_PRICE


def _find_data_dir(data_dir: str | Path | None) -> Path:
    """Find Instacart data directory via cascade."""
    candidates = []
    if data_dir is not None:
        candidates.append(Path(data_dir))

    env = os.environ.get("PYREVEALED_DATA_DIR")
    if env:
        candidates.append(Path(env) / "instacart")

    candidates.extend([
        Path.home() / ".prefgraph" / "data" / "instacart",
        Path(__file__).resolve().parents[3] / "datasets" / "instacart" / "data",
    ])

    for d in candidates:
        if d.is_dir() and (d / "orders.csv").exists():
            return d

    searched = "\n  ".join(str(c) for c in candidates)
    raise FileNotFoundError(
        f"Instacart data not found. Searched:\n  {searched}\n\n"
        "Download from Kaggle:\n"
        "  kaggle datasets download -d instacart/market-basket-analysis\n"
        "  unzip market-basket-analysis.zip -d ~/.prefgraph/data/instacart/\n\n"
        "Required files: orders.csv, order_products__prior.csv, products.csv, aisles.csv"
    )


[docs] def load_instacart( data_dir: str | Path | None = None, max_users: int | None = None, min_orders: int = 10, ) -> BehaviorPanel: """Load Instacart dataset as a BehaviorPanel. Aggregates products at the aisle level (134 aisles). Uses heuristic per-aisle prices based on keyword matching of aisle names (e.g. fresh produce ~$2, meat/seafood ~$6, alcohol ~$10). Args: data_dir: Path to directory containing Instacart CSV files. max_users: Maximum number of users (None = all). min_orders: Minimum orders per user (default 10). Returns: BehaviorPanel with one BehaviorLog per user. """ try: import pandas as pd except ImportError: raise ImportError( "pandas is required for dataset loaders. " "Install with: pip install 'prefgraph[datasets]'" ) from None data_path = _find_data_dir(data_dir) print(f" Loading Instacart data from {data_path}...") # Load orders -- use only "prior" (main historical data) orders = pd.read_csv(data_path / "orders.csv") prior_orders = orders[orders["eval_set"] == "prior"].copy() prior_orders = prior_orders.sort_values(["user_id", "order_number"]) # Load order-product details order_products = pd.read_csv(data_path / "order_products__prior.csv") # Load product -> aisle mapping products = pd.read_csv(data_path / "products.csv") aisles = pd.read_csv(data_path / "aisles.csv") products = products.merge(aisles, on="aisle_id") # Build heuristic price lookup: aisle_id -> $/unit aisle_price_map = { row["aisle_id"]: _aisle_price(row["aisle"]) for _, row in aisles.iterrows() } # Merge to get aisle_id per order-product order_products = order_products.merge( products[["product_id", "aisle_id"]], on="product_id", ) # Count items per aisle per order aisle_counts = ( order_products .groupby(["order_id", "aisle_id"]) .size() .reset_index(name="quantity") ) # Merge with order info to get user_id and order_number aisle_counts = aisle_counts.merge( prior_orders[["order_id", "user_id", "order_number"]], on="order_id", ) # Filter users with enough orders user_order_counts = prior_orders.groupby("user_id")["order_id"].nunique() qualifying_users = user_order_counts[user_order_counts >= min_orders].index if max_users is not None: qualifying_users = qualifying_users[:max_users] aisle_counts = aisle_counts[aisle_counts["user_id"].isin(qualifying_users)] # Build column index from observed aisles aisle_ids = sorted(aisle_counts["aisle_id"].unique()) aisle_idx = {a: i for i, a in enumerate(aisle_ids)} n_cols = len(aisle_ids) # Constant price vector (same for every observation) price_vector = np.array([aisle_price_map[a] for a in aisle_ids]) # Build per-user BehaviorLogs logs: dict[str, BehaviorLog] = {} for user_id, user_data in aisle_counts.groupby("user_id"): order_nums = sorted(user_data["order_number"].unique()) T = len(order_nums) if T < min_orders: continue qty_matrix = np.zeros((T, n_cols)) for _, row in user_data.iterrows(): t_idx = order_nums.index(row["order_number"]) a_idx = aisle_idx[row["aisle_id"]] qty_matrix[t_idx, a_idx] += row["quantity"] # Replicate price vector across all observations price_matrix = np.tile(price_vector, (T, 1)) uid = f"user_{user_id}" logs[uid] = BehaviorLog( cost_vectors=price_matrix, action_vectors=qty_matrix, user_id=uid, ) # Build aisle name list for metadata aisle_names = [] aisle_name_map = dict(zip(aisles["aisle_id"], aisles["aisle"])) for a in aisle_ids: aisle_names.append(aisle_name_map.get(a, f"aisle_{a}")) price_range = (price_vector.min(), price_vector.max()) print( f" Built {len(logs)} BehaviorLog objects " f"({n_cols} aisles, prices ${price_range[0]:.2f}-${price_range[1]:.2f}/unit)" ) return BehaviorPanel( _logs=logs, metadata={ "dataset": "instacart", "goods": aisle_ids, "aisle_names": aisle_names, "n_aisles": n_cols, "price_type": "heuristic_per_aisle", "price_range": price_range, }, )