import type {Tensor3D} from '@tensorflow/tfjs-core';
import {Tensor, browser} from '@tensorflow/tfjs-core/dist/base.js';

/**
 * We support several ways to get image inputs.
 */
export type InputImage =
    | HTMLVideoElement
    | HTMLImageElement
    | HTMLCanvasElement;

import type {
    Size,
    Rect,
    Frame,
    Canvas,
    CanvasContext,
    Callback,
    Runner,
} from '../types';
import {
    createAsyncCallbackLoop,
    subscribeVisibilityChangeEvent,
} from '../utils';

import type {
    MaskUnderlyingType,
    Mask,
    Segmentation,
    Color,
    ImageType,
    ProcessInputType,
} from './types';

/**
 * Draw the source onto the provided canvas according to the destination rect
 *
 * @param canvasContext - A buffer canvas used for drawing the frame
 * @param source - Source to take the frame from
 * @param destination - A destination rect to be draw on the canvas
 */
export const drawFrame = (
    canvasContext: CanvasContext,
    source: Frame,
    destination: Rect,
) => {
    canvasContext.drawImage(
        source.source,
        0,
        0,
        source.width,
        source.height,
        destination.x,
        destination.y,
        destination.width,
        destination.height,
    );
};

/**
 * Taking a frame from the source and return the `ImageData`
 *
 * @param canvasContext - A buffer canvas used for drawing the frame
 * @param source - Source to take the frame from
 * @param destination - A destination rect to be draw on the canvas for getting the
 * data
 */
export const takeFrame = (
    canvasContext: CanvasContext,
    source: Frame,
    destination: Rect,
) => {
    drawFrame(canvasContext, source, destination);
    return canvasContext.getImageData(
        0,
        0,
        destination.width,
        destination.height,
    );
};

/**
 * Draw provided ImageData onto a canvas. This is just a wrapper of
 * `putImageData` with both `dx` and `dy` set to 0
 *
 * @param canvasContext - The canvas context to draw on
 * @param data - The data to draw
 */
export const drawImageData = (
    canvasContext: CanvasContext,
    data: ImageData,
) => {
    canvasContext.putImageData(data, 0, 0);
};

/**
 * Create a Canvas element with provided width and height
 *
 * @param width - canvas.width
 * @param height - canvas.height
 */
export const createCanvas = (width: number, height: number) => {
    const canvas = document.createElement('canvas');
    canvas.width = width;
    canvas.height = height;
    return canvas;
};

/**
 * Create an OffscreenCanvas with provided width and height. When
 * OffscreenCanvas is not available, a Canvas element is returned.
 *
 * @param width - canvas.width
 * @param height - canvas.height
 */
export const createOffscreenCanvas = (width: number, height: number) => {
    try {
        const offscreen = new OffscreenCanvas(width, height);
        return offscreen;
    } catch {
        return createCanvas(width, height);
    }
};

/**
 * Create a video element with provided width and height, and set it to muted
 *
 * @param width - video.width
 * @param height - video.height
 */
export const createVideoElement = (width: number, height: number) => {
    const video = document.createElement('video');
    video.width = width;
    video.height = height;
    video.muted = true;
    return video;
};

export const setVideoElementSrc = (
    video: HTMLVideoElement,
    src: MediaProvider,
) => {
    let url = '';
    const revokeObjectURL = () => {
        if (url) {
            URL.revokeObjectURL(url);
        }
    };
    if (src instanceof MediaStream) {
        src.getVideoTracks().forEach(track => {
            track.addEventListener('ended', revokeObjectURL);
        });
    }
    if ('MediaSource' in window && src instanceof MediaSource) {
        src.addEventListener('sourceended', revokeObjectURL);
    }
    // https://developer.mozilla.org/en-US/docs/Web/API/HTMLMediaElement/srcObject
    try {
        video.srcObject = src;
    } catch (error: unknown) {
        if (error instanceof Error) {
            if (error.name === 'TypeError') {
                throw error;
            }
            if (
                ('MediaSource' in window && src instanceof MediaSource) ||
                src instanceof Blob
            ) {
                url = URL.createObjectURL(src);
            }
            video.src = url;
        }
    }
};

export const toVideoElement = (
    input: MediaStream,
    width: number,
    height: number,
): HTMLVideoElement => {
    const [settings = {}] = input
        .getVideoTracks()
        .map(track => track.getSettings());
    const video = createVideoElement(
        settings.width ?? width,
        settings.height ?? height,
    );
    video.playsInline = true;
    setVideoElementSrc(video, input);
    return video;
};

enum HTMLMediaElementReadyState {
    HaveNothing = 0,
    HaveMetaData = 1,
    HaveCurrentData = 2,
    HaveFutureData = 3,
    HaveEnoughData = 4,
}

