我的人像分割模型-ViTMatting by CS

alpha matting by Vision Transformer

Posted by CS on July 30, 2024

EMA-ViTMatting

[Project Page] [中文主页]

Using EMA to train Matting task.

Single RGB image input, single alpha image output.

This project focuses on the field of image alpha matting. Currently, there are few open-source end-to-end alpha matting models available, most of which are based on convolutional neural network models with large parameter sizes. Therefore, this paper adopts a mobile ViT combined with an improved cascaded decoder module to create a lightweight alpha matting model with reduced computational complexity. The innovation lies in the combination of a lightweight ViT model and an improved decoder module, bringing a more efficient solution to the alpha matting field.

👀 Demo

Demo: Bilibili Video

Original Image Label Our Results Our Result -– Original Image Label Our Result Our Result
Image 1 Image 2 Image 3 Image 4 -– Image 5 Image 6 Image 7 Image 8
Image 9 Image 10 Image 11 Image 12 -– Image 13 Image 14 Image 15 Image 16
Image 17 Image 18 Image 19 Image 20 -– Image 21 Image 22 Image 23 Image 24
Image 25 Image 26 Image 27 Image 28 -– Image 29 Image 30 Image 31 Image 32

Model structure: Image 33

📦 Prerequisites

Requirements:

  • Python >= 3.8
  • torch >= 2.2.2
  • CUDA Version >= 11.7

🔧 Install

Configure Environment:

git clone git@github.com:CSsaan/EMA-ViTMatting.git cd EMA-ViTMatting conda create -n ViTMatting python=3.10 -y conda activate ViTMatting pip install -r requirements.txt

🚀 Quick Start

train script:

1
2
3
4
5
6
7
8
9
Dataset directory structure:
data
└── AIM500
    ├── train
    │   ├── original
    │   └── mask
    └── test
        ├── original
        └── mask

python train.py –use_model_name ‘VisionTransformer’ –reload_model False –local_rank 0 –world_size 4 –batch_size 16 –data_path ‘/data/AIM500’ –use_distribute False

  • --use_model_name 'VisionTransformer': The name of the model to load
  • --reload_model False: Model checkpoint continuation training
  • --local_rank 0: The local rank of the current process
  • --world_size 4: The total number of processes
  • --batch_size 16: Batch size
  • --data_path '/data/AIM500': Data path
  • --use_distribute False: Whether to use distributed training

test script:

python inferenceCS.py –image_path data/AIM500/test/original/o_dc288b1a.jpg –model_name MobileViT_194_pure

📖 Paper

None

🎯 Todo

  • Data preprocessing -> dataset\AIM_500_datasets.py
  • Data augmentation -> dataset\AIM_500_datasets.py
  • Model loading -> config.py & Trainer.py
  • Loss functions -> benchmark\loss.py
  • Dynamic learning rate -> train.py
  • Distributed training -> train.py
  • Model visualization -> model\mobile_vit.py
  • Model parameters -> benchmark\config\model_MobileViT_parameters.yaml
  • Training -> train.py
  • Model saving -> Trainer.py
  • Test visualization ->
  • Model inference -> inferenceCS.py
  • Pytorch model to onnx -> onnx_demo
  • Model acceleration ->
  • Model optimization ->
  • Model tuning ->
  • Model integration ->
  • Model quantization, compression, deployment ->

📂 Repo structure (WIP)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
├── README.md
├── benchmark 
│   ├── loss.py              -> loss functions
│   └── config               -> all model's parameters
├── utils
│   ├── testGPU.py
│   ├── yuv_frame_io.py
│   └── print_structure.py
├── onnx_demo
│   ├── export_onnx.py
│   └── infer_onnx.py
├── data                     -> dataset
├── dataset                  -> dataloder
├── log                      -> tensorboard log
├── model
├── Trainer.py               -> load model & train.
├── config.py                -> all models dictionary
├── dataset.py               -> dataLoader
├── demo_Fc.py               -> model inder
├── pyproject.toml           -> project config
├── requirements.txt
├── train.py                 -> main
└── inferenceCS.py           -> model inference