#!/usr/bin/env python3

# Libervia plugin for Pubsub Extended Discovery.
# Copyright (C) 2009-2025 Jérôme Poisson (goffi@goffi.org)

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from collections.abc import Callable, Coroutine, Iterable
from typing import (
    TYPE_CHECKING,
    ClassVar,
    Final,
    Literal,
    NamedTuple,
    Self,
    cast,
)

from pydantic import BaseModel, Field, RootModel, model_validator
from sqlalchemy.exc import SQLAlchemyError
from twisted.internet import defer
from twisted.words.protocols.jabber.jid import JID
from twisted.words.protocols.jabber import xmlstream
from twisted.words.protocols.jabber.xmlstream import XMPPHandler
from twisted.words.xish import domish
from wokkel import disco, data_form, pubsub
from zope.interface import implementer

from libervia.backend.core import exceptions
from libervia.backend.core.constants import Const as C
from libervia.backend.core.core_types import SatXMPPComponent, SatXMPPEntity
from libervia.backend.core.i18n import _
from libervia.backend.core.log import getLogger
from libervia.backend.memory.sqla_mapping import PubsubItem, PubsubNode
from libervia.backend.models.types import JIDType
from libervia.backend.plugins.plugin_xep_0060 import NodeMetadata
from libervia.backend.tools import utils

if TYPE_CHECKING:
    from libervia.backend.core.main import LiberviaBackend

log = getLogger(__name__)


PLUGIN_INFO = {
    C.PI_NAME: "Pubsub Extended Discovery",
    C.PI_IMPORT_NAME: "XEP-0499",
    C.PI_TYPE: "XEP",
    C.PI_MODES: C.PLUG_MODE_BOTH,
    C.PI_PROTOCOLS: ["XEP-0499"],
    C.PI_DEPENDENCIES: [
        "XEP-0060",
    ],
    C.PI_RECOMMENDATIONS: [],
    C.PI_MAIN: "XEP_0499",
    C.PI_HANDLER: "yes",
    C.PI_DESCRIPTION: _(
        """Extended discovery for Pubsub nodes with linked nodes and descendants."""
    ),
}

NS_PUBSUB_EXT_DISCO: Final = "urn:xmpp:pubsub-ext-disco:0"
NS_PUBSUB_RELATIONSHIPS: Final = "urn:xmpp:pubsub-relationships:0"
PARENT_VAR = f"{{{NS_PUBSUB_RELATIONSHIPS}}}parent"
LINK_VAR = f"{{{NS_PUBSUB_RELATIONSHIPS}}}link"


class ExtDiscoMetadata(NodeMetadata):

    parent: str | None = None
    link: str | None = None

    @classmethod
    def from_data_form(cls, form: data_form.Form) -> Self:
        """Create a ExtDiscoMetadata instance from a data form.

        @param form: Data form containing node metadata.
        @return: Filled instance of this class.
        @raise TypeError: Type of the form do not correspond to what is expected according
            to specifications.
        """
        metadata = super().from_data_form(form)
        metadata.parent = form.get(PARENT_VAR)
        metadata.link = form.get(LINK_VAR)
        return metadata

    def to_data_form(self) -> data_form.Form:
        """Convert this instance to a data form.

        @return: Data form representation of this instance.
        """
        form = super().to_data_form()
        if self.parent is not None:
            form.addField(data_form.Field(var=PARENT_VAR, value=self.parent))
        if self.link is not None:
            form.addField(data_form.Field(var=LINK_VAR, value=self.link))
        return form


