nano-graphrag NO.3
用于提取发现的摘要。如果发现是字符串,则直接返回;否则返回其摘要字段。
Splitter.py
这个比较简单,就一个类,就是用来给文本分块。
解释一下chunk_overlap就是在分的时候为了让部分之间具有相关性,会重叠一部分文字。
通过分隔符分为子列再根据chunk_size划分为块,设置重叠部分后,如果有超出chunk_size的切断
主要方法
-
split_tokens:- 接受一个整数列表
tokens,并返回经过分隔符分割和合并后的子列表。
- 接受一个整数列表
-
_split_tokens_with_separators:- 该方法根据指定的分隔符将令牌列表分割成多个子列表。
- 它遍历令牌列表,查找分隔符并根据
keep_separator的值决定是否保留分隔符。 - 返回一个包含分割结果的列表。
-
_merge_splits:- 将分割后的子列表合并成块,确保每个块的大小不超过
chunk_size。 - 如果合并后的块数量为 1 且其大小超过
chunk_size,则调用_split_chunk方法进一步分割。 - 如果设置了重叠,则调用
_enforce_overlap方法以确保块之间的重叠。
- 将分割后的子列表合并成块,确保每个块的大小不超过
-
_split_chunk:- 将一个超出
chunk_size的块进一步分割成多个小块,确保每个小块的大小不超过chunk_size,并考虑重叠。
- 将一个超出
class SeparatorSplitter:
def __init__(
self,
separators: Optional[List[List[int]]] = None,
keep_separator: Union[bool, Literal["start", "end"]] = "end",
chunk_size: int = 4000,
chunk_overlap: int = 200,
length_function: callable = len,
):
self._separators = separators or []
self._keep_separator = keep_separator
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._length_function = length_function
def split_tokens(self, tokens: List[int]) -> List[List[int]]:
splits = self._split_tokens_with_separators(tokens)
return self._merge_splits(splits)
def _split_tokens_with_separators(self, tokens: List[int]) -> List[List[int]]:
splits = []
current_split = []
i = 0
while i < len(tokens):
separator_found = False
for separator in self._separators:
if tokens[i:i+len(separator)] == separator:
if self._keep_separator in [True, "end"]:
current_split.extend(separator)
if current_split:
splits.append(current_split)
current_split = []
if self._keep_separator == "start":
current_split.extend(separator)
i += len(separator)
separator_found = True
break
if not separator_found:
current_split.append(tokens[i])
i += 1
if current_split:
splits.append(current_split)
return [s for s in splits if s]
def _merge_splits(self, splits: List[List[int]]) -> List[List[int]]:
if not splits:
return []
merged_splits = []
current_chunk = []
for split in splits:
if not current_chunk:
current_chunk = split
elif self._length_function(current_chunk) + self._length_function(split) <= self._chunk_size:
current_chunk.extend(split)
else:
merged_splits.append(current_chunk)
current_chunk = split
if current_chunk:
merged_splits.append(current_chunk)
if len(merged_splits) == 1 and self._length_function(merged_splits[0]) > self._chunk_size:
return self._split_chunk(merged_splits[0])
if self._chunk_overlap > 0:
return self._enforce_overlap(merged_splits)
return merged_splits
def _split_chunk(self, chunk: List[int]) -> List[List[int]]:
result = []
for i in range(0, len(chunk), self._chunk_size - self._chunk_overlap):
new_chunk = chunk[i:i + self._chunk_size]
if len(new_chunk) > self._chunk_overlap: # 只有当 chunk 长度大于 overlap 时才添加
result.append(new_chunk)
return result
def _enforce_overlap(self, chunks: List[List[int]]) -> List[List[int]]:
result = []
for i, chunk in enumerate(chunks):
if i == 0:
result.append(chunk)
else:
overlap = chunks[i-1][-self._chunk_overlap:]
new_chunk = overlap + chunk
if self._length_function(new_chunk) > self._chunk_size:
new_chunk = new_chunk[:self._chunk_size]
result.append(new_chunk)
return result
op.py
大头来了。一千多行的代码。开始了。。。
chunking_by_token_size
这个函数应该很好理解,就是对文本分块,其中results存储了各个块的内容,id,描述,长度,文件id,这个分块是根据max_size分的
def chunking_by_token_size(
tokens_list: list[list[int]],
doc_keys,
tiktoken_model,
overlap_token_size=128,
max_token_size=1024,
):
results = []
for index, tokens in enumerate(tokens_list):
chunk_token = []
lengths = []
for start in range(0, len(tokens), max_token_size - overlap_token_size):
chunk_token.append(tokens[start : start + max_token_size])
lengths.append(min(max_token_size, len(tokens) - start))
# here somehow tricky, since the whole chunk tokens is list[list[list[int]]] for corpus(doc(chunk)),so it can't be decode entirely
chunk_token = tiktoken_model.decode_batch(chunk_token)
for i, chunk in enumerate(chunk_token):
results.append(
{
"tokens": lengths[i],
"content": chunk.strip(),
"chunk_order_index": i,
"full_doc_id": doc_keys[index],
}
)
return results
chunking_by_seperators
这是根据分隔符分的,并且文本之间有重叠,我理解就是现根据数量划分,然后要理解文字,肯定要考虑分隔符,就再加一个划分的函数,并且保证重叠部分。
虽然是根据分割符划分完还合并,但是我觉得是对这个块内的序列,而不是对所有文本每个序列。
def chunking_by_seperators(
tokens_list: list[list[int]],
doc_keys,
tiktoken_model,
overlap_token_size=128,
max_token_size=1024,
):
splitter = SeparatorSplitter(
separators=[
tiktoken_model.encode(s) for s in PROMPTS["default_text_separator"]
],
chunk_size=max_token_size,
chunk_overlap=overlap_token_size,
)
results = []
for index, tokens in enumerate(tokens_list):
chunk_token = splitter.split_tokens(tokens)
lengths = [len(c) for c in chunk_token]
# here somehow tricky, since the whole chunk tokens is list[list[list[int]]] for corpus(doc(chunk)),so it can't be decode entirely
chunk_token = tiktoken_model.decode_batch(chunk_token)
for i, chunk in enumerate(chunk_token):
results.append(
{
"tokens": lengths[i],
"content": chunk.strip(),
"chunk_order_index": i,
"full_doc_id": doc_keys[index],
}
)
return results
get_chunks
函数的主要作用是从一组新的文档中提取和生成块(chunks),并返回一个包含这些块的字典。指定分块函数chunking_by_token_size,并允许传递额外的参数来定制分块行为。
函数参数
1. new_docs
- 一个字典,键是文档的唯一标识符,值是包含文档内容的字典(通常包含 `content` 字段)。
2. chunk_func
- 一个可调用的分块函数,默认为 `chunking_by_token_size`。该函数用于将令牌化的文档分割成块。
3. **chunk_func_params 注:**是不确定参数数量
- 额外的关键字参数,这些参数将传递给 `chunk_func`,以便定制分块的行为。
函数逻辑
1. **初始化插入块的字典**:
- `inserting_chunks` 用于存储生成的块,键是块的哈希 ID,值是块的详细信息。
2. **准备文档列表和键**:
- 使用 `list(new_docs.items())` 将 `new_docs` 转换为列表,便于遍历。
- 从文档中提取内容和对应的键:
- `docs` 列表包含所有文档的内容。
- `doc_keys` 列表包含所有文档的唯一标识符。
3. **编码文档**:
- 使用 `tiktoken` 库中的 `encoding_for_model` 函数获取适用于指定模型(在这里是 `"gpt-4o"`)的编码器。
- 使用 `ENCODER.encode_batch(docs, num_threads=16)` 将文档内容批量编码为令牌,支持多线程处理以提高效率。
4. **调用分块函数**:
- 调用 `chunk_func`(默认为 `chunking_by_token_size`),将编码后的令牌、文档键和编码器传递给它,并传递任何额外的参数。
5. **构建插入块字典**:
- 遍历生成的块,将每个块的内容通过 `compute_mdhash_id` 函数计算出一个唯一的哈希 ID,并将其与块的详细信息一起存储在 `inserting_chunks` 字典中。
6. **返回结果**:
- 最后,返回包含所有生成块的字典 `inserting_chunks`。
返回结果的结构
- 返回的 `inserting_chunks` 字典的结构如下:
- 键是通过 `compute_mdhash_id` 生成的唯一哈希 ID(以 `"chunk-"` 为前缀)。
- 值是包含块信息的字典,通常包括块的内容和其他相关信息。
总结:就是从文档里提取文字转成块,编码内容,生成块的哈希ID
def get_chunks(new_docs, chunk_func=chunking_by_token_size, **chunk_func_params):
inserting_chunks = {}
new_docs_list = list(new_docs.items())
docs = [new_doc[1]["content"] for new_doc in new_docs_list]
doc_keys = [new_doc[0] for new_doc in new_docs_list]
ENCODER = tiktoken.encoding_for_model("gpt-4o")
tokens = ENCODER.encode_batch(docs, num_threads=16)
chunks = chunk_func(
tokens, doc_keys=doc_keys, tiktoken_model=ENCODER, **chunk_func_params
)
for chunk in chunks:
inserting_chunks.update(
{compute_mdhash_id(chunk["content"], prefix="chunk-"): chunk}
)
return inserting_chunks
_handle_entity_relation_summary
是一个异步函数,旨在生成实体或关系描述的摘要。使用大模型处理description并生成摘要。
函数参数
1. `entity_or_relation_name`:
- 一个字符串,表示要生成摘要的实体或关系的名称。
函数逻辑
1. **提取配置参数**:
- 从 `global_config` 中提取以下参数:
- `use_llm_func`: 用于生成摘要的语言模型函数。
- `llm_max_tokens`: 语言模型允许的最大令牌数。
- `tiktoken_model_name`: 用于编码和解码的 tiktoken 模型名称。
- `summary_max_tokens`: 生成摘要时允许的最大令牌数。
2. **编码描述**:
- 使用 `encode_string_by_tiktoken` 函数将输入的 `description` 编码为令牌,使用指定的 tiktoken 模型。
3. **检查令牌长度**:
- 如果编码后的令牌数量小于 `summary_max_tokens`,则表示描述不需要摘要,直接返回原始描述。
4. **准备提示模板**:
- 从 `PROMPTS` 中获取用于生成摘要的提示模板(`summarize_entity_descriptions`)。
5. **解码令牌**:
- 使用 `decode_tokens_by_tiktoken` 函数解码前 `llm_max_tokens` 个令牌,以便将其作为上下文传递给语言模型。
6. **构建上下文字典**:
- 创建一个字典 `context_base`,包含以下信息:
- `entity_name`: 实体或关系的名称。
- `description_list`: 解码后的描述字符串,按 `GRAPH_FIELD_SEP` 分隔。
7. **格式化提示**:
- 使用 `prompt_template` 和 `context_base` 中的内容格式化生成最终的提示(`use_prompt`)。
8. **记录调试信息**:
- 使用 `logger.debug` 记录触发摘要生成的实体或关系名称。
9. **调用语言模型生成摘要**:
- 使用 `await` 调用 `use_llm_func` 函数,传入格式化后的提示和最大令牌数,生成摘要。
10. **返回摘要**:
- 返回生成的摘要。
总结:根据实体和描述编码后输入模型,模型根据提示给出回答,再将回答解码,填入模板
async def _handle_entity_relation_summary(
entity_or_relation_name: str,
description: str,
global_config: dict,
) -> str:
use_llm_func: callable = global_config["cheap_model_func"]
llm_max_tokens = global_config["cheap_model_max_token_size"]
tiktoken_model_name = global_config["tiktoken_model_name"]
summary_max_tokens = global_config["entity_summary_to_max_tokens"]
tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name)
if len(tokens) < summary_max_tokens: # No need for summary
return description
prompt_template = PROMPTS["summarize_entity_descriptions"]
use_description = decode_tokens_by_tiktoken(
tokens[:llm_max_tokens], model_name=tiktoken_model_name
)
context_base = dict(
entity_name=entity_or_relation_name,
description_list=use_description.split(GRAPH_FIELD_SEP),
)
use_prompt = prompt_template.format(**context_base)
logger.debug(f"Trigger summary: {entity_or_relation_name}")
summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens)
return summary
_handle_single_entity_extraction
处理单个实体的情况就是清理一下无用信息,构建词典传入参数
async def _handle_single_entity_extraction(
record_attributes: list[str],
chunk_key: str,
):
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
return None
# add this record as a node in the G
entity_name = clean_str(record_attributes[1].upper())
if not entity_name.strip():
return None
entity_type = clean_str(record_attributes[2].upper())
entity_description = clean_str(record_attributes[3])
entity_source_id = chunk_key
return dict(
entity_name=entity_name,
entity_type=entity_type,
description=entity_description,
source_id=entity_source_id,
)
同理
async def _handle_single_relationship_extraction(
record_attributes: list[str],
chunk_key: str,
):
if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
return None
# add this record as edge
source = clean_str(record_attributes[1].upper())
target = clean_str(record_attributes[2].upper())
edge_description = clean_str(record_attributes[3])
edge_source_id = chunk_key
weight = (
float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0
)
return dict(
src_id=source,
tgt_id=target,
weight=weight,
description=edge_description,
source_id=edge_source_id,
)
_merge_nodes_then_upsert
检查边是否存在
- 检查边的存在性: 使用
has_edge方法检查从src_id到tgt_id的边是否存在。 - 获取已存在的边数据: 如果边存在,调用
get_edge方法获取边的详细信息,并将其属性(如权重、源 ID、描述和顺序)存储到相应的列表中。
计算新边的属性
- 顺序: 通过
min函数计算新边的顺序,确保它是所有边数据中的最小值。 - 权重: 计算新边的总权重,包含新边数据和已存在边的权重。
- 描述: 使用
join和set组合新边数据和已存在边的描述,确保描述唯一且排序。
节点插入
- 检查节点存在性: 对于源节点和目标节点,检查它们是否存在于知识图谱中。
- 插入节点: 如果节点不存在,调用
upsert_node方法插入节点,传入节点的相关数据(如源 ID、描述和实体类型)。
更新边
- 处理描述: 调用
_handle_entity_relation_summary函数处理描述 - 插入或更新边: 使用
upsert_edge方法插入或更新边,传入计算得到的权重、描述、源 ID 和顺序。
async def _merge_nodes_then_upsert(
entity_name: str,
nodes_data: list[dict],
knwoledge_graph_inst: BaseGraphStorage,
global_config: dict,
):
already_entitiy_types = []
already_source_ids = []
already_description = []
already_node = await knwoledge_graph_inst.get_node(entity_name)
if already_node is not None:
already_entitiy_types.append(already_node["entity_type"])
already_source_ids.extend(
split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP])
)
already_description.append(already_node["description"])
entity_type = sorted(
Counter(
[dp["entity_type"] for dp in nodes_data] + already_entitiy_types
).items(),
key=lambda x: x[1],
reverse=True,
)[0][0]
description = GRAPH_FIELD_SEP.join(
sorted(set([dp["description"] for dp in nodes_data] + already_description))
)
source_id = GRAPH_FIELD_SEP.join(
set([dp["source_id"] for dp in nodes_data] + already_source_ids)
)
description = await _handle_entity_relation_summary(
entity_name, description, global_config
)
node_data = dict(
entity_type=entity_type,
description=description,
source_id=source_id,
)
await knwoledge_graph_inst.upsert_node(
entity_name,
node_data=node_data,
)
node_data["entity_name"] = entity_name
return node_data
extract_entities
使用大llm根据提示词提取实体,共3次
async def extract_entities(
chunks: dict[str, TextChunkSchema],
knwoledge_graph_inst: BaseGraphStorage,
entity_vdb: BaseVectorStorage,
global_config: dict,
) -> Union[BaseGraphStorage, None]:
use_llm_func: callable = global_config["best_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
ordered_chunks = list(chunks.items())
entity_extract_prompt = PROMPTS["entity_extraction"]
context_base = dict(
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
entity_types=",".join(PROMPTS["DEFAULT_ENTITY_TYPES"]),
)
continue_prompt = PROMPTS["entiti_continue_extraction"]
if_loop_prompt = PROMPTS["entiti_if_loop_extraction"]
already_processed = 0
already_entities = 0
already_relations = 0
_process_single_content
结合上下文和输入成hint_prompt输入到达模型中,并将模型回答打包成历史信息,反复提取实体,将结果提取到历史信息。使用 if_loop_prompt 提示 LLM 是否继续提取,结果存储在 if_loop_result 中。如果结果不是 "yes",则退出循环。这里的yes是跟用户交互得到的。
感觉这是核心部分
- 正则表达式匹配: 对每条记录应用正则表达式,提取括号内的内容。
- 提取实体: 调用
_handle_single_entity_extraction函数处理记录属性,如果返回的实体不为None,则将其添加到maybe_nodes中。 - 提取关系: 调用
_handle_single_relationship_extraction函数处理记录属性,如果返回的关系不为None,则将其添加到maybe_edges中。 -
异步处理多个内容块
- 并发处理: 使用
asyncio.gather并发处理所有内容块,调用_process_single_content函数。 - 合并结果: 初始化
maybe_nodes和maybe_edges,将所有结果合并到这两个字典中。 - 合并节点: 使用
asyncio.gather并发调用_merge_nodes_then_upsert函数,将所有节点合并并更新到知识图谱中。 - 合并边: 使用
asyncio.gather并发调用_merge_edges_then_upsert函数,将所有边合并并更新到知识图谱中。
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
nonlocal already_processed, already_entities, already_relations
chunk_key = chunk_key_dp[0]
chunk_dp = chunk_key_dp[1]
content = chunk_dp["content"]
hint_prompt = entity_extract_prompt.format(**context_base, input_text=content)
final_result = await use_llm_func(hint_prompt)
history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
for now_glean_index in range(entity_extract_max_gleaning):
glean_result = await use_llm_func(continue_prompt, history_messages=history)
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
final_result += glean_result
if now_glean_index == entity_extract_max_gleaning - 1:
break
if_loop_result: str = await use_llm_func(
if_loop_prompt, history_messages=history
)
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
if if_loop_result != "yes":
break
records = split_string_by_multi_markers(
final_result,
[context_base["record_delimiter"], context_base["completion_delimiter"]],
)
maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list)
for record in records:
record = re.search(r"\((.*)\)", record)
if record is None:
continue
record = record.group(1)
record_attributes = split_string_by_multi_markers(
record, [context_base["tuple_delimiter"]]
)
if_entities = await _handle_single_entity_extraction(
record_attributes, chunk_key
)
if if_entities is not None:
maybe_nodes[if_entities["entity_name"]].append(if_entities)
continue
if_relation = await _handle_single_relationship_extraction(
record_attributes, chunk_key
)
if if_relation is not None:
maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
if_relation
)
already_processed += 1
already_entities += len(maybe_nodes)
already_relations += len(maybe_edges)
now_ticks = PROMPTS["process_tickers"][
already_processed % len(PROMPTS["process_tickers"])
]
print(
f"{now_ticks} Processed {already_processed}({already_processed*100//len(ordered_chunks)}%) chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
end="",
flush=True,
)
return dict(maybe_nodes), dict(maybe_edges)
# use_llm_func is wrapped in ascynio.Semaphore, limiting max_async callings
results = await asyncio.gather(
*[_process_single_content(c) for c in ordered_chunks]
)
print() # clear the progress bar
maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list)
for m_nodes, m_edges in results:
for k, v in m_nodes.items():
maybe_nodes[k].extend(v)
for k, v in m_edges.items():
# it's undirected graph
maybe_edges[tuple(sorted(k))].extend(v)
all_entities_data = await asyncio.gather(
*[
_merge_nodes_then_upsert(k, v, knwoledge_graph_inst, global_config)
for k, v in maybe_nodes.items()
]
)
await asyncio.gather(
*[
_merge_edges_then_upsert(k[0], k[1], v, knwoledge_graph_inst, global_config)
for k, v in maybe_edges.items()
]
)
if not len(all_entities_data):
logger.warning("Didn't extract any entities, maybe your LLM is not working")
return None
if entity_vdb is not None:
data_for_vdb = {
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
"content": dp["entity_name"] + dp["description"],
"entity_name": dp["entity_name"],
}
for dp in all_entities_data
}
await entity_vdb.upsert(data_for_vdb)
return knwoledge_graph_inst
_pack_single_community_by_sub_communities
排序子社区: 根据每个子社区的 occurrence 属性对子社区进行降序排序
截断子社区列表: 调用 truncate_list_by_token_size 函数,限制子社区列表的长度,以确保其 report_string 的总令牌数不超过 max_token_size。
转换为CSV格式
该函数的主要功能是处理一个社区及其子社区的信息,生成描述并提取相关的节点和边。限制输出的长度和格式化数据。
def _pack_single_community_by_sub_communities(
community: SingleCommunitySchema,
max_token_size: int,
already_reports: dict[str, CommunitySchema],
) -> tuple[str, int]:
# TODO
all_sub_communities = [
already_reports[k] for k in community["sub_communities"] if k in already_reports
]
all_sub_communities = sorted(
all_sub_communities, key=lambda x: x["occurrence"], reverse=True
)
may_trun_all_sub_communities = truncate_list_by_token_size(
all_sub_communities,
key=lambda x: x["report_string"],
max_token_size=max_token_size,
)
sub_fields = ["id", "report", "rating", "importance"]
sub_communities_describe = list_of_list_to_csv(
[sub_fields]
+ [
[
i,
c["report_string"],
c["report_json"].get("rating", -1),
c["occurrence"],
]
for i, c in enumerate(may_trun_all_sub_communities)
]
)
already_nodes = []
already_edges = []
for c in may_trun_all_sub_communities:
already_nodes.extend(c["nodes"])
already_edges.extend([tuple(e) for e in c["edges"]])
return (
sub_communities_describe,
len(encode_string_by_tiktoken(sub_communities_describe)),
set(already_nodes),
set(already_edges),
)
_pack_single_community_describe
升序排列节点,根据边的key升序排列还是咋
-
排序节点: 将社区中的节点按字母顺序排序
nodes_in_order = sorted(community["nodes"]) -
排序边: 将社区中的边按源和目标节点的组合排序。大概率是字符拼接。
edges_in_order = sorted(community["edges"], key=lambda x: x[0] + x[1])
并发处理多个模块
插入节点,根据src和tgt源节点和目标节点构成边
-
排序节点列表: 根据节点的度数对节点列表进行降序排序。
nodes_list_data = sorted(nodes_list_data, key=lambda x: x[-1], reverse=True)
构建节点的列表,序号,名字,类型描述,等插入之后获取?await函数的作用?
节点根据描述长度还是啥倒序排列
划分数据
总结:描述一个社区及其相关节点和边的信息,并在必要时使用子社区来确保描述的完整性。异步获取数据和动态调整输出。
async def _pack_single_community_describe(
knwoledge_graph_inst: BaseGraphStorage,
community: SingleCommunitySchema,
max_token_size: int = 12000,
already_reports: dict[str, CommunitySchema] = {},
global_config: dict = {},
) -> str:
nodes_in_order = sorted(community["nodes"])
edges_in_order = sorted(community["edges"], key=lambda x: x[0] + x[1])
nodes_data = await asyncio.gather(
*[knwoledge_graph_inst.get_node(n) for n in nodes_in_order]
)
edges_data = await asyncio.gather(
*[knwoledge_graph_inst.get_edge(src, tgt) for src, tgt in edges_in_order]
)
node_fields = ["id", "entity", "type", "description", "degree"]
edge_fields = ["id", "source", "target", "description", "rank"]
nodes_list_data = [
[
i,
node_name,
node_data.get("entity_type", "UNKNOWN"),
node_data.get("description", "UNKNOWN"),
await knwoledge_graph_inst.node_degree(node_name),
]
for i, (node_name, node_data) in enumerate(zip(nodes_in_order, nodes_data))
]
nodes_list_data = sorted(nodes_list_data, key=lambda x: x[-1], reverse=True)
nodes_may_truncate_list_data = truncate_list_by_token_size(
nodes_list_data, key=lambda x: x[3], max_token_size=max_token_size // 2
)
edges_list_data = [
[
i,
edge_name[0],
edge_name[1],
edge_data.get("description", "UNKNOWN"),
await knwoledge_graph_inst.edge_degree(*edge_name),
]
for i, (edge_name, edge_data) in enumerate(zip(edges_in_order, edges_data))
]
edges_list_data = sorted(edges_list_data, key=lambda x: x[-1], reverse=True)
edges_may_truncate_list_data = truncate_list_by_token_size(
edges_list_data, key=lambda x: x[3], max_token_size=max_token_size // 2
)
truncated = len(nodes_list_data) > len(nodes_may_truncate_list_data) or len(
edges_list_data
) > len(edges_may_truncate_list_data)
# If context is exceed the limit and have sub-communities:
report_describe = ""
need_to_use_sub_communities = (
truncated and len(community["sub_communities"]) and len(already_reports)
)
force_to_use_sub_communities = global_config["addon_params"].get(
"force_to_use_sub_communities", False
)
if need_to_use_sub_communities or force_to_use_sub_communities:
logger.debug(
f"Community {community['title']} exceeds the limit or you set force_to_use_sub_communities to True, using its sub-communities"
)
report_describe, report_size, contain_nodes, contain_edges = (
_pack_single_community_by_sub_communities(
community, max_token_size, already_reports
)
)
report_exclude_nodes_list_data = [
n for n in nodes_list_data if n[1] not in contain_nodes
]
report_include_nodes_list_data = [
n for n in nodes_list_data if n[1] in contain_nodes
]
report_exclude_edges_list_data = [
e for e in edges_list_data if (e[1], e[2]) not in contain_edges
]
report_include_edges_list_data = [
e for e in edges_list_data if (e[1], e[2]) in contain_edges
]
# if report size is bigger than max_token_size, nodes and edges are []
nodes_may_truncate_list_data = truncate_list_by_token_size(
report_exclude_nodes_list_data + report_include_nodes_list_data,
key=lambda x: x[3],
max_token_size=(max_token_size - report_size) // 2,
)
edges_may_truncate_list_data = truncate_list_by_token_size(
report_exclude_edges_list_data + report_include_edges_list_data,
key=lambda x: x[3],
max_token_size=(max_token_size - report_size) // 2,
)
nodes_describe = list_of_list_to_csv([node_fields] + nodes_may_truncate_list_data)
edges_describe = list_of_list_to_csv([edge_fields] + edges_may_truncate_list_data)
return f"""-----Reports-----
```csv
{report_describe}
```
-----Entities-----
```csv
{nodes_describe}
```
-----Relationships-----
```csv
{edges_describe}'''"""
过滤节点和边
-
排除节点: 创建一个列表
report_exclude_nodes_list_data,其中包含所有不在contain_nodes集合中的节点。 -
包含节点: 创建一个列表
report_include_nodes_list_data,其中包含所有在contain_nodes集合中的节点。 -
排除边,包含边
-
截断,以限制长度
-
生成描述,返回格式化格式
_community_report_json_to_str
获取标题: 从 parsed_output 中获取标题,如果没有则默认为 "Report"。
定义发现摘要函数: 用于提取发现的摘要。如果发现是字符串,则直接返回;否则返回其摘要字段。
join 方法
"\n\n".join(...)是将生成器表达式的结果连接成一个单一的字符串。"\n\n"是连接符,表示在每个字符串之间插入两个换行符(即空行)。- 这意味着每个发现的摘要和解释之间将有一个空行,以便在最终的报告中更清晰地分隔每个部分。
def _community_report_json_to_str(parsed_output: dict) -> str:
"""refer official graphrag: index/graph/extractors/community_reports"""
title = parsed_output.get("title", "Report")
summary = parsed_output.get("summary", "")
findings = parsed_output.get("findings", [])
def finding_summary(finding: dict):
if isinstance(finding, str):
return finding
return finding.get("summary")
def finding_explanation(finding: dict):
if isinstance(finding, str):
return ""
return finding.get("explanation")
report_sections = "\n\n".join(
f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings
)
return f"# {title}\n\n{summary}\n\n{report_sections}"
async def generate_community_report(
community_report_kv: BaseKVStorage[CommunitySchema],
knwoledge_graph_inst: BaseGraphStorage,
global_config: dict,
):
llm_extra_kwargs = global_config["special_community_report_llm_kwargs"]
use_llm_func: callable = global_config["best_model_func"]
use_string_json_convert_func: callable = global_config[
"convert_response_to_json_func"
]
community_report_prompt = PROMPTS["community_report"]
communities_schema = await knwoledge_graph_inst.community_schema()
community_keys, community_values = list(communities_schema.keys()), list(
communities_schema.values()
)
already_processed = 0
_from_single_community_report
为多个社区生成报告。它首先定义了一个异步函数 _form_single_community_report 来处理单个社区的报告生成,然后在主逻辑中按层级处理所有社区,使用并发方式生成报告,并最终将结果存储到指定的存储中。通过使用异步编程,代码能够高效地处理多个社区的报告生成任务。
async def _form_single_community_report(
community: SingleCommunitySchema, already_reports: dict[str, CommunitySchema]
):
nonlocal already_processed
describe = await _pack_single_community_describe(
knwoledge_graph_inst,
community,
max_token_size=global_config["best_model_max_token_size"],
already_reports=already_reports,
global_config=global_config,
)
prompt = community_report_prompt.format(input_text=describe)
response = await use_llm_func(prompt, **llm_extra_kwargs)
data = use_string_json_convert_func(response)
already_processed += 1
now_ticks = PROMPTS["process_tickers"][
already_processed % len(PROMPTS["process_tickers"])
]
print(
f"{now_ticks} Processed {already_processed} communities\r",
end="",
flush=True,
)
return data
levels = sorted(set([c["level"] for c in community_values]), reverse=True)
logger.info(f"Generating by levels: {levels}")
community_datas = {}
for level in levels:
this_level_community_keys, this_level_community_values = zip(
*[
(k, v)
for k, v in zip(community_keys, community_values)
if v["level"] == level
]
)
this_level_communities_reports = await asyncio.gather(
*[
_form_single_community_report(c, community_datas)
for c in this_level_community_values
]
)
community_datas.update(
{
k: {
"report_string": _community_report_json_to_str(r),
"report_json": r,
**v,
}
for k, r, v in zip(
this_level_community_keys,
this_level_communities_reports,
this_level_community_values,
)
}
)
print() # clear the progress bar
await community_report_kv.upsert(community_datas)
_find_most_related_text_unit_from_entities
提取文本单元,获取与源节点关联的目标节点进行计数
async def _find_most_related_text_unit_from_entities(
node_datas: list[dict],
query_param: QueryParam,
text_chunks_db: BaseKVStorage[TextChunkSchema],
knowledge_graph_inst: BaseGraphStorage,
):
text_units = [
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
for dp in node_datas
]
edges = await asyncio.gather(
*[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
)
all_one_hop_nodes = set()
for this_edges in edges:
if not this_edges:
continue
all_one_hop_nodes.update([e[1] for e in this_edges])
all_one_hop_nodes = list(all_one_hop_nodes)
all_one_hop_nodes_data = await asyncio.gather(
*[knowledge_graph_inst.get_node(e) for e in all_one_hop_nodes]
)
all_one_hop_text_units_lookup = {
k: set(split_string_by_multi_markers(v["source_id"], [GRAPH_FIELD_SEP]))
for k, v in zip(all_one_hop_nodes, all_one_hop_nodes_data)
if v is not None
}
all_text_units_lookup = {}
for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)):
for c_id in this_text_units:
if c_id in all_text_units_lookup:
continue
relation_counts = 0
for e in this_edges:
if (
e[1] in all_one_hop_text_units_lookup
and c_id in all_one_hop_text_units_lookup[e[1]]
):
relation_counts += 1
all_text_units_lookup[c_id] = {
"data": await text_chunks_db.get_by_id(c_id),
"order": index,
"relation_counts": relation_counts,
}
if any([v is None for v in all_text_units_lookup.values()]):
logger.warning("Text chunks are missing, maybe the storage is damaged")
all_text_units = [
{"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
]
all_text_units = sorted(
all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
)
all_text_units = truncate_list_by_token_size(
all_text_units,
key=lambda x: x["data"]["content"],
max_token_size=query_param.local_max_token_for_text_unit,
)
all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
return all_text_units
_find_most_related_edges_from_entities
从给定的节点数据中查找与之相关的边。它通过获取节点的边信息,收集唯一的边,获取边的详细信息和度数,整理并排序边的数据,最后返回相关的边数据。
async def _find_most_related_edges_from_entities(
node_datas: list[dict],
query_param: QueryParam,
knowledge_graph_inst: BaseGraphStorage,
):
all_related_edges = await asyncio.gather(
*[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
)
all_edges = []
seen = set()
for this_edges in all_related_edges:
for e in this_edges:
sorted_edge = tuple(sorted(e))
if sorted_edge not in seen:
seen.add(sorted_edge)
all_edges.append(sorted_edge)
all_edges_pack = await asyncio.gather(
*[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges]
)
all_edges_degree = await asyncio.gather(
*[knowledge_graph_inst.edge_degree(e[0], e[1]) for e in all_edges]
)
all_edges_data = [
{"src_tgt": k, "rank": d, **v}
for k, v, d in zip(all_edges, all_edges_pack, all_edges_degree)
if v is not None
]
all_edges_data = sorted(
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
)
all_edges_data = truncate_list_by_token_size(
all_edges_data,
key=lambda x: x["description"],
max_token_size=query_param.local_max_token_for_local_context,
)
return all_edges_data
_build_local_query_context
构建一个本地查询上下文,整合与查询相关的实体、关系、社区报告和文本单元。它通过查询向量数据库和知识图谱,获取相关信息,并将其格式化为 CSV 格式,最终返回一个包含所有相关信息的字符串。这种结构化的返回格式便于后续处理和展示
async def _build_local_query_context(
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
community_reports: BaseKVStorage[CommunitySchema],
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
):
results = await entities_vdb.query(query, top_k=query_param.top_k)
if not len(results):
return None
node_datas = await asyncio.gather(
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
)
if not all([n is not None for n in node_datas]):
logger.warning("Some nodes are missing, maybe the storage is damaged")
node_degrees = await asyncio.gather(
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
)
node_datas = [
{**n, "entity_name": k["entity_name"], "rank": d}
for k, n, d in zip(results, node_datas, node_degrees)
if n is not None
]
use_communities = await _find_most_related_community_from_entities(
node_datas, query_param, community_reports
)
use_text_units = await _find_most_related_text_unit_from_entities(
node_datas, query_param, text_chunks_db, knowledge_graph_inst
)
use_relations = await _find_most_related_edges_from_entities(
node_datas, query_param, knowledge_graph_inst
)
logger.info(
f"Using {len(node_datas)} entites, {len(use_communities)} communities, {len(use_relations)} relations, {len(use_text_units)} text units"
)
entites_section_list = [["id", "entity", "type", "description", "rank"]]
for i, n in enumerate(node_datas):
entites_section_list.append(
[
i,
n["entity_name"],
n.get("entity_type", "UNKNOWN"),
n.get("description", "UNKNOWN"),
n["rank"],
]
)
entities_context = list_of_list_to_csv(entites_section_list)
relations_section_list = [
["id", "source", "target", "description", "weight", "rank"]
]
for i, e in enumerate(use_relations):
relations_section_list.append(
[
i,
e["src_tgt"][0],
e["src_tgt"][1],
e["description"],
e["weight"],
e["rank"],
]
)
relations_context = list_of_list_to_csv(relations_section_list)
communities_section_list = [["id", "content"]]
for i, c in enumerate(use_communities):
communities_section_list.append([i, c["report_string"]])
communities_context = list_of_list_to_csv(communities_section_list)
text_units_section_list = [["id", "content"]]
for i, t in enumerate(use_text_units):
text_units_section_list.append([i, t["content"]])
text_units_context = list_of_list_to_csv(text_units_section_list)
return f"""
-----Reports-----
```csv
{communities_context}
```
-----Entities-----
```csv
{entities_context}
```
-----Relationships-----
```csv
{relations_context}
```
-----Sources-----
```csv
{text_units_context}
```
"""
local_query函数: 处理用户查询,构建上下文并生成响应。它首先构建查询上下文,然后根据查询参数决定是否返回上下文或生成响应。_map_global_communities函数: 将社区数据分组,以便于后续处理。它根据查询参数的限制截断社区数据并将其分组。
async def local_query(
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
community_reports: BaseKVStorage[CommunitySchema],
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
) -> str:
use_model_func = global_config["best_model_func"]
context = await _build_local_query_context(
query,
knowledge_graph_inst,
entities_vdb,
community_reports,
text_chunks_db,
query_param,
)
if query_param.only_need_context:
return context
if context is None:
return PROMPTS["fail_response"]
sys_prompt_temp = PROMPTS["local_rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
)
response = await use_model_func(
query,
system_prompt=sys_prompt,
)
return response
async def _map_global_communities(
query: str,
communities_data: list[CommunitySchema],
query_param: QueryParam,
global_config: dict,
):
use_string_json_convert_func = global_config["convert_response_to_json_func"]
use_model_func = global_config["best_model_func"]
community_groups = []
while len(communities_data):
this_group = truncate_list_by_token_size(
communities_data,
key=lambda x: x["report_string"],
max_token_size=query_param.global_max_token_for_community_report,
)
community_groups.append(this_group)
communities_data = communities_data[len(this_group) :]
_process
处理被截断的社区数据,生成包含社区信息的响应。它首先将社区数据整理成一个 CSV 格式的字符串,然后构建系统提示,调用模型函数生成响应,并将响应转换为 JSON 格式,最后提取并返回 "points" 字段。通过并发处理多个社区组,代码能够高效地进行全局搜索。
async def _process(community_truncated_datas: list[CommunitySchema]) -> dict:
communities_section_list = [["id", "content", "rating", "importance"]]
for i, c in enumerate(community_truncated_datas):
communities_section_list.append(
[
i,
c["report_string"],
c["report_json"].get("rating", 0),
c["occurrence"],
]
)
community_context = list_of_list_to_csv(communities_section_list)
sys_prompt_temp = PROMPTS["global_map_rag_points"]
sys_prompt = sys_prompt_temp.format(context_data=community_context)
response = await use_model_func(
query,
system_prompt=sys_prompt,
**query_param.global_special_community_map_llm_kwargs,
)
data = use_string_json_convert_func(response)
return data.get("points", [])
logger.info(f"Grouping to {len(community_groups)} groups for global search")
responses = await asyncio.gather(*[_process(c) for c in community_groups])
return responses
全局检索
async def global_query(
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
community_reports: BaseKVStorage[CommunitySchema],
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
) -> str:
community_schema = await knowledge_graph_inst.community_schema()
community_schema = {
k: v for k, v in community_schema.items() if v["level"] <= query_param.level
}
if not len(community_schema):
return PROMPTS["fail_response"]
use_model_func = global_config["best_model_func"]
sorted_community_schemas = sorted(
community_schema.items(),
key=lambda x: x[1]["occurrence"],
reverse=True,
)
sorted_community_schemas = sorted_community_schemas[
: query_param.global_max_consider_community
]
community_datas = await community_reports.get_by_ids(
[k[0] for k in sorted_community_schemas]
)
community_datas = [c for c in community_datas if c is not None]
community_datas = [
c
for c in community_datas
if c["report_json"].get("rating", 0) >= query_param.global_min_community_rating
]
community_datas = sorted(
community_datas,
key=lambda x: (x["occurrence"], x["report_json"].get("rating", 0)),
reverse=True,
)
logger.info(f"Revtrieved {len(community_datas)} communities")
map_communities_points = await _map_global_communities(
query, community_datas, query_param, global_config
)
final_support_points = []
for i, mc in enumerate(map_communities_points):
for point in mc:
if "description" not in point:
continue
final_support_points.append(
{
"analyst": i,
"answer": point["description"],
"score": point.get("score", 1),
}
)
final_support_points = [p for p in final_support_points if p["score"] > 0]
if not len(final_support_points):
return PROMPTS["fail_response"]
final_support_points = sorted(
final_support_points, key=lambda x: x["score"], reverse=True
)
final_support_points = truncate_list_by_token_size(
final_support_points,
key=lambda x: x["answer"],
max_token_size=query_param.global_max_token_for_community_report,
)
points_context = []
for dp in final_support_points:
points_context.append(
f"""----Analyst {dp['analyst']}----
Importance Score: {dp['score']}
{dp['answer']}
"""
)
points_context = "\n".join(points_context)
if query_param.only_need_context:
return points_context
sys_prompt_temp = PROMPTS["global_reduce_rag_response"]
response = await use_model_func(
query,
sys_prompt_temp.format(
report_data=points_context, response_type=query_param.response_type
),
)
return response
朴素检索
async def naive_query(
query,
chunks_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
):
use_model_func = global_config["best_model_func"]
results = await chunks_vdb.query(query, top_k=query_param.top_k)
if not len(results):
return PROMPTS["fail_response"]
chunks_ids = [r["id"] for r in results]
chunks = await text_chunks_db.get_by_ids(chunks_ids)
maybe_trun_chunks = truncate_list_by_token_size(
chunks,
key=lambda x: x["content"],
max_token_size=query_param.naive_max_token_for_text_unit,
)
logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks")
section = "--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks])
if query_param.only_need_context:
return section
sys_prompt_temp = PROMPTS["naive_rag_response"]
sys_prompt = sys_prompt_temp.format(
content_data=section, response_type=query_param.response_type
)
response = await use_model_func(
query,
system_prompt=sys_prompt,
)
return response
更多推荐

所有评论(0)