import { experiment_status } from "@prisma/client";
import { getQueryKey } from "@trpc/react-query";
import {
  UseTRPCInfiniteQueryResult,
  UseTRPCQueryResult,
} from "@trpc/react-query/shared";
import { InternalTRPCOutput } from "src/backend/internal-api/internalTrpcRouter";
import TrpcClient from "src/frontend/api/TrpcClient";
import useToast from "src/frontend/components/ui/useToast";
import useOptimisticUpdate, {
  OptimisticUpdateContextType,
} from "src/frontend/hooks/useOptimisitcUpdate";
import {
  optimisticApprove,
  optimisticReject,
  optimisticRevert,
} from "src/frontend/pages/pricing/details/optimisticUpdatesForForecastTable";

export type GetExperimentsForecastOutput =
  InternalTRPCOutput["getExperimentsForecast"];

export type GetExperimentsForecastOutputInfinite = UseTRPCInfiniteQueryResult<
  GetExperimentsForecastOutput,
  Error
>["data"];

export type GetExperimentIDsAndStatusesForForecastOutput =
  InternalTRPCOutput["getExperimentIDsAndStatusesForForecast"];

export type GetExperimentIDsAndStatusesForForecast = UseTRPCQueryResult<
  GetExperimentIDsAndStatusesForForecastOutput,
  Error
>["data"];

function useUpdateExperiment() {
  const t = useToast();
  const queryKeyForForecastTable = getQueryKey(
    TrpcClient.internal.getExperimentsForecast,
  );

  const {
    invalidate: invalidateForecastTable,
    revertCache: revertCacheForForecastTable,
    updateCache: updateCacheForForecastTable,
  } = useOptimisticUpdate<GetExperimentsForecastOutputInfinite>(
    queryKeyForForecastTable,
  );

  const queryKey = getQueryKey(
    TrpcClient.internal.getExperimentIDsAndStatusesForForecast,
  );

  const { invalidate, revertCache, updateCache } =
    useOptimisticUpdate<GetExperimentIDsAndStatusesForForecast>(queryKey);

  const updateCacheForExperimentsStatuses = (
    input: { experimentIds: string[] },
    newStatus: experiment_status,
  ) =>
    updateCache((prev: GetExperimentIDsAndStatusesForForecast) => {
      const experimentsBeingModified = new Set(input.experimentIds);
      const newCache = prev?.map((experiment) => {
        if (experimentsBeingModified.has(experiment.experiment_id)) {
          return {
            ...experiment,
            experiment_status: newStatus,
          };
        }
        return experiment;
      });

      return newCache;
    });

  type OptimisticUpdateContextArray = [
    OptimisticUpdateContextType<GetExperimentsForecastOutputInfinite>,
    OptimisticUpdateContextType<GetExperimentIDsAndStatusesForForecast>,
  ];

  const handleError = (
    _err: unknown,
    _input: unknown,
    context: OptimisticUpdateContextArray | undefined,
  ) => {
    t.errorToast("Failed to save edits.");
    if (context == null) {
      return;
    }
    const [a, b] = context;
    revertCacheForForecastTable(a?.snapshot ?? []);
    revertCache(b?.snapshot ?? []);
  };

  const handleSettled = () => {
    void invalidateForecastTable();
    void invalidate();
  };

  const approveExperiment =
    TrpcClient.internal.approveExperiments.useMutation<OptimisticUpdateContextArray>(
      {
        onError: handleError,
        onMutate: (input) => {
          return [
            updateCacheForForecastTable(optimisticApprove(input)),
            updateCacheForExperimentsStatuses(
              input,
              experiment_status.APPROVED,
            ),
          ];
        },
        onSettled: handleSettled,
      },
    );

  const rejectExperiment =
    TrpcClient.internal.rejectExperiments.useMutation<OptimisticUpdateContextArray>(
      {
        onError: handleError,
        onMutate: (input) => {
          return [
            updateCacheForForecastTable(optimisticReject(input)),
            updateCacheForExperimentsStatuses(
              input,
              experiment_status.REJECTED,
            ),
          ];
        },
        onSettled: handleSettled,
      },
    );

  const revertExperiment =
    TrpcClient.internal.revertExperiments.useMutation<OptimisticUpdateContextArray>(
      {
        onError: handleError,
        onMutate: (input) => {
          return [
            updateCacheForForecastTable(optimisticRevert(input)),
            updateCacheForExperimentsStatuses(
              input,
              experiment_status.REVERTED,
            ),
          ];
        },
        onSettled: handleSettled,
      },
    );

  return { approveExperiment, rejectExperiment, revertExperiment };
}

export default useUpdateExperiment;
