支持新的推理器(Inferencer)

推理器(Inferencer)是 AISBench 中负责执行模型推理的核心组件,它根据不同的模型类型(API 模型或本地模型)采用不同的推理方式。在适配新的推理器前,建议先参考 prompt_templatemeta_template 的定义方法,了解 AISBench 对于 prompt 的构建方式。

目前 AISBench 已经支持的推理器类型如下:

  • GenInferencer:用于生成式任务的推理器,支持 API 模型和本地模型

  • MultiTurnGenInferencer:用于多轮对话任务的推理器,支持 API 模型和本地模型

  • PPLInferencer:用于困惑度(Perplexity)评估的推理器

针对某些特殊的推理场景或自定义需求,通常需要实现自定义推理器。根据调用的模型类型,推理器需要实现不同的接口:

  • API 模型:需要实现 do_request 异步方法,通过 HTTP 请求调用服务化模型

  • 本地模型:需要实现 batch_inference 同步方法,直接调用本地模型进行批量推理

新增 API 模型推理器

新增基于 API 模型的推理器,需要在 ais_bench/benchmark/openicl/icl_inferencer 下新建 my_custom_api_inferencer.py 文件,继承 BaseApiInferencer,并根据使用场景实现对应的功能接口。当前支持拓展的接口如下:

  • (必需)do_request:执行单个推理请求,用于 API 模型的异步推理

  • (必需)get_data_list:从 retriever 获取数据列表,用于构建推理数据

from multiprocessing import BoundedSemaphore
from typing import List, Optional
import uuid
import copy
import aiohttp

from ais_bench.benchmark.models.output import RequestOutput
from ais_bench.benchmark.registry import ICL_INFERENCERS
from ais_bench.benchmark.openicl.icl_retriever import BaseRetriever
from ais_bench.benchmark.openicl.icl_inferencer.icl_base_api_inferencer import BaseApiInferencer
from ais_bench.benchmark.openicl.icl_inferencer.output_handler.gen_inferencer_output_handler import GenInferencerOutputHandler


