深度强化学习的不平衡分类。
imbDRL的Python项目详细描述
imbDRL公司
非平衡分类与深度强化学习。
此存储库包含使用TensorFlow 2.3和TF Agents 0.6对不平衡数据集进行二进制分类的多种实现:
- 在
由van Hasselt et al.在this paper上发表的双深度Q网络正在使用一个基于Lin et al的this paper的定制环境。在
在 - 在 在
这两种实现的MNIST、IMDB和{a9}数据集的示例脚本可以在./imbDRL/examples
文件夹中找到。在
要求
- Python 3.8
pip install -r requirements.txt
- 可选:
./data/
文件夹,位于此存储库的根目录下。- 如果要使用Credit Card Fraud数据集,此文件夹必须包含从Kaggle下载的
creditcard.csv
。在 - {cd4>需要分开测试。请使用函数
imbDRL.utils.split_csv
- 如果要使用Credit Card Fraud数据集,此文件夹必须包含从Kaggle下载的
- 日志将保存到
./logs/
,经过训练的模型将保存到./models/
入门
- 对于DDQN示例:
python .\imbDRL\examples\ddqn\train_cartpole.py
python .\imbDRL\examples\ddqn\train_credit.py
python .\imbDRL\examples\ddqn\train_image.py
- 对于强盗的例子:
python .\imbDRL\examples\bandit\train_bandit_credit.py
python .\imbDRL\examples\bandit\train_bandit_image.py
python .\imbDRL\examples\bandit\train_bandit_imdb.py
张力板
要启用TensorBoard,请运行tensorboard --logdir logs
。在
测试和剥落
额外的参数用./tox.ini
文件处理。在
- Pytest:
python -m pytest
- 薄片8:
flake8
- 可在
./htmlcov
文件夹中找到覆盖范围
- 项目
标签: