Source code for moe.tests.bandit.ucb.ucb_test_case

# -*- coding: utf-8 -*-
"""Base test case class for UCB tests; includes different cases where unsampled arms are winners."""
from moe.tests.bandit.bandit_test_case import BanditTestCase


[docs]class UCBTestCase(BanditTestCase): """Base test case for the UCB bandit library.""" def _test_two_unsampled_arms(self): """Check that the two-unsampled-arms case always allocate each arm equally (the allocation is 0.5 for both arms). This tests num_unsampled_arms == num_arms > 1.""" bandit = self.bandit_class(self.two_unsampled_arms_test_case) assert bandit.allocate_arms() == {"arm1": 0.5, "arm2": 0.5} def _test_three_arms_one_unsampled_arm(self): """Check that the three-arms cases with integer and float payoffs return the expected arm allocations. When arm3 is the only unsampled arm, we expect all allocation is given to arm3.""" for historical_info in [self.three_arms_test_case, self.three_arms_float_payoffs_test_case, self.three_arms_two_winners_test_case]: bandit = self.bandit_class(historical_info) assert bandit.allocate_arms() == {"arm1": 0.0, "arm2": 0.0, "arm3": 1.0} def _test_three_arms_two_winners(self): """Check that the three-arms cases with two winners return the expected arm allocations. This tests num_arms > num_winning_arms > 1.""" bandit = self.bandit_class(self.three_arms_two_winners_no_unsampled_arm_test_case) assert bandit.allocate_arms() == {"arm1": 0.5, "arm2": 0.5, "arm3": 0.0}