@ICL_INFERENCERS.register_module()
class MyCustomApiInferencer(BaseApiInferencer):
    """自定义 API 模型推理器类。

    Attributes:
        model_cfg: 模型配置
        batch_size (:obj:`int`, optional): 批处理大小
        output_json_filepath (:obj:`str`, optional): 输出 JSON 文件路径
        save_every (:obj:`int`, optional): 每处理多少个样本保存一次中间结果
    """

    def __init__(
        self,
        model_cfg,
        batch_size: Optional[int] = 1,
        mode: Optional[str] = "infer",
        output_json_filepath: Optional[str] = "./icl_inference_output",
        save_every: Optional[int] = 1,
        **kwargs,
    ) -> None:
        super().__init__(
            model_cfg=model_cfg,
            batch_size=batch_size,
            mode=mode,
            output_json_filepath=output_json_filepath,
            save_every=save_every,
            **kwargs,
        )

        # 初始化输出处理器
        self.output_handler = GenInferencerOutputHandler(
            perf_mode=self.perf_mode,
            save_every=self.save_every
        )

    async def do_request(
        self,
        data: dict,
        token_bucket: BoundedSemaphore,
        session: aiohttp.ClientSession
    ):
        """执行单个推理请求。

        Args:
            data: 包含请求数据的字典,通常包含以下字段:
                - prompt: 输入提示词
                - index: 数据索引
                - data_abbr: 数据集标识
                - max_out_len: 最大输出长度
                - gold: 标准答案(可选)
            token_bucket: 用于限流的信号量
            session: HTTP 会话对象
        """
        data = copy.deepcopy(data)
        index = data.pop("index")
        input = data.pop("prompt")
        data_abbr = data.pop("data_abbr")
        max_out_len = data.pop("max_out_len")
        gold = data.pop("gold", None)

        # 生成唯一标识
        uid = str(uuid.uuid4()).replace("-", "")
        output = RequestOutput(self.perf_mode)
        output.uuid = uid

        # 更新状态计数器
        await self.status_counter.post()

        # 调用模型进行推理
        await self.model.generate(input, max_out_len, output, session=session, **data)

        # 更新状态
        if output.success:
            await self.status_counter.rev()
        else:
            await self.status_counter.failed()
        await self.status_counter.finish()
        await self.status_counter.case_finish()

        # 报告结果到输出处理器
        await self.output_handler.report_cache_info(index, input, output, data_abbr, gold)

    def get_data_list(
        self,
        retriever: BaseRetriever,
    ) -> List:
        """从 retriever 获取数据列表。

        Args:
            retriever: 检索器实例,用于获取数据和生成 prompt

        Returns:
            数据列表,每个元素是一个字典,包含推理所需的信息
        """
        data_abbr = retriever.dataset.abbr
        ice_idx_list = retriever.retrieve()
        prompt_list = []

        # 为每个样本生成 prompt
        for idx, ice_idx in enumerate(ice_idx_list):
            ice = retriever.generate_ice(ice_idx)
            prompt = retriever.generate_prompt_for_generate_task(
                idx,
                ice,
                gen_field_replace_token=self.gen_field_replace_token if hasattr(self, 'gen_field_replace_token') else "",
            )
            # 解析模板
            parsed_prompt = self.model.parse_template(prompt, mode="gen")
            prompt_list.append(parsed_prompt)

        self.logger.info(f"Apply ice template finished")

        # 获取标准答案
        gold_ans = retriever.get_gold_ans()

        # 构建数据列表
        data_list = []
        for index, prompt in enumerate(prompt_list):
            data_list.append(
                {
                    "prompt": prompt,
                    "data_abbr": data_abbr,
                    "index": index,
                    "max_out_len": self.model.max_out_len,
                }
            )

        # 添加标准答案
        if gold_ans is not None:
            for index, gold in enumerate(gold_ans):
                data_list[index]["gold"] = gold

        # 数据集指定的 max_out_len 具有最高优先级
        max_out_lens = retriever.dataset_reader.get_max_out_len()
        if max_out_lens is not None:
            self.logger.warning("Dataset-specified max_out_len has highest priority, use dataset-specified max_out_len")
            for index, max_out_len in enumerate(max_out_lens):
                data_list[index]["max_out_len"] = max_out_len if max_out_len else self.model.max_out_len

        return data_list

新增推理器类建议补充到__init__.py中,方便后续自动导入。

详细实现可参考:GenInferencer

新增本地模型推理器

新增基于本地模型的推理器,需要在 ais_bench/benchmark/openicl/icl_inferencer 下新建 my_custom_local_inferencer.py 文件,继承 BaseLocalInferencer,并根据使用场景实现对应的功能接口。当前支持拓展的接口如下:

  • (必需)batch_inference:执行批量推理,用于本地模型的同步推理

  • (必需)get_data_list:从 retriever 获取数据列表,用于构建推理数据

from typing import List, Optional
from torch.utils.data import DataLoader

from ais_bench.benchmark.registry import ICL_INFERENCERS
from ais_bench.benchmark.openicl.icl_retriever import BaseRetriever
from ais_bench.benchmark.openicl.icl_inferencer.icl_base_local_inferencer import BaseLocalInferencer
from ais_bench.benchmark.openicl.icl_inferencer.output_handler.gen_inferencer_output_handler import GenInferencerOutputHandler


