import { useQuery } from '@tanstack/react-query';

import api from '@/core/api';
import { usePathParameters } from '@/core/hooks/use-path-parameters/use-path-parameters';
import {
  DRIFT_SCORE_KEY,
  NON_META_COLORING_OPTIONS
} from '@/fine-tune/constants/embeddings.constants';
import { parseMetaColumnName } from '@/fine-tune/data-parsers/parse-meta-column-name/parse-meta-column-name';
import { useFilterParams } from '@/fine-tune/hooks/use-filters-params/use-filter-params';
import { useEmbeddingsStore } from '@/fine-tune/stores/embeddings-store/embeddings.store';
import {
  useComputedParameters,
  useParametersStore,
  useParametersStoreActions
} from '@/fine-tune/stores/parameters-store';
import useStore from '@/fine-tune/stores/store';
import { MetaColumn } from '@/fine-tune/types/response.types';

import { useRunGlobalMetrics } from '../use-run-global-metrics/use-run-global-metrics';

export const META_COLUMNS = 'META_COLUMNS';

export const getFirstFiveUserLoggedMetaCols = (
  metaCols: MetaColumn[] = [],
  dataframeColumns: string[],
  featureImportances?: Record<string, number>,
  isS2S?: boolean
) => {
  const sortByImportance = (a: string, b: string) => {
    const aImportance = featureImportances?.[a] || 0;
    const bImportance = featureImportances?.[b] || 0;

    return bImportance - aImportance;
  };

  const defaultSort = (a: string) => {
    if (isS2S) {
      if (['chat_id', 'turn_id', 'bleu', 'rouge'].includes(a)) {
        return -1;
      }
    }

    return 0;
  };

  const useFeatureImportances =
    featureImportances && Object.keys(featureImportances).length > 0;

  const firstFive = metaCols
    .filter((col) => !col?.name?.match(/galileo/))
    .map(({ name }) => name)
    .sort(useFeatureImportances ? sortByImportance : defaultSort)
    .slice(0, 5);

  const allColumnsPresent = firstFive.every((col) =>
    dataframeColumns.includes(col)
  );

  if (allColumnsPresent) {
    return null;
  }

  const mergedColumns = [...dataframeColumns, ...firstFive];

  return [...new Set(mergedColumns)];
};

export const useMetaColumns = () => {
  // Path parameters
  const { runId, projectId } = usePathParameters();

  // Store parameters
  const split = useParametersStore((state) => state.split);
  const inferenceName = useParametersStore((state) => state.inferenceName);
  const dataframeColumns = useParametersStore(
    (state) => state.dataframeColumns
  );
  const task = useParametersStore((state) => state.task);
  const shouldAddColumns = useStore((state) => state.shouldAddColumns);
  const setShouldAddColumns = useStore(
    (state) => state.actions.setShouldAddColumns
  );

  const { featureImportances, isLoading } = useRunGlobalMetrics();

  const filter_params = useFilterParams();

  const { isInference, isMltc, isSD, isS2S } = useComputedParameters();

  const { setParameters } = useParametersStoreActions();

  const colorBy = useEmbeddingsStore((s) => s.colorBy);
  const path =
    '/projects/{project_id}/runs/{run_id}/split/{split}/meta/columns';

  let enabled = Boolean(projectId && runId && split);

  if (isInference) {
    enabled = Boolean(enabled && inferenceName);
  }

  if (isMltc) {
    enabled = Boolean(enabled && task);
  }

  if (isSD) {
    enabled = Boolean(enabled && !isLoading);
  }

  const fetchMetaColumns = async () => {
    const res = await api.POST(path, {
      body: { filter_params, task },
      params: {
        path: {
          project_id: projectId!,
          run_id: runId!,
          split
        },
        query: {
          inference_name: inferenceName
        }
      }
    });

    return res.data;
  };

  const result = useQuery(
    [
      META_COLUMNS,
      { projectId, runId, split, filter_params, task, inferenceName }
    ],
    () =>
      // Only run this function on data change => https://github.com/TanStack/query/issues/936
      fetchMetaColumns().then((res) => {
        // We only want to add columns on initial run load
        if (!shouldAddColumns) {
          return res;
        }

        const baseAndMetaColumns = getFirstFiveUserLoggedMetaCols(
          res?.meta as MetaColumn[],
          dataframeColumns || [],
          featureImportances as Record<string, number>,
          isS2S
        );

        if (baseAndMetaColumns) {
          setParameters({
            dataframeColumns: baseAndMetaColumns
          });
          setShouldAddColumns(false);
        }

        return res;
      }),
    {
      enabled
    }
  );

  const metaColumnNames =
    result?.data?.meta?.map((column) => column.name) || [];

  const getColumnValues = (columnName: string) =>
    result?.data?.meta
      ?.find((col) => col.name === columnName)
      ?.unique_values?.sort()
      .map((val) => `${val}`) || [];

  const isContinuousMeta = (columnName: string) =>
    columnName === 'confidence' ||
    result?.data?.meta?.find((col) => col.name === columnName)?.is_continuous;

  const isContinuousAndNotCategorical = (columnName: string) => {
    const column = result?.data?.meta?.find((col) => col.name === columnName);
    return column?.is_continuous && !column?.is_categorical;
  };

  const isColorByMetaData = !NON_META_COLORING_OPTIONS.includes(colorBy);

  const continuousFilters =
    result?.data?.meta
      ?.filter((col) => col.is_continuous)
      ?.map((col) => ({
        label: parseMetaColumnName(col.name),
        value: col.name
      })) || [];

  const colorByMetaOptions =
    metaColumnNames?.map((name) => ({
      label: parseMetaColumnName(name),
      value: name
    })) || [];

  const isColorByContinuousMeta =
    (isColorByMetaData && isContinuousMeta(colorBy)) ||
    colorBy === 'perplexity';
  const isColorByCategoricalMeta =
    isColorByMetaData &&
    !isContinuousMeta(colorBy) &&
    ![DRIFT_SCORE_KEY, 'perplexity'].includes(colorBy);

  return {
    ...result,
    metaColumnNames,
    getColumnValues,
    isContinuousMeta,
    isContinuousAndNotCategorical,
    isColorByMetaData,
    isColorByContinuousMeta,
    isColorByCategoricalMeta,
    colorByMetaOptions,
    continuousFilters
  };
};
