euler/python/e345.py

77 lines
2.3 KiB
Python

from itertools import permutations
from heapq import heappush, heappop
m1 = """ 7 53 183 439 863
497 383 563 79 973
287 63 343 169 583
627 343 773 959 943
767 473 103 699 303"""
m2 = """ 7 53 183 439 863 497 383 563 79 973 287 63 343 169 583
627 343 773 959 943 767 473 103 699 303 957 703 583 639 913
447 283 463 29 23 487 463 993 119 883 327 493 423 159 743
217 623 3 399 853 407 103 983 89 463 290 516 212 462 350
960 376 682 962 300 780 486 502 912 800 250 346 172 812 350
870 456 192 162 593 473 915 45 989 873 823 965 425 329 803
973 965 905 919 133 673 665 235 509 613 673 815 165 992 326
322 148 972 962 286 255 941 541 265 323 925 281 601 95 973
445 721 11 525 473 65 511 164 138 672 18 428 154 448 848
414 456 310 312 798 104 566 520 302 248 694 976 430 392 198
184 829 373 181 631 101 969 613 840 740 778 458 284 760 390
821 461 843 513 17 901 711 993 293 157 274 94 192 156 574
34 124 4 878 450 476 712 914 838 669 875 299 823 329 699
815 559 813 459 522 788 168 586 966 232 308 833 251 631 107
813 883 451 509 615 77 281 613 459 205 380 274 302 35 805
"""
def brute_force(m):
idxs = [i for i in range(len(m))]
max_sum = 0
for idx in permutations(idxs):
temp_sum = 0
for i, j in enumerate(idx):
temp_sum += m[j][i]
if temp_sum > max_sum:
max_sum = temp_sum
return max_sum
def a_star(m):
# init heuristic function
h = {}
for i in range(len(m)):
h[i] = sum(map(min, m[i:]))
h[i + 1] = 0
states = []
for i, r in enumerate(m[0]):
state = (r + h[1], r, 1, [i])
heappush(states, state)
while states:
_, gscore, index, seen = heappop(states)
if index == len(m):
return -gscore
for row_index, row_score in enumerate(m[index]):
if row_index in seen:
continue
ngscore = gscore + row_score
nfscore = ngscore + h[index + 1]
nstate = (nfscore, ngscore, index + 1, seen + [row_index])
heappush(states, nstate)
def euler_345():
m = list(zip(*map(lambda line: list(map(lambda n: -int(n), line.split())), m2.splitlines())))
return a_star(m)
if __name__ == "__main__":
solution = euler_345()
print("e345.py: " + str(solution))
assert(solution == 13938)