import { MarkType, Node, NodeType, ResolvedPos } from "prosemirror-model"
import { EditorState, Plugin, PluginKey } from "prosemirror-state"

import '@/util/pm-extensions'

export interface ICacheContainer {
    activeMarkTypes: Map<MarkType, boolean>
    selectedNodeTypes: Map<NodeType, boolean>

    containingNodes: Map<Node, boolean>
    containingNodeTypes: Map<NodeType, boolean>

    runInAction: (fn: () => void) => void
}

export class CacheState {
    obj: ICacheContainer


    constructor(obj: ICacheContainer) {
        this.obj = obj
    }

    updateFromState(state: EditorState) {
        this.obj.runInAction(() => this._updateFromState(state))
    }

    private _updateFromState(state: EditorState) {
        this.getActiveMarks(state)
        this.getActiveNodes(state)
        this.getContainingNodes(state)
    }

    getActiveNodes(state: EditorState) {
        const {selectedNodeTypes: activeNodes} = this.obj
        const {empty, ranges} = state.selection

        const newActiveNodes = new Set<NodeType>()

        if (empty) {
            activeNodes.clear()
        } else {
            ranges.forEach((range) => {
                let parentSeen = false
                state.doc.nodesBetween(range.$from.pos, range.$to.pos, (node) => {
                    if (parentSeen) newActiveNodes.add(node.type)

                    if (range.$from.parent === node)
                        return parentSeen = true
                    else
                        return
                })
            })
        }

        newActiveNodes.forEach( (node) => activeNodes.set(node, true))
        activeNodes.forEach( (_, nodeType) => {
            if (!newActiveNodes.has(nodeType)) activeNodes.delete(nodeType)
        })
    }

    getActiveMarks(state: EditorState) {
        const {activeMarkTypes: activeMarks} = this.obj
        const {$from, empty, ranges} = state.selection

        const newActiveMarks = new Set<MarkType>()

        if (empty) {
            /** Stored marks of [..] means use stored marks */
            if (state.storedMarks) {
                state.storedMarks.forEach( (mark) => newActiveMarks.add(mark.type))
            /** Undefined or Null means surrounding marks carry forward */
            } else {
                $from.marks().forEach( (mark) => newActiveMarks.add(mark.type))
            }
        } else {
            ranges.forEach((range) => {
                state.doc.nodesBetween(range.$from.pos, range.$to.pos, (node) => {
                    node.marks.forEach( (mark) => newActiveMarks.add(mark.type))
                })
            })
        }

        newActiveMarks.forEach( (mark) => activeMarks.set(mark, true))
        activeMarks.forEach( (_, markType) => {
            if (!newActiveMarks.has(markType)) activeMarks.delete(markType)
        })
    }

    getContainingNodes(state: EditorState) {
        const rangeNodesList: Array<Set<Node>> = []

        for (const {$from, $to} of state.selection.ranges) {
            const foundNodes = new Set<Node>()

            const rangeNodes = new Set<Node>()
            rangeNodesList.push(rangeNodes)

            $from.nodesAbove((node) => { foundNodes.add(node) })
            $to.nodesAbove((node) => { if(foundNodes.has(node)) rangeNodes.add(node) })
        }

        const { containingNodes, containingNodeTypes } = this.obj
        const newContainingNodes = new Set<Node>()
        const newContainingNodeTypes = new Set<NodeType>()

        for (const rangeNode of rangeNodesList[0]) {
            if (rangeNodesList.every(list => list.has(rangeNode))) {
                newContainingNodes.add(rangeNode)
                newContainingNodeTypes.add(rangeNode.type)
            }
        }

        for (const node of newContainingNodes)
            containingNodes.set(node, true)
        for (const [node, _] of containingNodes)
            if (!newContainingNodes.has(node)) containingNodes.delete(node)

        for (const nodeType of newContainingNodeTypes)
            containingNodeTypes.set(nodeType, true)
        for (const [nodeType, _] of containingNodeTypes)
            if (!newContainingNodeTypes.has(nodeType)) containingNodeTypes.delete(nodeType)
    }
}

export const cacheKey = new PluginKey<CacheState>("cache")

export interface CacheConfig {
    container: ICacheContainer
}

export function cache(config: CacheConfig): Plugin<CacheState> {
    return new Plugin({
        key: cacheKey,

        state: {
            init: () => new CacheState(config.container),
            apply(tr, cache, oldState, newState) {
                cache.updateFromState(newState)

                return cache
            }
        }
    })
}