import type { Cell, TableCellNode } from '@lexical/table'
import {
	$getTableRowIndexFromTableCellNode,
	$getTableRowNodeFromTableCellNodeOrThrow,
	$getTableNodeFromLexicalNodeOrThrow,
	$isTableRowNode,
	$isTableCellNode,
} from '@lexical/table'
import type { LexicalEditor, LexicalNode } from 'lexical'
import { $getNearestNodeFromDOMNode } from 'lexical'
import { $assertIsRowNode, $assertIsCellNode } from '../../utils/tableUtils'
import type { CustomTableCellNode } from '../../nodes/CustomTableCellNode'
import { $isCustomTableCellNode } from '../../nodes/CustomTableCellNode'

export interface ITableNode {
	getChildren: () => unknown[]
}

export interface ITableRowNode {
	getChildren: () => unknown[]
}

export interface ITableCellNode {
	getColSpan: () => number
	getRowSpan: () => number
}

const MIN_COLUMN_WIDTH = 50

export function $createGetLocalIndexFromGlobalIndex(
	table: ITableNode,
	spanPunchCard: number[][],
	spanLookupTable: number[][],
	getWidth: (tableNode: ITableNode) => number = createGetTableWidth($assertIsRowNode, $assertIsCellNode)
) {
	const width = getWidth(table)
	const height = spanLookupTable.length
	const reverseLookupTable: number[][] = []
	for (let j = 0; j < height; j++) {
		reverseLookupTable[j] = []
		for (let i = 0; i < width; i++) {
			if (spanLookupTable[j][i] !== undefined) {
				reverseLookupTable[j][i + spanLookupTable[j][i]] = i
			}
		}
	}
	return function (j: number, i: number) {
		const mainIndex = reverseLookupTable[j][i]
		if (mainIndex !== undefined) {
			return [mainIndex, mainIndex + 1]
		} else if (spanPunchCard[j][i + 1] !== 1 && i + 1 < width) {
			return Array<number>(2).fill(reverseLookupTable[j][i + 1] ?? 0, 1)
		} else {
			return Array<number>(2).fill(reverseLookupTable[j][i + 1], 1)
		}
	}
}

export function createMakeSpanPunchCard(
	assertIsRowNode: (node: any) => asserts node is ITableRowNode = $assertIsRowNode,
	assertIsCellNode: (node: any) => asserts node is ITableCellNode = $assertIsCellNode,
	getTableWidth: (tableNode: ITableNode) => number = createGetTableWidth(assertIsRowNode, assertIsCellNode)
) {
	function makeSpanPunchCard(tableNode: ITableNode): number[][] {
		const punchCard: number[][] = []

		const tableRows = tableNode.getChildren()
		const height = tableRows.length

		if (height === 0) {
			return punchCard
		}

		const width = getTableWidth(tableNode)

		for (let i = 0; i < tableRows.length; i++) {
			punchCard[i] = []
		}

		for (let j = 0; j < height; j++) {
			const row = tableRows[j]
			assertIsRowNode(row)
			const rowCells = row.getChildren()
			let offset = 0
			for (let i = 0; i < width; i++) {
				const cell = rowCells[i - offset]
				if (punchCard[j][i] !== undefined) {
					offset += 1
					continue
				}
				assertIsCellNode(cell)
				addSpannedCellsToPunchCard(cell, punchCard, j, i)
				// Optimization - probably not needed
				// i += colSpan - 1
			}
		}
		return punchCard
	}
	return makeSpanPunchCard
}

/**
 * Adds 1s to the punch card for the cells that are spanned by the cell
 *
 * @param cell Lexical table cell node
 * @param punchCard table with shape of the table with 1s in the cells that are spanned
 * @param j	row index
 * @param i column index
 */
function addSpannedCellsToPunchCard(cell: ITableCellNode, punchCard: number[][], j: number, i: number) {
	const colSpan = cell.getColSpan()
	const rowSpan = cell.getRowSpan()

	for (let l = 0; l < rowSpan; l++) {
		for (let k = 0; k < colSpan; k++) {
			if (l === 0 && k === 0) {
				continue
			}
			punchCard[j + l][i + k] = 1
		}
	}
}

export function createGetTableWidth(
	assertIsRowNode: (node: any) => asserts node is ITableRowNode,
	assertIsCellNode: (node: any) => asserts node is ITableCellNode
) {
	return function (tableNode: ITableNode): number {
		const firstRow = tableNode.getChildren()[0]
		assertIsRowNode(firstRow)

		return firstRow.getChildren().reduce((sum: number, cell) => {
			assertIsCellNode(cell)
			return sum + cell.getColSpan()
		}, 0)
	}
}

export function modifySpanLookupTableToUseRightEdge(
	spanLookupTable: number[][],
	getColSpan: (j: number, i: number) => number
) {
	for (let j = 0; j < spanLookupTable.length; j++) {
		for (let i = 0; i < spanLookupTable[j].length; i++) {
			spanLookupTable[j][i] += getColSpan(j, i) - 1
		}
	}
}

