"""One-liner DataFrame-to-scores API.
Provides ``analyze()`` - feed a pandas DataFrame, get rationality scores back.
Example::
import prefgraph as rp
results = rp.analyze(df, user_col="user_id",
cost_cols=["price_A", "price_B"],
action_cols=["qty_A", "qty_B"])
"""
from __future__ import annotations
import warnings
from pathlib import Path
from typing import Any, Literal
_DEFAULT_BUDGET_METRICS = ["garp", "ccei", "mpi"]
def _detect_format(
*,
item_col: str | None,
cost_col: str | None,
action_col: str | None,
time_col: str | None,
cost_cols: list[str] | None,
action_cols: list[str] | None,
menu_col: str | None,
choice_col: str | None,
) -> Literal["long", "wide", "menu"]:
"""Detect input format from provided parameters."""
has_long = item_col is not None
has_wide = cost_cols is not None or action_cols is not None
has_menu = menu_col is not None or choice_col is not None
active = sum([has_long, has_wide, has_menu])
if active > 1:
parts = []
if has_long:
parts.append("long-format (item_col)")
if has_wide:
parts.append("wide-format (cost_cols/action_cols)")
if has_menu:
parts.append("menu (menu_col/choice_col)")
raise ValueError(
f"Conflicting format parameters: {', '.join(parts)}. "
f"Provide parameters for exactly one format."
)
if has_long:
return "long"
if has_wide:
return "wide"
if has_menu:
return "menu"
raise ValueError(
"Cannot detect data format. Provide parameters for one of:\n\n"
" Wide format (one row per observation, items as columns):\n"
" rp.analyze(df, cost_cols=['p1','p2'], action_cols=['q1','q2'])\n\n"
" Long format (one row per item per time):\n"
" rp.analyze(df, item_col='product', cost_col='price', "
"action_col='quantity', time_col='week')\n\n"
" Menu choice (one row per observation):\n"
" rp.analyze(df, menu_col='shown_items', choice_col='clicked')"
)
def _check_columns(
df: Any,
fmt: str,
*,
user_col: str,
item_col: str | None,
cost_col: str | None,
action_col: str | None,
time_col: str | None,
cost_cols: list[str] | None,
action_cols: list[str] | None,
menu_col: str | None,
choice_col: str | None,
) -> None:
"""Validate that all referenced columns exist in the DataFrame."""
available = set(df.columns)
def _check(col_name: str | None, param: str, is_default: bool = False) -> None:
if col_name is not None and col_name not in available:
default_hint = (
f" (defaulting to '{col_name}' - set {param}= explicitly)"
if is_default else ""
)
# Suggest close matches
close = [c for c in sorted(available)
if col_name.lower() in c.lower() or c.lower() in col_name.lower()]
suggestion = f" Similar: {close}." if close else ""
raise ValueError(
f"Column '{col_name}'{default_hint} not found. "
f"Available columns: {sorted(available)}.{suggestion}"
)
def _check_list(cols: list[str] | None, param: str) -> None:
if cols is not None:
missing = [c for c in cols if c not in available]
if missing:
raise ValueError(
f"Columns {missing} (from {param}=) not found. "
f"Available columns: {sorted(available)}"
)
if fmt == "wide":
_check_list(cost_cols, "cost_cols")
_check_list(action_cols, "action_cols")
elif fmt == "long":
_check(item_col, "item_col")
_check(cost_col or "price", "cost_col", is_default=cost_col is None)
_check(action_col or "quantity", "action_col", is_default=action_col is None)
_check(time_col or "time", "time_col", is_default=time_col is None)
elif fmt == "menu":
_check(menu_col or "menu", "menu_col", is_default=menu_col is None)
_check(choice_col or "choice", "choice_col", is_default=choice_col is None)
[docs]
def analyze(
df: Any,
*,
user_col: str = "user_id",
# Long format (transaction logs)
item_col: str | None = None,
cost_col: str | None = None,
action_col: str | None = None,
time_col: str | None = None,
# Wide format (pivoted)
cost_cols: list[str] | None = None,
action_cols: list[str] | None = None,
# Menu choice
menu_col: str | None = None,
choice_col: str | None = None,
# Options
metrics: list[str] | None = None,
output: Literal["dataframe", "objects"] = "dataframe",
nan_policy: Literal["raise", "warn", "drop"] = "raise",
# Legacy aliases
price_col: str | None = None,
qty_col: str | None = None,
price_cols: list[str] | None = None,
qty_cols: list[str] | None = None,
) -> Any:
"""Score rationality of choices in a pandas DataFrame.
Auto-detects whether your data is wide-format, long-format (transaction
logs), or menu-choice based on which parameters you provide.
Args:
df: pandas DataFrame containing choice data.
user_col: Column name for user/household IDs (default ``"user_id"``).
item_col: (Long format) Column for item/product identifiers.
cost_col: (Long format) Column for prices/costs.
action_col: (Long format) Column for quantities/actions.
time_col: (Long format) Column for time/observation identifiers.
cost_cols: (Wide format) List of column names for cost vectors.
action_cols: (Wide format) List of column names for action vectors.
menu_col: (Menu) Column containing sets/lists of available items.
choice_col: (Menu) Column containing the chosen item.
metrics: Engine metrics to compute. Default ``["garp", "ccei", "mpi"]``
for budget data. Ignored for menu data (always SARP/WARP/HM).
output: ``"dataframe"`` (default) returns a pandas DataFrame with one
row per user. ``"objects"`` returns a list of EngineResult/MenuResult.
nan_policy: How to handle NaN/Inf values. ``"raise"`` (default) raises
an error. ``"drop"`` silently removes affected rows. ``"warn"``
drops with a warning.
price_col: Alias for ``cost_col``.
qty_col: Alias for ``action_col``.
price_cols: Alias for ``cost_cols``.
qty_cols: Alias for ``action_cols``.
Returns:
pandas DataFrame (default) or list of result objects.
Examples:
Wide format::
results = rp.analyze(df,
cost_cols=["price_A", "price_B"],
action_cols=["qty_A", "qty_B"])
Long format (transaction logs)::
results = rp.analyze(df,
item_col="product", cost_col="price",
action_col="quantity", time_col="week")
Menu choice::
results = rp.analyze(df,
menu_col="shown_items", choice_col="clicked")
"""
# --- Parquet file path shortcut ---
if isinstance(df, (str, Path)):
path = Path(df)
if path.suffix == ".parquet" or path.is_dir():
from prefgraph.engine import Engine
engine = Engine(metrics=metrics or _DEFAULT_BUDGET_METRICS)
result = engine.analyze_parquet(
path,
user_col=user_col,
cost_cols=cost_cols or price_cols,
action_cols=action_cols or qty_cols,
item_col=item_col,
cost_col=cost_col or price_col,
action_col=action_col or qty_col,
time_col=time_col,
)
if output == "objects":
return list(result.iterrows())
return result
# --- Validate input type ---
try:
import pandas as pd
except ImportError:
raise ImportError(
"pandas is required for analyze(). "
"Install with: pip install pandas"
) from None
if not isinstance(df, pd.DataFrame):
hint = ""
if isinstance(df, pd.Series):
hint = " To convert a Series: pd.DataFrame(series)."
elif isinstance(df, dict):
hint = " To convert a dict: pd.DataFrame(your_dict)."
elif hasattr(df, 'shape'): # numpy array
hint = " To convert a numpy array: pd.DataFrame(array, columns=[...])."
raise TypeError(
f"First argument must be a pandas DataFrame, got {type(df).__name__}.{hint}"
)
if len(df) == 0:
raise ValueError("DataFrame is empty (0 rows). Nothing to analyze.")
# --- Validate output parameter ---
if output not in ("dataframe", "objects"):
raise ValueError(
f"output must be 'dataframe' or 'objects', got '{output}'."
)
# --- Resolve legacy aliases ---
cost_col = cost_col or price_col
action_col = action_col or qty_col
cost_cols = cost_cols or price_cols
action_cols = action_cols or qty_cols
# --- Catch string-instead-of-list mistake ---
for param_name, param_val in [("cost_cols", cost_cols), ("action_cols", action_cols),
("price_cols", price_cols), ("qty_cols", qty_cols)]:
if isinstance(param_val, str):
raise TypeError(
f"{param_name} must be a list of column names, got a string '{param_val}'. "
f"Use {param_name}=['{param_val}'] instead."
)
# --- Detect format ---
fmt = _detect_format(
item_col=item_col,
cost_col=cost_col,
action_col=action_col,
time_col=time_col,
cost_cols=cost_cols,
action_cols=action_cols,
menu_col=menu_col,
choice_col=choice_col,
)
# --- Validate columns exist ---
available = list(df.columns)
if user_col not in available:
# Try to suggest the closest match
close = [c for c in available if "user" in c.lower() or "id" in c.lower()
or "customer" in c.lower() or "household" in c.lower()]
suggestion = f" Did you mean: {close}?" if close else ""
raise ValueError(
f"Column '{user_col}' not found in DataFrame. "
f"Available columns: {available}.{suggestion}"
)
_check_columns(df, fmt, user_col=user_col, item_col=item_col,
cost_col=cost_col, action_col=action_col, time_col=time_col,
cost_cols=cost_cols, action_cols=action_cols,
menu_col=menu_col, choice_col=choice_col)
# --- Dispatch by format ---
if fmt == "wide":
user_ids, results = _analyze_wide(
df, user_col=user_col,
cost_cols=cost_cols, action_cols=action_cols,
metrics=metrics, nan_policy=nan_policy,
)
elif fmt == "long":
user_ids, results = _analyze_long(
df, user_col=user_col,
item_col=item_col, cost_col=cost_col,
action_col=action_col, time_col=time_col,
metrics=metrics, nan_policy=nan_policy,
)
else: # menu
if metrics is not None:
warnings.warn(
"metrics parameter is ignored for menu choice data. "
"Engine.analyze_menus() always computes SARP/WARP/HM.",
stacklevel=2,
)
user_ids, results = _analyze_menu(
df, user_col=user_col,
menu_col=menu_col, choice_col=choice_col,
)
# --- Return ---
if output == "objects":
return list(zip(user_ids, results))
from prefgraph.engine import results_to_dataframe
return results_to_dataframe(results, user_ids=user_ids)
def _handle_nan(
df: Any,
data_cols: list[str],
nan_policy: str,
) -> Any:
"""Handle NaN/Inf values in data columns before analysis."""
import numpy as np
subset = df[data_cols]
has_problem = subset.isnull().any(axis=1) | subset.isin([np.inf, -np.inf]).any(axis=1)
if not has_problem.any():
return df
n_bad = has_problem.sum()
if nan_policy == "raise":
raise ValueError(
f"Found {n_bad} rows with NaN or Inf values in columns {data_cols}. "
f"Options: set nan_policy='drop' to remove them, "
f"nan_policy='warn' to drop with a warning, "
f"or clean your data first with df.dropna(subset={data_cols})."
)
elif nan_policy == "warn":
warnings.warn(
f"Dropping {n_bad} rows with NaN/Inf values in {data_cols}.",
stacklevel=4,
)
# drop or warn: filter out bad rows
return df[~has_problem].copy()
def _analyze_wide(
df: Any,
*,
user_col: str,
cost_cols: list[str] | None,
action_cols: list[str] | None,
metrics: list[str] | None,
nan_policy: str = "raise",
) -> tuple[list[str], list]:
"""Wide format: one row per observation, items as columns."""
from prefgraph.core.panel import BehaviorPanel
from prefgraph.engine import Engine
if cost_cols is None:
raise ValueError("Wide format requires cost_cols (list of column names for prices).")
if action_cols is None:
raise ValueError("Wide format requires action_cols (list of column names for quantities).")
df = _handle_nan(df, cost_cols + action_cols, nan_policy)
try:
panel = BehaviorPanel.from_dataframe(
df, user_col=user_col, cost_cols=cost_cols, action_cols=action_cols
)
except (ValueError, TypeError) as e:
if "could not convert" in str(e).lower():
raise ValueError(
f"Non-numeric data in cost or action columns. "
f"Ensure all values in {cost_cols} and {action_cols} are numeric. "
f"Tip: df[cols].dtypes to check types, pd.to_numeric(df[col], errors='coerce') to convert."
) from None
raise
engine = Engine(metrics=metrics or _DEFAULT_BUDGET_METRICS)
results = engine.analyze_arrays(panel.to_engine_tuples())
return panel.user_ids, results
def _analyze_long(
df: Any,
*,
user_col: str,
item_col: str | None,
cost_col: str | None,
action_col: str | None,
time_col: str | None,
metrics: list[str] | None,
nan_policy: str = "raise",
) -> tuple[list[str], list]:
"""Long format: one row per item per time per user."""
from prefgraph.core.session import BehaviorLog
from prefgraph.engine import Engine
if item_col is None:
raise ValueError("Long format requires item_col.")
_cost = cost_col or "price"
_action = action_col or "quantity"
_time = time_col or "time"
df = _handle_nan(df, [_cost, _action], nan_policy)
user_ids: list[str] = []
tuples: list[tuple] = []
for uid, group in df.groupby(user_col, sort=True):
uid_str = str(uid)
try:
log = BehaviorLog.from_long_format(
group,
time_col=_time,
item_col=item_col,
cost_col=_cost,
action_col=_action,
user_id=uid_str,
)
except ValueError as e:
if "duplicate" in str(e).lower() or "reshape" in str(e).lower():
raise ValueError(
f"User '{uid_str}': duplicate (time, item) pairs found. "
f"Each ({_time}, {item_col}) combination must be unique per user. "
f"Aggregate duplicates first: "
f"df.groupby(['{user_col}','{_time}','{item_col}']).agg(...)."
) from None
raise
except Exception as e:
if "could not convert" in str(e).lower():
raise ValueError(
f"User '{uid_str}': non-numeric data in cost or action columns. "
f"Ensure '{_cost}' and '{_action}' contain numeric values."
) from None
raise
user_ids.append(uid_str)
tuples.append(log.to_engine_tuple())
engine = Engine(metrics=metrics or _DEFAULT_BUDGET_METRICS)
results = engine.analyze_arrays(tuples)
return user_ids, results
def _analyze_menu(
df: Any,
*,
user_col: str,
menu_col: str | None,
choice_col: str | None,
) -> tuple[list[str], list]:
"""Menu choice: sets of items and which was chosen."""
from prefgraph.core.panel import MenuChoicePanel
from prefgraph.engine import Engine
_menu = menu_col or "menu"
_choice = choice_col or "choice"
try:
panel = MenuChoicePanel.from_dataframe(
df, user_col=user_col, menu_col=_menu, choice_col=_choice
)
except TypeError as e:
raise TypeError(
f"Menu choice data requires integer item indices. "
f"If your items are strings, map them to integers first: "
f"item_map = {{v: i for i, v in enumerate(all_items)}}; "
f"df['{_menu}'] = df['{_menu}'].apply(lambda m: [item_map[x] for x in m])"
) from None
tuples = [log.to_engine_tuple() for _, log in panel]
engine = Engine() # metrics param only applies to budget data
results = engine.analyze_menus(tuples)
return panel.user_ids, results