import { useQueryClient, useSuspenseQuery } from '@tanstack/react-query';
import { useUser } from '../../UserContext';
import { mixEvalTests } from '../../Components/Tailor/ModelEvaluations/MixEvalTests';
import { baseModelRankingData } from '../../Components/Tailor/ModelEvaluations/BaseModelEvaluations';

const prepareModelRankingData = (data, queryClient) => {
  const rankingData = data
    .filter((model) => model.model_evaluation?.mix_eval)
    .map((model) => ({
      name: model.model_name,
      displayName: model.model_name,
      id: model.model_id,
      trainingDate: model.training_ended_at_unix,
      evaluations: mixEvalTests.reduce((acc, test) => {
        acc[test] = model.model_evaluation?.mix_eval?.[test]
          ? Math.round(
              (model.model_evaluation.mix_eval[test] + Number.EPSILON) * 100,
            ) / 100
          : null;
        return acc;
      }, {}),
    }));
  const combinedRankingData = [...baseModelRankingData, ...rankingData];
  queryClient.setQueryData(['rankingData'], combinedRankingData);
};

export const useGetModels = () => {
  const { customAxios } = useUser();
  const queryClient = useQueryClient();

  return useSuspenseQuery({
    queryKey: ['models'],
    queryFn: async () => {
      try {
        const response = await customAxios.get('tailor/v1/models');
        prepareModelRankingData(response?.data?.message, queryClient);
        return response?.data?.message;
      } catch (error) {
        if (import.meta.env.DEV) {
          console.error('Error fetching models:', error);
        }
        throw error;
      }
    },
    suspense: true,
    staleTime: 1000 * 60 * 2, // Consider data fresh for 2 minutes
    cacheTime: 1000 * 60 * 30, // Keep data in cache for 30 minutes
  });
};

export const useGetPipelines = () => {
  const { customAxios } = useUser();

  return useSuspenseQuery({
    queryKey: ['pipelines'],
    queryFn: async () => {
      try {
        const response = await customAxios.get(
          '/tailor/v1/continuous-training-pipeline',
        );
        return response.data.pipelines;
      } catch (error) {
        if (import.meta.env.DEV) {
          console.error('Error fetching pipelines:', error);
        }
        throw new Error('Error fetching pipelines:', error);
      }
    },
    suspense: true,
    staleTime: 1000 * 60 * 5, // Consider data fresh for 2 minutes
    cacheTime: 1000 * 60 * 30, // Keep data in cache for 30 minutes
  });
};

export const useGetBaseModels = () => {
  const { customAxios } = useUser();

  return useSuspenseQuery({
    queryKey: ['baseModels'],
    queryFn: async () => {
      try {
        const response = await customAxios.get('tailor/v1/base_models');
        const models = response.data.models;

        // Apply the same model type and requirements logic
        for (const model of models) {
          model.type = 'base_model';
          model.state = 'deployed';
          model.family = modelFamily(model);
          model.image_url = getImageUrlForFamily(model.family);
          model.model_config = {
            base_model: model.model_name,
          };
          if (!model.model_type) {
            model.model_type = 'language_model';
          }
          if (model.model_type === 'language_model') {
            if (model.model_name.includes('mistral')) {
              model.min_logs_required = 1_000;
              model.good_number_of_logs_required = 2_000;
              model.excellent_number_of_logs_required = 5_000;
            } else if (model.model_name.includes('mixtral')) {
              model.min_logs_required = 2_000;
              model.good_number_of_logs_required = 4_000;
              model.excellent_number_of_logs_required = 8_000;
            } else if (
              model.model_name.includes('llama') ||
              model.model_name.includes('gemma') ||
              model.model_name.includes('dbrx') ||
              model.model_name.includes('phi')
            ) {
              model.min_logs_required = 8_000;
              model.good_number_of_logs_required = 12_000;
              model.excellent_number_of_logs_required = 18_000;
            }
          }
        }
        return models;
      } catch (error) {
        if (import.meta.env.DEV) {
          console.error('Error fetching base models:', error);
        }
        // Return fallback models in case of error
        return [
          {
            id: 1,
            display_name: 'mistral-7b',
            model_name: 'mistral-7b-instruct-v0.3',
            available_for_inference: true,
            available_for_fine_tuning: true,
            model_type: 'language_model',
          },
          {
            id: 2,
            display_name: 'mixtral-8x7b',
            model_name: 'mixtral-8x7b-instruct-v0.1',
            available_for_inference: true,
            available_for_fine_tuning: true,
            model_type: 'language_model',
          },
        ];
      }
    },
    suspense: true,
    staleTime: 1000 * 60 * 2, // Consider data fresh for 2 minutes
    cacheTime: 1000 * 60 * 30, // Keep data in cache for 30 minutes
  });
};

export const useGetCustomEvals = () => {
  const { customAxios } = useUser();

  return useSuspenseQuery({
    queryKey: ['customEvals'],
    queryFn: async () => {
      try {
        const response = await customAxios.get(
          '/tailor/v1/custom_eval/get_eval',
        );
        return response.data.data;
      } catch (error) {
        if (import.meta.env.DEV) {
          console.error('Error fetching custom evaluations:', error);
        }
        throw new Error('Error fetching custom evaluations:', error);
      }
    },
    suspense: true,
    staleTime: 1000 * 60 * 2, // Consider data fresh for 2 minutes
    cacheTime: 1000 * 60 * 30, // Keep data in cache for 30 minutes
  });
};

export const useGetRankingData = () => {
  const queryClient = useQueryClient();
  const { customAxios } = useUser();

  return useSuspenseQuery({
    queryKey: ['rankingData'],
    queryFn: async () => {
      const existingData = queryClient.getQueryData(['rankingData']);
      if (!existingData) {
        await queryClient.fetchQuery({
          queryKey: ['models'],
          queryFn: async () => {
            const response = await customAxios.get('tailor/v1/models');
            prepareModelRankingData(response?.data?.message, queryClient);
            return response?.data?.message;
          },
        });
      }
      return queryClient.getQueryData(['rankingData']);
    },
    suspense: true,
    staleTime: 1000 * 60 * 2,
    cacheTime: 1000 * 60 * 30,
  });
};

const getImageUrlForFamily = (family) => {
  switch (family) {
    case 'Mistral':
    case 'Mixtral':
      return '/mistral.png';
    case 'LLaMA':
      return '/meta.png';
    case 'Gemma':
      return '/googleicon.png';
    case 'Phi':
      return '/microsoft.svg';
    case 'Salesforce':
      return '/salesforce-logo.png';
    case 'Databricks':
      return '/databricks.png';
    default:
      return '';
  }
};

// Safely check if display_name exists before using toLowerCase()
const modelFamily = (model) =>
  (() => {
    const modelName = model?.display_name?.toLowerCase() || '';
    switch (true) {
      case modelName.includes('mistral'):
        return 'Mistral';
      case modelName.includes('mixtral'):
        return 'Mixtral';
      case modelName.includes('llama'):
        return 'LLaMA';
      case modelName.includes('gemma'):
        return 'Gemma';
      case modelName.includes('phi'):
        return 'Phi';
      case modelName.includes('dbrx'):
        return 'Databricks';
      case modelName.includes('sfr'):
        return 'Salesforce';
      default:
        return 'Others';
    }
  })();
