Skip to content

uniform dirichlet prior with stickbreaking transform + ADVI  #4733

Open
@harrig12

Description

@harrig12

Description of your problem

My dirichlet prior does not appear to be behaving as I would expect when using ADVI. For uniform a=1, the posterior density of the last element is way off, and there are lots of divergences in the traceplot.

Reading through #4129 I wonder if it may have to do with Km1, because it's noticeably exacerbated when the size of the dirichlet is increased. (here I set to 30 to demonstrate).

The reason I noticed this is because my "uniform" prior does not actually look uniform at all. The unexpected behavior is lessened over the course of training, as my model learns - but in some settings it is greatly hampered by the very biased prior that is apparently being created. I've only fit the ADVI trace with a single step so as to show this.

Works as expected with NUTS

import pymc3 as pm
import numpy as np
import pandas as pd

with pm.Model() as model:
    decomp = pm.Dirichlet('decomp', np.ones(30), shape=30)
    trace1 = pm.sample(100, return_inferencedata=False)
    pm.plot_trace(trace1, var_names = 'decomp');
       
pd.DataFrame(trace1['decomp_stickbreaking__']).plot.kde(figsize=(10,4), legend=False);
pd.DataFrame(trace1['decomp']).plot.kde(figsize=(10,4), legend=False);

image

Strange result for the last dirichlet element when using ADVI

with pm.Model() as model:
    decomp = pm.Dirichlet('decomp', np.ones(30), shape=30)
    trace2 = pm.ADVI() 
    trace2.fit(1)
    pm.plot_trace(trace2.approx.sample(100), var_names = 'decomp');

pd.DataFrame(trace2.approx.sample(100)['decomp_stickbreaking__']).plot.kde(figsize=(10,4), legend=False);
pd.DataFrame(trace2.approx.sample(100)['decomp']).plot.kde(figsize=(10,4), legend = False);

image

Versions and main components

  • PyMC3 Version: 3.11.2
  • Aesara Version: n/a
  • Python Version: 3.8.8
  • Operating system: CentOS Linux 7 (Core)
  • How did you install PyMC3: conda

Metadata

Metadata

Assignees

No one assigned

    Labels

    VIVariational Inferencequestion

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions