import { Flex, Text } from '@mantine/core';
import sortBy from 'lodash/sortBy';

import { TruncatedGroup } from '@/core/components/atoms/truncated-group/truncated-group';
import {
  TASK_TYPE,
  TaskType
} from '@/core/constants/tasks-and-frameworks.constants';
import { useFeatureFlagsStore } from '@/core/stores/feature-flags-store/feature-flags.store';
import LabelBadgeWithPopover from '@/fine-tune/components/label-badge/label-badge-with-popover/label-badge-with-popover';
import ModalCarouselHeader from '@/fine-tune/components/modal-carousel-header/modal-carousel-header';
import {
  ACCURACY_COLUMN,
  BASE_GOLD_COLUMN,
  BASE_PRED_COLUMN,
  BASE_TEXT_COLUMN,
  CONFIDENCE_COLUMN,
  DEP_COLUMN,
  ERROR_TYPES_COLUMN,
  GENERATED_OUTPUT_COLUMN,
  GT_OUTPUT_COLUMN,
  IMAGE_COLUMN,
  INPUT_COLUMN,
  MAX_DEP_COLUMN,
  MAX_DEP_LABEL_COLUMN,
  NER_GOLD_COLUMN,
  NER_PRED_COLUMN,
  NOISE_TYPE_COLUMN
} from '@/fine-tune/constants/base-columns.constants';
import {
  DEP_ACCESSOR,
  GOLD_ACCESSOR,
  ID_ACCESSOR,
  PRED_ACCESSOR,
  SAMPLE_ID_ACCESSOR,
  TOTAL_ERRORS_ACCESSOR
} from '@/fine-tune/constants/dataframe-accessors.constants';
import { useMetaColumns } from '@/fine-tune/hooks/query-hooks/use-meta-columns/use-meta-columns';
import { useComputerVisionStore } from '@/fine-tune/stores/computer-vision-store/computer-vision.store';
import { useLLMStore } from '@/fine-tune/stores/llm-store/llm.store';
import {
  useComputedParameters,
  useParametersStore
} from '@/fine-tune/stores/parameters-store';
import { Splits } from '@/fine-tune/stores/parameters-store/parameters.store.types';
import useStore from '@/fine-tune/stores/store';
import { InsightsRow } from '@/fine-tune/types/query.types';
import { S2S_EXTRA_COLUMNS } from '@/fine-tune/types/s2s.types';
import { getDriftScoreColor } from '@/fine-tune/utils/get-drift-score-color/get-drift-score-color';

import {
  getParsedMetaColumns,
  mapMltcLabels
} from './use-dataframe-columns.support';
import { useInsightsRows } from '../query-hooks/use-insights-rows/use-insights-rows';
import { useThresholds } from '../query-hooks/use-thresholds/use-thresholds';
import {
  formatBleu,
  formatRouge,
  getUncertaintyColor
} from '../use-colors/use-colors';
import { Modals, useModals } from '../../../core/hooks/use-modals/use-modals';

// Column constants
const TC_COLUMNS = [
  BASE_TEXT_COLUMN,
  DEP_COLUMN,
  BASE_GOLD_COLUMN,
  BASE_PRED_COLUMN,
  CONFIDENCE_COLUMN
];

const IC_COLUMNS = [
  IMAGE_COLUMN,
  DEP_COLUMN,
  BASE_GOLD_COLUMN,
  BASE_PRED_COLUMN,
  CONFIDENCE_COLUMN
];

const SS_COLUMNS = [
  DEP_COLUMN,
  BASE_GOLD_COLUMN,
  BASE_PRED_COLUMN,
  ACCURACY_COLUMN
];

const SD_COLUMNS = [DEP_COLUMN, BASE_GOLD_COLUMN, BASE_PRED_COLUMN].map(
  (col) => ({
    ...col,
    isSticky: true,
    isAlwaysVisible: true
  })
);

interface S2SColumnsParams {
  showHighDEPWords: boolean;
  showPriorityHeatmap: boolean;
  actions: {
    setShowHighDEPWords: (value: boolean) => void;
    setShowPriorityHeatmap: (value: boolean) => void;
  };
  thresh: number;
}

const getS2SColumns = ({
  showHighDEPWords,
  showPriorityHeatmap,
  actions,
  thresh
}: S2SColumnsParams) => [
  INPUT_COLUMN,
  GT_OUTPUT_COLUMN(showHighDEPWords, thresh, () =>
    actions.setShowHighDEPWords(!showHighDEPWords)
  ),
  GENERATED_OUTPUT_COLUMN(showPriorityHeatmap, () =>
    actions.setShowPriorityHeatmap(!showPriorityHeatmap)
  ),
  DEP_COLUMN,
  UNCERTAINTY_COLUMN,
  defaultNumericColumn('Perplexity', 'perplexity'),
  defaultNumericColumn('Input Cutoff', 'input_cutoff'),
  defaultNumericColumn('Target Cutoff', 'target_cutoff'),
  BLEU_COLUMN,
  ROUGE_COLUMN
];

