"""
SQL composition utility module
"""

# Copyright (C) 2020 The Psycopg Team

from __future__ import annotations

import codecs
import string
from abc import ABC, abstractmethod
from typing import Any, overload
from collections.abc import Iterable, Iterator, Sequence

from .pq import Escaping
from .abc import AdaptContext
from ._enums import PyFormat
from ._compat import LiteralString, Template
from ._encodings import conn_encoding
from ._transformer import Transformer


def quote(obj: Any, context: AdaptContext | None = None) -> str:
    """
    Adapt a Python object to a quoted SQL string.

    Use this function only if you absolutely want to convert a Python string to
    an SQL quoted literal to use e.g. to generate batch SQL and you won't have
    a connection available when you will need to use it.

    This function is relatively inefficient, because it doesn't cache the
    adaptation rules. If you pass a `!context` you can adapt the adaptation
    rules used, otherwise only global rules are used.

    """
    return Literal(obj).as_string(context)


class Composable(ABC):
    """
    Abstract base class for objects that can be used to compose an SQL string.

    `!Composable` objects can be joined using the ``+`` operator: the result
    will be a `Composed` instance containing the objects joined. The operator
    ``*`` is also supported with an integer argument: the result is a
    `!Composed` instance containing the left argument repeated as many times as
    requested.

    `!SQL` and `!Composed` objects can be passed directly to
    `~psycopg.Cursor.execute()`, `~psycopg.Cursor.executemany()`,
    `~psycopg.Cursor.copy()` in place of the query string.
    """

    def __init__(self, obj: Any):
        self._obj = obj

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}({self._obj!r})"

    @abstractmethod
    def as_bytes(self, context: AdaptContext | None = None) -> bytes:
        """
        Return the value of the object as bytes.

        :param context: the context to evaluate the object into.
        :type context: `connection` or `cursor`

        The method is automatically invoked by `~psycopg.Cursor.execute()`,
        `~psycopg.Cursor.executemany()`, `~psycopg.Cursor.copy()` if a
        `!Composable` is passed instead of the query string.

        """
        raise NotImplementedError

    def as_string(self, context: AdaptContext | None = None) -> str:
        """
        Return the value of the object as string.

        :param context: the context to evaluate the string into.
        :type context: `connection` or `cursor`

        """
        enc = conn_encoding(context.connection if context else None)
        if isinstance((b := self.as_bytes(context)), bytes):
            return b.decode(enc)
        else:
            # buffer object
            return codecs.lookup(enc).decode(b)[0]

    def __add__(self, other: Composable) -> Composed:
        if isinstance(other, Composed):
            return Composed([self]) + other
        if isinstance(other, Composable):
            return Composed([self]) + Composed([other])
        else:
            return NotImplemented

    def __mul__(self, n: int) -> Composed:
        return Composed([self] * n)

    def __eq__(self, other: Any) -> bool:
        return type(self) is type(other) and self._obj == other._obj

    def __ne__(self, other: Any) -> bool:
        return not self.__eq__(other)


class Composed(Composable):
    """
    A `Composable` object made of a sequence of `!Composable`.

    The object is usually created using `!Composable` operators and methods
    (such as the `SQL.format()` method). `!Composed` objects can be passed
    directly to `~psycopg.Cursor.execute()`, `~psycopg.Cursor.executemany()`,
    `~psycopg.Cursor.copy()` in place of the query string.

    It is also possible to create a `!Composed` directly specifying a sequence
    of objects as arguments: if they are not `!Composable` they will be wrapped
    in a `Literal`.

    Example::

        >>> comp = sql.Composed(
        ...     [sql.SQL("INSERT INTO "), sql.Identifier("table")])
        >>> print(comp.as_string(conn))
        INSERT INTO "table"

    `!Composed` objects are iterable (so they can be used in `SQL.join` for
    instance).
    """

    _obj: list[Composable]

    def __init__(self, seq: Sequence[Any]):
        seq = [obj if isinstance(obj, Composable) else Literal(obj) for obj in seq]
        super().__init__(seq)

    def as_bytes(self, context: AdaptContext | None = None) -> bytes:
        return b"".join(obj.as_bytes(context) for obj in self._obj)

    def __iter__(self) -> Iterator[Composable]:
        return iter(self._obj)

    def __add__(self, other: Composable) -> Composed:
        if isinstance(other, Composed):
            return Composed(self._obj + other._obj)
        if isinstance(other, Composable):
            return Composed(self._obj + [other])
        else:
            return NotImplemented

    def join(self, joiner: SQL | LiteralString) -> Composed:
        """
        Return a new `!Composed` interposing the `!joiner` with the `!Composed` items.

        The `!joiner` must be a `SQL` or a string which will be interpreted as
        an `SQL`.

        Example::

            >>> fields = sql.Identifier('foo') + sql.Identifier('bar')  # a Composed
            >>> print(fields.join(', ').as_string(conn))
            "foo", "bar"

        """
        if isinstance(joiner, str):
            joiner = SQL(joiner)
        elif not isinstance(joiner, SQL):
            raise TypeError(
                "Composed.join() argument must be strings or SQL,"
                f" got {joiner!r} instead"
            )

        return joiner.join(self._obj)


