"""Test yamlOutputParser"""

from enum import Enum
from typing import Optional

import pytest
from langchain_core.exceptions import OutputParserException
from langchain_core.pydantic_v1 import BaseModel, Field

from langchain.output_parsers.yaml import YamlOutputParser


class Actions(Enum):
    SEARCH = "Search"
    CREATE = "Create"
    UPDATE = "Update"
    DELETE = "Delete"


class TestModel(BaseModel):
    action: Actions = Field(description="Action to be performed")
    action_input: str = Field(description="Input to be used in the action")
    additional_fields: Optional[str] = Field(
        description="Additional fields", default=None
    )
    for_new_lines: str = Field(description="To be used to test newlines")


# Prevent pytest from trying to run tests on TestModel
TestModel.__test__ = False  # type: ignore[attr-defined]


DEF_RESULT = """```yaml
---

action: Update
action_input: The yamlOutputParser class is powerful
additional_fields: null
for_new_lines: |
  not_escape_newline:
   escape_newline: 

```"""
DEF_RESULT_NO_BACKTICKS = """
action: Update
action_input: The yamlOutputParser class is powerful
additional_fields: null
for_new_lines: |
  not_escape_newline:
   escape_newline: 

"""

# action 'update' with a lowercase 'u' to test schema validation failure.
DEF_RESULT_FAIL = """```yaml
action: update
action_input: The yamlOutputParser class is powerful
additional_fields: null
```"""

DEF_EXPECTED_RESULT = TestModel(
    action=Actions.UPDATE,
    action_input="The yamlOutputParser class is powerful",
    additional_fields=None,
    for_new_lines="not_escape_newline:\n escape_newline: \n",
)


@pytest.mark.parametrize("result", [DEF_RESULT, DEF_RESULT_NO_BACKTICKS])
def test_yaml_output_parser(result: str) -> None:
    """Test yamlOutputParser."""

    yaml_parser: YamlOutputParser[TestModel] = YamlOutputParser(
        pydantic_object=TestModel
    )

    model = yaml_parser.parse(result)
    print("parse_result:", result)  # noqa: T201
    assert DEF_EXPECTED_RESULT == model


def test_yaml_output_parser_fail() -> None:
    """Test YamlOutputParser where completion result fails schema validation."""

    yaml_parser: YamlOutputParser[TestModel] = YamlOutputParser(
        pydantic_object=TestModel
    )

    try:
        yaml_parser.parse(DEF_RESULT_FAIL)
    except OutputParserException as e:
        print("parse_result:", e)  # noqa: T201
        assert "Failed to parse TestModel from completion" in str(e)
    else:
        assert False, "Expected OutputParserException"