export const playVideo = async (video: HTMLVideoElement) => {
    if (video.autoplay) {
        return;
    }
    if (video.readyState < HTMLMediaElementReadyState.HaveCurrentData) {
        // Not enough data, wait for the data and then play
        await Promise.race([
            new Promise<void>(resolve => {
                const waitForEvent = () => {
                    video.removeEventListener('loadeddata', waitForEvent);
                    resolve();
                };
                video.addEventListener('loadeddata', waitForEvent);
            }),
            // Use setTimeout to prevent a race to `loadeddata` event
            new Promise<void>(resolve => {
                setTimeout(() => {
                    resolve();
                }, 3000);
            }),
        ]);
    }
    return await video.play();
};

/**
 * Load image with provided image element
 */
export const loadImage = async (
    image: HTMLImageElement,
    src: string,
): Promise<void> => {
    image.src = src;
    return new Promise((resolve, reject) => {
        image.onload = () => resolve();
        image.onerror = reject;
    });
};

export const toNumber = (value: number | SVGAnimatedLength) =>
    value instanceof SVGAnimatedLength ? value.baseVal.value : value;

export const getCanvasRenderingContext2D = (
    canvas: Canvas,
    options?: CanvasRenderingContext2DSettings,
) => {
    const context = canvas.getContext('2d', options);
    if (!context) {
        throw new Error('Cannot get CanvasRenderingContext2D');
    }
    return context as CanvasRenderingContext2D;
};

export const getImageSize = (
    image:
        | VideoFrame
        | CanvasImageSource
        | OffscreenCanvas
        | ImageData
        | SVGImageElement
        | Tensor3D,
): Size => {
    if (image instanceof Tensor) {
        const [height = 0, width = 0] = image.shape.slice(0, 2);
        return {height, width};
    }
    if ('VideoFrame' in window && image instanceof VideoFrame) {
        return {height: image.displayHeight, width: image.displayWidth};
    }
    if (
        'offsetHeight' in image &&
        image.offsetHeight !== 0 &&
        'offsetWidth' in image &&
        image.offsetWidth !== 0
    ) {
        return {height: image.offsetHeight, width: image.offsetWidth};
    }
    if (
        'height' in image &&
        image.height !== 0 &&
        'width' in image &&
        image.width !== 0
    ) {
        return {height: toNumber(image.height), width: toNumber(image.width)};
    }
    throw new Error('Unknown input image');
};

export const toHTMLCanvasElementLossy = async (
    image: ImageData | SVGImageElement | OffscreenCanvas | Tensor3D,
) => {
    const {width, height} = getImageSize(image);
    const canvas = createCanvas(width, height);

    if (image instanceof Tensor) {
        await browser.toPixels(image, canvas);
        return canvas;
    }

    const context = getCanvasRenderingContext2D(canvas);
    if (image instanceof ImageData) {
        context.putImageData(image, 0, 0);
    } else {
        context.drawImage(image, 0, 0);
    }

    return canvas;
};

export const toImageDataLossy = async (image: CanvasImageSource | Tensor3D) => {
    const {width, height} = getImageSize(image);
    if (image instanceof Tensor) {
        return new ImageData(await browser.toPixels(image), width, height);
    }
    const canvas = createOffscreenCanvas(width, height);
    const context = getCanvasRenderingContext2D(canvas);
    context.drawImage(image, 0, 0);
    return context.getImageData(0, 0, canvas.width, canvas.height);
};

export const toTensorLossy = async (image: CanvasImageSource | ImageData) => {
    const pixelsInput =
        image instanceof SVGImageElement || image instanceof OffscreenCanvas
            ? await toHTMLCanvasElementLossy(image)
            : image;
    return browser.fromPixels(pixelsInput, 4);
};

/**
 * Construct a Segmentation object with tensorflow-models/shared utils
 */
export const toSegmentation = (
    data: CanvasImageSource | ImageData | Tensor3D,
    type: MaskUnderlyingType,
    maskValueToLabel: (value: number) => string,
): Segmentation => {
    const mask: Mask = {
        toCanvasImageSource: async () => {
            if (data instanceof HTMLCanvasElement) {
                return data;
            }
            if (
                data instanceof HTMLVideoElement ||
                data instanceof ImageBitmap ||
                data instanceof HTMLImageElement
            ) {
                const canvas = createCanvas(data.width, data.height);
                const context = canvas.getContext('2d');
                context?.drawImage(data, 0, 0, data.width, data.height);
                return canvas;
            }
            return await toHTMLCanvasElementLossy(data);
        },
        toImageData: async () => {
            if (data instanceof ImageData) {
                return data;
            }
            return await toImageDataLossy(data);
        },
        toTensor: async () => {
            if (data instanceof Tensor) {
                return data;
            }
            return await toTensorLossy(data);
        },
        getUnderlyingType: () => type,
    };
    return {
        maskValueToLabel,
        mask,
    };
};
export const flipCanvasHorizontal = (canvas: Canvas) => {
    const ctx = getCanvasRenderingContext2D(canvas);
    ctx.scale(-1, 1);
    ctx.translate(-canvas.width, 0);
};

