Selección ponderada aleatoria en Java

Quiero elegir un elemento al azar de un conjunto, pero la posibilidad de elegir cualquier artículo debe ser proporcional al peso asociado

Entradas de ejemplo:

item weight ---- ------ sword of misery 10 shield of happy 5 potion of dying 6 triple-edged sword 1 

Entonces, si tengo 4 elementos posibles, la posibilidad de obtener un artículo sin pesos sería de 1 en 4.

En este caso, un usuario debería tener 10 veces más probabilidades de obtener la espada de la miseria que la espada de triple filo.

¿Cómo hago una selección aleatoria ponderada en Java?

Yo usaría un NavigableMap

 public class RandomCollection { private final NavigableMap map = new TreeMap(); private final Random random; private double total = 0; public RandomCollection() { this(new Random()); } public RandomCollection(Random random) { this.random = random; } public RandomCollection add(double weight, E result) { if (weight <= 0) return this; total += weight; map.put(total, result); return this; } public E next() { double value = random.nextDouble() * total; return map.higherEntry(value).getValue(); } } 

Digamos que tengo una lista de animales de perros, gatos y caballos con probabilidades de 40%, 35%, 25% respectivamente

 RandomCollection rc = new RandomCollection<>() .add(40, "dog").add(35, "cat").add(25, "horse"); for (int i = 0; i < 10; i++) { System.out.println(rc.next()); } 

No encontrará un marco para este tipo de problema, ya que la funcionalidad solicitada no es más que una simple función. Haz algo como esto:

 interface Item { double getWeight(); } class RandomItemChooser { public Item chooseOnWeight(List items) { double completeWeight = 0.0; for (Item item : items) completeWeight += item.getWeight(); double r = Math.random() * completeWeight; double countWeight = 0.0; for (Item item : items) { countWeight += item.getWeight(); if (countWeight >= r) return item; } throw new RuntimeException("Should never be shown."); } } 

Ahora hay una clase para esto en Apache Commons: Distribución Enumerada

 Item selectedItem = new EnumeratedDistribution(itemWeights).sample(); 

donde itemWeights es una List> , like (asumiendo la interfaz Item en la respuesta de Arne):

 List> itemWeights = Collections.newArrayList(); for (Item i : itemSet) { itemWeights.add(new Pair(i, i.getWeight())); } 

o en Java 8:

 itemSet.stream().map(i -> new Pair(i, i.getWeight())).collect(toList()); 

Nota: El Pair aquí debe ser org.apache.commons.math3.util.Pair , no org.apache.commons.lang3.tuple.Pair .

Use un método de alias

Si vas a tirar muchas veces (como en un juego), debes usar un método de alias.

El código a continuación es bastante largo la implementación de dicho método de alias, de hecho. Pero esto se debe a la parte de inicialización. La recuperación de elementos es muy rápida (consulte los métodos next y applyAsInt que no se applyAsInt ).

Uso

 Set items = ... ; ToDoubleFunction weighter = ... ; Random random = new Random(); RandomSelector selector = RandomSelector.weighted(items, weighter); Item drop = selector.next(random); 

Implementación

Esta implementación:

  • usa Java 8 ;
  • está diseñado para ser lo más rápido posible (bueno, al menos, traté de hacerlo usando micro-benchmarking);
  • es totalmente seguro para hilos (mantenga un Random en cada hilo para un rendimiento máximo, use ThreadLocalRandom ?);
  • busca elementos en O (1) , a diferencia de lo que se encuentra principalmente en Internet o en StackOverflow, donde las implementaciones ingenuas se ejecutan en O (n) u O (log (n));
  • mantiene los artículos independientes de su peso , por lo que a un artículo se le pueden asignar varios pesos en diferentes contextos.

