Source code for moe.tests.bandit.epsilon.epsilon_test

# -*- coding: utf-8 -*-
"""Test epsilon bandit implementation (functions common to epsilon bandit).

Test functions in :class:`moe.bandit.epsilon.epsilon_interface.EpsilonInterface`

"""
import pytest

import logging

from moe.bandit.epsilon.epsilon_interface import EpsilonInterface
from moe.tests.bandit.epsilon.epsilon_test_case import EpsilonTestCase


@pytest.fixture()
[docs]def disable_logging(request): """Disable logging (for the duration of this test case).""" logging.disable(logging.CRITICAL) def finalize(): """Re-enable logging (so other test cases are unaffected).""" logging.disable(logging.NOTSET) request.addfinalizer(finalize)
[docs]class TestEpsilon(EpsilonTestCase): """Verify that different sample_arms return correct results.""" @pytest.mark.usefixtures("disable_logging")
[docs] def test_empty_arm_invalid(self): """Test empty ``sample_arms`` causes an ValueError.""" with pytest.raises(ValueError): EpsilonInterface.get_winning_arm_names({})
[docs] def test_two_unsampled_arms(self): """Check that the two-unsampled-arms case always returns both arms as winning arms. This tests num_winning_arms == num_arms > 1.""" assert EpsilonInterface.get_winning_arm_names(self.two_unsampled_arms_test_case.arms_sampled) == frozenset(["arm1", "arm2"])
[docs] def test_three_arms_two_winners(self): """Check that the three-arms cases with two winners return the expected winning arms. This tests num_arms > num_winning_arms > 1.""" assert EpsilonInterface.get_winning_arm_names(self.three_arms_two_winners_test_case.arms_sampled) == frozenset(["arm1", "arm2"])