Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions ais_bench/benchmark/cli/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from ais_bench.benchmark.cli.utils import fill_model_path_if_datasets_need, fill_test_range_use_num_prompts

class CustomConfigChecker:
MODEL_REQUIRED_FIELDS = ['type', 'abbr', 'attr']
DATASET_REQUIRED_FIELDS = ['type', 'abbr', 'reader_cfg', 'infer_cfg', 'eval_cfg']
MODEL_REQUIRED_FIELDS = ['abbr']
DATASET_REQUIRED_FIELDS = ['abbr']
SUMMARIZER_REQUIRED_FIELDS = ['attr']

def __init__(self, config, file_path):
Expand Down Expand Up @@ -106,6 +106,8 @@ def load_config(self, workflow):

def _fill_dataset_configs(self):
for dataset_cfg in self.cfg["datasets"]:
if dataset_cfg.get("infer_cfg", None) is None:
continue
fill_test_range_use_num_prompts(self.cfg["cli_args"].get("num_prompts"), dataset_cfg)
fill_model_path_if_datasets_need(self.cfg["models"][0], dataset_cfg)
retriever_cfg = dataset_cfg["infer_cfg"]["retriever"]
Expand Down
2 changes: 2 additions & 0 deletions ais_bench/benchmark/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
logger = AISLogger()

def get_config_type(obj) -> str:
if obj is None:
return None
Comment on lines +19 to +20
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While adding a None check is a good improvement for robustness, the function's return type hint -> str on line 18 is now incorrect because the function can return None. Please update the signature to -> Optional[str] to accurately reflect its behavior. You will also need to add from typing import Optional at the top of the file.

if isinstance(obj, str):
return obj
return f"{obj.__module__}.{obj.__name__}"
Expand Down
38 changes: 31 additions & 7 deletions ais_bench/benchmark/cli/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ais_bench.benchmark.partitioners import NaivePartitioner
from ais_bench.benchmark.runners import LocalRunner
from ais_bench.benchmark.tasks import OpenICLEvalTask, OpenICLApiInferTask, OpenICLInferTask
from ais_bench.benchmark.tasks.base import EmptyTask
from ais_bench.benchmark.summarizers import DefaultSummarizer, DefaultPerfSummarizer
from ais_bench.benchmark.calculators import DefaultPerfMetricCalculator
from ais_bench.benchmark.cli.utils import clear_repeat_tasks
Expand All @@ -26,6 +27,7 @@
class BaseWorker(ABC):
def __init__(self, args) -> None:
self.args = args
self.skip = False

@abstractmethod
def update_cfg(self, cfg: ConfigDict) -> None:
Expand All @@ -39,21 +41,29 @@ def do_work(self, cfg: ConfigDict):


class Infer(BaseWorker):
def update_cfg(self, cfg: ConfigDict) -> None:
def update_cfg(self, cfg: ConfigDict) -> ConfigDict:
def get_task_type() -> str:
if cfg["models"][0]["attr"] == "service":
return get_config_type(OpenICLApiInferTask)
else:
return get_config_type(OpenICLInferTask)

custom_infer = cfg.get("infer")
custom_task = None
if custom_infer:
custom_task = custom_infer.get("runner", {}).get("task", {}).get("type")
if custom_task == EmptyTask:
self.skip = True
return cfg

