// Package bigfft implements multiplication of big.Int using FFT. // // The implementation is based on the Schönhage-Strassen method // using integer FFT modulo 2^n+1.
package bigfft import ( ) const _W = int(unsafe.Sizeof(big.Word(0)) * 8) type nat []big.Word func ( nat) () string { := new(big.Int) .SetBits() return .String() } // fftThreshold is the size (in words) above which FFT is used over // Karatsuba from math/big. // // TestCalibrate seems to indicate a threshold of 60kbits on 32-bit // arches and 110kbits on 64-bit arches. var fftThreshold = 1800 // Mul computes the product x*y and returns z. // It can be used instead of the Mul method of // *big.Int from math/big package. func (, *big.Int) *big.Int { := len(.Bits()) := len(.Bits()) if > fftThreshold && > fftThreshold { return mulFFT(, ) } return new(big.Int).Mul(, ) } func (, *big.Int) *big.Int { var , nat = .Bits(), .Bits() := fftmul(, ) := new(big.Int) .SetBits() if .Sign()*.Sign() < 0 { .Neg() } return } // A FFT size of K=1<<k is adequate when K is about 2*sqrt(N) where // N = x.Bitlen() + y.Bitlen(). func (, nat) nat { , := fftSize(, ) := polyFromNat(, , ) := polyFromNat(, , ) := .Mul(&) return .Int() } // fftSizeThreshold[i] is the maximal size (in bits) where we should use // fft size i. var fftSizeThreshold = [...]int64{0, 0, 0, 4 << 10, 8 << 10, 16 << 10, // 5 32 << 10, 64 << 10, 1 << 18, 1 << 20, 3 << 20, // 10 8 << 20, 30 << 20, 100 << 20, 300 << 20, 600 << 20, } // returns the FFT length k, m the number of words per chunk // such that m << k is larger than the number of words // in x*y. func (, nat) ( uint, int) { := len() + len() := int64() * int64(_W) = uint(len(fftSizeThreshold)) for := range fftSizeThreshold { if fftSizeThreshold[] > { = uint() break } } // The 1<<k chunks of m words must have N bits so that // 2^N-1 is larger than x*y. That is, m<<k > words = >> + 1 return } // valueSize returns the length (in words) to use for polynomial // coefficients, to compute a correct product of polynomials P*Q // where deg(P*Q) < K (== 1<<k) and where coefficients of P and Q are // less than b^m (== 1 << (m*_W)). // The chosen length (in bits) must be a multiple of 1 << (k-extra). func ( uint, int, uint) int { // The coefficients of P*Q are less than b^(2m)*K // so we need W * valueSize >= 2*m*W+K := 2**_W + int() // necessary bits := 1 << ( - ) if < _W { = _W } = (( / ) + 1) * // round to a multiple of K return / _W } // poly represents an integer via a polynomial in Z[x]/(x^K+1) // where K is the FFT length and b^m is the computation basis 1<<(m*_W). // If P = a[0] + a[1] x + ... a[n] x^(K-1), the associated natural number // is P(b^m). type poly struct { k uint // k is such that K = 1<<k. m int // the m such that P(b^m) is the original number. a []nat // a slice of at most K m-word coefficients. } // polyFromNat slices the number x into a polynomial // with 1<<k coefficients made of m words. func ( nat, uint, int) poly { := poly{k: , m: } := len()/ + 1 .a = make([]nat, ) for := range .a { if len() < { .a[] = make(nat, ) copy(.a[], ) break } .a[] = [:] = [:] } return } // Int evaluates back a poly to its integer value. func ( *poly) () nat { := len(.a)*.m + 1 if := len(.a); > 0 { += len(.a[-1]) } := make(nat, ) := .m := for := range .a { := len(.a[]) := addVV([:], [:], .a[]) if [] < ^big.Word(0) { [] += } else { addVW([:], [:], ) } = [:] } = trim() return } func ( nat) nat { for := range { if [len()-1-] != 0 { return [:len()-] } } return nil } // Mul multiplies p and q modulo X^K-1, where K = 1<<p.k. // The product is done via a Fourier transform. func ( *poly) ( *poly) poly { // extra=2 because: // * some power of 2 is a K-th root of unity when n is a multiple of K/2. // * 2 itself is a square (see fermat.ShiftHalf) := valueSize(.k, .m, 2) , := .Transform(), .Transform() := .Mul(&) := .InvTransform() .m = .m return } // A polValues represents the value of a poly at the powers of a // K-th root of unity θ=2^(l/2) in Z/(b^n+1)Z, where b^n = 2^(K/4*l). type polValues struct { k uint // k is such that K = 1<<k. n int // the length of coefficients, n*_W a multiple of K/4. values []fermat // a slice of K (n+1)-word values } // Transform evaluates p at θ^i for i = 0...K-1, where // θ is a K-th primitive root of unity in Z/(b^n+1)Z. func ( *poly) ( int) polValues { := .k := make([]big.Word, (+1)<<) := make([]fermat, 1<<) // Now computed q(ω^i) for i = 0 ... K-1 := make([]big.Word, (+1)<<) := make([]fermat, 1<<) for := range { [] = [*(+1) : (+1)*(+1)] if < len(.a) { copy([], .a[]) } [] = fermat([*(+1) : (+1)*(+1)]) } fourier(, , false, , ) return polValues{, , } } // InvTransform reconstructs p (modulo X^K - 1) from its // values at θ^i for i = 0..K-1. func ( *polValues) () poly { , := .k, .n // Perform an inverse Fourier transform to recover p. := make([]big.Word, (+1)<<) := make([]fermat, 1<<) for := range { [] = fermat([*(+1) : (+1)*(+1)]) } fourier(, .values, true, , ) // Divide by K, and untwist q to recover p. := make(fermat, +1) := make([]nat, 1<<) for := range { .Shift([], -int()) copy([], ) [] = nat([]) } return poly{k: , m: 0, a: } } // NTransform evaluates p at θω^i for i = 0...K-1, where // θ is a (2K)-th primitive root of unity in Z/(b^n+1)Z // and ω = θ². func ( *poly) ( int) polValues { := .k if len(.a) >= 1<< { panic("Transform: len(p.a) >= 1<<k") } // θ is represented as a shift. := ( * _W) >> // p(x) = a_0 + a_1 x + ... + a_{K-1} x^(K-1) // p(θx) = q(x) where // q(x) = a_0 + θa_1 x + ... + θ^(K-1) a_{K-1} x^(K-1) // // Twist p by θ to obtain q. := make([]big.Word, (+1)<<) := make([]fermat, 1<<) := make(fermat, +1) for := range { [] = fermat([*(+1) : (+1)*(+1)]) if < len(.a) { for := range { [] = 0 } copy(, .a[]) [].Shift(, *) } } // Now computed q(ω^i) for i = 0 ... K-1 := make([]big.Word, (+1)<<) := make([]fermat, 1<<) for := range { [] = fermat([*(+1) : (+1)*(+1)]) } fourier(, , false, , ) return polValues{, , } } // InvTransform reconstructs a polynomial from its values at // roots of x^K+1. The m field of the returned polynomial // is unspecified. func ( *polValues) () poly { := .k := .n := ( * _W) >> // Perform an inverse Fourier transform to recover q. := make([]big.Word, (+1)<<) := make([]fermat, 1<<) for := range { [] = fermat([*(+1) : (+1)*(+1)]) } fourier(, .values, true, , ) // Divide by K, and untwist q to recover p. := make(fermat, +1) := make([]nat, 1<<) for := range { .Shift([], -int()-*) copy([], ) [] = nat([]) } return poly{k: , m: 0, a: } } // fourier performs an unnormalized Fourier transform // of src, a length 1<<k vector of numbers modulo b^n+1 // where b = 1<<_W. func ( []fermat, []fermat, bool, int, uint) { var func(, []fermat, uint) := make(fermat, +1) // pre-allocate temporary variables. := make(fermat, +1) // pre-allocate temporary variables. // The recursion function of the FFT. // The root of unity used in the transform is ω=1<<(ω2shift/2). // The source array may use shifted indices (i.e. the i-th // element is src[i << idxShift]). = func(, []fermat, uint) { := - := (4 * * _W) >> if { = - } // Easy cases. if len([0]) != +1 || len([0]) != +1 { panic("len(src[0]) != n+1 || len(dst[0]) != n+1") } switch { case 0: copy([0], [0]) return case 1: [0].Add([0], [1<<]) // dst[0] = src[0] + src[1] [1].Sub([0], [1<<]) // dst[1] = src[0] - src[1] return } // Let P(x) = src[0] + src[1<<idxShift] * x + ... + src[K-1 << idxShift] * x^(K-1) // The P(x) = Q1(x²) + x*Q2(x²) // where Q1's coefficients are src with indices shifted by 1 // where Q2's coefficients are src[1<<idxShift:] with indices shifted by 1 // Split destination vectors in halves. := [:1<<(-1)] := [1<<(-1):] // Transform Q1 and Q2 in the halves. (, , -1) (, [1<<:], -1) // Reconstruct P's transform from transforms of Q1 and Q2. // dst[i] is dst1[i] + ω^i * dst2[i] // dst[i + 1<<(k-1)] is dst1[i] + ω^(i+K/2) * dst2[i] // for := range { .ShiftHalf([], *, ) // ω^i * dst2[i] [].Sub([], ) [].Add([], ) } } (, , ) } // Mul returns the pointwise product of p and q. func ( *polValues) ( *polValues) ( polValues) { := .n .k, .n = .k, .n .values = make([]fermat, len(.values)) := make([]big.Word, len(.values)*(+1)) := make(fermat, 8*) for := range .values { .values[] = [*(+1) : (+1)*(+1)] := .Mul(.values[], .values[]) copy(.values[], ) } return }