class SQL(Composable):
    """
    A `Composable` representing a snippet of SQL statement.

    `!SQL` exposes `join()` and `format()` methods useful to create a template
    where to merge variable parts of a query (for instance field or table
    names).

    The `!obj` string doesn't undergo any form of escaping, so it is not
    suitable to represent variable identifiers or values: you should only use
    it to pass constant strings representing templates or snippets of SQL
    statements; use other objects such as `Identifier` or `Literal` to
    represent variable parts.

    `!SQL` objects can be passed directly to `~psycopg.Cursor.execute()`,
    `~psycopg.Cursor.executemany()`, `~psycopg.Cursor.copy()` in place of the
    query string.

    Example::

        >>> query = sql.SQL("SELECT {0} FROM {1}").format(
        ...    sql.SQL(', ').join([sql.Identifier('foo'), sql.Identifier('bar')]),
        ...    sql.Identifier('table'))
        >>> print(query.as_string(conn))
        SELECT "foo", "bar" FROM "table"
    """

    _obj: LiteralString
    _formatter = string.Formatter()

    def __init__(self, obj: LiteralString):
        super().__init__(obj)
        if not isinstance(obj, str):
            raise TypeError(f"SQL values must be strings, got {obj!r} instead")

    def as_string(self, context: AdaptContext | None = None) -> str:
        return self._obj

    def as_bytes(self, context: AdaptContext | None = None) -> bytes:
        conn = context.connection if context else None
        enc = conn_encoding(conn)
        return self._obj.encode(enc)

    def format(self, *args: Any, **kwargs: Any) -> Composed:
        """
        Merge `Composable` objects into a template.

        :param args: parameters to replace to numbered (``{0}``, ``{1}``) or
            auto-numbered (``{}``) placeholders
        :param kwargs: parameters to replace to named (``{name}``) placeholders
        :return: the union of the `!SQL` string with placeholders replaced
        :rtype: `Composed`

        The method is similar to the Python `str.format()` method: the string
        template supports auto-numbered (``{}``), numbered (``{0}``,
        ``{1}``...), and named placeholders (``{name}``), with positional
        arguments replacing the numbered placeholders and keywords replacing
        the named ones. However placeholder modifiers (``{0!r}``, ``{0:<10}``)
        are not supported.

        If a `!Composable` objects is passed to the template it will be merged
        according to its `as_string()` method. If any other Python object is
        passed, it will be wrapped in a `Literal` object and so escaped
        according to SQL rules.

        Example::

            >>> print(sql.SQL("SELECT * FROM {} WHERE {} = %s")
            ...     .format(sql.Identifier('people'), sql.Identifier('id'))
            ...     .as_string(conn))
            SELECT * FROM "people" WHERE "id" = %s

            >>> print(sql.SQL("SELECT * FROM {tbl} WHERE name = {name}")
            ...     .format(tbl=sql.Identifier('people'), name="O'Rourke"))
            ...     .as_string(conn))
            SELECT * FROM "people" WHERE name = 'O''Rourke'

        """
        rv: list[Composable] = []
        autonum: int | None = 0
        # TODO: this is probably not the right way to whitelist pre
        # pyre complains. Will wait for mypy to complain too to fix.
        pre: LiteralString
        for pre, name, spec, conv in self._formatter.parse(self._obj):
            if spec:
                raise ValueError("no format specification supported by SQL")
            if conv:
                raise ValueError("no format conversion supported by SQL")
            if pre:
                rv.append(SQL(pre))

            if name is None:
                continue

            if name.isdigit():
                if autonum:
                    raise ValueError(
                        "cannot switch from automatic field numbering to manual"
                    )
                rv.append(args[int(name)])
                autonum = None

            elif not name:
                if autonum is None:
                    raise ValueError(
                        "cannot switch from manual field numbering to automatic"
                    )
                rv.append(args[autonum])
                autonum += 1

            else:
                rv.append(kwargs[name])

        return Composed(rv)

    @overload
    def join(self, seq: Iterable[Template]) -> Template: ...

    @overload
    def join(self, seq: Iterable[Any]) -> Composed: ...

    def join(self, seq: Iterable[Any]) -> Composed | Template:
        """
        Join a sequence of `Composable`.

        :param seq: the elements to join.

        Use the `!SQL` object's string to separate the elements in `!seq`.
        Elements that are not `Composable` will be considered `Literal`.

        If the arguments are `Template` instance, return a `Template` joining
        all the items. Note that arguments must either be all templates or
        none should be.

        Note that `Composed` objects are iterable too, so they can be used as
        argument for this method.

        Example::

            >>> snip = sql.SQL(', ').join(
            ...     sql.Identifier(n) for n in ['foo', 'bar', 'baz'])
            >>> print(snip.as_string(conn))
            "foo", "bar", "baz"
        """

        it = iter(seq)
        try:
            first = next(it)
        except StopIteration:
            return Composed([])

        if isinstance(first, Template):
            items = list(first)
            for t in it:
                if not isinstance(t, Template):
                    raise TypeError(f"can't mix Template and {type(t).__name__}")
                items.append(self._obj)
                items.extend(t)
            return Template(*items)

        cs = [first]
        for i in it:
            if isinstance(i, Template):
                raise TypeError(f"can't mix Template and {type(i).__name__}")
            cs.append(self)
            cs.append(i)

        return Composed(cs)


