Skip to content

Commit

Permalink
Split out subjects and add more checks for VideoSubject
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 692892203
Change-Id: I01dd3726dee34a37cd5a67139cccc932f65a2bf7
  • Loading branch information
jagapiou authored and copybara-github committed Nov 4, 2024
1 parent 5ff2a7f commit 663653f
Show file tree
Hide file tree
Showing 6 changed files with 286 additions and 199 deletions.
92 changes: 5 additions & 87 deletions meltingpot/utils/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@
import collections
from collections.abc import Collection, Iterator, Mapping
import contextlib
import os
from typing import Optional, TypeVar
import uuid

from absl import logging
import cv2
import dm_env
import meltingpot
from meltingpot.utils.evaluation import return_subject
from meltingpot.utils.evaluation import video_subject as video_subject_lib
from meltingpot.utils.policies import policy as policy_lib
from meltingpot.utils.policies import saved_model_policy
from meltingpot.utils.scenarios import population as population_lib
Expand All @@ -32,7 +30,6 @@
import numpy as np
import pandas as pd
from reactivex import operators as ops
from reactivex import subject

T = TypeVar('T')

Expand All @@ -52,85 +49,6 @@ def run_episode(
actions = population.await_action()


class VideoSubject(subject.Subject):
"""Subject that emits a video at the end of each episode."""

def __init__(
self,
root: str,
*,
extension: str = 'webm',
codec: str = 'vp90',
fps: int = 30,
) -> None:
"""Initializes the instance.
Args:
root: directory to write videos in.
extension: file extention of file.
codec: codex to write with.
fps: frames-per-second for videos.
"""
super().__init__()
self._root = root
self._extension = extension
self._codec = codec
self._fps = fps
self._path = None
self._writer = None

def on_next(self, timestep: dm_env.TimeStep) -> None:
"""Called on each timestep.
Args:
timestep: the most recent timestep.
"""
rgb_frame = timestep.observation[0]['WORLD.RGB']
if timestep.step_type.first():
self._path = os.path.join(
self._root, f'{uuid.uuid4().hex}.{self._extension}')
height, width, _ = rgb_frame.shape
self._writer = cv2.VideoWriter(
filename=self._path,
fourcc=cv2.VideoWriter_fourcc(*self._codec),
fps=self._fps,
frameSize=(width, height),
isColor=True)
elif self._writer is None:
raise ValueError('First timestep must be StepType.FIRST.')
bgr_frame = cv2.cvtColor(rgb_frame, cv2.COLOR_RGB2BGR)
assert self._writer.isOpened() # Catches any cv2 usage errors.
self._writer.write(bgr_frame)
if timestep.step_type.last():
self._writer.release()
super().on_next(self._path)
self._path = None
self._writer = None

def dispose(self):
"""See base class."""
if self._writer is not None:
self._writer.release()
super().dispose()


class ReturnSubject(subject.Subject):
"""Subject that emits the player returns at the end of each episode."""

def on_next(self, timestep: dm_env.TimeStep):
"""Called on each timestep.
Args:
timestep: the most recent timestep.
"""
if timestep.step_type.first():
self._return = np.zeros_like(timestep.reward)
self._return += timestep.reward
if timestep.step_type.last():
super().on_next(self._return)
self._return = None


def run_and_observe_episodes(
population: population_lib.Population,
substrate: substrate_lib.Substrate,
Expand Down Expand Up @@ -173,19 +91,19 @@ def subscribe(observable, *args, **kwargs):
stack.callback(disposable.dispose)

if video_root:
video_subject = VideoSubject(video_root)
video_subject = video_subject_lib.VideoSubject(video_root)
subscribe(substrate_observables.timestep, video_subject)
subscribe(video_subject, on_next=data['video_path'].append)

focal_return_subject = ReturnSubject()
focal_return_subject = return_subject.ReturnSubject()
subscribe(focal_observables.timestep, focal_return_subject)
subscribe(focal_return_subject, on_next=data['focal_player_returns'].append)
subscribe(focal_return_subject.pipe(ops.map(np.mean)),
on_next=data['focal_per_capita_return'].append)
subscribe(focal_observables.names,
on_next=data['focal_player_names'].append)

background_return_subject = ReturnSubject()
background_return_subject = return_subject.ReturnSubject()
subscribe(background_observables.timestep, background_return_subject)
subscribe(background_return_subject,
on_next=data['background_player_returns'].append)
Expand Down
112 changes: 0 additions & 112 deletions meltingpot/utils/evaluation/evaluation_test.py

This file was deleted.

39 changes: 39 additions & 0 deletions meltingpot/utils/evaluation/return_subject.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2024 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Subject that emits the player returns at the end of each episode."""

import dm_env
import numpy as np
from reactivex import subject


class ReturnSubject(subject.Subject):
"""Subject that emits the player returns at the end of each episode."""

_return: np.ndarray | None = None

def on_next(self, timestep: dm_env.TimeStep) -> None:
"""Called on each timestep.
Args:
timestep: the most recent timestep.
"""
if timestep.step_type.first():
self._return = np.zeros_like(timestep.reward)
elif self._return is None:
raise ValueError('First timestep must be StepType.FIRST.')
self._return += timestep.reward
if timestep.step_type.last():
super().on_next(self._return)
self._return = None
53 changes: 53 additions & 0 deletions meltingpot/utils/evaluation/return_subject_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2024 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from absl.testing import absltest
import dm_env
from meltingpot.utils.evaluation import return_subject
import numpy as np


def _send_timesteps_to_subject(subject, timesteps):
results = []
subject.subscribe(on_next=results.append)

for n, timestep in enumerate(timesteps):
subject.on_next(timestep)
if results:
return n, results.pop()
return None, None


class ReturnSubjectTest(absltest.TestCase):

def test(self):
timesteps = [
dm_env.restart(observation=[{}])._replace(reward=[0, 0]),
dm_env.transition(observation=[{}], reward=[2, 4]),
dm_env.termination(observation=[{}], reward=[1, 3]),
]
subject = return_subject.ReturnSubject()
step_written, episode_returns = _send_timesteps_to_subject(
subject, timesteps
)

with self.subTest('written_on_final_step'):
self.assertEqual(step_written, 2)

with self.subTest('returns'):
np.testing.assert_equal(episode_returns, [3, 7])

if __name__ == '__main__':
absltest.main()
Loading

0 comments on commit 663653f

Please sign in to comment.