VITA-MLLM,全称Visual Interactive Task AI - Multimodal Large Language Model,是由腾讯优图实验室联合南京大学、厦门大学以及中国科学院自动化研究所共同研发的首个开源多模态大语言模型。

VITA-MLLM是一个基于Mixtral8×7B基础架构的扩展模型,它通过增加中文词汇量并进行双语指令微调来提升其在中文环境下的表现。

不同于传统的单模态或仅能处理有限模态的语言模型,VITA-MLLM旨在成为一个能够同时处理多种类型输入信息的全能型AI系统。

VITA-MLLM可以同时理解并分析来自不同感官通道的信息,比如视觉(视频、图像)和听觉(音频),这使得它能够在更加复杂的情境下做出响应。

为了提高用户体验,VITA采用了生成模型与监控模型相结合的方式工作,其中生成模型用于应答用户提问,而监控模型则负责监听环境中的声音变化,以便适时调整响应策略。

github项目地址:https://github.com/VITA-MLLM/VITA。

一、环境安装

1、python环境

建议安装python版本在3.10以上。

2、pip库安装

pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

pip install -r web_demo/web_demo_requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

pip install flash-attn --no-build-isolation

3、VITA模型下载

git lfs install

git clone https://huggingface.co/VITA-MLLM/VITA

4、InternViT-300M-448px模型下载

git lfs install

git clone https://huggingface.co/OpenGVLab/InternViT-300M-448px

、功能测试

1、运行测试

(1)python代码调用测试

import argparse
import os
import time

import numpy as np
import torch
from PIL import Image

from decord import VideoReader, cpu
from vita.constants import (
    DEFAULT_AUDIO_TOKEN,
    DEFAULT_IMAGE_TOKEN,
    MAX_IMAGE_LENGTH,
)
from vita.conversation import SeparatorStyle, conv_templates
from vita.model.builder import load_pretrained_model
from vita.util.data_utils_video_audio_neg_patch import dynamic_preprocess
from vita.util.mm_utils import (
    KeywordsStoppingCriteria,
    get_model_name_from_path,
    tokenizer_image_audio_token,
    tokenizer_image_token,
)
from vita.util.utils import disable_torch_init


