Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
18f0a0a
Move configs from package and refine
schatty Jul 5, 2025
47dd280
Replace inheritance with composition for trainer, make visualization …
schatty Jul 6, 2025
f1193d5
Refactor replay buffer
schatty Jul 6, 2025
39eb435
Refactor environment module
schatty Jul 7, 2025
bd39d30
Add new environment files
schatty Jul 7, 2025
6feea32
Remove utils folder
schatty Jul 7, 2025
3880f02
Make trainers dataclasses
schatty Jul 12, 2025
faa8d0d
Add option for storing logs outside src folder, make logs dir creatio…
schatty Jul 12, 2025
1edb178
Introduce explicit protocols
schatty Jul 12, 2025
f4dc094
Introduce base algorithm with explore and exploit functionality
schatty Jul 12, 2025
efe2f9d
Ensure algo and buffer has been created in trainer
schatty Jul 12, 2025
d4b5486
Add proper stdout logging instead of bare prints
schatty Jul 12, 2025
0e6511f
Make algos configurable
schatty Jul 13, 2025
2ed37d0
Replace bare dict with pydantic config in configs
schatty Jul 13, 2025
fa44399
Refine annotations
schatty Jul 17, 2025
80f6325
Make distrib code work
schatty Jul 17, 2025
8e00d9c
Refactor distrib training
schatty Jul 17, 2025
45f006d
Add explicit gymnasium support
schatty Jul 17, 2025
6b972d7
Fix: gymnasium support
schatty Jul 17, 2025
5b815f2
Add uv ruff
schatty Jul 17, 2025
164b6e3
Add uv to dockerfile and ci
schatty Jul 17, 2025
32f1b1f
Add tests
schatty Jul 17, 2025
dc8db60
Add mypy check for custom modules
schatty Jul 17, 2025
2df0394
Add ruff to ci pipeline
schatty Jul 17, 2025
29aa426
Fix in ci
schatty Jul 17, 2025
6c9b56d
Fix ci
schatty Jul 17, 2025
00e255d
Fix mypy typo in ci
schatty Jul 17, 2025
6efbecb
Add missing module file
schatty Jul 17, 2025
23cfd4a
Exclude runners
schatty Jul 17, 2025
60f8f07
Update README.md
schatty Jul 17, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions .github/workflows/unit-tests-docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,34 @@ on:

jobs:
unit_tests:
name: Unit Tests
name: Unit Tests - Coverage - MyPy
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v3

- name: Build Docker image
run: docker build -t oprl .

- name: Ruff
run: |
docker run --rm oprl sh -c "
uv run ruff check src
"

- name: Unit Tests
run: docker run --rm oprl

- name: MyPy
run: |
docker run --rm oprl sh -c "
uv run mypy --ignore-missing-imports --python-version 3.10 src/oprl/trainers src/oprl/buffers
"

- name: Extract coverage
run: |
docker run --rm -v $(pwd):/host oprl sh -c "
pytest tests/functional --cov=oprl --cov-report=xml &&
uv run pytest tests/functional --cov=oprl --cov-report=xml &&
cp coverage.xml /host/
"

Expand Down
13 changes: 6 additions & 7 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
FROM python:3.10.8
COPY --from=ghcr.io/astral-sh/uv:0.7.21 /uv /uvx /bin/

WORKDIR /app

RUN pip install --no-cache-dir --upgrade pip
COPY . .

RUN uv sync --locked && uv pip install pytest pytest-cov mypy

# Install SafetyGymansium from external lib
RUN wget https://github.com/PKU-Alignment/safety-gymnasium/archive/refs/heads/main.zip && \
unzip main.zip && \
cd safety-gymnasium-main && \
pip install . && \
uv pip install . && \
cd ..

COPY . .

RUN pip install --no-cache-dir . && pip install pytest pytest-cov

# Run tests by default
CMD ["pytest", "tests/functional"]
CMD ["uv", "run", "pytest", "tests/functional"]
103 changes: 68 additions & 35 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,48 +1,69 @@
<img align="left" width="70" alt="oprl_logo" src="https://github.com/schatty/oprl/assets/23639048/c7ea0fee-3472-4d9c-86f3-9ab01f02222d">
<p align="center">
<img src="https://github.com/user-attachments/assets/0c98353f-3de6-46f6-b40c-db1d69672b12" alt="Description" width="150">
</p>

# OPRL

