/*

Magic squares in Picat.

Model created by Hakan Kjellerstrand, hakank@gmail.com

*/

% Licenced under CC-BY-4.0 : http://creativecommons.org/licenses/by/4.0/

import util.
import cp.

main => go.

%
% Simple run
%
go =>
magic(6,_Square).

%
% Run for different sizes
%
% ff/updown solves N= 6 in 0.06s and 7099 backtracks
%                    10 in 2.99s and 312671 backtracks
go2 =>
foreach(N in 3..10) time2(magic(N,_Square)) end.

%
% All solutions.
%
go3 =>
N = 4,
L = findall(Square,magic(N,Square)),
Len = length(L),
printf("Len: %d\n",Len).

magic(N,Square) =>

writef("\n\nN: %d\n", N),
NN = N*N,
Sum = N*(NN+1)//2,% magical sum
writef("Sum = %d\n", Sum),

Square = new_array(N,N),
Square :: 1..NN,

all_different(Square.vars()),

foreach(I in 1..N)
Sum #= sum([T : J in 1..N, T = Square[I,J]]),% rows
Sum #= sum([T : J in 1..N, T = Square[J,I]]) % column
end,

% diagonal sums
Sum #= sum([Square[I,I] : I in 1..N]),
Sum #= sum([Square[I,N-I+1] : I in 1..N]),

% Symmetry breaking
Square[1,1] #< Square[1,N],
Square[1,1] #< Square[N,N],
Square[1,1] #< Square[N,1],
Square[1,N] #< Square[N,1],

solve([ffd,updown],Square),

print_square(Square).

%
% Alternativt using rows(), columns(), diagonal1(), and diagonal2()
% from the util module.
%
magic2(N,Square) =>

printf("\n\nN: %d\n", N),
NN = N*N,
Sum = N*(NN+1)//2,% magical sum
printf("Sum = %d\n", Sum),

Square = new_array(N,N),
Square :: 1..NN,

all_different(Square.vars()),

% Note that sum/1 requires a list, so Row and Column must be converted
% to a list with .to_list().
foreach(Row in Square.rows()) Sum #= sum(Row.to_list()) end,
foreach(Column in Square.columns()) Sum #= sum(Column.to_list()) end,

% diagonal sums
Sum #= sum(Square.diagonal1()),
Sum #= sum(Square.diagonal2()),

println(diagonals),

% Symmetry breaking
Square[1,1] #< Square[1,N],
Square[1,1] #< Square[N,N],
Square[1,1] #< Square[N,1],
Square[1,N] #< Square[N,1],

solve([ffd,updown],Square),

print_square(Square).

print_square(Square) =>
N = Square.length,
foreach(I in 1..N)
foreach(J in 1..N)
writef("%3d ", Square[I,J])
end,
writef("\n")
end,
writef("\n").