def _get_rawvideo_dec(
    video_path,
    image_processor,
    max_frames=MAX_IMAGE_LENGTH,
    min_frames=4,
    image_resolution=384,
    video_framerate=1,
    s=None,
    e=None,
    image_aspect_ratio="pad",
):
    if s is None:
        start_time, end_time = None, None
    else:
        start_time, end_time = int(s), int(e)
        start_time, end_time = max(0, start_time), max(0, end_time)
        if start_time > end_time:
            start_time, end_time = end_time, start_time
        elif start_time == end_time:
            end_time += 1

    if not os.path.exists(video_path):
        raise FileNotFoundError(f"Video path {video_path} does not exist.")

    vreader = VideoReader(video_path, ctx=cpu(0))
    fps = vreader.get_avg_fps()
    f_start, f_end = int(start_time * fps), int(min(end_time * fps, len(vreader) - 1))
    num_frames = f_end - f_start + 1

    if num_frames <= 0:
        raise ValueError(f"Invalid frame range in video {video_path}")

    sample_fps = int(video_framerate)
    t_stride = int(round(float(fps) / sample_fps))
    all_pos = list(range(f_start, f_end + 1, t_stride))
    all_pos_len = len(all_pos)

    sample_pos = (
        [all_pos[i] for i in np.linspace(0, all_pos_len - 1, num=max_frames, dtype=int)]
        if all_pos_len > max_frames else
        [all_pos[i] for i in np.linspace(0, all_pos_len - 1, num=min_frames, dtype=int)]
    )

    patch_images = [Image.fromarray(f) for f in vreader.get_batch(sample_pos).asnumpy()]

    if image_aspect_ratio == "pad":
        def expand2square(pil_img, background_color):
            width, height = pil_img.size
            size = max(width, height)
            result = Image.new(pil_img.mode, (size, size), background_color)
            result.paste(pil_img, (0, (size - height) // 2) if width > height else ((size - width) // 2, 0))
            return result

        patch_images = [expand2square(img, tuple(int(x * 255) for x in image_processor.image_mean)) for img in patch_images]

    patch_images = [image_processor.preprocess(img, return_tensors="pt")["pixel_values"][0] for img in patch_images]
    return torch.stack(patch_images), len(patch_images)


def parse_args():
    parser = argparse.ArgumentParser(description="Process model and video paths.")
    parser.add_argument("--model_path", type=str, required=True, help="Path to the model directory")
    parser.add_argument("--model_base", type=str, default=None)
    parser.add_argument("--video_path", type=str, default=None)
    parser.add_argument("--image_path", type=str, default=None)
    parser.add_argument("--audio_path", type=str, default=None)
    parser.add_argument("--model_type", type=str, default="mixtral-8x7b")
    parser.add_argument("--conv_mode", type=str, default="mixtral_two")
    parser.add_argument("--question", type=str, default="")
    return parser.parse_args()


def load_model(args):
    model_path = os.path.expanduser(args.model_path)
    model_name = get_model_name_from_path(model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model(
        model_path, args.model_base, model_name, args.model_type
    )
    model.resize_token_embeddings(len(tokenizer))

    vision_tower = model.get_vision_tower()
    if not vision_tower.is_loaded:
        vision_tower.load_model()
    
    audio_encoder = model.get_audio_encoder()
    audio_encoder.to(dtype=torch.float16)

    return tokenizer, model, image_processor, audio_encoder.audio_processor


def load_audio(audio_path, audio_processor):
    if audio_path is None:
        audio = torch.zeros(400, 80)
    else:
        audio, _ = audio_processor.process(os.path.join(audio_path))
    audio_length = audio.shape[0]
    audio = torch.unsqueeze(audio, dim=0)
    audio_length = torch.unsqueeze(torch.tensor(audio_length), dim=0)
    
    return {
        "audios": audio.half().cuda(),
        "lengths": audio_length.half().cuda()
    }


def load_input_data(args, image_processor, model):
    if args.video_path:
        video_frames, slice_len = _get_rawvideo_dec(
            args.video_path,
            image_processor,
            max_frames=MAX_IMAGE_LENGTH,
            video_framerate=1,
            image_aspect_ratio=getattr(model.config, "image_aspect_ratio", None)
        )
        image_tensor = video_frames.half().cuda()
        prompt_prefix = DEFAULT_IMAGE_TOKEN * slice_len
        modality = "video"
    elif args.image_path:
        image = Image.open(args.image_path).convert("RGB")
        image, p_num = dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=True)
        assert len(p_num) == 1
        image_tensor = model.process_images(image, model.config).to(dtype=model.dtype, device="cuda")
        prompt_prefix = DEFAULT_IMAGE_TOKEN * p_num[0]
        modality = "image"
    else:
        image_tensor = torch.zeros((1, 3, 448, 448)).to(dtype=model.dtype, device="cuda")
        prompt_prefix = ""
        modality = "lang"

    return image_tensor, prompt_prefix, modality


def main():
    args = parse_args()

    disable_torch_init()
    tokenizer, model, image_processor, audio_processor = load_model(args)

    model.eval()
    audios = load_audio(args.audio_path, audio_processor)
    image_tensor, prompt_prefix, modality = load_input_data(args, image_processor, model)

    qs = args.question
    conv = conv_templates[args.conv_mode].copy()
    conv.append_message(conv.roles[0], f"{prompt_prefix}\n{qs}{DEFAULT_AUDIO_TOKEN if args.audio_path else ''}")
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt(modality)

    if args.audio_path:
        input_ids = tokenizer_image_audio_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
    else:
        input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
    
    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    stopping_criteria = KeywordsStoppingCriteria([stop_str], tokenizer, input_ids)

    start_time = time.time()
    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            audios=audios,
            do_sample=False,
            temperature=0.01,
            top_p=None,
            num_beams=1,
            output_scores=True,
            return_dict_in_generate=True,
            max_new_tokens=1024,
            use_cache=True,
            stopping_criteria=[stopping_criteria],
        )
    infer_time = time.time() - start_time
    output_ids = output_ids.sequences
    input_token_len = input_ids.shape[1]
    outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=False)[0].strip()
    
    if outputs.endswith(stop_str):
        outputs = outputs[:-len(stop_str)].strip()
        
    print(outputs)
    print(f"Time consume: {infer_time:.2f} seconds")


if __name__ == "__main__":
    main()

未完......

更多详细的欢迎关注:杰哥新技术

Logo

Agent 垂直技术社区,欢迎活跃、内容共建。

更多推荐