人群计数套餐

crowdcount的Python项目详细描述


人群计数套餐

PyPi VersionGitHub starsPyPi downloads

crowdcount是使用Pytorch进行群组计数的库,由Fudan-VTS Research支持

来源

安装

  • pip install crowdcount --user --upgrade

简介

人群计数任务:

  • 估计人群数量
  • crowd counting demo

用户指南:

  • 模型

     from crowdcount.models import * 
     # crowd counting models includes csr_net, mcnn, resnet50, resnet101, unet, vgg
    
  • 变换

    ^{pr2}$ 在
  • 数据加载程序

     from crowdcount.data.data_loader import *
     # includes ShanghaiTech, UCF_QNRF, UCF_CC_50, Fudan-ShanghaiTech temporarily
    
  • 数据预处理

     from crowdcount.data.data_preprocess import *
     # gaussian preprocess for datasets
    
  • 实用工具

     from crowdcount.utils import *
     # includes loss functions, optimizers, tensorboard and save function
    
  • 发动机

     from crowdcount.engine import train
     # start to train
     train(*args, **kwargs)
    
  • 更多详细信息请参见document

演示

from crowdcount.engine import train
from crowdcount.models import Res101
from crowdcount.data.data_loader import *
from crowdcount.utils import *
import crowdcount.transforms as cc_transforms
import torchvision.transforms as transforms

# init model
model = Res101()
# init transforms
img_transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.452016860247, 0.447249650955, 0.431981861591],
                                                         std=[0.23242045939, 0.224925786257, 0.221840232611])
                                    ])
gt_transform = cc_transforms.LabelEnlarge()
both_transform = cc_transforms.ComplexCompose([cc_transforms.TransposeFlip()])
# init dataset
train_set = ShanghaiTechDataset(mode="train",
                                part="b",
                                img_transform=img_transform,
                                gt_transform=gt_transform,
                                both_transform=both_transform,
                                root="/home/vts/chensongjian/CrowdCount/crowdcount/data/datasets/shtu_dataset_sigma_15")
test_set = ShanghaiTechDataset(mode="test",
                               part='b',
                               img_transform=img_transform,
                               root="/home/vts/chensongjian/CrowdCount/crowdcount/data/datasets/shtu_dataset_sigma_15")
# init loss
train_loss = AVGLoss()
test_loss = EnlargeLoss(100)
# init save function
saver = Saver(path="../exp/2019-12-22-main_sigma15_6")
# init tensorboard
tb = TensorBoard(path="../runs/2019-12-22-main_sigma15_6")
# start to train
train(model, train_set, test_set, train_loss, test_loss, optim="Adam", saver=saver, cuda_num=[3], train_batch=2,
      test_batch=2, learning_rate=1e-5, epoch_num=500, enlarge_num=100, tensorboard=tb)
  • 您可以在demo中找到更多演示

实验

我们将很快添加结果

感谢来自

欢迎加入QQ群-->: 979659372 Python中文网_新手群

推荐PyPI第三方库


热门话题
java类。getResource和ClassLoader。getSystemResource:有没有理由选择其中一个而不是另一个?   在Java中以编程方式粘贴后恢复剪贴板   Java字符串到日期没有时间   JavaSpring注释:@Component起作用,@Repository不起作用   java“addScript”在HSQL中是否有最大记录计数?   java如何将值从JDialog框返回到父JFrame?   java我的模块库的用户有没有办法访问尚未导出的类?   java javac:未找到命令   java如何解决jsoup错误:无法找到请求目标的有效证书路径   类中的java作用域变量   Java中集合实现中的arraylist add()方法不起作用   java如何使用while循环和从用户接收输入来近似Pi?   java Spring安全CSRF培训模式   在安卓系统中,如何通过在警报框外单击来限制用户?