Skip to content

Commit 9a0501d

Browse files
authored
fix: video recording
* screenshot_to_np Image.Image argument * remove unused import; update documentation/types
1 parent efd8c6b commit 9a0501d

File tree

1 file changed

+13
-17
lines changed

1 file changed

+13
-17
lines changed

openadapt/video.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from PIL import Image
99
import av
1010
import numpy as np
11-
import mss
1211

1312
from openadapt import config, utils
1413

@@ -70,7 +69,7 @@ def initialize_video_writer(
7069
def write_video_frame(
7170
container: av.container.OutputContainer,
7271
stream: av.stream.Stream,
73-
screenshot: mss.base.ScreenShot,
72+
screenshot: Image.Image,
7473
timestamp: float,
7574
base_timestamp: float,
7675
last_pts: int,
@@ -87,7 +86,7 @@ def write_video_frame(
8786
container (av.container.OutputContainer): The output container to which
8887
the frame is written.
8988
stream (av.stream.Stream): The video stream within the container.
90-
screenshot (mss.base.ScreenShot): The screenshot to be written as a video frame.
89+
screenshot (Image.Image): The screenshot to be written as a video frame.
9190
timestamp (float): The timestamp of the current frame.
9291
base_timestamp (float): The base timestamp from which the video
9392
recording started.
@@ -99,16 +98,14 @@ def write_video_frame(
9998
int: The updated last_pts value, to be used for writing the next frame.
10099
101100
Note:
102-
- This function assumes the screenshot is in the correct pixel format
103-
and dimensions as specified in the video stream settings.
104101
- It is crucial to maintain monotonically increasing PTS values for the
105102
video stream's consistency and playback.
106103
- The function logs the current timestamp, base timestamp, and
107104
calculated PTS values for debugging purposes.
108105
"""
109106
logger.debug(f"{timestamp=} {base_timestamp=}")
110107

111-
# Convert MSS ScreenShot to np.ndarray
108+
# Convert PIL Image to np.ndarray
112109
frame = screenshot_to_np(screenshot)
113110

114111
# Convert the numpy array to an AVFrame
@@ -171,23 +168,22 @@ def close_container() -> None:
171168
logger.info("done")
172169

173170

174-
def screenshot_to_np(screenshot: mss.base.ScreenShot) -> np.ndarray:
175-
"""Converts an MSS screenshot to a NumPy array.
171+
def screenshot_to_np(screenshot: Image.Image) -> np.ndarray:
172+
"""Converts a PIL Image screenshot to a NumPy array.
176173
177174
Args:
178-
screenshot (mss.base.ScreenShot): The screenshot object from MSS.
175+
screenshot (PIL.Image.Image): The screenshot object from PIL.
179176
180177
Returns:
181178
np.ndarray: The screenshot as a NumPy array in RGB format.
182179
"""
183-
# Convert the screenshot to an RGB PIL Image
184-
img = screenshot.rgb
185-
# Convert the RGB data to a NumPy array and reshape it to the correct dimensions
186-
frame = np.frombuffer(img, dtype=np.uint8).reshape(
187-
screenshot.height,
188-
screenshot.width,
189-
3,
190-
)
180+
# Ensure the image is in RGB format (in case it is not)
181+
if screenshot.mode != "RGB":
182+
screenshot = screenshot.convert("RGB")
183+
184+
# Convert the PIL Image to a NumPy array
185+
frame = np.array(screenshot)
186+
191187
return frame
192188

193189

0 commit comments

Comments
 (0)