Source code for moe.tests.bandit.data_containers_test

# -*- coding: utf-8 -*-
"""Tests for functions in SampleArm and HistoricalData."""
import pytest

import copy
import logging
import pprint

from moe.bandit.data_containers import HistoricalData, SampleArm
from moe.tests.bandit.bandit_test_case import BanditTestCase


@pytest.fixture()
[docs]def disable_logging(request): """Disable logging (for the duration of this test case).""" logging.disable(logging.CRITICAL) def finalize(): """Re-enable logging (so other test cases are unaffected).""" logging.disable(logging.NOTSET) request.addfinalizer(finalize)
[docs]class TestDataContainers(BanditTestCase): """Tests functions in :class:`moe.bandit.data_containers.SampleArm` and :class:`moe.bandit.data_containers.HistoricalData`."""
[docs] def test_sample_arm_str(self): """Test SampleArm's __str__ overload operator.""" for historical_info in self.historical_infos_to_test: for arm in historical_info.arms_sampled.itervalues(): assert str(arm) == pprint.pformat(arm.json_payload())
[docs] def test_sample_arm_add(self): """Test SampleArm's __add__ overload operator.""" arm1 = SampleArm(win=2, loss=1, total=3) arm2 = SampleArm(win=3, loss=2, total=5) arm3 = arm1 + arm2 assert arm3.json_payload() == SampleArm(win=5, loss=3, total=8).json_payload() # Verify that the + operator does not modify arm1 and arm2 assert arm1.json_payload() == SampleArm(win=2, loss=1, total=3).json_payload() assert arm2.json_payload() == SampleArm(win=3, loss=2, total=5).json_payload() arm1 += arm2 arm2 += arm1 assert arm1.json_payload() == SampleArm(win=5, loss=3, total=8).json_payload() assert arm2.json_payload() == SampleArm(win=8, loss=5, total=13).json_payload() # Verify that modifying arm1 and arm2 does not change arm3 assert arm3.json_payload() == SampleArm(win=5, loss=3, total=8).json_payload()
[docs] def test_sample_arm_iadd(self): """Test SampleArm's __iadd__ overload operator. Verify that after x += y, x gets the new value x + y and still retains its old id. """ arm1 = SampleArm(win=2, loss=1, total=3) arm2 = SampleArm(win=3, loss=2, total=5) arm3 = arm1 + arm2 arm1_old_id = id(arm1) arm1 += arm2 arm1_new_id = id(arm1) assert arm1_old_id == arm1_new_id assert arm1.json_payload() == arm3.json_payload()
@pytest.mark.usefixtures("disable_logging")
[docs] def test_sample_arm_add_arm_with_variance_invalid(self): """Test that adding arms with variance causes a ValueError. Neither of the arms can have non-None variance.""" with pytest.raises(ValueError): arm = SampleArm(win=2, loss=1, total=500, variance=0.1) arm.__add__(SampleArm(win=2, loss=1, total=500, variance=None)) with pytest.raises(ValueError): arm = SampleArm(win=2, loss=1, total=500, variance=None) arm.__add__(SampleArm(win=2, loss=1, total=500, variance=0.1))
[docs] def test_historical_data_str(self): """Test HistoricalData's __str__ overload operator.""" for historical_info in self.historical_infos_to_test: assert str(historical_info) == pprint.pformat(historical_info.json_payload())
[docs] def test_historical_data_append_unsampled_arm(self): """Test that adding an unsampled arm (already exists in historical info) to HistoricalData does not change anything.""" historical_info = self.two_unsampled_arms_test_case historical_info.append_sample_arms(self.one_arm_test_case.arms_sampled) assert historical_info.json_payload() == self.two_unsampled_arms_test_case.json_payload()
[docs] def test_historical_data_append_arms(self): """Test that appending arms to HistoricalData updates historical info correctly.""" historical_info = copy.deepcopy(self.three_arms_test_case) historical_info.append_sample_arms(self.three_arms_two_winners_test_case.arms_sampled) expected_historical_info = HistoricalData( sample_arms={ "arm1": SampleArm(win=4, loss=2, total=6), "arm2": SampleArm(win=3, loss=2, total=5), "arm3": SampleArm(win=0, loss=0, total=0), } ) assert historical_info.json_payload() == expected_historical_info.json_payload()
@pytest.mark.usefixtures("disable_logging")
[docs] def test_historical_data_append_arms_with_variance_invalid(self): """Test that adding arms with variance causes a ValueError.""" historical_info = copy.deepcopy(self.three_arms_with_variance_no_unsampled_arm_test_case) with pytest.raises(ValueError): historical_info.append_sample_arms( self.three_arms_with_variance_no_unsampled_arm_test_case.arms_sampled )