from langchain.tools import BaseTool
from langchain.agents import initialize_agent, AgentType
from langchain.llms import OpenAI
from typing import Optional, Type
from pydantic import BaseModel, Field
import asyncio

class MCPToolInput(BaseModel):
    """Input schema for MCP tool"""
    tool_name: str = Field(description="Name of the MCP tool to call")
    arguments: dict = Field(description="Arguments to pass to the MCP tool")

class MCPTool(BaseTool):
    """LangChain tool that interfaces with MCP servers"""
    name = "mcp_tool"
    description = "Execute tools through Model Context Protocol servers"
    args_schema: Type[BaseModel] = MCPToolInput

    def __init__(self, mcp_client):
        super().__init__()
        self.mcp_client = mcp_client

    def _run(self, tool_name: str, arguments: dict) -> str:
        return asyncio.run(self._arun(tool_name, arguments))

    async def _arun(self, tool_name: str, arguments: dict) -> str:
        try:
            result = await self.mcp_client.session.call_tool(
                name=tool_name,
                arguments=arguments
            )
            if result.content:
                return "\n".join([content.text for content in result.content if hasattr(content, 'text')])
            else:
                return "Tool executed successfully but returned no content"
        except Exception as e:
            return f"Error executing MCP tool: {str(e)}"

class MCPLangChainIntegration:
    """Integration class for using MCP with LangChain"""
    def __init__(self, llm, mcp_client):
        self.llm = llm
        self.mcp_client = mcp_client
        self.tools = []

    async def setup_tools(self):
        mcp_tools = await self.mcp_client.list_tools()
        for mcp_tool in mcp_tools:
            langchain_tool = self.create_langchain_tool(mcp_tool)
            self.tools.append(langchain_tool)

    def create_langchain_tool(self, mcp_tool):
        class DynamicMCPTool(BaseTool):
            name = mcp_tool.name
            description = mcp_tool.description
            def __init__(self, mcp_client, tool_name):
                super().__init__()
                self.mcp_client = mcp_client
                self.tool_name = tool_name
            def _run(self, **kwargs) -> str:
                return asyncio.run(self._arun(**kwargs))
            async def _arun(self, **kwargs) -> str:
                try:
                    result = await self.mcp_client.session.call_tool(
                        name=self.tool_name,
                        arguments=kwargs
                    )
                    if result.content:
                        return "\n".join([content.text for content in result.content if hasattr(content, 'text')])
                    else:
                        return "Tool executed successfully"
                except Exception as e:
                    return f"Error: {str(e)}"
        return DynamicMCPTool(self.mcp_client, mcp_tool.name)

    def create_agent(self):
        return initialize_agent(
            tools=self.tools,
            llm=self.llm,
            agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
            verbose=True
        )

async def demonstrate_langchain_mcp_integration():
    llm = OpenAI(temperature=0)
    mcp_client = MCPDatabaseClient()  # From previous example
    await mcp_client.connect()
    integration = MCPLangChainIntegration(llm, mcp_client)
    await integration.setup_tools()
    agent = integration.create_agent()
    result = agent.run("Get the database schema and then execute a query to find all active users")
    print("Agent result:", result)

if __name__ == "__main__":
    asyncio.run(demonstrate_langchain_mcp_integration())