# -*- coding: utf-8 -*-
"""Utilities for bandit."""
[docs]def get_winning_arm_names_from_payoff_arm_name_list(payoff_arm_name_list):
r"""Compute the set of winning arm names based on the given ``payoff_arm_name_list``..
Throws an exception when payoff_arm_name_list is empty.
:param payoff_arm_name_list: a list of (payoff, arm name) tuples
:type payoff_arm_name_list: list of (float64, str) tuples
:return: of set of names of the winning arms
:rtype: frozenset(str)
:raise: ValueError when ``payoff_arm_name_list`` are empty.
"""
if not payoff_arm_name_list:
raise ValueError('payoff_arm_name_list is empty!')
best_payoff, _ = max(payoff_arm_name_list)
# Filter out arms that have payoff less than the best payoff
winning_arm_payoff_name_list = filter(lambda payoff_arm_name: payoff_arm_name[0] == best_payoff, payoff_arm_name_list)
# Extract a list of winning arm names from a list of (payoff, arm name) tuples.
_, winning_arm_name_list = map(list, zip(*winning_arm_payoff_name_list))
winning_arm_names = frozenset(winning_arm_name_list)
return winning_arm_names
[docs]def get_equal_arm_allocations(arms_sampled, winning_arm_names=None):
r"""Split allocations equally among the given ``winning_arm_names``. If no ``winning_arm_names`` given, split allocations among ``arms_sampled``.
Throws an exception when arms_sampled is empty.
:param arms_sampled: a dictionary of arm name to :class:`moe.bandit.data_containers.SampleArm`
:type arms_sampled: dictionary of (str, SampleArm()) pairs
:param: winning_arm_names: a set of names of the winning arms
:type: winning_arm_names: frozenset(str)
:return: the dictionary of (arm, allocation) key-value pairs
:rtype: a dictionary of (str, float64) pairs
:raise: ValueError when ``arms_sampled`` are empty.
"""
if not arms_sampled:
raise ValueError('arms_sampled is empty!')
# If no ``winning_arm_names`` given, split allocations among ``arms_sampled``.
if winning_arm_names is None:
winning_arm_names = frozenset([arm_name for arm_name in arms_sampled.iterkeys()])
num_winning_arms = len(winning_arm_names)
arms_to_allocations = {}
winning_arm_allocation = 1.0 / num_winning_arms
# Split allocation among winning arms, all other arms get allocation of 0.
for arm_name in arms_sampled.iterkeys():
arms_to_allocations[arm_name] = winning_arm_allocation if arm_name in winning_arm_names else 0.0
return arms_to_allocations