class Identifier(Composable):
    """
    A `Composable` representing an SQL identifier or a dot-separated sequence.

    Identifiers usually represent names of database objects, such as tables or
    fields. PostgreSQL identifiers follow `different rules`__ than SQL string
    literals for escaping (e.g. they use double quotes instead of single).

    .. __: https://www.postgresql.org/docs/current/sql-syntax-lexical.html# \
        SQL-SYNTAX-IDENTIFIERS

    Example::

        >>> t1 = sql.Identifier("foo")
        >>> t2 = sql.Identifier("ba'r")
        >>> t3 = sql.Identifier('ba"z')
        >>> print(sql.SQL(', ').join([t1, t2, t3]).as_string(conn))
        "foo", "ba'r", "ba""z"

    Multiple strings can be passed to the object to represent a qualified name,
    i.e. a dot-separated sequence of identifiers.

    Example::

        >>> query = sql.SQL("SELECT {} FROM {}").format(
        ...     sql.Identifier("table", "field"),
        ...     sql.Identifier("schema", "table"))
        >>> print(query.as_string(conn))
        SELECT "table"."field" FROM "schema"."table"

    """

    _obj: Sequence[str]

    def __init__(self, *strings: str):
        # init super() now to make the __repr__ not explode in case of error
        super().__init__(strings)

        if not strings:
            raise TypeError("Identifier cannot be empty")

        for s in strings:
            if not isinstance(s, str):
                raise TypeError(
                    f"SQL identifier parts must be strings, got {s!r} instead"
                )

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}({', '.join(map(repr, self._obj))})"

    def as_bytes(self, context: AdaptContext | None = None) -> bytes:
        if conn := (context.connection if context else None):
            esc = Escaping(conn.pgconn)
            enc = conn_encoding(conn)
            escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj]
        else:
            escs = [self._escape_identifier(s.encode()) for s in self._obj]
        return b".".join(escs)

    def _escape_identifier(self, s: bytes) -> bytes:
        """
        Approximation of PQescapeIdentifier taking no connection.
        """
        return b'"' + s.replace(b'"', b'""') + b'"'


class Literal(Composable):
    """
    A `Composable` representing an SQL value to include in a query.

    Usually you will want to include placeholders in the query and pass values
    as `~cursor.execute()` arguments. If however you really really need to
    include a literal value in the query you can use this object.

    The string returned by `!as_string()` follows the normal :ref:`adaptation
    rules <types-adaptation>` for Python objects.

    Example::

        >>> s1 = sql.Literal("fo'o")
        >>> s2 = sql.Literal(42)
        >>> s3 = sql.Literal(date(2000, 1, 1))
        >>> print(sql.SQL(', ').join([s1, s2, s3]).as_string(conn))
        'fo''o', 42, '2000-01-01'::date

    """

    def as_bytes(self, context: AdaptContext | None = None) -> bytes:
        tx = Transformer.from_context(context)
        return tx.as_literal(self._obj)


