Source code for moe.tests.bandit.utils_test

# -*- coding: utf-8 -*-
"""Tests for functions in utils."""
import pytest

import logging

from moe.bandit.utils import get_winning_arm_names_from_payoff_arm_name_list, get_equal_arm_allocations
from moe.tests.bandit.bandit_test_case import BanditTestCase


@pytest.fixture()
[docs]def disable_logging(request): """Disable logging (for the duration of this test case).""" logging.disable(logging.CRITICAL) def finalize(): """Re-enable logging (so other test cases are unaffected).""" logging.disable(logging.NOTSET) request.addfinalizer(finalize)
[docs]class TestUtils(BanditTestCase): """Tests :func:`moe.bandit.utils.get_winning_arm_names_from_payoff_arm_name_list` and :func:`moe.bandit.utils.get_equal_arm_allocations`.""" @pytest.mark.usefixtures("disable_logging")
[docs] def test_get_winning_arm_names_from_payoff_arm_name_list_empty_list_invalid(self): """Test empty ``payoff_arm_name_list`` causes an ValueError.""" with pytest.raises(ValueError): get_winning_arm_names_from_payoff_arm_name_list([])
[docs] def test_get_winning_arm_names_from_payoff_arm_name_list_one_winner(self): """Test winning arm name matches the winner.""" assert get_winning_arm_names_from_payoff_arm_name_list([(0.5, "arm1"), (0.0, "arm2")]) == frozenset(["arm1"])
[docs] def test_get_winning_arm_names_from_payoff_arm_name_list_two_winners(self): """Test winning arm names match the winners.""" assert get_winning_arm_names_from_payoff_arm_name_list([(0.5, "arm1"), (0.5, "arm2"), (0.4, "arm3")]) == frozenset(["arm1", "arm2"])
@pytest.mark.usefixtures("disable_logging")
[docs] def test_get_equal_arm_allocations_empty_arm_invalid(self): """Test empty ``arms_sampled`` causes an ValueError.""" with pytest.raises(ValueError): get_equal_arm_allocations({})
[docs] def test_get_equal_arm_allocations_no_winner(self): """Test allocations split among all sample arms when there is no winner.""" assert get_equal_arm_allocations(self.two_unsampled_arms_test_case.arms_sampled) == {"arm1": 0.5, "arm2": 0.5}
[docs] def test_get_equal_arm_allocations_one_winner(self): """Test all allocation given to the winning arm.""" assert get_equal_arm_allocations(self.three_arms_test_case.arms_sampled, frozenset(["arm1"])) == {"arm1": 1.0, "arm2": 0.0, "arm3": 0.0}
[docs] def test_get_equal_arm_allocations_two_winners(self): """Test allocations split between two winning arms.""" assert get_equal_arm_allocations(self.three_arms_two_winners_test_case.arms_sampled, frozenset(["arm1", "arm2"])) == {"arm1": 0.5, "arm2": 0.5, "arm3": 0.0}