Source code for torchrl.envs.llm.libs.mlgym
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import functools
import json
import os
import re
from contextlib import contextmanager
from dataclasses import asdict
from pathlib import Path
from typing import Any, Literal
import numpy as np
import torch
from tensordict import NestedKey, NonTensorData, TensorDict, TensorDictBase
from tensordict.tensorclass import is_non_tensor
from torchrl._utils import logger as torchrl_logger
from torchrl.data import Choice, Composite, NonTensor
from torchrl.data.llm import History
from torchrl.envs import ConditionalSkip, GymWrapper, Transform, TransformedEnv
# Inv transforms:
# Transforms to apply prior to pass the model output to the env
@contextmanager
def _temp_cwd_mlgym():
"""Temporarily change the current working directory to mlgym."""
import mlgym
path = Path(mlgym.__spec__.submodule_search_locations[0]).parent
old_pwd = os.getcwd()
os.chdir(str(path))
# sys.path.insert(-1, "mlgym")
try:
yield
finally:
# sys.path.pop()
os.chdir(old_pwd)
class MLGymBaseTransform(Transform):
"""Base class for all MLGym transforms."""
@property
def config(self):
return self.parent.base_env.config
@property
def system_args(self):
return {
"command_docs": self.config.tools_handler.command_docs,
**self.config.tools_handler.env_variables,
}
@property
def task_args(self):
# Placeholder
task_args = getattr(self, "_task_args", None)
if task_args is None:
return self.parent.base_env.task.args
return task_args
@task_args.setter
def task_args(self, task_args):
self._task_args = task_args
@property
def name(self):
return "torchrl"
@property
def state_command(self):
return self.config.state_command.name
@property
def agent_args(self):
return self.parent.base_env.agent_args
@property
def model_name(self) -> Literal["human", "human_thought"]:
return self.agent_args.model.model_name
#######################################################
# Forward transforms: Format the env output
# Transform #0: Resets the env
class ResetModule(MLGymBaseTransform):
"""Runs setup pipeline and enables multi-resets.
The reset method reads the 'system' initial input from the config and parses it to a History
object.
"""
response_key: NestedKey = "text_response"
def __init__(self):
super().__init__(in_keys=[], out_keys=["history"])
@_temp_cwd_mlgym()
def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase:
base_env = self.parent.base_env._env
if tensordict is not None and "task" in tensordict:
import gymnasium as gym
task = tensordict["task"]
torchrl_logger.info(f"Resetting with {task=}")
if is_non_tensor(task):
task = task.data
task_id, agent_args = _TASK_IDS[task]
try:
base_env.close()
except Exception:
torchrl_logger.info(f"Failed to close {base_env=}")
base_env = gym.make(
f"mlgym/{task}",
devices=["cpu_0"],
).unwrapped
base_env.config = agent_args.config
self.parent.base_env.set_env(base_env)
base_env.reset_container()
base_env.communicate(f"cd {Path(base_env.task_workspace).parent}")
return tensordict
@_temp_cwd_mlgym()
def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
# TODO: what to do with this?
# reset model stats
# self.model.reset_stats(init_model_stats)
# env = self.parent.base_env._env
env = self.parent.base_env._env
self.set_environment_vars(env, self.config.env_variables)
system_msg = self.config.system_template.format(
**self.system_args, **asdict(self.task_args)
)
# self.logger.log(self._default_logging_level, f"SYSTEM ({self.name})\n{system_msg}")
history = History(
role="system",
content=system_msg, # agent=self.name,
batch_size=(1,),
device=self.parent.device,
)
tensordict_reset["history"] = history
return tensordict_reset
def _step(
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
) -> TensorDictBase:
# Placeholder
if "history" not in next_tensordict:
if "local_history" in tensordict:
local_history = tensordict["local_history"]
else:
local_history = None
history = tensordict["history"]
if local_history is not None:
history = history.append(local_history, inplace=False)
tensordict["history"] = history
next_tensordict["history"] = history
return next_tensordict
def set_environment_vars(
self, env: MLGymWrapper, env_variables: dict[str, Any]
) -> None:
commands_to_execute = (
[self.config.state_command.code]
+ # [code for code in self.config.util_functions] +
# [command.code for command in self.config._commands] +
[f"{k}={v}" for k, v in env_variables.items()]
)
commands = "\n".join(commands_to_execute)
try:
output = env.communicate(commands)
if env.returncode != 0:
msg = f"Nonzero return code: {env.returncode}\nOutput: {output}"
raise RuntimeError(msg)
except KeyboardInterrupt:
raise
except Exception as e:
raise e
command_files = []
for file in self.config.command_files:
datum = {}
with open(file) as f:
contents = f.read()
datum["contents"] = contents
filename = Path(file).name
if not contents.strip().startswith("#!"):
if filename.endswith(".sh"):
# files are sourced, so they are not executable
datum["name"] = Path(file).name
datum["type"] = "source_file"
elif filename.startswith("_"):
# files are sourced, so they are not executable
datum["name"] = Path(file).name
datum["type"] = "utility"
else:
msg = (
f"Non-shell script file {file} does not start with shebang.\n"
"Either add a shebang (#!) or change the file extension to .sh if you want to source it.\n"
"You can override this behavior by adding an underscore to the file name (e.g. _utils.py)."
)
raise ValueError(msg)
else:
# scripts are made executable
datum["name"] = Path(file).name.rsplit(".", 1)[0]
datum["type"] = "script"
command_files.append(datum)
# TODO: implement add commands method in environment
env.add_commands(command_files)
def transform_observation_spec(self, observation_spec: Composite) -> Composite:
observation_spec["history"] = History.default_spec()
return observation_spec
def transform_action_spec(self, action_spec: Composite) -> Composite:
if isinstance(action_spec, Composite):
action_spec[self.response_key] = self.transform_action_spec(
action_spec[self.response_key]
)
return action_spec
# make the "random" action just a choice between innocuous bash commands
return Choice(
[
NonTensor(example_data="ls -rtlh", shape=action_spec.shape),
NonTensor(example_data="pwd", shape=action_spec.shape),
]
)
def transform_state_spec(self, state_spec: Composite) -> Composite:
state_spec["history"] = History.default_spec()
return state_spec
class TaskSampler(Transform):
"""A sampler for tasks in a certain task set."""
def __init__(self, tasks: list[str]):
super().__init__()
self.tasks = tasks
def transform_observation_spec(self, observation_spec: Composite) -> Composite:
observation_spec["task"] = NonTensor(example_data="<a task>", shape=())
return observation_spec
@_temp_cwd_mlgym()
def _reset_env_preprocess(
self, tensordict: TensorDictBase | None
) -> TensorDictBase:
if tensordict is None:
tensordict = TensorDict(batch_size=self.parent.batch_size)
# Sample a task
task = np.random.choice(self.tasks)
tensordict["task"] = NonTensorData(task)
self._current_task = task
return tensordict
def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
next_tensordict["task"] = self._current_task
return next_tensordict
# Transform #1: env -> state
class ReadState(MLGymBaseTransform):
"""Reads current state and writes it as a parsable str in the tensordict."""
# from mlgym/agent/base.py:BaseAgent:forward_model
def _step(
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
) -> TensorDictBase:
base_mlgym_env = self.parent.base_env # getattr is forwarded
command = self.state_command
state = base_mlgym_env.communicate(command) if self.state_command else None
next_tensordict["state"] = state
return next_tensordict
def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
# tensordict_reset.setdefault("message", NonTensorData(""))
# tensordict_reset.setdefault("state", NonTensorData(""))
return self._step(tensordict_reset, tensordict_reset)
def transform_observation_spec(self, observation_spec):
observation_spec.set(
"state",
NonTensor(
example_data="a string",
device=observation_spec.device,
shape=observation_spec.shape,
),
)
return observation_spec
# Transform #2: state -> message
class StateToMessage(MLGymBaseTransform):
"""Parses the string using json to a given template.
Requires:
- a 'state' key from the ReadState transform
- an 'observation' key from the base environment
"""
def _step(
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
) -> TensorDictBase:
base_mlgym_env = self.parent.base_env # getattr is forwarded
observation = tensordict["observation"]
state = tensordict["state"]
config = self.config
current_step = base_mlgym_env.current_step
max_steps = base_mlgym_env.max_steps
try:
state_vars = json.loads(state)
except json.JSONDecodeError as e:
msg = f"State {state!r} is not valid json. This is an internal error, please report it."
raise ValueError(msg) from e
# add step information to state_vars
state_vars["current_step"] = current_step
state_vars["remaining_steps"] = max_steps - current_step
# FIXME: we don't need to do this, we have our own observation space
# Determine observation template based on what prior observation was
history: History = tensordict["history"]
if history[..., -1].role == "system":
# Show task template if prev. obs. was initial system message
templates = [config.task_template]
if config.strategy_template is not None:
templates.append(config.strategy_template)
elif observation is None or observation.strip() == "":
# Show no output template if observation content was empty
assert config.next_step_no_output_template is not None # linting
templates = [config.next_step_no_output_template]
else:
# Show standard output template if there is observation content
assert config.next_step_template is not None # linting
templates = [config.next_step_template]
# Format selected template(s) with information
messages = []
assert self.task_args is not None
for template in templates:
messages.append(
template.format(
**asdict(self.task_args),
**self.system_args,
**state_vars,
observation=(observation if observation is not None else ""),
# missing forwarded_vars because no attempts
),
)
message = "\n".join(messages)
next_tensordict["message"] = message
# model query hooks here
return next_tensordict
def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
# tensordict_reset.setdefault("message", NonTensorData(""))
# tensordict_reset.setdefault("state", NonTensorData(""))
return self._step(tensordict_reset, tensordict_reset)
def transform_observation_spec(self, observation_spec):
observation_spec.set(
"message",
NonTensor(
example_data="a string",
device=observation_spec.device,
shape=observation_spec.shape,
),
)
return observation_spec
# Transform #3: Append message to history
class MessageToHistory(MLGymBaseTransform):
"""Parses the message string to a History object, then reparses the history to a complete message.
.. seealso:: HistoryToMessage
"""
def __init__(self):
super().__init__(in_keys=["message", "history"], out_keys=["history", "chat"])
# from mlgym/agent/base.py:BaseAgent:local_history
# from mlgym/agent/base.py:BaseAgent:_append_history
def _step(
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
) -> TensorDictBase:
# From PrepareDataForModel
message: str = next_tensordict["message"]
# from mlgym/agent/base.py:BaseAgent:forward_model
history = tensordict["history"]
cur_history = History(
role="user", content=message, batch_size=(), device=self.parent.device
)
# This is the basic thing our transform does: append the history to the existing one.
# (We should be able to extend the lazy stack directly)
history = history.append(cur_history, inplace=False)
next_tensordict["history"] = history
return next_tensordict
def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
return self._step(tensordict_reset, tensordict_reset)
# Inverse transforms:
# Format the action from the model for the env
class TemplateTransform(MLGymBaseTransform):
"""A transform to apply the chat template to the History."""
response_key: NestedKey = "text_response"
prompt_key: NestedKey = "text"
# alternative to DummyFormat, wip
def __init__(
self,
in_keys=None,
out_keys=None,
in_keys_inv=None,
out_keys_inv=None,
tokenizer=None,
chat_template_name: Literal["chatml_format"] | None = None,
continue_final_message: bool = False,
tokenize: bool = False,
return_tensors: str = "pt",
return_dict: bool = False,
padding: bool | str = False,
truncation: bool | str = False,
):
super().__init__(
in_keys=["history"] if in_keys is None else in_keys,
out_keys=[self.prompt_key] if out_keys is None else out_keys,
in_keys_inv=[self.prompt_key, self.response_key]
if in_keys_inv is None
else in_keys_inv,
# TODO: we should not use the response key here but another dedicated entry, like "action_parsed"
out_keys_inv=[self.response_key] if out_keys_inv is None else out_keys_inv,
)
self.chat_template_name = chat_template_name
self.tokenizer = tokenizer
self.tokenize = tokenize
self.continue_final_message = continue_final_message
self.return_tensors = return_tensors
self.return_dict = return_dict
self.padding = padding
self.truncation = truncation
def transform_observation_spec(self, observation_spec: Composite):
observation_spec[self.prompt_key] = NonTensor(
example_data="<some chat string>",
shape=observation_spec.shape,
device=observation_spec.device,
)
return observation_spec
@property
def _chat_template(self):
chat_template = None
if self.chat_template_name:
from torchrl.data.llm.datatypes.chat import _CHAT_TEMPLATES
chat_template = _CHAT_TEMPLATES[self.chat_template_name]
elif self.tokenizer.chat_template is not None:
chat_template = self.tokenizer.chat_template
elif chat_template is None:
raise ValueError("Failed to determine chat template.")
return chat_template
def _apply_transform(self, history: History) -> NonTensorData:
if self.tokenizer is None:
raise RuntimeError("Cannot apply chat template without a tokenizer.")
result = history.apply_chat_template(
tokenizer=self.tokenizer,
add_generation_prompt=True,
chat_template=self._chat_template,
continue_final_message=self.continue_final_message,
tokenize=self.tokenize,
padding=self.padding,
truncation=self.truncation,
return_tensors=self.return_tensors,
)
return result
def _reset(self, tensordict, tensordict_reset):
return self._call(tensordict_reset)
def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.in_keys_inv:
prompt = tensordict[self.prompt_key]
response = tensordict[self.response_key]
if isinstance(prompt, list):
action = [
prompt + response for prompt, response in zip(prompt, response)
]
else:
action = prompt + response
try:
history, action = self._inv_apply_transform(action)
tensordict["local_history"] = history
tensordict[self.response_key] = action
except RuntimeError as e:
if "Expected assistant role" in str(e):
tensordict["local_history"] = History(role="assistant", content="")
tensordict[self.response_key] = ""
return tensordict
def _inv_apply_transform(self, action):
if self.tokenize:
action = self.tokenizer.decode(action)
if not isinstance(action, (str, list)):
action = action.data
history, action = self._inv_apply_transform(action)
action = NonTensorData(
action, batch_size=action.batch_size, device=action.device
)
return history, action
history = History.from_text(
action,
# chat_template=self._chat_template,
)[..., -1]
if history.role != "assistant":
raise RuntimeError(f"Expected assistant role, got {history.role=}")
action = history.get("content")
return history, action
class IsolateCodeBlock(MLGymBaseTransform):
"""A transform that isolates the code block in the action generated by the LLM.
Optionally, wrongly formatted actions are assigned a negative reward.
"""
response_key: NestedKey = "text_response"
def __init__(self, reward_wrong_format: float | None = None):
super().__init__(
in_keys_inv=[self.response_key], out_keys_inv=[self.response_key]
)
from mlgym.agent.parsing import ThoughtActionParser
self.parser = ThoughtActionParser()
self.reward_wrong_format = reward_wrong_format
self._assign_reward = False
def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
torchrl_logger.info("inv call with IsolateCodeBlock")
action = tensordict[self.response_key]
# if we didn't find an action, the action is empty
if not action:
torchrl_logger.info(
"Did not find a suitable action, skipping the call to step."
)
tensordict["retry"] = torch.ones(tensordict.shape, dtype=torch.bool)
self._assign_reward = True
else:
from mlgym.exceptions import FormatError
try:
action = self._inv_apply_transform(action)
tensordict[self.response_key] = action
torchrl_logger.info(f"Code block: {action}")
tensordict["retry"] = torch.zeros(tensordict.shape, dtype=torch.bool)
self._assign_reward = False
except FormatError:
tensordict["retry"] = torch.ones(tensordict.shape, dtype=torch.bool)
self._assign_reward = True
return tensordict
def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
if self._assign_reward:
torchrl_logger.info(
f"Assigning penalty for unsuitable action: {self.reward_wrong_format}"
)
if self.reward_wrong_format is not None:
tensordict[self.parent.reward_key] += self.reward_wrong_format
return tensordict
def _inv_apply_transform(self, action):
if not isinstance(action, (str, list)):
return NonTensorData(
self._inv_apply_transform(action.tolist()),
batch_size=action.batch_size,
device=action.device,
)
if isinstance(action, list):
return [self._inv_apply_transform(action) for action in action]
thought, action = self.parser(action, None)
return action
class EvaluationOutputParser:
"""Parser for the reward transform in MLGym.
.. seealso:: :class:`~torchrl.envs.llm.libs.mlgym.MLGymRewardAssignment`
"""
def __init__(self):
# Regular expressions to match the required fields
self.patterns = {
"submission_artefact_path": r"valid submission artefact at (.*)\.",
"baseline_score": r"Baseline Score: \{'Score': (.*)\}",
"evaluation_score": r"Evaluation Score: \{'Score': (.*)\}",
"current_step": r"\(Current Step: (\d+),",
"remaining_steps": r"Remaining Steps: (\d+)\)",
"open_file": r"\(Open file: (.*)\)",
"current_directory": r"\(Current directory: (.*)\)",
}
def __call__(self, output_string):
parsed_data = {}
for key, pattern in self.patterns.items():
match = re.search(pattern, output_string)
if match:
parsed_data[key] = match.group(1).strip()
if "baseline_score" in parsed_data:
parsed_data["baseline_score"] = float(parsed_data["baseline_score"])
if "evaluation_score" in parsed_data:
parsed_data["evaluation_score"] = float(parsed_data["evaluation_score"])
if "current_step" in parsed_data:
parsed_data["current_step"] = int(parsed_data["current_step"])
if "remaining_steps" in parsed_data:
parsed_data["remaining_steps"] = int(parsed_data["remaining_steps"])
return parsed_data
class MLGymRewardAssignment(MLGymBaseTransform):
"""Reward assignment through parsing of the last item in history.
By default, the :class:`~torchrl.envs.llm.libs.mlgym.EvaluationOutputParser` class is used as parser.
"""
def __init__(self):
super().__init__(in_keys=["reward", "history"], out_keys=["reward"])
self.parser = EvaluationOutputParser()
def _call(self, tensordict):
history = tensordict.get("history")
if history is None:
raise KeyError(f"History is missing in tensordict {tensordict}")
if history.ndim != 1:
raise ValueError(f"History shape must be 1D, got {history.shape}")
content = history[-1].content
torchrl_logger.info(f"Parsing reward from: {content}")
parsed = self.parser(content)
reward = parsed.get("evaluation_score", 0.0) - parsed.get("baseline_score", 0.0)
torchrl_logger.info(f"Parsed reward: {reward}")
tensordict["reward"] = tensordict["reward"] + reward
return tensordict
class _add_info_to_reset:
def __init__(self, func):
functools.update_wrapper(self, func)
self.func = func
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs), {}
class _add_truncated_to_step:
def __init__(self, func):
functools.update_wrapper(self, func)
self.func = func
@_temp_cwd_mlgym()
def __call__(self, *args, **kwargs):
obs, r, done, info = self.func(*args, **kwargs)
return obs, r, done, False, info
[docs]class MLGymWrapper(GymWrapper):
"""A thin wrapper for MLGym environments.
This specialized :class:`~torchrl.envs.GymWrapper` subclass defines the observation space with `observation=NonTensor()`
and the action space with `text_response=NonTensor()`, according to the :class:`~torchrl.envs.llm.ChatEnv` API.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.full_action_spec = Composite(
text_response=NonTensor(example_data="<a string>", shape=())
)
self.full_observation_spec = Composite(
observation=NonTensor(example_data="<a string>", shape=())
)
self.set_env()
def set_env(self, env: Any = None):
if env is not None:
self._env = env
self._patch_reset()
self._patch_step()
def _patch_reset(self):
if not isinstance(self._env.reset, _add_info_to_reset):
self._env.reset = _add_info_to_reset(self._env.reset)
def _patch_step(self):
if not isinstance(self._env.reset, _add_truncated_to_step):
self._env.step = _add_truncated_to_step(self._env.step)
@_temp_cwd_mlgym()
def _reset(
self, tensordict: TensorDictBase | None = None, **kwargs
) -> TensorDictBase:
return super()._reset(tensordict=tensordict, **kwargs)
_TASK_IDS = {}
def get_args(
task: Literal["prisonersDilemma"] = "prisonersDilemma",
) -> tuple[
mlgym.environment.env.EnvironmentArguments, # noqa
mlgym.agent.base.AgentArguments, # noqa
]: # noqa
"""Parse command line arguments and return a ScriptArguments object.
Args:
args: Optional list of arguments to parse. If not provided, uses sys.argv.
"""
import mlgym.environment.registration # noqa
from mlgym import CONFIG_DIR
from mlgym.agent.base import AgentArguments
from mlgym.backend.base import ModelArguments
from mlgym.environment.env import EnvironmentArguments
from mlgym.environment.registration import register_task
environment_args = EnvironmentArguments(
task_config_path=f"tasks/{task}.yaml",
max_steps=10,
seed=42,
container_type="podman",
verbose=False,
aliases_file="docker/aliases.sh",
)
agent_args = AgentArguments(
# placeholder
model=ModelArguments(""),
# Despite using torchrl as an agent, we still need the agent config - see StateToMessage parser
agent_config_path=CONFIG_DIR / "agents" / "default.yaml",
)
register_task(environment_args)
_TASK_IDS[task] = (environment_args.task.id, agent_args)
return environment_args, agent_args
[docs]def make_mlgym(
*,
task: Literal["prisonersDilemma"] | None = None,
tasks: list[Literal["prisonersDilemma"]] | None = None,
tokenizer: transformers.AutoTokenizer | str | None = None, # noqa
device="cpu",
reward_wrong_format: float | None = None,
) -> TransformedEnv:
"""Wraps an MLGymEnv in a TorchRL Environment.
The appended transforms will make sure that the data is formatted for the LLM during (for the outputs of `env.step`)
and for the MLGym API (for inputs to `env.step`).
Keyword Args:
task (str): The task to wrap. Exclusive with `tasks` argument.
.. note:: The correct format is simply the task name, e.g., `"prisonersDilemma"`.
tasks (List[str]): The tasks available for the env. Exclusive with `task` argument.
.. note:: The correct format is simply the task name, e.g., `"prisonersDilemma"`.
tokenizer (transformers.AutoTokenizer or str, optional): A transformer that tokenizes the data.
If a string is passed, it will be converted to a `transformers.AutoTokenizer`.
device (str, optional): The device to set to the env. Defaults to "cpu".
reward_wrong_format (float, optional): The reward (negative penalty) for wrongly formatted actions.
Defaults to `None` (no penalty).
"""
import gymnasium as gym
if isinstance(tokenizer, str):
import transformers
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer)
with _temp_cwd_mlgym():
if task and not tasks:
environment_args, agent_args = get_args(task=task)
elif tasks and not task:
for task in tasks:
environment_args, agent_args = get_args(task=task)
else:
raise ValueError(
f"Either task or tasks should be provided, not both and not none. Got {task=} and {tasks=}."
)
base_env = gym.make(
f"mlgym/{_TASK_IDS[task][0]}",
devices=["cpu_0"],
).unwrapped
# we need the env to have access to the config
base_env.config = agent_args.config
env = TransformedEnv(
MLGymWrapper(base_env, auto_reset=False, device=device), auto_unwrap=False
)
env.append_transform(ConditionalSkip(lambda td: td["retry"]))
env.append_transform(IsolateCodeBlock(reward_wrong_format=reward_wrong_format))
env.append_transform(ResetModule())
if tasks:
# Add a task sampler
env.append_transform(TaskSampler(tasks))
env.append_transform(ReadState())
env.append_transform(StateToMessage())
env.append_transform(MessageToHistory())
env.append_transform(TemplateTransform(tokenizer=tokenizer))
env.append_transform(MLGymRewardAssignment())
# # We want the env to have a batch-size of (1,) because it will be easier to interact with
# # LLMs
# env.append_transform(BatchSizeTransform(batch_size=(1,)))
return env