class Placeholder(Composable):
    """A `Composable` representing a placeholder for query parameters.

    If the name is specified, generate a named placeholder (e.g. ``%(name)s``,
    ``%(name)b``), otherwise generate a positional placeholder (e.g. ``%s``,
    ``%b``).

    The object is useful to generate SQL queries with a variable number of
    arguments.

    Examples::

        >>> names = ['foo', 'bar', 'baz']

        >>> q1 = sql.SQL("INSERT INTO my_table ({}) VALUES ({})").format(
        ...     sql.SQL(', ').join(map(sql.Identifier, names)),
        ...     sql.SQL(', ').join(sql.Placeholder() * len(names)))
        >>> print(q1.as_string(conn))
        INSERT INTO my_table ("foo", "bar", "baz") VALUES (%s, %s, %s)

        >>> q2 = sql.SQL("INSERT INTO my_table ({}) VALUES ({})").format(
        ...     sql.SQL(', ').join(map(sql.Identifier, names)),
        ...     sql.SQL(', ').join(map(sql.Placeholder, names)))
        >>> print(q2.as_string(conn))
        INSERT INTO my_table ("foo", "bar", "baz") VALUES (%(foo)s, %(bar)s, %(baz)s)

    """

    def __init__(self, name: str = "", format: str | PyFormat = PyFormat.AUTO):
        super().__init__(name)
        if not isinstance(name, str):
            raise TypeError(f"expected string as name, got {name!r}")

        if ")" in name:
            raise ValueError(f"invalid name: {name!r}")

        if type(format) is str:
            format = PyFormat(format)
        if not isinstance(format, PyFormat):
            raise TypeError(
                f"expected PyFormat as format, got {type(format).__name__!r}"
            )

        self._format: PyFormat = format

    def __repr__(self) -> str:
        parts = []
        if self._obj:
            parts.append(repr(self._obj))
        if self._format is not PyFormat.AUTO:
            parts.append(f"format={self._format.name}")

        return f"{self.__class__.__name__}({', '.join(parts)})"

    def as_string(self, context: AdaptContext | None = None) -> str:
        code = self._format.value
        return f"%({self._obj}){code}" if self._obj else f"%{code}"

    def as_bytes(self, context: AdaptContext | None = None) -> bytes:
        conn = context.connection if context else None
        enc = conn_encoding(conn)
        return self.as_string(context).encode(enc)


# Literals
NULL = SQL("NULL")
DEFAULT = SQL("DEFAULT")


def as_string(obj: Any, context: AdaptContext | None = None) -> str:
    """Convert an object to a string according to SQL rules.

    :param obj: the object to convert
    :param context: the context in which to convert the object
    :type context: `~psycopg.abc.AdaptContext` | `!None`

    Adaptation happens according to the type of `!obj`:

    - `Composable` objects are converted according to their
      `~Composable.as_string()` method;
    - `~string.templatelib.Template` strings are converted according to the
      rules documented in :ref:`template-strings`;
    - every other object is converted as it was :ref:`a parameter passed to a
      query <types-adaptation>`.

    If `!context` is specified then it is be used to customize the conversion.
    for example using the encoding of a connection or the dumpers registered.
    """
    if isinstance(obj, Composable):
        return obj.as_string(context=context)
    elif isinstance(obj, Template):
        from ._tstrings import as_string

        return as_string(obj, context)
    else:
        return Literal(obj).as_string(context=context)


def as_bytes(obj: Any, context: AdaptContext | None = None) -> bytes:
    """Convert an object to a bytes string according to SQL rules.

    :param obj: the object to convert
    :param context: the context in which to convert the object
    :type context: `~psycopg.abc.AdaptContext` | `!None`

    See `as_string()` for details.
    """
    if isinstance(obj, Composable):
        return obj.as_bytes(context=context)
    elif isinstance(obj, Template):
        from ._tstrings import as_bytes

        return as_bytes(obj, context)
    else:
        return Literal(obj).as_bytes(context=context)
