import { Box, Mark, Text, Tooltip } from '@mantine/core';
import { useMouse } from '@mantine/hooks';
import { IconInfoCircle } from '@tabler/icons-react';
import check from 'check-types';
import chroma from 'chroma-js';
import { useState } from 'react';

import { useColors } from '@/core/hooks/use-colors/use-colors';

import HoverTooltip from './hover-tooltip';
import TextSection from '../../atoms/text-section/text-section';

export type Segment = {
  start: number;
  end: number;
  hallucination?: number | null;
  hallucination_severity?: number;
  data_error_potential?: number | null | undefined;
  generated_uncertainty?: number | null | undefined;
  top_k_tokens?:
    | {
        token: string;
        prob: number;
      }[]
    | undefined;
  segment_probability?: number | null | undefined;
};

export interface TokenLevelScoreProps {
  response: string;
  segments: Segment[];
  isTruncated: boolean;
  showTokens: boolean;
  withTooltip?: boolean;
  coloringThreshold?: number;
  extraText?: string;
  extraTextTooltip?: string;
  threshold?: number;
}

export type BuiltSegment = {
  textToHighlight: string;
  color: string;
  value: number | null | undefined;
  kTokens?: any[];
  segProbability: number | null | undefined;
  isUncertainty?: boolean;
};

const getDepScaledColor = (dep: number | null, thresh: number) => {
  if (dep == null) return 'white';
  const range = chroma
    .scale(['#90e790', '#FDEEAF', '#FB7474'])
    .domain([0, thresh, 1]);

  return range(dep).hex();
};

const getUncertaintyColor = (value: number) => {
  let color = '#FB7474';
  if (value < 0.025) {
    color = '#90e790';
  } else if (value < 0.05) {
    color = '#FDEEAF';
  }

  return color;
};

const TokenLevelScore = ({
  response,
  segments,
  isTruncated,
  showTokens,
  withTooltip = false,
  coloringThreshold,
  extraText,
  extraTextTooltip,
  threshold
}: TokenLevelScoreProps) => {
  // Hooks
  const { getSteppedSeverityColor } = useColors();
  const { ref, x, y } = useMouse();

  const [hoveredSegment, setHoveredSegment] = useState<
    BuiltSegment | null | undefined
  >(null);
  // Utils
  let mappedSegments = segments?.map((segment) => {
    const { start, end, hallucination, hallucination_severity } = segment;

    const {
      generated_uncertainty: uncertainty,
      data_error_potential,
      top_k_tokens,
      segment_probability
    } = segment;

    const isUncertainty = check.number(uncertainty);
    const textToHighlight = response?.slice(start, end);

    const isHallucination = check.number(hallucination_severity);

    const color = isHallucination
      ? getSteppedSeverityColor(
          hallucination_severity as number,
          hallucination as number
        )
      : isUncertainty
        ? getUncertaintyColor(uncertainty as number)
        : getDepScaledColor(
            data_error_potential as number,
            coloringThreshold || (threshold as number)
          );

    return {
      textToHighlight,
      color,
      value: isHallucination
        ? hallucination_severity
        : isUncertainty
          ? uncertainty
          : data_error_potential,
      kTokens: top_k_tokens,
      segProbability: segment_probability,
      isUncertainty
    };
  });

  if (isTruncated) {
    const firstLineBreak = mappedSegments?.findIndex(
      (seg) => seg.textToHighlight === `\n`
    );
    if (firstLineBreak > 0) {
      mappedSegments = mappedSegments?.slice(0, firstLineBreak);
    }
  }

  if (!showTokens || !mappedSegments?.length) {
    return (
      <Box
        maw='100%'
        ref={ref}
        style={{
          overflowWrap: 'anywhere'
        }}
      >
        <TextSection isTruncated={isTruncated} value={response} />
      </Box>
    );
  } else {
    const textPosition = ref.current?.getBoundingClientRect() || {};
    const tooltipX = (textPosition.x || 0) + x;
    const tooltipY = (textPosition.y || 0) + y;
    return (
      <>
        {hoveredSegment && (
          <HoverTooltip seg={hoveredSegment} x={tooltipX} y={tooltipY} />
        )}
        <Text
          data-testid='token-level-score'
          ref={ref}
          size='sm'
          ta='left'
          truncate={isTruncated || undefined}
        >
          {mappedSegments?.map((seg, i) => {
            if ((seg.value as number) === -1) {
              return (
                <Text component='span' key={`space-${i}`} mb='xs'>
                  {seg.textToHighlight}
                </Text>
              );
            }

            return (
              <Mark
                bg={seg.color}
                key={seg.color + seg.textToHighlight + i}
                mb='sm'
                onMouseEnter={() => {
                  if (withTooltip) setHoveredSegment(seg);
                }}
                onMouseLeave={() => {
                  if (withTooltip) setHoveredSegment(null);
                }}
              >
                {seg.textToHighlight}
              </Mark>
            );
          })}
          {extraText && (
            <Text c='#9B98AE' component='span' fw={400} ml={2}>
              <Tooltip multiline withArrow label={extraTextTooltip} w={300}>
                <IconInfoCircle size={14} style={{ marginBottom: -2 }} />
              </Tooltip>
              {extraText}
            </Text>
          )}
        </Text>
      </>
    );
  }
};

export default TokenLevelScore;