A Modular Library for Off-Policy Reinforcement Learning with a focus on SafeRL and distributed computing. Benchmarking resutls are available at associated homepage: [Homepage](https://schatty.github.io/oprl/)
A Modular Library for Off-Policy Reinforcement Learning with a focus on SafeRL and distributed computing. The code supports `SafetyGymnasium` environment set for giving a starting point developing SafeRL solutions. Distributed setting is implemented via `pika` library and will be improved in the near future.

[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![codecov](https://codecov.io/gh/schatty/oprl/branch/master/graph/badge.svg)](https://codecov.io/gh/schatty/oprl)

### Roadmap 🏗
- [x] Support for SafetyGymnasium
- [x] Style and readability improvements
- [ ] REDQ, DrQ Algorithms support
- [ ] Distributed Training Improvements

## In a Snapshot

# Disclaimer
The project is under an active renovation, for the old code with D4PG algorithm working with multiprocessing queues and `mujoco_py` please refer to the branch `d4pg_legacy`.
Environments Support

### Roadmap 🏗
- [x] Switching to `mujoco 3.1.1`
- [x] Replacing multiprocessing queues with RabbitMQ for distributed RL
- [x] Baselines with DDPG, TQC for `dm_control` for 1M step
- [x] Tests
- [x] Support for SafetyGymnasium
- [ ] Style and readability improvements
- [ ] Baselines with Distributed algorithms for `dm_control`
- [ ] D4PG logic on top of TQC
| DMControl Suite | SafetyGymnasium | Gymnasium |
| -------- | -------- | -------- |

Algorithms

| DDPG | TD3 | SAC | TQC |
| --- | --- | --- | --- |

# Installation
## Installation

The project supports [uv](https://docs.astral.sh/uv/) for package managment and [ruff](https://github.com/astral-sh/ruff) for formatting checks. To install it via uv in virutalenv:

```
pip install -r requirements.txt
cd src && pip install -e .
uv venv
source .venv/bin/activate
uv sync
```

### Installing SafetyGymnasium

For working with [SafetyGymnasium](https://github.com/PKU-Alignment/safety-gymnasium) install it manually
```
git clone https://github.com/PKU-Alignment/safety-gymnasium
cd safety-gymnasium && pip install -e .
cd safety-gymnasium && uv pip install -e .
```

## Tests

To run tests locally:

```
uv pip install pytest
uv run pytest tests/functional
```

## RL Training

# Usage
All training is set via python config files located in `configs` folder. To make your own configuration, change the code there or create a similar one. During training, all the code is copied to logs folder to ensure full experimental reproducibility.

### Single Agent

To run DDPG in a single process
```
python src/oprl/configs/ddpg.py --env walker-walk
python configs/ddpg.py --env walker-walk
```

To run distributed DDPG
### Distributed

Run RabbitMQ
```
Expand All @@ -51,23 +72,35 @@ docker run -it --rm --name rabbitmq -p 5672:5672 -p 15672:15672 rabbitmq:3.12-ma

Run training
```
python src/oprl/configs/d3pg.py --env walker-walk
```

## Tests

```
cd src && pip install -e .
cd .. && pip install -r tests/functional/requirements.txt
python -m pytest tests
python configs/distrib_ddpg.py --env walker-walk
```

## Results

Results for single process DDPG and TQC:
![ddpg_tqc_eval](https://github.com/schatty/d4pg-pytorch/assets/23639048/f2c32f62-63b4-4a66-a636-4ce0ea1522f6)

## Acknowledgements
* DDPG and TD3 code is based on the official TD3 implementation: [sfujim/TD3](https://github.com/sfujim/TD3)
* TQC code is based on the official TQC implementation: [SamsungLabs/tqc](https://github.com/SamsungLabs/tqc)
* SafetyGymnasium: [PKU-Alignment/safety-gymnasium](https://github.com/PKU-Alignment/safety-gymnasium)
## Cite

__OPRL__
```
@inproceedings{
kuznetsov2024safer,
title={Safer Reinforcement Learning by Going Off-policy: a Benchmark},
author={Igor Kuznetsov},
booktitle={ICML 2024 Next Generation of AI Safety Workshop},
year={2024},
url={https://openreview.net/forum?id=pAmTC9EdGq}
}
```

__SafetyGymnasium__
```
@inproceedings{ji2023safety,
title={Safety Gymnasium: A Unified Safe Reinforcement Learning Benchmark},
author={Jiaming Ji and Borong Zhang and Jiayi Zhou and Xuehai Pan and Weidong Huang and Ruiyang Sun and Yiran Geng and Yifan Zhong and Josef Dai and Yaodong Yang},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems Datasets and Benchmarks Track},
year={2023},
url={https://openreview.net/forum?id=WZmlxIuIGR}
}
```
73 changes: 73 additions & 0 deletions configs/ddpg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from oprl.algos.protocols import AlgorithmProtocol
from oprl.algos.ddpg import DDPG
from oprl.buffers.protocols import ReplayBufferProtocol
from oprl.buffers.episodic_buffer import EpisodicReplayBuffer
from oprl.parse_args import parse_args
from oprl.logging import (
LoggerProtocol,
make_text_logger_func,
)
from oprl.environment.protocols import EnvProtocol
from oprl.environment import make_env as _make_env
from oprl.runners.train import run_training
from oprl.runners.config import CommonParameters


args = parse_args()

def make_env(seed: int) -> EnvProtocol:
return _make_env(args.env, seed=seed)


env = make_env(seed=0)
STATE_DIM: int = env.observation_space.shape[0]
ACTION_DIM: int = env.action_space.shape[0]


config = CommonParameters(
state_dim=STATE_DIM,
action_dim=ACTION_DIM,
num_steps=int(100_000),
eval_every=2500,
device=args.device,
estimate_q_every=5000,
log_every=2500,
)


def make_algo(logger: LoggerProtocol) -> AlgorithmProtocol:
return DDPG(
state_dim=STATE_DIM,
action_dim=ACTION_DIM,
device=args.device,
logger=logger,
).create()


def make_replay_buffer() -> ReplayBufferProtocol:
return EpisodicReplayBuffer(
buffer_size_transitions=max(config.num_steps, int(1e6)),
state_dim=STATE_DIM,
action_dim=ACTION_DIM,
device=config.device,
).create()


make_logger = make_text_logger_func(
algo="DDPG",
env=args.env,
)


if __name__ == "__main__":
run_training(
make_algo=make_algo,
make_env=make_env,
make_replay_buffer=make_replay_buffer,
make_logger=make_logger,
config=config,
seeds=args.seeds,
start_seed=args.start_seed
)


92 changes: 92 additions & 0 deletions configs/distrib_ddpg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import os
import logging

import torch.nn as nn

from oprl.algos.ddpg import DDPG
from oprl.algos.nn_models import DeterministicPolicy
from oprl.algos.protocols import AlgorithmProtocol, PolicyProtocol
from oprl.buffers.protocols import ReplayBufferProtocol
from oprl.environment import make_env as _make_env
from oprl.buffers.episodic_buffer import EpisodicReplayBuffer
from oprl.logging import (
LoggerProtocol,
FileTxtLogger,
get_logs_path,
)
from oprl.parse_args import parse_args_distrib
from oprl.runners.config import DistribConfig
from oprl.runners.train_distrib import run_distrib_training
from oprl.distrib.env_worker import run_env_worker
from oprl.distrib.policy_update_worker import run_policy_update_worker


config = DistribConfig(
batch_size=128,
num_env_workers=4,
episodes_per_worker=100,
warmup_epochs=16,
episode_length=1000,
learner_num_waits=10,
)


args = parse_args_distrib()

def make_env(seed: int):
return _make_env(args.env, seed=seed)


env = make_env(seed=0)
STATE_DIM = env.observation_space.shape[0]
ACTION_DIM = env.action_space.shape[0]
logging.info(f"Env state {STATE_DIM}\tEnv action {ACTION_DIM}")


def make_logger() -> LoggerProtocol:
logs_root = os.environ.get("OPRL_LOGS", "logs")
log_dir = get_logs_path(logdir=logs_root, algo="DistribDDPG", env=args.env, seed=0)
logger = FileTxtLogger(log_dir)
logger.copy_source_code()
return logger


def make_policy() -> PolicyProtocol:
return DeterministicPolicy(
state_dim=STATE_DIM,
action_dim=ACTION_DIM,
hidden_units=(256, 256),
hidden_activation=nn.ReLU(inplace=True),
device=args.device,
)


def make_replay_buffer() -> ReplayBufferProtocol:
return EpisodicReplayBuffer(
buffer_size_transitions=int(1_000_000),
state_dim=STATE_DIM,
action_dim=ACTION_DIM,
device=args.device,
).create()


def make_algo(logger: LoggerProtocol) -> AlgorithmProtocol:
return DDPG(
logger=logger,
state_dim=STATE_DIM,
action_dim=ACTION_DIM,
device=args.device,
).create()


if __name__ == "__main__":
run_distrib_training(
run_env_worker=run_env_worker,
run_policy_update_worker=run_policy_update_worker,
make_env=make_env,
make_algo=make_algo,
make_policy=make_policy,
make_replay_buffer=make_replay_buffer,
make_logger=make_logger,
config=config,
)
Loading
Loading