from typing import Any, Callable, List, Literal, Optional, Type

import pytest

from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import BaseTool, tool
from langchain_core.utils.function_calling import convert_to_openai_function


@pytest.fixture()
def pydantic() -> Type[BaseModel]:
    class dummy_function(BaseModel):
        """dummy function"""

        arg1: int = Field(..., description="foo")
        arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")

    return dummy_function


@pytest.fixture()
def function() -> Callable:
    def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
        """dummy function

        Args:
            arg1: foo
            arg2: one of 'bar', 'baz'
        """
        pass

    return dummy_function


@pytest.fixture()
def dummy_tool() -> BaseTool:
    class Schema(BaseModel):
        arg1: int = Field(..., description="foo")
        arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")

    class DummyFunction(BaseTool):
        args_schema: Type[BaseModel] = Schema
        name: str = "dummy_function"
        description: str = "dummy function"

        def _run(self, *args: Any, **kwargs: Any) -> Any:
            pass

    return DummyFunction()


def test_convert_to_openai_function(
    pydantic: Type[BaseModel], function: Callable, dummy_tool: BaseTool
) -> None:
    expected = {
        "name": "dummy_function",
        "description": "dummy function",
        "parameters": {
            "type": "object",
            "properties": {
                "arg1": {"description": "foo", "type": "integer"},
                "arg2": {
                    "description": "one of 'bar', 'baz'",
                    "enum": ["bar", "baz"],
                    "type": "string",
                },
            },
            "required": ["arg1", "arg2"],
        },
    }

    for fn in (pydantic, function, dummy_tool, expected):
        actual = convert_to_openai_function(fn)  # type: ignore
        assert actual == expected


@pytest.mark.xfail(reason="Pydantic converts Optional[str] to str in .schema()")
def test_function_optional_param() -> None:
    @tool
    def func5(
        a: Optional[str],
        b: str,
        c: Optional[List[Optional[str]]],
    ) -> None:
        """A test function"""
        pass

    func = convert_to_openai_function(func5)
    req = func["parameters"]["required"]
    assert set(req) == {"b"}
