123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145 |
- import os
- import random
- from typing import List, Tuple
- import matplotlib.pyplot as plt
- import numpy as np
- import torch
- from PIL.Image import Image
- from matplotlib.pyplot import figure, imshow, axis
- from pytube import YouTube
- from torch import nn
- from pipeline import BuildDataset
- # images constituting a segments and the length in seconds
- Segment = Tuple[List[Image], int]
- class Segmentor:
- def __init__(self,
- model: nn.Module,
- min_frames: int,
- threshold: float):
- self.model = model
- self.min_frames = min_frames
- self.threshold = threshold
- @staticmethod
- def _segmentor(preds: List[int],
- min_frames: int,
- threshold: float) -> List[List[int]]:
- candidates = []
- n = len(preds)
- for idx_start in range(n):
- if preds[idx_start] == 1:
- if n - idx_start >= min_frames:
- best_here = (-1, (-1, -1))
- for idx_end in range(idx_start + min_frames - 1, len(preds)):
- if preds[idx_end] == 1:
- if np.mean(preds[idx_start:idx_end + 1]) >= threshold:
- frames = idx_end - idx_start + 1
- endpoints = (idx_start, idx_end)
- if frames > best_here[0]:
- best_here = (frames, endpoints)
- if best_here[0] > 0:
- candidates.append(best_here[1])
- overlap = True
- while overlap:
- overlap = False
- for i in range(len(candidates)):
- ref_idx_start, ref_idx_end = candidates[i]
- for j in range(i + 1, len(candidates)):
- comp_idx_start, comp_idx_end = candidates[j]
- if ref_idx_start <= comp_idx_end <= ref_idx_end or ref_idx_start <= comp_idx_start <= ref_idx_end:
- # overlapping, take the longer one
- if comp_idx_end - comp_idx_end > ref_idx_end - ref_idx_start:
- del candidates[i]
- else:
- del candidates[j]
- overlap = True
- if overlap:
- break
- if overlap:
- break
- return [list(range(idx_start, idx_end + 1)) for idx_start, idx_end in candidates]
- @staticmethod
- def _torch_img_to_pil(img: torch.Tensor) -> Image:
- return BuildDataset.transform_reverse(img)
- @staticmethod
- def _get_segment_len(indices: List[int]):
- return max(indices) - min(indices) + 1
- def segmentor(self, preds: List[int], images: List[torch.Tensor]) -> List[Segment]:
- segment_list = self._segmentor(preds, self.min_frames, self.threshold)
- return [
- ([self._torch_img_to_pil(images[idx])
- for idx in segment_idx], self._get_segment_len(segment_idx))
- for segment_idx in segment_list]
- def _predict(self, audio: torch.Tensor, image: torch.Tensor) -> int:
- return int(torch.max(self.model(audio.unsqueeze(0), image.unsqueeze(0)), 1)[1][0])
- def get_segments(self, path_video: str) -> List[Segment]:
- audio, images = BuildDataset.one_video_extract_audio_and_stills(path_video)
- preds = [self._predict(audio[idx], images[idx]) for idx in range(len(images))]
- return self.segmentor(preds, images)
- @staticmethod
- def show_images_horizontally(images: List[Image]) -> None:
- # https://stackoverflow.com/questions/36006136/how-to-display-images-in-a-row-with-ipython-display
- fig = figure(figsize=(20, 20))
- number_of_files = len(images)
- for i in range(number_of_files):
- a = fig.add_subplot(1, number_of_files, i + 1)
- image = images[i]
- imshow(image)
- axis('off')
- plt.show()
- def visualize_segments(self, path_video: str, n_to_show: int = 10) -> None:
- segments = self.get_segments(path_video)
- n_segments = len(segments)
- print(f'Found {len(segments)} segments')
- if n_segments > 0:
- for i, (segment_images, segment_len) in enumerate(segments):
- print(f'Segment {i + 1}, {segment_len} seconds')
- print(f'First {n_to_show}')
- self.show_images_horizontally(segment_images[:n_to_show])
- print(f'{n_to_show} random shots')
- self.show_images_horizontally(random.sample(segment_images, n_to_show))
- print('Last 10')
- self.show_images_horizontally(segment_images[-n_to_show:])
- print('=' * 10)
- @staticmethod
- def _download_youtube_video(youtube_id: str, show_title: bool = True) -> str:
- yt = YouTube(f'http://youtube.com/watch?v={youtube_id}')
- if show_title:
- print(f'Title: {yt.title}')
- yt_stream = yt.streams.first()
- path = f'{yt_stream.default_filename}'
- yt_stream.download()
- return path
- def visualize_segments_youtube(self,
- youtube_id: str,
- n_to_show: int = 10,
- show_title: bool = True,
- remove_file: bool = True):
- path = self._download_youtube_video(youtube_id, show_title)
- self.visualize_segments(path, n_to_show)
- if remove_file:
- os.remove(path)