Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions .github/workflows/unit-tests-docker.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: Unit Tests
on:
push:
branches: [ "master", "develop"]
pull_request:
branches: [ "master", "develop"]

jobs:
unit_tests:
name: Unit Tests
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v3

- name: Build Docker image
run: docker build -t oprl .

- name: Unit Tests
run: docker run --rm oprl

- name: Extract coverage
run: |
docker run --rm -v $(pwd):/host oprl sh -c "
pytest --cov=oprl --cov-report=xml &&
cp coverage.xml /host/
"

- name: Upload coverage
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: schatty/oprl
file: ./coverage.xml
12 changes: 12 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
FROM python:3.10.8

WORKDIR /app

RUN pip install --no-cache-dir --upgrade pip

COPY . .

RUN pip install --no-cache-dir . && pip install pytest pytest-cov

# Run tests by default
CMD ["pytest", "tests/functional"]
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
A Modular Library for Off-Policy Reinforcement Learning with a focus on SafeRL and distributed computing. Benchmarking resutls are available at associated homepage: [Homepage](https://schatty.github.io/oprl/)

[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![codecov](https://codecov.io/gh/schatty/oprl/branch/master/graph/badge.svg)](https://codecov.io/gh/schatty/oprl)



# Disclaimer
Expand Down
44 changes: 44 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
[build-system]
requires = ["setuptools>=64"]
build-backend = "setuptools.build_meta"

[project]
name = "oprl"
version = "0.1.0"
description = "An RL Lib"
readme = "README.md"
requires-python = "==3.10.8"
license = {text = "MIT"}
authors = [
{name = "Igor Kuznetsov"},
]
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Programming Language :: Python :: 3.10",
]
dependencies = [
"torch==2.2.2",
"tensorboard==2.15.1",
"packaging==23.2",
"dm-control==1.0.11",
"mujoco==2.3.3",
"numpy==1.26.4",
]

[project.optional-dependencies]
dev = [
"pytest>=6.0",
"black",
"flake8",
]

[project.urls]
"Homepage" = "https://schatty.github.io/oprl"

[tool.setuptools.packages.find]
where = ["src"]
include = ["oprl*"]

[tool.setuptools.package-dir]
"" = "src"
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
torch==2.2.2
tensorboard==2.15.1
packaging==23.2
dm-control==1.0.16
mujoco==3.1.3
dm-control==1.0.11
mujoco==2.3.3
numpy==1.26.4
5 changes: 2 additions & 3 deletions src/oprl/configs/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ def make_env(seed: int):
config = {
"state_dim": STATE_DIM,
"action_dim": ACTION_DIM,
"num_steps": int(1_000_000),
"num_steps": int(100_000),
"eval_every": 2500,
"device": args.device,
"save_buffer": False,
"visualise_every": 0,
"visualise_every": 50000,
"estimate_q_every": 5000,
"log_every": 2500,
}
Expand All @@ -48,7 +48,6 @@ def make_algo(logger):


def make_logger(seed: int) -> Logger:
global config
log_dir = create_logdir(logdir="logs", algo="DDPG", env=args.env, seed=seed)
return FileLogger(log_dir, config)

Expand Down
1 change: 0 additions & 1 deletion src/oprl/configs/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def make_algo(logger):


def make_logger(seed: int) -> Logger:
global config
log_dir = create_logdir(logdir="logs", algo="SAC", env=args.env, seed=seed)
return FileLogger(log_dir, config)

Expand Down
2 changes: 0 additions & 2 deletions src/oprl/configs/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ def make_algo(logger):


def make_logger(seed: int) -> Logger:
global config

log_dir = create_logdir(logdir="logs", algo="TD3", env=args.env, seed=seed)
return FileLogger(log_dir, config)

Expand Down
1 change: 0 additions & 1 deletion src/oprl/configs/tqc.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def make_algo(logger: Logger):


def make_logger(seed: int) -> Logger:
global config
log_dir = create_logdir(logdir="logs", algo="TQC", env=args.env, seed=seed)
return FileLogger(log_dir, config)

Expand Down
2 changes: 1 addition & 1 deletion src/oprl/configs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def parse_args() -> argparse.Namespace:


def create_logdir(logdir: str, algo: str, env: str, seed: int) -> str:
dt = datetime.now().strftime("%Y_%m_%d_%Hh%Mm")
dt = datetime.now().strftime("%Y_%m_%d_%Hh%Mm%Ss")
log_dir = os.path.join(logdir, algo, f"{algo}-env_{env}-seed_{seed}-{dt}")
logging.info(f"LOGDIR: {log_dir}")
return log_dir
6 changes: 3 additions & 3 deletions src/oprl/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ def step(
) -> tuple[npt.ArrayLike, npt.ArrayLike, bool, bool, dict[str, Any]]:
obs, reward, cost, terminated, truncated, info = self._env.step(action)
info["cost"] = cost
return obs, reward, terminated, truncated, info
return obs.astype("float32"), reward, terminated, truncated, info

def reset(self) -> tuple[npt.ArrayLike, dict[str, Any]]:
obs, info = self._env.reset(seed=self._seed)
self._env.step(self._env.action_space.sample())
return obs, info
return obs.astype("float32"), info

def sample_action(self):
return self._env.action_space.sample()
Expand Down Expand Up @@ -129,7 +129,7 @@ def render(self) -> npt.ArrayLike:
width=self._render_width,
)
img = img.astype(np.uint8)
return np.expand_dims(img, 0)
return img

def _flat_obs(self, obs: OrderedDict) -> npt.ArrayLike:
obs_flatten = []
Expand Down
11 changes: 10 additions & 1 deletion src/oprl/trainers/base_trainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Any, Callable

import numpy as np
Expand All @@ -24,6 +25,7 @@ def __init__(
eval_interval: int = int(2e3),
num_eval_episodes: int = 10,
save_buffer_every: int = 0,
save_policy_every: int = int(50_000),
visualise_every: int = 0,
estimate_q_every: int = 0,
stdout_log_every: int = int(1e5),
Expand Down Expand Up @@ -60,6 +62,7 @@ def __init__(
self._visualize_every = visualise_every
self._estimate_q_every = estimate_q_every
self._stdout_log_every = stdout_log_every
self._save_policy_every= save_policy_every
self._logger = logger
self.seed = seed

Expand Down Expand Up @@ -106,6 +109,7 @@ def train(self):
self._eval_routine(env_step, batch)
self._visualize(env_step)
self._save_buffer(env_step)
self._save_policy(env_step)
self._log_stdout(env_step, batch)

def _eval_routine(self, env_step: int, batch):
Expand Down Expand Up @@ -151,9 +155,14 @@ def _visualize(self, env_step: int):
self._logger.log_video("eval_policy", imgs, env_step)

def _save_buffer(self, env_step: int):
# TODO: doesn't work
if self._save_buffer_every > 0 and env_step % self._save_buffer_every == 0:
self.buffer.save(f"{self.log_dir}/buffers/buffer_step_{env_step}.pickle")

def _save_policy(self, env_step: int):
if self._save_policy_every > 0 and env_step % self._save_policy_every == 0:
self._logger.save_weights(self._algo.actor, env_step)

def _estimate_q(self, env_step: int):
if self._estimate_q_every > 0 and env_step % self._estimate_q_every == 0:
q_true = self.estimate_true_q()
Expand Down Expand Up @@ -187,7 +196,7 @@ def visualise_policy(self):
action = self._algo.exploit(state)
state, _, terminated, truncated, _ = env.step(action)
done = terminated or truncated
return np.concatenate(imgs)
return np.concatenate(imgs, dtype="uint8")
except Exception as e:
print(f"Failed to visualise a policy: {e}")
return None
Expand Down
5 changes: 4 additions & 1 deletion src/oprl/trainers/safe_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
eval_interval: int = int(2e3),
num_eval_episodes: int = 10,
save_buffer_every: int = 0,
save_policy_every: int = int(50_000),
visualise_every: int = 0,
estimate_q_every: int = 0,
stdout_log_every: int = int(1e5),
Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(
eval_interval=eval_interval,
num_eval_episodes=num_eval_episodes,
save_buffer_every=save_buffer_every,
save_policy_every=save_policy_every,
visualise_every=visualise_every,
estimate_q_every=estimate_q_every,
stdout_log_every=stdout_log_every,
Expand Down Expand Up @@ -97,10 +99,11 @@ def train(self):
if len(self.buffer) < self.batch_size:
continue
batch = self.buffer.sample(self.batch_size)
self._algo.update(batch)
self._algo.update(*batch)

self._eval_routine(env_step, batch)
self._visualize(env_step)
self._save_policy(env_step)
self._save_buffer(env_step)
self._log_stdout(env_step, batch)

Expand Down
13 changes: 12 additions & 1 deletion src/oprl/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
import os
import shutil
from abc import ABC, abstractmethod
from sys import path
from typing import Any

import numpy as np
import torch
import torch.nn as nn
from torch.utils.tensorboard.writer import SummaryWriter


Expand Down Expand Up @@ -66,11 +69,19 @@ def log_scalar(self, tag: str, value: float, step: int) -> None:
self._log_scalar_to_file(tag, value, step)

def log_video(self, tag: str, imgs, step: int) -> None:
os.makedirs(os.path.join(self._log_dir, "images"))
os.makedirs(os.path.join(self._log_dir, "images"), exist_ok=True)
fn = os.path.join(self._log_dir, "images", f"{tag}_step_{step}.npz")
with open(fn, "wb") as f:
np.save(f, imgs)

def save_weights(self, weights: nn.Module, step: int) -> None:
os.makedirs(os.path.join(self._log_dir, "weights"), exist_ok=True)
fn = os.path.join(self._log_dir, "weights", f"step_{step}.w")
torch.save(
weights,
fn
)

def _log_scalar_to_file(self, tag: str, val: float, step: int) -> None:
fn = os.path.join(self._log_dir, f"{tag}.log")
os.makedirs(os.path.dirname(fn), exist_ok=True)
Expand Down
20 changes: 0 additions & 20 deletions src/setup.py

This file was deleted.

6 changes: 0 additions & 6 deletions tests/functional/requirements.txt

This file was deleted.

22 changes: 17 additions & 5 deletions tests/functional/src/test_env.py → tests/functional/test_env.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import pytest

from oprl.env import DMControlEnv
from oprl.env import make_env

dm_control_envs = [

dm_control_envs: list[str] = [
"acrobot-swingup",
"ball_in_cup-catch",
"cartpole-balance",
Expand Down Expand Up @@ -30,9 +31,20 @@
]


@pytest.mark.parametrize("env_name", dm_control_envs)
def test_dm_control_envs(env_name: str):
env = DMControlEnv(env_name, seed=0)
safety_envs: list[str] = [
"SafetyPointGoal1-v0",
"SafetyPointButton1-v0",
"SafetyPointPush1-v0",
"SafetyPointCircle1-v0",
]


env_names: list[str] = dm_control_envs # + safety_envs


@pytest.mark.parametrize("env_name", env_names)
def test_envs(env_name: str) -> None:
env = make_env(env_name, seed=0)
obs, info = env.reset()
assert obs.shape[0] == env.observation_space.shape[0]
assert isinstance(info, dict), "Info is expected to be a dict"
Expand Down
Loading