// SPDX-FileCopyrightText: 2023 TRUMPF Laser GmbH
//
// SPDX-License-Identifier: LicenseRef-TRUMPF
import { ImageEvaluationMessage } from '@tls/sw91-communication/types/com.api_predictor';
import { HubContext } from 'components/common/emitter/HubContext';
import { ProjectVersion } from 'model/ProjectMetaMessageExtensions';
import { useContext, useEffect, useRef, useState } from 'react';
import sortBy from 'lodash/sortBy';
import { ModelEvaluationProgressResult } from 'model/ModelEvaluationProgressResult';

export type BoxPlotType = {
  median: number;
  upperWhisker: number;
  lowerWhisker: number;
  firstQuartile: number;
  thirdQuartile: number;
};

export type IdWithValue = {
  imageMd5: string;
  value: number;
};

export default function useModelEvaluationApi(projectId: string, version: ProjectVersion, evaluationId?: string) {
  const [evaluation, setEvaluation] = useState<IdWithValue[]>([]);
  const [progress, setProgress] = useState(0);
  const [completed, setCompleted] = useState(false);
  const [boxPlot, setBoxPlot] = useState<BoxPlotType>();
  const data = useRef<ImageEvaluationMessage[]>([]);

  const { events } = useContext(HubContext);

  useEffect(() => {
    function getData(message: ModelEvaluationProgressResult) {
      if (message.messageHeader?.projectId !== projectId) return;
      if (message.evaluationResult?.strategyVersion !== version) return;
      if (!message.evaluationResult) return;
      if (message.id !== evaluationId) return;

      data.current.push(message.evaluationResult);

      const sorted = sortBy(data.current, ['imageMd5']);

      setBoxPlot(calculateBoxPlotData(getValues(sorted)));
      setEvaluation(sorted.map(GetIdWithValue));

      setProgress(message.progress ?? 0);
      setCompletedIfDone(completed, setCompleted, message.progress);
    }

    events.on('onEvaluationCompleted', getData);

    return () => {
      events.off('onEvaluationCompleted', getData);
    };
  }, [completed, evaluationId, events, version, projectId]);

  return { evaluation, boxPlot, progress: completed ? 100 : progress };
}

function setCompletedIfDone(completed: boolean, setCompleted: (isCompleted: boolean) => void, progress?: number) {
  const currentProgress = progress ?? 0;
  if (currentProgress < 1 || completed) return;

  setCompleted(true);
}

function getValues(sorted: ImageEvaluationMessage[]): number[] {
  return sorted
    .map(d => {
      const classes = d.evaluationResults.find(e => e.name === 'IoU')?.values;
      if (!classes || (classes?.length ?? 0) < 1) return -Number.MAX_VALUE;
      return classes[0];
    })
    .filter(c => c !== -Number.MAX_VALUE);
}

function GetIdWithValue(message: ImageEvaluationMessage): IdWithValue {
  return {
    imageMd5: message.imageMd5 ?? '',
    value: message.evaluationResults.find(e => e.name === 'IoU')?.values[0] ?? -1,
  };
}

function getMedian(values: number[]) {
  const sortedInterval = [...values];
  sortedInterval.sort((a, b) => a - b);
  const count = sortedInterval.length;
  if (count % 2 === 1) {
    return sortedInterval[(count - 1) / 2];
  }

  return 0.5 * sortedInterval[count / 2] + 0.5 * sortedInterval[count / 2 - 1];
}

function calculateBoxPlotData(data: number[]): BoxPlotType {
  const values = [...data];
  values.sort((a, b) => a - b);
  const median = getMedian(values);

  const r = values.length % 2;
  const middle = (values.length + r) / 2;
  const firstQuartile = getMedian(values.slice(0, middle));
  const thirdQuartile = getMedian(values.slice(middle));

  const iqr = thirdQuartile - firstQuartile;
  const step = iqr * 1.5;
  let upperWhisker = thirdQuartile + step;
  upperWhisker = Math.max(...values.filter(v => v <= upperWhisker));
  let lowerWhisker = firstQuartile - step;
  lowerWhisker = Math.min(...values.filter(v => v >= lowerWhisker));

  return { firstQuartile, lowerWhisker, median, thirdQuartile, upperWhisker };
}
