import { QueryKey, useQueryClient } from "@tanstack/react-query";
import { z } from "zod";

type CacheSnapshotType<CachedDataType> = {
  cachedData: CachedDataType | undefined;
  queryKey: QueryKey;
};

export type OptimisticUpdateContextType<CachedDataType> = {
  snapshot: CacheSnapshotType<CachedDataType>[];
};

export default function useOptimisticUpdate<CachedDataType>(
  queryKey: QueryKey,
) {
  const queryClient = useQueryClient();

  const snapshotCache = () => {
    const cachedQueries = queryClient.getQueryCache().getAll();
    const allKeys = cachedQueries.map((query) => query.queryKey);
    const subkeys = allKeys.filter((key) => {
      const currentKey = z.array(z.string()).parse(key[0]);
      const baseKey = z.array(z.string()).parse(queryKey[0]);
      return currentKey.join() === baseKey.join();
    });

    return subkeys.map((queryKey) => {
      const cachedData = queryClient.getQueryData<CachedDataType>(queryKey);
      return { cachedData, queryKey };
    });
  };

  const updateCache = (
    cacheUpdater: (cachedData: CachedDataType | undefined) => CachedDataType,
  ): { snapshot: CacheSnapshotType<CachedDataType>[] } => {
    void queryClient.cancelQueries(queryKey);
    const snapshot = snapshotCache();
    snapshot.forEach(({ queryKey }) => {
      queryClient.setQueryData<CachedDataType>(queryKey, cacheUpdater);
    });
    return { snapshot };
  };

  const revertCache = <CachedDataType>(
    snapshot: CacheSnapshotType<CachedDataType>[],
  ) => {
    snapshot.forEach(({ cachedData, queryKey }) => {
      queryClient.setQueryData<CachedDataType>(queryKey, cachedData);
    });
  };

  const invalidate = async () => {
    await queryClient.invalidateQueries(queryKey);
  };

  return { invalidate, revertCache, updateCache };
}