export const drawStroke = (
    bytes: Uint8ClampedArray,
    row: number,
    column: number,
    width: number,
    radius: number,
    color: Color = {
        r: 0,
        g: 255,
        b: 255,
        a: 255,
    },
    // eslint-disable-next-line max-params -- avoid unnecessary object creation
) => {
    for (let i = -radius; i <= radius; i++) {
        for (let j = -radius; j <= radius; j++) {
            if (i !== 0 && j !== 0) {
                const n = (row + i) * width + (column + j);
                bytes[4 * n + 0] = color.r;
                bytes[4 * n + 1] = color.g;
                bytes[4 * n + 2] = color.b;
                bytes[4 * n + 3] = color.a;
            }
        }
    }
};

export const isSegmentationBoundary = (
    data: Uint8ClampedArray,
    row: number,
    column: number,
    width: number,
    isForegroundId: boolean[],
    alphaCutoff: number,
    radius = 1,
    // eslint-disable-next-line max-params -- avoid unnecessary object creation
): boolean => {
    let numberBackgroundPixels = 0;
    for (let i = -radius; i <= radius; i++) {
        for (let j = -radius; j <= radius; j++) {
            if (i !== 0 && j !== 0) {
                const n = (row + i) * width + (column + j);
                const foregroundColor = data[4 * n];
                const alphaColor = data[4 * n + 3];
                if (
                    (foregroundColor !== undefined &&
                        !isForegroundId[foregroundColor]) ||
                    (alphaColor !== undefined && alphaColor < alphaCutoff)
                ) {
                    numberBackgroundPixels += 1;
                }
            }
        }
    }
    return numberBackgroundPixels > 0;
};

export const toBinaryMask = async (
    segmentation: Segmentation | Segmentation[],
    foreground: Color = {
        r: 0,
        g: 0,
        b: 0,
        a: 0,
    },
    background: Color = {
        r: 0,
        g: 0,
        b: 0,
        a: 255,
    },
    drawContour = false,
    foregroundThreshold = 0.5,
    foregroundMaskValues = Array.from(Array(256).keys()),
    // eslint-disable-next-line max-params -- avoid unnecessary object creation
) => {
    const segmentations = !Array.isArray(segmentation)
        ? [segmentation]
        : segmentation;

    if (segmentations.length === 0) {
        return null;
    }

    const masks = await Promise.all(
        segmentations.map(segmentation => segmentation.mask.toImageData()),
    );
    const [imageData] = masks;
    if (!imageData) {
        return null;
    }
    const {width, height} = imageData;
    const bytes = new Uint8ClampedArray(width * height * 4);
    const alphaCutoff = Math.round(255 * foregroundThreshold);
    const isForegroundId = new Array<boolean>(256).fill(false);
    foregroundMaskValues.forEach(id => (isForegroundId[id] = true));

    for (let i = 0; i < height; i++) {
        for (let j = 0; j < width; j++) {
            const n = i * width + j;
            bytes[4 * n + 0] = background.r;
            bytes[4 * n + 1] = background.g;
            bytes[4 * n + 2] = background.b;
            bytes[4 * n + 3] = background.a;
            for (const mask of masks) {
                const maskForegroundColor = mask.data[4 * n];
                const maskAlphaColor = mask.data[4 * n + 3];
                if (
                    maskForegroundColor !== undefined &&
                    isForegroundId[maskForegroundColor] &&
                    maskAlphaColor !== undefined &&
                    maskAlphaColor >= alphaCutoff
                ) {
                    bytes[4 * n] = foreground.r;
                    bytes[4 * n + 1] = foreground.g;
                    bytes[4 * n + 2] = foreground.b;
                    bytes[4 * n + 3] = foreground.a;
                    if (
                        drawContour &&
                        i - 1 >= 0 &&
                        i + 1 < height &&
                        j - 1 >= 0 &&
                        j + 1 < width &&
                        isSegmentationBoundary(
                            mask.data,
                            i,
                            j,
                            width,
                            isForegroundId,
                            alphaCutoff,
                        )
                    ) {
                        drawStroke(bytes, i, j, width, 1);
                    }
                }
            }
        }
    }

    return new ImageData(bytes, width, height);
};

type Draw = (input: ImageType) => Promise<Canvas>;
export const createInputImageConvertor =
    (draw: Draw) =>
    async (input: ProcessInputType): Promise<InputImage> => {
        return (
            input instanceof ImageData || input instanceof ImageBitmap
                ? await draw(input)
                : input
        ) as InputImage;
    };

