import edu.princeton.cs.algs4.In;
import edu.princeton.cs.algs4.StdDraw;
import java.awt.Color;

public class WeakLearnerVisualizer {

    public static void main(String[] args) {
        // Define colors
        Color LIGHT_BLUE = Color.decode("#7faac9");
        Color LIGHT_RED = Color.decode("#e5b4b8");
        Color BLUE = Color.decode("#005493");
        Color RED = Color.decode("#8d3138");

        // Read input file
        In datafile = new In(args[0]);
        int n = datafile.readInt(); // number of points
        int k = datafile.readInt(); // dimensions

        if (k != 2) {
            throw new IllegalArgumentException("Only 2D data is supported for visualization.");
        }

        int[][] input = new int[n][k];
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < k; j++) {
                input[i][j] = datafile.readInt();
            }
        }

        int[] labels = new int[n];
        for (int i = 0; i < n; i++) {
            labels[i] = datafile.readInt();
        }

        double[] weights = new double[n];
        for (int i = 0; i < n; i++) {
            weights[i] = datafile.readDouble();
        }

        WeakLearner weakLearner = new WeakLearner(input, weights, labels);

        // Get decision stump parameters
        int dimension = weakLearner.dimensionPredictor(); // dp
        int value = weakLearner.valuePredictor();         // vp
        int sign = weakLearner.signPredictor();           // sp

        // Determine plot boundaries
        int minX = Integer.MAX_VALUE, maxX = Integer.MIN_VALUE;
        int minY = Integer.MAX_VALUE, maxY = Integer.MIN_VALUE;
        for (int i = 0; i < n; i++) {
            minX = Math.min(minX, input[i][0]);
            maxX = Math.max(maxX, input[i][0]);
            minY = Math.min(minY, input[i][1]);
            maxY = Math.max(maxY, input[i][1]);
        }

        // Add a buffer of 1 unit
        minX -= 1;
        maxX += 1;
        minY -= 1;
        maxY += 1;

        // Compute accuracy
        double accuracy = 0.0;
        for (int i = 0; i < n; i++) {
            int predictedLabel = weakLearner.predict(input[i]);
            if (predictedLabel == labels[i]) {
                accuracy += weights[i];
            }
        }

        // Set up StdDraw canvas
        StdDraw.setCanvasSize(800, 800); // Larger canvas
        StdDraw.setXscale(minX - 1, maxX + 1); // Expand the range for axes arrows
        StdDraw.setYscale(minY - 1, maxY + 1);
        StdDraw.clear();

        // Draw shaded regions
        if (dimension == 0) { // Vertical split
            if (sign == 0) {
                StdDraw.setPenColor(LIGHT_BLUE);
                StdDraw.filledRectangle((minX + value) / 2.0, (minY + maxY) / 2.0, (value - minX) / 2.0, (maxY - minY) / 2.0);
                StdDraw.setPenColor(LIGHT_RED);
                StdDraw.filledRectangle((value + maxX) / 2.0, (minY + maxY) / 2.0, (maxX - value) / 2.0, (maxY - minY) / 2.0);
            } else {
                StdDraw.setPenColor(LIGHT_RED);
                StdDraw.filledRectangle((minX + value) / 2.0, (minY + maxY) / 2.0, (value - minX) / 2.0, (maxY - minY) / 2.0);
                StdDraw.setPenColor(LIGHT_BLUE);
                StdDraw.filledRectangle((value + maxX) / 2.0, (minY + maxY) / 2.0, (maxX - value) / 2.0, (maxY - minY) / 2.0);
            }
        } else if (dimension == 1) { // Horizontal split
            if (sign == 0) {
                StdDraw.setPenColor(LIGHT_BLUE);
                StdDraw.filledRectangle((minX + maxX) / 2.0, (minY + value) / 2.0, (maxX - minX) / 2.0, (value - minY) / 2.0);
                StdDraw.setPenColor(LIGHT_RED);
                StdDraw.filledRectangle((minX + maxX) / 2.0, (value + maxY) / 2.0, (maxX - minX) / 2.0, (maxY - value) / 2.0);
            } else {
                StdDraw.setPenColor(LIGHT_RED);
                StdDraw.filledRectangle((minX + maxX) / 2.0, (minY + value) / 2.0, (maxX - minX) / 2.0, (value - minY) / 2.0);
                StdDraw.setPenColor(LIGHT_BLUE);
                StdDraw.filledRectangle((minX + maxX) / 2.0, (value + maxY) / 2.0, (maxX - minX) / 2.0, (maxY - value) / 2.0);
            }
        }

        // Draw grid
        StdDraw.setPenColor(StdDraw.LIGHT_GRAY);
        for (int x = minX; x <= maxX; x++) {
            StdDraw.line(x, minY, x, maxY);
        }
        for (int y = minY; y <= maxY; y++) {
            StdDraw.line(minX, y, maxX, y);
        }

        // Draw axes with arrowheads
        StdDraw.setPenColor(StdDraw.BLACK);
        StdDraw.setPenRadius(0.01); // Thicker axes
        double arrowSize = 0.3; // Size of arrowhead

        // x-axis
        StdDraw.line(0, 0, maxX + 0.5 - arrowSize, 0);
        StdDraw.filledPolygon(
            new double[]{maxX + 0.5, maxX + 0.5 - arrowSize, maxX + 0.5 - arrowSize},
            new double[]{0, arrowSize / 2, -arrowSize / 2}
        );

        // y-axis
        StdDraw.line(0, 0, 0, maxY + 0.5 - arrowSize);
        StdDraw.filledPolygon(
            new double[]{0, arrowSize / 2, -arrowSize / 2},
            new double[]{maxY + 0.5, maxY + 0.5 - arrowSize, maxY + 0.5 - arrowSize}
        );

        // Axis labels
        StdDraw.setFont(StdDraw.getFont().deriveFont(20f)); // Larger font size
        StdDraw.setPenColor(StdDraw.BLACK);
        StdDraw.text(maxX + 0.7, -0.3, "x");
        StdDraw.text(-0.3, maxY + 0.7, "y");

        // Draw decision stump line
        StdDraw.setPenColor(StdDraw.BLACK);
        StdDraw.setPenRadius(0.005); // Thinner than axes
        if (dimension == 0) {
            StdDraw.line(value, minY - 0.2, value, maxY + 0.2); // Vertical line
        } else if (dimension == 1) {
            StdDraw.line(minX - 0.2, value, maxX + 0.2, value); // Horizontal line
        }

        // Draw the input points
        StdDraw.setPenRadius(0.01);
        for (int i = 0; i < n; i++) {
            if (labels[i] == 0) {
                StdDraw.setPenColor(BLUE);
                StdDraw.filledSquare(input[i][0], input[i][1], 0.1);
            } else if (labels[i] == 1) {
                StdDraw.setPenColor(RED);
                StdDraw.filledCircle(input[i][0], input[i][1], 0.1);
            }
        }

        // Display accuracy and predictor parameters at the bottom
        StdDraw.setFont(StdDraw.getFont().deriveFont(18f)); // Font size for text
        StdDraw.setPenColor(StdDraw.BLACK); // Black text
        double textHeight = minY - 0.3; // Adjusted text height for visibility
        StdDraw.text((minX + maxX) / 2.0, textHeight, String.format("Accuracy: %.2f", accuracy));
        StdDraw.text((minX + maxX) / 2.0, textHeight - 0.5, String.format("vp: %d, dp: %d, sp: %d", value, dimension, sign));

        StdDraw.show();
    }
}