class ExtDiscoOptions(BaseModel):
    """Pydantic model for the pubsub extended discovery form fields."""

    type: list[Literal["items", "nodes"]] = Field(
        default_factory=lambda: ["items", "nodes"]
    )
    linked_nodes: bool = False
    full_metadata: bool = False
    depth: int = 0

    _fields_defs: ClassVar[dict] = {
        "type": {"type": "list-multi"},
        "linked_nodes": {"type": "boolean"},
        "full_metadata": {"type": "boolean"},
        "depth": {"type": "text-single"},
    }

    @classmethod
    def from_data_form(cls, form: data_form.Form) -> Self:
        """Create a PubsubExtDiscoForm instance from a data form.

        @param form: Extended Discovery Data Form.
        @return: Filled instance of this class.
        @raise TypeError: Type of the form do not correspond to what is expected according
            to specifications.
        """
        fields = {}
        form.typeCheck(cls._fields_defs)
        for field in form.fields.values():
            if field.var == "type":
                fields["type"] = field.values
            elif field.var == "linked_nodes":
                fields["linked_nodes"] = field.value
            elif field.var == "full_metadata":
                fields["full_metadata"] = field.value
            elif field.var == "depth":
                try:
                    fields["depth"] = int(field.value)
                except (ValueError, TypeError):
                    log.warning(f"Invalid depth found: {field.value!r}.")
                    fields["depth"] = 0

        return cls(**fields)

    def to_data_form(self) -> data_form.Form:
        """Convert this instance to a data form.

        @return: Data form representation of this instance.
        """
        form = data_form.Form(formType="submit", formNamespace=NS_PUBSUB_EXT_DISCO)

        form.makeFields(
            {
                "type": self.type,
                "linked_nodes": self.linked_nodes,
                "full_metadata": self.full_metadata,
                "depth": str(self.depth),
            },
            fieldDefs=self._fields_defs,
        )

        return form

    def to_element(self) -> domish.Element:
        """Generate the <x> element corresponding to this form."""
        return self.to_data_form().toElement()


class DiscoPubsubItem(BaseModel):
    type: Literal["item"] = "item"
    parent_node: str
    jid: JIDType
    name: str

    @classmethod
    def from_sqlalchemy(cls, item: PubsubItem, node_name: str, service: JID) -> Self:
        """Create a DiscoPubsubItem instance from a PubsubItem SQLAlchemy model.

        @param item: The SQLAlchemy PubsubItem instance.
        @param node_name: The name of the parent node.
        @param service: The JID of the service where the item is.
        @return: A new instance of this class.
        """
        return cls(parent_node=node_name, jid=service, name=item.name)

    def to_element(self) -> domish.Element:
        """Generate the element corresponding to this instance."""
        item_elt = domish.Element((disco.NS_DISCO_ITEMS, "item"))
        item_elt["jid"] = self.jid.full()
        item_elt["node"] = self.parent_node
        item_elt["name"] = self.name
        return item_elt


