%% IFAP 2025 High Rollers Solver
clear; clc; close all; 

%% graph settings
set(0, 'DefaultLineLineWidth', 1, 'DefaultLineMarkerSize', 6, 'DefaultAxesFontSize', 18, ...
    'DefaultAxesTickLabelInterpreter', 'latex', 'DefaultFigurePaperOrientation', 'portrait',...
    'DefaultFigurePaperUnits', 'inches',...
    'DefaultFigurePaperSize', [6, 4.5],...
    'DefaultFigurePaperPosition', [0,0,6,4.5],...
    'DefaultLegendInterpreter', 'latex', ... 
    'DefaultLineLineWidth', 2, ...
    'DefaultTextInterpreter', 'latex', 'DefaultFigureRenderer', 'painters')

%% Backward Induction from no rounds left to T:
    clc % clear the console but not memory
    % Set up ---
        T=100; price = 0;
        faces   = 1:20;         % scores of d20
        R       = zeros(T,1); % reroll values
        xstar   = zeros(T,1); % thresholds
        x_exact = zeros(T,1);
        V       = zeros(T,20);% value function

    % Game is trivial at V^0

    % Last Turn, T=1 ---
        V(1,:)      = log(faces);  % V(1,x) = x when one round remains
        R(1)        = 0;      % R(1) = E(V(0)) = 0
        xstar(1)    = 0;      % irrelevant, always stick

    % Build up, 2 to T ---
        for t = 2:T
            % reroll value
            R(t) = mean(V(t-1,:));
            % threshold
            x_exact(t) = exp(R(t))/t;
            xstar(t) = ceil(x_exact(t));
            % fill V
            for x = faces
                V(t,x) = max(log(t*x), R(t));
            end
        end

    % Starting state (T rounds left, face=1)
        expected_total = mean(V(T,:));
        avg_per_round  = expected_total / T;
        nt             = 10.5*T;
    % Print updates --- 
        fprintf('Expected total payoff = %.4f\n', expected_total);
        fprintf('Average per round     = %.4f\n', avg_per_round);
        fprintf('Threshold at T=%d     = %d\n', T, xstar(T));
        fprintf('E(X)xT                = %d\n', nt);

%% Plot Results --- 
close all;

figure(99);
subplot(1,2,1)
hold on;

