import { useEffect } from 'react'
import { getFirstFocusable } from '../getFirstFocusable'

// If some element already matches the focus-visible pseudo class, calling the focus method on a new element will make this new element also match focus-visible,
// as per the specification: https://www.w3.org/TR/selectors-4/#the-focus-visible-pseudo
const focusElement = (element: HTMLElement | null | undefined) => element?.focus({ preventScroll: true })

export const useInitialFocus = (
	focusRootRef: React.RefObject<HTMLDivElement>,
	isOpen: boolean,
	rootIsTarget: boolean
) => {
	useEffect(() => {
		if (!isOpen) {
			return
		}

		const setFocusOnCorrectTarget = () =>
			rootIsTarget ? setInitialFocusOnRoot(focusRootRef) : setInitialFocusOnFirstInteractiveDescendent(focusRootRef)
		callFunctionAfterCallstack(setFocusOnCorrectTarget)
	}, [isOpen])
}

const setInitialFocusOnRoot = (rootRef: React.RefObject<HTMLDivElement>) => {
	const focusedElement = document.activeElement
	const rootElement = rootRef.current
	const willStealFocusFromDescendent = rootElement?.contains(focusedElement)

	if (!willStealFocusFromDescendent) {
		focusElement(rootElement)
	}
}

const setInitialFocusOnFirstInteractiveDescendent = (rootRef: React.RefObject<HTMLDivElement>) => {
	const firstFocusable = getFirstFocusable(rootRef.current)
	focusElement(firstFocusable)
}

const callFunctionAfterCallstack = (callback: () => void) => window.setTimeout(callback, 0)
