diff --git a/saycan/README.md b/saycan/README.md new file mode 100644 index 0000000..740c593 --- /dev/null +++ b/saycan/README.md @@ -0,0 +1,105 @@ +# SayCan - Language-Conditioned Robotic Manipulation + +This experiment implements the SayCan approach for grounding language in robotic affordances, +combining: +- **ViLD**: Open-vocabulary object detection +- **LLM (Ollama)**: Task planning and action scoring +- **CLIPort**: Language-conditioned pick-and-place manipulation +- **PyBullet**: Physics simulation with UR5e robot arm + +## Installation + +### Core Dependencies +```bash +pip install -r requirements.txt +``` + +### Ollama (for LLM) +Install Ollama from [ollama.ai](https://ollama.ai) and pull a model: +```bash +ollama pull llama3.2:1b +``` + +### Asset Downloads +Assets (robot URDFs, ViLD model, CLIPort checkpoint) are downloaded automatically on first run. + +## Setting on the Webserver + + +```bash +python manage.py shell -c "from experiment.models import Environment; Environment.objects.update_or_create(name='SayCan', defaults={ + 'description': 'Language-conditioned robotic manipulation with LLM planning', + 'filepaths': {'environment': 'saycan/environment.py'} +})" +``` + + +```bash +python manage.py shell -c "from experiment.models import Experiment, Environment; Experiment.objects.update_or_create(link='saycan', defaults={ + 'name': 'SayCan', + 'short_description': 'Robot manipulation with natural language instructions', + 'long_description': 'Guide a robot arm to pick and place objects using natural language instructions. The system uses ViLD for object detection, an LLM for task planning, and CLIPort for language-conditioned manipulation.\r\n\n
\n
\nYou can give instructions like:\n', + 'enabled': True, + 'environment': Environment.objects.get(name='SayCan'), + 'number_of_episodes': 1, + 'target_fps': 24.0, + 'wait_for_inputs': False +})" +``` + + +```bash +python manage.py shell -c "from experiment.models import Policy; Policy.objects.update_or_create(name='SayCan', defaults={ + 'description': 'SayCan policy with LLM planning and CLIPort execution', + 'filepaths': {'policy': 'saycan/policy.py'}, + 'checkpoint_interval': 0 +})" +``` + + +```bash +python manage.py shell -c "from experiment.models import Agent, Policy; Agent.objects.update_or_create(role='agent_0', defaults={ + 'name': 'Robot', + 'description': 'UR5e robot arm with Robotiq gripper', + 'policy': Policy.objects.get(name='SayCan'), + 'participant': True, + 'keyboard_inputs': {}, + 'multiple_keyboard_inputs': False, + 'inputs_type': 'other', + 'textual_inputs': True +})" +``` + + +```bash +python manage.py shell -c "from experiment.models import Experiment, Agent; exp = Experiment.objects.get(link='saycan'); exp.agents.add(Agent.objects.get(role='agent_0'))" +``` + +## Usage + +### Action Types +The environment accepts the following action types: + +| Action | Description | +|--------|-------------| +| `"task:"` | Set a high-level task for LLM planning | +| `"plan"` | Get next planned action from LLM | +| `""` | Direct pick-and-place instruction | +| `"done"` | End the episode | + +### Example Tasks +- `task: put all blocks in bowls` +- `task: stack the blocks` +- `task: sort blocks by color` +- `pick the blue block and place it on the red bowl` + +## References + +- **SayCan**: [Ahn et al. (2022) - Do As I Can, Not As I Say](https://arxiv.org/abs/2204.01691) +- **CLIPort**: [Shridhar et al. (2021) - What and Where Pathways for Robotic Manipulation](https://arxiv.org/abs/2109.12098) +- **ViLD**: [Gu et al. (2021) - Open-Vocabulary Object Detection via Vision and Language Knowledge Distillation](https://arxiv.org/abs/2104.13921) + +## Repository + +- Original SayCan: https://github.com/google-research/google-research/tree/master/saycan +- CLIPort: https://github.com/cliport/cliport \ No newline at end of file diff --git a/saycan/base_environment.py b/saycan/base_environment.py new file mode 100644 index 0000000..b972b94 --- /dev/null +++ b/saycan/base_environment.py @@ -0,0 +1,348 @@ +""" +SayCan Environment Wrapper for SHARPIE. + +This module wraps the PickPlaceEnv from the SayCan codebase to work with the +SHARPIE experiment framework. It integrates: +- ViLD for open-vocabulary object detection +- LLM (via Ollama) for task planning and action scoring +- CLIPort for language-conditioned pick-and-place manipulation + +Action Types: +- "task:" - Set task and auto-plan first action +- "plan" - Get next planned action from LLM +- "" - Direct CLIPort instruction +- "done" - End episode + +Original SayCan Repository: + https://github.com/google-research/google-research/tree/master/saycan + +Reference: + Ahn, M., Brohan, A., Brown, N., Chebotar, Y., Cortes, Y., David, B., + Finn, C., Fu, C., Gopalakrishnan, K., Hausman, K., Herzog, A., Ho, D., + Hsu, J., Ibarz, J., Ichter, B., Irpan, A., Jang, E., Jang, R., Julian, R., + ... & Zeng, A. (2022). Do As I Can, Not As I Say: Grounding Language in + Robotic Affordances. arXiv preprint arXiv:2204.01691. +""" + +import cv2 +import os +import sys +import tempfile +import numpy as np +from PIL import Image + +# Add the saycan directory to path for imports +SAYCAN_DIR = os.path.dirname(os.path.abspath(__file__)) +if SAYCAN_DIR not in sys.path: + sys.path.insert(0, SAYCAN_DIR) + +from pick_place_env import PickPlaceEnv +from config import PICK_TARGETS, PLACE_TARGETS +from cliport import get_cliport +# Import LLM and helpers for planning +from llm import make_options, gpt3_scoring, gpt3_context, termination_string +from helpers import normalize_scores, step_to_nlp, affordance_scoring +from vild import vild, category_name_string, vild_params + + +class SayCanBaseEnvironment: + """Wrapper for the SayCan PickPlaceEnv with LLM planning and CLIPort integration.""" + + def __init__(self): + """Initialize the environment.""" + self.env = PickPlaceEnv() + self.config = None + self._step_count = 0 + self._max_steps = 100 + self._cliport = None + self.cached_video_frames = [] + + # LLM planning state + self._current_task = None + self._max_tasks = 10 + self._gpt3_prompt = None + self._options = None + self._found_objects = None + self._task_step_count = 0 + + def reset(self, config=None): + """ + Reset the environment to an initial state. + + Args: + config: Optional configuration dict with 'pick' and 'place' lists. + If None, uses default objects. + + Returns: + observation: Initial observation dict with 'image', 'xyzmap', 'pick', 'place' + info: Additional information dict + """ + self._step_count = 0 + self.cached_video_frames = [] + + # Reset LLM planning state + self._current_task = None + self._gpt3_prompt = None + self._options = None + self._found_objects = None + self._task_step_count = 0 + + if config is None: + config = {'pick': ['yellow block', 'blue block', 'red block'], + 'place': ['blue bowl', 'red bowl']} + + self.config = config + observation = self.env.reset(config) + + info = { + "step": 0, + "config": config, + "pick_objects": config.get("pick", []), + "place_objects": config.get("place", []) + } + + return observation, info + + def set_task(self, task_text): + """ + Set the current task from natural language. + + Args: + task_text: Task instruction (e.g., "put all the blocks in different corners") + """ + self._current_task = task_text + self._gpt3_prompt = gpt3_context + "\n# " + task_text + "\n" + self._task_step_count = 0 + self._found_objects = None + self._options = None + print(f"Environment: Task set to '{task_text}'") + + def detect_objects(self, observation=None): + """ + Detect objects in the scene using ViLD. + + Args: + observation: Observation dict with 'image'. If None, uses current observation. + + Returns: + found_objects: List of detected object names + """ + if observation is None: + observation = self.env.get_observation() + + # Save image to temp file for ViLD + image = observation['image'] + with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as f: + temp_path = f.name + Image.fromarray(image).save(temp_path) + + try: + # Run ViLD detection + prompt_swaps = [('block', 'cube')] + found_objects = vild(temp_path, category_name_string, vild_params, + plot_on=False, prompt_swaps=prompt_swaps) + print(f"Environment: Detected objects: {found_objects}") + finally: + # Clean up temp file + os.unlink(temp_path) + + return found_objects + + def plan_next_action(self, observation=None): + """ + Plan the next action using LLM + affordance scoring. + + Args: + observation: Current observation. If None, uses current observation. + + Returns: + action_text: Natural language action instruction + done: Whether the task is complete + """ + if observation is None: + observation = self.env.get_observation() + + # Detect objects if not already done + if self._found_objects is None: + self._found_objects = self.detect_objects(observation) + + # Create options if not already done + if self._options is None: + self._options = make_options(PICK_TARGETS, PLACE_TARGETS, + termination_string=termination_string) + + # Calculate affordance scores based on detected objects + affordance_scores = affordance_scoring(self._options, self._found_objects, + block_name="box", bowl_name="circle", + verbose=False) + + # Get LLM scores + llm_scores, _ = gpt3_scoring(self._gpt3_prompt, self._options, verbose=True) + + # Combine scores + combined_scores = { + option: np.exp(llm_scores[option]) * affordance_scores[option] + for option in self._options + } + combined_scores = normalize_scores(combined_scores) + + # Select best action + selected_task = max(combined_scores, key=combined_scores.get) + + # Check for termination + if selected_task == termination_string: + print("Environment: Task completed (termination signal)") + return "done", True + + # Update prompt for next step + self._gpt3_prompt += selected_task + "\n" + self._task_step_count += 1 + + # Check max tasks limit + if self._task_step_count >= self._max_tasks: + print("Environment: Max steps reached") + return "done", True + + # Convert to natural language + action_text = step_to_nlp(selected_task) + print(f"Environment: Step {self._task_step_count} - {action_text}") + return action_text, False + + def step(self, action_dict): + """ + Execute one step in the environment. + + Args: + action_dict: Dictionary with agent id as keys and action as value. + Action can be: + - string text instruction directly + - "task:" to set a task and auto-plan + - "plan" to get the next planned action + - "done" to end the episode + + Returns: + observation: New observation dict + reward: Reward for the action (float) + terminated: Whether the episode has ended (bool) + truncated: Whether the episode was truncated (bool) + info: Additional information (dict) + """ + self._step_count += 1 + + if len(self.cached_video_frames) > 0: + return np.array([]), 0.0, False, False, {"info": "No action taken"} + + # Extract action from dict (single-agent environment) + action = list(action_dict.values())[0] if isinstance(action_dict, dict) else action_dict + + # Handle different action types + if action == 'done': + return np.array([]), 0.0, True, False, {"info": "Task completed"} + elif isinstance(action, str) and action.startswith('task:'): + # Execute complete task automatically + task_text = action[5:].strip() + results = self.run_task(task_text) + return np.array([]), results["total_reward"], False, False, results + elif action: + # Direct text instruction + obs, reward, _, info = self._step_with_text(action) + # Get the frames buffer + self.cached_video_frames = self.env.cache_video + else: + return np.array([]), 0.0, False, False, {"info": "No action taken"} + + # Check termination conditions + terminated = False + truncated = self._step_count >= self._max_steps + + info["step"] = self._step_count + info["max_steps"] = self._max_steps + + return obs, reward, terminated, truncated, info + + def _step_with_text(self, text): + """Execute a step using CLIPort with text instruction.""" + if self._cliport is None: + self._cliport = get_cliport() + + # Get current observation + obs = self.env.get_observation() + + # Use CLIPort to predict action + action = self._cliport.predict(obs, text) + + # Execute the predicted action + obs, reward, done, info = self.env.step({ + 'pick': action['pick'], + 'place': action['place'] + }) + + info['text_instruction'] = text + info['cliport_action'] = action + + return obs, reward, done, info + + def render(self): + """Render the environment.""" + if len(self.cached_video_frames) > 0: + return cv2.cvtColor(self.cached_video_frames.pop(0), cv2.COLOR_BGR2RGB) + return cv2.cvtColor(self.env.get_camera_image(), cv2.COLOR_BGR2RGB) + + def get_observation(self): + """Get current observation without stepping.""" + return self.env.get_observation() + + def run_task(self, task_text, max_steps=5): + """ + Execute a complete task from start to finish with automatic planning. + + This method sets the task and automatically executes all planned actions + until completion, without requiring manual 'plan' calls between steps. + + Args: + task_text: Natural language task description (e.g., "put all blocks in bowls") + max_steps: Maximum number of actions to execute (default: 50) + + Returns: + results: Dictionary containing: + - task: The original task text + - completed: Whether the task completed successfully + - steps: List of executed steps with actions and rewards + - total_reward: Cumulative reward across all steps + - termination_reason: Why execution stopped + """ + self.set_task(task_text) + + results = { + "task": task_text, + "completed": False, + "steps": [], + "total_reward": 0.0, + "termination_reason": None + } + + for step in range(max_steps): + # Plan next action + action_text, task_done = self.plan_next_action() + + # Check for task completion signal from LLM + if task_done or action_text == "done": + results["completed"] = True + results["termination_reason"] = "task_done" + break + + # Execute the planned action + obs, reward, _, info = self._step_with_text(action_text) + results["total_reward"] += reward + + results["steps"].append({ + "step": step, + "action": action_text, + "reward": reward, + "info": info + }) + + # Cache final video frames for rendering + self.cached_video_frames = self.env.cache_video + + return results \ No newline at end of file diff --git a/saycan/cliport.py b/saycan/cliport.py new file mode 100644 index 0000000..6cfde29 --- /dev/null +++ b/saycan/cliport.py @@ -0,0 +1,490 @@ +""" +CLIPort - CLIP + Transporter Networks for Language-Conditioned Manipulation. + +This module implements the CLIPort architecture for language-conditioned pick-and-place +operations. It combines CLIP (Contrastive Language-Image Pre-training) with Transporter +Networks to predict pick and place positions based on natural language instructions. + +Key Components: +- ResNet-based encoder-decoder architecture +- CLIP text and image encoders +- Transporter Networks for pick and place heatmap prediction +- Pretrained checkpoint loading + +CLIPort Repository: + https://github.com/cliport/cliport + +Reference: + Shridhar, M., Manuelli, L., & Fox, D. (2021). CLIPort: What and Where Pathways + for Robotic Manipulation. Conference on Robot Learning (CoRL). + +Used in SayCan: + https://github.com/google-research/google-research/tree/master/saycan +""" + +import os +import subprocess +import numpy as np +import torch +import clip +import matplotlib.pyplot as plt +import jax +import jax.numpy as jnp +import optax +import flax +from flax import linen as nn +from flax.training import checkpoints +from moviepy import ImageSequenceClip +from IPython.display import display + +# Get the saycan directory for checkpoint paths +SAYCAN_DIR = os.path.dirname(os.path.abspath(__file__)) + +class ResNetBlock(nn.Module): + """ResNet pre-Activation block. https://arxiv.org/pdf/1603.05027.pdf""" + features: int + stride: int = 1 + + def setup(self): + self.conv0 = nn.Conv(self.features // 4, (1, 1), (self.stride, self.stride)) + self.conv1 = nn.Conv(self.features // 4, (3, 3)) + self.conv2 = nn.Conv(self.features, (1, 1)) + self.conv3 = nn.Conv(self.features, (1, 1), (self.stride, self.stride)) + + def __call__(self, x): + y = self.conv0(nn.relu(x)) + y = self.conv1(nn.relu(y)) + y = self.conv2(nn.relu(y)) + if x.shape != y.shape: + x = self.conv3(nn.relu(x)) + return x + y + + +class UpSample(nn.Module): + """Simple 2D 2x bilinear upsample.""" + + def __call__(self, x): + B, H, W, C = x.shape + new_shape = (B, H * 2, W * 2, C) + return jax.image.resize(x, new_shape, 'bilinear') + + +class ResNet(nn.Module): + """Hourglass 53-layer ResNet with 8-stride.""" + out_dim: int + + def setup(self): + self.dense0 = nn.Dense(8) + + self.conv0 = nn.Conv(64, (3, 3), (1, 1)) + self.block0 = ResNetBlock(64) + self.block1 = ResNetBlock(64) + self.block2 = ResNetBlock(128, stride=2) + self.block3 = ResNetBlock(128) + self.block4 = ResNetBlock(256, stride=2) + self.block5 = ResNetBlock(256) + self.block6 = ResNetBlock(512, stride=2) + self.block7 = ResNetBlock(512) + + self.block8 = ResNetBlock(256) + self.block9 = ResNetBlock(256) + self.upsample0 = UpSample() + self.block10 = ResNetBlock(128) + self.block11 = ResNetBlock(128) + self.upsample1 = UpSample() + self.block12 = ResNetBlock(64) + self.block13 = ResNetBlock(64) + self.upsample2 = UpSample() + self.block14 = ResNetBlock(16) + self.block15 = ResNetBlock(16) + self.conv1 = nn.Conv(self.out_dim, (3, 3), (1, 1)) + + def __call__(self, x, text): + + # # Project and concatenate CLIP features (early fusion). + # text = self.dense0(text) + # text = jnp.expand_dims(text, axis=(1, 2)) + # text = jnp.broadcast_to(text, x.shape[:3] + (8,)) + # x = jnp.concatenate((x, text), axis=-1) + + x = self.conv0(x) + x = self.block0(x) + x = self.block1(x) + x = self.block2(x) + x = self.block3(x) + x = self.block4(x) + x = self.block5(x) + x = self.block6(x) + x = self.block7(x) + + # Concatenate CLIP features (mid-fusion). + text = jnp.expand_dims(text, axis=(1, 2)) + text = jnp.broadcast_to(text, x.shape) + x = jnp.concatenate((x, text), axis=-1) + + x = self.block8(x) + x = self.block9(x) + x = self.upsample0(x) + x = self.block10(x) + x = self.block11(x) + x = self.upsample1(x) + x = self.block12(x) + x = self.block13(x) + x = self.upsample2(x) + x = self.block14(x) + x = self.block15(x) + x = self.conv1(x) + return x + + +class TransporterNets(nn.Module): + """TransporterNet with 3 ResNets (translation only).""" + + def setup(self): + # Picking affordances. + self.pick_net = ResNet(1) + + # Pick-conditioned placing affordances. + self.q_net = ResNet(3) # Query (crop around pick location). + self.k_net = ResNet(3) # Key (place features). + self.crop_size = 64 + self.crop_conv = nn.Conv(features=1, kernel_size=(self.crop_size, self.crop_size), use_bias=False, dtype=jnp.float32, padding='SAME') + + def __call__(self, x, text, p=None, train=True): + B, H, W, C = x.shape + pick_out = self.pick_net(x, text) # (B, H, W, 1) + + # Get key features. + k = self.k_net(x, text) + + # Add 0-padding before cropping. + h = self.crop_size // 2 + x_crop = jnp.pad(x, [(0, 0), (h, h), (h, h), (0, 0)], 'maximum') + + # Get query features and convolve them over key features. + place_out = jnp.zeros((0, H, W, 1), jnp.float32) + for b in range(B): + + # Get coordinates at center of crop. + if p is None: + pick_out_b = pick_out[b, ...] # (H, W, 1) + pick_out_b = pick_out_b.flatten() # (H * W,) + amax_i = jnp.argmax(pick_out_b) + v, u = jnp.unravel_index(amax_i, (H, W)) + else: + v, u = p[b, :] + + # Get query crop. + x_crop_b = jax.lax.dynamic_slice(x_crop, (b, v, u, 0), (1, self.crop_size, self.crop_size, x_crop.shape[3])) + # x_crop_b = x_crop[b:b+1, v:(v + self.crop_size), u:(u + self.crop_size), ...] + + # Convolve q (query) across k (key). + q = self.q_net(x_crop_b, text[b:b+1, :]) # (1, H, W, 3) + q = jnp.transpose(q, (1, 2, 3, 0)) # (H, W, 3, 1) + place_out_b = self.crop_conv.apply({'params': {'kernel': q}}, k[b:b+1, ...]) # (1, H, W, 1) + scale = 1 / (self.crop_size * self.crop_size) # For higher softmax temperatures. + place_out_b *= scale + place_out = jnp.concatenate((place_out, place_out_b), axis=0) + + return pick_out, place_out + + +def n_params(params): + return jnp.sum(jnp.int32([n_params(v) if isinstance(v, dict) or isinstance(v, flax.core.frozen_dict.FrozenDict) else np.prod(v.shape) for v in params.values()])) + +from flax.training import train_state + +class TrainState(train_state.TrainState): + pass + + + + +#@markdown Train your own model, or load a pretrained one. +load_pretrained = True #@param {type:"boolean"} + +# Initialize model weights using dummy tensors. +rng = jax.random.PRNGKey(0) +rng, key = jax.random.split(rng) +init_img = jnp.ones((4, 224, 224, 5), jnp.float32) +init_text = jnp.ones((4, 512), jnp.float32) +init_pix = jnp.zeros((4, 2), np.int32) +init_params = TransporterNets().init(key, init_img, init_text, init_pix)['params'] +print(f'Model parameters: {n_params(init_params):,}') + +# Define the Optax optimizer +optimizer_tx = optax.adam(learning_rate=1e-4) + +# Create an initial TrainState object. This will have step=0. +optim = TrainState.create(apply_fn=TransporterNets().apply, + params=init_params, + tx=optimizer_tx) + +if load_pretrained: + ckpt_path = os.path.join(SAYCAN_DIR, f'ckpt_{40000}') + if not os.path.exists(ckpt_path): + import subprocess + print("Downloading CLIPort checkpoint...") + subprocess.run(['gdown', '--id', '1Nq0q1KbqHOA5O7aRSu4u7-u27EMMXqgP', '-O', ckpt_path], check=False) + + try: + # Attempt to restore directly. This will fail if 'step' is missing in the checkpoint. + optim = checkpoints.restore_checkpoint(ckpt_path, optim) + print('Loaded:', ckpt_path) + except ValueError as e: + if "Missing field step in state dict" in str(e): + print("Attempting to load old checkpoint format (missing 'step' field).") + # Load the raw checkpoint data as a dictionary + loaded_state_dict = checkpoints.restore_checkpoint(ckpt_path, target=None) + + if isinstance(loaded_state_dict, dict): + # Extract parameters, common keys for parameters are 'params' or 'target' + params_from_ckpt = loaded_state_dict.get('params', loaded_state_dict.get('target', init_params)) + + # Re-initialize the opt_state using the current optax optimizer with loaded parameters. + # This means the exact state of the old optimizer might be lost if it was not optax-compatible, + # but model parameters are preserved. + new_opt_state = optimizer_tx.init(params_from_ckpt) + + # Create a new TrainState with the loaded parameters, re-initialized opt_state, and step=0 + optim = TrainState( + step=0, # Default to step 0 if not present in old checkpoint + params=params_from_ckpt, + tx=optimizer_tx, + opt_state=new_opt_state, + apply_fn=TransporterNets().apply + ) + print('Successfully migrated and loaded checkpoint (params restored, opt_state re-initialized, step set to 0).') + else: + print(f"Error: Checkpoint '{ckpt_path}' is not a dictionary. Cannot migrate. Using initial model state.") + else: + # Re-raise other ValueErrors + raise + +else: + + # Training loop. + batch_size = 8 + for train_iter in range(1, 40001): + batch_i = np.random.randint(dataset_size, size=batch_size) + text_feat = data_text_feats[batch_i, ...] + img = dataset['image'][batch_i, ...] / 255 + img = np.concatenate((img, np.broadcast_to(coords[None, ...], (batch_size,) + coords.shape)), axis=3) + + # Get onehot label maps. + pick_yx = np.zeros((batch_size, 2), dtype=np.int32) + pick_onehot = np.zeros((batch_size, 224, 224), dtype=np.float32) + place_onehot = np.zeros((batch_size, 224, 224), dtype=np.float32) + for i in range(len(batch_i)): + pick_y, pick_x = dataset['pick_yx'][batch_i[i], :] + place_y, place_x = dataset['place_yx'][batch_i[i], :] + pick_onehot[i, pick_y, pick_x] = 1 + place_onehot[i, place_y, place_x] = 1 + # pick_onehot[i, ...] = scipy.ndimage.gaussian_filter(pick_onehot[i, ...], sigma=3) + + # Data augmentation (random translation). + roll_y, roll_x = np.random.randint(-112, 112, size=2) + img[i, ...] = np.roll(img[i, ...], roll_y, axis=0) + img[i, ...] = np.roll(img[i, ...], roll_x, axis=1) + pick_onehot[i, ...] = np.roll(pick_onehot[i, ...], roll_y, axis=0) + pick_onehot[i, ...] = np.roll(pick_onehot[i, ...], roll_x, axis=1) + place_onehot[i, ...] = np.roll(place_onehot[i, ...], roll_y, axis=0) + place_onehot[i, ...] = np.roll(place_onehot[i, ...], roll_x, axis=1) + pick_yx[i, 0] = pick_y + roll_y + pick_yx[i, 1] = pick_x + roll_x + + # Backpropagate. + batch = {} + batch['img'] = jnp.float32(img) + batch['text'] = jnp.float32(text_feat) + batch['pick_yx'] = jnp.int32(pick_yx) + batch['pick_onehot'] = jnp.float32(pick_onehot) + batch['place_onehot'] = jnp.float32(place_onehot) + rng, batch['rng'] = jax.random.split(rng) + optim, loss, _, _ = train_step(optim, batch) + writer.scalar('train/loss', loss, train_iter) + + if train_iter % np.power(10, min(4, np.floor(np.log10(train_iter)))) == 0: + print(f'Train Step: {train_iter} Loss: {loss}') + + if train_iter % 1000 == 0: + checkpoints.save_checkpoint('.', optim, train_iter, prefix='ckpt_', keep=100000, overwrite=True) + + + + +# ============================================================================ +# CLIPort Interface Class for easy integration +# ============================================================================ + +import os + +# Get the saycan directory for checkpoint paths +SAYCAN_DIR = os.path.dirname(os.path.abspath(__file__)) + + +class CLIPort: + """CLIPort model interface for text-conditioned pick and place. + + This class provides a simple interface for using CLIPort without + needing to manage all the global variables and initialization. + """ + + def __init__(self): + self.clip_model = None + self.optim = None + self.coords = None + self._initialized = False + + def _init(self): + """Lazy initialization of CLIP and Transporter models.""" + if self._initialized: + return + + print("Initializing CLIPort...") + + # Initialize CLIP + self.clip_model, _ = clip.load("ViT-B/32") + if torch.cuda.is_available(): + self.clip_model = self.clip_model.cuda() + self.clip_model.eval() + + # Create coordinate tensor + h, w = 224, 224 + y_coords = np.linspace(0, 1, h) + x_coords = np.linspace(0, 1, w) + xx, yy = np.meshgrid(x_coords, y_coords) + self.coords = np.stack([xx, yy], axis=-1).astype(np.float32) + + # Initialize Transporter Net + rng = jax.random.PRNGKey(0) + rng, key = jax.random.split(rng) + init_img = jnp.ones((1, 224, 224, 5), jnp.float32) + init_text = jnp.ones((1, 512), jnp.float32) + init_params = TransporterNets().init(key, init_img, init_text)['params'] + + # Create optimizer state + optimizer_tx = optax.adam(learning_rate=1e-4) + self.optim = TrainState.create( + apply_fn=TransporterNets().apply, + params=init_params, + tx=optimizer_tx + ) + + # Try to load checkpoint + ckpt_path = os.path.join(SAYCAN_DIR, 'ckpt_40000') + if os.path.exists(ckpt_path): + try: + # Attempt to restore directly + self.optim = checkpoints.restore_checkpoint(ckpt_path, self.optim) + print(f"Loaded CLIPort checkpoint from {ckpt_path}") + except ValueError as e: + if "Missing field step" in str(e): + print("Migrating old checkpoint format...") + try: + # Load the raw checkpoint data as a dictionary + loaded_state_dict = checkpoints.restore_checkpoint(ckpt_path, target=None) + if isinstance(loaded_state_dict, dict): + # Extract parameters + params_from_ckpt = loaded_state_dict.get('params', loaded_state_dict.get('target', init_params)) + # Re-initialize the opt_state with loaded parameters + new_opt_state = optimizer_tx.init(params_from_ckpt) + # Create a new TrainState with the loaded parameters + self.optim = TrainState( + step=0, + params=params_from_ckpt, + tx=optimizer_tx, + opt_state=new_opt_state, + apply_fn=TransporterNets().apply + ) + print(f"Successfully migrated checkpoint from {ckpt_path}") + else: + print(f"Could not migrate checkpoint, using random initialization") + except Exception as e2: + print(f"Could not load CLIPort checkpoint: {e2}") + print("Using random initialization - model may not perform well without training.") + else: + raise + except Exception as e: + print(f"Could not load CLIPort checkpoint: {e}") + print("Using random initialization - model may not perform well without training.") + else: + print("No CLIPort checkpoint found, using random initialization.") + print("Run: python config.py to download pretrained weights") + + self._initialized = True + + def encode_text(self, text): + """Encode text instruction using CLIP.""" + with torch.no_grad(): + tokens = clip.tokenize([text]) + if torch.cuda.is_available(): + tokens = tokens.cuda() + text_feats = self.clip_model.encode_text(tokens).float() + text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True) + text_feats = text_feats.cpu().numpy() + return text_feats.astype(np.float32) + + def predict(self, observation, text): + """ + Predict pick and place coordinates from text instruction. + + Args: + observation: Dict with 'image' and 'xyzmap' + text: Text instruction string + + Returns: + action: Dict with 'pick' and 'place' 3D coordinates + """ + self._init() + + # Get image and encode text + image = observation['image'] + xyzmap = observation['xyzmap'] + text_feats = self.encode_text(text) + + # Prepare image batch + img = image[np.newaxis, ...] / 255.0 + img = np.concatenate([img, self.coords[np.newaxis, ...]], axis=-1) + + # Run inference + def eval_step(optim, batch): + pick_out, place_out = TransporterNets().apply( + {'params': optim.params}, batch['img'], batch['text'] + ) + return pick_out, place_out + + batch = {'img': jnp.float32(img), 'text': jnp.float32(text_feats)} + pick_map, place_map = eval_step(self.optim, batch) + pick_map, place_map = np.float32(pick_map[0]), np.float32(place_map[0]) + + # Get pick position + pick_max = np.argmax(pick_map.flatten()) + pick_y, pick_x = np.unravel_index(pick_max, (224, 224)) + pick_y, pick_x = np.clip(pick_y, 20, 204), np.clip(pick_x, 20, 204) + pick_xyz = xyzmap[pick_y, pick_x] + + # Get place position + place_max = np.argmax(place_map.flatten()) + place_y, place_x = np.unravel_index(place_max, (224, 224)) + place_y, place_x = np.clip(place_y, 20, 204), np.clip(place_x, 20, 204) + place_xyz = xyzmap[place_y, place_x] + + return { + 'pick': pick_xyz, + 'place': place_xyz, + 'pick_map': pick_map, + 'place_map': place_map + } + + +# Global CLIPort instance +_cliport = None + + +def get_cliport(): + """Get or create the global CLIPort instance.""" + global _cliport + if _cliport is None: + _cliport = CLIPort() + return _cliport \ No newline at end of file diff --git a/saycan/config.py b/saycan/config.py new file mode 100644 index 0000000..43e64b8 --- /dev/null +++ b/saycan/config.py @@ -0,0 +1,141 @@ +""" +SayCan Configuration and Asset Downloader. + +This module provides global configuration constants for the SayCan environment +and handles downloading required assets (robot URDFs, model weights). + +Original SayCan Repository: + https://github.com/google-research/google-research/tree/master/saycan + +Reference: + Ahn, M., Brohan, A., Brown, N., Chebotar, Y., Cortes, Y., David, B., + Finn, C., Fu, C., Gopalakrishnan, K., Hausman, K., Herzog, A., Ho, D., + Hsu, J., Ibarz, J., Ichter, B., Irpan, A., Jang, E., Jang, R., Julian, R., + ... & Zeng, A. (2022). Do As I Can, Not As I Say: Grounding Language in + Robotic Affordances. arXiv preprint arXiv:2204.01691. +""" + +import collections +import datetime +import os +import random +import threading +import time + +import cv2 # Used by ViLD. +import clip +from easydict import EasyDict +import flax +from flax import linen as nn +from flax.training import checkpoints +from flax.metrics import tensorboard +import imageio +from heapq import nlargest +import IPython +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +from moviepy import ImageSequenceClip +import numpy as np +import optax +import pickle +from PIL import Image +import pybullet +import pybullet_data +import tensorflow.compat.v1 as tf +import torch +from tqdm import tqdm + +import subprocess + +# Get the directory where this script is located +SAYCAN_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def download_assets(): + """Download PyBullet robot assets, ViLD model weights, and CLIPort checkpoint.""" + # Change to saycan directory for downloads + original_dir = os.getcwd() + os.chdir(SAYCAN_DIR) + + try: + # Download PyBullet assets (UR5e robot, Robotiq gripper, bowl) + if not os.path.exists('ur5e/ur5e.urdf'): + print("Downloading UR5e robot assets...") + subprocess.run(['gdown', '--id', '1Cc_fDSBL6QiDvNT4dpfAEbhbALSVoWcc'], check=True) + subprocess.run(['gdown', '--id', '1yOMEm-Zp_DL3nItG9RozPeJAmeOldekX'], check=True) + subprocess.run(['gdown', '--id', '1GsqNLhEl9dd4Mc3BM0dX3MibOI1FVWNM'], check=True) + + print("Extracting assets...") + subprocess.run(['unzip', '-o', 'ur5e.zip'], check=True) + subprocess.run(['unzip', '-o', 'robotiq_2f_85.zip'], check=True) + subprocess.run(['unzip', '-o', 'bowl.zip'], check=True) + + # Download ViLD pretrained model weights + if not os.path.exists('image_path_v2'): + print("Downloading ViLD model weights...") + # Try using wget with public URL since gsutil may not be available + os.makedirs('image_path_v2/variables', exist_ok=True) + base_url = 'https://storage.googleapis.com/cloud-tpu-checkpoints/detection/projects/vild/colab/image_path_v2/' + subprocess.run(['wget', '-q', base_url + 'saved_model.pb', '-O', 'image_path_v2/saved_model.pb'], check=False) + subprocess.run(['wget', '-q', base_url + 'variables/variables.data-00000-of-00001', '-O', 'image_path_v2/variables/variables.data-00000-of-00001'], check=False) + subprocess.run(['wget', '-q', base_url + 'variables/variables.index', '-O', 'image_path_v2/variables/variables.index'], check=False) + + # Download CLIPort pretrained checkpoint + if not os.path.exists('cliport_checkpoint'): + print("Downloading CLIPort pretrained checkpoint...") + os.makedirs('cliport_checkpoint', exist_ok=True) + # CLIPort checkpoint from original SayCan paper + subprocess.run(['gdown', '--id', '1NqJDTyxZOOqvCM2RZthJT5qPX3Xi-a-g', '-O', 'cliport_checkpoint/checkpoint'], check=False) + + # Download training dataset (optional, for fine-tuning) + if not os.path.exists('dataset-9999.pkl'): + print("Downloading CLIPort training dataset...") + subprocess.run(['gdown', '--id', '1yCz6C-6eLWb4SFYKdkM-wz5tlMjbG2h8'], check=False) + finally: + os.chdir(original_dir) + +download_assets() + +# ============================================================================= +# Global Constants +# ============================================================================= + +# Objects that can be picked up +PICK_TARGETS = { + "blue block": None, + "red block": None, + "green block": None, + "yellow block": None, +} + +# RGBA colors for objects +COLORS = { + "blue": (78/255, 121/255, 167/255, 255/255), + "red": (255/255, 87/255, 89/255, 255/255), + "green": (89/255, 169/255, 79/255, 255/255), + "yellow": (237/255, 201/255, 72/255, 255/255), +} + +# Target locations for placing objects (None = dynamic, tuple = fixed position) +PLACE_TARGETS = { + "blue block": None, + "red block": None, + "green block": None, + "yellow block": None, + + "blue bowl": None, + "red bowl": None, + "green bowl": None, + "yellow bowl": None, + + "top left corner": (-0.3 + 0.05, -0.2 - 0.05, 0), + "top right corner": (0.3 - 0.05, -0.2 - 0.05, 0), + "middle": (0, -0.5, 0), + "bottom left corner": (-0.3 + 0.05, -0.8 + 0.05, 0), + "bottom right corner": (0.3 - 0.05, -0.8 + 0.05, 0), +} + +# Workspace configuration +PIXEL_SIZE = 0.00267857 # Meters per pixel +BOUNDS = np.float32([[-0.3, 0.3], [-0.8, -0.2], [0, 0.15]]) # X, Y, Z bounds in meters \ No newline at end of file diff --git a/saycan/datasets.py b/saycan/datasets.py new file mode 100644 index 0000000..063198b --- /dev/null +++ b/saycan/datasets.py @@ -0,0 +1,60 @@ +#@markdown Collect demonstrations with a scripted expert, or download a pre-generated dataset. +load_pregenerated = True #@param {type:"boolean"} + +# Load pre-existing dataset. +if load_pregenerated: + if not os.path.exists('dataset-9999.pkl'): + # !gdown --id 1TECwTIfawxkRYbzlAey0z1mqXKcyfPc- + !gdown --id 1yCz6C-6eLWb4SFYKdkM-wz5tlMjbG2h8 + dataset = pickle.load(open('dataset-9999.pkl', 'rb')) # ~10K samples. + dataset_size = len(dataset['text']) + +# Generate new dataset. +else: + dataset = {} + dataset_size = 2 # Size of new dataset. + dataset['image'] = np.zeros((dataset_size, 224, 224, 3), dtype=np.uint8) + dataset['pick_yx'] = np.zeros((dataset_size, 2), dtype=np.int32) + dataset['place_yx'] = np.zeros((dataset_size, 2), dtype=np.int32) + dataset['text'] = [] + policy = ScriptedPolicy(env) + data_idx = 0 + while data_idx < dataset_size: + np.random.seed(data_idx) + num_pick, num_place = 3, 3 + + # Select random objects for data collection. + pick_items = list(PICK_TARGETS.keys()) + pick_items = np.random.choice(pick_items, size=num_pick, replace=False) + place_items = list(PLACE_TARGETS.keys()) + for pick_item in pick_items: # For simplicity: place items != pick items. + place_items.remove(pick_item) + place_items = np.random.choice(place_items, size=num_place, replace=False) + config = {'pick': pick_items, 'place': place_items} + + # Initialize environment with selected objects. + obs = env.reset(config) + + # Create text prompts. + prompts = [] + for i in range(len(pick_items)): + pick_item = pick_items[i] + place_item = place_items[i] + prompts.append(f'Pick the {pick_item} and place it on the {place_item}.') + + # Execute 3 pick and place actions. + for prompt in prompts: + act = policy.step(prompt, obs) + dataset['text'].append(prompt) + dataset['image'][data_idx, ...] = obs['image'].copy() + dataset['pick_yx'][data_idx, ...] = xyz_to_pix(act['pick']) + dataset['place_yx'][data_idx, ...] = xyz_to_pix(act['place']) + data_idx += 1 + obs, _, _, _ = env.step(act) + debug_clip = ImageSequenceClip(env.cache_video, fps=25) + display(debug_clip.ipython_display(autoplay=1, loop=1)) + env.cache_video = [] + if data_idx >= dataset_size: + break + + pickle.dump(dataset, open(f'dataset-{dataset_size}.pkl', 'wb')) \ No newline at end of file diff --git a/saycan/environment.py b/saycan/environment.py new file mode 100644 index 0000000..effd2c5 --- /dev/null +++ b/saycan/environment.py @@ -0,0 +1,67 @@ +""" +SayCan Environment Wrapper for SHARPIE. + +This module wraps the PickPlaceEnv from the SayCan codebase to work with the +SHARPIE experiment framework. It integrates: +- ViLD for open-vocabulary object detection +- LLM (via Ollama) for task planning and action scoring +- CLIPort for language-conditioned pick-and-place manipulation + +Action Types: +- "task:" - Set task and auto-plan first action +- "plan" - Get next planned action from LLM +- "" - Direct CLIPort instruction +- "done" - End episode + +Original SayCan Repository: + https://github.com/google-research/google-research/tree/master/saycan + +Reference: + Ahn, M., Brohan, A., Brown, N., Chebotar, Y., Cortes, Y., David, B., + Finn, C., Fu, C., Gopalakrishnan, K., Hausman, K., Herzog, A., Ho, D., + Hsu, J., Ibarz, J., Ichter, B., Irpan, A., Jang, E., Jang, R., Julian, R., + ... & Zeng, A. (2022). Do As I Can, Not As I Say: Grounding Language in + Robotic Affordances. arXiv preprint arXiv:2204.01691. +""" +import os +import sys + +# Add the saycan directory to path for imports +SAYCAN_DIR = os.path.dirname(os.path.abspath(__file__)) +if SAYCAN_DIR not in sys.path: + sys.path.insert(0, SAYCAN_DIR) + +from base_environment import SayCanBaseEnvironment + +class EnvironmentWrapper(SayCanBaseEnvironment): + """Wrapper for the SayCan PickPlaceEnv with LLM planning and CLIPort integration.""" + + def __init__(self): + """Initialize the environment.""" + super().__init__() + + def step(self, action_dict): + """ + Execute one step in the environment. + + Args: + action_dict: Dictionary with agent id as keys and action as value. + Action can be: + - string text instruction directly + - "task:" to set a task and auto-plan + - "plan" to get the next planned action + - "done" to end the episode + + Returns: + observation: New observation dict + reward: Reward for the action (float) + terminated: Whether the episode has ended (bool) + truncated: Whether the episode was truncated (bool) + info: Additional information (dict) + """ + # Extract action from dict (single-agent environment) + action = list(action_dict.values())[0] if isinstance(action_dict, dict) else action_dict + return super().step(action) + +# Create the environment instance for SHARPIE runner +environment = EnvironmentWrapper() diff --git a/saycan/helpers.py b/saycan/helpers.py new file mode 100644 index 0000000..93c1028 --- /dev/null +++ b/saycan/helpers.py @@ -0,0 +1,111 @@ +""" +SayCan Helper Functions. + +Utility functions for affordance scoring, scene description, and visualization. + +Original SayCan Repository: + https://github.com/google-research/google-research/tree/master/saycan + +Reference: + Ahn, M., et al. (2022). Do As I Can, Not As I Say: Grounding Language in + Robotic Affordances. arXiv preprint arXiv:2204.01691. +""" + +import numpy as np +import matplotlib.pyplot as plt +from heapq import nlargest +from config import PLACE_TARGETS + +def build_scene_description(found_objects, block_name="box", bowl_name="circle"): + scene_description = f"objects = {found_objects}" + scene_description = scene_description.replace(block_name, "block") + scene_description = scene_description.replace(bowl_name, "bowl") + scene_description = scene_description.replace("'", "") + return scene_description + +def step_to_nlp(step): + step = step.replace("robot.pick_and_place(", "") + step = step.replace(")", "") + pick, place = step.split(", ") + return "Pick the " + pick + " and place it on the " + place + "." + +def normalize_scores(scores): + max_score = max(scores.values()) + normed_scores = {key: np.clip(scores[key] / max_score, 0, 1) for key in scores} + return normed_scores + +def plot_saycan(llm_scores, vfs, combined_scores, task, correct=True, show_top=None): + if show_top: + top_options = nlargest(show_top, combined_scores, key = combined_scores.get) + # add a few top llm options in if not already shown + top_llm_options = nlargest(show_top // 2, llm_scores, key = llm_scores.get) + for llm_option in top_llm_options: + if not llm_option in top_options: + top_options.append(llm_option) + llm_scores = {option: llm_scores[option] for option in top_options} + vfs = {option: vfs[option] for option in top_options} + combined_scores = {option: combined_scores[option] for option in top_options} + + sorted_keys = dict(sorted(combined_scores.items())) + keys = [key for key in sorted_keys] + positions = np.arange(len(combined_scores.items())) + width = 0.3 + + fig = plt.figure(figsize=(12, 6)) + ax1 = fig.add_subplot(1,1,1) + + plot_llm_scores = normalize_scores({key: np.exp(llm_scores[key]) for key in sorted_keys}) + plot_llm_scores = np.asarray([plot_llm_scores[key] for key in sorted_keys]) + plot_affordance_scores = np.asarray([vfs[key] for key in sorted_keys]) + plot_combined_scores = np.asarray([combined_scores[key] for key in sorted_keys]) + + ax1.bar(positions, plot_combined_scores, 3 * width, alpha=0.6, color="#93CE8E", label="combined") + + score_colors = ["#ea9999ff" for score in plot_affordance_scores] + ax1.bar(positions + width / 2, 0 * plot_combined_scores, width, color="#ea9999ff", label="vfs") + ax1.bar(positions + width / 2, 0 * plot_combined_scores, width, color="#a4c2f4ff", label="language") + ax1.bar(positions - width / 2, np.abs(plot_affordance_scores), width, color=score_colors) + + plt.xticks(rotation="vertical") + ax1.set_ylim(0.0, 1.0) + + ax1.grid(True, which="both") + ax1.axis("on") + + ax1_llm = ax1.twinx() + ax1_llm.bar(positions + width / 2, plot_llm_scores, width, color="#a4c2f4ff", label="language") + ax1_llm.set_ylim(0.01, 1.0) + plt.yscale("log") + + font = {"fontname":"Arial", "size":"16", "color":"k" if correct else "r"} + plt.title(task, **font) + key_strings = [key.replace("robot.pick_and_place", "").replace(", ", " to ").replace("(", "").replace(")","") for key in keys] + plt.xticks(positions, key_strings, **font) + ax1.legend() + plt.show() + + + +#@title Affordance Scoring +#@markdown Given this environment does not have RL-trained policies or an asscociated value function, we use affordances through an object detector. + +def affordance_scoring(options, found_objects, verbose=False, block_name="box", bowl_name="circle", termination_string="done()"): + affordance_scores = {} + found_objects = [ + found_object.replace(block_name, "block").replace(bowl_name, "bowl") + for found_object in found_objects + list(PLACE_TARGETS.keys())[-5:]] + verbose and print("found_objects", found_objects) + for option in options: + if option == termination_string: + affordance_scores[option] = 0.2 + continue + pick, place = option.replace("robot.pick_and_place(", "").replace(")", "").split(", ") + affordance = 0 + found_objects_copy = found_objects.copy() + if pick in found_objects_copy: + found_objects_copy.remove(pick) + if place in found_objects_copy: + affordance = 1 + affordance_scores[option] = affordance + verbose and print(affordance, '\t', option) + return affordance_scores \ No newline at end of file diff --git a/saycan/llm.py b/saycan/llm.py new file mode 100644 index 0000000..34e1058 --- /dev/null +++ b/saycan/llm.py @@ -0,0 +1,170 @@ +""" +SayCan LLM Module - Language Model Integration for Task Planning. + +This module provides integration with Large Language Models (LLMs) for task +planning and action scoring. Originally designed for GPT-3, now adapted for +local Ollama-based models. + +The LLM provides: +- Few-shot prompting for task decomposition +- Action scoring based on task context +- Natural language to action mapping + +Original SayCan Repository: + https://github.com/google-research/google-research/tree/master/saycan + +Reference: + Ahn, M., et al. (2022). Do As I Can, Not As I Say: Grounding Language in + Robotic Affordances. arXiv preprint arXiv:2204.01691. +""" + +import ollama + +# Ollama client configuration +client = ollama.Client(host='http://localhost:11434') +ENGINE = "llama3.2:1b" + +from config import PICK_TARGETS, PLACE_TARGETS + +# LLM Cache for repeated queries +LLM_CACHE = {} + + +def gpt3_call(engine=ENGINE, prompt="", max_tokens=128, temperature=0): + """ + Call the LLM with caching for repeated queries. + + Args: + engine: Model name to use + prompt: Input prompt string + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + + Returns: + Generated text response + """ + cache_id = (engine, prompt, max_tokens, temperature) + if cache_id in LLM_CACHE: + print('cache hit, returning cached response') + return LLM_CACHE[cache_id] + + ollama_options = {} + if max_tokens > 0: + ollama_options['num_predict'] = max_tokens + if temperature > 0: + ollama_options['temperature'] = temperature + + response = client.generate(model=engine, prompt=prompt, options=ollama_options) + generated_text = response['response'] + LLM_CACHE[cache_id] = generated_text + return generated_text + + +def gpt3_scoring(query, options, engine=ENGINE, limit_num_options=None, option_start="\n", verbose=False, print_tokens=False): + """ + Score action options using the LLM. + + Note: For local models without log probability access, this returns + uniform scores. The actual discrimination comes from affordance scoring. + + Args: + query: Prompt context for scoring + options: List of action options to score + engine: Model name + limit_num_options: Limit number of options to score + option_start: Prefix for options (unused) + verbose: Print scoring details + print_tokens: Print token details (unused) + + Returns: + Tuple of (scores dict, empty response dict) + """ + if limit_num_options: + options = options[:limit_num_options] + verbose and print("Scoring", len(options), "options with uniform LLM scores.") + + # Uniform scores since local models don't provide log probs + uniform_logprob = 0.0 + scores = {option: uniform_logprob for option in options} + + if verbose: + for i, (option, score) in enumerate(sorted(scores.items(), key=lambda x: -x[1])): + print(score, "\t", option) + if i >= 10: + break + + return scores, {} + + +def make_options(pick_targets=None, place_targets=None, options_in_api_form=True, termination_string="done()"): + """ + Generate all possible pick-and-place action options. + + Args: + pick_targets: Dict of pickable objects (uses PICK_TARGETS if None) + place_targets: Dict of place targets (uses PLACE_TARGETS if None) + options_in_api_form: If True, use API format; otherwise natural language + termination_string: String to append for task completion + + Returns: + List of action option strings + """ + if not pick_targets: + pick_targets = PICK_TARGETS + if not place_targets: + place_targets = PLACE_TARGETS + + options = [] + for pick in pick_targets: + for place in place_targets: + if options_in_api_form: + option = f"robot.pick_and_place({pick}, {place})" + else: + option = f"Pick the {pick} and place it on the {place}." + options.append(option) + + options.append(termination_string) + print("Considering", len(options), "options") + return options + + +# Termination string for task completion +termination_string = "done()" + +# Few-shot prompt examples for task decomposition +gpt3_context = """ +objects = [red block, yellow block, blue block, green bowl] +# move all the blocks to the top left corner. +robot.pick_and_place(blue block, top left corner) +robot.pick_and_place(red block, top left corner) +robot.pick_and_place(yellow block, top left corner) +done() + +objects = [red block, yellow block, blue block, green bowl] +# put the yellow one the green thing. +robot.pick_and_place(yellow block, green bowl) +done() + +objects = [yellow block, blue block, red block] +# move the light colored block to the middle. +robot.pick_and_place(yellow block, middle) +done() + +objects = [blue block, green bowl, red block, yellow bowl, green block] +# stack the blocks. +robot.pick_and_place(green block, blue block) +robot.pick_and_place(red block, green block) +done() + +objects = [red block, blue block, green bowl, blue bowl, yellow block, green block] +# group the blue objects together. +robot.pick_and_place(blue block, blue bowl) +done() + +objects = [green bowl, red block, green block, red bowl, yellow bowl, yellow block] +# sort all the blocks into their matching color bowls. +robot.pick_and_place(green block, green bowl) +robot.pick_and_place(red block, red bowl) +robot.pick_and_place(yellow block, yellow bowl) +done() +""" \ No newline at end of file diff --git a/saycan/pick_place_env.py b/saycan/pick_place_env.py new file mode 100644 index 0000000..892d94c --- /dev/null +++ b/saycan/pick_place_env.py @@ -0,0 +1,487 @@ +""" +SayCan Pick and Place Environment. + +A Gym-style PyBullet environment for robotic pick-and-place manipulation tasks. +This environment simulates a UR5e robot arm with a Robotiq 2F-85 gripper +manipulating blocks and bowls on a workspace. + +Key Features: +- UR5e robot arm with Robotiq 2F-85 gripper +- Configurable objects (blocks and bowls of various colors) +- Pick-and-place motion primitives +- RGB-D observation with heightmap generation +- Video recording of episodes + +Original SayCan Repository: + https://github.com/google-research/google-research/tree/master/saycan + +Reference: + Ahn, M., et al. (2022). Do As I Can, Not As I Say: Grounding Language in + Robotic Affordances. arXiv preprint arXiv:2204.01691. +""" + +import os +import numpy as np +import pybullet +import pybullet_data +from robot import Robotiq2F85 +from config import COLORS, BOUNDS, PIXEL_SIZE, SAYCAN_DIR + +class PickPlaceEnv(): + + def __init__(self): + self.dt = 1/480 + self.sim_step = 0 + + # Configure and start PyBullet. + # python3 -m pybullet_utils.runServer + # pybullet.connect(pybullet.SHARED_MEMORY) # pybullet.GUI for local GUI. + pybullet.connect(pybullet.DIRECT) # pybullet.GUI for local GUI. + pybullet.configureDebugVisualizer(pybullet.COV_ENABLE_GUI, 0) + pybullet.setPhysicsEngineParameter(enableFileCaching=0) + pybullet.setAdditionalSearchPath(SAYCAN_DIR) + pybullet.setAdditionalSearchPath(pybullet_data.getDataPath()) + pybullet.setTimeStep(self.dt) + + self.home_joints = (np.pi / 2, -np.pi / 2, np.pi / 2, -np.pi / 2, 3 * np.pi / 2, 0) # Joint angles: (J0, J1, J2, J3, J4, J5). + self.home_ee_euler = (np.pi, 0, np.pi) # (RX, RY, RZ) rotation in Euler angles. + self.ee_link_id = 9 # Link ID of UR5 end effector. + self.tip_link_id = 10 # Link ID of gripper finger tips. + self.gripper = None + + def reset(self, config): + pybullet.resetSimulation(pybullet.RESET_USE_DEFORMABLE_WORLD) + pybullet.setGravity(0, 0, -9.8) + self.cache_video = [] + + # Temporarily disable rendering to load URDFs faster. + pybullet.configureDebugVisualizer(pybullet.COV_ENABLE_RENDERING, 0) + + # Add ground plane (from pybullet_data) and robot (from saycan directory). + pybullet.loadURDF("plane.urdf", [0, 0, -0.001]) + ur5e_urdf = os.path.join(SAYCAN_DIR, "ur5e", "ur5e.urdf") + self.robot_id = pybullet.loadURDF(ur5e_urdf, [0, 0, 0], flags=pybullet.URDF_USE_MATERIAL_COLORS_FROM_MTL) + self.ghost_id = pybullet.loadURDF(ur5e_urdf, [0, 0, -10]) # For forward kinematics. + self.joint_ids = [pybullet.getJointInfo(self.robot_id, i) for i in range(pybullet.getNumJoints(self.robot_id))] + self.joint_ids = [j[0] for j in self.joint_ids if j[2] == pybullet.JOINT_REVOLUTE] + + # Move robot to home configuration. + for i in range(len(self.joint_ids)): + pybullet.resetJointState(self.robot_id, self.joint_ids[i], self.home_joints[i]) + + # Add gripper. + if self.gripper is not None: + while self.gripper.constraints_thread.is_alive(): + self.constraints_thread_active = False + self.gripper = Robotiq2F85(self.robot_id, self.ee_link_id) + self.gripper.release() + + # Add workspace. + plane_shape = pybullet.createCollisionShape(pybullet.GEOM_BOX, halfExtents=[0.3, 0.3, 0.001]) + plane_visual = pybullet.createVisualShape(pybullet.GEOM_BOX, halfExtents=[0.3, 0.3, 0.001]) + plane_id = pybullet.createMultiBody(0, plane_shape, plane_visual, basePosition=[0, -0.5, 0]) + pybullet.changeVisualShape(plane_id, -1, rgbaColor=[0.2, 0.2, 0.2, 1.0]) + + # Load objects according to config. + self.config = config + self.obj_name_to_id = {} + obj_names = list(self.config["pick"]) + list(self.config["place"]) + obj_xyz = np.zeros((0, 3)) + for obj_name in obj_names: + if ("block" in obj_name) or ("bowl" in obj_name): + + # Get random position 15cm+ from other objects. + while True: + rand_x = np.random.uniform(BOUNDS[0, 0] + 0.1, BOUNDS[0, 1] - 0.1) + rand_y = np.random.uniform(BOUNDS[1, 0] + 0.1, BOUNDS[1, 1] - 0.1) + rand_xyz = np.float32([rand_x, rand_y, 0.03]).reshape(1, 3) + if len(obj_xyz) == 0: + obj_xyz = np.concatenate((obj_xyz, rand_xyz), axis=0) + break + else: + nn_dist = np.min(np.linalg.norm(obj_xyz - rand_xyz, axis=1)).squeeze() + if nn_dist > 0.15: + obj_xyz = np.concatenate((obj_xyz, rand_xyz), axis=0) + break + + object_color = COLORS[obj_name.split(" ")[0]] + object_type = obj_name.split(" ")[1] + object_position = rand_xyz.squeeze() + if object_type == "block": + object_shape = pybullet.createCollisionShape(pybullet.GEOM_BOX, halfExtents=[0.02, 0.02, 0.02]) + object_visual = pybullet.createVisualShape(pybullet.GEOM_BOX, halfExtents=[0.02, 0.02, 0.02]) + object_id = pybullet.createMultiBody(0.01, object_shape, object_visual, basePosition=object_position) + elif object_type == "bowl": + object_position[2] = 0 + bowl_urdf = os.path.join(SAYCAN_DIR, "bowl", "bowl.urdf") + object_id = pybullet.loadURDF(bowl_urdf, object_position, useFixedBase=1) + pybullet.changeVisualShape(object_id, -1, rgbaColor=object_color) + self.obj_name_to_id[obj_name] = object_id + + # Re-enable rendering. + pybullet.configureDebugVisualizer(pybullet.COV_ENABLE_RENDERING, 1) + + for _ in range(200): + pybullet.stepSimulation() + return self.get_observation() + + def servoj(self, joints): + """Move to target joint positions with position control.""" + pybullet.setJointMotorControlArray( + bodyIndex=self.robot_id, + jointIndices=self.joint_ids, + controlMode=pybullet.POSITION_CONTROL, + targetPositions=joints, + positionGains=[0.01]*6) + + def movep(self, position): + """Move to target end effector position.""" + joints = pybullet.calculateInverseKinematics( + bodyUniqueId=self.robot_id, + endEffectorLinkIndex=self.tip_link_id, + targetPosition=position, + targetOrientation=pybullet.getQuaternionFromEuler(self.home_ee_euler), + maxNumIterations=100) + self.servoj(joints) + + def step(self, action=None): + """Do pick and place motion primitive.""" + pick_xyz, place_xyz = action["pick"].copy(), action["place"].copy() + + # Set fixed primitive z-heights. + hover_xyz = pick_xyz.copy() + np.float32([0, 0, 0.2]) + pick_xyz[2] = 0.03 + place_xyz[2] = 0.15 + + # Move to object. + ee_xyz = np.float32(pybullet.getLinkState(self.robot_id, self.tip_link_id)[0]) + while np.linalg.norm(hover_xyz - ee_xyz) > 0.01: + self.movep(hover_xyz) + self.step_sim_and_render() + ee_xyz = np.float32(pybullet.getLinkState(self.robot_id, self.tip_link_id)[0]) + while np.linalg.norm(pick_xyz - ee_xyz) > 0.01: + self.movep(pick_xyz) + self.step_sim_and_render() + ee_xyz = np.float32(pybullet.getLinkState(self.robot_id, self.tip_link_id)[0]) + + # Pick up object. + self.gripper.activate() + for _ in range(240): + self.step_sim_and_render() + while np.linalg.norm(hover_xyz - ee_xyz) > 0.01: + self.movep(hover_xyz) + self.step_sim_and_render() + ee_xyz = np.float32(pybullet.getLinkState(self.robot_id, self.tip_link_id)[0]) + + # Move to place location. + while np.linalg.norm(place_xyz - ee_xyz) > 0.01: + self.movep(place_xyz) + self.step_sim_and_render() + ee_xyz = np.float32(pybullet.getLinkState(self.robot_id, self.tip_link_id)[0]) + + # Place down object. + while (not self.gripper.detect_contact()) and (place_xyz[2] > 0.03): + place_xyz[2] -= 0.001 + self.movep(place_xyz) + for _ in range(3): + self.step_sim_and_render() + self.gripper.release() + for _ in range(240): + self.step_sim_and_render() + place_xyz[2] = 0.2 + ee_xyz = np.float32(pybullet.getLinkState(self.robot_id, self.tip_link_id)[0]) + while np.linalg.norm(place_xyz - ee_xyz) > 0.01: + self.movep(place_xyz) + self.step_sim_and_render() + ee_xyz = np.float32(pybullet.getLinkState(self.robot_id, self.tip_link_id)[0]) + place_xyz = np.float32([0, -0.5, 0.2]) + while np.linalg.norm(place_xyz - ee_xyz) > 0.01: + self.movep(place_xyz) + self.step_sim_and_render() + ee_xyz = np.float32(pybullet.getLinkState(self.robot_id, self.tip_link_id)[0]) + + observation = self.get_observation() + reward = self.get_reward() + done = False + info = {} + return observation, reward, done, info + + def set_alpha_transparency(self, alpha: float) -> None: + for id in range(20): + visual_shape_data = pybullet.getVisualShapeData(id) + for i in range(len(visual_shape_data)): + object_id, link_index, _, _, _, _, _, rgba_color = visual_shape_data[i] + rgba_color = list(rgba_color[0:3]) + [alpha] + pybullet.changeVisualShape( + self.robot_id, linkIndex=i, rgbaColor=rgba_color) + pybullet.changeVisualShape( + self.gripper.body, linkIndex=i, rgbaColor=rgba_color) + + def step_sim_and_render(self): + pybullet.stepSimulation() + self.sim_step += 1 + + # Render current image at 8 FPS. + if self.sim_step % 60 == 0: + self.cache_video.append(self.get_camera_image()) + + def get_camera_image(self, resolution_factor=4): + """ + Get camera image with adjustable resolution. + + Args: + resolution_factor: Multiplier for resolution (default 4 = 960x960) + 1 = 240x240, 2 = 480x480, 3 = 720x720, 4 = 960x960 + """ + base_size = 240 + base_focal = 120. + image_size = (base_size * resolution_factor, base_size * resolution_factor) + focal = base_focal * resolution_factor + intrinsics = (focal, 0, focal, 0, focal, focal, 0, 0, 1) + color, _, _, _, _ = self.render_image(image_size, intrinsics) + return color + + def get_camera_image_top(self, + image_size=(240, 240), + intrinsics=(2000., 0, 2000., 0, 2000., 2000., 0, 0, 1), + position=(0, -0.5, 5), + orientation=(0, np.pi, -np.pi / 2), + zrange=(0.01, 1.), + set_alpha=True): + set_alpha and self.set_alpha_transparency(0) + color, _, _, _, _ = self.render_image_top(image_size, + intrinsics, + position, + orientation, + zrange) + set_alpha and self.set_alpha_transparency(1) + return color + + def get_reward(self): + return 0 # TODO: check did the robot follow text instructions? + + def get_observation(self): + observation = {} + + # Render current image. + color, depth, position, orientation, intrinsics = self.render_image() + + # Get heightmaps and colormaps. + points = self.get_pointcloud(depth, intrinsics) + position = np.float32(position).reshape(3, 1) + rotation = pybullet.getMatrixFromQuaternion(orientation) + rotation = np.float32(rotation).reshape(3, 3) + transform = np.eye(4) + transform[:3, :] = np.hstack((rotation, position)) + points = self.transform_pointcloud(points, transform) + heightmap, colormap, xyzmap = self.get_heightmap(points, color, BOUNDS, PIXEL_SIZE) + + observation["image"] = colormap + observation["xyzmap"] = xyzmap + observation["pick"] = list(self.config["pick"]) + observation["place"] = list(self.config["place"]) + return observation + + def render_image(self, image_size=(720, 720), intrinsics=(360., 0, 360., 0, 360., 360., 0, 0, 1)): + + # Camera parameters. + position = (0, -0.85, 0.4) + orientation = (np.pi / 4 + np.pi / 48, np.pi, np.pi) + orientation = pybullet.getQuaternionFromEuler(orientation) + zrange = (0.01, 10.) + noise=True + + # OpenGL camera settings. + lookdir = np.float32([0, 0, 1]).reshape(3, 1) + updir = np.float32([0, -1, 0]).reshape(3, 1) + rotation = pybullet.getMatrixFromQuaternion(orientation) + rotm = np.float32(rotation).reshape(3, 3) + lookdir = (rotm @ lookdir).reshape(-1) + updir = (rotm @ updir).reshape(-1) + lookat = position + lookdir + focal_len = intrinsics[0] + znear, zfar = (0.01, 10.) + viewm = pybullet.computeViewMatrix(position, lookat, updir) + fovh = (image_size[0] / 2) / focal_len + fovh = 180 * np.arctan(fovh) * 2 / np.pi + + # Notes: 1) FOV is vertical FOV 2) aspect must be float + aspect_ratio = image_size[1] / image_size[0] + projm = pybullet.computeProjectionMatrixFOV(fovh, aspect_ratio, znear, zfar) + + # Render with OpenGL camera settings. + # Use brighter lighting to prevent colors from appearing dark/washed out + _, _, color, depth, segm = pybullet.getCameraImage( + width=image_size[1], + height=image_size[0], + viewMatrix=viewm, + projectionMatrix=projm, + shadow=1, + lightDirection=[0.5, 0.5, 1], + lightColor=[1.0, 1.0, 1.0], + lightDistance=2.0, + flags=pybullet.ER_SEGMENTATION_MASK_OBJECT_AND_LINKINDEX, + renderer=pybullet.ER_BULLET_HARDWARE_OPENGL) + + # Get color image. + color_image_size = (image_size[0], image_size[1], 4) + color = np.array(color, dtype=np.uint8).reshape(color_image_size) + color = color[:, :, :3] # remove alpha channel + if noise: + color = np.int32(color) + color += np.int32(np.random.normal(0, 3, color.shape)) + color = np.uint8(np.clip(color, 0, 255)) + + # Get depth image. + depth_image_size = (image_size[0], image_size[1]) + zbuffer = np.float32(depth).reshape(depth_image_size) + depth = (zfar + znear - (2 * zbuffer - 1) * (zfar - znear)) + depth = (2 * znear * zfar) / depth + if noise: + depth += np.random.normal(0, 0.003, depth.shape) + + intrinsics = np.float32(intrinsics).reshape(3, 3) + return color, depth, position, orientation, intrinsics + + def render_image_top(self, + image_size=(240, 240), + intrinsics=(2000., 0, 2000., 0, 2000., 2000., 0, 0, 1), + position=(0, -0.5, 5), + orientation=(0, np.pi, -np.pi / 2), + zrange=(0.01, 1.)): + + # Camera parameters. + orientation = pybullet.getQuaternionFromEuler(orientation) + noise=True + + # OpenGL camera settings. + lookdir = np.float32([0, 0, 1]).reshape(3, 1) + updir = np.float32([0, -1, 0]).reshape(3, 1) + rotation = pybullet.getMatrixFromQuaternion(orientation) + rotm = np.float32(rotation).reshape(3, 3) + lookdir = (rotm @ lookdir).reshape(-1) + updir = (rotm @ updir).reshape(-1) + lookat = position + lookdir + focal_len = intrinsics[0] + znear, zfar = (0.01, 10.) + viewm = pybullet.computeViewMatrix(position, lookat, updir) + fovh = (image_size[0] / 2) / focal_len + fovh = 180 * np.arctan(fovh) * 2 / np.pi + + # Notes: 1) FOV is vertical FOV 2) aspect must be float + aspect_ratio = image_size[1] / image_size[0] + projm = pybullet.computeProjectionMatrixFOV(fovh, aspect_ratio, znear, zfar) + + # Render with OpenGL camera settings. + # Use brighter lighting to prevent colors from appearing dark/washed out + _, _, color, depth, segm = pybullet.getCameraImage( + width=image_size[1], + height=image_size[0], + viewMatrix=viewm, + projectionMatrix=projm, + shadow=1, + lightDirection=[0.5, 0.5, 1], + lightColor=[1.0, 1.0, 1.0], + lightDistance=2.0, + flags=pybullet.ER_SEGMENTATION_MASK_OBJECT_AND_LINKINDEX, + renderer=pybullet.ER_BULLET_HARDWARE_OPENGL) + + # Get color image. + color_image_size = (image_size[0], image_size[1], 4) + color = np.array(color, dtype=np.uint8).reshape(color_image_size) + color = color[:, :, :3] # remove alpha channel + if noise: + color = np.int32(color) + color += np.int32(np.random.normal(0, 3, color.shape)) + color = np.uint8(np.clip(color, 0, 255)) + + # Get depth image. + depth_image_size = (image_size[0], image_size[1]) + zbuffer = np.float32(depth).reshape(depth_image_size) + depth = (zfar + znear - (2 * zbuffer - 1) * (zfar - znear)) + depth = (2 * znear * zfar) / depth + if noise: + depth += np.random.normal(0, 0.003, depth.shape) + + intrinsics = np.float32(intrinsics).reshape(3, 3) + return color, depth, position, orientation, intrinsics + + def get_pointcloud(self, depth, intrinsics): + """Get 3D pointcloud from perspective depth image. + Args: + depth: HxW float array of perspective depth in meters. + intrinsics: 3x3 float array of camera intrinsics matrix. + Returns: + points: HxWx3 float array of 3D points in camera coordinates. + """ + height, width = depth.shape + xlin = np.linspace(0, width - 1, width) + ylin = np.linspace(0, height - 1, height) + px, py = np.meshgrid(xlin, ylin) + px = (px - intrinsics[0, 2]) * (depth / intrinsics[0, 0]) + py = (py - intrinsics[1, 2]) * (depth / intrinsics[1, 1]) + points = np.float32([px, py, depth]).transpose(1, 2, 0) + return points + + def transform_pointcloud(self, points, transform): + """Apply rigid transformation to 3D pointcloud. + Args: + points: HxWx3 float array of 3D points in camera coordinates. + transform: 4x4 float array representing a rigid transformation matrix. + Returns: + points: HxWx3 float array of transformed 3D points. + """ + padding = ((0, 0), (0, 0), (0, 1)) + homogen_points = np.pad(points.copy(), padding, + "constant", constant_values=1) + for i in range(3): + points[Ellipsis, i] = np.sum(transform[i, :] * homogen_points, axis=-1) + return points + + def get_heightmap(self, points, colors, bounds, pixel_size): + """Get top-down (z-axis) orthographic heightmap image from 3D pointcloud. + Args: + points: HxWx3 float array of 3D points in world coordinates. + colors: HxWx3 uint8 array of values in range 0-255 aligned with points. + bounds: 3x2 float array of values (rows: X,Y,Z; columns: min,max) defining + region in 3D space to generate heightmap in world coordinates. + pixel_size: float defining size of each pixel in meters. + Returns: + heightmap: HxW float array of height (from lower z-bound) in meters. + colormap: HxWx3 uint8 array of backprojected color aligned with heightmap. + xyzmap: HxWx3 float array of XYZ points in world coordinates. + """ + width = int(np.round((bounds[0, 1] - bounds[0, 0]) / pixel_size)) + height = int(np.round((bounds[1, 1] - bounds[1, 0]) / pixel_size)) + heightmap = np.zeros((height, width), dtype=np.float32) + colormap = np.zeros((height, width, colors.shape[-1]), dtype=np.uint8) + xyzmap = np.zeros((height, width, 3), dtype=np.float32) + + # Filter out 3D points that are outside of the predefined bounds. + ix = (points[Ellipsis, 0] >= bounds[0, 0]) & (points[Ellipsis, 0] < bounds[0, 1]) + iy = (points[Ellipsis, 1] >= bounds[1, 0]) & (points[Ellipsis, 1] < bounds[1, 1]) + iz = (points[Ellipsis, 2] >= bounds[2, 0]) & (points[Ellipsis, 2] < bounds[2, 1]) + valid = ix & iy & iz + points = points[valid] + colors = colors[valid] + + # Sort 3D points by z-value, which works with array assignment to simulate + # z-buffering for rendering the heightmap image. + iz = np.argsort(points[:, -1]) + points, colors = points[iz], colors[iz] + px = np.int32(np.floor((points[:, 0] - bounds[0, 0]) / pixel_size)) + py = np.int32(np.floor((points[:, 1] - bounds[1, 0]) / pixel_size)) + px = np.clip(px, 0, width - 1) + py = np.clip(py, 0, height - 1) + heightmap[py, px] = points[:, 2] - bounds[2, 0] + for c in range(colors.shape[-1]): + colormap[py, px, c] = colors[:, c] + xyzmap[py, px, c] = points[:, c] + colormap = colormap[::-1, :, :] # Flip up-down. + xv, yv = np.meshgrid(np.linspace(BOUNDS[0, 0], BOUNDS[0, 1], height), + np.linspace(BOUNDS[1, 0], BOUNDS[1, 1], width)) + xyzmap[:, :, 0] = xv + xyzmap[:, :, 1] = yv + xyzmap = xyzmap[::-1, :, :] # Flip up-down. + heightmap = heightmap[::-1, :] # Flip up-down. + return heightmap, colormap, xyzmap \ No newline at end of file diff --git a/saycan/policy.py b/saycan/policy.py new file mode 100644 index 0000000..8b89a47 --- /dev/null +++ b/saycan/policy.py @@ -0,0 +1,61 @@ +""" +SayCan Policy for SHARPIE. + +A simple policy wrapper for the SayCan environment. The actual LLM planning +and CLIPort execution are handled by the environment module. + +Original SayCan Repository: + https://github.com/google-research/google-research/tree/master/saycan + +Reference: + Ahn, M., et al. (2022). Do As I Can, Not As I Say: Grounding Language in + Robotic Affordances. arXiv preprint arXiv:2204.01691. +""" + + +class Policy: + """ + SayCan-based policy for pick-and-place operations. + + This policy passes participant inputs directly to the environment, + which handles LLM planning and CLIPort execution. + """ + + def __init__(self, room_name=""): + """ + Initialize the SayCan policy. + + Args: + room_name: Optional room identifier (unused, kept for compatibility) + """ + self.name = "SayCan_Policy" + self.room_name = room_name + + def predict(self, observation, participant_input=None): + """ + Predict an action based on the observation. + + Args: + observation: Current observation from the environment + participant_input: Text instruction from participant: + - "task:" to set task and auto-plan + - "plan" to get next planned action + - Direct text instruction for CLIPort + + Returns: + The participant_input (passed through to environment) + """ + return participant_input + + def update(self, observation, action, reward, done, next_observation): + """ + Update the policy based on experience (no-op for SayCan). + + SayCan doesn't use traditional RL updates. This method is kept + for compatibility with the SHARPIE framework. + """ + pass + + +# Create an instance of the policy for use by the runner +policy = Policy('saycan') \ No newline at end of file diff --git a/saycan/requirements.txt b/saycan/requirements.txt new file mode 100644 index 0000000..a822930 --- /dev/null +++ b/saycan/requirements.txt @@ -0,0 +1,41 @@ +# Text processing utilities +ftfy +regex +tqdm +fvcore + +# OpenAI CLIP (install from git) +git+https://github.com/openai/CLIP.git + +# Google Drive downloader +gdown + +# Video and image processing +moviepy +imageio +imageio-ffmpeg +opencv-python +pillow + +# Plotting and display +matplotlib +ipython + +# Robotics simulation and utilities +pybullet +ollama +easydict + +# Deep learning frameworks +tensorflow +torch +torchvision + +# JAX with CUDA support +jax[cuda] +flax +optax + +# Numerical computing +numpy +scipy \ No newline at end of file diff --git a/saycan/robot.py b/saycan/robot.py new file mode 100644 index 0000000..0eb827d --- /dev/null +++ b/saycan/robot.py @@ -0,0 +1,161 @@ +""" +Robotiq 2F-85 Gripper Control Module. + +This module provides control for the Robotiq 2F-85 parallel gripper in PyBullet +simulation. The gripper is commonly used with UR5e robot arm in manipulation tasks. + +Key Features: +- Gripper open/close control +- Grasp detection +- Thread-based constraint enforcement + +Original SayCan Repository: + https://github.com/google-research/google-research/tree/master/saycan + +Reference: + Ahn, M., et al. (2022). Do As I Can, Not As I Say: Grounding Language in + Robotic Affordances. arXiv preprint arXiv:2204.01691. +""" + +import os +import threading +import time +import numpy as np +import pybullet + +# Get the saycan directory for asset paths +SAYCAN_DIR = os.path.dirname(os.path.abspath(__file__)) + + +class Robotiq2F85: + """ + Gripper handling for Robotiq 2F-85. + + This class manages the gripper attached to a robot arm, providing + open/close functionality and grasp detection. + + Attributes: + robot: PyBullet robot body ID + tool: Link ID of the robot's tool (end effector) + body: PyBullet gripper body ID + n_joints: Number of joints in the gripper + activated: Whether the gripper is currently activated (grasping) + """ + + def __init__(self, robot, tool): + """ + Initialize the gripper. + + Args: + robot: PyBullet body ID of the robot + tool: Link ID of the robot's end effector + """ + self.robot = robot + self.tool = tool + pos = [0.1339999999999999, -0.49199999999872496, 0.5] + rot = pybullet.getQuaternionFromEuler([np.pi, 0, np.pi]) + urdf = os.path.join(SAYCAN_DIR, "robotiq_2f_85", "robotiq_2f_85.urdf") + self.body = pybullet.loadURDF(urdf, pos, rot) + self.n_joints = pybullet.getNumJoints(self.body) + self.activated = False + + # Connect gripper base to robot tool + pybullet.createConstraint( + self.robot, tool, self.body, 0, + jointType=pybullet.JOINT_FIXED, + jointAxis=[0, 0, 0], + parentFramePosition=[0, 0, 0], + childFramePosition=[0, 0, -0.07], + childFrameOrientation=pybullet.getQuaternionFromEuler([0, 0, np.pi / 2]) + ) + + # Set friction coefficients for gripper fingers + for i in range(pybullet.getNumJoints(self.body)): + pybullet.changeDynamics( + self.body, i, + lateralFriction=10.0, + spinningFriction=1.0, + rollingFriction=1.0, + frictionAnchor=True + ) + + # Start thread to handle additional gripper constraints + self.motor_joint = 1 + self.constraints_thread = threading.Thread(target=self.step) + self.constraints_thread.daemon = True + self.constraints_thread.start() + + def step(self): + """Control joint positions by enforcing hard constraints on gripper behavior.""" + while True: + try: + currj = [pybullet.getJointState(self.body, i)[0] for i in range(self.n_joints)] + indj = [6, 3, 8, 5, 10] + targj = [currj[1], -currj[1], -currj[1], currj[1], currj[1]] + pybullet.setJointMotorControlArray( + self.body, indj, pybullet.POSITION_CONTROL, targj, + positionGains=np.ones(5) + ) + except: + return + time.sleep(0.001) + + def activate(self): + """Activate the gripper (close fingers to grasp).""" + pybullet.setJointMotorControl2( + self.body, self.motor_joint, + pybullet.VELOCITY_CONTROL, + targetVelocity=1, + force=10 + ) + self.activated = True + + def release(self): + """Release the gripper (open fingers).""" + pybullet.setJointMotorControl2( + self.body, self.motor_joint, + pybullet.VELOCITY_CONTROL, + targetVelocity=-1, + force=10 + ) + self.activated = False + + def detect_contact(self): + obj, _, ray_frac = self.check_proximity() + if self.activated: + empty = self.grasp_width() < 0.01 + cbody = self.body if empty else obj + if obj == self.body or obj == 0: + return False + return self.external_contact(cbody) + # else: + # return ray_frac < 0.14 or self.external_contact() + + # Return if body is in contact with something other than gripper + def external_contact(self, body=None): + if body is None: + body = self.body + pts = pybullet.getContactPoints(bodyA=body) + pts = [pt for pt in pts if pt[2] != self.body] + return len(pts) > 0 # pylint: disable=g-explicit-length-test + + def check_grasp(self): + while self.moving(): + time.sleep(0.001) + success = self.grasp_width() > 0.01 + return success + + def grasp_width(self): + lpad = np.array(pybullet.getLinkState(self.body, 4)[0]) + rpad = np.array(pybullet.getLinkState(self.body, 9)[0]) + dist = np.linalg.norm(lpad - rpad) - 0.047813 + return dist + + def check_proximity(self): + ee_pos = np.array(pybullet.getLinkState(self.robot, self.tool)[0]) + tool_pos = np.array(pybullet.getLinkState(self.body, 0)[0]) + vec = (tool_pos - ee_pos) / np.linalg.norm((tool_pos - ee_pos)) + ee_targ = ee_pos + vec + ray_data = pybullet.rayTest(ee_pos, ee_targ)[0] + obj, link, ray_frac = ray_data[0], ray_data[1], ray_data[2] + return obj, link, ray_frac \ No newline at end of file diff --git a/saycan/vild.py b/saycan/vild.py new file mode 100644 index 0000000..4bc018c --- /dev/null +++ b/saycan/vild.py @@ -0,0 +1,680 @@ +""" +ViLD - Vision and Language Knowledge Distillation for Open-Vocabulary Object Detection. + +This module provides the ViLD (Vision-Language Detection) model for open-vocabulary +object detection. ViLD enables detecting objects beyond a fixed set of categories +by leveraging CLIP embeddings. + +Key Components: +- Text embedding building with prompt engineering +- Object detection with confidence scoring +- Visualization of detection results + +Original ViLD Repository: + https://github.com/tensorflow/tpu/tree/master/models/official/detection/projects/vild + +Reference: + Gu, X., Lin, T., Kuo, C., & Cui, Y. (2021). Open-Vocabulary Object Detection + via Vision and Language Knowledge Distillation. arXiv preprint arXiv:2104.13921. + +Used in SayCan: + https://github.com/google-research/google-research/tree/master/saycan +""" + +import os +import collections +import numpy as np +import cv2 +import torch +import clip +import matplotlib.pyplot as plt +from tqdm import tqdm +from easydict import EasyDict +from PIL import Image +import tensorflow.compat.v1 as tf + +# Get the directory where this script is located +SAYCAN_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def softmax(x, axis=-1): + """Compute softmax values for each element in x.""" + e_x = np.exp(x - np.max(x, axis=axis, keepdims=True)) + return e_x / np.sum(e_x, axis=axis, keepdims=True) + + +# ViLD configuration flags +FLAGS = { + 'prompt_engineering': True, + 'this_is': True, + 'temperature': 100.0, + 'use_softmax': False, +} +FLAGS = EasyDict(FLAGS) + +# Visualization parameters +display_input_size = (10, 10) +overall_fig_size = (18, 24) +line_thickness = 1 +fig_size_w = 35 +mask_color = 'red' +alpha = 0.5 + + +def article(name): + """Return 'an' if name starts with a vowel, 'a' otherwise.""" + return "an" if name[0] in "aeiou" else "a" + + +def processed_name(name, rm_dot=False): + """Process category name by replacing underscores and slashes.""" + res = name.replace("_", " ").replace("/", " or ").lower() + if rm_dot: + res = res.rstrip(".") + return res + + +# Prompt templates for CLIP embedding +single_template = ["a photo of {article} {}."] + +multiple_templates = [ + 'There is {article} {} in the scene.', + 'There is the {} in the scene.', + 'a photo of {article} {} in the scene.', + 'a photo of the {} in the scene.', + 'a photo of one {} in the scene.', + 'itap of {article} {}.', + 'itap of my {}.', + 'itap of the {}.', + 'a photo of {article} {}.', + 'a photo of my {}.', + 'a photo of the {}.', + 'a photo of one {}.', + 'a photo of many {}.', + 'a good photo of {article} {}.', + 'a good photo of the {}.', + 'a bad photo of {article} {}.', + 'a bad photo of the {}.', + 'a photo of a nice {}.', + 'a photo of the nice {}.', + 'a photo of a cool {}.', + 'a photo of the cool {}.', + 'a photo of a weird {}.', + 'a photo of the weird {}.', + 'a photo of a small {}.', + 'a photo of the small {}.', + 'a photo of a large {}.', + 'a photo of the large {}.', + 'a photo of a clean {}.', + 'a photo of the clean {}.', + 'a photo of a dirty {}.', + 'a photo of the dirty {}.', + 'a bright photo of {article} {}.', + 'a bright photo of the {}.', + 'a dark photo of {article} {}.', + 'a dark photo of the {}.', + 'a photo of a hard to see {}.', + 'a photo of the hard to see {}.', + 'a low resolution photo of {article} {}.', + 'a low resolution photo of the {}.', + 'a cropped photo of {article} {}.', + 'a cropped photo of the {}.', + 'a close-up photo of {article} {}.', + 'a close-up photo of the {}.', + 'a jpeg corrupted photo of {article} {}.', + 'a jpeg corrupted photo of the {}.', + 'a blurry photo of {article} {}.', + 'a blurry photo of the {}.', + 'a pixelated photo of {article} {}.', + 'a pixelated photo of the {}.', + 'a black and white photo of the {}.', + 'a black and white photo of {article} {}.', + 'a plastic {}.', + 'the plastic {}.', + 'a toy {}.', + 'the toy {}.', + 'a plushie {}.', + 'the plushie {}.', + 'a cartoon {}.', + 'the cartoon {}.', + 'an embroidered {}.', + 'the embroidered {}.', + 'a painting of the {}.', + 'a painting of a {}.', +] + +# Lazy-loaded models (loaded on first use, can be freed) +_clip_model = None +_clip_preprocess = None +_tf_session = None + + +def get_clip_model(): + """Get or lazily load the CLIP model.""" + global _clip_model, _clip_preprocess + if _clip_model is None: + _clip_model, _clip_preprocess = clip.load("ViT-B/32") + if torch.cuda.is_available(): + _clip_model.cuda() + _clip_model.eval() + print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in _clip_model.parameters()]):,}") + print("Input resolution:", _clip_model.visual.input_resolution) + print("Context length:", _clip_model.context_length) + print("Vocab size:", _clip_model.vocab_size) + return _clip_model, _clip_preprocess + + +def get_tf_session(): + """Get or lazily load the TensorFlow session.""" + global _tf_session + if _tf_session is None: + config = tf.ConfigProto(allow_soft_placement=True) + if torch.cuda.is_available(): + gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.2) + config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) + _tf_session = tf.Session(graph=tf.Graph(), config=config) + saved_model_dir = os.path.join(SAYCAN_DIR, "image_path_v2") + _ = tf.saved_model.loader.load(_tf_session, ["serve"], saved_model_dir) + return _tf_session + + +def cleanup_models(): + """Free model resources. Call this when done with ViLD.""" + global _clip_model, _clip_preprocess, _tf_session + + if _tf_session is not None: + _tf_session.close() + _tf_session = None + + if _clip_model is not None: + del _clip_model + del _clip_preprocess + _clip_model = None + _clip_preprocess = None + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +# Backward compatibility - load on import (can be disabled by setting LAZY_LOAD=true) +import os as _os +if _os.environ.get('VILD_LAZY_LOAD', '').lower() != 'true': + clip_model, clip_preprocess = get_clip_model() + session = get_tf_session() +else: + clip_model, clip_preprocess = None, None + session = None + + +def build_text_embedding(categories): + """ + Build text embeddings for object categories using CLIP. + + Args: + categories: List of category dicts with 'name' and 'id' keys + + Returns: + Numpy array of text embeddings + """ + clip_model, _ = get_clip_model() + + if FLAGS.prompt_engineering: + templates = multiple_templates + else: + templates = single_template + + run_on_gpu = torch.cuda.is_available() + + with torch.no_grad(): + all_text_embeddings = [] + print("Building text embeddings...") + for category in tqdm(categories): + texts = [ + template.format(processed_name(category["name"], rm_dot=True), + article=article(category["name"])) + for template in templates + ] + if FLAGS.this_is: + texts = [ + "This is " + text if text.startswith("a") or text.startswith("the") else text + for text in texts + ] + texts = clip.tokenize(texts) + if run_on_gpu: + texts = texts.cuda() + text_embeddings = clip_model.encode_text(texts) + text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True) + text_embedding = text_embeddings.mean(dim=0) + text_embedding /= text_embedding.norm() + all_text_embeddings.append(text_embedding.cpu()) # Move to CPU immediately + all_text_embeddings = torch.stack(all_text_embeddings, dim=1) + + # Clear GPU cache after embedding + if run_on_gpu: + torch.cuda.empty_cache() + + return all_text_embeddings.cpu().numpy().T + + +# Load ViLD TensorFlow model +config = tf.ConfigProto(allow_soft_placement=True) +if torch.cuda.is_available(): + gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.2) + config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) +session = tf.Session(graph=tf.Graph(), config=config) +saved_model_dir = os.path.join(SAYCAN_DIR, "image_path_v2") +_ = tf.saved_model.loader.load(session, ["serve"], saved_model_dir) + +numbered_categories = [{"name": str(idx), "id": idx} for idx in range(50)] +numbered_category_indices = {cat["id"]: cat for cat in numbered_categories} + + +def nms(dets, scores, thresh, max_dets=1000): + """ + Non-maximum suppression. + + Args: + dets: Detection boxes [N, 4] + scores: Detection scores [N,] + thresh: IoU threshold + max_dets: Maximum detections to keep + + Returns: + List of indices to keep + """ + y1 = dets[:, 0] + x1 = dets[:, 1] + y2 = dets[:, 2] + x2 = dets[:, 3] + + areas = (x2 - x1) * (y2 - y1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0 and len(keep) < max_dets: + i = order[0] + keep.append(i) + + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1) + h = np.maximum(0.0, yy2 - yy1) + intersection = w * h + overlap = intersection / (areas[i] + areas[order[1:]] - intersection + 1e-12) + + inds = np.where(overlap <= thresh)[0] + order = order[inds + 1] + return keep + + +import PIL.ImageColor as ImageColor +import PIL.ImageDraw as ImageDraw +import PIL.ImageFont as ImageFont + +STANDARD_COLORS = ["White"] + + +def draw_bounding_box_on_image(image, ymin, xmin, ymax, xmax, color="red", thickness=4, + display_str_list=(), use_normalized_coordinates=True): + """Adds a bounding box to an image.""" + draw = ImageDraw.Draw(image) + im_width, im_height = image.size + if use_normalized_coordinates: + (left, right, top, bottom) = ( + xmin * im_width, xmax * im_width, ymin * im_height, ymax * im_height + ) + else: + (left, right, top, bottom) = (xmin, xmax, ymin, ymax) + draw.line([(left, top), (left, bottom), (right, bottom), (right, top), (left, top)], + width=thickness, fill=color) + try: + font = ImageFont.truetype("arial.ttf", 24) + except IOError: + font = ImageFont.load_default() + + display_str_heights = [font.getbbox(ds)[3] - font.getbbox(ds)[1] for ds in display_str_list] + total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights) + + if top > total_display_str_height: + text_bottom = top + else: + text_bottom = bottom + total_display_str_height + + for display_str in display_str_list[::-1]: + text_left = min(5, left) + bbox = font.getbbox(display_str) + text_width, text_height = bbox[2] - bbox[0], bbox[3] - bbox[1] + margin = np.ceil(0.05 * text_height) + draw.rectangle([(left, text_bottom - text_height - 2 * margin), + (left + text_width, text_bottom)], fill=color) + draw.text((left + margin, text_bottom - text_height - margin), display_str, + fill="black", font=font) + text_bottom -= text_height - 2 * margin + + +def draw_bounding_box_on_image_array(image, ymin, xmin, ymax, xmax, color="red", thickness=4, + display_str_list=(), use_normalized_coordinates=True): + """Adds a bounding box to an image (numpy array).""" + image_pil = Image.fromarray(np.uint8(image)).convert("RGB") + draw_bounding_box_on_image(image_pil, ymin, xmin, ymax, xmax, color, thickness, + display_str_list, use_normalized_coordinates) + np.copyto(image, np.array(image_pil)) + + +def draw_mask_on_image_array(image, mask, color="red", alpha=0.4): + """Draws mask on an image.""" + if image.dtype != np.uint8: + raise ValueError("`image` not of type np.uint8") + if mask.dtype != np.uint8: + raise ValueError("`mask` not of type np.uint8") + if np.any(np.logical_and(mask != 1, mask != 0)): + raise ValueError("`mask` elements should be in [0, 1]") + if image.shape[:2] != mask.shape: + raise ValueError("Image and mask dimensions don't match") + + rgb = ImageColor.getrgb(color) + pil_image = Image.fromarray(image) + solid_color = np.expand_dims(np.ones_like(mask), axis=2) * np.reshape(list(rgb), [1, 1, 3]) + pil_solid_color = Image.fromarray(np.uint8(solid_color)).convert("RGBA") + pil_mask = Image.fromarray(np.uint8(255.0 * alpha * mask)).convert("L") + pil_image = Image.composite(pil_solid_color, pil_image, pil_mask) + np.copyto(image, np.array(pil_image.convert("RGB"))) + + +def visualize_boxes_and_labels_on_image_array(image, boxes, classes, scores, category_index, + instance_masks=None, instance_boundaries=None, + use_normalized_coordinates=False, max_boxes_to_draw=20, + min_score_thresh=0.5, agnostic_mode=False, + line_thickness=1, groundtruth_box_visualization_color="black", + skip_scores=False, skip_labels=False, mask_alpha=0.4, + plot_color=None): + """Overlay labeled boxes on an image with formatted scores and label names.""" + box_to_display_str_map = collections.defaultdict(list) + box_to_color_map = collections.defaultdict(str) + box_to_instance_masks_map = {} + box_to_score_map = {} + box_to_instance_boundaries_map = {} + + if not max_boxes_to_draw: + max_boxes_to_draw = boxes.shape[0] + + for i in range(min(max_boxes_to_draw, boxes.shape[0])): + if scores is None or scores[i] > min_score_thresh: + box = tuple(boxes[i].tolist()) + if instance_masks is not None: + box_to_instance_masks_map[box] = instance_masks[i] + if instance_boundaries is not None: + box_to_instance_boundaries_map[box] = instance_boundaries[i] + if scores is None: + box_to_color_map[box] = groundtruth_box_visualization_color + else: + display_str = "" + if not skip_labels: + if not agnostic_mode: + if classes[i] in list(category_index.keys()): + class_name = category_index[classes[i]]["name"] + else: + class_name = "N/A" + display_str = str(class_name) + if not skip_scores: + if not display_str: + display_str = "{}%".format(int(100 * scores[i])) + else: + float_score = ("%.2f" % scores[i]).lstrip("0") + display_str = "{}: {}".format(display_str, float_score) + box_to_score_map[box] = int(100 * scores[i]) + + box_to_display_str_map[box].append(display_str) + if plot_color is not None: + box_to_color_map[box] = plot_color + elif agnostic_mode: + box_to_color_map[box] = "DarkOrange" + else: + box_to_color_map[box] = STANDARD_COLORS[classes[i] % len(STANDARD_COLORS)] + + if box_to_score_map: + box_color_iter = sorted(box_to_color_map.items(), key=lambda kv: box_to_score_map[kv[0]]) + else: + box_color_iter = box_to_color_map.items() + + for box, color in box_color_iter: + ymin, xmin, ymax, xmax = box + if instance_masks is not None: + draw_mask_on_image_array(image, box_to_instance_masks_map[box], color=color, alpha=mask_alpha) + if instance_boundaries is not None: + draw_mask_on_image_array(image, box_to_instance_boundaries_map[box], color="red", alpha=1.0) + draw_bounding_box_on_image_array(image, ymin, xmin, ymax, xmax, color=color, + thickness=line_thickness, + display_str_list=box_to_display_str_map[box], + use_normalized_coordinates=use_normalized_coordinates) + + return image + + +def paste_instance_masks(masks, detected_boxes, image_height, image_width): + """Paste instance masks to generate the image segmentation results.""" + def expand_boxes(boxes, scale): + w_half = boxes[:, 2] * 0.5 + h_half = boxes[:, 3] * 0.5 + x_c = boxes[:, 0] + w_half + y_c = boxes[:, 1] + h_half + w_half *= scale + h_half *= scale + boxes_exp = np.zeros(boxes.shape) + boxes_exp[:, 0] = x_c - w_half + boxes_exp[:, 2] = x_c + w_half + boxes_exp[:, 1] = y_c - h_half + boxes_exp[:, 3] = y_c + h_half + return boxes_exp + + _, mask_height, mask_width = masks.shape + scale = max((mask_width + 2.0) / mask_width, (mask_height + 2.0) / mask_height) + ref_boxes = expand_boxes(detected_boxes, scale) + ref_boxes = ref_boxes.astype(np.int32) + padded_mask = np.zeros((mask_height + 2, mask_width + 2), dtype=np.float32) + segms = [] + + for mask_ind, mask in enumerate(masks): + im_mask = np.zeros((image_height, image_width), dtype=np.uint8) + padded_mask[1:-1, 1:-1] = mask[:, :] + ref_box = ref_boxes[mask_ind, :] + w = ref_box[2] - ref_box[0] + 1 + h = ref_box[3] - ref_box[1] + 1 + w = np.maximum(w, 1) + h = np.maximum(h, 1) + mask = cv2.resize(padded_mask, (w, h)) + mask = np.array(mask > 0.5, dtype=np.uint8) + x_0 = min(max(ref_box[0], 0), image_width) + x_1 = min(max(ref_box[2] + 1, 0), image_width) + y_0 = min(max(ref_box[1], 0), image_height) + y_1 = min(max(ref_box[3] + 1, 0), image_height) + im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - ref_box[1]):(y_1 - ref_box[1]), + (x_0 - ref_box[0]):(x_1 - ref_box[0])] + segms.append(im_mask) + + segms = np.array(segms) + return segms + + +def plot_mask(color, alpha, original_image, mask): + """Plot instance mask on image.""" + rgb = ImageColor.getrgb(color) + pil_image = Image.fromarray(original_image) + solid_color = np.expand_dims(np.ones_like(mask), axis=2) * np.reshape(list(rgb), [1, 1, 3]) + pil_solid_color = Image.fromarray(np.uint8(solid_color)).convert("RGBA") + pil_mask = Image.fromarray(np.uint8(255.0 * alpha * mask)).convert("L") + pil_image = Image.composite(pil_solid_color, pil_image, pil_mask) + return np.array(pil_image.convert("RGB")) + + +def display_image(path_or_array, size=(10, 10)): + """Display an image from path or array.""" + if isinstance(path_or_array, str): + image = np.asarray(Image.open(open(path_or_array, "rb")).convert("RGB")) + else: + image = path_or_array + plt.figure(figsize=size) + plt.imshow(image) + plt.axis("off") + plt.show() + + +def vild(image_path, category_name_string, params, plot_on=True, prompt_swaps=[]): + """ + Run ViLD object detection on an image. + + Args: + image_path: Path to the input image + category_name_string: Semicolon-separated category names + params: Tuple of (max_boxes, nms_thresh, min_rpn_score, min_box_area, max_box_area) + plot_on: Whether to display visualization + prompt_swaps: List of (old, new) string replacements for categories + + Returns: + List of detected object names + """ + # Preprocessing categories + for a, b in prompt_swaps: + category_name_string = category_name_string.replace(a, b) + category_names = [x.strip() for x in category_name_string.split(";")] + category_names = ["background"] + category_names + categories = [{"name": item, "id": idx + 1} for idx, item in enumerate(category_names)] + category_indices = {cat["id"]: cat for cat in categories} + + max_boxes_to_draw, nms_threshold, min_rpn_score_thresh, min_box_area, max_box_area = params + fig_size_h = min(max(5, int(len(category_names) / 2.5)), 10) + + # Run ViLD model + roi_boxes, roi_scores, detection_boxes, scores_unused, box_outputs, detection_masks, visual_features, image_info = session.run( + ["RoiBoxes:0", "RoiScores:0", "2ndStageBoxes:0", "2ndStageScoresUnused:0", + "BoxOutputs:0", "MaskOutputs:0", "VisualFeatOutputs:0", "ImageInfo:0"], + feed_dict={"Placeholder:0": [image_path]}) + + roi_boxes = np.squeeze(roi_boxes, axis=0) + roi_scores = np.squeeze(roi_scores, axis=0) + detection_boxes = np.squeeze(detection_boxes, axis=(0, 2)) + scores_unused = np.squeeze(scores_unused, axis=0) + box_outputs = np.squeeze(box_outputs, axis=0) + detection_masks = np.squeeze(detection_masks, axis=0) + visual_features = np.squeeze(visual_features, axis=0) + + image_info = np.squeeze(image_info, axis=0) + image_scale = np.tile(image_info[2:3, :], (1, 2)) + image_height = int(image_info[0, 0]) + image_width = int(image_info[0, 1]) + + rescaled_detection_boxes = detection_boxes / image_scale + + # Read image + image = np.asarray(Image.open(open(image_path, "rb")).convert("RGB")) + assert image_height == image.shape[0] + assert image_width == image.shape[1] + + # Filter boxes with NMS + nmsed_indices = nms(detection_boxes, roi_scores, thresh=nms_threshold) + box_sizes = (rescaled_detection_boxes[:, 2] - rescaled_detection_boxes[:, 0]) * \ + (rescaled_detection_boxes[:, 3] - rescaled_detection_boxes[:, 1]) + + valid_indices = np.where( + np.logical_and( + np.isin(np.arange(len(roi_scores), dtype=int), nmsed_indices), + np.logical_and( + np.logical_not(np.all(roi_boxes == 0., axis=-1)), + np.logical_and( + roi_scores >= min_rpn_score_thresh, + np.logical_and(box_sizes > min_box_area, box_sizes < max_box_area) + ) + ) + ) + )[0] + + detection_roi_scores = roi_scores[valid_indices][:max_boxes_to_draw, ...] + detection_boxes = detection_boxes[valid_indices][:max_boxes_to_draw, ...] + detection_masks = detection_masks[valid_indices][:max_boxes_to_draw, ...] + detection_visual_feat = visual_features[valid_indices][:max_boxes_to_draw, ...] + rescaled_detection_boxes = rescaled_detection_boxes[valid_indices][:max_boxes_to_draw, ...] + + # Compute text embeddings and scores + text_features = build_text_embedding(categories) + raw_scores = detection_visual_feat.dot(text_features.T) + if FLAGS.use_softmax: + scores_all = softmax(FLAGS.temperature * raw_scores, axis=-1) + else: + scores_all = raw_scores + + indices = np.argsort(-np.max(scores_all, axis=1)) + indices_fg = np.array([i for i in indices if np.argmax(scores_all[i]) != 0]) + + # Get found objects + found_objects = [] + for a, b in prompt_swaps: + category_names = [name.replace(b, a) for name in category_names] + + for anno_idx in indices[0:int(rescaled_detection_boxes.shape[0])]: + scores = scores_all[anno_idx] + if np.argmax(scores) == 0: + continue + found_object = category_names[np.argmax(scores)] + if found_object == "background": + continue + print("Found a", found_object, "with score:", np.max(scores)) + found_objects.append(category_names[np.argmax(scores)]) + + if not plot_on: + return found_objects + + # Visualization + ymin, xmin, ymax, xmax = np.split(rescaled_detection_boxes, 4, axis=-1) + processed_boxes = np.concatenate([xmin, ymin, xmax - xmin, ymax - ymin], axis=-1) + segmentations = paste_instance_masks(detection_masks, processed_boxes, image_height, image_width) + + if len(indices_fg) == 0: + display_image(np.array(image), size=overall_fig_size) + print("ViLD does not detect anything belonging to the given category") + else: + image_with_detections = visualize_boxes_and_labels_on_image_array( + np.array(image), + rescaled_detection_boxes[indices_fg], + valid_indices[:max_boxes_to_draw][indices_fg], + detection_roi_scores[indices_fg], + numbered_category_indices, + instance_masks=segmentations[indices_fg], + use_normalized_coordinates=False, + max_boxes_to_draw=max_boxes_to_draw, + min_score_thresh=min_rpn_score_thresh, + skip_scores=False, + skip_labels=True) + + plt.imshow(image_with_detections) + plt.title("ViLD detected objects and RPN scores.") + plt.show() + + return found_objects + + +# Default category names for pick-and-place tasks +category_names = [ + 'blue block', 'red block', 'green block', 'orange block', 'yellow block', + 'purple block', 'pink block', 'cyan block', 'brown block', 'gray block', + 'blue bowl', 'red bowl', 'green bowl', 'orange bowl', 'yellow bowl', + 'purple bowl', 'pink bowl', 'cyan bowl', 'brown bowl', 'gray bowl' +] + +image_path = 'tmp.jpg' + +# ViLD settings +category_name_string = ";".join(category_names) +max_boxes_to_draw = 8 +prompt_swaps = [('block', 'cube')] +nms_threshold = 0.4 +min_rpn_score_thresh = 0.4 +min_box_area = 10 +max_box_area = 3000 +vild_params = max_boxes_to_draw, nms_threshold, min_rpn_score_thresh, min_box_area, max_box_area + + +if __name__ == "__main__": + found_objects = vild(image_path, category_name_string, vild_params, plot_on=True, prompt_swaps=prompt_swaps) + print("Found objects:", found_objects) \ No newline at end of file