Cálculo rápido de cuadrados bignum

Para acelerar mis divisiones bignum, necesito acelerar la operación y = x^2 para los elementos grandes que se representan como matrices dinámicas de DWORD sin signo. Para ser claro:

 DWORD x[n+1] = { LSW, ......, MSW }; 
  • donde n + 1 es el número de DWORDs usadas
  • entonces el valor del número x = x[0]+x[1]<<32 + ... x[N]<<32*(n)

La pregunta es: ¿cómo calculo y = x^2 más rápido posible sin pérdida de precisión? – Utilizando C ++ y con aritmética de enteros (32 bits con Carry) a disposición.

Mi enfoque actual es aplicar la multiplicación y = x*x y evitar multiplicaciones múltiples.

Por ejemplo:

 x = x[0] + x[1]<<32 + ... x[n]<<32*(n) 

Para simplificar, déjame reescribirlo:

 x = x0+ x1 + x2 + ... + xn 

donde index representa la dirección dentro de la matriz, entonces:

 y = x*x y = (x0 + x1 + x2 + ...xn)*(x0 + x1 + x2 + ...xn) y = x0*(x0 + x1 + x2 + ...xn) + x1*(x0 + x1 + x2 + ...xn) + x2*(x0 + x1 + x2 + ...xn) + ...xn*(x0 + x1 + x2 + ...xn) y0 = x0*x0 y1 = x1*x0 + x0*x1 y2 = x2*x0 + x1*x1 + x0*x2 y3 = x3*x0 + x2*x1 + x1*x2 ... y(2n-3) = xn(n-2)*x(n ) + x(n-1)*x(n-1) + x(n )*x(n-2) y(2n-2) = xn(n-1)*x(n ) + x(n )*x(n-1) y(2n-1) = xn(n )*x(n ) 

Después de una mirada más cercana, está claro que casi todos los xi*xj aparecen dos veces (no el primero y el último) lo que significa que las multiplicaciones de N*N pueden ser reemplazadas por (N+1)*(N/2) multiplicaciones. PS 32bit*32bit = 64bit por lo que el resultado de cada operación mul+add se maneja como 64+1 bit .

¿Hay una mejor manera de calcular esto rápido? Todo lo que encontré durante las búsquedas fueron algoritmos de sqrts, no sqr …

Rápido sqr

!!! Tenga en cuenta que todos los números en mi código son MSW primero, … no como en la prueba anterior (hay LSW primero por simplicidad de ecuaciones, de lo contrario sería un lío de índice).

