Friday, May 30, 2014

Exhaustive List of the Sums of Four-Squares

I'd recently been asked to assist on some research on equilateral triangles in $\mathbb{Z}^4$. The project required coming up with a list of all of the lattice points $(i,j,k,l)$ on a hypersphere of radius $\sqrt{d}$ for $d\in\mathbb{N}$. Basically, this reduces to the problem of finding all sets of four integers $(i,j,k,l)$ such that $i^2+j^2+k^2+l^2=d$.

This is an old problem and in 1770 Lagrange proved that every natural number $d$ can be represented as the sum of four squares[1]. So we know that we should always find at least one. Additionally, in 1834 Jacobi determined how many solutions exist[2]: $$N(d)=8\sum_{4\nmid m|d}m$$ We can use that to check to see if our program can find all the points (at least, the right number of points). Of course for the smaller values of $d$, a complete check on the lattice can be performed from $-\sqrt{d}$ to $\sqrt{d}$ in all four cardinal directions. However, this has a complexity of $O(d^2)$.
from math import sqrt

def sum_squares(d):

   sqrt_d = int(sqrt(d))

   ## This list will hold all of the points
   points = []

   ## Nested for loops to check
   ## Add 1 because of the exclusivity of the bound of range() function
   for i in range(-sqrt_d,sqrt_d + 1):
      for j in range(-sqrt_d,sqrt_d + 1):
         for k in range(-sqrt_d,sqrt_d + 1):
            for l in range(-sqrt_d,sqrt_d + 1):
               if i**2 + j**2 + k**2 + l**2 == d:
                  points.append([i,j,k,l])

   return points

However, this gets really slow really quickly! Especially if I wanted to run this for $d$ in the tens of thousands. Better results can be obtained when we realize that our $l$ loop is completely unnecessary. Since we would already have a guesses for $i, j,$ and $k$ and we know $d$, we can solve for $l$ and check to see if it is an integer. This reduces complexity to $O(d^{1.5})$. We have to do a few things to make sure that the sqrt() function only takes a positive number and that we include both the positive and negative root, but the updated code is as follows.
from math import sqrt

def sum_squares(d):

   ## This list will hold all of the points
   points = []

   ## Nested for loops to check
   ## Add 1 because of the exclusivity of the bound of range() function
   for i in range(-sqrt_d,sqrt_d + 1):
      for j in range(-sqrt_d,sqrt_d + 1):
         for k in range(-sqrt_d,sqrt_d + 1):
            l_squared = d - i**2 - j**2 - k**2
            if l_squared >= 0:
               l = sqrt(l_squared)
               if l.is_integer():
                  points.append([i,j,k,int(l)])
                  if l != 0:
                     points.append([i,j,k,-int(l)])

   return points

Even better results can be obtained if you only check the first orthant and permute the sign of the results into the other 15 orthants. This will run in slightly more than $1/16$ the amount of time, but will still have the same complexity. However, depending on how many first orthant points we find, the permutation can take a significant amount of time and I will address the permutations more in a different post. Looking in the first orthant only, we eliminate the negative root for $l$ in the last line of the nested loops.

from math import sqrt

## This list contains the 16 sign permutations
signs = [[ 1, 1, 1, 1], [ 1, 1, 1,-1], [ 1, 1,-1, 1], [ 1, 1,-1,-1],
         [ 1,-1, 1, 1], [ 1,-1, 1,-1], [ 1,-1,-1, 1], [ 1,-1,-1,-1],
         [-1, 1, 1, 1], [-1, 1, 1,-1], [-1, 1,-1, 1], [-1, 1,-1,-1],
         [-1,-1, 1, 1], [-1,-1, 1,-1], [-1,-1,-1, 1], [-1,-1,-1,-1]]

