import React, { useEffect, useMemo, useRef } from "react"

export interface ITrackedNode<El extends HTMLElement> {
  node: React.RefObject<El>
  key: string
  topOffsetPx?: number
}

export interface ITrackedNodeArg<El extends HTMLElement> {
  key: string
  topOffsetPx?: number
}

export enum ScrollTrackerState {
  before,
  within,
  after,
}

export const useScrollListener = (didScroll: () => void) => {
  useEffect(() => {
    document.addEventListener("scroll", didScroll)
    if (typeof window !== "undefined")
      window.addEventListener("resize", didScroll)
    return () => {
      document.removeEventListener("scroll", didScroll)
      if (typeof window !== "undefined")
        window.removeEventListener("resize", didScroll)
    }
  })
}

export const useScrollTracker = <
  ContainerEl extends HTMLElement,
  TrackedNodeEl extends HTMLElement
>({
  trackedNodes,
  ...args
}: {
  trackedNodes: ITrackedNodeArg<TrackedNodeEl>[]
  referenceTopOffsetPx?: number
  referenceBottomOffsetPx?: number
  referenceAnchor?: number
  callback: () => void
}) => {
  const container = useRef<ContainerEl>(null)
  const _trackedNodes: ITrackedNode<TrackedNodeEl>[] = []
  for (const node of trackedNodes) {
    _trackedNodes.push({
      node: useRef<TrackedNodeEl>(null),
      ...node,
    })
  }
  return useMemo(
    () =>
      new ScrollTracker({ container, trackedNodes: _trackedNodes, ...args }),
    [container]
  )
}

export default class ScrollTracker<
  ContainerEl extends HTMLElement,
  TrackedNodeEl extends HTMLElement
> {
  public container: React.RefObject<ContainerEl>
  public trackedNodes: ITrackedNode<TrackedNodeEl>[]
  private referenceTopOffsetPx: number
  private referenceBottomOffsetPx: number
  private referenceAnchor: number
  private _htmlNode: HTMLHtmlElement | null

  private _state: ScrollTrackerState
  private _activeNodeKey: string | null
  private _ticking: boolean = false
  private _ignoreScroll: boolean = false
  private _timer: any

  private callback: () => void

  constructor({
    container,
    trackedNodes,
    referenceTopOffsetPx = 0,
    referenceBottomOffsetPx = 0,
    referenceAnchor = 0,
    callback,
  }: {
    container: React.RefObject<ContainerEl>
    trackedNodes: ITrackedNode<TrackedNodeEl>[]
    referenceTopOffsetPx?: number
    referenceBottomOffsetPx?: number
    referenceAnchor?: number
    callback: () => void
  }) {
    this.container = container
    this.trackedNodes = trackedNodes
    this.referenceTopOffsetPx = referenceTopOffsetPx
    this.referenceBottomOffsetPx = referenceBottomOffsetPx
    this.referenceAnchor = referenceAnchor
    this.callback = callback

    this.didScroll()
  }

  private get htmlNode(): HTMLHtmlElement | null {
    if (typeof window === "undefined") {
      return null
    }
    if (this._htmlNode == null) {
      this._htmlNode = document.querySelector("html")
    }
    return this._htmlNode
  }

  public get state(): ScrollTrackerState {
    return this._state
  }

  public get activeNode(): string | null {
    return this._activeNodeKey
  }

  public get currentScrollPosition(): [number, number] {
    if (this.htmlNode == null) {
      return [0, 0]
    }

    return [
      this.htmlNode.scrollTop +
        window.innerHeight * this.referenceAnchor +
        this.referenceTopOffsetPx,
      this.htmlNode.scrollTop +
        window.innerHeight * this.referenceAnchor +
        this.referenceTopOffsetPx +
        this.referenceBottomOffsetPx,
    ]
  }

  public scrollToNode(key: string) {
    if (typeof window === "undefined") return

    for (const trackedNode of this.trackedNodes) {
      if (trackedNode.key == key) {
        const pos = trackedNode.node.current!.offsetTop

        this.lockScroll()

        window.scrollTo({
          top: pos - this.referenceTopOffsetPx - (trackedNode.topOffsetPx ?? 0),
          behavior: "smooth",
        })

        this._activeNodeKey = key
        this.callback()
      }
    }
  }

  private lockScroll() {
    this._ignoreScroll = true
    if (this._timer != null) {
      clearTimeout(this._timer)
    }
    this._timer = setTimeout(() => {
      this._ignoreScroll = false
      this._timer = null
    }, 750)
  }

  public refForItem(key: string) {
    for (const trackedNode of this.trackedNodes) {
      if (trackedNode.key == key) {
        return trackedNode.node
      }
    }

    return null
  }

  private firstNodeKey() {
    if (this.trackedNodes.length > 0) {
      return this.trackedNodes[0].key
    }
    return null
  }

  private lastNodeKey() {
    if (this.trackedNodes.length > 0) {
      return this.trackedNodes[this.trackedNodes.length - 1].key
    }
    return null
  }

  private calculateState(): void {
    const [topScrollPos, bottomScrollPos] = this.currentScrollPosition

    if (this.container.current == null) {
      return
    }

    let newActiveNodeKey = this._activeNodeKey
    const container = this.container.current!

    const containerOffsetTop = this.getScrollXOf(container)

    if (topScrollPos <= containerOffsetTop) {
      newActiveNodeKey = this.firstNodeKey()
      this._state = ScrollTrackerState.before
      return this.callback()
    }

    if (bottomScrollPos >= containerOffsetTop + container.offsetHeight) {
      newActiveNodeKey = this.lastNodeKey()
      this._state = ScrollTrackerState.after
      return this.callback()
    }

    this._state = ScrollTrackerState.within

    if (!this._ignoreScroll) {
      for (const trackedNode of this.trackedNodes) {
        if (trackedNode.node.current == null) {
          return
        }

        const domNode = trackedNode.node.current!
        const nodeOffsetTop = this.getScrollXOf(domNode)
        if (topScrollPos + (trackedNode.topOffsetPx ?? 0) >= nodeOffsetTop) {
          newActiveNodeKey = trackedNode.key
        }
      }

      this._activeNodeKey = newActiveNodeKey
    }
    this.callback()
  }

  private getScrollXOf(node: HTMLElement) {
    return (
      node.getBoundingClientRect().top +
      (typeof window !== "undefined" ? window.scrollY : 0)
    )
  }

  public didScroll({
    referenceBottomOffsetPx,
    referenceTopOffsetPx,
  }: {
    referenceBottomOffsetPx?: number
    referenceTopOffsetPx?: number
  } = {}): void {
    if (typeof window === "undefined") {
      return
    }

    if (referenceBottomOffsetPx != null) {
      this.referenceBottomOffsetPx = referenceBottomOffsetPx
    }
    if (referenceTopOffsetPx != null) {
      this.referenceTopOffsetPx = referenceTopOffsetPx
    }

    if (!this._ticking) {
      window.requestAnimationFrame(() => {
        this.calculateState()
        this._ticking = false
      })

      this._ticking = true
    }
  }
}