const BLEU_COLUMN = {
  label: 'BLEU',
  accessor: 'bleu',
  width: 80,
  isSortable: true,
  cell: ({ value }: { value: number }) => formatBleu(value, true)
};

const UNCERTAINTY_COLUMN = {
  label: 'Uncertainty',
  accessor: 'generated_uncertainty',
  width: 115,
  isSortable: true,
  cell: ({ value }: { value: number }) => (
    <Text c={getUncertaintyColor(value)} fw={700} size='sm'>
      {value?.toFixed(2)}
    </Text>
  )
};

const defaultNumericColumn = (label: string, accessor: string) => ({
  label,
  accessor,
  width: 115,
  isSortable: true,
  cell: ({ value }: { value: number }) => (
    <Text fw={600} size='sm'>
      {parseFloat(value?.toFixed(2))}
    </Text>
  )
});

const ROUGE_COLUMN = {
  label: 'ROUGE',
  accessor: 'rouge',
  width: 90,
  isSortable: true,
  cell: ({ value }: { value: number }) => formatRouge(value, true)
};

const getMltcColumns = (showAllLabels: boolean) => {
  const labelsWidth = showAllLabels ? 250 : 'auto';
  const columnWidth = showAllLabels ? 250 : 200;

  return [
    BASE_TEXT_COLUMN,
    DEP_COLUMN,
    MAX_DEP_LABEL_COLUMN,
    MAX_DEP_COLUMN,
    {
      label: 'Ground Truth',
      accessor: GOLD_ACCESSOR,
      isFilterable: true,
      width: columnWidth,
      cell: ({
        data,
        isTruncated
      }: {
        data: InsightsRow;
        isTruncated: Boolean;
      }) => {
        const badges =
          (showAllLabels && mapMltcLabels(data, 'mltc_golds')) || [];

        return (
          <Flex className='cursor-pointer' gap='sm' w={labelsWidth} wrap='wrap'>
            <LabelBadgeWithPopover isFilled sample={data} value={data?.gold} />
            <TruncatedGroup
              elements={badges}
              isTruncated={Boolean(isTruncated)}
              labels={data?.mltc_golds}
            />
          </Flex>
        );
      }
    },
    {
      label: 'Prediction',
      accessor: PRED_ACCESSOR,
      isFilterable: true,
      width: columnWidth,
      cell: ({
        data,
        isTruncated
      }: {
        data: InsightsRow;
        isTruncated: Boolean;
      }) => {
        const badges =
          (showAllLabels && mapMltcLabels(data, 'mltc_preds')) || [];

        return (
          <Flex className='cursor-pointer' gap='sm' w={labelsWidth} wrap='wrap'>
            <LabelBadgeWithPopover isFilled sample={data} value={data?.pred} />
            <TruncatedGroup
              elements={badges}
              isTruncated={Boolean(isTruncated)}
              labels={data?.mltc_preds}
            />
          </Flex>
        );
      }
    },
    CONFIDENCE_COLUMN
  ];
};
const NER_COLUMNS = [
  NER_GOLD_COLUMN,
  NER_PRED_COLUMN,
  DEP_COLUMN,
  {
    label: 'Total Errors',
    accessor: TOTAL_ERRORS_ACCESSOR,
    isSortable: true
  },
  ERROR_TYPES_COLUMN
];

const INFERENCE_COLUMNS = [
  BASE_TEXT_COLUMN,
  BASE_PRED_COLUMN,
  CONFIDENCE_COLUMN
];

const INFERENCE_NER_COLUMNS = [NER_PRED_COLUMN, CONFIDENCE_COLUMN];

/**
 * useDataframeColumns
 *
 */
