import check from 'check-types';
import { shallow } from 'zustand/shallow';
import { createWithEqualityFn } from 'zustand/traditional';

import {
  COMPARISON_POINT_COLOR_OPTIONS,
  GOLD_KEY,
  INITIAL_VIEW_STATE
} from '@/fine-tune/constants/embeddings.constants';
import {
  Bounds,
  ColorBy,
  ComparedToColor,
  DragMode,
  Point,
  PointRow,
  TooltipPosition,
  ViewState
} from '@/fine-tune/types/embeddings.types';

import { EmbeddingsState, ShapeCoords } from './embeddings.store.types';

const defaultValues = {
  activePoint: null,
  brushRadius: 0.3,
  chartBounds: [
    [0, 0],
    [0, 0]
  ],
  colorBy: GOLD_KEY,
  comparedToColor: COMPARISON_POINT_COLOR_OPTIONS.gray,
  continuousMetaRange: undefined,
  bucketRange: null,
  dragMode: 'pan',
  hiddenTraces: [],
  hideInactivePoints: false,
  hoverCluster: null,
  inactivePoints: null,
  isDefaultView: true,
  isTooltipClicked: false,
  metaColorMap: {},
  radiusMinPixels: 1,
  activePoints: null,
  shapeSelection: null,
  shapeIds: [],
  tabSelection: undefined,
  tooltipPosition: {
    top: null,
    bottom: null,
    left: null,
    right: null
  },
  viewState: INITIAL_VIEW_STATE
} as any; // TODO: Fix

export const useEmbeddingsStore = createWithEqualityFn<EmbeddingsState>()(
  (set) => ({
    ...defaultValues,
    actions: {
      resetStore: () => set(() => defaultValues),
      resetHiddenTraces: () => set(() => ({ hiddenTraces: [] })),
      setActivePoint: (pt: Point | null) => set(() => ({ activePoint: pt })),
      setBucketRange: (range) => set(() => ({ bucketRange: range })),
      setBrushRadius: (radius: number) => set(() => ({ brushRadius: radius })),
      setChartBounds: (bounds: Bounds) => set(() => ({ chartBounds: bounds })),
      setColorBy: (key: ColorBy) => set(() => ({ colorBy: key })),
      setComparedToColor: (key: ComparedToColor) =>
        set(() => ({ comparedToColor: key })),
      setContinuousMetaRange: (range) =>
        set(() => ({ continuousMetaRange: range })),
      setDragMode: (mode: DragMode) => set(() => ({ dragMode: mode })),
      setHideInactivePoints: (bool?: boolean) =>
        set((state) => {
          const toggledValue = check.assigned(bool)
            ? bool
            : !state.hideInactivePoints;

          return { hideInactivePoints: toggledValue };
        }),
      setInactivePoints: (pts: PointRow[] | null) =>
        set(() => ({ inactivePoints: pts })),
      setIsDefaultView: (bool: boolean) => set(() => ({ isDefaultView: bool })),
      setHiddenTraces: (traces: string[]) =>
        set(() => ({ hiddenTraces: traces })),
      setHoverCluster: (cluster: number | null) =>
        set(() => ({ hoverCluster: cluster })),
      setViewState: (state: ViewState) => set(() => ({ viewState: state })),
      setIsTooltipClicked: (bool: boolean) =>
        set(() => ({ isTooltipClicked: bool })),
      setMetaColorMap: (map) => set(() => ({ metaColorMap: map })),
      setRadiusMinPixels: (val: number) =>
        set(() => ({ radiusMinPixels: val })),
      setActivePoints: (pts: PointRow[] | null) =>
        set(() => ({ activePoints: pts })),
      setShapeSelection: (selection: ShapeCoords | null, tab?: string) =>
        set(() => ({ shapeSelection: selection, tabSelection: tab })),
      setShapeIds: (ids: number[]) => set(() => ({ shapeIds: ids })),

      setTooltipPosition: (pos: TooltipPosition | null) =>
        set(() => ({ tooltipPosition: pos })),
      toggleHiddenTrace: (trace: string) =>
        set((state) => {
          let traces = [];

          if (state.hiddenTraces.includes(trace)) {
            traces = state.hiddenTraces.filter((stored) => stored !== trace);
          } else {
            traces = [...state.hiddenTraces, trace];
          }

          return { hiddenTraces: traces };
        }),
      updateEmbeddingsStore: (params) => set(() => ({ ...params }))
    }
  }),
  shallow
);