Implementación fsqr funcional actual

 void arbnum::sqr(const arbnum &x) { // O((N+1)*N/2) arbnum c; DWORD h, l; int N, nx, nc, i, i0, i1, k; c._alloc(x.siz + x.siz + 1); nx = x.siz - 1; nc = c.siz - 1; N = nx + nx; for (i=0; i<=nc; i++) c.dat[i]=0; for (i=1; i<N; i++) for (i0=0; (i0<=nx) && (i0= i1) break; if (i1 > nx) continue; h = x.dat[nx-i0]; if (!h) continue; l = x.dat[nx-i1]; if (!l) continue; alu.mul(h, l, h, l); k = nc - i; if (k >= 0) alu.add(c.dat[k], c.dat[k], l); k--; if (k>=0) alu.adc(c.dat[k], c.dat[k],h); k--; for (; (alu.cy) && (k>=0); k--) alu.inc(c.dat[k]); } c.shl(1); for (i = 0; i >1; h = x.dat[nx-i0]; if (!h) continue; alu.mul(h, l, h, h); k = nc - i; if (k >= 0) alu.add(c.dat[k], c.dat[k],l); k--; if (k>=0) alu.adc(c.dat[k], c.dat[k], h); k--; for (; (alu.cy) && (k >= 0); k--) alu.inc(c.dat[k]); } c.bits = c.siz<<5; c.exp = x.exp + x.exp + ((c.siz - x.siz - x.siz)<<5) + 1; c.sig = sig; *this = c; } 

Uso de la multiplicación de Karatsuba

(gracias a Calpis)

Implementé la multiplicación de Karatsuba, pero los resultados son mucho más lentos incluso que con el uso de la simple multiplicación O(N^2) , probablemente debido a esa recursión horrible que no veo ninguna forma de evitar. La compensación debe ser en números realmente grandes (más grandes que cientos de dígitos) … pero incluso así, hay muchas transferencias de memoria. ¿Hay alguna manera de evitar las llamadas de recursión (variante no recursiva, … Casi todos los algoritmos recursivos se pueden hacer de esa manera). Aún así, intentaré modificar las cosas y ver qué pasa (evite las normalizaciones, etc., también podría ser un error tonto en el código). De todos modos, después de resolver Karatsuba para el caso x*x no hay mucha ganancia de rendimiento.

Multiplicación optimizada de Karatsuba

Prueba de rendimiento para y = x^2 looped 1000x times, 0.9 < x < 1 ~ 32*98 bits :

 x = 0.98765588997654321000000009876... | 98*32 bits sqr [ 213.989 ms ] ... O((N+1)*N/2) fast sqr mul1[ 363.472 ms ] ... O(N^2) classic multiplication mul2[ 349.384 ms ] ... O(3*(N^log2(3))) optimized Karatsuba multiplication mul3[ 9345.127 ms] ... O(3*(N^log2(3))) unoptimized Karatsuba multiplication x = 0.98765588997654321000... | 195*32 bits sqr [ 883.01 ms ] mul1[ 1427.02 ms ] mul2[ 1089.84 ms ] x = 0.98765588997654321000... | 389*32 bits sqr [ 3189.19 ms ] mul1[ 5553.23 ms ] mul2[ 3159.07 ms ] 

Después de las optimizaciones para Karatsuba, el código es masivamente más rápido que antes. Aún así, para números más pequeños, es un poco menos de la mitad de la velocidad de mi multiplicación de O(N^2) . Para números más grandes, es más rápido con la proporción dada por las complejidades de las multiplicaciones de Booth. El umbral para la multiplicación es de alrededor de 32 * 98 bits y para sqr alrededor de 32 * 389 bits, por lo que si la sum de los bits de entrada cruza este umbral, la multiplicación de Karatsuba se utilizará para acelerar la multiplicación y también para sqr.

Por cierto, optimizaciones incluidas:

  • Minimice la destrucción de montón por argumento de recursión demasiado grande
  • En su lugar, se utiliza la evitación de cualquier aritmética de bignum (+, -) ALU de 32 bits con acarreo.
  • Ignorando 0*y o x*0 o 0*0 casos
  • Reformateo de los tamaños de los números de entrada x,y a la potencia de dos para evitar la reasignación
  • Implementar la multiplicación de módulo para z1 = (x0 + x1)*(y0 + y1) para minimizar la recursión

Modificación de la multiplicación de Schönhage-Strassen a la implementación de sqr

He probado el uso de transformaciones FFT y NTT para acelerar el cálculo de sqr. Los resultados son estos:

  1. FFT

    Perder precisión y, por lo tanto, necesita números complejos de alta precisión. Esto en realidad ralentiza considerablemente las cosas por lo que no hay aceleración. El resultado no es preciso (se puede redondear incorrectamente), por lo que FFT no se puede usar (por ahora)

  2. NTT

    NTT es campo finito DFT y por lo tanto no se produce pérdida de precisión. Necesita aritmética modular en enteros sin signo: modpow, modmul, modadd y modsub .

    Yo uso DWORD (números enteros sin signo de 32 bits). ¡El tamaño del vector NTT input / otput es limitado debido a problemas de desbordamiento! Para la aritmética modular de 32 bits, N está limitado a (2^32)/(max(input[])^2) por lo que bigint debe dividirse en fragmentos más pequeños (yo uso BYTES para que el tamaño máximo de bigint procesado sea

     (2^32)/((2^8)^2) = 2^16 bytes = 2^14 DWORDs = 16384 DWORDs) 

    El sqr usa solo 1xNTT + 1xINTT lugar de 2xNTT + 1xINTT para la multiplicación, pero el uso de NTT es demasiado lento y el tamaño del número de umbral es demasiado grande para el uso práctico en mi implementación (para mul y también para sqr ).

    Es posible que esté incluso por encima del límite de desbordamiento, por lo que se deben usar aritméticas modulares de 64 bits que pueden ralentizar aún más las cosas. Entonces NTT también es inutilizable para mis propósitos.

Algunas medidas:

 a = 0.98765588997654321000 | 389*32 bits looped 1x times sqr1[ 3.177 ms ] fast sqr sqr2[ 720.419 ms ] NTT sqr mul1[ 5.588 ms ] simpe mul mul2[ 3.172 ms ] karatsuba mul mul3[ 1053.382 ms ] NTT mul 

Mi implementación:

 void arbnum::sqr_NTT(const arbnum &x) { // O(N*log(N)*(log(log(N)))) - 1x NTT // Schönhage-Strassen sqr // To prevent NTT overflow: n  result siz  x.siz + y.siz <= 12K!!! int i, j, k, n; int s = x.sig*x.sig, exp0 = x.exp + x.exp - ((x.siz+x.siz)<<5) + 2; i = x.siz; for (n = 1; n < i; n< 0x3000) { _error(_arbnum_error_TooBigNumber); zero(); return; } n <<= 3; DWORD *xx, *yy, q, qq; xx = new DWORD[n+n]; #ifdef _mmap_h if (xx) mmap_new(xx, (n+n) <= 0; i--) { q = x.dat[i]; xx[k] = q&0xFF; k++; q>>=8; xx[k] = q&0xFF; k++; q>>=8; xx[k] = q&0xFF; k++; q>>=8; xx[k] = q&0xFF; k++; } for (;k<n;k++) xx[k] = 0; //NTT fourier_NTT ntt; ntt.NTT(yy,xx,n); // init NTT for n // Convolution for (i=0; i<n; i++) yy[i] = modmul(yy[i], yy[i], ntt.p); //INTT ntt.INTT(xx, yy); //suma q=0; for (i = 0, j = 0; i>=8; qq>>=8; q+=qq; } // Merge WORDs to DWORDs and copy them to result _alloc(n>>2); for (i = 0, j = 0; i<siz; i++) { q =(yy[j]<<24)&0xFF000000; j++; q |=(yy[j]<<16)&0x00FF0000; j++; q |=(yy[j]<< 8)&0x0000FF00; j++; q |=(yy[j] )&0x000000FF; j++; dat[i] = q; } #ifdef _mmap_h if (xx) mmap_del(xx); #endif delete xx; bits = siz<<5; sig = s; exp = exp0 + (siz<<5) - 1; // _normalize(); } 

Conclusión

Para números más pequeños, es la mejor opción mi enfoque rápido sqr , y después del umbral la multiplicación de Karatsuba es mejor. Pero todavía creo que debería haber algo trivial que hemos pasado por alto. ¿Alguien tiene otras ideas?

Optimización NTT

Después de optimizaciones intensamente intensas (en su mayoría NTT ): pregunta sobre desbordamiento de stack Aritmética modular y optimizaciones NTT (campo finito DFT) .

Algunos valores han cambiado:

 a = 0.98765588997654321000 | 1553*32bits looped 10x times mul2[ 28.585 ms ] Karatsuba mul mul3[ 26.311 ms ] NTT mul 

Así que ahora la multiplicación de NTT es finalmente más rápida que Karatsuba después de un umbral de aproximadamente 1500 * 32 bits.

Algunas medidas y errores detectados

 a = 0.99991970486 | 1553*32 bits looped: 10x sqr1[ 58.656 ms ] fast sqr sqr2[ 13.447 ms ] NTT sqr mul1[ 102.563 ms ] simpe mul mul2[ 28.916 ms ] Karatsuba mul Error mul3[ 19.470 ms ] NTT mul 

Descubrí que mi Karatsuba (más / menos) fluye el LSB de cada segmento DWORD de bignum. Cuando investigué, actualizaré el código …

Además, después de nuevas optimizaciones NTT los umbrales cambiaron, por lo que para NTT sqr es 310*32 bits = 9920 bits de operando , y para NTT mul es 1396*32 bits = 44672 bits de resultado (sum de bits de operandos).

Código de Karatsuba reparado gracias a @greybeard

 //--------------------------------------------------------------------------- void arbnum::_mul_karatsuba(DWORD *z, DWORD *x, DWORD *y, int n) { // Recursion for Karatsuba // z[2n] = x[n]*y[n]; // n=2^m int i; for (i=0; i<n; i++) if (x[i]) { i=-1; break; } // x==0 ? if (i < 0) for (i = 0; i= 0) { for (i = 0; i < n + n; i++) z[i]=0; return; } // 0.? = 0 if (n == 1) { alu.mul(z[0], z[1], x[0], y[0]); return; } if (n>1; _mul_karatsuba(z+n, x+n2, y+n2, n2); // z0 = x0.y0 _mul_karatsuba(z , x , y , n2); // z2 = x1.y1 DWORD *q = new DWORD[n<=0; i--) alu.adc(qq[i], q0[i], q1[i]); } // qq = q0 + q1 ...[i..0] #define _sub { alu.sub(qq[i], q0[i], q1[i]); for (i--; i>=0; i--) alu.sbc(qq[i], q0[i], q1[i]); } // qq = q0 - q1 ...[i..0] qq = q; q0 = x + n2; q1 = x; i = n2 - 1; _add; cx = alu.cy; // =x0+x1 qq = q + n2; q0 = y + n2; q1 = y; i = n2 - 1; _add; cy = alu.cy; // =y0+y1 _mul_karatsuba(q + n, q + n2, q, n2); // =(x0+x1)(y0+y1) mod ((2^N)-1) if (cx) { qq = q + n; q0 = qq; q1 = q + n2; i = n2 - 1; _add; cx = alu.cy; }// += cx*(y0 + y1) << n2 if (cy) { qq = q + n; q0 = qq; q1 = q; i = n2 -1; _add; cy = alu.cy; }// +=cy*(x0+x1)<=0; i--) if (alu.cy) alu.inc(z[i]); else break; } delete[] q; #undef _add #undef _sub } //--------------------------------------------------------------------------- void arbnum::mul_karatsuba(const arbnum &x, const arbnum &y) { // O(3*(N)^log2(3)) ~ O(3*(N^1.585)) // Karatsuba multiplication // int s = x.sig*y.sig; arbnum a, b; a = x; b = y; a.sig = +1; b.sig = +1; int i, n; for (n = 1; (n < a.siz) || (n < b.siz); n <<= 1) ; a._realloc(n); b._realloc(n); _alloc(n + n); for (i=0; i < siz; i++) dat[i]=0; _mul_karatsuba(dat, a.dat, b.dat, n); bits = siz << 5; sig = s; exp = a.exp + b.exp + ((siz-a.siz-b.siz)<<5) + 1; // _normalize(); } //--------------------------------------------------------------------------- 

Mi representación numérica arbnum :

 // dat is MSDW first ... LSDW last DWORD *dat; int siz,exp,sig,bits; 
  • dat[siz] es la mantisa. LSDW significa DWORD menos significativo.
  • exp es el exponente de MSB de dat[0]
  • ¡El primer bit distinto de cero está presente en la mantisa!

     // |-----|---------------------------|---------------|------| // | sig | MSB mantisa LSB | exponent | bits | // |-----|---------------------------|---------------|------| // | +1 | 0.(0 ... 0) | 2^0 | 0 | +zero // | -1 | 0.(0 ... 0) | 2^0 | 0 | -zero // |-----|---------------------------|---------------|------| // | +1 | 1.(dat[0] ... dat[siz-1]) | 2^exp | n | +number // | -1 | 1.(dat[0] ... dat[siz-1]) | 2^exp | n | -number // |-----|---------------------------|---------------|------| // | +1 | 1.0 | 2^+0x7FFFFFFE | 1 | +infinity // | -1 | 1.0 | 2^+0x7FFFFFFE | 1 | -infinity // |-----|---------------------------|---------------|------| 

Si entiendo tu algoritmo correctamente, parece O(n^2) donde n es el número de dígitos.

¿Has mirado Algoritmo Karatsuba ? Acelera la multiplicación usando el enfoque de dividir y conquistar. Puede valer la pena echarle un vistazo.

Si está buscando escribir un nuevo exponente mejor, es posible que deba escribirlo en ensamblaje. Este es el código de golang.

https://code.google.com/p/go/source/browse/src/pkg/math/exp_amd64.s