diff --git a/saycan/README.md b/saycan/README.md
new file mode 100644
index 0000000..740c593
--- /dev/null
+++ b/saycan/README.md
@@ -0,0 +1,105 @@
+# SayCan - Language-Conditioned Robotic Manipulation
+
+This experiment implements the SayCan approach for grounding language in robotic affordances,
+combining:
+- **ViLD**: Open-vocabulary object detection
+- **LLM (Ollama)**: Task planning and action scoring
+- **CLIPort**: Language-conditioned pick-and-place manipulation
+- **PyBullet**: Physics simulation with UR5e robot arm
+
+## Installation
+
+### Core Dependencies
+```bash
+pip install -r requirements.txt
+```
+
+### Ollama (for LLM)
+Install Ollama from [ollama.ai](https://ollama.ai) and pull a model:
+```bash
+ollama pull llama3.2:1b
+```
+
+### Asset Downloads
+Assets (robot URDFs, ViLD model, CLIPort checkpoint) are downloaded automatically on first run.
+
+## Setting on the Webserver
+
+
+```bash
+python manage.py shell -c "from experiment.models import Environment; Environment.objects.update_or_create(name='SayCan', defaults={
+ 'description': 'Language-conditioned robotic manipulation with LLM planning',
+ 'filepaths': {'environment': 'saycan/environment.py'}
+})"
+```
+
+
+```bash
+python manage.py shell -c "from experiment.models import Experiment, Environment; Experiment.objects.update_or_create(link='saycan', defaults={
+ 'name': 'SayCan',
+ 'short_description': 'Robot manipulation with natural language instructions',
+ 'long_description': 'Guide a robot arm to pick and place objects using natural language instructions. The system uses ViLD for object detection, an LLM for task planning, and CLIPort for language-conditioned manipulation.\r\n\n
\n
\nYou can give instructions like:\n
\n- \"task: put all blocks in bowls\" - Set a high-level task
\n- \"pick the blue block and place it on the red bowl\" - Direct instruction
\n
',
+ 'enabled': True,
+ 'environment': Environment.objects.get(name='SayCan'),
+ 'number_of_episodes': 1,
+ 'target_fps': 24.0,
+ 'wait_for_inputs': False
+})"
+```
+
+
+```bash
+python manage.py shell -c "from experiment.models import Policy; Policy.objects.update_or_create(name='SayCan', defaults={
+ 'description': 'SayCan policy with LLM planning and CLIPort execution',
+ 'filepaths': {'policy': 'saycan/policy.py'},
+ 'checkpoint_interval': 0
+})"
+```
+
+
+```bash
+python manage.py shell -c "from experiment.models import Agent, Policy; Agent.objects.update_or_create(role='agent_0', defaults={
+ 'name': 'Robot',
+ 'description': 'UR5e robot arm with Robotiq gripper',
+ 'policy': Policy.objects.get(name='SayCan'),
+ 'participant': True,
+ 'keyboard_inputs': {},
+ 'multiple_keyboard_inputs': False,
+ 'inputs_type': 'other',
+ 'textual_inputs': True
+})"
+```
+
+
+```bash
+python manage.py shell -c "from experiment.models import Experiment, Agent; exp = Experiment.objects.get(link='saycan'); exp.agents.add(Agent.objects.get(role='agent_0'))"
+```
+
+## Usage
+
+### Action Types
+The environment accepts the following action types:
+
+| Action | Description |
+|--------|-------------|
+| `"task:"` | Set a high-level task for LLM planning |
+| `"plan"` | Get next planned action from LLM |
+| `""` | Direct pick-and-place instruction |
+| `"done"` | End the episode |
+
+### Example Tasks
+- `task: put all blocks in bowls`
+- `task: stack the blocks`
+- `task: sort blocks by color`
+- `pick the blue block and place it on the red bowl`
+
+## References
+
+- **SayCan**: [Ahn et al. (2022) - Do As I Can, Not As I Say](https://arxiv.org/abs/2204.01691)
+- **CLIPort**: [Shridhar et al. (2021) - What and Where Pathways for Robotic Manipulation](https://arxiv.org/abs/2109.12098)
+- **ViLD**: [Gu et al. (2021) - Open-Vocabulary Object Detection via Vision and Language Knowledge Distillation](https://arxiv.org/abs/2104.13921)
+
+## Repository
+
+- Original SayCan: https://github.com/google-research/google-research/tree/master/saycan
+- CLIPort: https://github.com/cliport/cliport
\ No newline at end of file
diff --git a/saycan/base_environment.py b/saycan/base_environment.py
new file mode 100644
index 0000000..b972b94
--- /dev/null
+++ b/saycan/base_environment.py
@@ -0,0 +1,348 @@
+"""
+SayCan Environment Wrapper for SHARPIE.
+
+This module wraps the PickPlaceEnv from the SayCan codebase to work with the
+SHARPIE experiment framework. It integrates:
+- ViLD for open-vocabulary object detection
+- LLM (via Ollama) for task planning and action scoring
+- CLIPort for language-conditioned pick-and-place manipulation
+
+Action Types:
+- "task:" - Set task and auto-plan first action
+- "plan" - Get next planned action from LLM
+- "" - Direct CLIPort instruction
+- "done" - End episode
+
+Original SayCan Repository:
+ https://github.com/google-research/google-research/tree/master/saycan
+
+Reference:
+ Ahn, M., Brohan, A., Brown, N., Chebotar, Y., Cortes, Y., David, B.,
+ Finn, C., Fu, C., Gopalakrishnan, K., Hausman, K., Herzog, A., Ho, D.,
+ Hsu, J., Ibarz, J., Ichter, B., Irpan, A., Jang, E., Jang, R., Julian, R.,
+ ... & Zeng, A. (2022). Do As I Can, Not As I Say: Grounding Language in
+ Robotic Affordances. arXiv preprint arXiv:2204.01691.
+"""
+
+import cv2
+import os
+import sys
+import tempfile
+import numpy as np
+from PIL import Image
+
+# Add the saycan directory to path for imports
+SAYCAN_DIR = os.path.dirname(os.path.abspath(__file__))
+if SAYCAN_DIR not in sys.path:
+ sys.path.insert(0, SAYCAN_DIR)
+
+from pick_place_env import PickPlaceEnv
+from config import PICK_TARGETS, PLACE_TARGETS
+from cliport import get_cliport
+# Import LLM and helpers for planning
+from llm import make_options, gpt3_scoring, gpt3_context, termination_string
+from helpers import normalize_scores, step_to_nlp, affordance_scoring
+from vild import vild, category_name_string, vild_params
+
+
+class SayCanBaseEnvironment:
+ """Wrapper for the SayCan PickPlaceEnv with LLM planning and CLIPort integration."""
+
+ def __init__(self):
+ """Initialize the environment."""
+ self.env = PickPlaceEnv()
+ self.config = None
+ self._step_count = 0
+ self._max_steps = 100
+ self._cliport = None
+ self.cached_video_frames = []
+
+ # LLM planning state
+ self._current_task = None
+ self._max_tasks = 10
+ self._gpt3_prompt = None
+ self._options = None
+ self._found_objects = None
+ self._task_step_count = 0
+
+ def reset(self, config=None):
+ """
+ Reset the environment to an initial state.
+
+ Args:
+ config: Optional configuration dict with 'pick' and 'place' lists.
+ If None, uses default objects.
+
+ Returns:
+ observation: Initial observation dict with 'image', 'xyzmap', 'pick', 'place'
+ info: Additional information dict
+ """
+ self._step_count = 0
+ self.cached_video_frames = []
+
+ # Reset LLM planning state
+ self._current_task = None
+ self._gpt3_prompt = None
+ self._options = None
+ self._found_objects = None
+ self._task_step_count = 0
+
+ if config is None:
+ config = {'pick': ['yellow block', 'blue block', 'red block'],
+ 'place': ['blue bowl', 'red bowl']}
+
+ self.config = config
+ observation = self.env.reset(config)
+
+ info = {
+ "step": 0,
+ "config": config,
+ "pick_objects": config.get("pick", []),
+ "place_objects": config.get("place", [])
+ }
+
+ return observation, info
+
+ def set_task(self, task_text):
+ """
+ Set the current task from natural language.
+
+ Args:
+ task_text: Task instruction (e.g., "put all the blocks in different corners")
+ """
+ self._current_task = task_text
+ self._gpt3_prompt = gpt3_context + "\n# " + task_text + "\n"
+ self._task_step_count = 0
+ self._found_objects = None
+ self._options = None
+ print(f"Environment: Task set to '{task_text}'")
+
+ def detect_objects(self, observation=None):
+ """
+ Detect objects in the scene using ViLD.
+
+ Args:
+ observation: Observation dict with 'image'. If None, uses current observation.
+
+ Returns:
+ found_objects: List of detected object names
+ """
+ if observation is None:
+ observation = self.env.get_observation()
+
+ # Save image to temp file for ViLD
+ image = observation['image']
+ with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as f:
+ temp_path = f.name
+ Image.fromarray(image).save(temp_path)
+
+ try:
+ # Run ViLD detection
+ prompt_swaps = [('block', 'cube')]
+ found_objects = vild(temp_path, category_name_string, vild_params,
+ plot_on=False, prompt_swaps=prompt_swaps)
+ print(f"Environment: Detected objects: {found_objects}")
+ finally:
+ # Clean up temp file
+ os.unlink(temp_path)
+
+ return found_objects
+
+ def plan_next_action(self, observation=None):
+ """
+ Plan the next action using LLM + affordance scoring.
+
+ Args:
+ observation: Current observation. If None, uses current observation.
+
+ Returns:
+ action_text: Natural language action instruction
+ done: Whether the task is complete
+ """
+ if observation is None:
+ observation = self.env.get_observation()
+
+ # Detect objects if not already done
+ if self._found_objects is None:
+ self._found_objects = self.detect_objects(observation)
+
+ # Create options if not already done
+ if self._options is None:
+ self._options = make_options(PICK_TARGETS, PLACE_TARGETS,
+ termination_string=termination_string)
+
+ # Calculate affordance scores based on detected objects
+ affordance_scores = affordance_scoring(self._options, self._found_objects,
+ block_name="box", bowl_name="circle",
+ verbose=False)
+
+ # Get LLM scores
+ llm_scores, _ = gpt3_scoring(self._gpt3_prompt, self._options, verbose=True)
+
+ # Combine scores
+ combined_scores = {
+ option: np.exp(llm_scores[option]) * affordance_scores[option]
+ for option in self._options
+ }
+ combined_scores = normalize_scores(combined_scores)
+
+ # Select best action
+ selected_task = max(combined_scores, key=combined_scores.get)
+
+ # Check for termination
+ if selected_task == termination_string:
+ print("Environment: Task completed (termination signal)")
+ return "done", True
+
+ # Update prompt for next step
+ self._gpt3_prompt += selected_task + "\n"
+ self._task_step_count += 1
+
+ # Check max tasks limit
+ if self._task_step_count >= self._max_tasks:
+ print("Environment: Max steps reached")
+ return "done", True
+
+ # Convert to natural language
+ action_text = step_to_nlp(selected_task)
+ print(f"Environment: Step {self._task_step_count} - {action_text}")
+ return action_text, False
+
+ def step(self, action_dict):
+ """
+ Execute one step in the environment.
+
+ Args:
+ action_dict: Dictionary with agent id as keys and action as value.
+ Action can be:
+ - string text instruction directly
+ - "task:" to set a task and auto-plan
+ - "plan" to get the next planned action
+ - "done" to end the episode
+
+ Returns:
+ observation: New observation dict
+ reward: Reward for the action (float)
+ terminated: Whether the episode has ended (bool)
+ truncated: Whether the episode was truncated (bool)
+ info: Additional information (dict)
+ """
+ self._step_count += 1
+
+ if len(self.cached_video_frames) > 0:
+ return np.array([]), 0.0, False, False, {"info": "No action taken"}
+
+ # Extract action from dict (single-agent environment)
+ action = list(action_dict.values())[0] if isinstance(action_dict, dict) else action_dict
+
+ # Handle different action types
+ if action == 'done':
+ return np.array([]), 0.0, True, False, {"info": "Task completed"}
+ elif isinstance(action, str) and action.startswith('task:'):
+ # Execute complete task automatically
+ task_text = action[5:].strip()
+ results = self.run_task(task_text)
+ return np.array([]), results["total_reward"], False, False, results
+ elif action:
+ # Direct text instruction
+ obs, reward, _, info = self._step_with_text(action)
+ # Get the frames buffer
+ self.cached_video_frames = self.env.cache_video
+ else:
+ return np.array([]), 0.0, False, False, {"info": "No action taken"}
+
+ # Check termination conditions
+ terminated = False
+ truncated = self._step_count >= self._max_steps
+
+ info["step"] = self._step_count
+ info["max_steps"] = self._max_steps
+
+ return obs, reward, terminated, truncated, info
+
+ def _step_with_text(self, text):
+ """Execute a step using CLIPort with text instruction."""
+ if self._cliport is None:
+ self._cliport = get_cliport()
+
+ # Get current observation
+ obs = self.env.get_observation()
+
+ # Use CLIPort to predict action
+ action = self._cliport.predict(obs, text)
+
+ # Execute the predicted action
+ obs, reward, done, info = self.env.step({
+ 'pick': action['pick'],
+ 'place': action['place']
+ })
+
+ info['text_instruction'] = text
+ info['cliport_action'] = action
+
+ return obs, reward, done, info
+
+ def render(self):
+ """Render the environment."""
+ if len(self.cached_video_frames) > 0:
+ return cv2.cvtColor(self.cached_video_frames.pop(0), cv2.COLOR_BGR2RGB)
+ return cv2.cvtColor(self.env.get_camera_image(), cv2.COLOR_BGR2RGB)
+
+ def get_observation(self):
+ """Get current observation without stepping."""
+ return self.env.get_observation()
+
+ def run_task(self, task_text, max_steps=5):
+ """
+ Execute a complete task from start to finish with automatic planning.
+
+ This method sets the task and automatically executes all planned actions
+ until completion, without requiring manual 'plan' calls between steps.
+
+ Args:
+ task_text: Natural language task description (e.g., "put all blocks in bowls")
+ max_steps: Maximum number of actions to execute (default: 50)
+
+ Returns:
+ results: Dictionary containing:
+ - task: The original task text
+ - completed: Whether the task completed successfully
+ - steps: List of executed steps with actions and rewards
+ - total_reward: Cumulative reward across all steps
+ - termination_reason: Why execution stopped
+ """
+ self.set_task(task_text)
+
+ results = {
+ "task": task_text,
+ "completed": False,
+ "steps": [],
+ "total_reward": 0.0,
+ "termination_reason": None
+ }
+
+ for step in range(max_steps):
+ # Plan next action
+ action_text, task_done = self.plan_next_action()
+
+ # Check for task completion signal from LLM
+ if task_done or action_text == "done":
+ results["completed"] = True
+ results["termination_reason"] = "task_done"
+ break
+
+ # Execute the planned action
+ obs, reward, _, info = self._step_with_text(action_text)
+ results["total_reward"] += reward
+
+ results["steps"].append({
+ "step": step,
+ "action": action_text,
+ "reward": reward,
+ "info": info
+ })
+
+ # Cache final video frames for rendering
+ self.cached_video_frames = self.env.cache_video
+
+ return results
\ No newline at end of file
diff --git a/saycan/cliport.py b/saycan/cliport.py
new file mode 100644
index 0000000..6cfde29
--- /dev/null
+++ b/saycan/cliport.py
@@ -0,0 +1,490 @@
+"""
+CLIPort - CLIP + Transporter Networks for Language-Conditioned Manipulation.
+
+This module implements the CLIPort architecture for language-conditioned pick-and-place
+operations. It combines CLIP (Contrastive Language-Image Pre-training) with Transporter
+Networks to predict pick and place positions based on natural language instructions.
+
+Key Components:
+- ResNet-based encoder-decoder architecture
+- CLIP text and image encoders
+- Transporter Networks for pick and place heatmap prediction
+- Pretrained checkpoint loading
+
+CLIPort Repository:
+ https://github.com/cliport/cliport
+
+Reference:
+ Shridhar, M., Manuelli, L., & Fox, D. (2021). CLIPort: What and Where Pathways
+ for Robotic Manipulation. Conference on Robot Learning (CoRL).
+
+Used in SayCan:
+ https://github.com/google-research/google-research/tree/master/saycan
+"""
+
+import os
+import subprocess
+import numpy as np
+import torch
+import clip
+import matplotlib.pyplot as plt
+import jax
+import jax.numpy as jnp
+import optax
+import flax
+from flax import linen as nn
+from flax.training import checkpoints
+from moviepy import ImageSequenceClip
+from IPython.display import display
+
+# Get the saycan directory for checkpoint paths
+SAYCAN_DIR = os.path.dirname(os.path.abspath(__file__))
+
+class ResNetBlock(nn.Module):
+ """ResNet pre-Activation block. https://arxiv.org/pdf/1603.05027.pdf"""
+ features: int
+ stride: int = 1
+
+ def setup(self):
+ self.conv0 = nn.Conv(self.features // 4, (1, 1), (self.stride, self.stride))
+ self.conv1 = nn.Conv(self.features // 4, (3, 3))
+ self.conv2 = nn.Conv(self.features, (1, 1))
+ self.conv3 = nn.Conv(self.features, (1, 1), (self.stride, self.stride))
+
+ def __call__(self, x):
+ y = self.conv0(nn.relu(x))
+ y = self.conv1(nn.relu(y))
+ y = self.conv2(nn.relu(y))
+ if x.shape != y.shape:
+ x = self.conv3(nn.relu(x))
+ return x + y
+
+
+class UpSample(nn.Module):
+ """Simple 2D 2x bilinear upsample."""
+
+ def __call__(self, x):
+ B, H, W, C = x.shape
+ new_shape = (B, H * 2, W * 2, C)
+ return jax.image.resize(x, new_shape, 'bilinear')
+
+
+class ResNet(nn.Module):
+ """Hourglass 53-layer ResNet with 8-stride."""
+ out_dim: int
+
+ def setup(self):
+ self.dense0 = nn.Dense(8)
+
+ self.conv0 = nn.Conv(64, (3, 3), (1, 1))
+ self.block0 = ResNetBlock(64)
+ self.block1 = ResNetBlock(64)
+ self.block2 = ResNetBlock(128, stride=2)
+ self.block3 = ResNetBlock(128)
+ self.block4 = ResNetBlock(256, stride=2)
+ self.block5 = ResNetBlock(256)
+ self.block6 = ResNetBlock(512, stride=2)
+ self.block7 = ResNetBlock(512)
+
+ self.block8 = ResNetBlock(256)
+ self.block9 = ResNetBlock(256)
+ self.upsample0 = UpSample()
+ self.block10 = ResNetBlock(128)
+ self.block11 = ResNetBlock(128)
+ self.upsample1 = UpSample()
+ self.block12 = ResNetBlock(64)
+ self.block13 = ResNetBlock(64)
+ self.upsample2 = UpSample()
+ self.block14 = ResNetBlock(16)
+ self.block15 = ResNetBlock(16)
+ self.conv1 = nn.Conv(self.out_dim, (3, 3), (1, 1))
+
+ def __call__(self, x, text):
+
+ # # Project and concatenate CLIP features (early fusion).
+ # text = self.dense0(text)
+ # text = jnp.expand_dims(text, axis=(1, 2))
+ # text = jnp.broadcast_to(text, x.shape[:3] + (8,))
+ # x = jnp.concatenate((x, text), axis=-1)
+
+ x = self.conv0(x)
+ x = self.block0(x)
+ x = self.block1(x)
+ x = self.block2(x)
+ x = self.block3(x)
+ x = self.block4(x)
+ x = self.block5(x)
+ x = self.block6(x)
+ x = self.block7(x)
+
+ # Concatenate CLIP features (mid-fusion).
+ text = jnp.expand_dims(text, axis=(1, 2))
+ text = jnp.broadcast_to(text, x.shape)
+ x = jnp.concatenate((x, text), axis=-1)
+
+ x = self.block8(x)
+ x = self.block9(x)
+ x = self.upsample0(x)
+ x = self.block10(x)
+ x = self.block11(x)
+ x = self.upsample1(x)
+ x = self.block12(x)
+ x = self.block13(x)
+ x = self.upsample2(x)
+ x = self.block14(x)
+ x = self.block15(x)
+ x = self.conv1(x)
+ return x
+
+
+class TransporterNets(nn.Module):
+ """TransporterNet with 3 ResNets (translation only)."""
+
+ def setup(self):
+ # Picking affordances.
+ self.pick_net = ResNet(1)
+
+ # Pick-conditioned placing affordances.
+ self.q_net = ResNet(3) # Query (crop around pick location).
+ self.k_net = ResNet(3) # Key (place features).
+ self.crop_size = 64
+ self.crop_conv = nn.Conv(features=1, kernel_size=(self.crop_size, self.crop_size), use_bias=False, dtype=jnp.float32, padding='SAME')
+
+ def __call__(self, x, text, p=None, train=True):
+ B, H, W, C = x.shape
+ pick_out = self.pick_net(x, text) # (B, H, W, 1)
+
+ # Get key features.
+ k = self.k_net(x, text)
+
+ # Add 0-padding before cropping.
+ h = self.crop_size // 2
+ x_crop = jnp.pad(x, [(0, 0), (h, h), (h, h), (0, 0)], 'maximum')
+
+ # Get query features and convolve them over key features.
+ place_out = jnp.zeros((0, H, W, 1), jnp.float32)
+ for b in range(B):
+
+ # Get coordinates at center of crop.
+ if p is None:
+ pick_out_b = pick_out[b, ...] # (H, W, 1)
+ pick_out_b = pick_out_b.flatten() # (H * W,)
+ amax_i = jnp.argmax(pick_out_b)
+ v, u = jnp.unravel_index(amax_i, (H, W))
+ else:
+ v, u = p[b, :]
+
+ # Get query crop.
+ x_crop_b = jax.lax.dynamic_slice(x_crop, (b, v, u, 0), (1, self.crop_size, self.crop_size, x_crop.shape[3]))
+ # x_crop_b = x_crop[b:b+1, v:(v + self.crop_size), u:(u + self.crop_size), ...]
+
+ # Convolve q (query) across k (key).
+ q = self.q_net(x_crop_b, text[b:b+1, :]) # (1, H, W, 3)
+ q = jnp.transpose(q, (1, 2, 3, 0)) # (H, W, 3, 1)
+ place_out_b = self.crop_conv.apply({'params': {'kernel': q}}, k[b:b+1, ...]) # (1, H, W, 1)
+ scale = 1 / (self.crop_size * self.crop_size) # For higher softmax temperatures.
+ place_out_b *= scale
+ place_out = jnp.concatenate((place_out, place_out_b), axis=0)
+
+ return pick_out, place_out
+
+
+def n_params(params):
+ return jnp.sum(jnp.int32([n_params(v) if isinstance(v, dict) or isinstance(v, flax.core.frozen_dict.FrozenDict) else np.prod(v.shape) for v in params.values()]))
+
+from flax.training import train_state
+
+class TrainState(train_state.TrainState):
+ pass
+
+
+
+
+#@markdown Train your own model, or load a pretrained one.
+load_pretrained = True #@param {type:"boolean"}
+
+# Initialize model weights using dummy tensors.
+rng = jax.random.PRNGKey(0)
+rng, key = jax.random.split(rng)
+init_img = jnp.ones((4, 224, 224, 5), jnp.float32)
+init_text = jnp.ones((4, 512), jnp.float32)
+init_pix = jnp.zeros((4, 2), np.int32)
+init_params = TransporterNets().init(key, init_img, init_text, init_pix)['params']
+print(f'Model parameters: {n_params(init_params):,}')
+
+# Define the Optax optimizer
+optimizer_tx = optax.adam(learning_rate=1e-4)
+
+# Create an initial TrainState object. This will have step=0.
+optim = TrainState.create(apply_fn=TransporterNets().apply,
+ params=init_params,
+ tx=optimizer_tx)
+
+if load_pretrained:
+ ckpt_path = os.path.join(SAYCAN_DIR, f'ckpt_{40000}')
+ if not os.path.exists(ckpt_path):
+ import subprocess
+ print("Downloading CLIPort checkpoint...")
+ subprocess.run(['gdown', '--id', '1Nq0q1KbqHOA5O7aRSu4u7-u27EMMXqgP', '-O', ckpt_path], check=False)
+
+ try:
+ # Attempt to restore directly. This will fail if 'step' is missing in the checkpoint.
+ optim = checkpoints.restore_checkpoint(ckpt_path, optim)
+ print('Loaded:', ckpt_path)
+ except ValueError as e:
+ if "Missing field step in state dict" in str(e):
+ print("Attempting to load old checkpoint format (missing 'step' field).")
+ # Load the raw checkpoint data as a dictionary
+ loaded_state_dict = checkpoints.restore_checkpoint(ckpt_path, target=None)
+
+ if isinstance(loaded_state_dict, dict):
+ # Extract parameters, common keys for parameters are 'params' or 'target'
+ params_from_ckpt = loaded_state_dict.get('params', loaded_state_dict.get('target', init_params))
+
+ # Re-initialize the opt_state using the current optax optimizer with loaded parameters.
+ # This means the exact state of the old optimizer might be lost if it was not optax-compatible,
+ # but model parameters are preserved.
+ new_opt_state = optimizer_tx.init(params_from_ckpt)
+
+ # Create a new TrainState with the loaded parameters, re-initialized opt_state, and step=0
+ optim = TrainState(
+ step=0, # Default to step 0 if not present in old checkpoint
+ params=params_from_ckpt,
+ tx=optimizer_tx,
+ opt_state=new_opt_state,
+ apply_fn=TransporterNets().apply
+ )
+ print('Successfully migrated and loaded checkpoint (params restored, opt_state re-initialized, step set to 0).')
+ else:
+ print(f"Error: Checkpoint '{ckpt_path}' is not a dictionary. Cannot migrate. Using initial model state.")
+ else:
+ # Re-raise other ValueErrors
+ raise
+
+else:
+
+ # Training loop.
+ batch_size = 8
+ for train_iter in range(1, 40001):
+ batch_i = np.random.randint(dataset_size, size=batch_size)
+ text_feat = data_text_feats[batch_i, ...]
+ img = dataset['image'][batch_i, ...] / 255
+ img = np.concatenate((img, np.broadcast_to(coords[None, ...], (batch_size,) + coords.shape)), axis=3)
+
+ # Get onehot label maps.
+ pick_yx = np.zeros((batch_size, 2), dtype=np.int32)
+ pick_onehot = np.zeros((batch_size, 224, 224), dtype=np.float32)
+ place_onehot = np.zeros((batch_size, 224, 224), dtype=np.float32)
+ for i in range(len(batch_i)):
+ pick_y, pick_x = dataset['pick_yx'][batch_i[i], :]
+ place_y, place_x = dataset['place_yx'][batch_i[i], :]
+ pick_onehot[i, pick_y, pick_x] = 1
+ place_onehot[i, place_y, place_x] = 1
+ # pick_onehot[i, ...] = scipy.ndimage.gaussian_filter(pick_onehot[i, ...], sigma=3)
+
+ # Data augmentation (random translation).
+ roll_y, roll_x = np.random.randint(-112, 112, size=2)
+ img[i, ...] = np.roll(img[i, ...], roll_y, axis=0)
+ img[i, ...] = np.roll(img[i, ...], roll_x, axis=1)
+ pick_onehot[i, ...] = np.roll(pick_onehot[i, ...], roll_y, axis=0)
+ pick_onehot[i, ...] = np.roll(pick_onehot[i, ...], roll_x, axis=1)
+ place_onehot[i, ...] = np.roll(place_onehot[i, ...], roll_y, axis=0)
+ place_onehot[i, ...] = np.roll(place_onehot[i, ...], roll_x, axis=1)
+ pick_yx[i, 0] = pick_y + roll_y
+ pick_yx[i, 1] = pick_x + roll_x
+
+ # Backpropagate.
+ batch = {}
+ batch['img'] = jnp.float32(img)
+ batch['text'] = jnp.float32(text_feat)
+ batch['pick_yx'] = jnp.int32(pick_yx)
+ batch['pick_onehot'] = jnp.float32(pick_onehot)
+ batch['place_onehot'] = jnp.float32(place_onehot)
+ rng, batch['rng'] = jax.random.split(rng)
+ optim, loss, _, _ = train_step(optim, batch)
+ writer.scalar('train/loss', loss, train_iter)
+
+ if train_iter % np.power(10, min(4, np.floor(np.log10(train_iter)))) == 0:
+ print(f'Train Step: {train_iter} Loss: {loss}')
+
+ if train_iter % 1000 == 0:
+ checkpoints.save_checkpoint('.', optim, train_iter, prefix='ckpt_', keep=100000, overwrite=True)
+
+
+
+
+# ============================================================================
+# CLIPort Interface Class for easy integration
+# ============================================================================
+
+import os
+
+# Get the saycan directory for checkpoint paths
+SAYCAN_DIR = os.path.dirname(os.path.abspath(__file__))
+
+
+class CLIPort:
+ """CLIPort model interface for text-conditioned pick and place.
+
+ This class provides a simple interface for using CLIPort without
+ needing to manage all the global variables and initialization.
+ """
+
+ def __init__(self):
+ self.clip_model = None
+ self.optim = None
+ self.coords = None
+ self._initialized = False
+
+ def _init(self):
+ """Lazy initialization of CLIP and Transporter models."""
+ if self._initialized:
+ return
+
+ print("Initializing CLIPort...")
+
+ # Initialize CLIP
+ self.clip_model, _ = clip.load("ViT-B/32")
+ if torch.cuda.is_available():
+ self.clip_model = self.clip_model.cuda()
+ self.clip_model.eval()
+
+ # Create coordinate tensor
+ h, w = 224, 224
+ y_coords = np.linspace(0, 1, h)
+ x_coords = np.linspace(0, 1, w)
+ xx, yy = np.meshgrid(x_coords, y_coords)
+ self.coords = np.stack([xx, yy], axis=-1).astype(np.float32)
+
+ # Initialize Transporter Net
+ rng = jax.random.PRNGKey(0)
+ rng, key = jax.random.split(rng)
+ init_img = jnp.ones((1, 224, 224, 5), jnp.float32)
+ init_text = jnp.ones((1, 512), jnp.float32)
+ init_params = TransporterNets().init(key, init_img, init_text)['params']
+
+ # Create optimizer state
+ optimizer_tx = optax.adam(learning_rate=1e-4)
+ self.optim = TrainState.create(
+ apply_fn=TransporterNets().apply,
+ params=init_params,
+ tx=optimizer_tx
+ )
+
+ # Try to load checkpoint
+ ckpt_path = os.path.join(SAYCAN_DIR, 'ckpt_40000')
+ if os.path.exists(ckpt_path):
+ try:
+ # Attempt to restore directly
+ self.optim = checkpoints.restore_checkpoint(ckpt_path, self.optim)
+ print(f"Loaded CLIPort checkpoint from {ckpt_path}")
+ except ValueError as e:
+ if "Missing field step" in str(e):
+ print("Migrating old checkpoint format...")
+ try:
+ # Load the raw checkpoint data as a dictionary
+ loaded_state_dict = checkpoints.restore_checkpoint(ckpt_path, target=None)
+ if isinstance(loaded_state_dict, dict):
+ # Extract parameters
+ params_from_ckpt = loaded_state_dict.get('params', loaded_state_dict.get('target', init_params))
+ # Re-initialize the opt_state with loaded parameters
+ new_opt_state = optimizer_tx.init(params_from_ckpt)
+ # Create a new TrainState with the loaded parameters
+ self.optim = TrainState(
+ step=0,
+ params=params_from_ckpt,
+ tx=optimizer_tx,
+ opt_state=new_opt_state,
+ apply_fn=TransporterNets().apply
+ )
+ print(f"Successfully migrated checkpoint from {ckpt_path}")
+ else:
+ print(f"Could not migrate checkpoint, using random initialization")
+ except Exception as e2:
+ print(f"Could not load CLIPort checkpoint: {e2}")
+ print("Using random initialization - model may not perform well without training.")
+ else:
+ raise
+ except Exception as e:
+ print(f"Could not load CLIPort checkpoint: {e}")
+ print("Using random initialization - model may not perform well without training.")
+ else:
+ print("No CLIPort checkpoint found, using random initialization.")
+ print("Run: python config.py to download pretrained weights")
+
+ self._initialized = True
+
+ def encode_text(self, text):
+ """Encode text instruction using CLIP."""
+ with torch.no_grad():
+ tokens = clip.tokenize([text])
+ if torch.cuda.is_available():
+ tokens = tokens.cuda()
+ text_feats = self.clip_model.encode_text(tokens).float()
+ text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True)
+ text_feats = text_feats.cpu().numpy()
+ return text_feats.astype(np.float32)
+
+ def predict(self, observation, text):
+ """
+ Predict pick and place coordinates from text instruction.
+
+ Args:
+ observation: Dict with 'image' and 'xyzmap'
+ text: Text instruction string
+
+ Returns:
+ action: Dict with 'pick' and 'place' 3D coordinates
+ """
+ self._init()
+
+ # Get image and encode text
+ image = observation['image']
+ xyzmap = observation['xyzmap']
+ text_feats = self.encode_text(text)
+
+ # Prepare image batch
+ img = image[np.newaxis, ...] / 255.0
+ img = np.concatenate([img, self.coords[np.newaxis, ...]], axis=-1)
+
+ # Run inference
+ def eval_step(optim, batch):
+ pick_out, place_out = TransporterNets().apply(
+ {'params': optim.params}, batch['img'], batch['text']
+ )
+ return pick_out, place_out
+
+ batch = {'img': jnp.float32(img), 'text': jnp.float32(text_feats)}
+ pick_map, place_map = eval_step(self.optim, batch)
+ pick_map, place_map = np.float32(pick_map[0]), np.float32(place_map[0])
+
+ # Get pick position
+ pick_max = np.argmax(pick_map.flatten())
+ pick_y, pick_x = np.unravel_index(pick_max, (224, 224))
+ pick_y, pick_x = np.clip(pick_y, 20, 204), np.clip(pick_x, 20, 204)
+ pick_xyz = xyzmap[pick_y, pick_x]
+
+ # Get place position
+ place_max = np.argmax(place_map.flatten())
+ place_y, place_x = np.unravel_index(place_max, (224, 224))
+ place_y, place_x = np.clip(place_y, 20, 204), np.clip(place_x, 20, 204)
+ place_xyz = xyzmap[place_y, place_x]
+
+ return {
+ 'pick': pick_xyz,
+ 'place': place_xyz,
+ 'pick_map': pick_map,
+ 'place_map': place_map
+ }
+
+
+# Global CLIPort instance
+_cliport = None
+
+
+def get_cliport():
+ """Get or create the global CLIPort instance."""
+ global _cliport
+ if _cliport is None:
+ _cliport = CLIPort()
+ return _cliport
\ No newline at end of file
diff --git a/saycan/config.py b/saycan/config.py
new file mode 100644
index 0000000..43e64b8
--- /dev/null
+++ b/saycan/config.py
@@ -0,0 +1,141 @@
+"""
+SayCan Configuration and Asset Downloader.
+
+This module provides global configuration constants for the SayCan environment
+and handles downloading required assets (robot URDFs, model weights).
+
+Original SayCan Repository:
+ https://github.com/google-research/google-research/tree/master/saycan
+
+Reference:
+ Ahn, M., Brohan, A., Brown, N., Chebotar, Y., Cortes, Y., David, B.,
+ Finn, C., Fu, C., Gopalakrishnan, K., Hausman, K., Herzog, A., Ho, D.,
+ Hsu, J., Ibarz, J., Ichter, B., Irpan, A., Jang, E., Jang, R., Julian, R.,
+ ... & Zeng, A. (2022). Do As I Can, Not As I Say: Grounding Language in
+ Robotic Affordances. arXiv preprint arXiv:2204.01691.
+"""
+
+import collections
+import datetime
+import os
+import random
+import threading
+import time
+
+import cv2 # Used by ViLD.
+import clip
+from easydict import EasyDict
+import flax
+from flax import linen as nn
+from flax.training import checkpoints
+from flax.metrics import tensorboard
+import imageio
+from heapq import nlargest
+import IPython
+import jax
+import jax.numpy as jnp
+import matplotlib.pyplot as plt
+from moviepy import ImageSequenceClip
+import numpy as np
+import optax
+import pickle
+from PIL import Image
+import pybullet
+import pybullet_data
+import tensorflow.compat.v1 as tf
+import torch
+from tqdm import tqdm
+
+import subprocess
+
+# Get the directory where this script is located
+SAYCAN_DIR = os.path.dirname(os.path.abspath(__file__))
+
+
+def download_assets():
+ """Download PyBullet robot assets, ViLD model weights, and CLIPort checkpoint."""
+ # Change to saycan directory for downloads
+ original_dir = os.getcwd()
+ os.chdir(SAYCAN_DIR)
+
+ try:
+ # Download PyBullet assets (UR5e robot, Robotiq gripper, bowl)
+ if not os.path.exists('ur5e/ur5e.urdf'):
+ print("Downloading UR5e robot assets...")
+ subprocess.run(['gdown', '--id', '1Cc_fDSBL6QiDvNT4dpfAEbhbALSVoWcc'], check=True)
+ subprocess.run(['gdown', '--id', '1yOMEm-Zp_DL3nItG9RozPeJAmeOldekX'], check=True)
+ subprocess.run(['gdown', '--id', '1GsqNLhEl9dd4Mc3BM0dX3MibOI1FVWNM'], check=True)
+
+ print("Extracting assets...")
+ subprocess.run(['unzip', '-o', 'ur5e.zip'], check=True)
+ subprocess.run(['unzip', '-o', 'robotiq_2f_85.zip'], check=True)
+ subprocess.run(['unzip', '-o', 'bowl.zip'], check=True)
+
+ # Download ViLD pretrained model weights
+ if not os.path.exists('image_path_v2'):
+ print("Downloading ViLD model weights...")
+ # Try using wget with public URL since gsutil may not be available
+ os.makedirs('image_path_v2/variables', exist_ok=True)
+ base_url = 'https://storage.googleapis.com/cloud-tpu-checkpoints/detection/projects/vild/colab/image_path_v2/'
+ subprocess.run(['wget', '-q', base_url + 'saved_model.pb', '-O', 'image_path_v2/saved_model.pb'], check=False)
+ subprocess.run(['wget', '-q', base_url + 'variables/variables.data-00000-of-00001', '-O', 'image_path_v2/variables/variables.data-00000-of-00001'], check=False)
+ subprocess.run(['wget', '-q', base_url + 'variables/variables.index', '-O', 'image_path_v2/variables/variables.index'], check=False)
+
+ # Download CLIPort pretrained checkpoint
+ if not os.path.exists('cliport_checkpoint'):
+ print("Downloading CLIPort pretrained checkpoint...")
+ os.makedirs('cliport_checkpoint', exist_ok=True)
+ # CLIPort checkpoint from original SayCan paper
+ subprocess.run(['gdown', '--id', '1NqJDTyxZOOqvCM2RZthJT5qPX3Xi-a-g', '-O', 'cliport_checkpoint/checkpoint'], check=False)
+
+ # Download training dataset (optional, for fine-tuning)
+ if not os.path.exists('dataset-9999.pkl'):
+ print("Downloading CLIPort training dataset...")
+ subprocess.run(['gdown', '--id', '1yCz6C-6eLWb4SFYKdkM-wz5tlMjbG2h8'], check=False)
+ finally:
+ os.chdir(original_dir)
+
+download_assets()
+
+# =============================================================================
+# Global Constants
+# =============================================================================
+
+# Objects that can be picked up
+PICK_TARGETS = {
+ "blue block": None,
+ "red block": None,
+ "green block": None,
+ "yellow block": None,
+}
+
+# RGBA colors for objects
+COLORS = {
+ "blue": (78/255, 121/255, 167/255, 255/255),
+ "red": (255/255, 87/255, 89/255, 255/255),
+ "green": (89/255, 169/255, 79/255, 255/255),
+ "yellow": (237/255, 201/255, 72/255, 255/255),
+}
+
+# Target locations for placing objects (None = dynamic, tuple = fixed position)
+PLACE_TARGETS = {
+ "blue block": None,
+ "red block": None,
+ "green block": None,
+ "yellow block": None,
+
+ "blue bowl": None,
+ "red bowl": None,
+ "green bowl": None,
+ "yellow bowl": None,
+
+ "top left corner": (-0.3 + 0.05, -0.2 - 0.05, 0),
+ "top right corner": (0.3 - 0.05, -0.2 - 0.05, 0),
+ "middle": (0, -0.5, 0),
+ "bottom left corner": (-0.3 + 0.05, -0.8 + 0.05, 0),
+ "bottom right corner": (0.3 - 0.05, -0.8 + 0.05, 0),
+}
+
+# Workspace configuration
+PIXEL_SIZE = 0.00267857 # Meters per pixel
+BOUNDS = np.float32([[-0.3, 0.3], [-0.8, -0.2], [0, 0.15]]) # X, Y, Z bounds in meters
\ No newline at end of file
diff --git a/saycan/datasets.py b/saycan/datasets.py
new file mode 100644
index 0000000..063198b
--- /dev/null
+++ b/saycan/datasets.py
@@ -0,0 +1,60 @@
+#@markdown Collect demonstrations with a scripted expert, or download a pre-generated dataset.
+load_pregenerated = True #@param {type:"boolean"}
+
+# Load pre-existing dataset.
+if load_pregenerated:
+ if not os.path.exists('dataset-9999.pkl'):
+ # !gdown --id 1TECwTIfawxkRYbzlAey0z1mqXKcyfPc-
+ !gdown --id 1yCz6C-6eLWb4SFYKdkM-wz5tlMjbG2h8
+ dataset = pickle.load(open('dataset-9999.pkl', 'rb')) # ~10K samples.
+ dataset_size = len(dataset['text'])
+
+# Generate new dataset.
+else:
+ dataset = {}
+ dataset_size = 2 # Size of new dataset.
+ dataset['image'] = np.zeros((dataset_size, 224, 224, 3), dtype=np.uint8)
+ dataset['pick_yx'] = np.zeros((dataset_size, 2), dtype=np.int32)
+ dataset['place_yx'] = np.zeros((dataset_size, 2), dtype=np.int32)
+ dataset['text'] = []
+ policy = ScriptedPolicy(env)
+ data_idx = 0
+ while data_idx < dataset_size:
+ np.random.seed(data_idx)
+ num_pick, num_place = 3, 3
+
+ # Select random objects for data collection.
+ pick_items = list(PICK_TARGETS.keys())
+ pick_items = np.random.choice(pick_items, size=num_pick, replace=False)
+ place_items = list(PLACE_TARGETS.keys())
+ for pick_item in pick_items: # For simplicity: place items != pick items.
+ place_items.remove(pick_item)
+ place_items = np.random.choice(place_items, size=num_place, replace=False)
+ config = {'pick': pick_items, 'place': place_items}
+
+ # Initialize environment with selected objects.
+ obs = env.reset(config)
+
+ # Create text prompts.
+ prompts = []
+ for i in range(len(pick_items)):
+ pick_item = pick_items[i]
+ place_item = place_items[i]
+ prompts.append(f'Pick the {pick_item} and place it on the {place_item}.')
+
+ # Execute 3 pick and place actions.
+ for prompt in prompts:
+ act = policy.step(prompt, obs)
+ dataset['text'].append(prompt)
+ dataset['image'][data_idx, ...] = obs['image'].copy()
+ dataset['pick_yx'][data_idx, ...] = xyz_to_pix(act['pick'])
+ dataset['place_yx'][data_idx, ...] = xyz_to_pix(act['place'])
+ data_idx += 1
+ obs, _, _, _ = env.step(act)
+ debug_clip = ImageSequenceClip(env.cache_video, fps=25)
+ display(debug_clip.ipython_display(autoplay=1, loop=1))
+ env.cache_video = []
+ if data_idx >= dataset_size:
+ break
+
+ pickle.dump(dataset, open(f'dataset-{dataset_size}.pkl', 'wb'))
\ No newline at end of file
diff --git a/saycan/environment.py b/saycan/environment.py
new file mode 100644
index 0000000..effd2c5
--- /dev/null
+++ b/saycan/environment.py
@@ -0,0 +1,67 @@
+"""
+SayCan Environment Wrapper for SHARPIE.
+
+This module wraps the PickPlaceEnv from the SayCan codebase to work with the
+SHARPIE experiment framework. It integrates:
+- ViLD for open-vocabulary object detection
+- LLM (via Ollama) for task planning and action scoring
+- CLIPort for language-conditioned pick-and-place manipulation
+
+Action Types:
+- "task:" - Set task and auto-plan first action
+- "plan" - Get next planned action from LLM
+- "" - Direct CLIPort instruction
+- "done" - End episode
+
+Original SayCan Repository:
+ https://github.com/google-research/google-research/tree/master/saycan
+
+Reference:
+ Ahn, M., Brohan, A., Brown, N., Chebotar, Y., Cortes, Y., David, B.,
+ Finn, C., Fu, C., Gopalakrishnan, K., Hausman, K., Herzog, A., Ho, D.,
+ Hsu, J., Ibarz, J., Ichter, B., Irpan, A., Jang, E., Jang, R., Julian, R.,
+ ... & Zeng, A. (2022). Do As I Can, Not As I Say: Grounding Language in
+ Robotic Affordances. arXiv preprint arXiv:2204.01691.
+"""
+import os
+import sys
+
+# Add the saycan directory to path for imports
+SAYCAN_DIR = os.path.dirname(os.path.abspath(__file__))
+if SAYCAN_DIR not in sys.path:
+ sys.path.insert(0, SAYCAN_DIR)
+
+from base_environment import SayCanBaseEnvironment
+
+class EnvironmentWrapper(SayCanBaseEnvironment):
+ """Wrapper for the SayCan PickPlaceEnv with LLM planning and CLIPort integration."""
+
+ def __init__(self):
+ """Initialize the environment."""
+ super().__init__()
+
+ def step(self, action_dict):
+ """
+ Execute one step in the environment.
+
+ Args:
+ action_dict: Dictionary with agent id as keys and action as value.
+ Action can be:
+ - string text instruction directly
+ - "task:" to set a task and auto-plan
+ - "plan" to get the next planned action
+ - "done" to end the episode
+
+ Returns:
+ observation: New observation dict
+ reward: Reward for the action (float)
+ terminated: Whether the episode has ended (bool)
+ truncated: Whether the episode was truncated (bool)
+ info: Additional information (dict)
+ """
+ # Extract action from dict (single-agent environment)
+ action = list(action_dict.values())[0] if isinstance(action_dict, dict) else action_dict
+ return super().step(action)
+
+# Create the environment instance for SHARPIE runner
+environment = EnvironmentWrapper()
diff --git a/saycan/helpers.py b/saycan/helpers.py
new file mode 100644
index 0000000..93c1028
--- /dev/null
+++ b/saycan/helpers.py
@@ -0,0 +1,111 @@
+"""
+SayCan Helper Functions.
+
+Utility functions for affordance scoring, scene description, and visualization.
+
+Original SayCan Repository:
+ https://github.com/google-research/google-research/tree/master/saycan
+
+Reference:
+ Ahn, M., et al. (2022). Do As I Can, Not As I Say: Grounding Language in
+ Robotic Affordances. arXiv preprint arXiv:2204.01691.
+"""
+
+import numpy as np
+import matplotlib.pyplot as plt
+from heapq import nlargest
+from config import PLACE_TARGETS
+
+def build_scene_description(found_objects, block_name="box", bowl_name="circle"):
+ scene_description = f"objects = {found_objects}"
+ scene_description = scene_description.replace(block_name, "block")
+ scene_description = scene_description.replace(bowl_name, "bowl")
+ scene_description = scene_description.replace("'", "")
+ return scene_description
+
+def step_to_nlp(step):
+ step = step.replace("robot.pick_and_place(", "")
+ step = step.replace(")", "")
+ pick, place = step.split(", ")
+ return "Pick the " + pick + " and place it on the " + place + "."
+
+def normalize_scores(scores):
+ max_score = max(scores.values())
+ normed_scores = {key: np.clip(scores[key] / max_score, 0, 1) for key in scores}
+ return normed_scores
+
+def plot_saycan(llm_scores, vfs, combined_scores, task, correct=True, show_top=None):
+ if show_top:
+ top_options = nlargest(show_top, combined_scores, key = combined_scores.get)
+ # add a few top llm options in if not already shown
+ top_llm_options = nlargest(show_top // 2, llm_scores, key = llm_scores.get)
+ for llm_option in top_llm_options:
+ if not llm_option in top_options:
+ top_options.append(llm_option)
+ llm_scores = {option: llm_scores[option] for option in top_options}
+ vfs = {option: vfs[option] for option in top_options}
+ combined_scores = {option: combined_scores[option] for option in top_options}
+
+ sorted_keys = dict(sorted(combined_scores.items()))
+ keys = [key for key in sorted_keys]
+ positions = np.arange(len(combined_scores.items()))
+ width = 0.3
+
+ fig = plt.figure(figsize=(12, 6))
+ ax1 = fig.add_subplot(1,1,1)
+
+ plot_llm_scores = normalize_scores({key: np.exp(llm_scores[key]) for key in sorted_keys})
+ plot_llm_scores = np.asarray([plot_llm_scores[key] for key in sorted_keys])
+ plot_affordance_scores = np.asarray([vfs[key] for key in sorted_keys])
+ plot_combined_scores = np.asarray([combined_scores[key] for key in sorted_keys])
+
+ ax1.bar(positions, plot_combined_scores, 3 * width, alpha=0.6, color="#93CE8E", label="combined")
+
+ score_colors = ["#ea9999ff" for score in plot_affordance_scores]
+ ax1.bar(positions + width / 2, 0 * plot_combined_scores, width, color="#ea9999ff", label="vfs")
+ ax1.bar(positions + width / 2, 0 * plot_combined_scores, width, color="#a4c2f4ff", label="language")
+ ax1.bar(positions - width / 2, np.abs(plot_affordance_scores), width, color=score_colors)
+
+ plt.xticks(rotation="vertical")
+ ax1.set_ylim(0.0, 1.0)
+
+ ax1.grid(True, which="both")
+ ax1.axis("on")
+
+ ax1_llm = ax1.twinx()
+ ax1_llm.bar(positions + width / 2, plot_llm_scores, width, color="#a4c2f4ff", label="language")
+ ax1_llm.set_ylim(0.01, 1.0)
+ plt.yscale("log")
+
+ font = {"fontname":"Arial", "size":"16", "color":"k" if correct else "r"}
+ plt.title(task, **font)
+ key_strings = [key.replace("robot.pick_and_place", "").replace(", ", " to ").replace("(", "").replace(")","") for key in keys]
+ plt.xticks(positions, key_strings, **font)
+ ax1.legend()
+ plt.show()
+
+
+
+#@title Affordance Scoring
+#@markdown Given this environment does not have RL-trained policies or an asscociated value function, we use affordances through an object detector.
+
+def affordance_scoring(options, found_objects, verbose=False, block_name="box", bowl_name="circle", termination_string="done()"):
+ affordance_scores = {}
+ found_objects = [
+ found_object.replace(block_name, "block").replace(bowl_name, "bowl")
+ for found_object in found_objects + list(PLACE_TARGETS.keys())[-5:]]
+ verbose and print("found_objects", found_objects)
+ for option in options:
+ if option == termination_string:
+ affordance_scores[option] = 0.2
+ continue
+ pick, place = option.replace("robot.pick_and_place(", "").replace(")", "").split(", ")
+ affordance = 0
+ found_objects_copy = found_objects.copy()
+ if pick in found_objects_copy:
+ found_objects_copy.remove(pick)
+ if place in found_objects_copy:
+ affordance = 1
+ affordance_scores[option] = affordance
+ verbose and print(affordance, '\t', option)
+ return affordance_scores
\ No newline at end of file
diff --git a/saycan/llm.py b/saycan/llm.py
new file mode 100644
index 0000000..34e1058
--- /dev/null
+++ b/saycan/llm.py
@@ -0,0 +1,170 @@
+"""
+SayCan LLM Module - Language Model Integration for Task Planning.
+
+This module provides integration with Large Language Models (LLMs) for task
+planning and action scoring. Originally designed for GPT-3, now adapted for
+local Ollama-based models.
+
+The LLM provides:
+- Few-shot prompting for task decomposition
+- Action scoring based on task context
+- Natural language to action mapping
+
+Original SayCan Repository:
+ https://github.com/google-research/google-research/tree/master/saycan
+
+Reference:
+ Ahn, M., et al. (2022). Do As I Can, Not As I Say: Grounding Language in
+ Robotic Affordances. arXiv preprint arXiv:2204.01691.
+"""
+
+import ollama
+
+# Ollama client configuration
+client = ollama.Client(host='http://localhost:11434')
+ENGINE = "llama3.2:1b"
+
+from config import PICK_TARGETS, PLACE_TARGETS
+
+# LLM Cache for repeated queries
+LLM_CACHE = {}
+
+
+def gpt3_call(engine=ENGINE, prompt="", max_tokens=128, temperature=0):
+ """
+ Call the LLM with caching for repeated queries.
+
+ Args:
+ engine: Model name to use
+ prompt: Input prompt string
+ max_tokens: Maximum tokens to generate
+ temperature: Sampling temperature
+
+ Returns:
+ Generated text response
+ """
+ cache_id = (engine, prompt, max_tokens, temperature)
+ if cache_id in LLM_CACHE:
+ print('cache hit, returning cached response')
+ return LLM_CACHE[cache_id]
+
+ ollama_options = {}
+ if max_tokens > 0:
+ ollama_options['num_predict'] = max_tokens
+ if temperature > 0:
+ ollama_options['temperature'] = temperature
+
+ response = client.generate(model=engine, prompt=prompt, options=ollama_options)
+ generated_text = response['response']
+ LLM_CACHE[cache_id] = generated_text
+ return generated_text
+
+
+def gpt3_scoring(query, options, engine=ENGINE, limit_num_options=None, option_start="\n", verbose=False, print_tokens=False):
+ """
+ Score action options using the LLM.
+
+ Note: For local models without log probability access, this returns
+ uniform scores. The actual discrimination comes from affordance scoring.
+
+ Args:
+ query: Prompt context for scoring
+ options: List of action options to score
+ engine: Model name
+ limit_num_options: Limit number of options to score
+ option_start: Prefix for options (unused)
+ verbose: Print scoring details
+ print_tokens: Print token details (unused)
+
+ Returns:
+ Tuple of (scores dict, empty response dict)
+ """
+ if limit_num_options:
+ options = options[:limit_num_options]
+ verbose and print("Scoring", len(options), "options with uniform LLM scores.")
+
+ # Uniform scores since local models don't provide log probs
+ uniform_logprob = 0.0
+ scores = {option: uniform_logprob for option in options}
+
+ if verbose:
+ for i, (option, score) in enumerate(sorted(scores.items(), key=lambda x: -x[1])):
+ print(score, "\t", option)
+ if i >= 10:
+ break
+
+ return scores, {}
+
+
+def make_options(pick_targets=None, place_targets=None, options_in_api_form=True, termination_string="done()"):
+ """
+ Generate all possible pick-and-place action options.
+
+ Args:
+ pick_targets: Dict of pickable objects (uses PICK_TARGETS if None)
+ place_targets: Dict of place targets (uses PLACE_TARGETS if None)
+ options_in_api_form: If True, use API format; otherwise natural language
+ termination_string: String to append for task completion
+
+ Returns:
+ List of action option strings
+ """
+ if not pick_targets:
+ pick_targets = PICK_TARGETS
+ if not place_targets:
+ place_targets = PLACE_TARGETS
+
+ options = []
+ for pick in pick_targets:
+ for place in place_targets:
+ if options_in_api_form:
+ option = f"robot.pick_and_place({pick}, {place})"
+ else:
+ option = f"Pick the {pick} and place it on the {place}."
+ options.append(option)
+
+ options.append(termination_string)
+ print("Considering", len(options), "options")
+ return options
+
+
+# Termination string for task completion
+termination_string = "done()"
+
+# Few-shot prompt examples for task decomposition
+gpt3_context = """
+objects = [red block, yellow block, blue block, green bowl]
+# move all the blocks to the top left corner.
+robot.pick_and_place(blue block, top left corner)
+robot.pick_and_place(red block, top left corner)
+robot.pick_and_place(yellow block, top left corner)
+done()
+
+objects = [red block, yellow block, blue block, green bowl]
+# put the yellow one the green thing.
+robot.pick_and_place(yellow block, green bowl)
+done()
+
+objects = [yellow block, blue block, red block]
+# move the light colored block to the middle.
+robot.pick_and_place(yellow block, middle)
+done()
+
+objects = [blue block, green bowl, red block, yellow bowl, green block]
+# stack the blocks.
+robot.pick_and_place(green block, blue block)
+robot.pick_and_place(red block, green block)
+done()
+
+objects = [red block, blue block, green bowl, blue bowl, yellow block, green block]
+# group the blue objects together.
+robot.pick_and_place(blue block, blue bowl)
+done()
+
+objects = [green bowl, red block, green block, red bowl, yellow bowl, yellow block]
+# sort all the blocks into their matching color bowls.
+robot.pick_and_place(green block, green bowl)
+robot.pick_and_place(red block, red bowl)
+robot.pick_and_place(yellow block, yellow bowl)
+done()
+"""
\ No newline at end of file
diff --git a/saycan/pick_place_env.py b/saycan/pick_place_env.py
new file mode 100644
index 0000000..892d94c
--- /dev/null
+++ b/saycan/pick_place_env.py
@@ -0,0 +1,487 @@
+"""
+SayCan Pick and Place Environment.
+
+A Gym-style PyBullet environment for robotic pick-and-place manipulation tasks.
+This environment simulates a UR5e robot arm with a Robotiq 2F-85 gripper
+manipulating blocks and bowls on a workspace.
+
+Key Features:
+- UR5e robot arm with Robotiq 2F-85 gripper
+- Configurable objects (blocks and bowls of various colors)
+- Pick-and-place motion primitives
+- RGB-D observation with heightmap generation
+- Video recording of episodes
+
+Original SayCan Repository:
+ https://github.com/google-research/google-research/tree/master/saycan
+
+Reference:
+ Ahn, M., et al. (2022). Do As I Can, Not As I Say: Grounding Language in
+ Robotic Affordances. arXiv preprint arXiv:2204.01691.
+"""
+
+import os
+import numpy as np
+import pybullet
+import pybullet_data
+from robot import Robotiq2F85
+from config import COLORS, BOUNDS, PIXEL_SIZE, SAYCAN_DIR
+
+class PickPlaceEnv():
+
+ def __init__(self):
+ self.dt = 1/480
+ self.sim_step = 0
+
+ # Configure and start PyBullet.
+ # python3 -m pybullet_utils.runServer
+ # pybullet.connect(pybullet.SHARED_MEMORY) # pybullet.GUI for local GUI.
+ pybullet.connect(pybullet.DIRECT) # pybullet.GUI for local GUI.
+ pybullet.configureDebugVisualizer(pybullet.COV_ENABLE_GUI, 0)
+ pybullet.setPhysicsEngineParameter(enableFileCaching=0)
+ pybullet.setAdditionalSearchPath(SAYCAN_DIR)
+ pybullet.setAdditionalSearchPath(pybullet_data.getDataPath())
+ pybullet.setTimeStep(self.dt)
+
+ self.home_joints = (np.pi / 2, -np.pi / 2, np.pi / 2, -np.pi / 2, 3 * np.pi / 2, 0) # Joint angles: (J0, J1, J2, J3, J4, J5).
+ self.home_ee_euler = (np.pi, 0, np.pi) # (RX, RY, RZ) rotation in Euler angles.
+ self.ee_link_id = 9 # Link ID of UR5 end effector.
+ self.tip_link_id = 10 # Link ID of gripper finger tips.
+ self.gripper = None
+
+ def reset(self, config):
+ pybullet.resetSimulation(pybullet.RESET_USE_DEFORMABLE_WORLD)
+ pybullet.setGravity(0, 0, -9.8)
+ self.cache_video = []
+
+ # Temporarily disable rendering to load URDFs faster.
+ pybullet.configureDebugVisualizer(pybullet.COV_ENABLE_RENDERING, 0)
+
+ # Add ground plane (from pybullet_data) and robot (from saycan directory).
+ pybullet.loadURDF("plane.urdf", [0, 0, -0.001])
+ ur5e_urdf = os.path.join(SAYCAN_DIR, "ur5e", "ur5e.urdf")
+ self.robot_id = pybullet.loadURDF(ur5e_urdf, [0, 0, 0], flags=pybullet.URDF_USE_MATERIAL_COLORS_FROM_MTL)
+ self.ghost_id = pybullet.loadURDF(ur5e_urdf, [0, 0, -10]) # For forward kinematics.
+ self.joint_ids = [pybullet.getJointInfo(self.robot_id, i) for i in range(pybullet.getNumJoints(self.robot_id))]
+ self.joint_ids = [j[0] for j in self.joint_ids if j[2] == pybullet.JOINT_REVOLUTE]
+
+ # Move robot to home configuration.
+ for i in range(len(self.joint_ids)):
+ pybullet.resetJointState(self.robot_id, self.joint_ids[i], self.home_joints[i])
+
+ # Add gripper.
+ if self.gripper is not None:
+ while self.gripper.constraints_thread.is_alive():
+ self.constraints_thread_active = False
+ self.gripper = Robotiq2F85(self.robot_id, self.ee_link_id)
+ self.gripper.release()
+
+ # Add workspace.
+ plane_shape = pybullet.createCollisionShape(pybullet.GEOM_BOX, halfExtents=[0.3, 0.3, 0.001])
+ plane_visual = pybullet.createVisualShape(pybullet.GEOM_BOX, halfExtents=[0.3, 0.3, 0.001])
+ plane_id = pybullet.createMultiBody(0, plane_shape, plane_visual, basePosition=[0, -0.5, 0])
+ pybullet.changeVisualShape(plane_id, -1, rgbaColor=[0.2, 0.2, 0.2, 1.0])
+
+ # Load objects according to config.
+ self.config = config
+ self.obj_name_to_id = {}
+ obj_names = list(self.config["pick"]) + list(self.config["place"])
+ obj_xyz = np.zeros((0, 3))
+ for obj_name in obj_names:
+ if ("block" in obj_name) or ("bowl" in obj_name):
+
+ # Get random position 15cm+ from other objects.
+ while True:
+ rand_x = np.random.uniform(BOUNDS[0, 0] + 0.1, BOUNDS[0, 1] - 0.1)
+ rand_y = np.random.uniform(BOUNDS[1, 0] + 0.1, BOUNDS[1, 1] - 0.1)
+ rand_xyz = np.float32([rand_x, rand_y, 0.03]).reshape(1, 3)
+ if len(obj_xyz) == 0:
+ obj_xyz = np.concatenate((obj_xyz, rand_xyz), axis=0)
+ break
+ else:
+ nn_dist = np.min(np.linalg.norm(obj_xyz - rand_xyz, axis=1)).squeeze()
+ if nn_dist > 0.15:
+ obj_xyz = np.concatenate((obj_xyz, rand_xyz), axis=0)
+ break
+
+ object_color = COLORS[obj_name.split(" ")[0]]
+ object_type = obj_name.split(" ")[1]
+ object_position = rand_xyz.squeeze()
+ if object_type == "block":
+ object_shape = pybullet.createCollisionShape(pybullet.GEOM_BOX, halfExtents=[0.02, 0.02, 0.02])
+ object_visual = pybullet.createVisualShape(pybullet.GEOM_BOX, halfExtents=[0.02, 0.02, 0.02])
+ object_id = pybullet.createMultiBody(0.01, object_shape, object_visual, basePosition=object_position)
+ elif object_type == "bowl":
+ object_position[2] = 0
+ bowl_urdf = os.path.join(SAYCAN_DIR, "bowl", "bowl.urdf")
+ object_id = pybullet.loadURDF(bowl_urdf, object_position, useFixedBase=1)
+ pybullet.changeVisualShape(object_id, -1, rgbaColor=object_color)
+ self.obj_name_to_id[obj_name] = object_id
+
+ # Re-enable rendering.
+ pybullet.configureDebugVisualizer(pybullet.COV_ENABLE_RENDERING, 1)
+
+ for _ in range(200):
+ pybullet.stepSimulation()
+ return self.get_observation()
+
+ def servoj(self, joints):
+ """Move to target joint positions with position control."""
+ pybullet.setJointMotorControlArray(
+ bodyIndex=self.robot_id,
+ jointIndices=self.joint_ids,
+ controlMode=pybullet.POSITION_CONTROL,
+ targetPositions=joints,
+ positionGains=[0.01]*6)
+
+ def movep(self, position):
+ """Move to target end effector position."""
+ joints = pybullet.calculateInverseKinematics(
+ bodyUniqueId=self.robot_id,
+ endEffectorLinkIndex=self.tip_link_id,
+ targetPosition=position,
+ targetOrientation=pybullet.getQuaternionFromEuler(self.home_ee_euler),
+ maxNumIterations=100)
+ self.servoj(joints)
+
+ def step(self, action=None):
+ """Do pick and place motion primitive."""
+ pick_xyz, place_xyz = action["pick"].copy(), action["place"].copy()
+
+ # Set fixed primitive z-heights.
+ hover_xyz = pick_xyz.copy() + np.float32([0, 0, 0.2])
+ pick_xyz[2] = 0.03
+ place_xyz[2] = 0.15
+
+ # Move to object.
+ ee_xyz = np.float32(pybullet.getLinkState(self.robot_id, self.tip_link_id)[0])
+ while np.linalg.norm(hover_xyz - ee_xyz) > 0.01:
+ self.movep(hover_xyz)
+ self.step_sim_and_render()
+ ee_xyz = np.float32(pybullet.getLinkState(self.robot_id, self.tip_link_id)[0])
+ while np.linalg.norm(pick_xyz - ee_xyz) > 0.01:
+ self.movep(pick_xyz)
+ self.step_sim_and_render()
+ ee_xyz = np.float32(pybullet.getLinkState(self.robot_id, self.tip_link_id)[0])
+
+ # Pick up object.
+ self.gripper.activate()
+ for _ in range(240):
+ self.step_sim_and_render()
+ while np.linalg.norm(hover_xyz - ee_xyz) > 0.01:
+ self.movep(hover_xyz)
+ self.step_sim_and_render()
+ ee_xyz = np.float32(pybullet.getLinkState(self.robot_id, self.tip_link_id)[0])
+
+ # Move to place location.
+ while np.linalg.norm(place_xyz - ee_xyz) > 0.01:
+ self.movep(place_xyz)
+ self.step_sim_and_render()
+ ee_xyz = np.float32(pybullet.getLinkState(self.robot_id, self.tip_link_id)[0])
+
+ # Place down object.
+ while (not self.gripper.detect_contact()) and (place_xyz[2] > 0.03):
+ place_xyz[2] -= 0.001
+ self.movep(place_xyz)
+ for _ in range(3):
+ self.step_sim_and_render()
+ self.gripper.release()
+ for _ in range(240):
+ self.step_sim_and_render()
+ place_xyz[2] = 0.2
+ ee_xyz = np.float32(pybullet.getLinkState(self.robot_id, self.tip_link_id)[0])
+ while np.linalg.norm(place_xyz - ee_xyz) > 0.01:
+ self.movep(place_xyz)
+ self.step_sim_and_render()
+ ee_xyz = np.float32(pybullet.getLinkState(self.robot_id, self.tip_link_id)[0])
+ place_xyz = np.float32([0, -0.5, 0.2])
+ while np.linalg.norm(place_xyz - ee_xyz) > 0.01:
+ self.movep(place_xyz)
+ self.step_sim_and_render()
+ ee_xyz = np.float32(pybullet.getLinkState(self.robot_id, self.tip_link_id)[0])
+
+ observation = self.get_observation()
+ reward = self.get_reward()
+ done = False
+ info = {}
+ return observation, reward, done, info
+
+ def set_alpha_transparency(self, alpha: float) -> None:
+ for id in range(20):
+ visual_shape_data = pybullet.getVisualShapeData(id)
+ for i in range(len(visual_shape_data)):
+ object_id, link_index, _, _, _, _, _, rgba_color = visual_shape_data[i]
+ rgba_color = list(rgba_color[0:3]) + [alpha]
+ pybullet.changeVisualShape(
+ self.robot_id, linkIndex=i, rgbaColor=rgba_color)
+ pybullet.changeVisualShape(
+ self.gripper.body, linkIndex=i, rgbaColor=rgba_color)
+
+ def step_sim_and_render(self):
+ pybullet.stepSimulation()
+ self.sim_step += 1
+
+ # Render current image at 8 FPS.
+ if self.sim_step % 60 == 0:
+ self.cache_video.append(self.get_camera_image())
+
+ def get_camera_image(self, resolution_factor=4):
+ """
+ Get camera image with adjustable resolution.
+
+ Args:
+ resolution_factor: Multiplier for resolution (default 4 = 960x960)
+ 1 = 240x240, 2 = 480x480, 3 = 720x720, 4 = 960x960
+ """
+ base_size = 240
+ base_focal = 120.
+ image_size = (base_size * resolution_factor, base_size * resolution_factor)
+ focal = base_focal * resolution_factor
+ intrinsics = (focal, 0, focal, 0, focal, focal, 0, 0, 1)
+ color, _, _, _, _ = self.render_image(image_size, intrinsics)
+ return color
+
+ def get_camera_image_top(self,
+ image_size=(240, 240),
+ intrinsics=(2000., 0, 2000., 0, 2000., 2000., 0, 0, 1),
+ position=(0, -0.5, 5),
+ orientation=(0, np.pi, -np.pi / 2),
+ zrange=(0.01, 1.),
+ set_alpha=True):
+ set_alpha and self.set_alpha_transparency(0)
+ color, _, _, _, _ = self.render_image_top(image_size,
+ intrinsics,
+ position,
+ orientation,
+ zrange)
+ set_alpha and self.set_alpha_transparency(1)
+ return color
+
+ def get_reward(self):
+ return 0 # TODO: check did the robot follow text instructions?
+
+ def get_observation(self):
+ observation = {}
+
+ # Render current image.
+ color, depth, position, orientation, intrinsics = self.render_image()
+
+ # Get heightmaps and colormaps.
+ points = self.get_pointcloud(depth, intrinsics)
+ position = np.float32(position).reshape(3, 1)
+ rotation = pybullet.getMatrixFromQuaternion(orientation)
+ rotation = np.float32(rotation).reshape(3, 3)
+ transform = np.eye(4)
+ transform[:3, :] = np.hstack((rotation, position))
+ points = self.transform_pointcloud(points, transform)
+ heightmap, colormap, xyzmap = self.get_heightmap(points, color, BOUNDS, PIXEL_SIZE)
+
+ observation["image"] = colormap
+ observation["xyzmap"] = xyzmap
+ observation["pick"] = list(self.config["pick"])
+ observation["place"] = list(self.config["place"])
+ return observation
+
+ def render_image(self, image_size=(720, 720), intrinsics=(360., 0, 360., 0, 360., 360., 0, 0, 1)):
+
+ # Camera parameters.
+ position = (0, -0.85, 0.4)
+ orientation = (np.pi / 4 + np.pi / 48, np.pi, np.pi)
+ orientation = pybullet.getQuaternionFromEuler(orientation)
+ zrange = (0.01, 10.)
+ noise=True
+
+ # OpenGL camera settings.
+ lookdir = np.float32([0, 0, 1]).reshape(3, 1)
+ updir = np.float32([0, -1, 0]).reshape(3, 1)
+ rotation = pybullet.getMatrixFromQuaternion(orientation)
+ rotm = np.float32(rotation).reshape(3, 3)
+ lookdir = (rotm @ lookdir).reshape(-1)
+ updir = (rotm @ updir).reshape(-1)
+ lookat = position + lookdir
+ focal_len = intrinsics[0]
+ znear, zfar = (0.01, 10.)
+ viewm = pybullet.computeViewMatrix(position, lookat, updir)
+ fovh = (image_size[0] / 2) / focal_len
+ fovh = 180 * np.arctan(fovh) * 2 / np.pi
+
+ # Notes: 1) FOV is vertical FOV 2) aspect must be float
+ aspect_ratio = image_size[1] / image_size[0]
+ projm = pybullet.computeProjectionMatrixFOV(fovh, aspect_ratio, znear, zfar)
+
+ # Render with OpenGL camera settings.
+ # Use brighter lighting to prevent colors from appearing dark/washed out
+ _, _, color, depth, segm = pybullet.getCameraImage(
+ width=image_size[1],
+ height=image_size[0],
+ viewMatrix=viewm,
+ projectionMatrix=projm,
+ shadow=1,
+ lightDirection=[0.5, 0.5, 1],
+ lightColor=[1.0, 1.0, 1.0],
+ lightDistance=2.0,
+ flags=pybullet.ER_SEGMENTATION_MASK_OBJECT_AND_LINKINDEX,
+ renderer=pybullet.ER_BULLET_HARDWARE_OPENGL)
+
+ # Get color image.
+ color_image_size = (image_size[0], image_size[1], 4)
+ color = np.array(color, dtype=np.uint8).reshape(color_image_size)
+ color = color[:, :, :3] # remove alpha channel
+ if noise:
+ color = np.int32(color)
+ color += np.int32(np.random.normal(0, 3, color.shape))
+ color = np.uint8(np.clip(color, 0, 255))
+
+ # Get depth image.
+ depth_image_size = (image_size[0], image_size[1])
+ zbuffer = np.float32(depth).reshape(depth_image_size)
+ depth = (zfar + znear - (2 * zbuffer - 1) * (zfar - znear))
+ depth = (2 * znear * zfar) / depth
+ if noise:
+ depth += np.random.normal(0, 0.003, depth.shape)
+
+ intrinsics = np.float32(intrinsics).reshape(3, 3)
+ return color, depth, position, orientation, intrinsics
+
+ def render_image_top(self,
+ image_size=(240, 240),
+ intrinsics=(2000., 0, 2000., 0, 2000., 2000., 0, 0, 1),
+ position=(0, -0.5, 5),
+ orientation=(0, np.pi, -np.pi / 2),
+ zrange=(0.01, 1.)):
+
+ # Camera parameters.
+ orientation = pybullet.getQuaternionFromEuler(orientation)
+ noise=True
+
+ # OpenGL camera settings.
+ lookdir = np.float32([0, 0, 1]).reshape(3, 1)
+ updir = np.float32([0, -1, 0]).reshape(3, 1)
+ rotation = pybullet.getMatrixFromQuaternion(orientation)
+ rotm = np.float32(rotation).reshape(3, 3)
+ lookdir = (rotm @ lookdir).reshape(-1)
+ updir = (rotm @ updir).reshape(-1)
+ lookat = position + lookdir
+ focal_len = intrinsics[0]
+ znear, zfar = (0.01, 10.)
+ viewm = pybullet.computeViewMatrix(position, lookat, updir)
+ fovh = (image_size[0] / 2) / focal_len
+ fovh = 180 * np.arctan(fovh) * 2 / np.pi
+
+ # Notes: 1) FOV is vertical FOV 2) aspect must be float
+ aspect_ratio = image_size[1] / image_size[0]
+ projm = pybullet.computeProjectionMatrixFOV(fovh, aspect_ratio, znear, zfar)
+
+ # Render with OpenGL camera settings.
+ # Use brighter lighting to prevent colors from appearing dark/washed out
+ _, _, color, depth, segm = pybullet.getCameraImage(
+ width=image_size[1],
+ height=image_size[0],
+ viewMatrix=viewm,
+ projectionMatrix=projm,
+ shadow=1,
+ lightDirection=[0.5, 0.5, 1],
+ lightColor=[1.0, 1.0, 1.0],
+ lightDistance=2.0,
+ flags=pybullet.ER_SEGMENTATION_MASK_OBJECT_AND_LINKINDEX,
+ renderer=pybullet.ER_BULLET_HARDWARE_OPENGL)
+
+ # Get color image.
+ color_image_size = (image_size[0], image_size[1], 4)
+ color = np.array(color, dtype=np.uint8).reshape(color_image_size)
+ color = color[:, :, :3] # remove alpha channel
+ if noise:
+ color = np.int32(color)
+ color += np.int32(np.random.normal(0, 3, color.shape))
+ color = np.uint8(np.clip(color, 0, 255))
+
+ # Get depth image.
+ depth_image_size = (image_size[0], image_size[1])
+ zbuffer = np.float32(depth).reshape(depth_image_size)
+ depth = (zfar + znear - (2 * zbuffer - 1) * (zfar - znear))
+ depth = (2 * znear * zfar) / depth
+ if noise:
+ depth += np.random.normal(0, 0.003, depth.shape)
+
+ intrinsics = np.float32(intrinsics).reshape(3, 3)
+ return color, depth, position, orientation, intrinsics
+
+ def get_pointcloud(self, depth, intrinsics):
+ """Get 3D pointcloud from perspective depth image.
+ Args:
+ depth: HxW float array of perspective depth in meters.
+ intrinsics: 3x3 float array of camera intrinsics matrix.
+ Returns:
+ points: HxWx3 float array of 3D points in camera coordinates.
+ """
+ height, width = depth.shape
+ xlin = np.linspace(0, width - 1, width)
+ ylin = np.linspace(0, height - 1, height)
+ px, py = np.meshgrid(xlin, ylin)
+ px = (px - intrinsics[0, 2]) * (depth / intrinsics[0, 0])
+ py = (py - intrinsics[1, 2]) * (depth / intrinsics[1, 1])
+ points = np.float32([px, py, depth]).transpose(1, 2, 0)
+ return points
+
+ def transform_pointcloud(self, points, transform):
+ """Apply rigid transformation to 3D pointcloud.
+ Args:
+ points: HxWx3 float array of 3D points in camera coordinates.
+ transform: 4x4 float array representing a rigid transformation matrix.
+ Returns:
+ points: HxWx3 float array of transformed 3D points.
+ """
+ padding = ((0, 0), (0, 0), (0, 1))
+ homogen_points = np.pad(points.copy(), padding,
+ "constant", constant_values=1)
+ for i in range(3):
+ points[Ellipsis, i] = np.sum(transform[i, :] * homogen_points, axis=-1)
+ return points
+
+ def get_heightmap(self, points, colors, bounds, pixel_size):
+ """Get top-down (z-axis) orthographic heightmap image from 3D pointcloud.
+ Args:
+ points: HxWx3 float array of 3D points in world coordinates.
+ colors: HxWx3 uint8 array of values in range 0-255 aligned with points.
+ bounds: 3x2 float array of values (rows: X,Y,Z; columns: min,max) defining
+ region in 3D space to generate heightmap in world coordinates.
+ pixel_size: float defining size of each pixel in meters.
+ Returns:
+ heightmap: HxW float array of height (from lower z-bound) in meters.
+ colormap: HxWx3 uint8 array of backprojected color aligned with heightmap.
+ xyzmap: HxWx3 float array of XYZ points in world coordinates.
+ """
+ width = int(np.round((bounds[0, 1] - bounds[0, 0]) / pixel_size))
+ height = int(np.round((bounds[1, 1] - bounds[1, 0]) / pixel_size))
+ heightmap = np.zeros((height, width), dtype=np.float32)
+ colormap = np.zeros((height, width, colors.shape[-1]), dtype=np.uint8)
+ xyzmap = np.zeros((height, width, 3), dtype=np.float32)
+
+ # Filter out 3D points that are outside of the predefined bounds.
+ ix = (points[Ellipsis, 0] >= bounds[0, 0]) & (points[Ellipsis, 0] < bounds[0, 1])
+ iy = (points[Ellipsis, 1] >= bounds[1, 0]) & (points[Ellipsis, 1] < bounds[1, 1])
+ iz = (points[Ellipsis, 2] >= bounds[2, 0]) & (points[Ellipsis, 2] < bounds[2, 1])
+ valid = ix & iy & iz
+ points = points[valid]
+ colors = colors[valid]
+
+ # Sort 3D points by z-value, which works with array assignment to simulate
+ # z-buffering for rendering the heightmap image.
+ iz = np.argsort(points[:, -1])
+ points, colors = points[iz], colors[iz]
+ px = np.int32(np.floor((points[:, 0] - bounds[0, 0]) / pixel_size))
+ py = np.int32(np.floor((points[:, 1] - bounds[1, 0]) / pixel_size))
+ px = np.clip(px, 0, width - 1)
+ py = np.clip(py, 0, height - 1)
+ heightmap[py, px] = points[:, 2] - bounds[2, 0]
+ for c in range(colors.shape[-1]):
+ colormap[py, px, c] = colors[:, c]
+ xyzmap[py, px, c] = points[:, c]
+ colormap = colormap[::-1, :, :] # Flip up-down.
+ xv, yv = np.meshgrid(np.linspace(BOUNDS[0, 0], BOUNDS[0, 1], height),
+ np.linspace(BOUNDS[1, 0], BOUNDS[1, 1], width))
+ xyzmap[:, :, 0] = xv
+ xyzmap[:, :, 1] = yv
+ xyzmap = xyzmap[::-1, :, :] # Flip up-down.
+ heightmap = heightmap[::-1, :] # Flip up-down.
+ return heightmap, colormap, xyzmap
\ No newline at end of file
diff --git a/saycan/policy.py b/saycan/policy.py
new file mode 100644
index 0000000..8b89a47
--- /dev/null
+++ b/saycan/policy.py
@@ -0,0 +1,61 @@
+"""
+SayCan Policy for SHARPIE.
+
+A simple policy wrapper for the SayCan environment. The actual LLM planning
+and CLIPort execution are handled by the environment module.
+
+Original SayCan Repository:
+ https://github.com/google-research/google-research/tree/master/saycan
+
+Reference:
+ Ahn, M., et al. (2022). Do As I Can, Not As I Say: Grounding Language in
+ Robotic Affordances. arXiv preprint arXiv:2204.01691.
+"""
+
+
+class Policy:
+ """
+ SayCan-based policy for pick-and-place operations.
+
+ This policy passes participant inputs directly to the environment,
+ which handles LLM planning and CLIPort execution.
+ """
+
+ def __init__(self, room_name=""):
+ """
+ Initialize the SayCan policy.
+
+ Args:
+ room_name: Optional room identifier (unused, kept for compatibility)
+ """
+ self.name = "SayCan_Policy"
+ self.room_name = room_name
+
+ def predict(self, observation, participant_input=None):
+ """
+ Predict an action based on the observation.
+
+ Args:
+ observation: Current observation from the environment
+ participant_input: Text instruction from participant:
+ - "task:" to set task and auto-plan
+ - "plan" to get next planned action
+ - Direct text instruction for CLIPort
+
+ Returns:
+ The participant_input (passed through to environment)
+ """
+ return participant_input
+
+ def update(self, observation, action, reward, done, next_observation):
+ """
+ Update the policy based on experience (no-op for SayCan).
+
+ SayCan doesn't use traditional RL updates. This method is kept
+ for compatibility with the SHARPIE framework.
+ """
+ pass
+
+
+# Create an instance of the policy for use by the runner
+policy = Policy('saycan')
\ No newline at end of file
diff --git a/saycan/requirements.txt b/saycan/requirements.txt
new file mode 100644
index 0000000..a822930
--- /dev/null
+++ b/saycan/requirements.txt
@@ -0,0 +1,41 @@
+# Text processing utilities
+ftfy
+regex
+tqdm
+fvcore
+
+# OpenAI CLIP (install from git)
+git+https://github.com/openai/CLIP.git
+
+# Google Drive downloader
+gdown
+
+# Video and image processing
+moviepy
+imageio
+imageio-ffmpeg
+opencv-python
+pillow
+
+# Plotting and display
+matplotlib
+ipython
+
+# Robotics simulation and utilities
+pybullet
+ollama
+easydict
+
+# Deep learning frameworks
+tensorflow
+torch
+torchvision
+
+# JAX with CUDA support
+jax[cuda]
+flax
+optax
+
+# Numerical computing
+numpy
+scipy
\ No newline at end of file
diff --git a/saycan/robot.py b/saycan/robot.py
new file mode 100644
index 0000000..0eb827d
--- /dev/null
+++ b/saycan/robot.py
@@ -0,0 +1,161 @@
+"""
+Robotiq 2F-85 Gripper Control Module.
+
+This module provides control for the Robotiq 2F-85 parallel gripper in PyBullet
+simulation. The gripper is commonly used with UR5e robot arm in manipulation tasks.
+
+Key Features:
+- Gripper open/close control
+- Grasp detection
+- Thread-based constraint enforcement
+
+Original SayCan Repository:
+ https://github.com/google-research/google-research/tree/master/saycan
+
+Reference:
+ Ahn, M., et al. (2022). Do As I Can, Not As I Say: Grounding Language in
+ Robotic Affordances. arXiv preprint arXiv:2204.01691.
+"""
+
+import os
+import threading
+import time
+import numpy as np
+import pybullet
+
+# Get the saycan directory for asset paths
+SAYCAN_DIR = os.path.dirname(os.path.abspath(__file__))
+
+
+class Robotiq2F85:
+ """
+ Gripper handling for Robotiq 2F-85.
+
+ This class manages the gripper attached to a robot arm, providing
+ open/close functionality and grasp detection.
+
+ Attributes:
+ robot: PyBullet robot body ID
+ tool: Link ID of the robot's tool (end effector)
+ body: PyBullet gripper body ID
+ n_joints: Number of joints in the gripper
+ activated: Whether the gripper is currently activated (grasping)
+ """
+
+ def __init__(self, robot, tool):
+ """
+ Initialize the gripper.
+
+ Args:
+ robot: PyBullet body ID of the robot
+ tool: Link ID of the robot's end effector
+ """
+ self.robot = robot
+ self.tool = tool
+ pos = [0.1339999999999999, -0.49199999999872496, 0.5]
+ rot = pybullet.getQuaternionFromEuler([np.pi, 0, np.pi])
+ urdf = os.path.join(SAYCAN_DIR, "robotiq_2f_85", "robotiq_2f_85.urdf")
+ self.body = pybullet.loadURDF(urdf, pos, rot)
+ self.n_joints = pybullet.getNumJoints(self.body)
+ self.activated = False
+
+ # Connect gripper base to robot tool
+ pybullet.createConstraint(
+ self.robot, tool, self.body, 0,
+ jointType=pybullet.JOINT_FIXED,
+ jointAxis=[0, 0, 0],
+ parentFramePosition=[0, 0, 0],
+ childFramePosition=[0, 0, -0.07],
+ childFrameOrientation=pybullet.getQuaternionFromEuler([0, 0, np.pi / 2])
+ )
+
+ # Set friction coefficients for gripper fingers
+ for i in range(pybullet.getNumJoints(self.body)):
+ pybullet.changeDynamics(
+ self.body, i,
+ lateralFriction=10.0,
+ spinningFriction=1.0,
+ rollingFriction=1.0,
+ frictionAnchor=True
+ )
+
+ # Start thread to handle additional gripper constraints
+ self.motor_joint = 1
+ self.constraints_thread = threading.Thread(target=self.step)
+ self.constraints_thread.daemon = True
+ self.constraints_thread.start()
+
+ def step(self):
+ """Control joint positions by enforcing hard constraints on gripper behavior."""
+ while True:
+ try:
+ currj = [pybullet.getJointState(self.body, i)[0] for i in range(self.n_joints)]
+ indj = [6, 3, 8, 5, 10]
+ targj = [currj[1], -currj[1], -currj[1], currj[1], currj[1]]
+ pybullet.setJointMotorControlArray(
+ self.body, indj, pybullet.POSITION_CONTROL, targj,
+ positionGains=np.ones(5)
+ )
+ except:
+ return
+ time.sleep(0.001)
+
+ def activate(self):
+ """Activate the gripper (close fingers to grasp)."""
+ pybullet.setJointMotorControl2(
+ self.body, self.motor_joint,
+ pybullet.VELOCITY_CONTROL,
+ targetVelocity=1,
+ force=10
+ )
+ self.activated = True
+
+ def release(self):
+ """Release the gripper (open fingers)."""
+ pybullet.setJointMotorControl2(
+ self.body, self.motor_joint,
+ pybullet.VELOCITY_CONTROL,
+ targetVelocity=-1,
+ force=10
+ )
+ self.activated = False
+
+ def detect_contact(self):
+ obj, _, ray_frac = self.check_proximity()
+ if self.activated:
+ empty = self.grasp_width() < 0.01
+ cbody = self.body if empty else obj
+ if obj == self.body or obj == 0:
+ return False
+ return self.external_contact(cbody)
+ # else:
+ # return ray_frac < 0.14 or self.external_contact()
+
+ # Return if body is in contact with something other than gripper
+ def external_contact(self, body=None):
+ if body is None:
+ body = self.body
+ pts = pybullet.getContactPoints(bodyA=body)
+ pts = [pt for pt in pts if pt[2] != self.body]
+ return len(pts) > 0 # pylint: disable=g-explicit-length-test
+
+ def check_grasp(self):
+ while self.moving():
+ time.sleep(0.001)
+ success = self.grasp_width() > 0.01
+ return success
+
+ def grasp_width(self):
+ lpad = np.array(pybullet.getLinkState(self.body, 4)[0])
+ rpad = np.array(pybullet.getLinkState(self.body, 9)[0])
+ dist = np.linalg.norm(lpad - rpad) - 0.047813
+ return dist
+
+ def check_proximity(self):
+ ee_pos = np.array(pybullet.getLinkState(self.robot, self.tool)[0])
+ tool_pos = np.array(pybullet.getLinkState(self.body, 0)[0])
+ vec = (tool_pos - ee_pos) / np.linalg.norm((tool_pos - ee_pos))
+ ee_targ = ee_pos + vec
+ ray_data = pybullet.rayTest(ee_pos, ee_targ)[0]
+ obj, link, ray_frac = ray_data[0], ray_data[1], ray_data[2]
+ return obj, link, ray_frac
\ No newline at end of file
diff --git a/saycan/vild.py b/saycan/vild.py
new file mode 100644
index 0000000..4bc018c
--- /dev/null
+++ b/saycan/vild.py
@@ -0,0 +1,680 @@
+"""
+ViLD - Vision and Language Knowledge Distillation for Open-Vocabulary Object Detection.
+
+This module provides the ViLD (Vision-Language Detection) model for open-vocabulary
+object detection. ViLD enables detecting objects beyond a fixed set of categories
+by leveraging CLIP embeddings.
+
+Key Components:
+- Text embedding building with prompt engineering
+- Object detection with confidence scoring
+- Visualization of detection results
+
+Original ViLD Repository:
+ https://github.com/tensorflow/tpu/tree/master/models/official/detection/projects/vild
+
+Reference:
+ Gu, X., Lin, T., Kuo, C., & Cui, Y. (2021). Open-Vocabulary Object Detection
+ via Vision and Language Knowledge Distillation. arXiv preprint arXiv:2104.13921.
+
+Used in SayCan:
+ https://github.com/google-research/google-research/tree/master/saycan
+"""
+
+import os
+import collections
+import numpy as np
+import cv2
+import torch
+import clip
+import matplotlib.pyplot as plt
+from tqdm import tqdm
+from easydict import EasyDict
+from PIL import Image
+import tensorflow.compat.v1 as tf
+
+# Get the directory where this script is located
+SAYCAN_DIR = os.path.dirname(os.path.abspath(__file__))
+
+
+def softmax(x, axis=-1):
+ """Compute softmax values for each element in x."""
+ e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
+ return e_x / np.sum(e_x, axis=axis, keepdims=True)
+
+
+# ViLD configuration flags
+FLAGS = {
+ 'prompt_engineering': True,
+ 'this_is': True,
+ 'temperature': 100.0,
+ 'use_softmax': False,
+}
+FLAGS = EasyDict(FLAGS)
+
+# Visualization parameters
+display_input_size = (10, 10)
+overall_fig_size = (18, 24)
+line_thickness = 1
+fig_size_w = 35
+mask_color = 'red'
+alpha = 0.5
+
+
+def article(name):
+ """Return 'an' if name starts with a vowel, 'a' otherwise."""
+ return "an" if name[0] in "aeiou" else "a"
+
+
+def processed_name(name, rm_dot=False):
+ """Process category name by replacing underscores and slashes."""
+ res = name.replace("_", " ").replace("/", " or ").lower()
+ if rm_dot:
+ res = res.rstrip(".")
+ return res
+
+
+# Prompt templates for CLIP embedding
+single_template = ["a photo of {article} {}."]
+
+multiple_templates = [
+ 'There is {article} {} in the scene.',
+ 'There is the {} in the scene.',
+ 'a photo of {article} {} in the scene.',
+ 'a photo of the {} in the scene.',
+ 'a photo of one {} in the scene.',
+ 'itap of {article} {}.',
+ 'itap of my {}.',
+ 'itap of the {}.',
+ 'a photo of {article} {}.',
+ 'a photo of my {}.',
+ 'a photo of the {}.',
+ 'a photo of one {}.',
+ 'a photo of many {}.',
+ 'a good photo of {article} {}.',
+ 'a good photo of the {}.',
+ 'a bad photo of {article} {}.',
+ 'a bad photo of the {}.',
+ 'a photo of a nice {}.',
+ 'a photo of the nice {}.',
+ 'a photo of a cool {}.',
+ 'a photo of the cool {}.',
+ 'a photo of a weird {}.',
+ 'a photo of the weird {}.',
+ 'a photo of a small {}.',
+ 'a photo of the small {}.',
+ 'a photo of a large {}.',
+ 'a photo of the large {}.',
+ 'a photo of a clean {}.',
+ 'a photo of the clean {}.',
+ 'a photo of a dirty {}.',
+ 'a photo of the dirty {}.',
+ 'a bright photo of {article} {}.',
+ 'a bright photo of the {}.',
+ 'a dark photo of {article} {}.',
+ 'a dark photo of the {}.',
+ 'a photo of a hard to see {}.',
+ 'a photo of the hard to see {}.',
+ 'a low resolution photo of {article} {}.',
+ 'a low resolution photo of the {}.',
+ 'a cropped photo of {article} {}.',
+ 'a cropped photo of the {}.',
+ 'a close-up photo of {article} {}.',
+ 'a close-up photo of the {}.',
+ 'a jpeg corrupted photo of {article} {}.',
+ 'a jpeg corrupted photo of the {}.',
+ 'a blurry photo of {article} {}.',
+ 'a blurry photo of the {}.',
+ 'a pixelated photo of {article} {}.',
+ 'a pixelated photo of the {}.',
+ 'a black and white photo of the {}.',
+ 'a black and white photo of {article} {}.',
+ 'a plastic {}.',
+ 'the plastic {}.',
+ 'a toy {}.',
+ 'the toy {}.',
+ 'a plushie {}.',
+ 'the plushie {}.',
+ 'a cartoon {}.',
+ 'the cartoon {}.',
+ 'an embroidered {}.',
+ 'the embroidered {}.',
+ 'a painting of the {}.',
+ 'a painting of a {}.',
+]
+
+# Lazy-loaded models (loaded on first use, can be freed)
+_clip_model = None
+_clip_preprocess = None
+_tf_session = None
+
+
+def get_clip_model():
+ """Get or lazily load the CLIP model."""
+ global _clip_model, _clip_preprocess
+ if _clip_model is None:
+ _clip_model, _clip_preprocess = clip.load("ViT-B/32")
+ if torch.cuda.is_available():
+ _clip_model.cuda()
+ _clip_model.eval()
+ print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in _clip_model.parameters()]):,}")
+ print("Input resolution:", _clip_model.visual.input_resolution)
+ print("Context length:", _clip_model.context_length)
+ print("Vocab size:", _clip_model.vocab_size)
+ return _clip_model, _clip_preprocess
+
+
+def get_tf_session():
+ """Get or lazily load the TensorFlow session."""
+ global _tf_session
+ if _tf_session is None:
+ config = tf.ConfigProto(allow_soft_placement=True)
+ if torch.cuda.is_available():
+ gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.2)
+ config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
+ _tf_session = tf.Session(graph=tf.Graph(), config=config)
+ saved_model_dir = os.path.join(SAYCAN_DIR, "image_path_v2")
+ _ = tf.saved_model.loader.load(_tf_session, ["serve"], saved_model_dir)
+ return _tf_session
+
+
+def cleanup_models():
+ """Free model resources. Call this when done with ViLD."""
+ global _clip_model, _clip_preprocess, _tf_session
+
+ if _tf_session is not None:
+ _tf_session.close()
+ _tf_session = None
+
+ if _clip_model is not None:
+ del _clip_model
+ del _clip_preprocess
+ _clip_model = None
+ _clip_preprocess = None
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+
+# Backward compatibility - load on import (can be disabled by setting LAZY_LOAD=true)
+import os as _os
+if _os.environ.get('VILD_LAZY_LOAD', '').lower() != 'true':
+ clip_model, clip_preprocess = get_clip_model()
+ session = get_tf_session()
+else:
+ clip_model, clip_preprocess = None, None
+ session = None
+
+
+def build_text_embedding(categories):
+ """
+ Build text embeddings for object categories using CLIP.
+
+ Args:
+ categories: List of category dicts with 'name' and 'id' keys
+
+ Returns:
+ Numpy array of text embeddings
+ """
+ clip_model, _ = get_clip_model()
+
+ if FLAGS.prompt_engineering:
+ templates = multiple_templates
+ else:
+ templates = single_template
+
+ run_on_gpu = torch.cuda.is_available()
+
+ with torch.no_grad():
+ all_text_embeddings = []
+ print("Building text embeddings...")
+ for category in tqdm(categories):
+ texts = [
+ template.format(processed_name(category["name"], rm_dot=True),
+ article=article(category["name"]))
+ for template in templates
+ ]
+ if FLAGS.this_is:
+ texts = [
+ "This is " + text if text.startswith("a") or text.startswith("the") else text
+ for text in texts
+ ]
+ texts = clip.tokenize(texts)
+ if run_on_gpu:
+ texts = texts.cuda()
+ text_embeddings = clip_model.encode_text(texts)
+ text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
+ text_embedding = text_embeddings.mean(dim=0)
+ text_embedding /= text_embedding.norm()
+ all_text_embeddings.append(text_embedding.cpu()) # Move to CPU immediately
+ all_text_embeddings = torch.stack(all_text_embeddings, dim=1)
+
+ # Clear GPU cache after embedding
+ if run_on_gpu:
+ torch.cuda.empty_cache()
+
+ return all_text_embeddings.cpu().numpy().T
+
+
+# Load ViLD TensorFlow model
+config = tf.ConfigProto(allow_soft_placement=True)
+if torch.cuda.is_available():
+ gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.2)
+ config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
+session = tf.Session(graph=tf.Graph(), config=config)
+saved_model_dir = os.path.join(SAYCAN_DIR, "image_path_v2")
+_ = tf.saved_model.loader.load(session, ["serve"], saved_model_dir)
+
+numbered_categories = [{"name": str(idx), "id": idx} for idx in range(50)]
+numbered_category_indices = {cat["id"]: cat for cat in numbered_categories}
+
+
+def nms(dets, scores, thresh, max_dets=1000):
+ """
+ Non-maximum suppression.
+
+ Args:
+ dets: Detection boxes [N, 4]
+ scores: Detection scores [N,]
+ thresh: IoU threshold
+ max_dets: Maximum detections to keep
+
+ Returns:
+ List of indices to keep
+ """
+ y1 = dets[:, 0]
+ x1 = dets[:, 1]
+ y2 = dets[:, 2]
+ x2 = dets[:, 3]
+
+ areas = (x2 - x1) * (y2 - y1)
+ order = scores.argsort()[::-1]
+
+ keep = []
+ while order.size > 0 and len(keep) < max_dets:
+ i = order[0]
+ keep.append(i)
+
+ xx1 = np.maximum(x1[i], x1[order[1:]])
+ yy1 = np.maximum(y1[i], y1[order[1:]])
+ xx2 = np.minimum(x2[i], x2[order[1:]])
+ yy2 = np.minimum(y2[i], y2[order[1:]])
+
+ w = np.maximum(0.0, xx2 - xx1)
+ h = np.maximum(0.0, yy2 - yy1)
+ intersection = w * h
+ overlap = intersection / (areas[i] + areas[order[1:]] - intersection + 1e-12)
+
+ inds = np.where(overlap <= thresh)[0]
+ order = order[inds + 1]
+ return keep
+
+
+import PIL.ImageColor as ImageColor
+import PIL.ImageDraw as ImageDraw
+import PIL.ImageFont as ImageFont
+
+STANDARD_COLORS = ["White"]
+
+
+def draw_bounding_box_on_image(image, ymin, xmin, ymax, xmax, color="red", thickness=4,
+ display_str_list=(), use_normalized_coordinates=True):
+ """Adds a bounding box to an image."""
+ draw = ImageDraw.Draw(image)
+ im_width, im_height = image.size
+ if use_normalized_coordinates:
+ (left, right, top, bottom) = (
+ xmin * im_width, xmax * im_width, ymin * im_height, ymax * im_height
+ )
+ else:
+ (left, right, top, bottom) = (xmin, xmax, ymin, ymax)
+ draw.line([(left, top), (left, bottom), (right, bottom), (right, top), (left, top)],
+ width=thickness, fill=color)
+ try:
+ font = ImageFont.truetype("arial.ttf", 24)
+ except IOError:
+ font = ImageFont.load_default()
+
+ display_str_heights = [font.getbbox(ds)[3] - font.getbbox(ds)[1] for ds in display_str_list]
+ total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights)
+
+ if top > total_display_str_height:
+ text_bottom = top
+ else:
+ text_bottom = bottom + total_display_str_height
+
+ for display_str in display_str_list[::-1]:
+ text_left = min(5, left)
+ bbox = font.getbbox(display_str)
+ text_width, text_height = bbox[2] - bbox[0], bbox[3] - bbox[1]
+ margin = np.ceil(0.05 * text_height)
+ draw.rectangle([(left, text_bottom - text_height - 2 * margin),
+ (left + text_width, text_bottom)], fill=color)
+ draw.text((left + margin, text_bottom - text_height - margin), display_str,
+ fill="black", font=font)
+ text_bottom -= text_height - 2 * margin
+
+
+def draw_bounding_box_on_image_array(image, ymin, xmin, ymax, xmax, color="red", thickness=4,
+ display_str_list=(), use_normalized_coordinates=True):
+ """Adds a bounding box to an image (numpy array)."""
+ image_pil = Image.fromarray(np.uint8(image)).convert("RGB")
+ draw_bounding_box_on_image(image_pil, ymin, xmin, ymax, xmax, color, thickness,
+ display_str_list, use_normalized_coordinates)
+ np.copyto(image, np.array(image_pil))
+
+
+def draw_mask_on_image_array(image, mask, color="red", alpha=0.4):
+ """Draws mask on an image."""
+ if image.dtype != np.uint8:
+ raise ValueError("`image` not of type np.uint8")
+ if mask.dtype != np.uint8:
+ raise ValueError("`mask` not of type np.uint8")
+ if np.any(np.logical_and(mask != 1, mask != 0)):
+ raise ValueError("`mask` elements should be in [0, 1]")
+ if image.shape[:2] != mask.shape:
+ raise ValueError("Image and mask dimensions don't match")
+
+ rgb = ImageColor.getrgb(color)
+ pil_image = Image.fromarray(image)
+ solid_color = np.expand_dims(np.ones_like(mask), axis=2) * np.reshape(list(rgb), [1, 1, 3])
+ pil_solid_color = Image.fromarray(np.uint8(solid_color)).convert("RGBA")
+ pil_mask = Image.fromarray(np.uint8(255.0 * alpha * mask)).convert("L")
+ pil_image = Image.composite(pil_solid_color, pil_image, pil_mask)
+ np.copyto(image, np.array(pil_image.convert("RGB")))
+
+
+def visualize_boxes_and_labels_on_image_array(image, boxes, classes, scores, category_index,
+ instance_masks=None, instance_boundaries=None,
+ use_normalized_coordinates=False, max_boxes_to_draw=20,
+ min_score_thresh=0.5, agnostic_mode=False,
+ line_thickness=1, groundtruth_box_visualization_color="black",
+ skip_scores=False, skip_labels=False, mask_alpha=0.4,
+ plot_color=None):
+ """Overlay labeled boxes on an image with formatted scores and label names."""
+ box_to_display_str_map = collections.defaultdict(list)
+ box_to_color_map = collections.defaultdict(str)
+ box_to_instance_masks_map = {}
+ box_to_score_map = {}
+ box_to_instance_boundaries_map = {}
+
+ if not max_boxes_to_draw:
+ max_boxes_to_draw = boxes.shape[0]
+
+ for i in range(min(max_boxes_to_draw, boxes.shape[0])):
+ if scores is None or scores[i] > min_score_thresh:
+ box = tuple(boxes[i].tolist())
+ if instance_masks is not None:
+ box_to_instance_masks_map[box] = instance_masks[i]
+ if instance_boundaries is not None:
+ box_to_instance_boundaries_map[box] = instance_boundaries[i]
+ if scores is None:
+ box_to_color_map[box] = groundtruth_box_visualization_color
+ else:
+ display_str = ""
+ if not skip_labels:
+ if not agnostic_mode:
+ if classes[i] in list(category_index.keys()):
+ class_name = category_index[classes[i]]["name"]
+ else:
+ class_name = "N/A"
+ display_str = str(class_name)
+ if not skip_scores:
+ if not display_str:
+ display_str = "{}%".format(int(100 * scores[i]))
+ else:
+ float_score = ("%.2f" % scores[i]).lstrip("0")
+ display_str = "{}: {}".format(display_str, float_score)
+ box_to_score_map[box] = int(100 * scores[i])
+
+ box_to_display_str_map[box].append(display_str)
+ if plot_color is not None:
+ box_to_color_map[box] = plot_color
+ elif agnostic_mode:
+ box_to_color_map[box] = "DarkOrange"
+ else:
+ box_to_color_map[box] = STANDARD_COLORS[classes[i] % len(STANDARD_COLORS)]
+
+ if box_to_score_map:
+ box_color_iter = sorted(box_to_color_map.items(), key=lambda kv: box_to_score_map[kv[0]])
+ else:
+ box_color_iter = box_to_color_map.items()
+
+ for box, color in box_color_iter:
+ ymin, xmin, ymax, xmax = box
+ if instance_masks is not None:
+ draw_mask_on_image_array(image, box_to_instance_masks_map[box], color=color, alpha=mask_alpha)
+ if instance_boundaries is not None:
+ draw_mask_on_image_array(image, box_to_instance_boundaries_map[box], color="red", alpha=1.0)
+ draw_bounding_box_on_image_array(image, ymin, xmin, ymax, xmax, color=color,
+ thickness=line_thickness,
+ display_str_list=box_to_display_str_map[box],
+ use_normalized_coordinates=use_normalized_coordinates)
+
+ return image
+
+
+def paste_instance_masks(masks, detected_boxes, image_height, image_width):
+ """Paste instance masks to generate the image segmentation results."""
+ def expand_boxes(boxes, scale):
+ w_half = boxes[:, 2] * 0.5
+ h_half = boxes[:, 3] * 0.5
+ x_c = boxes[:, 0] + w_half
+ y_c = boxes[:, 1] + h_half
+ w_half *= scale
+ h_half *= scale
+ boxes_exp = np.zeros(boxes.shape)
+ boxes_exp[:, 0] = x_c - w_half
+ boxes_exp[:, 2] = x_c + w_half
+ boxes_exp[:, 1] = y_c - h_half
+ boxes_exp[:, 3] = y_c + h_half
+ return boxes_exp
+
+ _, mask_height, mask_width = masks.shape
+ scale = max((mask_width + 2.0) / mask_width, (mask_height + 2.0) / mask_height)
+ ref_boxes = expand_boxes(detected_boxes, scale)
+ ref_boxes = ref_boxes.astype(np.int32)
+ padded_mask = np.zeros((mask_height + 2, mask_width + 2), dtype=np.float32)
+ segms = []
+
+ for mask_ind, mask in enumerate(masks):
+ im_mask = np.zeros((image_height, image_width), dtype=np.uint8)
+ padded_mask[1:-1, 1:-1] = mask[:, :]
+ ref_box = ref_boxes[mask_ind, :]
+ w = ref_box[2] - ref_box[0] + 1
+ h = ref_box[3] - ref_box[1] + 1
+ w = np.maximum(w, 1)
+ h = np.maximum(h, 1)
+ mask = cv2.resize(padded_mask, (w, h))
+ mask = np.array(mask > 0.5, dtype=np.uint8)
+ x_0 = min(max(ref_box[0], 0), image_width)
+ x_1 = min(max(ref_box[2] + 1, 0), image_width)
+ y_0 = min(max(ref_box[1], 0), image_height)
+ y_1 = min(max(ref_box[3] + 1, 0), image_height)
+ im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - ref_box[1]):(y_1 - ref_box[1]),
+ (x_0 - ref_box[0]):(x_1 - ref_box[0])]
+ segms.append(im_mask)
+
+ segms = np.array(segms)
+ return segms
+
+
+def plot_mask(color, alpha, original_image, mask):
+ """Plot instance mask on image."""
+ rgb = ImageColor.getrgb(color)
+ pil_image = Image.fromarray(original_image)
+ solid_color = np.expand_dims(np.ones_like(mask), axis=2) * np.reshape(list(rgb), [1, 1, 3])
+ pil_solid_color = Image.fromarray(np.uint8(solid_color)).convert("RGBA")
+ pil_mask = Image.fromarray(np.uint8(255.0 * alpha * mask)).convert("L")
+ pil_image = Image.composite(pil_solid_color, pil_image, pil_mask)
+ return np.array(pil_image.convert("RGB"))
+
+
+def display_image(path_or_array, size=(10, 10)):
+ """Display an image from path or array."""
+ if isinstance(path_or_array, str):
+ image = np.asarray(Image.open(open(path_or_array, "rb")).convert("RGB"))
+ else:
+ image = path_or_array
+ plt.figure(figsize=size)
+ plt.imshow(image)
+ plt.axis("off")
+ plt.show()
+
+
+def vild(image_path, category_name_string, params, plot_on=True, prompt_swaps=[]):
+ """
+ Run ViLD object detection on an image.
+
+ Args:
+ image_path: Path to the input image
+ category_name_string: Semicolon-separated category names
+ params: Tuple of (max_boxes, nms_thresh, min_rpn_score, min_box_area, max_box_area)
+ plot_on: Whether to display visualization
+ prompt_swaps: List of (old, new) string replacements for categories
+
+ Returns:
+ List of detected object names
+ """
+ # Preprocessing categories
+ for a, b in prompt_swaps:
+ category_name_string = category_name_string.replace(a, b)
+ category_names = [x.strip() for x in category_name_string.split(";")]
+ category_names = ["background"] + category_names
+ categories = [{"name": item, "id": idx + 1} for idx, item in enumerate(category_names)]
+ category_indices = {cat["id"]: cat for cat in categories}
+
+ max_boxes_to_draw, nms_threshold, min_rpn_score_thresh, min_box_area, max_box_area = params
+ fig_size_h = min(max(5, int(len(category_names) / 2.5)), 10)
+
+ # Run ViLD model
+ roi_boxes, roi_scores, detection_boxes, scores_unused, box_outputs, detection_masks, visual_features, image_info = session.run(
+ ["RoiBoxes:0", "RoiScores:0", "2ndStageBoxes:0", "2ndStageScoresUnused:0",
+ "BoxOutputs:0", "MaskOutputs:0", "VisualFeatOutputs:0", "ImageInfo:0"],
+ feed_dict={"Placeholder:0": [image_path]})
+
+ roi_boxes = np.squeeze(roi_boxes, axis=0)
+ roi_scores = np.squeeze(roi_scores, axis=0)
+ detection_boxes = np.squeeze(detection_boxes, axis=(0, 2))
+ scores_unused = np.squeeze(scores_unused, axis=0)
+ box_outputs = np.squeeze(box_outputs, axis=0)
+ detection_masks = np.squeeze(detection_masks, axis=0)
+ visual_features = np.squeeze(visual_features, axis=0)
+
+ image_info = np.squeeze(image_info, axis=0)
+ image_scale = np.tile(image_info[2:3, :], (1, 2))
+ image_height = int(image_info[0, 0])
+ image_width = int(image_info[0, 1])
+
+ rescaled_detection_boxes = detection_boxes / image_scale
+
+ # Read image
+ image = np.asarray(Image.open(open(image_path, "rb")).convert("RGB"))
+ assert image_height == image.shape[0]
+ assert image_width == image.shape[1]
+
+ # Filter boxes with NMS
+ nmsed_indices = nms(detection_boxes, roi_scores, thresh=nms_threshold)
+ box_sizes = (rescaled_detection_boxes[:, 2] - rescaled_detection_boxes[:, 0]) * \
+ (rescaled_detection_boxes[:, 3] - rescaled_detection_boxes[:, 1])
+
+ valid_indices = np.where(
+ np.logical_and(
+ np.isin(np.arange(len(roi_scores), dtype=int), nmsed_indices),
+ np.logical_and(
+ np.logical_not(np.all(roi_boxes == 0., axis=-1)),
+ np.logical_and(
+ roi_scores >= min_rpn_score_thresh,
+ np.logical_and(box_sizes > min_box_area, box_sizes < max_box_area)
+ )
+ )
+ )
+ )[0]
+
+ detection_roi_scores = roi_scores[valid_indices][:max_boxes_to_draw, ...]
+ detection_boxes = detection_boxes[valid_indices][:max_boxes_to_draw, ...]
+ detection_masks = detection_masks[valid_indices][:max_boxes_to_draw, ...]
+ detection_visual_feat = visual_features[valid_indices][:max_boxes_to_draw, ...]
+ rescaled_detection_boxes = rescaled_detection_boxes[valid_indices][:max_boxes_to_draw, ...]
+
+ # Compute text embeddings and scores
+ text_features = build_text_embedding(categories)
+ raw_scores = detection_visual_feat.dot(text_features.T)
+ if FLAGS.use_softmax:
+ scores_all = softmax(FLAGS.temperature * raw_scores, axis=-1)
+ else:
+ scores_all = raw_scores
+
+ indices = np.argsort(-np.max(scores_all, axis=1))
+ indices_fg = np.array([i for i in indices if np.argmax(scores_all[i]) != 0])
+
+ # Get found objects
+ found_objects = []
+ for a, b in prompt_swaps:
+ category_names = [name.replace(b, a) for name in category_names]
+
+ for anno_idx in indices[0:int(rescaled_detection_boxes.shape[0])]:
+ scores = scores_all[anno_idx]
+ if np.argmax(scores) == 0:
+ continue
+ found_object = category_names[np.argmax(scores)]
+ if found_object == "background":
+ continue
+ print("Found a", found_object, "with score:", np.max(scores))
+ found_objects.append(category_names[np.argmax(scores)])
+
+ if not plot_on:
+ return found_objects
+
+ # Visualization
+ ymin, xmin, ymax, xmax = np.split(rescaled_detection_boxes, 4, axis=-1)
+ processed_boxes = np.concatenate([xmin, ymin, xmax - xmin, ymax - ymin], axis=-1)
+ segmentations = paste_instance_masks(detection_masks, processed_boxes, image_height, image_width)
+
+ if len(indices_fg) == 0:
+ display_image(np.array(image), size=overall_fig_size)
+ print("ViLD does not detect anything belonging to the given category")
+ else:
+ image_with_detections = visualize_boxes_and_labels_on_image_array(
+ np.array(image),
+ rescaled_detection_boxes[indices_fg],
+ valid_indices[:max_boxes_to_draw][indices_fg],
+ detection_roi_scores[indices_fg],
+ numbered_category_indices,
+ instance_masks=segmentations[indices_fg],
+ use_normalized_coordinates=False,
+ max_boxes_to_draw=max_boxes_to_draw,
+ min_score_thresh=min_rpn_score_thresh,
+ skip_scores=False,
+ skip_labels=True)
+
+ plt.imshow(image_with_detections)
+ plt.title("ViLD detected objects and RPN scores.")
+ plt.show()
+
+ return found_objects
+
+
+# Default category names for pick-and-place tasks
+category_names = [
+ 'blue block', 'red block', 'green block', 'orange block', 'yellow block',
+ 'purple block', 'pink block', 'cyan block', 'brown block', 'gray block',
+ 'blue bowl', 'red bowl', 'green bowl', 'orange bowl', 'yellow bowl',
+ 'purple bowl', 'pink bowl', 'cyan bowl', 'brown bowl', 'gray bowl'
+]
+
+image_path = 'tmp.jpg'
+
+# ViLD settings
+category_name_string = ";".join(category_names)
+max_boxes_to_draw = 8
+prompt_swaps = [('block', 'cube')]
+nms_threshold = 0.4
+min_rpn_score_thresh = 0.4
+min_box_area = 10
+max_box_area = 3000
+vild_params = max_boxes_to_draw, nms_threshold, min_rpn_score_thresh, min_box_area, max_box_area
+
+
+if __name__ == "__main__":
+ found_objects = vild(image_path, category_name_string, vild_params, plot_on=True, prompt_swaps=prompt_swaps)
+ print("Found objects:", found_objects)
\ No newline at end of file