De todos modos, aquí está el código. (Tenga en cuenta que mantengo una versión actualizada de esta clase ).

 import static java.util.Objects.requireNonNull; import java.util.*; import java.util.function.*; public final class RandomSelector { public static  RandomSelector weighted(Set elements, ToDoubleFunction weighter) throws IllegalArgumentException { requireNonNull(elements, "elements must not be null"); requireNonNull(weighter, "weighter must not be null"); if (elements.isEmpty()) { throw new IllegalArgumentException("elements must not be empty"); } // Array is faster than anything. Use that. int size = elements.size(); T[] elementArray = elements.toArray((T[]) new Object[size]); double totalWeight = 0d; double[] discreteProbabilities = new double[size]; // Retrieve the probabilities for (int i = 0; i < size; i++) { double weight = weighter.applyAsDouble(elementArray[i]); if (weight < 0.0d) { throw new IllegalArgumentException("weighter may not return a negative number"); } discreteProbabilities[i] = weight; totalWeight += weight; } if (totalWeight == 0.0d) { throw new IllegalArgumentException("the total weight of elements must be greater than 0"); } // Normalize the probabilities for (int i = 0; i < size; i++) { discreteProbabilities[i] /= totalWeight; } return new RandomSelector<>(elementArray, new RandomWeightedSelection(discreteProbabilities)); } private final T[] elements; private final ToIntFunction selection; private RandomSelector(T[] elements, ToIntFunction selection) { this.elements = elements; this.selection = selection; } public T next(Random random) { return elements[selection.applyAsInt(random)]; } private static class RandomWeightedSelection implements ToIntFunction { // Alias method implementation O(1) // using Vose's algorithm to initialize O(n) private final double[] probabilities; private final int[] alias; RandomWeightedSelection(double[] probabilities) { int size = probabilities.length; double average = 1.0d / size; int[] small = new int[size]; int smallSize = 0; int[] large = new int[size]; int largeSize = 0; // Describe a column as either small (below average) or large (above average). for (int i = 0; i < size; i++) { if (probabilities[i] < average) { small[smallSize++] = i; } else { large[largeSize++] = i; } } // For each column, saturate a small probability to average with a large probability. while (largeSize != 0 && smallSize != 0) { int less = small[--smallSize]; int more = large[--largeSize]; probabilities[less] = probabilities[less] * size; alias[less] = more; probabilities[more] += probabilities[less] - average; if (probabilities[more] < average) { small[smallSize++] = more; } else { large[largeSize++] = more; } } // Flush unused columns. while (smallSize != 0) { probabilities[small[--smallSize]] = 1.0d; } while (largeSize != 0) { probabilities[large[--largeSize]] = 1.0d; } } @Override public int applyAsInt(Random random) { // Call random once to decide which column will be used. int column = random.nextInt(probabilities.length); // Call random a second time to decide which will be used: the column or the alias. if (random.nextDouble() < probabilities[column]) { return column; } else { return alias[column]; } } } } 

Si necesita eliminar elementos después de elegir, puede usar otra solución. Agregue todos los elementos en una ‘LinkedList’, cada elemento debe agregarse tantas veces como sea su peso, luego use Collections.shuffle() que, de acuerdo con JavaDoc

Cambia aleatoriamente la lista especificada utilizando una fuente predeterminada de aleatoriedad. Todas las permutaciones ocurren con aproximadamente la misma probabilidad.

Finalmente, obtenga y elimine elementos usando pop() o removeFirst()

 Map map = new HashMap() {{ put("Five", 5); put("Four", 4); put("Three", 3); put("Two", 2); put("One", 1); }}; LinkedList list = new LinkedList<>(); for (Map.Entry entry : map.entrySet()) { for (int i = 0; i < entry.getValue(); i++) { list.add(entry.getKey()); } } Collections.shuffle(list); int size = list.size(); for (int i = 0; i < size; i++) { System.out.println(list.pop()); } 
 public class RandomCollection { private final NavigableMap map = new TreeMap(); private double total = 0; public void add(double weight, E result) { if (weight <= 0 || map.containsValue(result)) return; total += weight; map.put(total, result); } public E next() { double value = ThreadLocalRandom.current().nextDouble() * total; return map.ceilingEntry(value).getValue(); } }