|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "id": "1182c4c3", |
| 6 | + "metadata": {}, |
| 7 | + "source": [ |
| 8 | + "# Single-view 3D reconstruction with SAM-3D Objects\n", |
| 9 | + "\n", |
| 10 | + "This notebook requires a self-hosted inference server with a 32GB+ VRAM GPU. See the README for the recommended setup." |
| 11 | + ] |
| 12 | + }, |
| 13 | + { |
| 14 | + "cell_type": "code", |
| 15 | + "execution_count": null, |
| 16 | + "id": "8274b79d", |
| 17 | + "metadata": {}, |
| 18 | + "outputs": [], |
| 19 | + "source": [ |
| 20 | + "%load_ext autoreload\n", |
| 21 | + "%autoreload 2\n", |
| 22 | + "%pip install -r requirements.txt" |
| 23 | + ] |
| 24 | + }, |
| 25 | + { |
| 26 | + "cell_type": "markdown", |
| 27 | + "id": "c67abc01", |
| 28 | + "metadata": {}, |
| 29 | + "source": [ |
| 30 | + "Set up the notebook to point at your inference server instance and use your API key to download model weights." |
| 31 | + ] |
| 32 | + }, |
| 33 | + { |
| 34 | + "cell_type": "code", |
| 35 | + "execution_count": null, |
| 36 | + "id": "e879d5e6", |
| 37 | + "metadata": {}, |
| 38 | + "outputs": [], |
| 39 | + "source": [ |
| 40 | + "API_URL = \"http://localhost:9001\"\n", |
| 41 | + "API_KEY = \"YOUR_API_KEY\"\n", |
| 42 | + "\n", |
| 43 | + "SEGMENTATION_MODEL_ID = \"rfdetr-seg-preview\"\n", |
| 44 | + "SAM3_3D_MODEL_ID = \"sam3-3d-objects\"" |
| 45 | + ] |
| 46 | + }, |
| 47 | + { |
| 48 | + "cell_type": "markdown", |
| 49 | + "id": "717682ab", |
| 50 | + "metadata": {}, |
| 51 | + "source": [ |
| 52 | + "Set input data and output directory for logging the annotated image and 3D view." |
| 53 | + ] |
| 54 | + }, |
| 55 | + { |
| 56 | + "cell_type": "code", |
| 57 | + "execution_count": null, |
| 58 | + "id": "be41fb82", |
| 59 | + "metadata": {}, |
| 60 | + "outputs": [], |
| 61 | + "source": [ |
| 62 | + "from supervision.assets import download_assets, VideoAssets\n", |
| 63 | + "\n", |
| 64 | + "# INPUT_VIDEO_PATH = download_assets(VideoAssets.MILK_BOTTLING_PLANT)\n", |
| 65 | + "INPUT_VIDEO_PATH = download_assets(VideoAssets.VEHICLES)\n", |
| 66 | + "\n", |
| 67 | + "OUTPUT_DIR = \"sam-3d-detect\"" |
| 68 | + ] |
| 69 | + }, |
| 70 | + { |
| 71 | + "cell_type": "code", |
| 72 | + "execution_count": null, |
| 73 | + "id": "16a4b598", |
| 74 | + "metadata": {}, |
| 75 | + "outputs": [], |
| 76 | + "source": [ |
| 77 | + "import os\n", |
| 78 | + "import shutil\n", |
| 79 | + "\n", |
| 80 | + "if os.path.exists(OUTPUT_DIR):\n", |
| 81 | + " shutil.rmtree(OUTPUT_DIR)\n", |
| 82 | + "os.makedirs(OUTPUT_DIR)" |
| 83 | + ] |
| 84 | + }, |
| 85 | + { |
| 86 | + "cell_type": "markdown", |
| 87 | + "id": "6b6ddc4b", |
| 88 | + "metadata": {}, |
| 89 | + "source": [ |
| 90 | + "Step 1: Load an input image and make sure it looks how we expect." |
| 91 | + ] |
| 92 | + }, |
| 93 | + { |
| 94 | + "cell_type": "code", |
| 95 | + "execution_count": null, |
| 96 | + "id": "d64d6a9d", |
| 97 | + "metadata": {}, |
| 98 | + "outputs": [], |
| 99 | + "source": [ |
| 100 | + "import supervision as sv\n", |
| 101 | + "\n", |
| 102 | + "image = next(sv.get_video_frames_generator(INPUT_VIDEO_PATH))\n", |
| 103 | + "sv.plot_image(image)" |
| 104 | + ] |
| 105 | + }, |
| 106 | + { |
| 107 | + "cell_type": "markdown", |
| 108 | + "id": "439764ea", |
| 109 | + "metadata": {}, |
| 110 | + "source": [ |
| 111 | + "Step 2: Generate 2D object masks by running an instance segmentation model like RF-DETR Seg." |
| 112 | + ] |
| 113 | + }, |
| 114 | + { |
| 115 | + "cell_type": "code", |
| 116 | + "execution_count": null, |
| 117 | + "id": "58066e1c", |
| 118 | + "metadata": {}, |
| 119 | + "outputs": [], |
| 120 | + "source": [ |
| 121 | + "from inference_sdk import InferenceHTTPClient\n", |
| 122 | + "import time\n", |
| 123 | + "\n", |
| 124 | + "client = InferenceHTTPClient(api_url=API_URL, api_key=API_KEY)\n", |
| 125 | + "\n", |
| 126 | + "start = time.perf_counter()\n", |
| 127 | + "\n", |
| 128 | + "seg_result = client.infer(image, model_id=SEGMENTATION_MODEL_ID)\n", |
| 129 | + "\n", |
| 130 | + "print(f\"{SEGMENTATION_MODEL_ID} inference took {(time.perf_counter() - start):.2f} sec\")" |
| 131 | + ] |
| 132 | + }, |
| 133 | + { |
| 134 | + "cell_type": "markdown", |
| 135 | + "id": "9418f08b", |
| 136 | + "metadata": {}, |
| 137 | + "source": [ |
| 138 | + "Let's take a look at the detections to check if they make sense." |
| 139 | + ] |
| 140 | + }, |
| 141 | + { |
| 142 | + "cell_type": "code", |
| 143 | + "execution_count": null, |
| 144 | + "id": "f21ce7aa", |
| 145 | + "metadata": {}, |
| 146 | + "outputs": [], |
| 147 | + "source": [ |
| 148 | + "import numpy as np\n", |
| 149 | + "\n", |
| 150 | + "detections = sv.Detections.from_inference(seg_result)\n", |
| 151 | + "\n", |
| 152 | + "# remove low-confidence detections\n", |
| 153 | + "detections = detections[detections.confidence > 0.5]\n", |
| 154 | + "\n", |
| 155 | + "labels = [\n", |
| 156 | + " f\"#{i} ({class_name})\" for i, class_name in enumerate(detections.data[\"class_name\"])\n", |
| 157 | + "]\n", |
| 158 | + "mask_annotator = sv.MaskAnnotator()\n", |
| 159 | + "label_annotator = sv.LabelAnnotator()\n", |
| 160 | + "annotated = mask_annotator.annotate(scene=image.copy(), detections=detections)\n", |
| 161 | + "annotated = label_annotator.annotate(scene=annotated, detections=detections, labels=labels)\n", |
| 162 | + "\n", |
| 163 | + "sv.plot_image(annotated)\n", |
| 164 | + "\n", |
| 165 | + "with sv.ImageSink(target_dir_path=OUTPUT_DIR) as sink:\n", |
| 166 | + " sink.save_image(annotated, \"annotated.png\")" |
| 167 | + ] |
| 168 | + }, |
| 169 | + { |
| 170 | + "cell_type": "markdown", |
| 171 | + "id": "2c915798", |
| 172 | + "metadata": {}, |
| 173 | + "source": [ |
| 174 | + "Step 3: Pass the input image and object masks to SAM-3D to generate 3D reconstructions of each object. \n", |
| 175 | + "\n", |
| 176 | + "This will take a few minutes the first time as the model weights need to be downloaded to the server. Subsequent inference calls can take anywhere from seconds to minutes depending on the number of objects and the inference configuration." |
| 177 | + ] |
| 178 | + }, |
| 179 | + { |
| 180 | + "cell_type": "code", |
| 181 | + "execution_count": null, |
| 182 | + "id": "33f2694e", |
| 183 | + "metadata": {}, |
| 184 | + "outputs": [], |
| 185 | + "source": [ |
| 186 | + "# flatten polygons to the expected [x1 y1 x2 y2 ... xN yN] format\n", |
| 187 | + "mask_input = [\n", |
| 188 | + " np.array(sv.mask_to_polygons(mask)[0]).flatten().tolist()\n", |
| 189 | + " for mask in detections.mask\n", |
| 190 | + "]\n", |
| 191 | + "\n", |
| 192 | + "start = time.perf_counter()\n", |
| 193 | + "sam3_3d_result = client.sam3_3d_infer(\n", |
| 194 | + " inference_input=image,\n", |
| 195 | + " mask_input=mask_input,\n", |
| 196 | + " model_id=SAM3_3D_MODEL_ID,\n", |
| 197 | + " # 'Fast' SAM-3D config\n", |
| 198 | + " output_meshes=False,\n", |
| 199 | + " output_scene=False,\n", |
| 200 | + " with_mesh_postprocess=False,\n", |
| 201 | + " with_texture_baking=False,\n", |
| 202 | + " use_distillations=True,\n", |
| 203 | + ")\n", |
| 204 | + "print(f\"SAM-3D inference took {(time.perf_counter() - start):.2f} sec\")\n", |
| 205 | + "\n", |
| 206 | + "detections.data[\"sam3_3d\"] = sam3_3d_result[\"objects\"]" |
| 207 | + ] |
| 208 | + }, |
| 209 | + { |
| 210 | + "cell_type": "markdown", |
| 211 | + "id": "3dc8d557", |
| 212 | + "metadata": {}, |
| 213 | + "source": [ |
| 214 | + "Step 4: Transform the 3D objects into a common global frame using their layout metadata, and draw them in [Rerun.io](https://rerun.io).\n", |
| 215 | + "\n", |
| 216 | + "When `output_scene=True` SAM-3D will output a combined 3D asset containing all 3D objects in a common frame. The code below uses the same Y-up frame convention to draw the objects, so it's consistent with what SAM-3D provides natively.\n", |
| 217 | + "\n", |
| 218 | + "Rerun will log to disk at `OUTPUT_DIR/rerun_log.rrd`. You can then visualize this file in the notebook or using the standalone Rerun viewer `rerun [OUTPUT_DIR]/rerun_log.rrd`." |
| 219 | + ] |
| 220 | + }, |
| 221 | + { |
| 222 | + "cell_type": "code", |
| 223 | + "execution_count": null, |
| 224 | + "id": "fcb2839d", |
| 225 | + "metadata": {}, |
| 226 | + "outputs": [], |
| 227 | + "source": [ |
| 228 | + "from base64 import b64decode\n", |
| 229 | + "from io import BytesIO\n", |
| 230 | + "\n", |
| 231 | + "import torch\n", |
| 232 | + "from pytorch3d.io import IO\n", |
| 233 | + "from pytorch3d.transforms.rotation_conversions import quaternion_to_matrix\n", |
| 234 | + "\n", |
| 235 | + "import rerun as rr\n", |
| 236 | + "\n", |
| 237 | + "rr.init(\"sam-3d-detect\")\n", |
| 238 | + "rr.save(os.path.join(OUTPUT_DIR, \"rerun_log.rrd\"))\n", |
| 239 | + "rr.log(\"/\", rr.ViewCoordinates.RIGHT_HAND_Y_UP, rr.TransformAxes3D(0.5), static=True)\n", |
| 240 | + "\n", |
| 241 | + "rr.set_time(\"tick\", sequence=0)\n", |
| 242 | + "\n", |
| 243 | + "rr.log(\"/camera/image\", rr.Image(annotated, color_model=\"bgr\"))\n", |
| 244 | + "\n", |
| 245 | + "# Coordinate transforms used in make_scene_glb\n", |
| 246 | + "z_to_y_up = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], dtype=torch.float)\n", |
| 247 | + "y_to_z_up = torch.tensor([[1, 0, 0], [0, 0, 1], [0, -1, 0]], dtype=torch.float)\n", |
| 248 | + "R_view = torch.tensor([[-1, 0, 0], [0, 0, -1], [0, -1, 0]], dtype=torch.float)\n", |
| 249 | + "\n", |
| 250 | + "for i in range(len(detections)):\n", |
| 251 | + " det = detections[i]\n", |
| 252 | + " obj_id = f\"#{i}\"\n", |
| 253 | + " if \"sam3_3d\" not in det.data:\n", |
| 254 | + " print(f\"No 3D data available for {obj_id}\")\n", |
| 255 | + " continue\n", |
| 256 | + " obj_sam3_3d = det.data[\"sam3_3d\"][0]\n", |
| 257 | + "\n", |
| 258 | + " obj_ply = IO().load_pointcloud(BytesIO(b64decode(obj_sam3_3d[\"gaussian_ply\"])))\n", |
| 259 | + " obj_pts = obj_ply.points_list()[0]\n", |
| 260 | + " obj_pts = obj_pts[::100, :] # Keep 1% of points to speed up rendering\n", |
| 261 | + " obj_box_size = (obj_pts.amax(dim=0) - obj_pts.amin(dim=0))\n", |
| 262 | + " obj_rgb = sv.annotators.utils.resolve_color(sv.ColorPalette.DEFAULT, detections, i).as_rgb()\n", |
| 263 | + "\n", |
| 264 | + " metadata = obj_sam3_3d[\"metadata\"]\n", |
| 265 | + " t = torch.tensor(metadata[\"translation\"], dtype=torch.float)\n", |
| 266 | + " R = quaternion_to_matrix(torch.tensor(metadata[\"rotation\"], dtype=torch.float))\n", |
| 267 | + " s = torch.tensor(metadata[\"scale\"], dtype=torch.float)\n", |
| 268 | + " # 1. Z-up → Y-up coordinate conversion (row-vector convention throughout SAM3D)\n", |
| 269 | + " # 2. PyTorch3D quaternion_to_matrix is column-vector (R @ v), but SAM3D uses it\n", |
| 270 | + " # row-vector (v @ R), so pass R.T to Rerun's column-vector mat3x3\n", |
| 271 | + " # 3. R_view: global scene correction from make_scene_glb, applied in world space\n", |
| 272 | + " t = t @ z_to_y_up @ R_view\n", |
| 273 | + " R = R_view @ y_to_z_up @ R.T @ z_to_y_up\n", |
| 274 | + "\n", |
| 275 | + " rr.log(\n", |
| 276 | + " f\"objects/{obj_id}\",\n", |
| 277 | + " rr.Boxes3D(sizes=obj_box_size, colors=obj_rgb, labels=obj_id),\n", |
| 278 | + " rr.Transform3D(translation=t, mat3x3=R, scale=s),\n", |
| 279 | + " )\n", |
| 280 | + " rr.log(\n", |
| 281 | + " f\"objects/{obj_id}/pts\",\n", |
| 282 | + " rr.Points3D(positions=obj_pts, colors=obj_rgb),\n", |
| 283 | + " )" |
| 284 | + ] |
| 285 | + }, |
| 286 | + { |
| 287 | + "cell_type": "code", |
| 288 | + "execution_count": null, |
| 289 | + "id": "ab9ba8e6", |
| 290 | + "metadata": {}, |
| 291 | + "outputs": [], |
| 292 | + "source": [ |
| 293 | + "# You can also use the standalone viewer app\n", |
| 294 | + "# rerun [OUTPUT_DIR]/rerun_log.rrd\n", |
| 295 | + "rr.notebook_show()\n", |
| 296 | + "rr.log_file_from_path(os.path.join(OUTPUT_DIR, \"rerun_log.rrd\"))" |
| 297 | + ] |
| 298 | + } |
| 299 | + ], |
| 300 | + "metadata": { |
| 301 | + "kernelspec": { |
| 302 | + "display_name": "Python 3", |
| 303 | + "language": "python", |
| 304 | + "name": "python3" |
| 305 | + }, |
| 306 | + "language_info": { |
| 307 | + "codemirror_mode": { |
| 308 | + "name": "ipython", |
| 309 | + "version": 3 |
| 310 | + }, |
| 311 | + "file_extension": ".py", |
| 312 | + "mimetype": "text/x-python", |
| 313 | + "name": "python", |
| 314 | + "nbconvert_exporter": "python", |
| 315 | + "pygments_lexer": "ipython3", |
| 316 | + "version": "3.10.19" |
| 317 | + } |
| 318 | + }, |
| 319 | + "nbformat": 4, |
| 320 | + "nbformat_minor": 5 |
| 321 | +} |
0 commit comments