class DiscoPubsubNode(BaseModel):
    type: Literal["node"] = "node"
    jid: JIDType
    name: str
    items: list[DiscoPubsubItem] | None = None
    linking_nodes: list[Self] | None = None
    children: list[Self] | None = None
    metadata: ExtDiscoMetadata = Field(default_factory=ExtDiscoMetadata)

    @model_validator(mode="after")
    def set_link_and_parent_metadata(self) -> Self:
        """We ensure that `metadata.link` and `metadata.parent` are set correctly."""
        if self.linking_nodes is not None:
            for linking_node in self.linking_nodes:
                linking_node.metadata.link = self.name

        if self.children is not None:
            for child in self.children:
                child.metadata.parent = self.name

        return self

    def add_child(self, child_node: Self) -> None:
        """Add a child and set its ``parent`` metadata to current node."""
        if self.children is None:
            self.children = []
        self.children.append(child_node)
        child_node.metadata.parent = self.name

    def link_to(self, linked_node: Self) -> None:
        """Link this node to another one, and set the ``link`` metadata."""
        if linked_node.linking_nodes is None:
            linked_node.linking_nodes = []
        linked_node.linking_nodes.append(self)
        self.metadata.link = linked_node.name

    @classmethod
    def from_sqlalchemy(cls, node: PubsubNode, service: JID) -> Self:
        """Create a DiscoPubsubNode instance from a PubsubNode SQLAlchemy model.

        @param node: The SQLAlchemy PubsubNode instance.
        @param service: The JID of the service where the node is.
        @return: A new instance of this class.
        """
        # Create base node instance
        disco_node = cls(
            jid=service,
            name=node.name,
            metadata=ExtDiscoMetadata(
                type=node.type_ if node.type_ else None,
                access_model=node.access_model.value if node.access_model else None,
                publish_model=node.publish_model.value if node.publish_model else None,
            ),
        )

        # Set parent and link metadata based on relationships
        if node.parent_node_id:
            disco_node.metadata.parent = node.parent_node.name

        if node.linked_node_id:
            disco_node.metadata.link = node.linked_node.name

        if node.extra is not None:
            for field in ("title", "description"):
                try:
                    setattr(disco_node.metadata, field, node.extra[field])
                except KeyError:
                    pass

        return disco_node

    @classmethod
    def from_sqlalchemy_full_hierarchy(cls, node: PubsubNode, service: JID) -> Self:
        """Create a DiscoPubsubNode instance with children and items populated.

        The whole nodes and items hierarchy will be recursively created.

        @param node: The SQLAlchemy PubsubNode instance.
        @param service: The JID of the service where the node is.
        @return: A new instance of this class with children and items.
        """
        disco_node = cls.from_sqlalchemy(node, service)

        try:
            items = node.items
        except SQLAlchemyError as e:
            log.debug(f"Can't load items: {e}")
        else:
            disco_node.items = [
                DiscoPubsubItem.from_sqlalchemy(item, node.name, service)
                for item in items
            ]

        try:
            child_nodes = node.child_nodes
        except Exception as e:
            log.debug(f"Can't load child nodes: {e}")
        else:
            try:
                disco_node.children = [
                    cls.from_sqlalchemy_full_hierarchy(child, service)
                    for child in child_nodes
                ]
            except SQLAlchemyError as e:
                log.debug(f"Can't handle children, ignoring them: {e}.")

        try:
            linking_nodes = node.linking_nodes
        except Exception as e:
            log.debug(f"Can't load linking nodes: {e}")
        else:
            disco_node.linking_nodes = [
                cls.from_sqlalchemy_full_hierarchy(linking_node, service)
                for linking_node in linking_nodes
            ]

        return disco_node

    def to_element(self) -> domish.Element:
        item_elt = domish.Element((disco.NS_DISCO_ITEMS, "item"))
        item_elt["jid"] = self.jid.full()
        item_elt["node"] = self.name
        item_elt.addChild(self.metadata.to_element())
        return item_elt

    def to_elements(
        self,
    ) -> list[domish.Element]:
        """Return this elements and all its descendants and linking nodes.

        @param parent_node: Parent of this node, None if it's a root node.
        @param linked_node: Node this node is linking to, if any.
        @return: This node, its descendants and linking nodes.
        """
        elements = [self.to_element()]
        if self.linking_nodes is not None:
            for linking_node in self.linking_nodes:
                elements.append(linking_node.to_element())
        if self.children is not None:
            for child in self.children:
                elements.extend(child.to_elements())
        return elements


class ExtDiscoResult(RootModel):
    root: list[DiscoPubsubNode | DiscoPubsubItem]

    def __iter__(self) -> Iterable[DiscoPubsubNode | DiscoPubsubItem]:  # type: ignore
        return iter(self.root)

    def __getitem__(self, item) -> str:
        return self.root[item]

    def __len__(self) -> int:
        return len(self.root)

    def append(self, item: DiscoPubsubNode | DiscoPubsubItem) -> None:
        self.root.append(item)

    def sort(self, key=None, reverse=False) -> None:
        self.root.sort(key=key, reverse=reverse)  # type: ignore

    @classmethod
    def from_element(cls, query_elt: domish.Element) -> Self:
        """Build nodes hierarchy from unordered list of disco <items>

        @param query_elt: Parent <query> element from the disco result, disco <item>
            elements will be retrieved in its children.
        @return: An instance of this class, with nodes hierarchy reconstructed.
        """
        # First pass: Create all nodes and items
        nodes: dict[str, DiscoPubsubNode] = {}
        items: list[DiscoPubsubItem] = []

        # Process all disco items
        for item_elt in query_elt.elements(disco.NS_DISCO_ITEMS, "item"):
            try:
                item_jid = JID(item_elt["jid"])
                node_name = item_elt["node"]
            except KeyError:
                log.warning(f"Invalid extended disco item, ignoring: {item_elt.toXml()}")
                continue

            # Check if this item has metadata (indicating it's a node)
            metadata_form = data_form.findForm(item_elt, pubsub.NS_PUBSUB_META_DATA)

            if metadata_form is not None:
                # This is a PubsubNode
                metadata = ExtDiscoMetadata.from_data_form(metadata_form)
                node = DiscoPubsubNode(jid=item_jid, name=node_name, metadata=metadata)
                nodes[node_name] = node
            else:
                # This is a PubsubItem
                try:
                    name = item_elt["name"]
                except KeyError:
                    log.warning(
                        "Invalid disco item, pubsub items must have a name: "
                        f"{item_elt.toXml()}"
                    )
                    continue
                item = DiscoPubsubItem(parent_node=node_name, jid=item_jid, name=name)
                items.append(item)

        # Second pass: Build hierarchy using parent/link relationships
        root_nodes: list[DiscoPubsubNode] = []

        for node in nodes.values():
            parent_name = node.metadata.parent
            link_name = node.metadata.link

            if parent_name is not None:
                # This node has a parent
                parent_node = nodes.get(parent_name)
                if parent_node is not None:
                    parent_node.add_child(node)
                else:
                    # Parent not found, treat as root
                    log.warning(f"Parent node found for {node}.")
                    root_nodes.append(node)
            elif link_name is not None:
                # This is a linking node
                linked_node = nodes.get(link_name)
                if linked_node is not None:
                    linked_node.link_to(node)
                else:
                    # Linked node not found, treat as root
                    log.warning(f"Linked node found for {node}.")
                    root_nodes.append(node)
            else:
                # This is a root node (no parent, no link)
                root_nodes.append(node)

        # Third pass: Assign items to their corresponding nodes
        root_items: list[DiscoPubsubItem] = []

        for item in items:
            target_node = nodes.get(item.parent_node)
            if target_node is not None:
                if target_node.items is None:
                    target_node.items = []
                target_node.items.append(item)
            else:
                # Item's parent node not found, treat as root item
                root_items.append(item)

        # Combine root nodes and root items
        result_items: list[DiscoPubsubNode | DiscoPubsubItem] = []
        result_items.extend(root_nodes)
        result_items.extend(root_items)

        return cls(root=result_items)

    def to_elements(self) -> list[domish.Element]:
        elements = []
        for item in self.root:
            if isinstance(item, DiscoPubsubNode):
                elements.extend(item.to_elements())
            else:
                elements.append(item.to_element())
        return elements


