Source code for moe.tests.views.rest.bandit_epsilon_test

# -*- coding: utf-8 -*-
"""Test class for bandit_epsilon view."""
import pyramid.testing

import simplejson as json

from moe.bandit.constant import BANDIT_EPSILON_ENDPOINT, EPSILON_SUBTYPES_TO_DEFAULT_HYPERPARAMETER_INFOS, EPSILON_SUBTYPE_FIRST, EPSILON_SUBTYPE_GREEDY
from moe.tests.bandit.bandit_test_case import BanditTestCase
from moe.tests.views.rest.bandit_test import TestBanditViews
from moe.views.constant import BANDIT_EPSILON_MOE_ROUTE
from moe.views.rest.bandit_epsilon import BanditEpsilonView


[docs]class TestBanditEpsilonViews(TestBanditViews): """Integration test for the /bandit/epsilon endpoint.""" _endpoint = BANDIT_EPSILON_ENDPOINT _historical_infos = BanditTestCase.historical_infos_to_test _moe_route = BANDIT_EPSILON_MOE_ROUTE _view = BanditEpsilonView @staticmethod def _build_json_payload(subtype, historical_info, hyperparameter_info=None): """Create a json_payload to POST to the /bandit/epsilon endpoint with all needed info.""" if hyperparameter_info is None: hyperparameter_info = EPSILON_SUBTYPES_TO_DEFAULT_HYPERPARAMETER_INFOS[subtype] dict_to_dump = { 'subtype': subtype, 'historical_info': historical_info.json_payload(), 'hyperparameter_info': hyperparameter_info, } return json.dumps(dict_to_dump)
[docs] def test_epsilon_greedy_hyperparameters_passed_through(self): """Test that the hyperparameters get passed through to the epsilon-greedy endpoint.""" historical_info = self.one_arm_test_case # Test default test parameters get passed through json_payload = json.loads(self._build_json_payload(EPSILON_SUBTYPE_GREEDY, historical_info, EPSILON_SUBTYPES_TO_DEFAULT_HYPERPARAMETER_INFOS[EPSILON_SUBTYPE_GREEDY])) request = pyramid.testing.DummyRequest(post=json_payload) request.json_body = json_payload view = BanditEpsilonView(request) params = view.get_params_from_request() assert params['hyperparameter_info'] == json_payload['hyperparameter_info'] # Test arbitrary epsilons get passed through json_payload['hyperparameter_info']['epsilon'] = 1.0 request = pyramid.testing.DummyRequest(post=json_payload) request.json_body = json_payload view = BanditEpsilonView(request) params = view.get_params_from_request() assert params['hyperparameter_info'] == json_payload['hyperparameter_info']
[docs] def test_epsilon_first_hyperparameters_passed_through(self): """Test that the hyperparameters get passed through to the epsilon-first endpoint.""" historical_info = self.one_arm_test_case # Test default test parameters get passed through json_payload = json.loads(self._build_json_payload(EPSILON_SUBTYPE_FIRST, historical_info, EPSILON_SUBTYPES_TO_DEFAULT_HYPERPARAMETER_INFOS[EPSILON_SUBTYPE_FIRST])) request = pyramid.testing.DummyRequest(post=json_payload) request.json_body = json_payload view = self._view(request) params = view.get_params_from_request() assert params['hyperparameter_info'] == json_payload['hyperparameter_info'] # Test an arbitrary epsilon and total_tamples get passed through json_payload['hyperparameter_info']['epsilon'] = 1.0 json_payload['hyperparameter_info']['total_samples'] = 20000 request = pyramid.testing.DummyRequest(post=json_payload) request.json_body = json_payload view = self._view(request) params = view.get_params_from_request() assert params['hyperparameter_info'] == json_payload['hyperparameter_info']
[docs] def test_historical_info_passed_through(self): """Test that the historical info get passed through to the endpoint.""" self._test_historical_info_passed_through()
[docs] def test_interface_returns_as_expected(self): """Integration test for the /bandit/epsilon endpoint.""" self._test_interface_returns_as_expected()