new_cfg = dict(
infer=dict(
partitioner=dict(type=get_config_type(NaivePartitioner)),
runner=dict(
max_num_workers=self.args.max_num_workers,
max_workers_per_gpu=self.args.max_workers_per_gpu,
debug=self.args.debug,
task=dict(type=get_task_type()),
task=dict(type=get_config_type(custom_task) if custom_task else get_task_type()),
type=get_config_type(LocalRunner),
),
),
Expand All @@ -66,6 +76,9 @@ def get_task_type() -> str:
return cfg

def do_work(self, cfg: ConfigDict):
if self.skip:
logger.info("EmptyTask is selected, skip inference.")
return
partitioner = PARTITIONERS.build(cfg.infer.partitioner)
logger.info("Starting inference tasks...")
tasks = partitioner(cfg)
Expand Down Expand Up @@ -119,7 +132,7 @@ def __init__(self, args) -> None:
super().__init__(args)
self.judge_model_type = None

def update_cfg(self, cfg: ConfigDict) -> None:
def update_cfg(self, cfg: ConfigDict) -> ConfigDict:
for dataset_cfg in cfg["datasets"]:
judge_infer_cfg = dataset_cfg.get("judge_infer_cfg")
if judge_infer_cfg:
Expand Down Expand Up @@ -276,20 +289,28 @@ def _result_post_process(self, tasks, cfg: ConfigDict):


class Eval(BaseWorker):
def update_cfg(self, cfg: ConfigDict) -> None:
def update_cfg(self, cfg: ConfigDict) -> ConfigDict:
custom_eval = cfg.get("eval")
custom_task = None
if custom_eval:
custom_task = custom_eval.get("runner", {}).get("task", {}).get("type")
if custom_task == EmptyTask:
self.skip = True
return cfg

new_cfg = dict(
eval=dict(
partitioner=dict(type=get_config_type(NaivePartitioner)),
runner=dict(
max_num_workers=self.args.max_num_workers,
max_workers_per_gpu=self.args.max_workers_per_gpu,
debug=self.args.debug,
task=dict(type=get_config_type(OpenICLEvalTask)),
task=dict(type=get_config_type(custom_task) if custom_task else get_config_type(OpenICLEvalTask)),
type=get_config_type(LocalRunner),
),
),
)

new_cfg["eval"]["runner"]["type"] = get_config_type(LocalRunner)
new_cfg["eval"]["runner"]["max_workers_per_gpu"] = self.args.max_workers_per_gpu
cfg.merge_from_dict(new_cfg)
if cfg.cli_args.dump_eval_details:
cfg.eval.runner.task.dump_details = True
Expand All @@ -301,6 +322,9 @@ def update_cfg(self, cfg: ConfigDict) -> None:
return cfg

def do_work(self, cfg: ConfigDict):
if self.skip:
logger.info("EmptyTask is selected, skip evaluation.")
return
partitioner = PARTITIONERS.build(cfg.eval.partitioner)
logger.info("Starting evaluation tasks...")
self._cfg_pre_process(cfg)
Expand Down
4 changes: 2 additions & 2 deletions ais_bench/benchmark/partitioners/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def _check_task_cfg(self, tasks):
filtered_tasks = []
for task in tasks:
mode = task.get("cli_args", {}).get("mode")
dataset_type = task["datasets"][0][0]["type"]
model_type = task["models"][0]["type"]
dataset_type = task["datasets"][0][0].get("type", None)
model_type = task["models"][0].get("type", None)
if mode not in ["perf", "perf_viz"] and dataset_type in ONLY_PERF_DATASETS:
self.logger.warning(
f"'{dataset_type}' can only be used for performance evaluation, "
Expand Down
4 changes: 2 additions & 2 deletions ais_bench/benchmark/partitioners/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class NaivePartitioner(BasePartitioner):
"""

def __init__(self,
out_dir: str,
out_dir: str = '',
n: int = 1,
keep_keys: Optional[List[str]] = None):
super().__init__(out_dir=out_dir, keep_keys=keep_keys)
Expand All @@ -33,7 +33,7 @@ def partition(self,
model_dataset_combinations: List[Dict[str,
List[ConfigDict]]],
work_dir: str,
out_dir: str,
out_dir: str = '',
add_cfg: Dict = {}) -> List[Dict]:
"""Partition model-dataset pairs into tasks. Each task is defined as a
dict and will run independently as a unit. Its structure is as
Expand Down
2 changes: 1 addition & 1 deletion ais_bench/benchmark/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def register_module(

PARTITIONERS = Registry('partitioner', locations=get_locations('partitioners'))
RUNNERS = Registry('runner', locations=get_locations('runners'))
TASKS = Registry('task', locations=get_locations('tasks'))
TASKS = Registry('task', locations=get_locations('tasks') + get_locations('tasks.custom_tasks'))
MODELS = Registry('model', locations=get_locations('models'))
# TODO: LOAD_DATASET -> DATASETS
LOAD_DATASET = Registry('load_dataset', locations=get_locations('datasets'))
Expand Down
30 changes: 26 additions & 4 deletions ais_bench/benchmark/summarizers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,29 +286,51 @@ def _format_table(self, parsed_results, dataset_metrics, dataset_eval_mode, requ
elif isinstance(item, (list, tuple)):
summarizer_dataset_abbrs.append((item[0], item[1]))

has_total_count = False
for dataset_abbr in dataset_metrics:
if 'total_count' in dataset_metrics[dataset_abbr]:
has_total_count = True
break

table = []
header = ['dataset', 'version', 'metric', 'mode'] + self.model_abbrs
if has_total_count:
header = ['dataset', 'version', 'metric', 'mode', 'total_count'] + self.model_abbrs
table.append(header)
for dataset_abbr, metric in summarizer_dataset_abbrs:
if dataset_abbr not in dataset_metrics:
if not skip_all_slash:
table.append([dataset_abbr, '-', '-', '-'] + ['-'] * len(self.model_abbrs))
if has_total_count:
table.append([dataset_abbr, '-', '-', '-', '-'] + ['-'] * len(self.model_abbrs))
else:
table.append([dataset_abbr, '-', '-', '-'] + ['-'] * len(self.model_abbrs))
continue
if metric is None:
metric = dataset_metrics[dataset_abbr][0]
elif metric in dataset_metrics[dataset_abbr]:
pass
else:
if not skip_all_slash:
table.append([dataset_abbr, '-', '-', '-'] + ['-'] * len(self.model_abbrs))
if has_total_count:
table.append([dataset_abbr, '-', '-', '-', '-'] + ['-'] * len(self.model_abbrs))
else:
table.append([dataset_abbr, '-', '-', '-'] + ['-'] * len(self.model_abbrs))
continue

total_count_value = '/'
if 'total_count' in dataset_metrics[dataset_abbr]:
first_model_abbr = self.model_abbrs[0]
if dataset_abbr in parsed_results[first_model_abbr] and 'total_count' in parsed_results[first_model_abbr][dataset_abbr]:
total_count_value = str(int(parsed_results[first_model_abbr][dataset_abbr]['total_count']))

row = [dataset_abbr, prompt_version.get(dataset_abbr, '-'), metric, dataset_eval_mode.get(dataset_abbr, '-')]
if has_total_count:
row.append(total_count_value)
for model_abbr in self.model_abbrs:
if dataset_abbr in parsed_results[model_abbr]:
row.append('{:.02f}'.format(parsed_results[model_abbr][dataset_abbr][metric]))
correct_count = parsed_results[model_abbr][dataset_abbr].pop('correct_count', None)
total_count = parsed_results[model_abbr][dataset_abbr].pop('total_count', None)
correct_count = parsed_results[model_abbr][dataset_abbr].get('correct_count', None)
total_count = parsed_results[model_abbr][dataset_abbr].get('total_count', None)
if correct_count is not None and total_count is not None:
row[-1] = str(row[-1]) + f' ({correct_count}/{total_count})'
else:
Expand Down
8 changes: 8 additions & 0 deletions ais_bench/benchmark/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ def get_output_paths(self, file_extension: str = "json") -> List[str]:
return output_paths


class EmptyTask(BaseTask):
def run(self):
pass

def get_command(self, cfg_path, template) -> str:
return ""


class TaskStateManager:
def __init__(self, tmp_path: str, task_name: str, is_debug: bool, refresh_interval: int = 0.5):
self.logger = AISLogger()
Expand Down
Empty file.
Loading
Loading