@ICL_INFERENCERS.register_module()
class MyCustomLocalInferencer(BaseLocalInferencer):
    """自定义本地模型推理器类。

    Attributes:
        model_cfg: 模型配置
        batch_size (:obj:`int`, optional): 批处理大小
        output_json_filepath (:obj:`str`, optional): 输出 JSON 文件路径
        save_every (:obj:`int`, optional): 每处理多少个样本保存一次中间结果
    """

    def __init__(
        self,
        model_cfg,
        batch_size: Optional[int] = 1,
        output_json_filepath: Optional[str] = "./icl_inference_output",
        save_every: Optional[int] = 1,
        **kwargs,
    ) -> None:
        super().__init__(
            model_cfg=model_cfg,
            batch_size=batch_size,
            output_json_filepath=output_json_filepath,
        )

        self.save_every = save_every

        # 初始化输出处理器
        self.output_handler = GenInferencerOutputHandler(
            perf_mode=False,  # 本地推理器通常不支持性能模式
            save_every=self.save_every
        )

    def batch_inference(
        self,
        datum: dict,
    ) -> None:
        """执行批量推理。

        Args:
            datum: 包含批量数据的字典,通常包含以下字段:
                - prompt: 输入提示词列表
                - index: 数据索引列表
                - data_abbr: 数据集标识列表
                - max_out_len: 最大输出长度列表
                - gold: 标准答案列表(可选)
        """
        indexs = datum.pop("index")
        inputs = datum.pop("prompt")
        data_abbrs = datum.pop("data_abbr")
        max_out_lens = datum.pop("max_out_len")
        golds = datum.pop("gold", [None] * len(inputs))

        # 调用本地模型进行批量推理
        # 本地模型使用模型配置中统一的 max_out_len
        outputs = self.model.generate(inputs, self.model.max_out_len, **datum)

        # 处理每个输出结果
        for index, input, output, data_abbr, gold in zip(
            indexs, inputs, outputs, data_abbrs, golds
        ):
            self.output_handler.report_cache_info_sync(
                index, input, output, data_abbr, gold
            )

    def get_data_list(
        self,
        retriever: BaseRetriever,
    ) -> List:
        """从 retriever 获取数据列表。

        Args:
            retriever: 检索器实例,用于获取数据和生成 prompt

        Returns:
            数据列表,每个元素是一个字典,包含推理所需的信息
        """
        data_abbr = retriever.dataset.abbr
        ice_idx_list = retriever.retrieve()
        prompt_list = []

        # 为每个样本生成 prompt
        for idx, ice_idx in enumerate(ice_idx_list):
            ice = retriever.generate_ice(ice_idx)
            prompt = retriever.generate_prompt_for_generate_task(
                idx,
                ice,
                gen_field_replace_token=self.gen_field_replace_token if hasattr(self, 'gen_field_replace_token') else "",
            )
            # 解析模板
            parsed_prompt = self.model.parse_template(prompt, mode="gen")
            prompt_list.append(parsed_prompt)

        self.logger.info(f"Apply ice template finished")

        # 获取标准答案
        gold_ans = retriever.get_gold_ans()

        # 构建数据列表
        data_list = []
        for index, prompt in enumerate(prompt_list):
            data_list.append(
                {
                    "prompt": prompt,
                    "data_abbr": data_abbr,
                    "index": index,
                    "max_out_len": self.model.max_out_len,
                }
            )

        # 添加标准答案
        if gold_ans is not None:
            for index, gold in enumerate(gold_ans):
                data_list[index]["gold"] = gold

        # 数据集指定的 max_out_len 具有最高优先级
        max_out_lens = retriever.dataset_reader.get_max_out_len()
        if max_out_lens is not None:
            self.logger.warning("Dataset-specified max_out_len has highest priority, use dataset-specified max_out_len")
            for index, max_out_len in enumerate(max_out_lens):
                data_list[index]["max_out_len"] = max_out_len if max_out_len else self.model.max_out_len

        return data_list

新增推理器类建议补充到__init__.py中,方便后续自动导入。

详细实现可参考:GenInferencer

同时支持 API 模型和本地模型的推理器

如果推理器需要同时支持 API 模型和本地模型,可以同时继承 BaseApiInferencerBaseLocalInferencer,并实现两个基类的必需方法。这样同一个推理器类可以用于两种类型的模型。

from ais_bench.benchmark.registry import ICL_INFERENCERS
from ais_bench.benchmark.openicl.icl_inferencer.icl_base_api_inferencer import BaseApiInferencer
from ais_bench.benchmark.openicl.icl_inferencer.icl_base_local_inferencer import BaseLocalInferencer


