import { CustomLayer, ResponsiveLine } from "@nivo/line";
import { useOrdinalColorScale } from "@nivo/colors";
import { Defs } from "@nivo/core";
import { useMemo } from "react";
import {
  line as LineGenerator,
  area as AreaGenerator,
  curveStepAfter,
} from "d3-shape";

interface SurvivalAnalysisProps {
  data: FormatedGraphData[];
  selectedAudiences: string[];
  x_axis_type: string;
}

export interface FormatedGraphData {
  id: string;
  cohort: string;
  month: string;
  data: { x: number; y: number }[];
}

export const SurvivalAnalysisGraph = (props: SurvivalAnalysisProps) => {
  let isSurvival = props.x_axis_type === "survival"; // Only alternative is lifetime_value

  const colors = useOrdinalColorScale({ scheme: "category10" }, "cohort");
  const data = useMemo(() => {
    // Update `colors` inner mapping
    props.data.map((value) => colors(value));

    // Filters out unchecked legend items
    let _data = Object.values(props.data).filter((value) =>
      props.selectedAudiences.includes(value.cohort)
    );
    // Inverse the survival function
    _data = _data.map((value) => {
      let _value = Object.assign({}, value);
      _value.data = value.data.map((coords) => ({
        x: coords.x,
        y: 1 - coords.y,
      }));
      return _value;
    });
    return _data;
  }, [props.data, props.selectedAudiences, colors]);

  const DashedLine: CustomLayer = (_props) => {
    // Custom line generator from d3
    let lineGenerator = new (LineGenerator as any)().curve(curveStepAfter);
    // Convert each date string to a number, larger is more recent
    let dates = _props.series
      .map((serie) => {
        let date = new Date(serie.month);
        return date.getFullYear() * 12 + date.getMonth();
      })
      .filter((d) => !isNaN(d));
    // Lets us know the range of dates we are using
    let minDate = Math.min.apply(null, dates);
    let maxDate = Math.max.apply(null, dates);
    // Converts each serie to a line, solid for "all-time average"; highlighted
    // when the user is comparing dates, and fading dashed lines corresponding to
    // older = more faded
    return _props.series
      .map((serie) => {
        // Using the month string, gets posBetweenDates: 0->1 from minDate->maxDate
        // monthAsNumber is NaN when serie.month == "default"
        let monthDate = new Date(serie.month);
        let monthAsNumber = monthDate.getFullYear() * 12 + monthDate.getMonth();
        let posBetweenDates =
          maxDate === minDate
            ? 0.5
            : (monthAsNumber - minDate) / (maxDate - minDate);
        // Main line, solid for "all-time average", and dashed for months
        let paths = [
          <path
            key={serie.id}
            d={lineGenerator(
              Object.values(serie.data).map((d) => [
                _props.xScale(d.data.x as number) as number,
                _props.yScale(d.data.y as number) as number,
              ])
            )}
            stroke={serie.color}
            fill="none"
            style={{
              strokeDasharray: monthAsNumber
                ? `${Math.floor(22 * posBetweenDates + 3)}, 3` // 3-25
                : undefined,
              strokeWidth: monthAsNumber
                ? posBetweenDates * 1.8 + 0.2 // 0.2-2
                : 2,
            }}
          />,
        ];
        // Adds a highlight to the "all-time average" line when months are included
        if (dates.length !== 0 && isNaN(monthAsNumber))
          paths = [
            <path
              key={`${serie.id}_0`}
              d={lineGenerator.curve(curveStepAfter)(
                Object.values(serie.data).map((d) => [
                  d.position.x,
                  d.position.y,
                ])
              )}
              stroke={serie.color + "66"}
              fill="none"
              style={{
                strokeWidth: 10,
              }}
            />,
            ...paths,
          ];
        return paths;
      })
      .flat();
  };

  // nivo.rocks default area layer was a bit buggy when changing the x-axis scale,
  // so creating our own custom area layer
  const AreaLayer: CustomLayer = (_props) => {
    // Custom area-under-curve from d3
    const areaGenerator = new (AreaGenerator as any)()
      .curve(curveStepAfter)
      .x((d: { data: { x: number } }) => _props.xScale(d.data.x))
      .y0(_props.innerHeight)
      .y1((d: { data: { y: number } }) => _props.yScale(d.data.y));

    // Creates a gradient area for each serie
    return _props.series.map((serie, i) => {
      let gradientId = `gradient-${i}`;
      return (
        <>
          <Defs
            defs={[
              {
                id: gradientId,
                type: "linearGradient",
                colors: [
                  {
                    offset: 0,
                    color: serie.color as string,
                    // Opacity for all lines will approach 0.1 as we add more lines
                    // to the graph. Max of 6 lines before opacity=0.1
                    opacity:
                      data.length > 7
                        ? 0.1
                        : data.length === 0
                        ? 0.4
                        : ((6 - (data.length - 1)) / 6) * 0.3 + 0.1, // 0.1-0.4
                  },
                  { offset: 100, color: serie.color as string, opacity: 0 },
                ],
              },
            ]}
          />
          <path d={areaGenerator(serie.data)} fill={`url(#${gradientId})`} />
        </>
      );
    });
  };

  const marginPositionChange: any = {};

  return (
    <ResponsiveLine
      pointLabel={function (e) {
        return e.x + ": " + e.y;
      }}
      data={data}
      curve="stepAfter"
      enablePoints={false}
      enablePointLabel={true}
      enableSlices="x"
      margin={{
        top: 20,
        right: 20,
        bottom: 60,
        left: 80,
        ...marginPositionChange,
      }}
      animate={true}
      pointLabelYOffset={-20}
      colors={colors}
      crosshairType="cross"
      xScale={{
        ...(isSurvival
          ? {
              type: "linear",
            }
          : {
              type: "symlog",
            }),
        ...{
          // For some reason these are not computed correctly by the library, so we force compute it here
          min: Math.min(
            ...data.map((value) => value.data.map((data) => data.x)).flat()
          ),
          max: Math.max(
            ...data.map((value) => value.data.map((data) => data.x)).flat()
          ),
        },
      }}
      xFormat={isSurvival ? " >.0f" : ">$,.2f"}
      yScale={{
        type: "linear",
        stacked: false,
        min: 0,
        max: 1,
      }}
      yFormat=" >.2%"
      axisLeft={{
        tickSize: 5,
        tickPadding: 5,
        tickRotation: 0,
        legend: "Churn Risk",
        legendOffset: -50,
        legendPosition: "middle",
        format: function (value) {
          return Math.floor(value * 100) + "%";
        },
      }}
      axisBottom={{
        tickSize: 5,
        tickPadding: 5,
        tickRotation: 0,
        legend: isSurvival ? "Days" : "Lifetime Value ($)",
        legendOffset: 36,
        legendPosition: "middle",
      }}
      layers={[
        "grid",
        AreaLayer,
        "markers",
        "crosshair",
        DashedLine,
        "slices",
        "points",
        "axes",
        "legends",
      ]}
      markers={[
        {
          axis: "y",
          value: 0.5,
          lineStyle: {
            stroke: "red",
          },
          textStyle: {
            fill: "red",
          },
        },
      ]}
      sliceTooltip={({ slice }) => {
        let xVal = Math.min(
          ...slice.points.map((point) => {
            return point.data.x as number;
          })
        );
        let xFormatted = Object.assign(
          {},
          ...slice.points.map((point) => {
            return { [point.data.x as number]: point.data.xFormatted };
          })
        )[xVal];
        return (
          <div
            style={{
              background: "white",
              padding: "9px 12px",
              border: "1px solid #ccc",
            }}
          >
            <div>
              {isSurvival ? "Day" : "LTV ($)"}: {xFormatted}
            </div>
            {slice.points
              .filter((point) => {
                var samePoints = slice.points.filter(
                  (_point) => _point.serieId === point.serieId
                );
                var smallestPoint = samePoints[samePoints.length - 1];
                return point.id === smallestPoint.id;
              })
              .sort((point1, point2) => {
                return point1.y - point2.y;
              })
              .map((point) => (
                <div
                  key={point.id}
                  style={{
                    color: point.serieColor,
                    padding: "3px 0",
                  }}
                >
                  <strong>{point.serieId}</strong> {point.data.yFormatted}
                </div>
              ))}
          </div>
        );
      }}
    />
  );
};