class RequestHandler(NamedTuple):
    callback: Callable[
        [SatXMPPComponent, domish.Element, ExtDiscoOptions],
        ExtDiscoResult
        | Coroutine[None, None, ExtDiscoResult]
        | defer.Deferred[ExtDiscoResult],
    ]
    priority: int


class XEP_0499:
    namespace = NS_PUBSUB_EXT_DISCO

    def __init__(self, host: "LiberviaBackend") -> None:
        log.info(f"plugin {PLUGIN_INFO[C.PI_NAME]!r} initialization")
        self.host = host
        self.handlers: list[RequestHandler] = []

        host.register_namespace("pubsub-ext-disco", NS_PUBSUB_EXT_DISCO)
        host.register_namespace("pubsub-relationships", NS_PUBSUB_RELATIONSHIPS)
        host.bridge.add_method(
            "ps_disco_get",
            ".plugin",
            in_sign="ssss",
            out_sign="s",
            method=self._disco_get,
            async_=True,
        )

    def get_handler(self, client: SatXMPPEntity) -> "PubsubExtDiscoHandler":
        return PubsubExtDiscoHandler(self)

    def register_handler(
        self,
        callback: Callable[
            [SatXMPPComponent, domish.Element, ExtDiscoOptions],
            ExtDiscoResult
            | Coroutine[None, None, ExtDiscoResult]
            | defer.Deferred[ExtDiscoResult],
        ],
        priority: int = 0,
    ) -> None:
        """Register an extended discovery request handler.

        @param callack: method to call when a request is done
            the callback must return an DiscoveryData.
            If the callback raises a StanzaError, its condition will be used if no other
            callback can handle the request.
        @param priority: Handlers with higher priorities will be called first.
        """
        assert callback not in self.handlers
        req_handler = RequestHandler(callback, priority)
        self.handlers.append(req_handler)
        self.handlers.sort(key=lambda handler: handler.priority, reverse=True)

    def _handle_disco_items_request(
        self, iq_elt: domish.Element, client: SatXMPPEntity
    ) -> None:
        query_elt = iq_elt.query
        assert query_elt is not None
        ext_disco_form = data_form.findForm(query_elt, NS_PUBSUB_EXT_DISCO)
        if ext_disco_form is None:
            # This is a normal disco request, we transmit to Wokkel's disco handler.
            client.discoHandler.handleRequest(iq_elt)
            return
        # We have an extended pubsub discovery request (the form is present), we continue.
        iq_elt.handled = True
        ext_disco_options = ExtDiscoOptions.from_data_form(ext_disco_form)
        defer.ensureDeferred(
            self.handle_disco_items_request(client, iq_elt, ext_disco_options)
        )

    async def handle_disco_items_request(
        self,
        client: SatXMPPEntity,
        iq_elt: domish.Element,
        ext_disco_options: ExtDiscoOptions,
    ) -> None:
        query_elt = iq_elt.query
        assert query_elt is not None
        for handler in self.handlers:
            try:
                disco_result = await utils.as_deferred(
                    handler.callback, client, iq_elt, ext_disco_options
                )
            except Exception as e:
                log.exception("Can't retrieve disco data.")
                client.sendError(iq_elt, "internal-server-error", text=str(e))
                raise e
            else:
                if disco_result is not None:
                    break
        else:
            # No handler did return a result.
            disco_result = ExtDiscoResult([])
        iq_result_elt = xmlstream.toResponse(iq_elt, "result")
        result_query_elt = iq_result_elt.addElement((disco.NS_DISCO_ITEMS, "query"))
        if query_elt.hasAttribute("node"):
            result_query_elt["node"] = query_elt["node"]
        for elt in disco_result.to_elements():
            result_query_elt.addChild(elt)
        client.send(iq_result_elt)

    def _disco_get(
        self, service: str, node: str, options_s: str, profile: str
    ) -> defer.Deferred[str]:
        client = self.host.get_client(profile)
        service_jid = JID(service)
        options = ExtDiscoOptions.model_validate_json(options_s)
        d = defer.ensureDeferred(
            self.disco_get(client, service_jid, node or None, options)
        )
        d.addCallback(
            lambda ext_disco_result: ext_disco_result.model_dump_json(exclude_none=True)
        )
        d = cast(defer.Deferred[str], d)
        return d

    async def disco_get(
        self,
        client: SatXMPPEntity,
        service: JID,
        node: str | None,
        options: ExtDiscoOptions,
    ) -> ExtDiscoResult:

        query_elt = domish.Element((disco.NS_DISCO_ITEMS, "query"))
        if node is not None:
            query_elt["node"] = node

        query_elt.addChild(options.to_element())
        iq_elt = client.IQ("get")
        iq_elt["to"] = service.full()
        iq_elt.addChild(query_elt)
        iq_result_elt = await iq_elt.send()
        try:
            query_elt = next(iq_result_elt.elements(disco.NS_DISCO_ITEMS, "query"))
        except StopIteration:
            raise exceptions.DataError(
                "<query> missing in disco result, this is invalid."
            )
        disco_result = ExtDiscoResult.from_element(query_elt)

        return disco_result