@ICL_INFERENCERS.register_module()
class MyCustomInferencer(BaseApiInferencer, BaseLocalInferencer):
    """同时支持 API 模型和本地模型的自定义推理器。

    该类同时继承 BaseApiInferencer 和 BaseLocalInferencer,
    需要实现两个基类的必需方法。
    """

    def __init__(
        self,
        model_cfg,
        batch_size: Optional[int] = 1,
        mode: Optional[str] = "infer",
        output_json_filepath: Optional[str] = "./icl_inference_output",
        save_every: Optional[int] = 1,
        **kwargs,
    ) -> None:
        # 调用两个基类的初始化方法
        BaseApiInferencer.__init__(
            self,
            model_cfg=model_cfg,
            batch_size=batch_size,
            mode=mode,
            output_json_filepath=output_json_filepath,
            save_every=save_every,
            **kwargs,
        )

        # 初始化输出处理器
        self.output_handler = GenInferencerOutputHandler(
            perf_mode=self.perf_mode,
            save_every=self.save_every
        )

    async def do_request(self, data: dict, token_bucket: BoundedSemaphore, session: aiohttp.ClientSession):
        """API 模型的推理方法(必需)"""
        # 实现 API 模型的推理逻辑
        pass

    def batch_inference(self, datum: dict) -> None:
        """本地模型的推理方法(必需)"""
        # 实现本地模型的推理逻辑
        pass

    def get_data_list(self, retriever: BaseRetriever) -> List:
        """获取数据列表(必需)"""
        # 实现数据列表获取逻辑
        pass

详细实现可参考:GenInferencer

在配置文件中使用自定义推理器

定义好自定义推理器后,需要在数据集配置文件中使用它。在 ais_bench/benchmark/configs/datasets 下的相应配置文件中,将 infer_cfg 中的 inferencer 类型设置为自定义推理器类:

from ais_bench.benchmark.openicl.icl_inferencer import MyCustomInferencer
from ais_bench.benchmark.openicl.icl_prompt_template import PromptTemplate
from ais_bench.benchmark.openicl.icl_retriever import ZeroRetriever

# 推理配置
mydataset_infer_cfg = dict(
    prompt_template=dict(
        type=PromptTemplate,
        template=dict(
            round=[
                dict(
                    role="HUMAN",
                    prompt="{question}\nRemember to put your final answer within \\boxed{}.",
                ),
            ],
        ),
    ),
    retriever=dict(type=ZeroRetriever),      # 检索器配置
    inferencer=dict(type=MyCustomInferencer), # 自定义推理器配置
)

# 数据集配置列表
mydataset_datasets = [
    dict(
        type=MyDataset,                    # 自定义数据集类名
        # ... 其他数据集初始化参数 ...
        reader_cfg=mydataset_reader_cfg,   # 数据集读取配置
        infer_cfg=mydataset_infer_cfg,     # 推理配置(包含自定义推理器)
        eval_cfg=mydataset_eval_cfg        # 精度评估配置
    )
]

注意事项

  1. 注册装饰器:自定义推理器必须使用 @ICL_INFERENCERS.register_module() 装饰器进行注册,才能被配置系统识别。

  2. 输出处理器:根据实际需求选择合适的输出处理器,常用的有:

    • GenInferencerOutputHandler:用于生成式任务的输出处理

    • PPLInferencerOutputHandler:用于困惑度评估的输出处理

  3. 状态管理:对于 API 模型推理器,需要注意:

    • 使用 status_counter 来跟踪请求状态(post、rev、failed、finish、case_finish)

    • do_request 方法中正确更新状态计数器

  4. 错误处理:在推理过程中应该妥善处理异常情况,确保输出结果中包含错误信息,便于后续分析和调试。

  5. 性能模式:如果推理器需要支持性能测评(mode="perf"),需要确保:

    • API 模型推理器必须实现 parse_stream_response 接口(在模型类中)

    • 正确设置 perf_mode 标志

    • 使用 RequestOutput 来保存性能相关的指标

  6. 数据格式get_data_list 方法返回的数据列表中的每个字典必须包含以下必需字段:

    • prompt:输入提示词

    • index:数据索引

    • data_abbr:数据集标识

    • max_out_len:最大输出长度

    • gold:标准答案(可选)