Source code for moe.tests.views.rest.bandit_test

# -*- coding: utf-8 -*-
"""Test class for bandit views."""
import pyramid.testing

import simplejson as json

from moe.bandit.linkers import BANDIT_ENDPOINTS_TO_SUBTYPES
from moe.tests.bandit.bandit_test_case import BanditTestCase
from moe.tests.views.rest_test_case import RestTestCase
from moe.views.schemas.bandit_pretty_view import BanditResponse


[docs]class TestBanditViews(BanditTestCase, RestTestCase): """Integration test for the /bandit endpoints.""" _endpoint = None # Define in a subclass _historical_infos = None # Define in a subclass _moe_route = None # Define in a subclass _view = None # Define in a subclass @staticmethod def _build_json_payload(subtype, historical_info): """Create a json_payload to POST to the /bandit/* endpoint with all needed info.""" dict_to_dump = { 'subtype': subtype, 'historical_info': historical_info.json_payload(), } return json.dumps(dict_to_dump) def _test_historical_info_passed_through(self): """Test that the historical infos get passed through to the endpoint.""" for subtype in BANDIT_ENDPOINTS_TO_SUBTYPES[self._endpoint]: for historical_info in self._historical_infos: # Test default test parameters get passed through json_payload = json.loads(self._build_json_payload(subtype, historical_info)) request = pyramid.testing.DummyRequest(post=json_payload) request.json_body = json_payload view = self._view(request) params = view.get_params_from_request() assert params['historical_info'] == json_payload['historical_info'] def _test_interface_returns_as_expected(self): """Integration test for the bandit endpoints.""" for subtype in BANDIT_ENDPOINTS_TO_SUBTYPES[self._endpoint]: for historical_info in self._historical_infos: json_payload = self._build_json_payload(subtype, historical_info) arm_names = set([arm_name for arm_name in historical_info.arms_sampled.iterkeys()]) resp = self.testapp.post(self._moe_route.endpoint, json_payload) resp_schema = BanditResponse() resp_dict = resp_schema.deserialize(json.loads(resp.body)) resp_arm_names = set([arm_name for arm_name in resp_dict['arm_allocations'].iterkeys()]) assert arm_names == resp_arm_names # The allocations should be in range [0, 1] # The sum of all allocations should be 1.0. total_allocation = 0 for allocation in resp_dict['arm_allocations'].itervalues(): assert allocation >= 0 assert allocation <= 1 total_allocation += allocation assert total_allocation == 1.0