EMA-ViTMatting
Using EMA to train Matting task.
Single RGB image input, single alpha image output.
This project focuses on the field of image alpha matting. Currently, there are few open-source end-to-end alpha matting models available, most of which are based on convolutional neural network models with large parameter sizes. Therefore, this paper adopts a mobile ViT combined with an improved cascaded decoder module to create a lightweight alpha matting model with reduced computational complexity. The innovation lies in the combination of a lightweight ViT model and an improved decoder module, bringing a more efficient solution to the alpha matting field.
👀 Demo
Demo: Bilibili Video
Original Image | Label | Our Results | Our Result | -– | Original Image | Label | Our Result | Our Result |
---|---|---|---|---|---|---|---|---|
![]() |
![]() |
![]() |
![]() |
-– | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
-– | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
-– | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
-– | ![]() |
![]() |
![]() |
![]() |
📦 Prerequisites
Requirements:
- Python >= 3.8
- torch >= 2.2.2
- CUDA Version >= 11.7
🔧 Install
Configure Environment:
git clone git@github.com:CSsaan/EMA-ViTMatting.git cd EMA-ViTMatting conda create -n ViTMatting python=3.10 -y conda activate ViTMatting pip install -r requirements.txt
🚀 Quick Start
train script:
1
2
3
4
5
6
7
8
9
Dataset directory structure:
data
└── AIM500
├── train
│ ├── original
│ └── mask
└── test
├── original
└── mask
python train.py –use_model_name ‘VisionTransformer’ –reload_model False –local_rank 0 –world_size 4 –batch_size 16 –data_path ‘/data/AIM500’ –use_distribute False
--use_model_name 'VisionTransformer'
: The name of the model to load--reload_model False
: Model checkpoint continuation training--local_rank 0
: The local rank of the current process--world_size 4
: The total number of processes--batch_size 16
: Batch size--data_path '/data/AIM500'
: Data path--use_distribute False
: Whether to use distributed training
test script:
python inferenceCS.py –image_path data/AIM500/test/original/o_dc288b1a.jpg –model_name MobileViT_194_pure
📖 Paper
None
🎯 Todo
- Data preprocessing -> dataset\AIM_500_datasets.py
- Data augmentation -> dataset\AIM_500_datasets.py
- Model loading -> config.py & Trainer.py
- Loss functions -> benchmark\loss.py
- Dynamic learning rate -> train.py
- Distributed training -> train.py
- Model visualization -> model\mobile_vit.py
- Model parameters -> benchmark\config\model_MobileViT_parameters.yaml
- Training -> train.py
- Model saving -> Trainer.py
- Test visualization ->
- Model inference -> inferenceCS.py
- Pytorch model to onnx -> onnx_demo
- Model acceleration ->
- Model optimization ->
- Model tuning ->
- Model integration ->
- Model quantization, compression, deployment ->
📂 Repo structure (WIP)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
├── README.md
├── benchmark
│ ├── loss.py -> loss functions
│ └── config -> all model's parameters
├── utils
│ ├── testGPU.py
│ ├── yuv_frame_io.py
│ └── print_structure.py
├── onnx_demo
│ ├── export_onnx.py
│ └── infer_onnx.py
├── data -> dataset
├── dataset -> dataloder
├── log -> tensorboard log
├── model
├── Trainer.py -> load model & train.
├── config.py -> all models dictionary
├── dataset.py -> dataLoader
├── demo_Fc.py -> model inder
├── pyproject.toml -> project config
├── requirements.txt
├── train.py -> main
└── inferenceCS.py -> model inference