Skip to content

Shape Tests #141

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Draft
wants to merge 4 commits into
base: ig/normal_model
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions batchglm/train/numpy/glm_norm/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
import numpy as np


def ll(scale, loc, x):
resid = loc - x
ll = -.5 * np.log(2 * math.pi) - np.log(scale) - .5 * np.power(resid / scale, 2)
Expand Down
74 changes: 74 additions & 0 deletions tests/numpy/test_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import logging
import unittest

from utils import get_estimator, get_generated_model

from batchglm.train.numpy.base_glm import BaseModelContainer

logger = logging.getLogger("batchglm")

n_obs = 2000
n_vars = 100
num_batches = 4
num_conditions = 2


def _test_shape_of_model(model_container: BaseModelContainer) -> bool:
"""Check the shape of different fitted/parametrized values against what is epected"""
assert model_container.theta_scale.shape == (model_container.model.num_scale_params, n_vars)
assert model_container.theta_location.shape == (model_container.model.num_loc_params, n_vars)

assert model_container.fim_weight_location_location.shape == (n_obs, n_vars)

assert model_container.hessian_weight_location_location.shape == (n_obs, n_vars)
assert model_container.hessian_weight_location_scale.shape == (n_obs, n_vars)
assert model_container.hessian_weight_scale_scale.shape == (n_obs, n_vars)

assert model_container.jac_scale.shape == (n_vars, model_container.model.num_scale_params)
assert model_container.jac_location.shape == (n_vars, model_container.model.num_loc_params)


class TestShape(unittest.TestCase):

def _test_shape(self) -> bool:
dense_model = get_generated_model(
noise_model=self._model_name, num_conditions=num_conditions, num_batches=num_batches, sparse=False,
n_obs=n_obs, n_vars=n_vars
)
sparse_model = get_generated_model(
noise_model=self._model_name, num_conditions=num_conditions, num_batches=num_batches, sparse=True,
n_obs=n_obs, n_vars=n_vars
)
dense_estimator = get_estimator(
noise_model=self._model_name, model=dense_model, init_location="standard", init_scale="standard"
)
sparse_estimator = get_estimator(
noise_model=self._model_name, model=sparse_model, init_location="standard", init_scale="standard"
)
model_container_dense = dense_estimator.model_container
model_container_sparse = sparse_estimator.model_container
_test_shape_of_model(model_container_dense)
_test_shape_of_model(model_container_sparse)
return True


class TestShapeNB(TestShape):

def __init__(self, *args, **kwargs):
self._model_name = "nb"
super(TestShapeNB, self).__init__(*args, **kwargs)

def test_shape(self) -> bool:
return self._test_shape()

class TestShapeNorm(TestShape):

def __init__(self, *args, **kwargs):
self._model_name = "norm"
super(TestShapeNorm, self).__init__(*args, **kwargs)

def test_shape(self) -> bool:
return self._test_shape()

if __name__ == "__main__":
unittest.main()
7 changes: 4 additions & 3 deletions tests/numpy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def get_model(noise_model: str) -> _ModelGLM:


def get_generated_model(
noise_model: str, num_conditions: int, num_batches: int, sparse: bool, mode: Optional[str] = None
noise_model: str, num_conditions: int, num_batches: int, sparse: bool, mode: Optional[str] = None,
n_obs: Optional[int] = 2000, n_vars: Optional[int] = 100,
) -> _ModelGLM:
model = get_model(noise_model=noise_model)

Expand Down Expand Up @@ -85,8 +86,8 @@ def const(offset: float):
raise ValueError(f"Mode {mode} not recognized.")

model.generate_artificial_data(
n_obs=2000,
n_vars=100,
n_obs=n_obs,
n_vars=n_vars,
num_conditions=num_conditions,
num_batches=num_batches,
intercept_scale=True,
Expand Down