6. Custom RL Example using Stable Baselines¶
Apart from using examples from OmniIsaacGymEnvs, it is also possible to set up reinforcement learning tasks directly in Isaac Sim. Here, we will look at setting up a new Cartpole environment that can be trained in Isaac Sim with PPO provided by the stable baselines3 library. This is a simple single-environment setup that can work directly using RL interfaces provided by Omniverse Isaac Gym. Our goal for this task is for the Cartpole robot to learn to keep its pole in an upright position.
Scripts used in this tutorial can be found in standalone_examples/api/omni.isaac.gym.
6.1. Learning Objectives¶
This tutorial will walk through the steps of creating our own Cartpole reinforcement learning example using the interfaces provided in Omniverse Isaac Gym. We will
Set up a new Cartpole task
Implement training and inferencing for our Cartpole task with stable-baselines3
15-20 Minute Tutorial
6.2. Getting Started¶
Refer to Overview & Getting Started for details on the RL interfaces available in Omniverse Isaac Gym.
Refer to Default Python Environment to learn about Isaac Sim’s python environment and locate the python executable in Isaac Sim.
Refer to the Isaac Core tutorials for details on Isaac Core APIs and workflow.
6.3. Creating Cartpole Task¶
First, we will look at the Task implementation, cartpole_task.py, where we will implement our task logic. We will be basing our Task structure on the BaseTask class from the Isaac Core extension.
Starting with imports, we will use a couple of classes and utilities from the Isaac Core extension.
from omni.isaac.core.utils.nucleus import get_assets_root_path from omni.isaac.core.utils.stage import add_reference_to_stage from omni.isaac.core.tasks.base_task import BaseTask from omni.isaac.core.articulations import ArticulationView from omni.isaac.core.utils.prims import create_prim
Next up is an import for an Omniverse API, which we will use to set the camera angle of the viewport.
Finally, we have additional imports for general python libraries.
import gym from gym import spaces import numpy as np import torch import math
Next up, we will look at initializing our Task. We inherit our task from BaseTask, which is defined in the Isaac Core extension. This will allow us to add the task to our World and follow the recommended workflow in Isaac Core.
class CartpoleTask(BaseTask): def __init__( self, name, offset=None ) -> None: # task-specific parameters self._cartpole_position = [0.0, 0.0, 2.0] self._reset_dist = 3.0 self._max_push_effort = 400.0 # values used for defining RL buffers self._num_observations = 4 self._num_actions = 1 self._device = "cpu" self.num_envs = 1 # a few class buffers to store RL-related states self.obs = torch.zeros((self.num_envs, self._num_observations)) self.resets = torch.zeros((self.num_envs, 1)) # set the action and observation space for RL self.action_space = spaces.Box(np.ones(self._num_actions) * -1.0, np.ones(self._num_actions) * 1.0) self.observation_space = spaces.Box(np.ones(self._num_observations) * -np.Inf, np.ones(self._num_observations) * np.Inf) # trigger __init__ of parent class BaseTask.__init__(self, name=name, offset=offset)
We will also need to define a set_up_scene function in our task, which will be triggered automatically from World. In this function, we include logic to create our scene, including adding in our Cartpole robot, the ground plane, and setting the viewport angle.
def set_up_scene(self, scene) -> None: # retrieve file path for the Cartpole USD file assets_root_path = get_assets_root_path() usd_path = assets_root_path + "/Isaac/Robots/Cartpole/cartpole.usd" # add the Cartpole USD to our stage create_prim(prim_path="/World/Cartpole", prim_type="Xform", position=self._cartpole_position) add_reference_to_stage(usd_path, "/World/Cartpole") # create an ArticulationView wrapper for our cartpole - this can be extended towards accessing multiple cartpoles self._cartpoles = ArticulationView(prim_paths_expr="/World/Cartpole*", name="cartpole_view") # add Cartpole ArticulationView and ground plane to the Scene scene.add(self._cartpoles) scene.add_default_ground_plane() # set default camera viewport position and target self.set_initial_camera_params() def set_initial_camera_params(self, camera_position=[10, 10, 3], camera_target=[0, 0, 0]): viewport = omni.kit.viewport_legacy.get_default_viewport_window() viewport.set_camera_position("/OmniverseKit_Persp", camera_position, camera_position, camera_position, True) viewport.set_camera_target("/OmniverseKit_Persp", camera_target, camera_target, camera_target, True)
Another task API that we can implement is post_reset(). In this method, we can implement logic that gets executed once the scene is constructed and simulation starts running. In this example, we will use this method to retrieve indices of some joints in the Cartpole robot, as well as performing an initial reset of our environment.
def post_reset(self): self._cart_dof_idx = self._cartpoles.get_dof_index("cartJoint") self._pole_dof_idx = self._cartpoles.get_dof_index("poleJoint") # randomize all envs indices = torch.arange(self._cartpoles.count, dtype=torch.int64, device=self._device) self.reset(indices)
We will have to implement our reset method now. This method is used to set our environment into an initial state for starting a new training episode. In this example, we will randomize the joint positions and velocities and use APIs exposed through ArticulationView to set the joint positions and velocities.
def reset(self, env_ids=None): if env_ids is None: env_ids = torch.arange(self.num_envs, device=self._device) num_resets = len(env_ids) # randomize DOF positions dof_pos = torch.zeros((num_resets, self._cartpoles.num_dof), device=self._device) dof_pos[:, self._cart_dof_idx] = 1.0 * (1.0 - 2.0 * torch.rand(num_resets, device=self._device)) dof_pos[:, self._pole_dof_idx] = 0.125 * math.pi * (1.0 - 2.0 * torch.rand(num_resets, device=self._device)) # randomize DOF velocities dof_vel = torch.zeros((num_resets, self._cartpoles.num_dof), device=self._device) dof_vel[:, self._cart_dof_idx] = 0.5 * (1.0 - 2.0 * torch.rand(num_resets, device=self._device)) dof_vel[:, self._pole_dof_idx] = 0.25 * math.pi * (1.0 - 2.0 * torch.rand(num_resets, device=self._device)) # apply resets indices = env_ids.to(dtype=torch.int32) self._cartpoles.set_joint_positions(dof_pos, indices=indices) self._cartpoles.set_joint_velocities(dof_vel, indices=indices) # bookkeeping self.resets[env_ids] = 0
We will now implement our pre_physics_step method. This method will be called from VecEnvBase before each simulation step, and will pass in actions from the RL policy as an argument. In this method, we can transform the actions into force vectors, which we will then apply to our Cartpole robot.
def pre_physics_step(self, actions) -> None: reset_env_ids = self.resets.nonzero(as_tuple=False).squeeze(-1) if len(reset_env_ids) > 0: self.reset(reset_env_ids) actions = torch.tensor(actions) forces = torch.zeros((self._cartpoles.count, self._cartpoles.num_dof), dtype=torch.float32, device=self._device) forces[:, self._cart_dof_idx] = self._max_push_effort * actions indices = torch.arange(self._cartpoles.count, dtype=torch.int32, device=self._device) self._cartpoles.set_joint_efforts(forces, indices=indices)
Next, we will look at implementing our observations, rewards, and reset methods. These methods will be responsible for collecting states from physics to use as observations for the RL policy, compute the reward based on physics states, and determine when the Cartpole reaches a “bad” state, such that we should reset it to an initialization state.
def get_observations(self): dof_pos = self._cartpoles.get_joint_positions() dof_vel = self._cartpoles.get_joint_velocities() # collect pole and cart joint positions and velocities for observation cart_pos = dof_pos[:, self._cart_dof_idx] cart_vel = dof_vel[:, self._cart_dof_idx] pole_pos = dof_pos[:, self._pole_dof_idx] pole_vel = dof_vel[:, self._pole_dof_idx] self.obs[:, 0] = cart_pos self.obs[:, 1] = cart_vel self.obs[:, 2] = pole_pos self.obs[:, 3] = pole_vel return self.obs def calculate_metrics(self) -> None: cart_pos = self.obs[:, 0] cart_vel = self.obs[:, 1] pole_angle = self.obs[:, 2] pole_vel = self.obs[:, 3] # compute reward based on angle of pole and cart velocity reward = 1.0 - pole_angle * pole_angle - 0.01 * torch.abs(cart_vel) - 0.005 * torch.abs(pole_vel) # apply a penalty if cart is too far from center reward = torch.where(torch.abs(cart_pos) > self._reset_dist, torch.ones_like(reward) * -2.0, reward) # apply a penalty if pole is too far from upright reward = torch.where(torch.abs(pole_angle) > np.pi / 2, torch.ones_like(reward) * -2.0, reward) return reward.item() def is_done(self) -> None: cart_pos = self.obs[:, 0] pole_pos = self.obs[:, 2] # reset the robot if cart has reached reset_dist or pole is too far from upright resets = torch.where(torch.abs(cart_pos) > self._reset_dist, 1, 0) resets = torch.where(torch.abs(pole_pos) > math.pi / 2, 1, resets) self.resets = resets return resets.item()
Note that we should return the corresponding result from each method call. These values will be collected by VecEnvBase and passed on to the RL policy.
That wraps up our task implementation!
We will now set up our training script for running this task, which can be found in cartpole_train.py.
First, we need to make an instance of VecEnvBase. This will provide the interface for the RL policy to interact with.
from omni.isaac.gym.vec_env import VecEnvBase env = VecEnvBase(headless=True)
We have set the headless parameter to True here to speed up training. This will run our training loop without launching the Isaac Sim window. To launch training with the window, simply set headless=False.
Next, we will create an instance of our Cartpole task and register it to the VecEnvBase instance we have created above.
from cartpole_task import CartpoleTask task = CartpoleTask(name="Cartpole") env.set_task(task, backend="torch")
Now, we can add in our PPO training code.
from stable_baselines3 import PPO # create agent from stable baselines model = PPO( "MlpPolicy", env, n_steps=1000, batch_size=1000, n_epochs=20, learning_rate=0.001, gamma=0.99, device="cuda:0", ent_coef=0.0, vf_coef=0.5, max_grad_norm=1.0, verbose=1, tensorboard_log="./cartpole_tensorboard" ) model.learn(total_timesteps=100000) model.save("ppo_cartpole") env.close()
To run this script, we will use the python executable in Isaac Sim, which we will refer to as PYTHON_PATH. Locate the Isaac Sim python executable, which by default should be python.sh on Linux or python.bat on Windows, located at the root of the Isaac Sim directory.
To set a PYTHON_PATH variable in the terminal that links to the python executable, we can run a command that resembles the following. Make sure to update the paths to your local path.
For Linux: alias PYTHON_PATH=~/.local/share/ov/pkg/isaac_sim-*/python.sh For Windows: doskey PYTHON_PATH=C:\Users\user\AppData\Local\ov\pkg\isaac_sim-*\python.bat $*
We can run our training script as follows:
This should start running training once initialization completes. It will take around 10 minutes to train the full 100000 timesteps.
Once the model has been trained, we can also run inference on the policy. We provide a cartpole_play.py script for this purpose, which can be run as follows:
In this script, we will do the same VecEnvBase and CartpoleTask initialization, but run the inference in a loop. To enable visualization, we will also set headless=False.
# create isaac environment from omni.isaac.gym.vec_env import VecEnvBase env = VecEnvBase(headless=False) # create task and register task from cartpole_task import CartpoleTask task = CartpoleTask(name="Cartpole") env.set_task(task, backend="torch") # import stable baselines from stable_baselines3 import PPO # Run inference on the trained policy model = PPO.load("ppo_cartpole") env._world.reset() obs = env.reset() while env._simulation_app.is_running(): action, _states = model.predict(obs) obs, rewards, dones, info = env.step(action) env.close()
We should see our cartpole policy running!
This tutorial covered the following topics:
Implementation of Cartpole task for RL
Running training and inferencing with stable-baselines3