import React, {
  useState,
  useEffect,
  Fragment,
  Suspense,
  useMemo,
  useCallback,
} from 'react';
import { Menu } from '@headlessui/react';
import MyModelCard from './MyModels/MyModelCard';
import MyBaseModelCard from './MyModels/MyBaseModelCard';
import MyModelTable from './MyModels/MyModelTable';
import LaunchFineTuneModal from './MyModels/LaunchFineTuneModal';
// import { mockBaseModels } from './MyModels/mockdata';
import { useGetModels } from '../../Hooks/react-query';
import ModelsSkeleton from './MyModels/ModelsSkeleton';
import clsx from 'clsx';
import ModelsErrorBoundary from './MyModels/ModelsErrorBoundary';
import useBaseModels from '../../Hooks/useBaseModels';
import { modelStateAtom, sortCriteriaAtom } from '../../context/atoms';
import { useAtom, useSetAtom } from 'jotai';
import { useQueryClient } from '@tanstack/react-query';
import { useUser } from '../../UserContext';
import { Link } from 'react-router-dom';

// Main wrapper component
const TailorMyModels = () => {
  return (
    <ModelsErrorBoundary>
      <Suspense fallback={<ModelsSkeleton viewMode="grid" />}>
        <TailorMyModelsContent />
      </Suspense>
    </ModelsErrorBoundary>
  );
};