export const ensure = <T>(
    prop: T,
    message = 'Processor is not opened, please call open() method first',
) => {
    if (!prop) {
        throw new Error(message);
    }
    return prop;
};

/**
 * Feature detection to detect if there is requestVideoFrameCallback API
 * available as well as the input is an HTMLVideoElement
 *
 * @param input - An input to be verified for the feature
 */
const hasRequestVideoFrameCallback = (
    input: unknown,
): input is HTMLVideoElement =>
    typeof input === 'object' &&
    input instanceof HTMLVideoElement &&
    'requestVideoFrameCallback' in HTMLVideoElement.prototype;

/**
 * Get the estimated video frame rate from HTMLVideoElement, otherwise returns
 * a fallback rate provided by.
 * @see {@link https://developer.mozilla.org/en-US/docs/Web/API/HTMLVideoElement/getVideoPlaybackQuality}
 *
 * @param input - The input to be verified if the `getVideoPlaybackQuality` API
 * is available
 * @param frameRate - A fallback rate to be used when the API is not available
 */
const getFrameRate = (
    input: ProcessInputType | undefined,
    frameRate: number,
) => {
    if (input instanceof HTMLVideoElement) {
        const quality = input.getVideoPlaybackQuality();
        return (
            input.mozPresentedFrames ||
            quality.totalVideoFrames - quality.droppedVideoFrames
        );
    }
    return frameRate;
};

interface FrameCallbackRequestOptions {
    /**
     * Subscribe `visibilitychange` event from the DOM
     * @see {@link subscribeVisibilityChangeEvent}
     */
    subscribeVisibilityChange?: typeof subscribeVisibilityChangeEvent;
}

/**
 * Create a callback loop for video frame processing using
 * `requestVideoFrameCallback` under-the-hood when available otherwise our
 * fallback implementation based on `setTimeout`.
 *
 * @param callback - To be called by the loop
 * @param frameRate - A fallback frame rate when we are not able to get the rate
 * from API
 */
export const createFrameCallbackRequest = (
    callback: Callback<Promise<void>, [ProcessInputType]>,
    frameRate: number,
    {
        subscribeVisibilityChange = subscribeVisibilityChangeEvent,
    }: FrameCallbackRequestOptions = {},
) => {
    const props: {
        input?: ProcessInputType;
        callbackId: number;
        started: boolean;
        runner?: Runner<[ProcessInputType]>;
        frameRate: number;
        unsubscribe?: ReturnType<typeof subscribeVisibilityChange>;
    } = {callbackId: 0, frameRate, started: false};

    const getCallbackLoopRunner = (input: ProcessInputType) => {
        if (!props.runner) {
            const fallbackRequestCallback = async (input: ProcessInputType) => {
                if (props.runner) {
                    props.runner.frameRate = getFrameRate(
                        input,
                        props.frameRate,
                    );
                }
                await callback(input);
            };
            props.runner = createAsyncCallbackLoop(
                fallbackRequestCallback,
                getFrameRate(input, props.frameRate),
            );
        }
        return props.runner;
    };

    const cancelFrameCallback = () => {
        if (hasRequestVideoFrameCallback(props.input) && props.callbackId) {
            props.input.cancelVideoFrameCallback(props.callbackId);
            props.callbackId = 0;
        }
    };

    return {
        start: async (input: ProcessInputType) => {
            props.input = input;
            if (hasRequestVideoFrameCallback(input)) {
                // Subscribe visibility changes and switch to fallback callback
                // loop when the page is hidden since the
                // `requestVideoFrameCallback` is paused when the page is hidden
                props.unsubscribe = subscribeVisibilityChange(async hidden => {
                    if (props.started && props.input) {
                        if (hidden && props.callbackId) {
                            await getCallbackLoopRunner(props.input).start(
                                props.input,
                            );
                        } else {
                            getCallbackLoopRunner(props.input).stop();
                        }
                    }
                });
                await new Promise<void>(resolve => {
                    const wrap = async () => {
                        await callback(input);
                        if (!props.started) {
                            props.started = true;
                        }
                        resolve();
                        props.callbackId =
                            input.requestVideoFrameCallback(wrap);
                    };
                    props.callbackId = input.requestVideoFrameCallback(wrap);
                });
            } else {
                await getCallbackLoopRunner(input).start(input);
                props.started = true;
            }
        },
        stop: () => {
            cancelFrameCallback();
            props.runner?.stop();
            props.unsubscribe?.();
            props.started = false;
        },
        get frameRate() {
            return getFrameRate(props.input, frameRate);
        },
        set frameRate(value) {
            props.frameRate = value;
            if (props.runner) {
                props.runner.frameRate = frameRate;
            }
        },
    };
};
