Recently, a problem from the USACO training pages has been bothering me. I had solved it years ago in Java, but my friend Robert Won challenged me to do in Python. Since Python is many times slower, this means my code has to be much smarter.

Problem

An arithmetic progression is a sequence of the form $a$, $a+b$, $a+2b$, $\ldots$, $a+nb$ where $n=0, 1, 2, 3, \ldots$. For this problem, $a$ is a non-negative integer and $b$ is a positive integer.

Write a program that finds all arithmetic progressions of length $n$ in the set $S$ of bisquares. The set of bisquares is defined as the set of all integers of the form $p^2 + q^2$ (where $p$ and $q$ are non-negative integers).

TIME LIMIT: 5 secs

PROGRAM NAME: ariprog

INPUT FORMAT

  • Line 1: $N$ ($3 \leq N \leq 25$), the length of progressions for which to search
  • Line 2: $M$ ($1 \leq M \leq 250$), an upper bound to limit the search to the bisquares with $0 \leq p,q \leq M$.

SAMPLE INPUT (file ariprog.in)

5
7

OUTPUT FORMAT

If no sequence is found, a single line reading NONE. Otherwise, output one or more lines, each with two integers: the first element in a found sequence and the difference between consecutive elements in the same sequence. The lines should be ordered with smallest-difference sequences first and smallest starting number within those sequences first.

There will be no more than 10,000 sequences.

SAMPLE OUTPUT (file ariprog.out)

1 4
37 4
2 8
29 8
1 12
5 12
13 12
17 12
5 20
2 24

Dynamic Programming Solution

My initial solution that I translated from C++ to Python was not fast enough. I wrote a new solution that I thought was clever. We iterate over all possible deltas, and for each delta, we use dynamic programming to find the longest sequence with that delta.