@implementer(disco.IDisco)
class PubsubExtDiscoHandler(XMPPHandler):
    """Handler for pubsub extended discovery requests."""

    def __init__(self, plugin_parent: XEP_0499) -> None:
        self.plugin_parent = plugin_parent

    @property
    def client(self) -> SatXMPPEntity:
        return cast(SatXMPPEntity, self.parent)

    def connectionInitialized(self):
        assert self.xmlstream is not None
        if self.client.is_component:
            # We have to remove Wokkel's disco handler to avoid stanza to be replied
            # twice.
            self.xmlstream.removeObserver(
                disco.DISCO_ITEMS, self.client.discoHandler.handleRequest
            )
            self.xmlstream.addObserver(
                disco.DISCO_ITEMS,
                self.plugin_parent._handle_disco_items_request,
                client=self.client,
            )

    def getDiscoInfo(
        self, requestor: JID, target: JID, nodeIdentifier: str = ""
    ) -> list[disco.DiscoFeature]:
        """Get disco info for pubsub extended discovery

        @param requestor: JID of the requesting entity
        @param target: JID of the target entity
        @param nodeIdentifier: optional node identifier
        @return: list of disco features
        """
        return [
            disco.DiscoFeature(NS_PUBSUB_EXT_DISCO),
        ]

    def getDiscoItems(
        self, requestor: JID, target: JID, nodeIdentifier: str = ""
    ) -> list[disco.DiscoItem]:
        """Get disco items with extended discovery support

        @param requestor: JID of the requesting entity
        @param target: JID of the target entity
        @param nodeIdentifier: optional node identifier
        @return: list of disco items
        """
        # We return empty list for Wokkel disco handling, the extended discovery is done
        # with the ``handle_disco_items_request`` method above.

        return []