def sum_squares(d):

   ## Add 1 because of the exclusivity of the bound of range() function
   sqrt_d = int(sqrt(d)) + 1

   ## This list will hold the points in the first orthant
   points = []

   ## This list will hold the points from all orthants
   final = []

   ## Nested for loops to check
   for i in range(sqrt_d):
      for j in range(sqrt_d):
         for k in range(sqrt_d):
            l_squared = d - i**2 - j**2 - k**2
            if l_squared >= 0:
               l = sqrt(l_squared)
               if l.is_integer():
                  points.append([i,j,k,int(l)])

   ## For each point found in the first orthant
   for m in range(len(points)):

      ## This list will hold all sign permutations for the mth point
      ## in the first orthant
      point_signs = []

      ## For each of the possible sign permutations
      for i in range(len(signs)):
         tmp = [points[m][0] * signs[i][0], points[m][1] * signs[i][1],
                points[m][2] * signs[i][2], points[m][3] * signs[i][3]]

         ## To avoid errors with points on the axes
         ## i.e. 0 * 1 = 0 * -1
         if tmp not in point_signs:
            point_signs.append(tmp)

      final = final + point_signs

   return final

However, this is still has the same complexity. For larger values of $d$ something different must be done to reduce complexity! One such way to reduce complexity is to search for only the basic representation of the lattice points in the first orthant. The basic representation is the representation $(i,j,k,l)$ such that $i\geq j\geq k\geq l \geq 0$. For example, rather than finding $\{(2,1,1,0), (2,1,0,1), (2,0,1,1), \dots\}$, we only need to find $(2,1,1,0)$ and permute the order to produce the rest.

This further restricts the bounds on our for loops. For instance, we know that $\sqrt{d/4} \leq i \leq \sqrt{d}$. The upper bound is the same as before, however we have now bounded $i$ from below! To see why, we set $i=j=k=l=i_{min}$. This is $i$ at it's lowest possible value and $j,k,$ and $l$ at their largest. This yields $$d=i^2+j^2+k^2+l^2 = 4{i_{min}}^2\implies i_{min} = \sqrt{d/4} $$ Each of the subsequent indices are then bounded above by both the previous index and the difference between d and the sum of the squares of all the previous indices. For instance, $k \leq j$ and $k \leq d - (i^2 + j^2)$. A stronger statement is therefore $k \leq \mathrm{min}(j, d - (i^2 + j^2))$.

from math import sqrt

signs = [[ 1, 1, 1, 1], [ 1, 1, 1,-1], [ 1, 1,-1, 1], [ 1, 1,-1,-1],
         [ 1,-1, 1, 1], [ 1,-1, 1,-1], [ 1,-1,-1, 1], [ 1,-1,-1,-1],
         [-1, 1, 1, 1], [-1, 1, 1,-1], [-1, 1,-1, 1], [-1, 1,-1,-1],
         [-1,-1, 1, 1], [-1,-1, 1,-1], [-1,-1,-1, 1], [-1,-1,-1,-1]]

orders = [[0,1,2,3], [0,1,3,2], [0,2,1,3], [0,2,3,1], [0,3,1,2], [0,3,2,1],
          [1,0,2,3], [1,0,3,2], [1,2,0,3], [1,2,3,0], [1,3,0,2], [1,3,2,0],
          [2,0,1,3], [2,0,3,1], [2,1,0,3], [2,1,3,0], [2,3,0,1], [2,3,1,0],
          [3,0,1,2], [3,0,2,1], [3,1,0,2], [3,1,2,0], [3,2,0,1], [3,2,1,0]]

