Genera una matriz que contiene todas las combinaciones de elementos tomados de n vectores

Esta pregunta aparece con bastante frecuencia de una forma u otra (ver, por ejemplo, aquí o aquí ). Así que pensé en presentarlo de forma general y proporcionar una respuesta que podría servir para futuras referencias.

Dado un número arbitrario n de vectores de tamaños posiblemente diferentes, genere una matriz n columna cuyas filas describen todas las combinaciones de elementos tomados de esos vectores (producto cartesiano).

Por ejemplo,

 vectors = { [1 2], [3 6 9], [10 20] } 

debería dar

 combs = [ 1 3 10 1 3 20 1 6 10 1 6 20 1 9 10 1 9 20 2 3 10 2 3 20 2 6 10 2 6 20 2 9 10 2 9 20 ] 

La función ndgrid casi da la respuesta, pero tiene una advertencia: n las variables de salida deben definirse explícitamente para llamarlo. Como n es arbitrario, la mejor manera es usar una lista separada por comas (generada a partir de una matriz de celdas con n celdas) para servir como salida. Las n matrices resultantes se concatenan en la matriz de n columnas deseada:

 vectors = { [1 2], [3 6 9], [10 20] }; %// input data: cell array of vectors n = numel(vectors); %// number of vectors combs = cell(1,n); %// pre-define to generate comma-separated list [combs{end:-1:1}] = ndgrid(vectors{end:-1:1}); %// the reverse order in these two %// comma-separated lists is needed to produce the rows of the result matrix in %// lexicographical order combs = cat(n+1, combs{:}); %// concat the n n-dim arrays along dimension n+1 combs = reshape(combs,[],n); %// reshape to obtain desired matrix 

Un poco más simple … si tienes la caja de herramientas de Neural Network, simplemente puedes usar combvec :

 vectors = {[1 2], [3 6 9], [10 20]}; combs = combvec(vectors{:}).' % Use cells as arguments 

que devuelve una matriz en un orden ligeramente diferente:

 combs = 1 3 10 2 3 10 1 6 10 2 6 10 1 9 10 2 9 10 1 3 20 2 3 20 1 6 20 2 6 20 1 9 20 2 9 20 

Si quieres la matriz que está en la pregunta, puedes usar sortrows :

 combs = sortrows(combvec(vectors{:}).') % Or equivalently as per @LuisMendo in the comments: % combs = fliplr(combvec(vectors{end:-1:1}).') 

lo que da

 combs = 1 3 10 1 3 20 1 6 10 1 6 20 1 9 10 1 9 20 2 3 10 2 3 20 2 6 10 2 6 20 2 9 10 2 9 20 

Si miras las combvec internas de combvec (escribe edit combvec en la ventana de comandos), verás que usa un código diferente a la respuesta de @ LuisMendo. No puedo decir cuál es más eficiente en general.

Si tiene una matriz cuyas filas son similares a la matriz de celdas anterior, puede usar:

 vectors = [1 2;3 6;10 20]; vectors = num2cell(vectors,2); combs = sortrows(combvec(vectors{:}).') 

Hice algunas evaluaciones comparativas sobre las dos soluciones propuestas. El código de evaluación comparativa se basa en la función de tiempo , y se incluye al final de esta publicación.

Considero dos casos: tres vectores de tamaño n , y tres vectores de tamaños n/10 , n y n*10 respectivamente (ambos casos dan el mismo número de combinaciones). n se varía hasta un máximo de 240 (elijo este valor para evitar el uso de la memoria virtual en mi computadora portátil).

Los resultados se dan en la siguiente figura. La solución basada en ndgrid se considera que consistentemente toma menos tiempo que combvec . También es interesante notar que el tiempo que toma combvec varía un poco menos regularmente en el caso de diferentes tamaños.

enter image description here


Código de evaluación comparativa

Función para la solución basada en ndgrid :

 function combs = f1(vectors) n = numel(vectors); %// number of vectors combs = cell(1,n); %// pre-define to generate comma-separated list [combs{end:-1:1}] = ndgrid(vectors{end:-1:1}); %// the reverse order in these two %// comma-separated lists is needed to produce the rows of the result matrix in %// lexicographical order combs = cat(n+1, combs{:}); %// concat the n n-dim arrays along dimension n+1 combs = reshape(combs,[],n); 

Función para la solución combvec :

 function combs = f2(vectors) combs = combvec(vectors{:}).'; 

Script para medir el tiempo llamando al timeit sobre estas funciones:

 nn = 20:20:240; t1 = []; t2 = []; for n = nn; %//vectors = {1:n, 1:n, 1:n}; vectors = {1:n/10, 1:n, 1:n*10}; t = timeit(@() f1(vectors)); t1 = [t1; t]; t = timeit(@() f2(vectors)); t2 = [t2; t]; end 

Aquí hay un método de hágalo usted mismo que me hizo reír con deleite, utilizando nchoosek , aunque no es mejor que la solución aceptada de @Luis Mendo.

Para el ejemplo dado, después de 1,000 ejecuciones esta solución tomó mi máquina en promedio 0.00065935 s, frente a la solución aceptada 0.00012877 s. Para vectores más grandes, siguiendo la publicación de benchmarking de @Luis Mendo, esta solución es consistentemente más lenta que la respuesta aceptada. Sin embargo, decidí publicarlo con la esperanza de que tal vez encuentres algo útil al respecto:

Código:

 tic; v = {[1 2], [3 6 9], [10 20]}; L = [0 cumsum(cellfun(@length,v))]; V = cell2mat(v); J = nchoosek(1:L(end),length(v)); J(any(J>repmat(L(2:end),[size(J,1) 1]),2) | ... any(J< =repmat(L(1:end-1),[size(J,1) 1]),2),:) = []; V(J) toc 

da

 ans = 1 3 10 1 3 20 1 6 10 1 6 20 1 9 10 1 9 20 2 3 10 2 3 20 2 6 10 2 6 20 2 9 10 2 9 20 Elapsed time is 0.018434 seconds. 

Explicación:

L obtiene la longitud de cada vector usando cellfun . Aunque cellfun es básicamente un bucle, es eficiente aquí ya que su número de vectores tendrá que ser relativamente bajo para que este problema sea práctico.

V concatena todos los vectores para facilitar el acceso más adelante (esto supone que ingresó todos sus vectores como filas. V 'funcionaría para los vectores de columna).

nchoosek obtiene todas las formas de elegir n=length(v) partir del número total de elementos L(end) . Habrá más combinaciones aquí de lo que necesitamos.

 J = 1 2 3 1 2 4 1 2 5 1 2 6 1 2 7 1 3 4 1 3 5 1 3 6 1 3 7 1 4 5 1 4 6 1 4 7 1 5 6 1 5 7 1 6 7 2 3 4 2 3 5 2 3 6 2 3 7 2 4 5 2 4 6 2 4 7 2 5 6 2 5 7 2 6 7 3 4 5 3 4 6 3 4 7 3 5 6 3 5 7 3 6 7 4 5 6 4 5 7 4 6 7 5 6 7 

Como solo hay dos elementos en v(1) , debemos descartar cualquier fila donde J(:,1)>2 . De manera similar, donde J(:,2)<3 , J(:,2)>5 , etc ... Usando L y repmat podemos determinar si cada elemento de J está en su rango apropiado, y luego usar any para descartar filas que tiene algún elemento malo

Finalmente, estos no son los valores reales de v , solo índices. V(J) devolverá la matriz deseada.