plot_interactive#

Interactive Plotly plotting factory for MMM.

This module provides MMMPlotlyFactory, which creates interactive Plotly visualizations from MMM summary data produced by MMMSummaryFactory.

The factory supports:

  • Contributions: Bar charts showing channel/control/seasonality contributions

  • ROAS: Return on Ad Spend analysis with confidence intervals

  • Posterior Predictive: Time series with HDI bands comparing actual vs predicted

  • Saturation Curves: Visualize diminishing returns per channel

  • Adstock Curves: Show carryover effects over time

  • Automatic faceting based on custom dimensions (e.g., geo, brand)

  • Both Pandas and Polars DataFrames via Narwhals

Examples#

Basic Usage via MMM Model

Access the plotting factory directly from a fitted MMM model:

>>> # Posterior predictive with actual vs predicted
>>> fig = mmm.plot_interactive.posterior_predictive()
>>> fig.show()
>>> # Channel contributions over time
>>> fig = mmm.plot_interactive.contributions()
>>> fig.show()
>>> # ROAS analysis aggregated by year
>>> fig = mmm.plot_interactive.roas(frequency="yearly")
>>> fig.show()
>>> # Saturation curves showing diminishing returns
>>> fig = mmm.plot_interactive.saturation_curves()
>>> fig.show()
>>> # Adstock curves showing carryover effects
>>> fig = mmm.plot_interactive.adstock_curves()
>>> fig.show()

Customizing Plots

Control faceting and styling with kwargs:

>>> # ROAS colored by date, grouped by channel
>>> fig = mmm.plot_interactive.roas(frequency="yearly", color="date", x="channel")
>>> fig.show()
>>> # Disable auto-faceting and manually set facet column
>>> fig = mmm.plot_interactive.contributions(
...     facet_col="country", title="Channel Effects by Country"
... )
>>> fig.show()
>>> # Saturation curves faceted by brand
>>> fig = mmm.plot_interactive.saturation_curves(
...     facet_row="brand",
... )
>>> fig.show()

Working with Filtered/Aggregated Data

Create custom factories with filtered or aggregated data:

>>> from pymc_marketing.mmm.summary import MMMSummaryFactory
>>> from pymc_marketing.mmm.plot_interactive import MMMPlotlyFactory
>>> # Aggregate multiple geos into one
>>> agg_data = mmm.data.aggregate_dims(
...     dim="geo", values=["geo_a", "geo_b"], new_label="all_geos"
... )
>>> agg_summary = MMMSummaryFactory(agg_data, mmm)
>>> agg_factory = MMMPlotlyFactory(summary=agg_summary)
>>> fig = agg_factory.roas(frequency="yearly", color="channel", x="date")
>>> fig.show()
>>> # Filter to specific geo
>>> filtered_data = mmm.data.filter_dims(geo="geo_a")
>>> filtered_summary = MMMSummaryFactory(filtered_data, mmm, validate_data=False)
>>> filtered_factory = MMMPlotlyFactory(summary=filtered_summary)
>>> fig = filtered_factory.roas(frequency="yearly", color="channel", x="date")
>>> fig.show()
>>> # Filter by date range
>>> filtered_data = mmm.data.filter_dates(start_date="2024-01-01")
>>> filtered_summary = MMMSummaryFactory(filtered_data, mmm)
>>> filtered_factory = MMMPlotlyFactory(summary=filtered_summary)
>>> fig = filtered_factory.roas(frequency="quarterly", color="channel", x="date")
>>> fig.show()

Classes

MMMPlotlyFactory(summary)

Factory for creating interactive Plotly plots from MMM summary data.