def sum_squares(d):
   
   ## Add 1 because of the exclusivity of the bound of range() function
   sqrt_d = int(sqrt(d)) + 1
   sqrt_d4 = int(sqrt(d/4))
   
   ## This list will hold the basic points in the first orthant
   points = []

   ## This list will hold the points from all orthants
   final = []
   
   ## Nested for loops to check
   for i in range(sqrt_d4, sqrt_d):
      sum_i = d - i**2
      for j in range(min(i,int(sqrt(sum_i)))+1):
         sum_j = sum_i - j**2
         for k in range(min(j,int(sqrt(sum_j)))+1):
            sum_k = sum_j - k**2
            l = int(sqrt(sum_k))
            if l**2 == sum_k and l <= k:
               points.append([i,j,k,l])

   ## For each basic point found in the first orthant
   for m in range(len(points)):
      ## This list will hold all order permutations for the
      ## ith sign permutation
      point_orders = []

      ## For each of the possible order permutations
      for n in range(len(orders)):
         tmp = [points[m][orders[n][0]],
                points[m][orders[n][1]],
                points[m][orders[n][2]],
                points[m][orders[n][3]]]

         ## To avoid duplicates who have the same value in two directions
         ## i.e. [2,1,1,0] = [2,1,1,0]
         if tmp not in point_orders:
            point_orders.append(tmp)

      for n in range(len(point_orders)):
         
         ## This list will hold all sign permutations for the mth point
         ## in the first orthant
         point_signs = []
 
         ## For each of the possible sign permutations
         for o in range(len(signs)):
            tmp = [point_orders[n][0] * signs[o][0],
                   point_orders[n][1] * signs[o][1],
                   point_orders[n][2] * signs[o][2],
                   point_orders[n][3] * signs[o][3]]
 
            ## To avoid duplicates with points on the axes
            ## i.e. 0 * 1 = 0 * -1
            if tmp not in point_signs:
               point_signs.append(tmp)

         final = final + point_signs
            
   return final

But this is still pretty much the same complexity. I still wanted to get something faster. I did some research online and was unable to find anything. The closest I came was an article in Communications on Pure and Applied Mathematics by Rabin and Shallit[3] that used randomized algorithms to produce a single lattice point such that $x^2+y^2+z^2+w^2=d^2$. While fast (complexity of $O(\log^2 d)$), this method produces only one lattice points and there is no way to guarantee which point you get.

Finally, I stumbled upon a post on the Computer Science Stack Exchange[4] where a user suggested this nearly $O(d)$ algorithm. The way this algorithm works is that is populates a list that is $d+1$ long with pairs of numbers whose squares sum to the index of that particular element. So for the $0$th index we will have the pair $[0,0]$ because $0^2+0^2=0$, for the $5$th index we will have the pair $[2,1]$ because $2^2+1^2=5$, and for the $25$th index we will have the pairs $[5,0]$ and $[4,3]$ because $5^2+0^2=4^2+3^2=25$. Now, we know that the pairs in the $0$th location and the pairs in the $d$th location will have a four-square sum of $d$ and the pairs in the $5$th location and the pairs in the $d-5$th location will have a four-square sum of $d$. We do a meet in the middle loop through this list and we have our basic representations!

For example, consider $d=6$. We have the following list of pairs which we can generation in $O(d)$ time. $$\begin{array}{|c|c|c|c|c|c|c|c|} \hline Index & 0 & 1 & 2 & 3 & 4 & 5 & 6 \\ \hline Pairs & [0,0] & [1,0] & [1,1] & \phantom{[0,0]} & [2,0] & [2,1] & \phantom{[0,0]}\\ \hline \end{array}$$ Now look at the pair in index $1$, $[1,0]$ and the pair in index $5$, $[2,1]$. Because $1+5=6$, we know the sum of the squares of all four of the numbers in these two pairs will sum to $6$: $2^2+1^2+1^2+0^2=6$. A wonderfully constructed list! The code for this method is below.

from math import sqrt

signs = [[ 1, 1, 1, 1], [ 1, 1, 1,-1], [ 1, 1,-1, 1], [ 1, 1,-1,-1],
         [ 1,-1, 1, 1], [ 1,-1, 1,-1], [ 1,-1,-1, 1], [ 1,-1,-1,-1],
         [-1, 1, 1, 1], [-1, 1, 1,-1], [-1, 1,-1, 1], [-1, 1,-1,-1],
         [-1,-1, 1, 1], [-1,-1, 1,-1], [-1,-1,-1, 1], [-1,-1,-1,-1]]

