# -*- coding: utf-8 -*-
"""Test class for bandit views."""
import pyramid.testing
import simplejson as json
from moe.bandit.linkers import BANDIT_ENDPOINTS_TO_SUBTYPES
from moe.tests.bandit.bandit_test_case import BanditTestCase
from moe.tests.views.rest_test_case import RestTestCase
from moe.views.schemas.bandit_pretty_view import BanditResponse
[docs]class TestBanditViews(BanditTestCase, RestTestCase):
"""Integration test for the /bandit endpoints."""
_endpoint = None # Define in a subclass
_historical_infos = None # Define in a subclass
_moe_route = None # Define in a subclass
_view = None # Define in a subclass
@staticmethod
def _build_json_payload(subtype, historical_info):
"""Create a json_payload to POST to the /bandit/* endpoint with all needed info."""
dict_to_dump = {
'subtype': subtype,
'historical_info': historical_info.json_payload(),
}
return json.dumps(dict_to_dump)
def _test_historical_info_passed_through(self):
"""Test that the historical infos get passed through to the endpoint."""
for subtype in BANDIT_ENDPOINTS_TO_SUBTYPES[self._endpoint]:
for historical_info in self._historical_infos:
# Test default test parameters get passed through
json_payload = json.loads(self._build_json_payload(subtype, historical_info))
request = pyramid.testing.DummyRequest(post=json_payload)
request.json_body = json_payload
view = self._view(request)
params = view.get_params_from_request()
assert params['historical_info'] == json_payload['historical_info']
def _test_interface_returns_as_expected(self):
"""Integration test for the bandit endpoints."""
for subtype in BANDIT_ENDPOINTS_TO_SUBTYPES[self._endpoint]:
for historical_info in self._historical_infos:
json_payload = self._build_json_payload(subtype, historical_info)
arm_names = set([arm_name for arm_name in historical_info.arms_sampled.iterkeys()])
resp = self.testapp.post(self._moe_route.endpoint, json_payload)
resp_schema = BanditResponse()
resp_dict = resp_schema.deserialize(json.loads(resp.body))
resp_arm_names = set([arm_name for arm_name in resp_dict['arm_allocations'].iterkeys()])
assert arm_names == resp_arm_names
# The allocations should be in range [0, 1]
# The sum of all allocations should be 1.0.
total_allocation = 0
for allocation in resp_dict['arm_allocations'].itervalues():
assert allocation >= 0
assert allocation <= 1
total_allocation += allocation
assert total_allocation == 1.0