package bigfft

import (
	
)

// Arithmetic modulo 2^n+1.

// A fermat of length w+1 represents a number modulo 2^(w*_W) + 1. The last
// word is zero or one. A number has at most two representatives satisfying the
// 0-1 last word constraint.
type fermat nat

func ( fermat) () string { return nat().String() }

func ( fermat) () {
	 := len() - 1
	 := []
	if  == 0 {
		return
	}
	if [0] >=  {
		[] = 0
		[0] -= 
		return
	}
	// z[0] < z[n].
	subVW(, , ) // Substract c
	if  > 1 {
		[] -=  - 1
		 = 1
	}
	// Add back c.
	if [] == 1 {
		[] = 0
		return
	} else {
		addVW(, , 1)
	}
}

// Shift computes (x << k) mod (2^n+1).
func ( fermat) ( fermat,  int) {
	if len() != len() {
		panic("len(z) != len(x) in Shift")
	}
	 := len() - 1
	// Shift by n*_W is taking the opposite.
	 %= 2 *  * _W
	if  < 0 {
		 += 2 *  * _W
	}
	 := false
	if  >= *_W {
		 -=  * _W
		 = true
	}

	,  := /_W, %_W

	[] = 1 // Add (-1)
	if ! {
		for  := 0;  < ; ++ {
			[] = 0
		}
		// Shift left by kw words.
		// x = a·2^(n-k) + b
		// x<<k = (b<<k) - a
		copy([:], [:-])
		 := subVV([:+1], [:+1], [-:])
		if [+1] > 0 {
			[+1] -= 
		} else {
			subVW([+1:], [+1:], )
		}
	} else {
		for  :=  + 1;  < ; ++ {
			[] = 0
		}
		// Shift left and negate, by kw words.
		copy([:+1], [-:+1])            // z_low = x_high
		 := subVV([:], [:], [:-]) // z_high -= x_low
		[] -= 
	}
	// Add back 1.
	if [] > 0 {
		[]--
	} else if [0] < ^big.Word(0) {
		[0]++
	} else {
		addVW(, , 1)
	}
	// Shift left by kb bits
	shlVU(, , uint())
	.norm()
}

// ShiftHalf shifts x by k/2 bits the left. Shifting by 1/2 bit
// is multiplication by sqrt(2) mod 2^n+1 which is 2^(3n/4) - 2^(n/4).
// A temporary buffer must be provided in tmp.
func ( fermat) ( fermat,  int,  fermat) {
	 := len() - 1
	if %2 == 0 {
		.Shift(, /2)
		return
	}
	 := ( - 1) / 2
	 :=  + (3*_W/4)*
	 :=  + (_W/4)*
	.Shift(, )
	.Shift(, )
	.Sub(, )
}

// Add computes addition mod 2^n+1.
func ( fermat) (,  fermat) fermat {
	if len() != len() {
		panic("Add: len(z) != len(x)")
	}
	addVV(, , ) // there cannot be a carry here.
	.norm()
	return 
}

// Sub computes substraction mod 2^n+1.
func ( fermat) (,  fermat) fermat {
	if len() != len() {
		panic("Add: len(z) != len(x)")
	}
	 := len() - 1
	 := subVV([:], [:], [:])
	 += []
	// If b > 0, we need to subtract b<<n, which is the same as adding b.
	[] = []
	if [0] <= ^big.Word(0)- {
		[0] += 
	} else {
		addVW(, , )
	}
	.norm()
	return 
}

func ( fermat) (,  fermat) fermat {
	if len() != len() {
		panic("Mul: len(x) != len(y)")
	}
	 := len() - 1
	if  < 30 {
		 = [:2*+2]
		basicMul(, , )
		 = [:2*+1]
	} else {
		var , ,  big.Int
		.SetBits()
		.SetBits()
		.SetBits()
		 := .Mul(&, &).Bits()
		if len() <=  {
			// Short product.
			copy(, )
			for  := len();  < len(); ++ {
				[] = 0
			}
			return 
		}
		 = 
	}
	// len(z) is at most 2n+1.
	if len() > 2*+1 {
		panic("len(z) > 2n+1")
	}
	// We now have
	// z = z[:n] + 1<<(n*W) * z[n:2n+1]
	// which normalizes to:
	// z = z[:n] - z[n:2n] + z[2n]
	 := big.Word(0)
	if len() > 2* {
		 = addVW([:], [:], [2*])
	}
	 := big.Word(0)
	if len() >= 2* {
		 = subVV([:], [:], [:2*])
	} else {
		 := len() - 
		 = subVV([:], [:], [:])
		 = subVW([:], [:], )
	}
	// Restore carries.
	// Substracting z[n] -= c2 is the same
	// as z[0] += c2
	 = [:+1]
	[] = 
	 := addVW(, , )
	if  != 0 {
		panic("impossible")
	}
	.norm()
	return 
}

// copied from math/big
//
// basicMul multiplies x and y and leaves the result in z.
// The (non-normalized) result is placed in z[0 : len(x) + len(y)].
func (, ,  fermat) {
	// initialize z
	for  := 0;  < len(); ++ {
		[] = 0
	}
	for ,  := range  {
		if  != 0 {
			[len()+] = addMulVVW([:+len()], , )
		}
	}
}