<script>
  import { onMount } from 'svelte';
  import {
    env,
    AutoModel,
    AutoProcessor,
    RawImage,
  } from '@huggingface/transformers';

  // New prop to accept an image URL
  export let inputImageUrl = '';
  // New prop to control upload feature visibility
  export let showUploadFeature = false;
  // New prop for the callback function
  export let onGenerationComplete = processedImages => {};

  let images = [];
  let processedImages = [];
  let isProcessing = false;
  let isLoading = true;
  let error = null;

  let model;
  let processor;

  let fileInput;

  onMount(async () => {
    try {
      if (!navigator.gpu) {
        throw new Error('WebGPU is not supported in this browser.');
      }
      const model_id = 'Xenova/modnet';
      env.backends.onnx.wasm.proxy = false;
      model = await AutoModel.from_pretrained(model_id, {
        device: 'webgpu',
      });
      processor = await AutoProcessor.from_pretrained(model_id);
    } catch (err) {
      error = err;
    }
    isLoading = false;
  });

  // Watch for changes in inputImageUrl
  $: if (inputImageUrl && !images.includes(inputImageUrl)) {
    addImage(inputImageUrl);
  }

  async function onDrop(event) {
    const files = event.dataTransfer
      ? event.dataTransfer.files
      : event.target.files;
    for (const file of files) {
      const imageUrl = URL.createObjectURL(file);
      addImage(imageUrl);
    }
  }

  async function addImage(imageUrl) {
    images = [...images, imageUrl];
    await processImages();
  }

  function removeImage(index) {
    images = images.filter((_, i) => i !== index);
    processedImages = processedImages.filter((_, i) => i !== index);
  }

  async function processImages() {
    isProcessing = true;
    processedImages = [];

    for (let i = 0; i < images.length; ++i) {
      const img = await RawImage.fromURL(images[i]);
      const { pixel_values } = await processor(img);
      const { output } = await model({ input: pixel_values });
      const maskData = (
        await RawImage.fromTensor(output[0].mul(255).to('uint8')).resize(
          img.width,
          img.height,
        )
      ).data;

      const canvas = document.createElement('canvas');
      canvas.width = img.width;
      canvas.height = img.height;
      const ctx = canvas.getContext('2d');

      ctx.drawImage(img.toCanvas(), 0, 0);

      const pixelData = ctx.getImageData(0, 0, img.width, img.height);
      for (let i = 0; i < maskData.length; ++i) {
        pixelData.data[4 * i + 3] = maskData[i];
      }
      ctx.putImageData(pixelData, 0, 0);

      processedImages = [...processedImages, canvas.toDataURL('image/png')];
    }

    isProcessing = false;

    // Call the callback function with the processed images
    onGenerationComplete(processedImages);
  }

  function clearAll() {
    images = [];
    processedImages = [];
  }

  async function copyToClipboard(url) {
    try {
      const response = await fetch(url);
      const blob = await response.blob();
      const clipboardItem = new ClipboardItem({ [blob.type]: blob });
      await navigator.clipboard.write([clipboardItem]);
      console.log('Image copied to clipboard');
    } catch (err) {
      console.error('Failed to copy image: ', err);
    }
  }

  function downloadImage(url) {
    const link = document.createElement('a');
    link.href = url;
    link.download = 'image.png';
    document.body.appendChild(link);
    link.click();
    document.body.removeChild(link);
  }
</script>

{#if error}
  <div class="flex items-center justify-center text-xs text-white bg-black">
    <div class="text-center">
      <h2 class="mb-2 text-4xl">ERROR</h2>
      <p class="text-xl max-w-[500px]">{error.message}</p>
    </div>
  </div>
{:else if isLoading}
  <div class="items-center justify-center hidden text-white bg-black">
    <div class="text-center">
      <div
        class="inline-block w-8 h-8 mb-4 border-t-2 border-b-2 border-white rounded-full animate-spin"
      ></div>
      <p class="text-lg">Loading background removal model...</p>
    </div>
  </div>
{:else}
  <div class="p-8 text-white">
    <div class="max-w-6xl mx-auto">
      {#if showUploadFeature}
        <div
          on:dragenter|preventDefault
          on:dragover|preventDefault
          on:drop|preventDefault={onDrop}
          class="p-8 mb-8 text-center transition-colors duration-300 ease-in-out border-2 border-gray-700 border-dashed rounded-lg cursor-pointer hover:border-blue-500 hover:bg-blue-900/10"
        >
          <input
            bind:this={fileInput}
            type="file"
            on:change={onDrop}
            accept="image/*"
            multiple
            class="hidden"
          />
          <p class="mb-2 text-lg">Drag and drop some images here</p>
          <p class="text-sm text-gray-400">or click to select files</p>
        </div>
        <div class="flex flex-col items-center gap-4 mb-8">
          <div class="flex gap-4">
            <button
              on:click={clearAll}
              class="px-3 py-1 text-sm text-white transition-colors duration-200 bg-red-600 rounded-md hover:bg-red-700 focus:outline-none focus:ring-2 focus:ring-red-500 focus:ring-offset-2 focus:ring-offset-black"
            >
              Clear All
            </button>
          </div>
        </div>

        <div class="grid grid-cols-2 gap-6 md:grid-cols-3 lg:grid-cols-4">
          {#each images as src, index}
            <div class="relative group">
              <img
                src={processedImages[index] || src}
                alt={`Image ${index + 1}`}
                class="object-cover w-full h-48 rounded-lg"
              />
              {#if processedImages[index]}
                <div
                  class="absolute inset-0 flex items-center justify-center transition-opacity duration-300 bg-black rounded-lg opacity-0 bg-opacity-70 group-hover:opacity-100"
                >
                  <button
                    on:click={() =>
                      copyToClipboard(processedImages[index] || src)}
                    class="px-3 py-1 mx-2 text-sm text-gray-900 transition-colors duration-200 bg-white rounded-md hover:bg-gray-200"
                    aria-label={`Copy image ${index + 1} to clipboard`}
                  >
                    Copy
                  </button>
                  <button
                    on:click={() =>
                      downloadImage(processedImages[index] || src)}
                    class="px-3 py-1 mx-2 text-sm text-gray-900 transition-colors duration-200 bg-white rounded-md hover:bg-gray-200"
                    aria-label={`Download image ${index + 1}`}
                  >
                    Download
                  </button>
                </div>
              {/if}
              <button
                on:click={() => removeImage(index)}
                class="absolute flex items-center justify-center w-6 h-6 text-white transition-opacity duration-300 bg-black bg-opacity-50 rounded-full opacity-0 top-2 right-2 group-hover:opacity-100 hover:bg-opacity-70"
                aria-label={`Remove image ${index + 1}`}
              >
                &#x2715;
              </button>
            </div>
          {/each}
        </div>
      {/if}
    </div>
  </div>
{/if}