export function createMakeSpanLookupTable(
	assertIsRowNode: (node: any) => asserts node is ITableRowNode = $assertIsRowNode,
	assertIsCellNode: (node: any) => asserts node is ITableCellNode = $assertIsCellNode,
	makeSpanPunchCard: (tableNode: ITableNode) => number[][] = createMakeSpanPunchCard(),
	getTableWidth: (tableNode: ITableNode) => number = createGetTableWidth($assertIsRowNode, $assertIsCellNode),
	modifyToRightEdge = true
) {
	function makeSpanLookupTable(tableNode: ITableNode, spanPunchCard?: number[][]): number[][] {
		const punchCard = spanPunchCard ?? makeSpanPunchCard(tableNode)
		const spanLookupTable: number[][] = []
		const height = punchCard.length

		if (height === 0) {
			return spanLookupTable
		}

		const width = getTableWidth(tableNode)

		// Initialize tableCellSpanOffsets rows
		for (let i = 0; i < height; i++) {
			spanLookupTable[i] = []
		}

		for (let j = 0; j < height; j++) {
			let counter = 0
			for (let i = 0; i < width; i++) {
				if (punchCard[j][i] === 1) {
					counter++
				} else {
					spanLookupTable[j].push(counter)
				}
			}
		}
		const cells = tableNode.getChildren().map((row) => {
			assertIsRowNode(row)
			return row.getChildren()
		})
		if (modifyToRightEdge) {
			modifySpanLookupTableToUseRightEdge(spanLookupTable, (j, i) => {
				const cell = cells[j][i]
				assertIsCellNode(cell)
				return cell.getColSpan()
			})
		}
		return spanLookupTable
	}

	return makeSpanLookupTable
}

export function updateActiveCell(
	editor: LexicalEditor,
	cell: Cell,
	targetRef: React.MutableRefObject<HTMLElement | null>,
	target: EventTarget | null,
	tableRectRef: React.MutableRefObject<DOMRect | null>,
	setActiveCell: React.Dispatch<React.SetStateAction<Cell | null>>
) {
	editor.update(() => {
		const tableCellNode = $getNearestNodeFromDOMNode(cell.elem)
		if (!tableCellNode) {
			return
		}

		const tableNode = $getTableNodeFromLexicalNodeOrThrow(tableCellNode)
		const tableElement = editor.getElementByKey(tableNode.getKey())

		if (!tableElement) {
			return
			// throw new Error('TableCellResizer: Table element not found.')
		}

		targetRef.current = target as HTMLElement
		tableRectRef.current = tableElement.getBoundingClientRect()
		setActiveCell(cell)
	})
}

export function resizeCellsInRow(
	editor: LexicalEditor,
	tableRow: LexicalNode,
	correctedTableColumnIndex: number | undefined,
	deltaWidth: number,
	compressionRatio: number,
	availableWidth: number,
	shouldShrinkNext: boolean,
	rightCellIndex: number | undefined
) {
	if (!$isTableRowNode(tableRow)) {
		throw new Error('Expected table row')
	}

	const tableCells = tableRow.getChildren()

	const isRightMostCell = correctedTableColumnIndex === tableCells.length - 1
	function resizeRightCell() {
		if (isRightMostCell) {
			return
		}
		if (!shouldShrinkNext) {
			return
		}
		if (rightCellIndex === undefined) {
			return
		}
		const rightCell = tableCells[rightCellIndex]
		if (!rightCell) {
			return
		}
		if (!$isTableCellNode(rightCell)) {
			throw new Error('Expected table cell')
		}
		const currentRightCellWidth = getRightCellWidth(rightCell, editor)
		if (currentRightCellWidth === undefined) {
			throw new Error('TableCellResizer: Expected table cell to have width')
		}
		rightCell.setWidth(currentRightCellWidth - deltaWidth * compressionRatio)
	}

	if (correctedTableColumnIndex === undefined) {
		resizeRightCell()
		return
	}

	if (correctedTableColumnIndex >= tableCells.length || correctedTableColumnIndex < 0) {
		throw new Error('Expected table cell to be inside of table row.')
	}

	const tableCell = tableCells[correctedTableColumnIndex]

	if (!tableCell) {
		resizeRightCell()
		return
	}

	if (!$isCustomTableCellNode(tableCell)) {
		throw new Error('Expected table cell')
	}

	const newCellWidth = getNewCellWidth(editor, tableCell, isRightMostCell, deltaWidth, compressionRatio)

	if (!isRightMostCell) {
		resizeInternalCell(tableCell, newCellWidth, resizeRightCell)
	} else {
		resizeRightMostCell(tableCell, editor, availableWidth, newCellWidth, tableCells)
	}
}