export const useDataframeColumns = () => {
  // Computed Params
  const {
    isMltc,
    isNer,
    isInference,
    isInferenceNer,
    isSD,
    isSS,
    isIc,
    isS2S
  } = useComputedParameters();

  // Store
  const allMltcLabelsVisible = useStore((s) => s.allMltcLabelsVisible);
  const columnsOrder = useStore((s) => s.columnsOrder);

  const thresholds = useThresholds();
  const { hasS2SExtraColumns } = useInsightsRows();

  // Param Store
  const dataframeColumns = useParametersStore((s) => s.dataframeColumns);

  // Computer Vision Store
  const setModalIndex = useComputerVisionStore((s) => s.actions.setModalIndex);

  const { openModal } = useModals();

  // LLM Store
  const { showHighDEPWords, showPriorityHeatmap, actions } = useLLMStore(
    (s) => ({
      showHighDEPWords: s.showHighDEPWords,
      showPriorityHeatmap: s.showPriorityHeatmap,
      actions: s.actions
    })
  );

  const isDrifted = useParametersStore((s) => s.isDrifted);

  // Feature Flags
  const { getCoOccurrenceFlag } = useFeatureFlagsStore((s) => s.computed);
  const enableCoOccurrence = getCoOccurrenceFlag();

  // Hooks
  const { data: metaColsData } = useMetaColumns();

  const meta = metaColsData ? getParsedMetaColumns(metaColsData) : [];

  const handleS2SRowClick = (rowIndex: number) => {
    setModalIndex(rowIndex);
    openModal(
      Modals.S2S_ROW_DETAIL,
      {},
      () => {},
      <ModalCarouselHeader width='65vw' withTitle={false} />
    );
  };
  // TODO: Update typing
  let renderColumns: any[] = TC_COLUMNS;

  if (isNer) {
    renderColumns = NER_COLUMNS;
  }

  if (isIc) {
    renderColumns = IC_COLUMNS;
  }

  if (enableCoOccurrence && isMltc) {
    renderColumns = getMltcColumns(allMltcLabelsVisible);
  }

  if (isInference) {
    renderColumns = isInferenceNer ? INFERENCE_NER_COLUMNS : INFERENCE_COLUMNS;
  }

  if (isDrifted) {
    renderColumns = [
      ...renderColumns,
      {
        label: isInference ? 'Drift Score' : 'OOC Score',
        // OOC is the same as drift score, just called different name in test/ validation
        accessor: 'drift_score',
        isSortable: true,
        cell: ({ value }: { value: number }) => (
          <Text c={getDriftScoreColor(value)}>{value?.toFixed(3)}</Text>
        ),
        order: 1
      }
    ];
  }

  if (isSD) {
    renderColumns = [...SD_COLUMNS, NOISE_TYPE_COLUMN];
  }

  if (isSS) {
    renderColumns = SS_COLUMNS;
  }

  if (isS2S) {
    renderColumns = getS2SColumns({
      showHighDEPWords,
      showPriorityHeatmap,
      actions,
      thresh: thresholds?.data?.easy_samples_threshold || 0
    });
  }
  // Sorted first by sticky columns, then by order (if present)
  const allColumns = sortBy(
    [
      ...renderColumns,
      ...(isS2S
        ? meta.filter(
            (meta) =>
              !renderColumns.some((col) => col.accessor === meta.accessor)
          )
        : meta)
    ],
    ['isSticky', 'order', (o) => columnsOrder.indexOf(o.accessor)]
  );

  const visibleColumns = allColumns.filter(
    ({ accessor, isAlwaysVisible }) =>
      dataframeColumns?.includes(accessor) || isAlwaysVisible
  );

  // Id column changes for NER - sample_id is id of given span
  const idColumn = isNer
    ? { label: 'Sample Id', accessor: SAMPLE_ID_ACCESSOR }
    : { label: 'Id', accessor: ID_ACCESSOR };

  // If NER, add text as an export option since it is not rendered in table
  const textColumn = { label: 'Text', accessor: 'text' };

  let listState = [idColumn, ...allColumns];

  if (isNer) {
    // Insert textColumn after id column
    listState.splice(1, 0, textColumn);
    listState.splice(2, 0, { label: 'Spans', accessor: 'spans' });
  }

  const isContinuous = (columnName: string) =>
    ['confidence', 'dep'].includes(columnName) ||
    allColumns?.find((col) => col.accessor === columnName)?.isSortable;

  const isCategorical = (columnName: string) =>
    allColumns?.find((col) => col.accessor === columnName)?.isFilterable;

  return {
    stickyColumns: visibleColumns.filter(({ isSticky }) => isSticky),
    nonStickyColumns: visibleColumns.filter(({ isSticky }) => !isSticky),
    allColumns,
    listStateColumns: listState
      ?.filter(({ accessor }) => {
        if (isNer) {
          return ![GOLD_ACCESSOR, PRED_ACCESSOR, DEP_ACCESSOR].includes(
            accessor
          );
        }

        if (isS2S && S2S_EXTRA_COLUMNS.includes(accessor)) {
          return hasS2SExtraColumns;
        }
        return true;
      })
      .map(({ label, accessor }) => ({
        label,
        accessor,
        checked: true,
        renamedValue: ''
      })),
    metaColumnAccessors: meta.map(
      ({ accessor }: { accessor: string }) => accessor
    ),
    categoricalColumns: allColumns.filter(({ isFilterable }) => isFilterable),
    isContinuous,
    isCategorical,
    onRowClick: isS2S ? handleS2SRowClick : undefined
  };
};

export const getDefaultColumns = (taskType: TaskType, split?: Splits) => {
  if (taskType === TASK_TYPE.S2S) {
    return ['input', 'target', 'data_error_potential', 'perplexity'];
  }
  if (taskType === TASK_TYPE.NER) {
    return [
      'confidence',
      'data_error_potential',
      'gold',
      'pred',
      'total_errors',
      'error_types'
    ];
  }

  if (taskType === TASK_TYPE.SD) {
    return ['data_error_potential', 'gold', 'pred', 'noise_type'];
  }

  if (split === 'inference') {
    return ['pred', 'confidence'];
  }

  if (taskType === TASK_TYPE.MLTC) {
    return [
      'confidence',
      'data_error_potential',
      'gold',
      'max_dep',
      'max_dep_label',
      'pred',
      'text'
    ];
  }

  if (taskType === TASK_TYPE.SS) {
    return [];
  }

  return ['confidence', 'data_error_potential', 'gold', 'pred', 'text'];
};
