Shortcuts

Source code for torchrl.envs.llm.transforms.browser

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Browser automation transform for LLM agents."""

from __future__ import annotations

import asyncio
from typing import Any
from urllib.parse import urlparse

from tensordict import TensorDictBase

from torchrl.envs.llm.transforms.tools import MCPToolTransform

# Schema for the browser tool
BROWSER_SCHEMA = {
    "name": "browser",
    "description": "Browse and interact with web pages",
    "parameters": {
        "type": "object",
        "properties": {
            "action": {
                "type": "string",
                "enum": [
                    "navigate",
                    "click",
                    "type",
                    "screenshot",
                    "extract",
                    "scroll",
                ],
                "description": "The action to perform",
            },
            "url": {
                "type": "string",
                "description": "URL to navigate to (for navigate action)",
            },
            "selector": {
                "type": "string",
                "description": "CSS selector to target element (for click/type/extract actions)",
            },
            "text": {
                "type": "string",
                "description": "Text to type (for type action)",
            },
            "scroll_amount": {
                "type": "integer",
                "description": "Amount to scroll in pixels (for scroll action)",
            },
            "extract_type": {
                "type": "string",
                "enum": ["text", "html", "attribute"],
                "description": "What to extract from the element (for extract action)",
            },
            "attribute": {
                "type": "string",
                "description": "Attribute name to extract (for extract action with extract_type=attribute)",
            },
        },
        "required": ["action"],
        "allOf": [
            {
                "if": {"properties": {"action": {"const": "navigate"}}},
                "then": {"required": ["url"]},
            },
            {
                "if": {
                    "properties": {
                        "action": {"enum": ["click", "type", "extract"]},
                    }
                },
                "then": {"required": ["selector"]},
            },
            {
                "if": {"properties": {"action": {"const": "type"}}},
                "then": {"required": ["text"]},
            },
            {
                "if": {"properties": {"action": {"const": "scroll"}}},
                "then": {"required": ["scroll_amount"]},
            },
            {
                "if": {"properties": {"action": {"const": "extract"}}},
                "then": {"required": ["extract_type"]},
            },
            {
                "if": {
                    "properties": {
                        "action": {"const": "extract"},
                        "extract_type": {"const": "attribute"},
                    }
                },
                "then": {"required": ["attribute"]},
            },
        ],
    },
}


[docs]class BrowserTransform(MCPToolTransform): """A transform that enables web browsing capabilities. This transform allows LLM agents to interact with web pages through a browser, supporting actions like navigation, clicking, typing, and extracting content. For a complete example of how to use this transform, see the :ref:`llm_tools` tutorial. Args: allowed_domains (list[str], optional): List of allowed domains. If None, all domains are allowed. headless (bool): Whether to run browser in headless mode. Defaults to True. timeout (float): Timeout for browser operations in seconds. Defaults to 30.0. tokenizer: The tokenizer to use. Defaults to None. tool_name (str): The name of the tool in chat history. Defaults to "tool". """ def __init__( self, allowed_domains: list[str] | None = None, headless: bool = True, timeout: float = 30.0, tokenizer=None, # type: ignore tool_name: str = "tool", ): self.allowed_domains = allowed_domains self.headless = headless self.browser = None self.context = None self.page = None self.loop = asyncio.get_event_loop() super().__init__( tools={"browser": self._execute_browser_action}, tool_schemas={"browser": BROWSER_SCHEMA}, tokenizer=tokenizer, tool_name=tool_name, timeout=timeout, ) async def _init_browser(self): """Initialize the browser if not already initialized.""" from playwright.async_api import async_playwright if self.browser is None: playwright = await async_playwright().start() self.browser = await playwright.chromium.launch(headless=self.headless) self.context = await self.browser.new_context() self.page = await self.context.new_page() def _validate_url(self, url: str) -> bool: """Validate if the URL is allowed based on domain restrictions.""" if not self.allowed_domains: return True try: domain = urlparse(url).netloc return any(domain.endswith(d) for d in self.allowed_domains) except Exception: return False async def _navigate(self, url: str) -> dict[str, Any]: """Navigate to a URL.""" if not self._validate_url(url): return { "success": False, "error": f"Domain not allowed. Must be one of: {self.allowed_domains}", } try: await self._init_browser() response = await self.page.goto(url, wait_until="networkidle") return { "success": True, "result": { "url": self.page.url, "status": response.status if response else None, }, } except Exception as e: return {"success": False, "error": str(e)} async def _click(self, selector: str) -> dict[str, Any]: """Click an element on the page.""" try: await self._init_browser() await self.page.click(selector) return {"success": True, "result": {"clicked": selector}} except Exception as e: return {"success": False, "error": str(e)} async def _type(self, selector: str, text: str) -> dict[str, Any]: """Type text into an element.""" try: await self._init_browser() await self.page.fill(selector, text) return {"success": True, "result": {"typed": text, "into": selector}} except Exception as e: return {"success": False, "error": str(e)} async def _scroll(self, amount: int) -> dict[str, Any]: """Scroll the page.""" try: await self._init_browser() await self.page.evaluate(f"window.scrollBy(0, {amount})") return {"success": True, "result": {"scrolled": amount}} except Exception as e: return {"success": False, "error": str(e)} async def _extract( self, selector: str, extract_type: str, attribute: str | None = None ) -> dict[str, Any]: """Extract content from the page.""" try: await self._init_browser() element = await self.page.wait_for_selector(selector) if not element: return {"success": False, "error": f"Element not found: {selector}"} if extract_type == "text": content = await element.text_content() elif extract_type == "html": content = await element.inner_html() elif extract_type == "attribute" and attribute: content = await element.get_attribute(attribute) else: return {"success": False, "error": "Invalid extraction type"} return { "success": True, "result": {"content": content, "type": extract_type}, } except Exception as e: return {"success": False, "error": str(e)} async def _execute_browser_action_async(self, **kwargs) -> dict[str, Any]: """Execute a browser action asynchronously.""" action = kwargs.pop("action") if action == "navigate": return await self._navigate(kwargs["url"]) elif action == "click": return await self._click(kwargs["selector"]) elif action == "type": return await self._type(kwargs["selector"], kwargs["text"]) elif action == "scroll": return await self._scroll(kwargs["scroll_amount"]) elif action == "extract": return await self._extract( kwargs["selector"], kwargs["extract_type"], kwargs.get("attribute"), ) else: return {"success": False, "error": f"Unknown action: {action}"} def _execute_browser_action(self, **kwargs) -> dict[str, Any]: """Execute a browser action.""" return self.loop.run_until_complete( self._execute_browser_action_async(**kwargs) )
[docs] def close(self): """Close the browser and clean up resources.""" if self.browser: self.loop.run_until_complete(self.browser.close()) self.browser = None self.context = None self.page = None
def __del__(self): """Ensure browser is closed on deletion.""" self.close() def _reset( self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase ) -> TensorDictBase: """Reset the browser state.""" # Close and reinitialize browser on reset self.close() return tensordict_reset
[docs] def clone(self): """Clone the browser transform.""" return self.__class__( allowed_domains=self.allowed_domains, headless=self.headless, timeout=self.timeout, tokenizer=self.tokenizer, tool_name=self.tool_name, )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources