Skip to content

multinomial_dirichlet_inputs should handle categorical/string data #289

@williambdean

Description

@williambdean

Summary

The multinomial_dirichlet_inputs helper is currently an identity function that doesn't process data, but the documentation suggests it can handle raw categorical data like string labels.

Current Behavior

from conjugate.helpers import multinomial_dirichlet_inputs

# This is what the helper actually does:
counts = [5, 3, 8, 2]
inputs = multinomial_dirichlet_inputs(counts)
# Returns: {'x': [5, 3, 8, 2]}  # Just wraps in dict

# This does NOT work despite documentation suggesting it:
responses = ['Excellent', 'Good', 'Fair', 'Good', 'Excellent']
inputs = multinomial_dirichlet_inputs(responses)
# Returns: {'x': ['Excellent', 'Good', 'Fair', 'Good', 'Excellent']}
# This will FAIL when passed to the model!

Documentation Shows Different Usage

From docs/examples/raw-data-workflow.md:

responses = ['Excellent', 'Good', 'Fair', 'Good', ...]
inputs = multinomial_dirichlet_inputs(responses)  # Implied to work

# But then manually counts anyway:
from collections import Counter
counts = Counter(responses)
response_counts = [responses.count(cat) for cat in categories]

This contradicts the "automatic extraction" promise of the helpers module.

Expected Behavior

The helper should optionally accept categorical data and convert to counts:

# Option 1: Auto-detect and count unique values
responses = ['A', 'B', 'A', 'C', 'B', 'A']
inputs = multinomial_dirichlet_inputs(responses)
# Returns: {'x': [3, 2, 1]}  # Counts for A, B, C

# Option 2: Specify categories explicitly for ordering
inputs = multinomial_dirichlet_inputs(responses, categories=['A', 'B', 'C'])
# Returns: {'x': [3, 2, 1]}

Suggested Fix

def multinomial_dirichlet_inputs(x, *, categories=None):
    """Extract sufficient statistics for multinomial_dirichlet model.
    
    Args:
        x: Either counts array or categorical observations (strings/labels)
        categories: If x contains categorical data, specify category order.
                   If None and x contains non-numeric data, auto-detect categories.
    
    Returns:
        Dict with key 'x' containing counts for each category
    """
    if categories is not None:
        # Convert categorical to counts
        x_list = list(x)
        x = [x_list.count(cat) for cat in categories]
    elif hasattr(x, '__iter__') and len(x) > 0:
        # Auto-detect if first element is string-like
        first = next(iter(x))
        if isinstance(first, str):
            from collections import Counter
            counts = Counter(x)
            x = list(counts.values())
    
    return {"x": x}

Alternative

If the helper should remain simple, update documentation to clearly show:

  1. Helper only accepts pre-counted arrays
  2. Show manual counting as the expected pattern

Also Affects

  • categorical_dirichlet_inputs - same issue

Metadata

Metadata

Assignees

No one assigned

    Labels

    docsImprovements or additions to documentationenhancementNew feature or requestmodelsHas to do with the models

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions