diff --git a/python/e345.py b/python/e345.py new file mode 100644 index 0000000..a2cb454 --- /dev/null +++ b/python/e345.py @@ -0,0 +1,76 @@ +from itertools import permutations +from heapq import heappush, heappop + + +m1 = """ 7 53 183 439 863 + 497 383 563 79 973 + 287 63 343 169 583 + 627 343 773 959 943 + 767 473 103 699 303""" + + +m2 = """ 7 53 183 439 863 497 383 563 79 973 287 63 343 169 583 +627 343 773 959 943 767 473 103 699 303 957 703 583 639 913 +447 283 463 29 23 487 463 993 119 883 327 493 423 159 743 +217 623 3 399 853 407 103 983 89 463 290 516 212 462 350 +960 376 682 962 300 780 486 502 912 800 250 346 172 812 350 +870 456 192 162 593 473 915 45 989 873 823 965 425 329 803 +973 965 905 919 133 673 665 235 509 613 673 815 165 992 326 +322 148 972 962 286 255 941 541 265 323 925 281 601 95 973 +445 721 11 525 473 65 511 164 138 672 18 428 154 448 848 +414 456 310 312 798 104 566 520 302 248 694 976 430 392 198 +184 829 373 181 631 101 969 613 840 740 778 458 284 760 390 +821 461 843 513 17 901 711 993 293 157 274 94 192 156 574 + 34 124 4 878 450 476 712 914 838 669 875 299 823 329 699 +815 559 813 459 522 788 168 586 966 232 308 833 251 631 107 +813 883 451 509 615 77 281 613 459 205 380 274 302 35 805 +""" + + +def brute_force(m): + idxs = [i for i in range(len(m))] + max_sum = 0 + for idx in permutations(idxs): + temp_sum = 0 + for i, j in enumerate(idx): + temp_sum += m[j][i] + if temp_sum > max_sum: + max_sum = temp_sum + return max_sum + + +def a_star(m): + # init heuristic function + h = {} + for i in range(len(m)): + h[i] = sum(map(min, m[i:])) + h[i + 1] = 0 + + states = [] + for i, r in enumerate(m[0]): + state = (r + h[1], r, 1, [i]) + heappush(states, state) + + while states: + _, gscore, index, seen = heappop(states) + if index == len(m): + return -gscore + + for row_index, row_score in enumerate(m[index]): + if row_index in seen: + continue + ngscore = gscore + row_score + nfscore = ngscore + h[index + 1] + nstate = (nfscore, ngscore, index + 1, seen + [row_index]) + heappush(states, nstate) + + +def euler_345(): + m = list(zip(*map(lambda line: list(map(lambda n: -int(n), line.split())), m2.splitlines()))) + return a_star(m) + + +if __name__ == "__main__": + solution = euler_345() + print("e345.py: " + str(solution)) + assert(solution == 13938)