// Content component that contains the main logic
const TailorMyModelsContent = () => {
  const [viewMode, setViewMode] = useState('grid');
  const [isModalOpen, setIsModalOpen] = useState(false);
  const [sortedModels, setSortedModels] = useState([]);
  const [sortCriteria, setSortCriteria] = useAtom(sortCriteriaAtom);
  const [modelState, setModelState] = useAtom(modelStateAtom);
  const { user } = useUser();
  const { data: modelsData, error } = useGetModels();
  const { getInferenceBaseModels } = useBaseModels();
  const [filteredModels, setFilteredModels] = useState([]);

  if (error) {
    throw error;
  }

  const baseModels = useMemo(() => {
    const models = getInferenceBaseModels();
    return models.filter(
      (model) =>
        model.available_for_inference &&
        (user.location_preference === 'default' ||
          model.supported_locations.includes(user.location_preference)),
    );
  }, [getInferenceBaseModels, user.location_preference]);

  useEffect(() => {
    setFilteredModels(
      [...modelsData].filter((model) => {
        return (
          user?.location_preference === 'default' ||
          model?.base_model_data?.supported_locations?.includes(
            user?.location_preference,
          )
        );
      }),
    );
  }, [modelsData, user.location_preference, user]);

  const handleSort = useCallback(
    (criteria) => {
      if (!modelsData) {
        return;
      }
      setSortCriteria(criteria);

      let sortedList;

      if (criteria === 'state') {
        sortedList = filteredModels.sort((a, b) => {
          const stateOrder = [
            'deployed',
            'undeploying',
            'failed_undeploy',
            'failed_deploy',
            'deploying',
            'dormant',
            'training',
            'start_training',
            'training-cancelled',
            'failed_training',
            'failed',
          ];
          const stateIndexA = stateOrder.indexOf(a.state);
          const stateIndexB = stateOrder.indexOf(b.state);
          if (stateIndexA !== stateIndexB) {
            return stateIndexA - stateIndexB;
          }
          return b.created_at_unix - a.created_at_unix;
        });
      } else if (criteria === 'lastUpdated') {
        sortedList = filteredModels.sort((a, b) => {
          return new Date(b.last_used_unix) - new Date(a.last_used_unix);
        });
      }
      setSortedModels(sortedList);
    },
    [filteredModels, modelsData, setSortCriteria],
  );

  useEffect(() => {
    handleSort(sortCriteria);
  }, [sortCriteria, filteredModels, handleSort]);

  useEffect(() => {
    for (const model of modelsData) {
      if (modelState[model?.model_id]) {
        model.state = modelState[model?.model_id];
      }
    }
    handleSort(sortCriteria);
  }, [modelState]);

  const toggleViewMode = () => {
    setViewMode((prevMode) => (prevMode === 'grid' ? 'table' : 'grid'));
  };

  return (
    <div className="flex flex-col min-h-screen bg-zinc-50 font-dmSans">
      <header className="sticky top-0 bg-zinc-50 z-50">
        <div
          className="flex items-center justify-between h-16 lg:p-4 pr-0 pl-4 text-xl font-medium text-zinc-800"
          role="banner"
        >
          Models
          <div className="flex items-center lg:gap-6 gap-2 relative">
            <LaunchFineTuneModal
              isOpen={isModalOpen}
              onClose={() => setIsModalOpen(false)}
            />
          </div>
        </div>
        <hr className="border-t border-zinc-300" />
        <div className="w-full h-10 flex justify-between items-baseline px-8 pb-8 pt-4 space-x-4">
          <div className="flex items-baseline">
            <h2 className="text-xl font-medium text-zinc-800">
              Your fine-tuned models
            </h2>
            <div className="flex items-baseline ml-4 text-sm text-zinc-700">
              Showing results for
              <span className="capitalize text-zinc-600 ml-1 underline underline-offset-4">
                {user.location_preference === 'uk'
                  ? 'UK'
                  : user.location_preference === 'default'
                    ? 'everywhere'
                    : user.location_preference}
              </span>
              <Link
                to="/tailor/settings#location"
                className="text-indigo-600 text-xs ml-2"
              >
                (Change)
              </Link>
            </div>
          </div>
          <div className="flex items-center gap-1">
            <button
              onClick={toggleViewMode}
              className="h-8 w-8 shadow-sm rounded flex items-center justify-center"
            >
              {viewMode === 'grid' ? (
                // Icon for 'Toggle to Table'
                <svg
                  xmlns="http://www.w3.org/2000/svg"
                  fill="none"
                  viewBox="0 0 24 24"
                  strokeWidth={1.5}
                  stroke="currentColor"
                  className="size-5 text-zinc-500"
                >
                  <path
                    strokeLinecap="round"
                    strokeLinejoin="round"
                    d="M3.375 19.5h17.25m-17.25 0a1.125 1.125 0 0 1-1.125-1.125M3.375 19.5h7.5c.621 0 1.125-.504 1.125-1.125m-9.75 0V5.625m0 12.75v-1.5c0-.621.504-1.125 1.125-1.125m18.375 2.625V5.625m0 12.75c0 .621-.504 1.125-1.125 1.125m1.125-1.125v-1.5c0-.621-.504-1.125-1.125-1.125m0 3.75h-7.5A1.125 1.125 0 0 1 12 18.375m9.75-12.75c0-.621-.504-1.125-1.125-1.125H3.375c-.621 0-1.125.504-1.125 1.125m19.5 0v1.5c0 .621-.504 1.125-1.125 1.125M2.25 5.625v1.5c0 .621.504 1.125 1.125 1.125m0 0h17.25m-17.25 0h7.5c.621 0 1.125.504 1.125 1.125M3.375 8.25c-.621 0-1.125.504-1.125 1.125v1.5c0 .621.504 1.125 1.125 1.125m17.25-3.75h-7.5c-.621 0-1.125.504-1.125 1.125m8.625-1.125c.621 0 1.125.504 1.125 1.125v1.5c0 .621-.504 1.125-1.125 1.125m-17.25 0h7.5m-7.5 0c-.621 0-1.125.504-1.125 1.125v1.5c0 .621.504 1.125 1.125 1.125M12 10.875v-1.5m0 1.5c0 .621-.504 1.125-1.125 1.125M12 10.875c0 .621.504 1.125 1.125 1.125m-2.25 0c.621 0 1.125.504 1.125 1.125M13.125 12h7.5m-7.5 0c-.621 0-1.125.504-1.125 1.125M20.625 12c.621 0 1.125.504 1.125 1.125v1.5c0 .621-.504 1.125-1.125 1.125m-17.25 0h7.5M12 14.625v-1.5m0 1.5c0 .621-.504 1.125-1.125 1.125M12 14.625c0 .621.504 1.125 1.125 1.125m-2.25 0c.621 0 1.125.504 1.125 1.125m0 1.5v-1.5m0 0c0-.621.504-1.125 1.125-1.125m0 0h7.5"
                  />
                </svg>
              ) : (
                <svg
                  xmlns="http://www.w3.org/2000/svg"
                  fill="none"
                  viewBox="0 0 24 24"
                  strokeWidth={1.5}
                  stroke="currentColor"
                  className="size-5 text-zinc-500"
                >
                  <path
                    strokeLinecap="round"
                    strokeLinejoin="round"
                    d="M3.75 6A2.25 2.25 0 0 1 6 3.75h2.25A2.25 2.25 0 0 1 10.5 6v2.25a2.25 2.25 0 0 1-2.25 2.25H6a2.25 2.25 0 0 1-2.25-2.25V6ZM3.75 15.75A2.25 2.25 0 0 1 6 13.5h2.25a2.25 2.25 0 0 1 2.25 2.25V18a2.25 2.25 0 0 1-2.25 2.25H6A2.25 2.25 0 0 1 3.75 18v-2.25ZM13.5 6a2.25 2.25 0 0 1 2.25-2.25H18A2.25 2.25 0 0 1 20.25 6v2.25A2.25 2.25 0 0 1 18 10.5h-2.25a2.25 2.25 0 0 1-2.25-2.25V6ZM13.5 15.75a2.25 2.25 0 0 1 2.25-2.25H18a2.25 2.25 0 0 1 2.25 2.25V18A2.25 2.25 0 0 1 18 20.25h-2.25A2.25 2.25 0 0 1 13.5 18v-2.25Z"
                  />
                </svg>
              )}
            </button>
            <Menu as="div" className="relative">
              <Menu.Button className="h-8 w-8  shadow-sm rounded flex items-center justify-center">
                <svg
                  xmlns="http://www.w3.org/2000/svg"
                  fill="none"
                  viewBox="0 0 24 24"
                  strokeWidth={1.5}
                  stroke="currentColor"
                  className="size-5 text-zinc-500"
                >
                  <path
                    strokeLinecap="round"
                    strokeLinejoin="round"
                    d="M3 7.5 7.5 3m0 0L12 7.5M7.5 3v13.5m13.5 0L16.5 21m0 0L12 16.5m4.5 4.5V7.5"
                  />
                </svg>
              </Menu.Button>
              <Menu.Items className="absolute right-0 mt-2 w-48 origin-top-right bg-white border border-gray-200 divide-y divide-gray-100 rounded-md shadow-lg focus:outline-none">
                <Menu.Item>
                  {({ active }) => (
                    <div
                      className={clsx(
                        active && 'bg-zinc-50',
                        'group flex rounded-md items-center w-full px-2 py-2 text-sm h-12 lg:h-auto',
                        sortCriteria === 'state' && 'font-bold',
                      )}
                      onClick={() => handleSort('state')}
                    >
                      Sort by State
                    </div>
                  )}
                </Menu.Item>
                <Menu.Item>
                  {({ active }) => (
                    <div
                      className={clsx(
                        active && 'bg-zinc-50',
                        'group flex rounded-md items-center w-full px-2 py-2 text-sm h-12 lg:h-auto',
                        sortCriteria === 'lastUpdated' && 'font-bold',
                      )}
                      onClick={() => handleSort('lastUpdated')}
                    >
                      Sort by Last Updated
                    </div>
                  )}
                </Menu.Item>
              </Menu.Items>
            </Menu>
          </div>
        </div>
      </header>

      <main className="p-6">
        {viewMode === 'grid' ? (
          <>
            <div className="grid grid-cols-1 sm:grid-cols-2 lg:grid-cols-3 xl:grid-cols-4 gap-6">
              {sortedModels.map((model) => (
                <Fragment key={model.model_id}>
                  <MyModelCard model={model} />
                </Fragment>
              ))}
            </div>
            <h2 className="text-xl font-medium mt-8 mb-4 text-zinc-800">
              Base models
            </h2>
            <div className="grid grid-cols-1 sm:grid-cols-2 lg:grid-cols-3 xl:grid-cols-4 gap-6">
              {baseModels
                .sort((a, b) => a.family.localeCompare(b.family))
                .map((baseModel) => (
                  <Fragment key={baseModel.model_id}>
                    <MyBaseModelCard model={baseModel} />
                  </Fragment>
                ))}
            </div>
          </>
        ) : (
          <MyModelTable models={sortedModels} baseModels={baseModels} />
        )}
      </main>
    </div>
  );
};

export default TailorMyModels;
