97 lines
3.8 KiB
Python
97 lines
3.8 KiB
Python
"""
|
|
Modified from https://github.com/m-bain/frozen-in-time/blob/22a91d78405ec6032fdf521ae1ff5573358e632f/base/base_dataset.py
|
|
"""
|
|
import random
|
|
import decord
|
|
from PIL import Image
|
|
import numpy as np
|
|
import math
|
|
decord.bridge.set_bridge("torch")
|
|
|
|
|
|
def pts_to_secs(pts: int, time_base: float, start_pts: int) -> float:
|
|
"""
|
|
Converts a present time with the given time base and start_pts offset to seconds.
|
|
|
|
Returns:
|
|
time_in_seconds (float): The corresponding time in seconds.
|
|
|
|
https://github.com/facebookresearch/pytorchvideo/blob/main/pytorchvideo/data/utils.py#L54-L64
|
|
"""
|
|
if pts == math.inf:
|
|
return math.inf
|
|
|
|
return int(pts - start_pts) * time_base
|
|
|
|
|
|
def get_pyav_video_duration(video_reader):
|
|
video_stream = video_reader.streams.video[0]
|
|
video_duration = pts_to_secs(
|
|
video_stream.duration,
|
|
video_stream.time_base,
|
|
video_stream.start_time
|
|
)
|
|
return float(video_duration)
|
|
|
|
|
|
def get_frame_indices_by_fps():
|
|
pass
|
|
|
|
|
|
def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
|
|
if sample in ["rand", "middle"]:
|
|
acc_samples = min(num_frames, vlen)
|
|
# split the video into `acc_samples` intervals, and sample from each interval.
|
|
intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
|
|
ranges = []
|
|
for idx, interv in enumerate(intervals[:-1]):
|
|
ranges.append((interv, intervals[idx + 1] - 1))
|
|
if sample == 'rand':
|
|
try:
|
|
frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
|
|
except:
|
|
frame_indices = np.random.permutation(vlen)[:acc_samples]
|
|
frame_indices.sort()
|
|
frame_indices = list(frame_indices)
|
|
elif fix_start is not None:
|
|
frame_indices = [x[0] + fix_start for x in ranges]
|
|
elif sample == 'middle':
|
|
frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
if len(frame_indices) < num_frames: # padded with last frame
|
|
padded_frame_indices = [frame_indices[-1]] * num_frames
|
|
padded_frame_indices[:len(frame_indices)] = frame_indices
|
|
frame_indices = padded_frame_indices
|
|
elif "fps" in sample: # fps0.5, sequentially sample frames at 0.5 fps
|
|
output_fps = float(sample[3:])
|
|
duration = float(vlen) / input_fps
|
|
delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents
|
|
frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
|
|
frame_indices = np.around(frame_seconds * input_fps).astype(int)
|
|
frame_indices = [e for e in frame_indices if e < vlen]
|
|
if max_num_frames > 0 and len(frame_indices) > max_num_frames:
|
|
frame_indices = frame_indices[:max_num_frames]
|
|
# frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames)
|
|
else:
|
|
raise ValueError
|
|
return frame_indices
|
|
|
|
|
|
def read_frames_decord(video_path, num_frames, sample='rand', fix_start=None, max_num_frames=-1):
|
|
video_reader = decord.VideoReader(video_path, num_threads=1)
|
|
vlen = len(video_reader)
|
|
fps = video_reader.get_avg_fps()
|
|
duration = vlen / float(fps)
|
|
frame_indices = get_frame_indices(
|
|
num_frames, vlen, sample=sample, fix_start=fix_start,
|
|
input_fps=fps, max_num_frames=max_num_frames
|
|
)
|
|
frames = video_reader.get_batch(frame_indices) # (T, H, W, C), torch.uint8
|
|
frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8
|
|
frames = frames.split(1, dim=0)
|
|
|
|
frames = [Image.fromarray(f.squeeze().numpy(), mode='RGB') for f in frames]
|
|
# frames = frames.numpy() # convert to numpy
|
|
return frames, frame_indices, duration
|