orders = [[0,1,2,3], [0,1,3,2], [0,2,1,3], [0,2,3,1], [0,3,1,2], [0,3,2,1],
          [1,0,2,3], [1,0,3,2], [1,2,0,3], [1,2,3,0], [1,3,0,2], [1,3,2,0],
          [2,0,1,3], [2,0,3,1], [2,1,0,3], [2,1,3,0], [2,3,0,1], [2,3,1,0],
          [3,0,1,2], [3,0,2,1], [3,1,0,2], [3,1,2,0], [3,2,0,1], [3,2,1,0]]

def sum_squares(d):
   ## Element i of sum_pairs will be populated with all pairs [a,b]
   ## such that a, b < sqrt(d) and a*a + b*b = i
   sum_pairs = [[] for i in range(d+1)]

   ## Fill sum_pairs
   for a in range(0,int(sqrt(d)) + 1):
      for b in range(0,min(a + 1, int(sqrt(d - a**2)) + 1)):
         i = a*a + b*b
         if i <= d:
            sum_pairs[i].append([a,b])
         else:
            break

   ## This list will be populated with all points [i,j,k,l]
   ## such that i >= j >= k >= l >= 0 whose squares sum to d
   points = []

   ## This list will hold the points from all orthants
   final = []
   
   for i in range(0,d//2 + 1):
      for j in range(0,len(sum_pairs[i])):
         for k in range(0,len(sum_pairs[d-i])):
            tmp = [sum_pairs[i][j][0], sum_pairs[i][j][1],
                   sum_pairs[d-i][k][0], sum_pairs[d-i][k][1]]
            tmp.sort()
            if tmp not in points:
               points.append(tmp)

   ## For each basic point found in the first orthant
   for m in range(len(points)):
      ## This list will hold all order permutations for the
      ## ith sign permutation
      point_orders = []

      ## For each of the possible order permutations
      for n in range(len(orders)):
         tmp = [points[m][orders[n][0]],
                points[m][orders[n][1]],
                points[m][orders[n][2]],
                points[m][orders[n][3]]]

         ## To avoid duplicates who have the same value in two directions
         ## i.e. [2,1,1,0] = [2,1,1,0]
         if tmp not in point_orders:
            point_orders.append(tmp)

      for n in range(len(point_orders)):
         
         ## This list will hold all sign permutations for the mth point
         ## in the first orthant
         point_signs = []
 
         ## For each of the possible sign permutations
         for o in range(len(signs)):
            tmp = [point_orders[n][0] * signs[o][0],
                   point_orders[n][1] * signs[o][1],
                   point_orders[n][2] * signs[o][2],
                   point_orders[n][3] * signs[o][3]]
 
            ## To avoid duplicates with points on the axes
            ## i.e. 0 * 1 = 0 * -1
            if tmp not in point_signs:
               point_signs.append(tmp)

         final = final + point_signs
            
   return final
Ultimately, we've found a very quick way to find all lattice points on the hypersphere of radius $\sqrt{d}$. For the methods that we had to permute (either sign or position), most of the computation time is spent permuting. Therefore, I will spend some time working on a faster way to permute these basic representations while being mindful of repeats. Those methods will be discussed in another post.
[1] https://en.wikipedia.org/wiki/Lagrange%27s_four-square_theorem
[2] https://en.wikipedia.org/wiki/Jacobi%27s_four-square_theorem
[3] M.O. Rabin, J.O. Shallit, Randomized Algorithms in Number Theory, Communications on Pure and Applied Mathematics (1986), no. S1, pp S239-S256.
[4]http://cs.stackexchange.com/questions/2988/how-fast-can-we-find-all-four-square-combinations-that-sum-to-n

No comments:

Post a Comment