function getRightCellWidth(rightCell: TableCellNode, editor: LexicalEditor) {
	return rightCell.getWidth() || editor.getElementByKey(rightCell.getKey())?.getBoundingClientRect().width
}

function getNewCellWidth(
	editor: LexicalEditor,
	tableCell: CustomTableCellNode,
	isRightMostCell: boolean,
	deltaWidth: number,
	compressionRatio: number
) {
	const currentCellWidth = isRightMostCell
		? getTableCellBoundingClientRect(tableCell, editor).width
		: tableCell.getWidth() ?? getTableCellBoundingClientRect(tableCell, editor).width
	return isRightMostCell
		? Math.max(currentCellWidth + deltaWidth, MIN_COLUMN_WIDTH)
		: Math.max(currentCellWidth + deltaWidth * compressionRatio, MIN_COLUMN_WIDTH)
}

function getTableCellBoundingClientRect(tableCell: CustomTableCellNode, editor: LexicalEditor) {
	const cellElement = editor.getElementByKey(tableCell.getKey())
	if (!cellElement) {
		throw new Error('TableCellResizer: Expected table cell to have element')
	}
	return cellElement.getBoundingClientRect()
}

export function getLexicalTableCellNodeFromActiveCell(activeCell: Cell) {
	const tableCellNode = $getNearestNodeFromDOMNode(activeCell.elem)
	if (!$isCustomTableCellNode(tableCellNode)) {
		throw new Error('TableCellResizer: Table cell node not found.')
	}
	return tableCellNode
}

function resizeRightMostCell(
	tableCell: CustomTableCellNode,
	editor: LexicalEditor,
	availableWidth: number,
	newWidth: number,
	tableCells: LexicalNode[]
) {
	const currentWidth = tableCell.getWidth() || editor.getElementByKey(tableCell.getKey())?.getBoundingClientRect().width

	if (currentWidth === undefined) {
		throw new Error('TableCellResizer: Expected table cell to have width')
	}
	const maxWidth = currentWidth + availableWidth
	const clampedWidth = Math.min(maxWidth, newWidth)
	tableCell.setWidth(clampedWidth)
	tableCells.forEach((cell) => {
		if (cell === tableCell || !$isTableCellNode(cell)) {
			return
		}
		if (!cell.getWidth()) {
			return
		}
		if (availableWidth > 0) {
			return
		}
		const cellElement = editor.getElementByKey(cell.getKey())
		if (!cellElement) {
			throw new Error('TableCellResizer: Expected table cell to have element')
		}
		cell.setWidth(cellElement.clientWidth)
	})
}

function resizeInternalCell(tableCell: CustomTableCellNode, newWidth: number, resizeRightCell: () => void) {
	resizeRightCell()

	tableCell.setWidth(newWidth)
}

export function getCompressionRatio(editor: LexicalEditor, tableRows: LexicalNode[], availableWidth: number) {
	if (availableWidth > 0) {
		return 1
	}
	const tableRow = tableRows[0]
	if (!$isTableRowNode(tableRow)) {
		throw new Error('Expected table row')
	}
	const rowCells = tableRow.getChildren()
	const widths = rowCells.reduce(
		(prevSum, currentCellNode) => {
			if (!$isCustomTableCellNode(currentCellNode)) {
				return prevSum
			}
			const element = editor.getElementByKey(currentCellNode.getKey())
			if (!element) {
				return prevSum
			}
			const setWidth = currentCellNode.getWidth()
			// Is there any way we can find the default width of a column without specified width?
			if (setWidth === undefined) {
				return prevSum
			}
			prevSum.actualWidth += element.getBoundingClientRect().width
			prevSum.setWidth += setWidth
			return prevSum
		},
		{ actualWidth: 0, setWidth: 0 }
	)
	if (widths.actualWidth === 0) {
		return 1
	}
	return widths.setWidth / widths.actualWidth
}

export function $getTableColumnIndexFromTableCellNode(
	tableCellNode: CustomTableCellNode,
	spanLookupTable: (number | undefined)[][]
): number {
	const tableRowIndex = $getTableRowIndexFromTableCellNode(tableCellNode)
	const tableRowNode = $getTableRowNodeFromTableCellNodeOrThrow(tableCellNode)
	const index = tableRowNode.getChildren().findIndex((n) => n.is(tableCellNode))
	return index + (spanLookupTable[tableRowIndex][index] ?? 0)
}

export function correctIndexFromColSpanOffsets(
	colSpanOffsets: (number | undefined)[][],
	rowIndex: number,
	columnIndex: number
) {
	const row = colSpanOffsets[rowIndex]
	let colSpanOffset = 0
	for (let i = 0; i < columnIndex + 1; i++) {
		if (row[i]) {
			colSpanOffset += row[i]!
		}
	}
	return columnIndex + colSpanOffset
}

export const $makeSpanPunchCard = createMakeSpanPunchCard()

export const $makeSpanLookupTable = createMakeSpanLookupTable()