% Shade the regions
x_plot      = [1, 1:T, T];  % x coordinates for patch
y_stick     = [max(xstar)+1, xstar(1:T)', max(xstar)+1];  % from threshold to top
y_reroll    = [0, xstar(1:T)', 0];  % from bottom to threshold

% Create shaded regions
patch([x_plot, fliplr(x_plot)], [y_stick, max(xstar)+1*ones(1,length(x_plot))], ...
    [0.7, 0.85, 1], 'FaceAlpha', 0.3, 'EdgeColor', 'none'); % light blue for Stick
patch([x_plot, fliplr(x_plot)], [zeros(1,length(x_plot)), fliplr(y_reroll)], ...
    [1, 0.7, 0.7], 'FaceAlpha', 0.3, 'EdgeColor', 'none'); % light red for Reroll

% Plot the threshold line
scatter(1:T, xstar(1:T), 'MarkerFaceColor', 'b', 'LineWidth', 2);

% Add text labels
text(T*0.01+1, max(xstar)*1, 'Stick is Optimal', 'FontSize', 18, 'Interpreter', 'latex', ...
    'HorizontalAlignment', 'left', 'FontWeight', 'bold', 'Color', 'b');
text(T*0.5, max(xstar)*0.5, 'Reroll is Optimal', 'FontSize', 18, 'Interpreter', 'latex', ...
    'HorizontalAlignment', 'center', 'FontWeight', 'bold', 'Color', 'r');

xlabel('Rounds remaining'); 
ylabel('Threshold');
title('Optimal thresholds for d20, T');
grid on;
ylim([0 max(xstar)+1]); 
xlim([1 T]);
hold off;

subplot(1,2,2)
scatter(1:T, mean(V(1:T,:),2), 'MarkerFaceColor', 'b', 'LineWidth',2); hold on;  % Note: V(:,2: T=1 remaining) to get face=1
plot(1:T, log(10.5*(1:T)), '--k');
xlabel('Rounds remaining'); ylabel('Value');
title('Expected Value, $E(V^T(x))$');
grid on; xlim([1 T]);  % This adds grid lines at minor ticks too
x0=10;
y0=10;
width=800;
height=400;
set(gcf,'position',[x0,y0,width,height])
tightfig
saveas(gcf,'high_rollers1.pdf')

%% 
close all
subplot(2,2,1)
plot([1:1:20], log(T*[1:1:20])); hold on;
plot([1:1:20], ones(20,1)*R(T)); %ylim([0.95*min(V(T, :)) max(V(T, :))])
title('$T=100$'); grid;xticks(0:2:20);
legend('$S_{t}(x)$','$R_{t}$', 'Location','west');

subplot(2,2,2)
plot([1:1:20], log((T/2)*[1:1:20])); hold on;
plot([1:1:20], ones(20,1)*R(T/2)); %ylim([0.95*min(V(T/2, :)) max(V(T/2, :))])
title('$T=50$'); grid;xticks(0:2:20);  

subplot(2,2,3)
plot([1:1:20], log(3*[1:1:20])); hold on;
plot([1:1:20], ones(20,1)*R(3)); %ylim([0.9*min(V(3, :)) max(V(3, :))])
title('$T=3$'); grid;xticks(0:2:20);  
subplot(2,2,4)
plot([1:1:20], log(1*[1:1:20])); hold on;
plot([1:1:20], ones(20,1)*R(1)); %ylim([0 max(V(1, :))])
title('$T=1$'); grid; xticks(0:2:20);  

x0=10;
y0=10;
width=1000;
height=500;
tightfig
set(gcf,'position',[x0,y0,width,height])
saveas(gcf,'high_rollers2.pdf')


%% Simulate Naive Sum and Optimal Play
close all
simulate_panel_d20(33, 10^6)

%% HELPER fn: simulation calls

function simulate_d20(T, num_paths)
    faces = 1:20;
    
    % Precompute thresholds via DP (same as before)
    R = zeros(T+1,1); xstar = zeros(T+1,1); V = zeros(T+1,20);
    V(1,:) = faces;
    for t = 2:T
        R(t) = mean(V(t-1,:));
        xstar(t) = ceil(R(t)/t);
        for x = faces
            V(t,x) = max(t*x, R(t));
        end
    end

    % Simulate paths
    rng('shuffle'); % random seed
    for p = 1:num_paths
        t = T; 
        x = randi(20); % first roll
        history = x;
        while t > 0
            if x >= xstar(t)  % stick
                payout = x * t;
                fprintf('Path %d: Stick on %d with %d rounds left, payout = %.2f\n', ...
                        p, x, t, payout);
                break;
            else  % reroll
                t = t - 1;
                x = randi(20);
                history(end+1) = x;
            end
        end
        fprintf('Roll history: %s\n\n', mat2str(history));
    end
end


%%

function simulate_panel_d20(T, N)
    faces = 1:20;
    
    % --- Compute DP thresholds ---
    R = zeros(T+1,1); xstar = zeros(T+1,1); V = zeros(T+1,20);
    V(1,:) = faces;
    for t = 2:T
        R(t) = mean(V(t-1,:));
        xstar(t) = ceil(R(t)/t);
        for x = faces
            V(t,x) = max(t*x, R(t));
        end
    end

    % --- Simulate N players with optimal policy ---
    payouts = zeros(N,1);
    rng(124); % random seed

    for i = 1:N
        t = T;
        x = randi(20); % initial roll
        while t > 0
            if x >= xstar(t)  % stick
                payouts(i) = x * t;
                break;
            else
                t = t - 1;    % reroll
                x = randi(20);
            end
        end
    end

    % --- Simulate naive total scores (sum of T independent rolls) ---
    naive_scores = sum(randi(20, N, T), 2);

    % --- Plot histograms ---
    figure;
    % Plot histograms
    histogram(payouts, 20, 'Normalization', 'probability', ...
        'FaceColor', [0.2 0.6 0.8], 'EdgeAlpha', 0);
    hold on;
    histogram(naive_scores, 20, 'Normalization', 'probability', ...
        'FaceColor', [1 0 0], 'EdgeAlpha', 0);
    
    % Means
    mean_payouts = mean(payouts(:));
    mean_naive   = mean(naive_scores(:));
    
    % Vertical lines
    xline(mean_payouts, 'b-', 'LineWidth', 2);
    xline(mean_naive, 'r-', 'LineWidth', 2);
    
    % Add text above lines
    yl = ylim; % get y-axis limits
    text(mean_payouts, yl(2)*0.9, sprintf('%.2f', mean_payouts), ...
        'HorizontalAlignment', 'center', 'Color', 'k', 'FontSize', 18);
    text(mean_naive, yl(2)*0.8, sprintf('%.2f', mean_naive), ...
        'HorizontalAlignment', 'center', 'Color', 'k', 'FontSize', 18);
    
    hold off;


    xlabel('Total payout');
    ylabel('Probability');
    title(sprintf('Payoffs: %d players, T=%d', N, T));
    legend('Optimal policy', 'Naive sum of rolls', 'Location','northwest');
    grid on;

    x0=10;
    y0=10;
    width=800;
    height=400;
    set(gcf,'position',[x0,y0,width,height])
    tightfig
    saveas(gcf,'high_rollers3.pdf')

    % --- Print summary statistics ---
    fprintf('Optimal policy:\n');
    fprintf('Mean payout: %.2f\n', mean(payouts));
    fprintf('Median payout: %.2f\n', median(payouts));
    fprintf('Min payout: %.2f\n', min(payouts));
    fprintf('Max payout: %.2f\n', max(payouts));

    fprintf('\nNaive sum of rolls:\n');
    fprintf('Mean: %.2f\n', mean(naive_scores));
    fprintf('Median: %.2f\n', median(naive_scores));
    fprintf('Min: %.2f\n', min(naive_scores));
    fprintf('Max: %.2f\n', max(naive_scores));
end


