import { MediaSplitName } from '@/constants/stats_card';
import { useGetDefaultEvaluationSetsCountQuery } from '@/serverStore/evaluationSets';
import { useGetModelPerformanceSummaryQuery } from '@/serverStore/modelAnalysis';
import {
  LabelType,
  ModelStatus,
  RegisteredModel,
  SplitConfusionMatrices,
} from '@clef/shared/types';
import { useMemo } from 'react';

const getPerformanceText = (metrics: number | null) => {
  if (metrics === null) {
    return t('--');
  } else {
    return t('{{performance}}%', { performance: Math.round((metrics ?? 0) * 100) });
  }
};

const getFormattedConfusionMatricesArray = (
  confusionMatrices: SplitConfusionMatrices | undefined,
  labelType?: LabelType | null,
) => {
  return [
    ...(labelType === LabelType.Classification
      ? []
      : [
          {
            count: confusionMatrices?.falsePositive.count ?? 0,
            name: t('False Positive'),
            isCorrect: false,
            data: confusionMatrices?.falsePositive.data,
          },
          {
            count: confusionMatrices?.falseNegative.count ?? 0,
            name: t('False Negative'),
            isCorrect: false,
            data: confusionMatrices?.falseNegative.data,
          },
        ]),
    {
      count: confusionMatrices?.misClassification.count ?? 0,
      name: t('Misclassified'),
      isCorrect: false,
      data: confusionMatrices?.misClassification.data,
    },
    {
      count: confusionMatrices?.correct.count ?? 0,
      name: t('Correct'),
      isCorrect: true,
      data: confusionMatrices?.correct.data,
    },
  ];
};

export const useModelPerformanceData = (model: RegisteredModel) => {
  const { data: modelPerformanceSummary, isLoading: isModelPerformanceSummaryLoading } =
    useGetModelPerformanceSummaryQuery(
      model.status !== ModelStatus.Terminated ? model.id : undefined,
    );
  const { performance } = modelPerformanceSummary ?? {};

  const performanceArray = useMemo(() => {
    return [
      {
        f1: getPerformanceText(performance?.train?.f1 ?? null),
        setName: t('Train set'),
      },
      {
        f1: getPerformanceText(performance?.dev?.f1 ?? null),
        setName: t('Dev set'),
      },
      {
        f1: getPerformanceText(performance?.test?.f1 ?? null),
        setName: t('Test set'),
      },
    ];
  }, [performance]);

  return {
    isLoading: isModelPerformanceSummaryLoading,
    performanceArray,
  };
};

export const useModelPerformanceWithImageCountData = (model: RegisteredModel) => {
  const { data: modelPerformanceSummary, isLoading: isModelPerformanceSummaryLoading } =
    useGetModelPerformanceSummaryQuery(model.id);
  const { performance } = modelPerformanceSummary ?? {};
  const { data: defaultEvaluationSetsCount, isLoading: isDefaultEvaluationSetsCountLoading } =
    useGetDefaultEvaluationSetsCountQuery(model);

  const performanceArrayWithImageCount = useMemo(() => {
    return [
      {
        f1: getPerformanceText(performance?.train?.f1 ?? null),
        precision: getPerformanceText(performance?.train?.precision ?? null),
        recall: getPerformanceText(performance?.train?.recall ?? null),
        setName: t('Train set'),
        imageCount: defaultEvaluationSetsCount?.[MediaSplitName.Train],
      },
      {
        f1: getPerformanceText(performance?.dev?.f1 ?? null),
        precision: getPerformanceText(performance?.dev?.precision ?? null),
        recall: getPerformanceText(performance?.dev?.recall ?? null),
        setName: t('Dev set'),
        imageCount: defaultEvaluationSetsCount?.[MediaSplitName.Dev],
      },
      {
        f1: getPerformanceText(performance?.test?.f1 ?? null),
        precision: getPerformanceText(performance?.test?.precision ?? null),
        recall: getPerformanceText(performance?.test?.recall ?? null),
        setName: t('Test set'),
        imageCount: defaultEvaluationSetsCount?.[MediaSplitName.Test],
      },
    ];
  }, [performance, defaultEvaluationSetsCount]);

  return {
    isLoading: isModelPerformanceSummaryLoading || isDefaultEvaluationSetsCountLoading,
    performanceArrayWithImageCount,
  };
};

export const useModelPerformanceConfusionMatricesData = (
  model: RegisteredModel,
  labelType?: LabelType | null,
) => {
  const { data: modelPerformanceSummary, isLoading: isModelPerformanceSummaryLoading } =
    useGetModelPerformanceSummaryQuery(
      model.status !== ModelStatus.Terminated ? model.id : undefined,
    );
  const { confusionMatrices } = modelPerformanceSummary ?? {};

  const allConfusionMatricesArray = useMemo(() => {
    return getFormattedConfusionMatricesArray(confusionMatrices?.all, labelType);
  }, [confusionMatrices]);

  const trainConfusionMatricesArray = useMemo(() => {
    return getFormattedConfusionMatricesArray(confusionMatrices?.train, labelType);
  }, [confusionMatrices]);

  const devConfusionMatricesArray = useMemo(() => {
    return getFormattedConfusionMatricesArray(confusionMatrices?.dev, labelType);
  }, [confusionMatrices]);

  const testConfusionMatricesArray = useMemo(() => {
    return getFormattedConfusionMatricesArray(confusionMatrices?.test, labelType);
  }, [confusionMatrices]);

  return {
    isLoading: isModelPerformanceSummaryLoading,
    allConfusionMatricesArray,
    trainConfusionMatricesArray,
    devConfusionMatricesArray,
    testConfusionMatricesArray,
  };
};