def find_arithmetic_progressions(N, M):
    is_bisquare = [False] * (M * M + M * M + 1)
    bisquare_indices = [-1] * (M * M + M * M + 1)
    bisquares = []
    for p in range(0, M + 1):
        for q in range(p, M + 1):
            x = p * p + q * q
            if is_bisquare[x]: continue
            is_bisquare[x] = True
            bisquares.append(x)
    bisquares.sort()
    for i, bisquare in enumerate(bisquares):
        bisquare_indices[bisquare] = i

    sequences, i = [], 0
    for delta in range(1, bisquares[-1] // (N - 1) + 1):
        sequence_lengths = [1] * len(bisquares)
        while bisquares[i] < delta: i += 1
        for x in bisquares[i:]:
            previous_idx = bisquare_indices[x - delta]
            if previous_idx == -1: continue
            idx, sequence_length = bisquare_indices[x], sequence_lengths[previous_idx] + 1
            sequence_lengths[idx] = sequence_length
            if sequence_length >= N:
                sequences.append((delta, x - (N - 1) * delta))

    return sequences

Too slow!

Executing...
   Test 1: TEST OK [0.011 secs, 9352 KB]
   Test 2: TEST OK [0.011 secs, 9352 KB]
   Test 3: TEST OK [0.011 secs, 9168 KB]
   Test 4: TEST OK [0.011 secs, 9304 KB]
   Test 5: TEST OK [0.031 secs, 9480 KB]
   Test 6: TEST OK [0.215 secs, 9516 KB]
   Test 7: TEST OK [2.382 secs, 9676 KB]
  > Run 8: Execution error: Your program (`ariprog') used more than
        the allotted runtime of 5 seconds (it ended or was stopped at
        5.242 seconds) when presented with test case 8. It used 12948 KB
        of memory. 

        ------ Data for Run 8 [length=7 bytes] ------
        22 
        250 
        ----------------------------
    Test 8: RUNTIME 5.242>5 (12948 KB)

I managed to put my mathematics background to good use here: $p^2 + q^2 \not\equiv 3 \pmod 4$ and $p^2 + q^2 \not\equiv 6 \pmod 8$. This means that a bisquare arithmetic progression with more than 3 elements must have delta divisible by 4. If $b \equiv 1 \pmod 4$ or $b \equiv 3 \pmod 4$, there would have to be a bisquare $p^2 + q^2 \equiv 3 \pmod 4$, which is impossible. If $b \equiv 2 \pmod 4$, there would be have to be $p^2 + q^2 \equiv 6 \pmod 8$, which is also impossible.

This optimization makes it fast, enough.

def find_arithmetic_progressions(N, M):
    is_bisquare = [False] * (M * M + M * M + 1)
    bisquare_indices = [-1] * (M * M + M * M + 1)
    bisquares = []
    for p in range(0, M + 1):
        for q in range(p, M + 1):
            x = p * p + q * q
            if is_bisquare[x]: continue
            is_bisquare[x] = True
            bisquares.append(x)
    bisquares.sort()
    for i, bisquare in enumerate(bisquares):
        bisquare_indices[bisquare] = i

    sequences, i = [], 0
    for delta in (range(1, bisquares[-1] // (N - 1) + 1) if N == 3 else
                  range(4, bisquares[-1] // (N - 1) + 1, 4)):
        sequence_lengths = [1] * len(bisquares)
        while bisquares[i] < delta: i += 1
        for x in bisquares[i:]:
            previous_idx = bisquare_indices[x - delta]
            if previous_idx == -1: continue
            idx, sequence_length = bisquare_indices[x], sequence_lengths[previous_idx] + 1
            sequence_lengths[idx] = sequence_length
            if sequence_length >= N:
                sequences.append((delta, x - (N - 1) * delta))

    return sequences
Executing...
   Test 1: TEST OK [0.010 secs, 9300 KB]
   Test 2: TEST OK [0.011 secs, 9368 KB]
   Test 3: TEST OK [0.015 secs, 9248 KB]
   Test 4: TEST OK [0.014 secs, 9352 KB]
   Test 5: TEST OK [0.045 secs, 9340 KB]
   Test 6: TEST OK [0.078 secs, 9464 KB]
   Test 7: TEST OK [0.662 secs, 9756 KB]
   Test 8: TEST OK [1.473 secs, 9728 KB]
   Test 9: TEST OK [1.313 secs, 9740 KB]

All tests OK.

Even Faster!

Not content to merely pass, I wanted to see if we could pass all test cases with less than 1 second (time limit was 5 seconds). Indeed, we can. The solution in the official analysis take advantage of the fact that the sequence length is short. The dynamic programming optimization is not that helpful. It's better to optimize for traversing the bisquares less. Instead, we take pairs of bisquares carefully: we break out when the delta is too big. The official solution has some inefficiencies like using a hash map. If we instead use indexed array lookups, we can be very fast.

def find_arithmetic_progressions(N, M):
    is_bisquare = [False] * (M * M + M * M + 1)
    bisquares = []
    for p in range(0, M + 1):
        for q in range(p, M + 1):
            x = p * p + q * q
            if is_bisquare[x]: continue
            is_bisquare[x] = True
            bisquares.append(x)
    bisquares.sort()

    sequences = []
    for i in reversed(range(len(bisquares))):
        x = bisquares[i]
        max_delta = x // (N - 1)
        for j in reversed(range(i)):
            y = bisquares[j]
            delta = x - y
            if delta > max_delta: break
            if N > 3 and delta % 4 != 0: continue
            z = x - (N - 1) * delta
            while y > z and is_bisquare[y - delta]: y -= delta
            if z == y: sequences.append((delta, z))
    sequences.sort()
    return sequences
Executing...
   Test 1: TEST OK [0.013 secs, 9280 KB]
   Test 2: TEST OK [0.012 secs, 9284 KB]
   Test 3: TEST OK [0.013 secs, 9288 KB]
   Test 4: TEST OK [0.012 secs, 9208 KB]
   Test 5: TEST OK [0.018 secs, 9460 KB]
   Test 6: TEST OK [0.051 secs, 9292 KB]
   Test 7: TEST OK [0.421 secs, 9552 KB]
   Test 8: TEST OK [0.896 secs, 9588 KB]
   Test 9: TEST OK [0.786 secs, 9484 KB]

All tests OK.

Yay!


New Comment


Comments

No comments have been posted yet. You can be the first!