From d72d17143bac785d1e1a2985d9f1045478773512 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Wed, 20 Nov 2024 21:26:45 +0000 Subject: [PATCH 01/33] move commit from PR 8498 to here in order to fix DCO Signed-off-by: KuntaiDu Co-authored-by: ApostaC Co-authored-by: YaoJiayi <120040070@link.cuhk.edu.cn> --- .buildkite/test-pipeline.yaml | 12 + .../disagg_overhead_benchmark.sh | 150 ++++++++++ .../disagg_performance_benchmark.sh | 168 +++++++++++ .../disagg_prefill_proxy_server.py | 61 ++++ .../disagg_benchmarks/round_robin_proxy.py | 60 ++++ .../visualize_benchmark_results.py | 46 +++ .../kv_transfer/disagg_prefill_example.sh | 97 +++++++ tests/kv_transfer/disagg_test.py | 133 +++++++++ tests/kv_transfer/module_test.py | 64 +++++ tests/kv_transfer/test_lookup_buffer.py | 161 +++++++++++ tests/kv_transfer/test_lookup_buffer.sh | 3 + tests/kv_transfer/test_send_recv.py | 156 +++++++++++ tests/kv_transfer/test_send_recv.sh | 3 + vllm/config.py | 56 ++++ vllm/distributed/kv_transfer/__init__.py | 0 .../kv_transfer/kv_connector/__init__.py | 0 .../kv_transfer/kv_connector/base.py | 194 +++++++++++++ .../kv_transfer/kv_connector/factory.py | 14 + .../kv_connector/pynccl_connector/buffer.py | 240 ++++++++++++++++ .../pynccl_connector/connector.py | 258 +++++++++++++++++ .../kv_connector/pynccl_connector/pipe.py | 265 ++++++++++++++++++ .../kv_transfer/kv_transfer_agent.py | 83 ++++++ vllm/distributed/parallel_state.py | 28 +- vllm/distributed/utils.py | 12 + vllm/engine/arg_utils.py | 82 +++++- vllm/worker/model_runner.py | 98 ++++++- vllm/worker/worker.py | 12 +- vllm/worker/worker_base.py | 1 + 28 files changed, 2439 insertions(+), 18 deletions(-) create mode 100644 benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh create mode 100644 benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh create mode 100644 benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py create mode 100644 benchmarks/disagg_benchmarks/round_robin_proxy.py create mode 100644 benchmarks/disagg_benchmarks/visualize_benchmark_results.py create mode 100644 examples/kv_transfer/disagg_prefill_example.sh create mode 100644 tests/kv_transfer/disagg_test.py create mode 100644 tests/kv_transfer/module_test.py create mode 100644 tests/kv_transfer/test_lookup_buffer.py create mode 100644 tests/kv_transfer/test_lookup_buffer.sh create mode 100644 tests/kv_transfer/test_send_recv.py create mode 100644 tests/kv_transfer/test_send_recv.sh create mode 100644 vllm/distributed/kv_transfer/__init__.py create mode 100644 vllm/distributed/kv_transfer/kv_connector/__init__.py create mode 100644 vllm/distributed/kv_transfer/kv_connector/base.py create mode 100644 vllm/distributed/kv_transfer/kv_connector/factory.py create mode 100644 vllm/distributed/kv_transfer/kv_connector/pynccl_connector/buffer.py create mode 100644 vllm/distributed/kv_transfer/kv_connector/pynccl_connector/connector.py create mode 100644 vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pipe.py create mode 100644 vllm/distributed/kv_transfer/kv_transfer_agent.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 501743c88759..aaa79a347e5a 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -473,6 +473,18 @@ steps: - pytest -v -s distributed/test_pp_cudagraph.py - pytest -v -s distributed/test_pipeline_parallel.py +- label: Disaggregated Prefill Test # 4min + working_dir: "/vllm-workspace/tests" + num_gpus: 4 + source_file_dependencies: + - vllm/distributed/parallel_state.py + - vllm/distributed/kv_transfer + - vllm/worker/worker_base.py + - vllm/worker/model_runner.py + commands: + - pytest -v -s kv_transfer/module_test.py + - pytest -v -s kv_transfer/disagg_test.py + - label: LoRA Long Context (Distributed) # 11min # This test runs llama 13B, so it is required to run on 4 GPUs. num_gpus: 4 diff --git a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh new file mode 100644 index 000000000000..e7d30001e850 --- /dev/null +++ b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh @@ -0,0 +1,150 @@ +#!/bin/bash + +# benchmark the overhead of disaggregated prefill. +# methodology: +# - send all request to prefill vLLM instance. It will buffer KV cache. +# - then send all request to decode instance. +# - The TTFT of decode instance is the overhead. + +set -ex + +kill_gpu_processes() { + # kill all processes on GPU. + pkill -f pt_main_thread + sleep 10 + + # remove vllm config file + rm -rf ~/.config/vllm + + # Print the GPU memory usage + # so that we know if all GPU processes are killed. + gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0) + # The memory usage should be 0 MB. + echo "GPU 0 Memory Usage: $gpu_memory_usage MB" +} + +wait_for_server() { + # wait for vllm server to start + # return 1 if vllm server crashes + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + + +benchmark() { + + export VLLM_LOGGING_LEVEL=DEBUG + export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') + + # compare chunked prefill with disaggregated prefill + + results_folder="./results" + model="meta-llama/Meta-Llama-3.1-8B-Instruct" + dataset_name="sonnet" + dataset_path="../sonnet_4x.txt" + num_prompts=10 + qps=$1 + prefix_len=50 + input_len=2048 + output_len=$2 + + + CUDA_VISIBLE_DEVICES=0 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model meta-llama/Meta-Llama-3.1-8B-Instruct \ + --port 8100 \ + --max-model-len 10000 \ + --gpu-memory-utilization 0.6 \ + --kv-connector PyNcclConnector \ + --kv-role kv_producer \ + --kv-rank 0 \ + --kv-parallel-size 2 \ + --kv-buffer-size 1e10 & + + + CUDA_VISIBLE_DEVICES=1 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model meta-llama/Meta-Llama-3.1-8B-Instruct \ + --port 8200 \ + --max-model-len 10000 \ + --gpu-memory-utilization 0.6 \ + --kv-connector PyNcclConnector \ + --kv-role kv_consumer \ + --kv-rank 1 \ + --kv-parallel-size 2 \ + --kv-buffer-size 1e10 & + + wait_for_server 8100 + wait_for_server 8200 + + # let the prefill instance finish prefill + python3 ../benchmark_serving.py \ + --backend vllm \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --sonnet-input-len $input_len \ + --sonnet-output-len "$output_len" \ + --sonnet-prefix-len $prefix_len \ + --num-prompts $num_prompts \ + --port 8100 \ + --save-result \ + --result-dir $results_folder \ + --result-filename disagg_prefill_2xtp4.json \ + --request-rate "inf" + + + # send the request to decode. + # The TTFT of this command will be the overhead of disagg prefill impl. + python3 ../benchmark_serving.py \ + --backend vllm \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --sonnet-input-len $input_len \ + --sonnet-output-len "$output_len" \ + --sonnet-prefix-len $prefix_len \ + --num-prompts $num_prompts \ + --port 8200 \ + --save-result \ + --result-dir $results_folder \ + --result-filename disagg_prefill_2xtp4.json \ + --request-rate "$qps" + kill_gpu_processes + +} + + +main() { + + (which wget && which curl) || (apt-get update && apt-get install -y wget curl) + (which jq) || (apt-get -y install jq) + (which socat) || (apt-get -y install socat) + + pip install quart httpx + + cd "$(dirname "$0")" + + cd .. + # create sonnet-4x.txt + echo "" > sonnet_4x.txt + for _ in {1..4} + do + cat sonnet.txt >> sonnet_4x.txt + done + cd disagg_benchmarks + + rm -rf results + mkdir results + + default_qps=1 + default_output_len=1 + benchmark $default_qps $default_output_len + +} + + +main "$@" diff --git a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh new file mode 100644 index 000000000000..2bd7aa14de5f --- /dev/null +++ b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh @@ -0,0 +1,168 @@ +#!/bin/bash + +# Requirement: 8x H100 GPUs. + + +# Model: neuralmagic/Meta-Llama-3-70B-Instruct-FP8-KV +# Query: 2048 input tokens, 11 output tokens, QPS 4, 500 requests +# Resource: 8x H100 +# Approaches: +# 1. Chunked prefill: 1 vllm instance with tp=8 +# 2. Chunked prefill: 2 vllm instance with tp=4, equivalent to 1 tp=4 instance with QPS 4 +# 3. Disaggregated prefill: 1 prefilling instance and 1 decoding instance +# Prefilling instance: max_output_token=1 +# Decoding instance: force the input tokens be the same across requests to bypass prefilling + +set -ex + +kill_gpu_processes() { + # kill all processes on GPU. + pgrep pt_main_thread | xargs -r kill -9 + pgrep python3 | xargs -r kill -9 + for port in 8000 8100 8200; do lsof -t -i:$port | xargs -r kill -9; done + sleep 1 +} + +wait_for_server() { + # wait for vllm server to start + # return 1 if vllm server crashes + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + + +launch_chunked_prefill() { + model="meta-llama/Meta-Llama-3.1-8B-Instruct" + # disagg prefill + CUDA_VISIBLE_DEVICES=0 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8100 \ + --max-model-len 10000 \ + --enable-chunked-prefill \ + --gpu-memory-utilization 0.6 & + CUDA_VISIBLE_DEVICES=1 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8200 \ + --max-model-len 10000 \ + --enable-chunked-prefill \ + --gpu-memory-utilization 0.6 & + wait_for_server 8100 + wait_for_server 8200 + python3 round_robin_proxy.py & + sleep 1 +} + + +launch_disagg_prefill() { + model="meta-llama/Meta-Llama-3.1-8B-Instruct" + # disagg prefill + CUDA_VISIBLE_DEVICES=0 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8100 \ + --max-model-len 10000 \ + --gpu-memory-utilization 0.6 \ + --kv-connector PyNcclConnector \ + --kv-role kv_producer \ + --kv-rank 0 \ + --kv-parallel-size 2 \ + --kv-buffer-size 5e9 & + CUDA_VISIBLE_DEVICES=1 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8200 \ + --max-model-len 10000 \ + --gpu-memory-utilization 0.6 \ + --kv-connector PyNcclConnector \ + --kv-role kv_consumer \ + --kv-rank 1 \ + --kv-parallel-size 2 \ + --kv-buffer-size 5e9 & + wait_for_server 8100 + wait_for_server 8200 + python3 disagg_prefill_proxy_server.py & + sleep 1 +} + + +benchmark() { + results_folder="./results" + model="meta-llama/Meta-Llama-3.1-8B-Instruct" + dataset_name="sonnet" + dataset_path="../sonnet_4x.txt" + num_prompts=100 + qps=$1 + prefix_len=50 + input_len=1024 + output_len=$2 + tag=$3 + + python3 ../benchmark_serving.py \ + --backend vllm \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --sonnet-input-len $input_len \ + --sonnet-output-len "$output_len" \ + --sonnet-prefix-len $prefix_len \ + --num-prompts $num_prompts \ + --port 8000 \ + --save-result \ + --result-dir $results_folder \ + --result-filename "$tag"-qps-"$qps".json \ + --request-rate "$qps" + + sleep 2 + +} + + +main() { + + (which wget && which curl) || (apt-get update && apt-get install -y wget curl) + (which jq) || (apt-get -y install jq) + (which socat) || (apt-get -y install socat) + + pip install quart httpx matplotlib aiohttp + + cd "$(dirname "$0")" + + cd .. + # create sonnet-4x.txt so that we can sample 2048 tokens for input + echo "" > sonnet_4x.txt + for _ in {1..4} + do + cat sonnet.txt >> sonnet_4x.txt + done + cd disagg_benchmarks + + rm -rf results + mkdir results + + default_output_len=6 + + export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') + + launch_chunked_prefill + for qps in 2 4 6 8; do + benchmark $qps $default_output_len chunked_prefill + done + kill_gpu_processes + + launch_disagg_prefill + for qps in 2 4 6 8; do + benchmark $qps $default_output_len disagg_prefill + done + kill_gpu_processes + + python3 visualize_benchmark_results.py + +} + + +main "$@" diff --git a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py new file mode 100644 index 000000000000..4058b1c0a3b7 --- /dev/null +++ b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py @@ -0,0 +1,61 @@ +import os + +import aiohttp +from quart import Quart, make_response, request + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + +app = Quart(__name__) + + +async def forward_request(url, data): + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" + } + async with session.post(url=url, json=data, + headers=headers) as response: + if response.status == 200: + # if response.headers.get('Transfer-Encoding') == 'chunked': + if True: + async for chunk_bytes in response.content.iter_chunked( + 1024): + yield chunk_bytes + else: + content = await response.read() + yield content + + +@app.route('/v1/completions', methods=['POST']) +async def handle_request(): + try: + original_request_data = await request.get_json() + + prefill_request = original_request_data.copy() + # change max_tokens = 1 to let it only do prefill + prefill_request['max_tokens'] = 1 + + # finish prefill + async for _ in forward_request('http://localhost:8100/v1/completions', + prefill_request): + continue + + # return decode + generator = forward_request('http://localhost:8200/v1/completions', + original_request_data) + response = await make_response(generator) + response.timeout = None + + return response + + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server") + print(e) + print("".join(traceback.format_exception(*exc_info))) + + +if __name__ == '__main__': + app.run(port=8000) diff --git a/benchmarks/disagg_benchmarks/round_robin_proxy.py b/benchmarks/disagg_benchmarks/round_robin_proxy.py new file mode 100644 index 000000000000..6eb5f6398007 --- /dev/null +++ b/benchmarks/disagg_benchmarks/round_robin_proxy.py @@ -0,0 +1,60 @@ +import asyncio +import itertools + +import aiohttp +from aiohttp import web + + +class RoundRobinProxy: + + def __init__(self, target_ports): + self.target_ports = target_ports + self.port_cycle = itertools.cycle(self.target_ports) + + async def handle_request(self, request): + target_port = next(self.port_cycle) + target_url = f"http://localhost:{target_port}{request.path_qs}" + + async with aiohttp.ClientSession() as session: + try: + # Forward the request + async with session.request( + method=request.method, + url=target_url, + headers=request.headers, + data=request.content, + ) as response: + # Start sending the response + resp = web.StreamResponse(status=response.status, + headers=response.headers) + await resp.prepare(request) + + # Stream the response content + async for chunk in response.content.iter_any(): + await resp.write(chunk) + + await resp.write_eof() + return resp + + except Exception as e: + return web.Response(text=f"Error: {str(e)}", status=500) + + +async def main(): + proxy = RoundRobinProxy([8100, 8200]) + app = web.Application() + app.router.add_route('*', '/{path:.*}', proxy.handle_request) + + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, 'localhost', 8000) + await site.start() + + print("Proxy server started on http://localhost:8000") + + # Keep the server running + await asyncio.Event().wait() + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/benchmarks/disagg_benchmarks/visualize_benchmark_results.py b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py new file mode 100644 index 000000000000..e59d8bb0e6c8 --- /dev/null +++ b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py @@ -0,0 +1,46 @@ +import json + +import matplotlib.pyplot as plt +import pandas as pd + +if __name__ == "__main__": + + data = [] + for name in ['disagg_prefill', 'chunked_prefill']: + for qps in [2, 4, 6, 8]: + with open(f"results/{name}-qps-{qps}.json") as f: + x = json.load(f) + x['name'] = name + x['qps'] = qps + data.append(x) + + df = pd.DataFrame.from_dict(data) + dis_df = df[df['name'] == 'disagg_prefill'] + chu_df = df[df['name'] == 'chunked_prefill'] + + plt.style.use('bmh') + plt.rcParams['font.size'] = 20 + + for key in [ + 'mean_ttft_ms', 'median_ttft_ms', 'p99_ttft_ms', 'mean_itl_ms', + 'median_itl_ms', 'p99_itl_ms' + ]: + + fig, ax = plt.subplots(figsize=(11, 7)) + plt.plot(dis_df['qps'], + dis_df[key], + label='disagg_prefill', + marker='o', + linewidth=4) + plt.plot(chu_df['qps'], + chu_df[key], + label='chunked_prefill', + marker='o', + linewidth=4) + ax.legend() + + ax.set_xlabel('QPS') + ax.set_ylabel(key) + ax.set_ylim(bottom=0) + fig.savefig(f'results/{key}.png') + plt.close(fig) diff --git a/examples/kv_transfer/disagg_prefill_example.sh b/examples/kv_transfer/disagg_prefill_example.sh new file mode 100644 index 000000000000..e6c9d17227c7 --- /dev/null +++ b/examples/kv_transfer/disagg_prefill_example.sh @@ -0,0 +1,97 @@ +#!/bin/bash +# This file demonstrates the example usage of disaggregated prefilling +# We will launch 2 vllm instances (1 for prefill and 1 for decode), +# and then transfer the KV cache between them. + +export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') + +# install quart first -- required for disagg prefill proxy serve +if python3 -c "import quart" &> /dev/null; then + echo "Quart is already installed." +else + echo "Quart is not installed. Installing..." + python3 -m pip install quart +fi + +# a function that waits vLLM server to start +wait_for_server() { + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + + +# You can also adjust --kv-ip and --kv-port for distributed inference. + +# prefilling instance, which is the KV producer +CUDA_VISIBLE_DEVICES=0 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model meta-llama/Meta-Llama-3.1-8B-Instruct \ + --port 8100 \ + --max-model-len 10000 \ + --gpu-memory-utilization 0.8 \ + --kv-connector PyNcclConnector \ + --kv-role kv_producer \ + --kv-rank 0 \ + --kv-parallel-size 2 & + +# decoding instance, which is the KV consumer +CUDA_VISIBLE_DEVICES=1 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model meta-llama/Meta-Llama-3.1-8B-Instruct \ + --port 8200 \ + --max-model-len 10000 \ + --gpu-memory-utilization 0.8 \ + --kv-connector PyNcclConnector \ + --kv-role kv_consumer \ + --kv-rank 1 \ + --kv-parallel-size 2 & + +# wait until prefill and decode instances are ready +wait_for_server 8100 +wait_for_server 8200 + +# launch a proxy server that opens the service at port 8000 +# the workflow of this proxy: +# - send the request to prefill vLLM instance (port 8100), change max_tokens +# to 1 +# - after the prefill vLLM finishes prefill, send the request to decode vLLM +# instance +python3 ../../benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py & +sleep 1 + +# serve two example requests +output1=$(curl -s http://localhost:8000/v1/completions \ +-H "Content-Type: application/json" \ +-d '{ +"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", +"prompt": "San Francisco is a", +"max_tokens": 10, +"temperature": 0 +}') + +output2=$(curl -s http://localhost:8000/v1/completions \ +-H "Content-Type: application/json" \ +-d '{ +"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", +"prompt": "Santa Clara is a", +"max_tokens": 10, +"temperature": 0 +}') + + +# Cleanup commands, suppressing their output +pgrep pt_main_thread | xargs kill -9 > /dev/null 2>&1 +pkill -f python3 > /dev/null 2>&1 + +sleep 4 + +# Print the outputs of the curl requests +echo "" +echo "Output of first request: $output1" +echo "Output of second request: $output2" + +echo "Successfully finished 2 test requests!" +echo "" diff --git a/tests/kv_transfer/disagg_test.py b/tests/kv_transfer/disagg_test.py new file mode 100644 index 000000000000..d86be2eaa5d0 --- /dev/null +++ b/tests/kv_transfer/disagg_test.py @@ -0,0 +1,133 @@ +import os +import subprocess +import sys +import time +from subprocess import Popen + +import pytest +import requests +import torch + + +# Fixture to set up environment variables and teardown servers after tests +@pytest.fixture(scope="module", autouse=True) +def setup_servers(): + if torch.cuda.device_count() < 4: + pytest.skip("Skipping test: fewer than 4 GPUs available") + + # Set up environment variables + VLLM_HOST_IP = subprocess.check_output("hostname -I | awk '{print $1}'", + shell=True).decode().strip() + os.environ["VLLM_HOST_IP"] = VLLM_HOST_IP + + # Start prefill instance + prefill_cmd = [ + sys.executable, + "-m", + "vllm.entrypoints.openai.api_server", + "-tp", + "2", + "--model", + "meta-llama/Meta-Llama-3.1-8B-Instruct", + "--port", + "8100", + "--gpu-memory-utilization", + "0.5", + "--max-model-len", + "1000", + "--kv-connector", + "PyNcclConnector", + "--kv-role", + "kv_producer", + "--kv-rank", + "0", + "--kv-parallel-size", + "2", + ] + prefill_env = os.environ.copy() + prefill_env["CUDA_VISIBLE_DEVICES"] = "0,1" + prefill_proc = Popen(prefill_cmd, env=prefill_env) + + # Start decode instance + decode_cmd = [ + sys.executable, + "-m", + "vllm.entrypoints.openai.api_server", + "-tp", + "2", + "--model", + "meta-llama/Meta-Llama-3.1-8B-Instruct", + "--port", + "8200", + "--gpu-memory-utilization", + "0.5", + "--max-model-len", + "1000", + "--kv-connector", + "PyNcclConnector", + "--kv-role", + "kv_consumer", + "--kv-rank", + "1", + "--kv-parallel-size", + "2", + ] + decode_env = os.environ.copy() + decode_env["CUDA_VISIBLE_DEVICES"] = "2,3" + decode_proc = Popen(decode_cmd, env=decode_env) + + # Wait for servers to be ready + assert wait_for_server(8100), "Prefill server did not start in time" + assert wait_for_server(8200), "Decode server did not start in time" + + # Yield to the test function and handle teardown after tests + yield + + # Cleanup: kill the processes + prefill_proc.terminate() + decode_proc.terminate() + + # Additional cleanup if needed + prefill_proc.wait() + decode_proc.wait() + + +# Helper function to wait for server +def wait_for_server(port, timeout=240): + start_time = time.time() + while time.time() - start_time < timeout: + try: + response = requests.get(f"http://localhost:{port}/v1/completions") + if response.status_code in [200, 405]: + return True + except requests.ConnectionError: + time.sleep(1) + return False + + +# Test function to send curl requests and validate responses +@pytest.mark.parametrize("prompt", ["San Francisco is a", "Santa Clara is a"]) +def test_disaggregated_prefilling(prompt): + # Send to prefill + response = requests.post("http://localhost:8100/v1/completions", + headers={"Content-Type": "application/json"}, + json={ + "model": + "meta-llama/Meta-Llama-3.1-8B-Instruct", + "prompt": prompt, + "max_tokens": 1, + "temperature": 0 + }) + assert response.status_code == 200 + + # Send to decode + response = requests.post("http://localhost:8200/v1/completions", + headers={"Content-Type": "application/json"}, + json={ + "model": + "meta-llama/Meta-Llama-3.1-8B-Instruct", + "prompt": prompt, + "max_tokens": 10, + "temperature": 0 + }) + assert response.status_code == 200 diff --git a/tests/kv_transfer/module_test.py b/tests/kv_transfer/module_test.py new file mode 100644 index 000000000000..355461919cd7 --- /dev/null +++ b/tests/kv_transfer/module_test.py @@ -0,0 +1,64 @@ +import subprocess +import sys + +import pytest +import torch + + +def run_python_script(script_name, timeout): + script_name = f'kv_transfer/{script_name}' + try: + # Start both processes asynchronously using Popen + process0 = subprocess.Popen( + [sys.executable, script_name], + env={"RANK": + "0"}, # Set the RANK environment variable for process 0 + stdout=sys.stdout, # Pipe stdout to current stdout + stderr=sys.stderr, # Pipe stderr to current stderr + ) + + process1 = subprocess.Popen( + [sys.executable, script_name], + env={"RANK": + "1"}, # Set the RANK environment variable for process 1 + stdout=sys.stdout, # Pipe stdout to current stdout + stderr=sys.stderr, # Pipe stderr to current stderr + ) + + # Wait for both processes to complete, with a timeout + process0.wait(timeout=timeout) + process1.wait(timeout=timeout) + + # Check the return status of both processes + if process0.returncode != 0: + pytest.fail( + f"Test {script_name} failed for RANK=0, {process0.returncode}") + if process1.returncode != 0: + pytest.fail( + f"Test {script_name} failed for RANK=1, {process1.returncode}") + + except subprocess.TimeoutExpired: + # If either process times out, terminate both and fail the test + process0.terminate() + process1.terminate() + pytest.fail(f"Test {script_name} timed out") + except Exception as e: + pytest.fail(f"Test {script_name} failed with error: {str(e)}") + + +# Define the test cases using pytest's parametrize +@pytest.mark.parametrize( + "script_name,timeout", + [ + ("test_lookup_buffer.py", + 60), # Second test case with a 60-second timeout + ("test_send_recv.py", 120) # First test case with a 120-second timeout + ]) +def test_run_python_script(script_name, timeout): + # Check the number of GPUs + if torch.cuda.device_count() < 2: + pytest.skip( + f"Skipping test {script_name} because <2 GPUs are available") + + # Run the test if there are at least 2 GPUs + run_python_script(script_name, timeout) diff --git a/tests/kv_transfer/test_lookup_buffer.py b/tests/kv_transfer/test_lookup_buffer.py new file mode 100644 index 000000000000..a323ac731990 --- /dev/null +++ b/tests/kv_transfer/test_lookup_buffer.py @@ -0,0 +1,161 @@ +import os +import random + +import torch +from tqdm import tqdm + +from vllm.config import KVTransferConfig +from vllm.distributed.kv_transfer.kv_connector.pynccl_connector.buffer import ( + lookup_buffer as lb) +from vllm.distributed.kv_transfer.kv_connector.pynccl_connector.pipe import ( + pynccl_pipe as pnp) + +# TODO: the test depends on a lot of fields in the current implementation. +# We should have standard interface instead direct field access + + +def test_run(my_rank, buffer, device): + + # buffer should be empty in the beginning + if my_rank == 0: + assert buffer.buffer_size == 0 + assert len(buffer.buffer) == 0 + + print("My rank: %d, device: %s" % (my_rank, device)) + + # insert + tokens = torch.tensor([1, 2, 3]).to(device) + roi = (tokens > 0) + if my_rank == 0: + key = 2.0 * torch.ones([5, 6]).to(device) + value = 3.0 * torch.ones([5, 6]).to(device) + + placeholder = torch.tensor([1]).to(device) + + buffer.insert(tokens, roi, key, value, placeholder) + + torch.distributed.barrier() + + # drop_select + if my_rank == 1: + tok, roi_, key, value, hidden = buffer.drop_select(tokens, roi) + assert torch.allclose(tokens, tok) + assert torch.allclose(roi, roi_) + assert torch.allclose(key, 2.0 * torch.ones([5, 6], device=device)) + assert torch.allclose(value, 3.0 * torch.ones([5, 6], device=device)) + torch.distributed.barrier() + + if my_rank == 0: + assert buffer.buffer_size == 0 + assert len(buffer.buffer) == 0 + + print("Test run passed!") + + +def stress_test(my_rank, buf, device): + + torch.distributed.barrier() + torch.manual_seed(100) + + reqs = [ + ( + torch.rand(100).to(device), # tokens + torch.ones(100).bool().to(device), # roi + torch.rand(100).to(device), # key + torch.rand(100).to(device), # value + torch.rand(100).to(device), # hidden + ) for i in tqdm(range(200)) + ] + + random.seed(my_rank) + random.shuffle(reqs) + + torch.distributed.barrier() + + n = 0 + + # the buffer size can only store 100 reqs + # so the sender will occasionally block to wait for the receiver. + for req in tqdm(reqs): + if my_rank == 0: + buf.insert(*req) + else: + tok, roi, k, v, h = req + tok_, roi_, k_, v_, h_ = buf.drop_select(tok, roi) + + if tok_ is None: + assert roi_ is None + assert k_ is None + assert v_ is None + assert h_ is None + n += 1 + else: + assert torch.allclose(tok, tok_) + assert torch.allclose(roi, roi_) + assert torch.allclose(k, k_) + assert torch.allclose(v, v_) + assert torch.allclose(h, h_) + print('Rank %d done' % my_rank) + torch.distributed.barrier() + + if my_rank == 0: + x = torch.tensor([0]) + torch.distributed.recv(x, 1) + # the # of None received is the kv that are not selected + assert x.item() == len(buf.buffer) + # and the size of the buffer should be 2000 * buffer len + print(buf.buffer_size) + assert buf.buffer_size == 1700 * len(buf.buffer) + else: + torch.distributed.send(torch.tensor([n]), 0) + + print("Passed stress test!") + + +if __name__ == "__main__": + + my_rank = int(os.environ['RANK']) + + torch.distributed.init_process_group( + backend='gloo', + init_method='tcp://localhost:12398', + world_size=2, + rank=my_rank, + ) + + print("initialized! My rank is %d" % my_rank) + + config = KVTransferConfig( + kv_connector='PyNcclConnector', + kv_buffer_device='cuda', + kv_buffer_size=1e9, + kv_rank=my_rank, + kv_role="kv_both", # this arg doesn't matter in this test + kv_parallel_size=2, + kv_ip="127.0.0.1", + kv_port=12345, + ) + + data_pipe = pnp.PyNcclPipe( + local_rank=my_rank, + config=config, + device="cuda", + port_offset=0, + ) + cpu_pipe = pnp.PyNcclPipe( + local_rank=my_rank, + config=config, + device="cpu", + port_offset=1, + ) + + buffer = lb.LookupBuffer(cpu_pipe, data_pipe, 170000) + + test_run(my_rank, buffer, data_pipe.device) + + stress_test(my_rank, buffer, data_pipe.device) + + buffer.close() + data_pipe.close() + cpu_pipe.close() + print('Done') diff --git a/tests/kv_transfer/test_lookup_buffer.sh b/tests/kv_transfer/test_lookup_buffer.sh new file mode 100644 index 000000000000..09d7ee018c3f --- /dev/null +++ b/tests/kv_transfer/test_lookup_buffer.sh @@ -0,0 +1,3 @@ +#!/bin/bash +RANK=0 python test_lookup_buffer.py & +RANK=1 python test_lookup_buffer.py & \ No newline at end of file diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py new file mode 100644 index 000000000000..ad791b456c7b --- /dev/null +++ b/tests/kv_transfer/test_send_recv.py @@ -0,0 +1,156 @@ +import os +import time +from typing import List + +import torch +from tqdm import tqdm + +from vllm.config import KVTransferConfig +from vllm.distributed.kv_transfer.kv_connector.pynccl_connector import ( + pynccl_pipe as pnp) + + +def test_run(my_rank, pipe): + # test run + x = torch.tensor([1]).to(pipe.device) + y = torch.tensor([[2., 3., 4., 8.]]).to(pipe.device) + if my_rank == 0: + pipe.send_tensor(x) + print("sent tensor x") + pipe.send_tensor(y) + print("sent tensor y") + x2 = pipe.recv_tensor() + print("received x2 = ", x2) + y2 = pipe.recv_tensor() + print("received y2 = ", x2) + + else: + x2 = pipe.recv_tensor() + print("received x2 = ", x2) + y2 = pipe.recv_tensor() + print("received y2 = ", x2) + pipe.send_tensor(x) + print("sent tensor x") + pipe.send_tensor(y) + print("sent tensor y") + + assert torch.allclose(x, x2) + assert torch.allclose(y, y2) + + +def stress_test(my_rank, pipe): + + torch.distributed.barrier() + + tensors: List[torch.Tensor] = [] + + torch.manual_seed(0) + + for i in tqdm(range(500)): + mean = torch.rand(1).item() * 100 + std = torch.rand(1).item() * 100 + size = torch.randint(900, 1000, (2, )) + x = torch.normal(mean * 1.0, std * 1.0, + size=size.tolist()).to(pipe.device) + + # 5% probability of sending a None + if torch.rand(1).item() < 0.05: + tensors.append(None) + tensors.append(None) + tensors.append(None) + else: + tensors.append(x) + tensors.append(x.mean().unsqueeze(0)) + tensors.append(x.std().unsqueeze(0)) + + torch.distributed.barrier() + + for i in tqdm(range(500)): + if my_rank == int((i % 10) > 3): + pipe.send_tensor(tensors[3 * i]) + pipe.send_tensor(tensors[3 * i + 1]) + pipe.send_tensor(tensors[3 * i + 2]) + else: + x = pipe.recv_tensor() + mean = pipe.recv_tensor() + std = pipe.recv_tensor() + + if x is None: + assert mean is None + assert std is None + else: + assert torch.allclose(x, tensors[3 * i]) + assert x.mean() == mean[0] + assert x.std() == std[0] + + torch.distributed.barrier() + + +def latency_test(my_rank, pipe, nelement, ntensor): + + latencies = [] + + torch.distributed.barrier() + + for i in tqdm(range(500)): + + tensors = [] + + if my_rank == 0: + # create tensor + tensors = [ + torch.rand(nelement).to(pipe.device) for _ in range(ntensor) + ] + + torch.distributed.barrier() + + if my_rank == 0: + t = torch.tensor([time.time()], + dtype=torch.float64).to(pipe.device) + for tensor in tensors: + pipe.send_tensor(tensor) + pipe.send_tensor(t) + else: + for _ in range(ntensor): + pipe.recv_tensor() + t = pipe.recv_tensor() + latencies.append(time.time() - t.item()) + + torch.distributed.barrier() + + print('Latency test passed.') + print('Latency:', torch.tensor(latencies).mean().item() * 1000, 'ms') + + +if __name__ == "__main__": + + my_rank = int(os.environ['RANK']) + + torch.distributed.init_process_group( + backend='gloo', + init_method='tcp://localhost:12398', + world_size=2, + rank=my_rank, + ) + + config = KVTransferConfig( + kv_connector='PyNcclConnector', + kv_buffer_device='cuda', + kv_buffer_size=1e9, + kv_rank=my_rank, + kv_role="kv_both", # this arg doesn't matter in this test + kv_parallel_size=2, + kv_ip="127.0.0.1", + kv_port=12345, + ) + + pipe = pnp.PyNcclPipe( + local_rank=my_rank, + config=config, + ) + + test_run(my_rank, pipe) + stress_test(my_rank, pipe) + + # Use this function if you want to test the latency of pipe impl. + # latency_test(my_rank, pipe, 1024 * 8 * 128, 80) diff --git a/tests/kv_transfer/test_send_recv.sh b/tests/kv_transfer/test_send_recv.sh new file mode 100644 index 000000000000..935487bd00d6 --- /dev/null +++ b/tests/kv_transfer/test_send_recv.sh @@ -0,0 +1,3 @@ +#!/bin/bash +RANK=0 python test_send_recv.py & +RANK=1 python test_send_recv.py & \ No newline at end of file diff --git a/vllm/config.py b/vllm/config.py index 3d0c61686822..5e6834687a94 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2055,6 +2055,60 @@ def __post_init__(self): f"installed. Original error:\n{otel_import_error_traceback}") +@dataclass +class KVTransferConfig: + """Configuration for distributed KV cache transfer.""" + + # NOTE: these default values should align with EngineArgs + kv_connector: Optional[str] = None + kv_buffer_device: Optional[str] = None + kv_buffer_size: float = 1e9 + kv_role: Optional[str] = None + kv_rank: Optional[int] = None + kv_parallel_size: int = 1 + kv_ip: str = "127.0.0.1" + kv_port: int = 14579 + + @property + def is_kv_transfer_instance(self) -> bool: + return self.kv_connector is not None and \ + self.kv_role in ["kv_producer", "kv_consumer", "kv_both"] + + @property + def need_kv_parallel_group(self) -> bool: + # for those database-based connector, vLLM does not need to create + # parallel group, and in that case the kv parallel size will be 1. + return self.kv_connector is not None and self.kv_parallel_size > 1 + + @property + def is_kv_producer(self) -> bool: + return self.kv_connector is not None and \ + self.kv_role in ["kv_producer", "kv_both"] + + @property + def is_kv_consumer(self) -> bool: + return self.kv_connector is not None and \ + self.kv_role in ["kv_consumer", "kv_both"] + + def __post_init__(self): + + if self.kv_connector not in [None, "PyNcclConnector"]: + raise ValueError(f"Unsupported kv_connector: {self.kv_connector}. " + f"Supported connectors are " + f"`PyNcclConnector`.") + + if self.kv_role not in [None, "kv_producer", "kv_consumer", "kv_both"]: + raise ValueError( + f"Unsupported kv_role: {self.kv_role}. " + f"Supported roles are `kv_producer`, `kv_consumer`, " + f"and `kv_both`") + + if self.kv_connector is not None and self.kv_role is None: + raise ValueError("Please specify kv_disagg_role when kv_connector " + "is set, supported roles are `kv_producer`, " + "`kv_consumer`, and `kv_both`") + + class CompilationLevel: # constants for the levels of the compilation process NO_COMPILATION = 0 @@ -2285,6 +2339,8 @@ class VllmConfig: quant_config: Optional[QuantizationConfig] = None compilation_config: CompilationConfig = field(default=None, init=True) # type: ignore + kv_transfer_config: KVTransferConfig = field(default=None, + init=True) # type: ignore @staticmethod def _get_quantization_config( diff --git a/vllm/distributed/kv_transfer/__init__.py b/vllm/distributed/kv_transfer/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/distributed/kv_transfer/kv_connector/__init__.py b/vllm/distributed/kv_transfer/kv_connector/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/distributed/kv_transfer/kv_connector/base.py b/vllm/distributed/kv_transfer/kv_connector/base.py new file mode 100644 index 000000000000..8254d3fdd533 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/base.py @@ -0,0 +1,194 @@ +""" +This file contains a new class `KVLookupBufferBase` that allows developers to +think of KV cache operations as inserting new KV cache entries (`insert`) +into the lookup buffer and querying existing KV caches (`drop_select`) +from the lookup buffer. + +All distributed communications are abstracted behind this class. +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import torch + +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.config import KVTransferConfig + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + + +class KVConnectorBase(ABC): + """ + Abstract base class for a KV connector. + + This class provides an abstraction for a key-value (KV) cache lookup buffer. + + The key of the lookup buffer: + - input_tokens: token IDs of the request + - roi: a binary mask on top of input_tokens. + - Purpose of roi: Since KV cache may only be available for a subset of + tokens in the input (for example, when vLLM is connected to an external + KV cache service), roi specifies the subset of tokens that the KV cache + is associated with. + - NOTE: roi can be further extended to describe which part of KV the + current process is holding (each process may only hold a part of KV + due to TP and PP). This is not implemented for now. + + The value of the lookup buffer: + - key: the key tensor in the KV cache + - value: the value tensor in the KV cache + - hidden: the final hidden state generated by model forwarding. This allows + vLLM to bypass further model forwarding by transmitting the hidden state. + """ + + @abstractmethod + def __init__( + self, + rank: int, + local_rank: int, + config: "KVTransferConfig", + ): + raise NotImplementedError + + @abstractmethod + def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor) -> None: + """Insert into the lookup buffer, similar to SQL insert + + The functionality is similar to the following python statement + ``` + connector[input_tokens, roi] = [key, value, hidden] + ``` + + FIXME: in the future, we should only have two arguments, key and value, + where key is a tensor dict and value is a tensor dict. + + FIXME: we should transmit both sampler outputs and the hidden states. + + Args: + input_tokens (torch.Tensor): token IDs. + roi (torch.Tensor): A binary mask on top of the input tokens + key (torch.Tensor): The key tensor in the KV cache. + value (torch.Tensor): The value tensor in the KV cache. + hidden (torch.Tensor): The final hidden state tensor generated + during model forwarding to bypass model + forwarding. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def select(self, input_tokens: Optional[torch.Tensor], + roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: + """Select KV cache entries from the connector. + + The functionality is similar to the following python statements + ``` + return connector[input_tokens, roi] + ``` + + If `input_tokens` and `roi` is `None`, it means selecting any of the + KV caches in the buffer, return, and remove it from the buffer, useful + when offloading KV cache to KV cache storage service. + + Args: + input_tokens (torch.Tensor): token IDs. + roi (torch.Tensor): A binary mask on top of the input tokens + + Returns: + List[Optional[torch.Tensor]]: A list of tensors. Can be None. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def close(self) -> None: + """Close the buffer and release resources. + + This method is responsible for cleaning up resources related to the + connector when it is no longer needed. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + """ + Send KV caches and hidden states to the connector. + + This method processes the input tokens, KV caches, and + hidden/intermediate states for a given model and sends the data to the + decode instance. + + Args: + model_executable (torch.nn.Module): The model executable containing + start and end layer information. + model_input (ModelInputForGPUWithSamplingMetadata): The input + metadata from vLLM. + kv_caches (List[torch.Tensor]): List of KV caches (keys and values) + for each layer. + hidden_or_intermediate_states (Union[torch.Tensor, + IntermediateTensors]): + The hidden or intermediate states associated with the tokens. + + Returns: + None + + """ + + raise NotImplementedError + + @abstractmethod + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor] + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + """ + Receive KV caches and hidden states from the connector. + + This method attempts to retrieve KV caches and hidden states for input + tokens. If all required KV caches and hidden states are received, it + will bypass model input, else it will fall back to normal vLLM model + forwarding. + + Args: + model_executable (torch.nn.Module): + The model executable from vLLM modelrunner. + model_input (ModelInputForGPUWithSamplingMetadata): + The model input from vLLM modelrunner. + kv_caches (List[torch.Tensor]): + List of KV caches for each layer. + + Returns: + - hidden_or_intermediate_states (torch.Tensor or + IntermediateTensors): + Concatenated hidden states if all required data is retrieved, + otherwise `None`. + - bypass_model_exec (bool): + Indicates whether the model execution can be skipped (True) or + needs to be redone (False). + - model_input (ModelInputForGPUWithSamplingMetadata): + Optionally adjusted input metadata for re-execution when + `bypass_model_exec=False`. + + """ + + raise NotImplementedError diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py new file mode 100644 index 000000000000..ac8bf788b39c --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -0,0 +1,14 @@ +from .base import KVConnectorBase + + +class KVConnectorFactory: + + @staticmethod + def create_connector(rank: int, local_rank: int, + config) -> KVConnectorBase: + if config.kv_connector == 'PyNcclConnector': + from .pynccl_connector.connector import PyNcclConnector + return PyNcclConnector(rank, local_rank, config) + else: + raise ValueError(f"Unsupported connector type: " + f"{config.kv_connector}") diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/buffer.py b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/buffer.py new file mode 100644 index 000000000000..883ff101d4b2 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/buffer.py @@ -0,0 +1,240 @@ +""" + This file implements a simple torch distributed connector by 3 classes: + - `TorchDistributedPipe`: a tensor transmission pipe between vllm instances, + using `torch.distributed` + - `TorchDistributedBuffer`: a buffer to store tensors, implemented on top + of `TorchDistributedPipe` + - `TorchDistributedConnector`: a torch distributed connector between P/D + instance, implemented on top of `TorchDistributedBuffer` +""" +import threading +import time +from collections import deque +from typing import Deque, List, Optional, Union + +import torch + +from vllm.distributed.kv_transfer.kv_connector.pynccl_connector.pipe import ( + PyNcclPipe) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class LookupBuffer: + + def __init__(self, signal_pipe: PyNcclPipe, data_pipe: PyNcclPipe, + buffer_size_thresh: float): + """ + signal_pipe: on CPU + + NOTE: on-device recv will block all threads in the process, making the + KV cache producer unable to listen to new request while transmitting + KV cache. Luckily CPU recv only blocks the current thread so we use + CPU recv to listen to new request. + + data_pipe: on device (e.g. GPU) + """ + + self.buffer: Deque[List[torch.Tensor]] = deque() + + self.buffer_size = 0 + self.buffer_size_threshold = buffer_size_thresh + self.buffer_lock = threading.Lock() + self.signal_pipe = signal_pipe + self.data_pipe = data_pipe + self.request_handling_thread: Optional[threading.Thread] = None + + self.normal_signal = torch.tensor([0], device="cpu") + self.end_signal = None + + def _matches(self, tokens_roi_sender: List[torch.Tensor], + tokens_roi_recver: List[torch.Tensor]): + + # tokens_roi_sender: tokens and roi of the producer (in the buffer) + # tokens_roi_recver: tokens and roi of the consumer (query) + + tokens_sender = tokens_roi_sender[0] + tokens_recver = tokens_roi_recver[0] + roi_sender = tokens_roi_sender[1] + roi_recver = tokens_roi_recver[1] + + if tokens_recver is None: + # consumer sends an empty request + # semantics: DROP SELECT * LIMIT 1 + # so any of the data in the buffer can be drop-selected + return True + + # Assuming that roi is a binary mask on tokens + tokens_sender = tokens_sender[roi_sender] + tokens_recver = tokens_recver[roi_recver] + + # simple common prefix matching + min_length = min(len(tokens_sender), len(tokens_recver)) + if torch.allclose(tokens_sender[:min_length], + tokens_recver[:min_length]): + return min_length + + return 0 + + def _send_tensor_and_dec_size(self, + tensor: Optional[torch.Tensor]) -> None: + + assert tensor is not None, "Use self.data_pipe.send(None) instead" + self.buffer_size -= tensor.element_size() * tensor.numel() + if tensor.dtype == torch.bool: + tensor = tensor.float() + self.data_pipe.send_tensor(tensor) + + def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]): + + if isinstance(data, torch.Tensor): + return data.element_size() * data.numel() + if not data: + # cannot perform `not data` on a tensor + # so this check needs to go after the check above + return 0 + + raise AssertionError(f"Unknown data type {type(data)}") + + def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor): + + if isinstance(input_tokens, torch.Tensor): + input_tokens = input_tokens.clone() + if isinstance(roi, torch.Tensor): + roi = roi.clone() + if isinstance(key, torch.Tensor): + key = key.clone() + if isinstance(value, torch.Tensor): + value = value.clone() + if isinstance(hidden, torch.Tensor): + hidden = hidden.clone() + + buffer_item = [input_tokens, roi, key, value, hidden] + + with self.buffer_lock: + for data in buffer_item: + self.buffer_size += self._get_element_size(data) + self.buffer.append(buffer_item) + + def _is_end_signal(self, signal): + return signal is None + + def drop_select_handler(self): + + try: + + while True: + signal = self.signal_pipe.recv_tensor() + if self._is_end_signal(signal): + logger.info("Received end signal!") + break + + input_tokens = self.data_pipe.recv_tensor() + + roi = self.data_pipe.recv_tensor() + assert roi is not None, "Please provide the roi when sending "\ + "drop-select request" + roi = (roi > 0.5) + tokens_roi_recver = [input_tokens, roi] + + matched_length = 0 + + # perform input tokens and roi matching + # FIXME: this matching is O(n), ideally it should be O(1) + # but this buffer size won't (and shouldn't) be too large so + # the fix is not urgent. + with self.buffer_lock: + + for _ in range(len(self.buffer)): + + temp_length = self._matches(self.buffer[0], + tokens_roi_recver) + if temp_length > 0: + matched_length = temp_length + break + # rotate the element we just accessed to the end + self.buffer.rotate(-1) + + if matched_length > 0: + # need to clone the tensor + # in case the tensor is freed before sending finishes + matched_item = self.buffer.popleft() + for tensor in matched_item: + self._send_tensor_and_dec_size(tensor) + + else: + # no match, just send None + for _ in range(5): + self.data_pipe.send_tensor(None) + + except RuntimeError as e: + if 'Connection closed by peer' not in str(e): + raise e + + logger.debug("Closing drop_select_handler") + + def drop_select( + self, input_tokens: Optional[torch.Tensor], + roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: + + assert self.request_handling_thread is None, \ + "drop_select should be called by the KV cache consumer "\ + "(e.g. the decode vLLM instance)" + + if isinstance(input_tokens, torch.Tensor): + input_tokens = input_tokens.clone() + if isinstance(roi, torch.Tensor): + roi = roi.clone().float() + + self.signal_pipe.send_tensor(self.normal_signal) + self.data_pipe.send_tensor(input_tokens) + self.data_pipe.send_tensor(roi) + + input_tokens = self.data_pipe.recv_tensor() + roi = self.data_pipe.recv_tensor() + if roi is not None: + # convert from float tensor to bool tensor + # as PyNccl does not support sending bool tensor + roi = (roi > 0.5) + key = self.data_pipe.recv_tensor() + value = self.data_pipe.recv_tensor() + hidden = self.data_pipe.recv_tensor() + + return [input_tokens, roi, key, value, hidden] + + def full_handler(self): + time.sleep(0.001) + + def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor) -> None: + + if self.buffer_size > self.buffer_size_threshold: + # log outside the while loop to avoid this message being logged + # repeatedly. + logger.debug("KV transfer buffer is full. Handling...") + while self.buffer_size > self.buffer_size_threshold: + self.full_handler() + + self._add_to_buffer(input_tokens, roi, key, value, hidden) + + # when calling the insert, the current process is a sender + # need to launch the request handler and start listening to request. + if self.request_handling_thread is None: + self.request_handling_thread = threading.Thread( + target=self.drop_select_handler) + self.request_handling_thread.start() + + def close(self): + + if hasattr(self, "request_handling_thread" + ) and self.request_handling_thread is not None: + self.request_handling_thread.join() + + else: + # TODO: have a explicit close signal and have a explicit way to + # check if it's requester + self.signal_pipe.send_tensor(self.end_signal) diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/connector.py b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/connector.py new file mode 100644 index 000000000000..fe45adba0605 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/connector.py @@ -0,0 +1,258 @@ +""" + This file implements a simple torch distributed connector by 3 classes: + - `TorchDistributedPipe`: a tensor transmission pipe between vllm instances, + using `torch.distributed` + - `TorchDistributedBuffer`: a buffer to store tensors, implemented on top + of `TorchDistributedPipe` + - `TorchDistributedConnector`: a torch distributed connector between P/D + instance, implemented on top of `TorchDistributedBuffer` +""" +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import torch + +from vllm import _custom_ops as ops +from vllm.config import KVTransferConfig +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.kv_transfer.kv_connector.pynccl_connector.buffer import ( + LookupBuffer) +from vllm.distributed.kv_transfer.kv_connector.pynccl_connector.pipe import ( + PyNcclPipe) +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + +logger = init_logger(__name__) + + +class PyNcclConnector(KVConnectorBase): + + def __init__( + self, + rank: int, + local_rank: int, + config: KVTransferConfig, + ): + + self.config = config + + self.lookup_buffer_size = self.config.kv_buffer_size + + self.producer_buffer: Optional[LookupBuffer] = None + self.consumer_buffer: Optional[LookupBuffer] = None + + # 2 pipes for every rank in the world + port_offset_base = 2 * rank + + # In disaggregated prefill, the prefill vLLM only uses send pipe + # and the decode vLLM only uses recv pipe + if config.is_kv_producer: + + self.producer_data_pipe = PyNcclPipe( + local_rank=local_rank, + config=config, + port_offset=port_offset_base, + ) + self.producer_signal_pipe = PyNcclPipe( + local_rank=local_rank, + config=config, + port_offset=port_offset_base + 1, + device="cpu", + ) + self.producer_buffer = LookupBuffer(self.producer_signal_pipe, + self.producer_data_pipe, + config.kv_buffer_size) + + else: + + # the current vLLM instance is KV consumer, so it needs to connect + # its recv pipe to the send pipe of KV producder + self.consumer_data_pipe = PyNcclPipe( + local_rank=local_rank, + config=config, + port_offset=port_offset_base, + ) + self.consumer_signal_pipe = PyNcclPipe( + local_rank=local_rank, + config=config, + port_offset=port_offset_base + 1, + device="cpu", + ) + self.consumer_buffer = LookupBuffer( + self.consumer_signal_pipe, + self.consumer_data_pipe, + config.kv_buffer_size, + ) + + def select(self, input_tokens: Optional[torch.Tensor], + roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: + + assert self.consumer_buffer is not None, "Please initialize the "\ + "consumer buffer before calling select." + return self.consumer_buffer.drop_select(input_tokens, roi) + + def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor) -> None: + + assert self.producer_buffer is not None, "Please initialize the "\ + "producer buffer before calling insert." + + self.producer_buffer.insert(input_tokens, roi, key, value, hidden) + + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer + + # query_lens contains new KV caches that are added to vLLM. + # so we will send them to decode instance + # FIXME(Kuntai): This assume that all requests are prefill. + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + current_tokens = input_tokens_tensor[start_pos:end_pos] + + keys, values = [], [] + + for layer_id in range(start_layer, end_layer): + kv_cache = kv_caches[layer_id - start_layer] + + _, _, num_heads, head_size = kv_cache[0].shape + + key_cache = kv_cache[0].reshape(-1, num_heads, head_size) + value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + + current_slot_mapping = slot_mapping_flat[start_pos:end_pos] + + keys.append(key_cache[current_slot_mapping].unsqueeze(0)) + values.append(value_cache[current_slot_mapping].unsqueeze(0)) + + keys = torch.cat(keys, dim=0) + values = torch.cat(values, dim=0) + + self.insert(current_tokens, + torch.ones_like(current_tokens, + dtype=bool), keys, values, + hidden_or_intermediate_states[start_pos:end_pos]) + + logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) + + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor] + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + + # When bypass_model_exec is set to False, it means that at least for one + # request its corresponding KV cache or hidden state is missing. + # In this case we need to do prefilling to recompute missing KV cache + # and hidden states. + bypass_model_exec = True + + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + + hidden_or_intermediate_states_for_one_req = [] + + input_tokens_list = [] + num_computed_tokens_list = [] + start_pos_list = [] + + # enumerate different requests + # FIXME(Kuntai): This impl assumes that all requests are prefill. + for idx, slen in enumerate(seq_lens): + + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + current_tokens = input_tokens_tensor[start_pos:end_pos] + num_tokens = slen + + # collecting data for rebuilding the input + input_tokens_list.append(current_tokens) + start_pos_list.append(start_pos) + + ret = self.select(current_tokens, + torch.ones_like(current_tokens, dtype=bool)) + if ret[0] is None: + # didn't find any match. + bypass_model_exec = False + num_computed_tokens_list.append(0) + continue + + roi: torch.Tensor = ret[1] + keys: torch.Tensor = ret[2] + values: torch.Tensor = ret[3] + hidden: torch.Tensor = ret[4] + + num_computed_tokens = roi.shape[0] + num_computed_tokens_list.append(num_computed_tokens) + + # check if both KV cache and the hidden states are received + # If not, need to redo the forwarding to compute missing states + if not all([(num_computed_tokens == num_tokens), hidden is not None + ]): + bypass_model_exec = False + + # update the end position based on how many tokens are cached. + end_pos = start_pos + num_computed_tokens + + # put received KV caches into paged memory + for i in range(model_executable.model.start_layer, + model_executable.model.end_layer): + + kv_cache = kv_caches[i - model_executable.model.start_layer] + layer = model_executable.model.layers[i] + + key_cache, value_cache = kv_cache[0], kv_cache[1] + ops.reshape_and_cache_flash( + keys[i - model_executable.model.start_layer].to( + key_cache.device), + values[i - model_executable.model.start_layer].to( + value_cache.device), + key_cache, + value_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + layer.self_attn.attn._v_scale, + ) + + hidden_or_intermediate_states_for_one_req.append(hidden) + + if not bypass_model_exec: + # Some of the KV cache is not retrieved + # so we need to adjust model_input and redo the forwarding. + logger.debug( + "[rank%d]: Failed to receive all KVs and hidden " + "states, redo model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = None + + else: + logger.debug( + "[rank%d]: Successfully received all KVs and hidden " + "states, skip model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = torch.cat( + hidden_or_intermediate_states_for_one_req, dim=0) + + return hidden_or_intermediate_states, bypass_model_exec, model_input + + def close(self): + self.producer_data_pipe.close() + self.producer_signal_pipe.close() + self.consumer_data_pipe.close() + self.consumer_signal_pipe.close() diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pipe.py b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pipe.py new file mode 100644 index 000000000000..d74f0967e473 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pipe.py @@ -0,0 +1,265 @@ +""" + This file implements a simple PyNccl pipe that can send and receive + Optional[torch.Tensor] between two ranks. + + We will first transmit the metadata, and then the tensor. + Metadata format: + Metadata = Dict[str, Optional[torch.Tensor]] + - "dtype": The data type of the tensor (tensor.dtype) or None + - "shape": The shape of the tensor (tensor.shape) or None +""" + +import threading +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Callable, Dict, Optional, Tuple + +import torch + +from vllm.config import KVTransferConfig +from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator +from vllm.distributed.utils import StatelessProcessGroup +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class BrokenPipeException(Exception): + + def __init__(self, message): + self.message = message + super().__init__(self.message) + + +Metadata = Dict[str, Optional[torch.Tensor]] + + +class PyNcclPipe: + + METADATA_LENGTH = 16 + MAX_TENSOR_DIMENSIONS = 14 + METADATA_DTYPE = torch.int64 + + def __init__(self, + local_rank: int, + config: KVTransferConfig, + device: Optional[str] = None, + port_offset: int = 0): + self.config = config + self.local_rank = local_rank + self.kv_rank = self.config.kv_rank + self.kv_parallel_size = self.config.kv_parallel_size + if device is None: + self.device = self._select_device(self.config.kv_buffer_device) + else: + self.device = self._select_device(device) + + # build distributed connection and send/recv implementation + self.group = StatelessProcessGroup.create( + host=self.config.kv_ip, + port=self.config.kv_port + port_offset, + rank=self.kv_rank, + world_size=self.kv_parallel_size, + ) + # add a barrier to make sure the connection is initiated properly + self.group.barrier() + impl = self._get_device_send_recv_impl(self.group) + self.device_send_func, self.device_recv_func = impl + # set target rank + self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size + self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size + + # transportation-related variables + self.transport_thread: Optional[ThreadPoolExecutor] = None + self.buffer_size = 0 + self.buffer_size_lock = threading.Lock() + self.buffer_size_thresh = self.config.kv_buffer_size + + def _get_device_send_recv_impl( + self, group: StatelessProcessGroup + ) -> Tuple[Callable[[torch.Tensor, int], None], Callable[ + [torch.Tensor, int], None]]: + + send: Callable[[torch.Tensor, int], None] + recv: Callable[[torch.Tensor, int], None] + if self.device.type == "cuda": + # use PyNCCL for send / recv + comm = PyNcclCommunicator(group, device=self.local_rank) + comm.disabled = False + send, recv = comm.send, comm.recv # type: ignore + else: + # use cpu communication + send = group.send + recv = group.recv + + return send, recv + + def _select_device(self, device: str): + if device == "cuda": + return torch.device(f"cuda:{self.local_rank}") + else: + return torch.device("cpu") + + def _make_metadata(self, tensor: Optional[torch.Tensor]) -> Metadata: + """ + Create the metadata as a dictionary based on the input tensor. + + Parameters: + - tensor: The input tensor or None if no tensor is provided. + + Returns: + - metadata: A dictionary with the following keys: + - "dtype": The data type of the tensor or None. + - "shape": The shape of the tensor or None. + """ + if tensor is None: + return {"dtype": None, "shape": None} + else: + return {"dtype": tensor.dtype, "shape": tensor.shape} + + def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor: + """ + Create a buffer to receive the tensor based on the provided metadata. + + Parameters: + - metadata: A dictionary with keys "dtype" and "shape", describing + the tensor's data type and shape. + + Returns: + - buffer: A tensor of the specified type and shape, allocated on + self.device. + """ + return torch.empty(metadata["shape"], + dtype=metadata["dtype"], + device=self.device) + + def _send_metadata(self, metadata: Metadata): + """ + Send the metadata dictionary to the target rank. + + Parameters: + - metadata: A dictionary with keys "dtype" and "shape". + """ + self.group.send_obj(metadata, self.target_rank_for_send) + + def _recv_metadata(self) -> Metadata: + """ + Receive the metadata dictionary from the target rank. + + Returns: + - metadata: A dictionary with keys "dtype" and "shape" describing + the tensor. + """ + return self.group.recv_obj(self.target_rank_for_recv) + + def _send_impl(self, tensor: Optional[torch.Tensor]) -> None: + """ + The actual implementation of sending the tensor and its metadata to the + target rank. + + Parameters: + - tensor: The input tensor to be sent, or None if no tensor is + being sent. + """ + metadata = self._make_metadata(tensor) + self._send_metadata(metadata) + if tensor is not None: + self.device_send_func(tensor.to(self.device), + self.target_rank_for_send) + + def _recv_impl(self) -> Optional[torch.Tensor]: + """ + The actual implementation of receiving a tensor and its metadata from + the target rank. + + Returns: + - buffer: The received tensor, or None if no tensor is received. + """ + metadata = self._recv_metadata() + if metadata["dtype"] is None: + return None + buffer = self._prepare_recv_buffer(metadata) + self.device_recv_func(buffer, self.target_rank_for_recv) + + return buffer + + def send_tensor_wrapper(self, tensor: Optional[torch.Tensor], + tensor_size: int) -> None: + """ + Wrapper for _send_impl to handle exceptions and update buffer size. + """ + try: + self._send_impl(tensor) + + with self.buffer_size_lock: + self.buffer_size -= tensor_size + except Exception as e: + logger.error("[rank%d]: Exception when trying to send %s, msg: %s", + torch.distributed.get_rank(), str(tensor), str(e)) + import traceback + traceback.print_exc() + + def block_if_full(self): + """ + Block the current thread if the buffer size is larger than the + threshold. + """ + while self.buffer_size > self.buffer_size_thresh: + logger.debug("KV cache transfer pipe is full. Waiting...") + time.sleep(0.05) + + def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + """ + Sends a tensor and its metadata to the destination rank in a + non-blocking way. + + Parameters: + - tensor: The tensor to send, or None if no tensor is being sent. + """ + if self.transport_thread is None: + self.transport_thread = ThreadPoolExecutor(max_workers=1) + + if tensor is not None: + tensor_size = tensor.element_size() * tensor.numel() + else: + tensor_size = 0 + + self.block_if_full() + + with self.buffer_size_lock: + self.buffer_size += tensor_size + + self.transport_thread.submit(self.send_tensor_wrapper, tensor, + tensor_size) + + def recv_tensor(self) -> Optional[torch.Tensor]: + """ + Receives a tensor and its metadata from the source rank. Blocking call. + + Returns: + - tensor: The received tensor, or None if no tensor is received. + """ + if self.transport_thread is None: + self.transport_thread = ThreadPoolExecutor(max_workers=1) + + future = self.transport_thread.submit(self._recv_impl) + + try: + tensor = future.result() + except Exception as e: + logger.error("Encountering exception in KV receiving thread") + logger.error("%s", e) + logger.error("My device: %s", self.device) + import traceback + traceback.print_exc() + raise e + + return tensor + + def close(self): + """ + Close the pipe and release associated resources. + """ + if hasattr(self, + "transport_thread") and self.transport_thread is not None: + self.transport_thread.shutdown() diff --git a/vllm/distributed/kv_transfer/kv_transfer_agent.py b/vllm/distributed/kv_transfer/kv_transfer_agent.py new file mode 100644 index 000000000000..98ca06138ebd --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_transfer_agent.py @@ -0,0 +1,83 @@ +"""vLLM distributed KV cache transfer API. +These APIs are used in `vllm/worker/model_runner.py`. + +Currently supporting TP. The TP between prefill and decode instance needs to be +the same. + +Workflow (disaggregated prefill) +- In prefill instance + - After prefill, vLLM `insert` its KV caches into a lookup buffer. + - The prefill instance will also open up a thread that listens to + `drop_select` request. +- In decode instance + - vLLM first runs `drop_select` to send input tokens and a mask on input + tokens (we call it roi, region of interest) to prefill instance + - The prefill instance then respond to `drop_select` request by + - Finding a match in current lookup buffer. + - Clone and send the matched item out + - Delete the matched item in the lookup buffer to free up GPU memory. + - The decode vLLM then store the KV cache into paged memory. +""" +from typing import TYPE_CHECKING, List, Tuple, Union + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + +import torch + +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +logger = init_logger(__name__) + + +class KVTransferAgent: + """ + A class designated for distributed KV transfer + + Target use cases: + 1. Disaggregated prefill + 2. Remote KV cache storage + """ + + def __init__( + self, + rank: int, + local_rank: int, + config, + ): + + self.config = config + assert self.config.is_kv_transfer_instance, "KV cache transfer "\ + "agent should only be used when kv_connector is set." + + self.connector = KVConnectorFactory.create_connector( + rank, local_rank, config) + + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + + self.connector.send_kv_caches_and_hidden_states( + model_executable, model_input, kv_caches, + hidden_or_intermediate_states) + + def close(self) -> None: + self.connector.close() + + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor] + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + + return self.connector.recv_kv_caches_and_hidden_states( + model_executable, model_input, kv_caches) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 87ade377266a..b308281494f4 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -27,18 +27,23 @@ from contextlib import contextmanager, nullcontext from dataclasses import dataclass from multiprocessing import shared_memory -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, + Union) from unittest.mock import patch import torch import torch.distributed from torch.distributed import Backend, ProcessGroup +import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer import vllm.envs as envs from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op, supports_custom_op +if TYPE_CHECKING: + from vllm.config import KVTransferConfig + @dataclass class GraphCaptureContext: @@ -942,6 +947,14 @@ def get_pp_group() -> GroupCoordinator: # kept for backward compatibility get_pipeline_model_parallel_group = get_pp_group +_KV_TRANSFER: Optional[kv_transfer.KVTransferAgent] = None + + +def get_kv_transfer_group() -> kv_transfer.KVTransferAgent: + assert _KV_TRANSFER is not None, ( + "disaggregated KV cache transfer parallel group is not initialized") + return _KV_TRANSFER + @contextmanager def graph_capture(): @@ -1090,6 +1103,19 @@ def initialize_model_parallel( group_name="pp") +def ensure_kv_transfer_initialized(config: "KVTransferConfig") -> None: + """ + Initialize KV cache transfer parallel group. + """ + + global _KV_TRANSFER + if config.need_kv_parallel_group and _KV_TRANSFER is None: + _KV_TRANSFER = kv_transfer.KVTransferAgent( + rank=get_world_group().rank, + local_rank=get_world_group().local_rank, + config=config) + + def ensure_model_parallel_initialized( tensor_model_parallel_size: int, pipeline_model_parallel_size: int, diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index dcfcb848cbe0..7cd35d85b893 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -131,6 +131,9 @@ def send_obj(self, obj: Any, dst: int): self.send_dst_counter[dst] += 1 self.entries.append((key, time.time())) + def send(self, tensor: torch.Tensor, dst: int): + self.send_obj(tensor, dst) + def expire_data(self): """Expire data that is older than `data_expiration_seconds` seconds.""" while self.entries: @@ -150,6 +153,15 @@ def recv_obj(self, src: int) -> Any: self.recv_src_counter[src] += 1 return obj + def recv(self, tensor: torch.Tensor, src: int): + """Receive a tensor from a source rank.""" + recv_tensor = self.recv_obj(src) + assert isinstance(recv_tensor, torch.Tensor), "Received object is"\ + " not a tensor." + assert tensor.size() == recv_tensor.size(), "Received tensor size"\ + " does not match the recv buffer size." + tensor[...] = recv_tensor + def broadcast_obj(self, obj: Optional[Any], src: int) -> Any: """Broadcast an object from a source rank to all other ranks. It does not clean up after all ranks have received the object. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 9288cd22c003..76f3037cd2d1 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -14,7 +14,7 @@ ObservabilityConfig, ParallelConfig, PoolerConfig, PromptAdapterConfig, SchedulerConfig, SpeculativeConfig, TaskOption, TokenizerPoolConfig, - VllmConfig) + VllmConfig, KVTransferConfig) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS @@ -107,6 +107,7 @@ class EngineArgs: # notice. distributed_executor_backend: Optional[Union[str, Type[ExecutorBase]]] = None + # number of P/D disaggregation (or other disaggregation) workers pipeline_parallel_size: int = 1 tensor_parallel_size: int = 1 max_parallel_loading_workers: Optional[int] = None @@ -192,6 +193,16 @@ class EngineArgs: override_pooler_config: Optional[PoolerConfig] = None compilation_config: Optional[CompilationConfig] = None + # P/D disaggregation coonfiguration + kv_connector: Optional[str] = None + kv_buffer_size: Optional[float] = 1e9 + kv_buffer_device: Optional[str] = "cuda" + kv_role: Optional[str] = None + kv_rank: Optional[str] = None + kv_parallel_size: int = 1 + kv_ip: str = "127.0.0.1" + kv_port: int = 14579 + def __post_init__(self): if not self.tokenizer: self.tokenizer = self.model @@ -884,6 +895,61 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'To specify the full compilation config, ' 'use a JSON string.') + parser.add_argument( + '--kv-parallel-size', + type=int, + default=EngineArgs.kv_parallel_size, + help="The number of parallel instances for KV cache transfer. " + "For PyNcclConnector, this should be >1.") + + parser.add_argument( + '--kv-connector', + type=str, + default=None, + choices=["PyNcclConnector"], + help="The KV connector for vLLM to transmit KV caches between vLLM" + " instances.") + + parser.add_argument( + '--kv-buffer-size', + type=float, + default=EngineArgs.kv_buffer_size, + help="The buffer size for TorchDistributedConnector. Measured in " + "number of bytes. Recommended value: 1e9 (about 1GB).") + + parser.add_argument( + '--kv-buffer-device', + type=str, + default=EngineArgs.kv_buffer_device, + choices=["cpu", "cuda"], + help="The device used by kv connector to buffer the KV cache. Can " + "be CPU or GPU. Recommended value: CPU.") + + parser.add_argument( + '--kv-role', + type=str, + default=None, + choices=["kv_producer", "kv_consumer", "both"], + help="Whether this vLLM instance produces, consumes KV cache, or " + "both. Choices are 'kv_producer', 'kv_consumer', and 'both'.") + + parser.add_argument( + '--kv-rank', + type=int, + default=None, + help="The rank of this vLLM instance in the KV cache transfer." + " Typical value: 0 for prefill instance, 1 for decode instance.") + + parser.add_argument('--kv-ip', + type=str, + default=EngineArgs.kv_ip, + help="The IP address of the KV cache producer.") + + parser.add_argument('--kv-port', + type=int, + default=EngineArgs.kv_port, + help="The port of the KV cache producer.") + return parser @classmethod @@ -996,7 +1062,18 @@ def create_engine_config(self) -> VllmConfig: self.tokenizer_pool_extra_config, ), ray_workers_use_nsight=self.ray_workers_use_nsight, - distributed_executor_backend=self.distributed_executor_backend) + distributed_executor_backend=self.distributed_executor_backend, + ) + kv_transfer_config = KVTransferConfig( + kv_parallel_size=self.kv_parallel_size, + kv_connector=self.kv_connector, + kv_buffer_size=self.kv_buffer_size, + kv_buffer_device=self.kv_buffer_device, + kv_role=self.kv_role, + kv_rank=self.kv_rank, + kv_ip=self.kv_ip, + kv_port=self.kv_port, + ) max_model_len = model_config.max_model_len use_long_context = max_model_len > 32768 @@ -1163,6 +1240,7 @@ def create_engine_config(self) -> VllmConfig: observability_config=observability_config, prompt_adapter_config=prompt_adapter_config, compilation_config=self.compilation_config, + kv_transfer_config=kv_transfer_config, ) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index ed0360fb7f72..9c58e179d3ca 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -21,7 +21,7 @@ from vllm.compilation.compile_context import set_compile_context from vllm.config import CompilationLevel, VllmConfig from vllm.core.scheduler import SchedulerOutputs -from vllm.distributed import get_pp_group +from vllm.distributed import get_kv_transfer_group, get_pp_group from vllm.distributed.parallel_state import graph_capture from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry @@ -1638,6 +1638,24 @@ def execute_model( else: model_executable = self.model + # Receive KV cache in distributed KV cache transfer setting + # In disagg prefill setting, it will also recv hidden states and bypass + # model forwarding + # In KV cache database setting, it will change the model input so that + # we can skip prefilling on tokens that successfully received KV caches + # NOTE: The receive operation is blocking + bypass_model_exec = False + if self.need_recv_kv(model_input, kv_caches): + hidden_or_intermediate_states, bypass_model_exec, model_input = \ + get_kv_transfer_group().recv_kv_caches_and_hidden_states( + # model is used to know which layer the current worker + # is working on, so that we can receive KV for only those + # layers. + model_executable, + model_input, + kv_caches=kv_caches + ) + multi_modal_kwargs = model_input.multi_modal_kwargs or {} seqlen_agnostic_kwargs = { "finished_requests_ids": model_input.finished_requests_ids, @@ -1649,21 +1667,35 @@ def execute_model( model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_start.record() - with set_forward_context(model_input.attn_metadata): - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **MultiModalKwargs.as_kwargs(multi_modal_kwargs, - device=self.device), - **seqlen_agnostic_kwargs) + if not bypass_model_exec: + with set_forward_context(model_input.attn_metadata): + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalKwargs.as_kwargs(multi_modal_kwargs, + device=self.device), + **seqlen_agnostic_kwargs) if (self.observability_config is not None and self.observability_config.collect_model_forward_time): model_forward_end.record() + # Sending KV cache in distributed KV cache transfer setting + # NOTE: the send operation is non-blocking + if self.need_send_kv(model_input, kv_caches): + get_kv_transfer_group().send_kv_caches_and_hidden_states( + # model_executable is used to know which layer the current + # worker is working on, so that we can send KV for only those + # layers. + model_executable, + model_input, + kv_caches, + hidden_or_intermediate_states, + ) + # Compute the logits in the last pipeline stage. if not get_pp_group().is_last_rank: if (self.is_driver_worker @@ -1731,6 +1763,50 @@ def execute_model( return [output] + def need_recv_kv(self, model_input, kv_caches) -> bool: + """Check if we need to receive kv-cache from the other worker. + We need to receive KV when + 1. current vLLM instance is KV cache consumer/decode vLLM instance + 2. this batch is not a profiling run + 3. this batch is a prefill run + + Args: + model_input: input to the model executable + kv_caches: vLLM's paged memory + """ + + prefill_meta = model_input.attn_metadata.prefill_metadata + + # check if the current run is profiling + is_profile_run = (kv_caches[0].numel() == 0) + # check if the current run is prefill + is_prefill_run = prefill_meta is not None + + return self.vllm_config.kv_transfer_config.is_kv_consumer and ( + not is_profile_run) and is_prefill_run + + def need_send_kv(self, model_input, kv_caches) -> bool: + """Check if we need to send kv-cache to the other worker. + We need to send KV when + 1. current vLLM instance is KV cache producer/prefill vLLM instance + 2. this batch is not a profiling run + 3. this batch is a prefill run + + Args: + model_input: input to the model executable + kv_caches: vLLM's paged memory + """ + + prefill_meta = model_input.attn_metadata.prefill_metadata + + # check if the current run is profiling + is_profile_run = (kv_caches[0].numel() == 0) + # check if the current run is prefill + is_prefill_run = prefill_meta is not None + + return self.vllm_config.kv_transfer_config.is_kv_producer and ( + not is_profile_run) and is_prefill_run + # NOTE: this is nn.Module so the profiler can properly capture/group # kernels calls made within the graph diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 80fd7bc3b67c..05c25f9a5b91 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -8,8 +8,9 @@ import torch.distributed import vllm.envs as envs -from vllm.config import ParallelConfig, VllmConfig -from vllm.distributed import (ensure_model_parallel_initialized, +from vllm.config import KVTransferConfig, ParallelConfig, VllmConfig +from vllm.distributed import (ensure_kv_transfer_initialized, + ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) from vllm.logger import init_logger @@ -143,7 +144,8 @@ def init_device(self) -> None: raise RuntimeError( f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. - init_worker_distributed_environment(self.parallel_config, self.rank, + init_worker_distributed_environment(self.parallel_config, + self.kv_transfer_config, self.rank, self.distributed_init_method, self.local_rank) # Set random seed. @@ -457,6 +459,7 @@ def get_cache_block_size_bytes(self) -> int: def init_worker_distributed_environment( parallel_config: ParallelConfig, + kv_transfer_config: KVTransferConfig, rank: int, distributed_init_method: Optional[str] = None, local_rank: int = -1, @@ -466,10 +469,11 @@ def init_worker_distributed_environment( init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank) - ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) + ensure_kv_transfer_initialized(kv_transfer_config) + def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): # Check if the GPU supports the dtype. diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index cf8a4946a71c..9ca4667f78cc 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -44,6 +44,7 @@ def __init__( self.speculative_config = vllm_config.speculative_config self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config + self.kv_transfer_config = vllm_config.kv_transfer_config @abstractmethod def init_device(self) -> None: From 1eadc9433b4a9c9d63325e42179ae8e89bd6477d Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Wed, 20 Nov 2024 21:39:21 +0000 Subject: [PATCH 02/33] Make ruff happy. Signed-off-by: KuntaiDu --- vllm/engine/arg_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 76f3037cd2d1..d949bc0de407 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -9,12 +9,12 @@ import vllm.envs as envs from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat, - DecodingConfig, DeviceConfig, HfOverrides, LoadConfig, - LoadFormat, LoRAConfig, ModelConfig, - ObservabilityConfig, ParallelConfig, PoolerConfig, - PromptAdapterConfig, SchedulerConfig, + DecodingConfig, DeviceConfig, HfOverrides, + KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig, + ModelConfig, ObservabilityConfig, ParallelConfig, + PoolerConfig, PromptAdapterConfig, SchedulerConfig, SpeculativeConfig, TaskOption, TokenizerPoolConfig, - VllmConfig, KVTransferConfig) + VllmConfig) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS From a36c12c3256d968cb7811dd34d3263add84bfbfb Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Wed, 20 Nov 2024 21:56:34 +0000 Subject: [PATCH 03/33] fix wrong import in tests Signed-off-by: KuntaiDu --- tests/kv_transfer/test_lookup_buffer.py | 10 +++++----- tests/kv_transfer/test_send_recv.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/kv_transfer/test_lookup_buffer.py b/tests/kv_transfer/test_lookup_buffer.py index a323ac731990..7e158b166dc4 100644 --- a/tests/kv_transfer/test_lookup_buffer.py +++ b/tests/kv_transfer/test_lookup_buffer.py @@ -6,9 +6,9 @@ from vllm.config import KVTransferConfig from vllm.distributed.kv_transfer.kv_connector.pynccl_connector.buffer import ( - lookup_buffer as lb) + LookupBuffer) from vllm.distributed.kv_transfer.kv_connector.pynccl_connector.pipe import ( - pynccl_pipe as pnp) + PyNcclPipe) # TODO: the test depends on a lot of fields in the current implementation. # We should have standard interface instead direct field access @@ -136,20 +136,20 @@ def stress_test(my_rank, buf, device): kv_port=12345, ) - data_pipe = pnp.PyNcclPipe( + data_pipe = PyNcclPipe( local_rank=my_rank, config=config, device="cuda", port_offset=0, ) - cpu_pipe = pnp.PyNcclPipe( + cpu_pipe = PyNcclPipe( local_rank=my_rank, config=config, device="cpu", port_offset=1, ) - buffer = lb.LookupBuffer(cpu_pipe, data_pipe, 170000) + buffer = LookupBuffer(cpu_pipe, data_pipe, 170000) test_run(my_rank, buffer, data_pipe.device) diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py index ad791b456c7b..79b55731e79c 100644 --- a/tests/kv_transfer/test_send_recv.py +++ b/tests/kv_transfer/test_send_recv.py @@ -7,7 +7,7 @@ from vllm.config import KVTransferConfig from vllm.distributed.kv_transfer.kv_connector.pynccl_connector import ( - pynccl_pipe as pnp) + pipe as pnp) def test_run(my_rank, pipe): From 6768108f3e25547252ad584cd7a9477ffed154d7 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Wed, 20 Nov 2024 22:14:34 +0000 Subject: [PATCH 04/33] adjust docstring Signed-off-by: KuntaiDu --- .../kv_connector/pynccl_connector/buffer.py | 16 +++++++++------- .../kv_connector/pynccl_connector/pipe.py | 18 ++++++++++-------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/buffer.py b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/buffer.py index 883ff101d4b2..d9febc733c5d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/buffer.py +++ b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/buffer.py @@ -1,11 +1,13 @@ """ - This file implements a simple torch distributed connector by 3 classes: - - `TorchDistributedPipe`: a tensor transmission pipe between vllm instances, - using `torch.distributed` - - `TorchDistributedBuffer`: a buffer to store tensors, implemented on top - of `TorchDistributedPipe` - - `TorchDistributedConnector`: a torch distributed connector between P/D - instance, implemented on top of `TorchDistributedBuffer` + Implements a distributed key-value (KV) cache transfer mechanism for vLLM + instances with buffer management. + + Key Features: + - Distributed KV cache transmission using PyNccl pipes. + - Non-blocking `insert`, blocking `drop_select`. + - Use CPU signal pipe to avoid racing condition + - Handles buffer size constraints and provide backpressure mechanism to + stop the prefill instance when the decode instance is slow. """ import threading import time diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pipe.py b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pipe.py index d74f0967e473..c53696c7c328 100644 --- a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pipe.py +++ b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pipe.py @@ -1,12 +1,14 @@ """ - This file implements a simple PyNccl pipe that can send and receive - Optional[torch.Tensor] between two ranks. - - We will first transmit the metadata, and then the tensor. - Metadata format: - Metadata = Dict[str, Optional[torch.Tensor]] - - "dtype": The data type of the tensor (tensor.dtype) or None - - "shape": The shape of the tensor (tensor.shape) or None + This module implements a PyNccl pipe for sending and receiving + Optional[torch.Tensor] between distributed ranks with advanced + communication features. + + Key Features: + - Supports sending and receiving tensors with metadata + - Handles both CUDA and CPU device communications + - Implements a non-blocking tensor transfer mechanism + - Manages buffer size and provides backpressure control + - Supports distributed process groups with configurable parameters """ import threading From 83e35897d4e0e641cfa6446b32432c73ccec31db Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Wed, 20 Nov 2024 22:26:05 +0000 Subject: [PATCH 05/33] make yapf happy Signed-off-by: KuntaiDu --- tests/kv_transfer/test_send_recv.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py index 79b55731e79c..b6d074532579 100644 --- a/tests/kv_transfer/test_send_recv.py +++ b/tests/kv_transfer/test_send_recv.py @@ -6,8 +6,8 @@ from tqdm import tqdm from vllm.config import KVTransferConfig -from vllm.distributed.kv_transfer.kv_connector.pynccl_connector import ( - pipe as pnp) +from vllm.distributed.kv_transfer.kv_connector.pynccl_connector import (pipe as + pnp) def test_run(my_rank, pipe): From 263677b0712bb6909edf1b5e2b81aa4e6255ad5f Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Wed, 20 Nov 2024 22:30:01 +0000 Subject: [PATCH 06/33] make format checker happy Signed-off-by: KuntaiDu --- tests/kv_transfer/test_send_recv.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py index b6d074532579..c0447bed4912 100644 --- a/tests/kv_transfer/test_send_recv.py +++ b/tests/kv_transfer/test_send_recv.py @@ -6,8 +6,7 @@ from tqdm import tqdm from vllm.config import KVTransferConfig -from vllm.distributed.kv_transfer.kv_connector.pynccl_connector import (pipe as - pnp) +from vllm.distributed.kv_transfer.kv_connector.pynccl_connector import pipe def test_run(my_rank, pipe): @@ -144,7 +143,7 @@ def latency_test(my_rank, pipe, nelement, ntensor): kv_port=12345, ) - pipe = pnp.PyNcclPipe( + pipe = pipe.PyNcclPipe( local_rank=my_rank, config=config, ) From 47a3f79d48e12ee884a5de77f87f1a49cad717f4 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Thu, 21 Nov 2024 07:27:49 +0000 Subject: [PATCH 07/33] add an empty __init__.py file to make sure it is recognized as python package Signed-off-by: KuntaiDu --- .../kv_transfer/kv_connector/pynccl_connector/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 vllm/distributed/kv_transfer/kv_connector/pynccl_connector/__init__.py diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/__init__.py b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 From e84e14a9928be4916d00ec7ebba879574a1bab35 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Thu, 21 Nov 2024 08:21:19 +0000 Subject: [PATCH 08/33] bug fix: handle the case where kv_transfer_config is None Signed-off-by: KuntaiDu --- .buildkite/test-pipeline.yaml | 1 - vllm/distributed/parallel_state.py | 5 ++++- vllm/worker/model_runner.py | 6 ++++++ 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index aaa79a347e5a..3a7aaf4f887b 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -482,7 +482,6 @@ steps: - vllm/worker/worker_base.py - vllm/worker/model_runner.py commands: - - pytest -v -s kv_transfer/module_test.py - pytest -v -s kv_transfer/disagg_test.py - label: LoRA Long Context (Distributed) # 11min diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index b308281494f4..5af0a23dc08c 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1109,7 +1109,10 @@ def ensure_kv_transfer_initialized(config: "KVTransferConfig") -> None: """ global _KV_TRANSFER - if config.need_kv_parallel_group and _KV_TRANSFER is None: + if all([ + config is not None, config.need_kv_parallel_group, + _KV_TRANSFER is None + ]): _KV_TRANSFER = kv_transfer.KVTransferAgent( rank=get_world_group().rank, local_rank=get_world_group().local_rank, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9c58e179d3ca..5c58c0f4a012 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1782,6 +1782,9 @@ def need_recv_kv(self, model_input, kv_caches) -> bool: # check if the current run is prefill is_prefill_run = prefill_meta is not None + if self.vllm_config.kv_transfer_config is None: + return False + return self.vllm_config.kv_transfer_config.is_kv_consumer and ( not is_profile_run) and is_prefill_run @@ -1804,6 +1807,9 @@ def need_send_kv(self, model_input, kv_caches) -> bool: # check if the current run is prefill is_prefill_run = prefill_meta is not None + if self.vllm_config.kv_transfer_config is None: + return False + return self.vllm_config.kv_transfer_config.is_kv_producer and ( not is_profile_run) and is_prefill_run From 0243d7150c5d8903037e14b93ab6955cea93059e Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Thu, 21 Nov 2024 15:17:40 +0000 Subject: [PATCH 09/33] further fix the case when config is None Signed-off-by: KuntaiDu --- vllm/distributed/parallel_state.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 5af0a23dc08c..b170d35d1eaf 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1109,8 +1109,12 @@ def ensure_kv_transfer_initialized(config: "KVTransferConfig") -> None: """ global _KV_TRANSFER + + if config is None: + return + if all([ - config is not None, config.need_kv_parallel_group, + config.need_kv_parallel_group, _KV_TRANSFER is None ]): _KV_TRANSFER = kv_transfer.KVTransferAgent( From fdd605a1e19c6ad32f66aea007072e65021960bf Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Thu, 21 Nov 2024 16:16:46 +0000 Subject: [PATCH 10/33] make yapf happy Signed-off-by: KuntaiDu --- vllm/distributed/parallel_state.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index b170d35d1eaf..0a3bbc7aec17 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1113,10 +1113,7 @@ def ensure_kv_transfer_initialized(config: "KVTransferConfig") -> None: if config is None: return - if all([ - config.need_kv_parallel_group, - _KV_TRANSFER is None - ]): + if all([config.need_kv_parallel_group, _KV_TRANSFER is None]): _KV_TRANSFER = kv_transfer.KVTransferAgent( rank=get_world_group().rank, local_rank=get_world_group().local_rank, From f6ca0bb0f57b7a9ad2f873af8c4c19c85e5ccf51 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 24 Nov 2024 20:50:36 +0000 Subject: [PATCH 11/33] make format checker happy Signed-off-by: KuntaiDu --- vllm/worker/model_runner.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 76a3cb4ec444..f7fb01be69f3 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1668,7 +1668,8 @@ def execute_model( model_forward_start.record() if not bypass_model_exec: - with set_forward_context(model_input.attn_metadata, self.vllm_config): + with set_forward_context(model_input.attn_metadata, + self.vllm_config): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, @@ -1676,7 +1677,7 @@ def execute_model( attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, - device=self.device), + device=self.device), **seqlen_agnostic_kwargs) if (self.observability_config is not None From 2fcb099c56056ed2af6bfd2cfff25e2c4fa29db4 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Thu, 28 Nov 2024 06:13:30 +0000 Subject: [PATCH 12/33] adjust docstring Signed-off-by: KuntaiDu --- .../kv_transfer/kv_connector/base.py | 86 ++----------------- 1 file changed, 7 insertions(+), 79 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/base.py b/vllm/distributed/kv_transfer/kv_connector/base.py index 8254d3fdd533..c8cc3d0c69be 100644 --- a/vllm/distributed/kv_transfer/kv_connector/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/base.py @@ -1,10 +1,9 @@ """ -This file contains a new class `KVLookupBufferBase` that allows developers to -think of KV cache operations as inserting new KV cache entries (`insert`) -into the lookup buffer and querying existing KV caches (`drop_select`) -from the lookup buffer. +KVConnectorBase Class for Distributed KV Cache & Hidden State communication -All distributed communications are abstracted behind this class. +The class provides two primary abstract methods: +1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states +2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states """ from abc import ABC, abstractmethod @@ -23,24 +22,9 @@ class KVConnectorBase(ABC): """ Abstract base class for a KV connector. - This class provides an abstraction for a key-value (KV) cache lookup buffer. - - The key of the lookup buffer: - - input_tokens: token IDs of the request - - roi: a binary mask on top of input_tokens. - - Purpose of roi: Since KV cache may only be available for a subset of - tokens in the input (for example, when vLLM is connected to an external - KV cache service), roi specifies the subset of tokens that the KV cache - is associated with. - - NOTE: roi can be further extended to describe which part of KV the - current process is holding (each process may only hold a part of KV - due to TP and PP). This is not implemented for now. - - The value of the lookup buffer: - - key: the key tensor in the KV cache - - value: the value tensor in the KV cache - - hidden: the final hidden state generated by model forwarding. This allows - vLLM to bypass further model forwarding by transmitting the hidden state. + The class provides two primary abstract methods: + 1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states + 2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states """ @abstractmethod @@ -52,62 +36,6 @@ def __init__( ): raise NotImplementedError - @abstractmethod - def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, - hidden: torch.Tensor) -> None: - """Insert into the lookup buffer, similar to SQL insert - - The functionality is similar to the following python statement - ``` - connector[input_tokens, roi] = [key, value, hidden] - ``` - - FIXME: in the future, we should only have two arguments, key and value, - where key is a tensor dict and value is a tensor dict. - - FIXME: we should transmit both sampler outputs and the hidden states. - - Args: - input_tokens (torch.Tensor): token IDs. - roi (torch.Tensor): A binary mask on top of the input tokens - key (torch.Tensor): The key tensor in the KV cache. - value (torch.Tensor): The value tensor in the KV cache. - hidden (torch.Tensor): The final hidden state tensor generated - during model forwarding to bypass model - forwarding. - - Raises: - NotImplementedError: This method must be implemented in subclasses. - """ - raise NotImplementedError - - @abstractmethod - def select(self, input_tokens: Optional[torch.Tensor], - roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: - """Select KV cache entries from the connector. - - The functionality is similar to the following python statements - ``` - return connector[input_tokens, roi] - ``` - - If `input_tokens` and `roi` is `None`, it means selecting any of the - KV caches in the buffer, return, and remove it from the buffer, useful - when offloading KV cache to KV cache storage service. - - Args: - input_tokens (torch.Tensor): token IDs. - roi (torch.Tensor): A binary mask on top of the input tokens - - Returns: - List[Optional[torch.Tensor]]: A list of tensors. Can be None. - - Raises: - NotImplementedError: This method must be implemented in subclasses. - """ - raise NotImplementedError - @abstractmethod def close(self) -> None: """Close the buffer and release resources. From 68ef9301d3b4d503ee4cc8a0e64c2557ec7d756a Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Thu, 28 Nov 2024 06:19:00 +0000 Subject: [PATCH 13/33] align default value for kv-buffer-device Signed-off-by: KuntaiDu --- vllm/engine/arg_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 52e2534e5a6d..4e51b8fbdf9f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -933,7 +933,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.kv_buffer_device, choices=["cpu", "cuda"], help="The device used by kv connector to buffer the KV cache. Can " - "be CPU or GPU. Recommended value: CPU.") + "be cpu or cuda. Recommended value: cuda.") parser.add_argument( '--kv-role', From c0d8d22d050bb256c7b84cfef2547bd3977218a1 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Thu, 28 Nov 2024 07:17:02 +0000 Subject: [PATCH 14/33] Adjust implementation for CPU send recv in statelessprocessgroup Signed-off-by: KuntaiDu --- tests/kv_transfer/test_send_recv.sh | 4 ++-- vllm/distributed/utils.py | 28 ++++++++++++++++++++-------- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/tests/kv_transfer/test_send_recv.sh b/tests/kv_transfer/test_send_recv.sh index 935487bd00d6..1e89e246b499 100644 --- a/tests/kv_transfer/test_send_recv.sh +++ b/tests/kv_transfer/test_send_recv.sh @@ -1,3 +1,3 @@ #!/bin/bash -RANK=0 python test_send_recv.py & -RANK=1 python test_send_recv.py & \ No newline at end of file +RANK=0 python3 test_send_recv.py & +RANK=1 python3 test_send_recv.py & \ No newline at end of file diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 7cd35d85b893..c66c549138b3 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -132,7 +132,15 @@ def send_obj(self, obj: Any, dst: int): self.entries.append((key, time.time())) def send(self, tensor: torch.Tensor, dst: int): - self.send_obj(tensor, dst) + """Send out a tensor to destination rank. + This function implements send using CPU communication, which is + needed when the KV cache buffer is placed on CPU in distributed KV + cache communication.""" + self.expire_data() + key = f"send_to/{dst}/{self.send_dst_counter[dst]}" + self.store.set(key, tensor.numpy().tobytes()) + self.send_dst_counter[dst] += 1 + self.entries.append((key, time.time())) def expire_data(self): """Expire data that is older than `data_expiration_seconds` seconds.""" @@ -154,13 +162,17 @@ def recv_obj(self, src: int) -> Any: return obj def recv(self, tensor: torch.Tensor, src: int): - """Receive a tensor from a source rank.""" - recv_tensor = self.recv_obj(src) - assert isinstance(recv_tensor, torch.Tensor), "Received object is"\ - " not a tensor." - assert tensor.size() == recv_tensor.size(), "Received tensor size"\ - " does not match the recv buffer size." - tensor[...] = recv_tensor + """Receive a tensor from source rank. + This function implements recv using CPU communication, which is + needed when the KV cache buffer is placed on CPU in distributed KV + cache communication.""" + received_tensor = torch.frombuffer( + self.store.get( + f"send_to/{self.rank}/{self.recv_src_counter[src]}"), + dtype=tensor.dtype + ).reshape(tensor.shape) + self.recv_src_counter[src] += 1 + tensor[...] = received_tensor def broadcast_obj(self, obj: Optional[Any], src: int) -> Any: """Broadcast an object from a source rank to all other ranks. From b9315ec92dd264d535563f1d94cb896f96a3ed6c Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Thu, 28 Nov 2024 07:43:22 +0000 Subject: [PATCH 15/33] send VllmConfig instead of KVTransferConfig to connector so that it can access parallelconfig in the future. Signed-off-by: KuntaiDu --- .../kv_transfer/disagg_prefill_example.sh | 13 ++++++++++ .../kv_transfer/kv_connector/base.py | 4 +-- .../kv_transfer/kv_connector/factory.py | 8 ++++-- .../pynccl_connector/connector.py | 26 ++++++++++--------- .../kv_transfer/kv_transfer_agent.py | 12 ++++++--- vllm/distributed/parallel_state.py | 10 ++++--- vllm/worker/worker.py | 9 +++---- 7 files changed, 54 insertions(+), 28 deletions(-) diff --git a/examples/kv_transfer/disagg_prefill_example.sh b/examples/kv_transfer/disagg_prefill_example.sh index e6c9d17227c7..676f474bbc48 100644 --- a/examples/kv_transfer/disagg_prefill_example.sh +++ b/examples/kv_transfer/disagg_prefill_example.sh @@ -3,6 +3,19 @@ # We will launch 2 vllm instances (1 for prefill and 1 for decode), # and then transfer the KV cache between them. +# Trap the SIGINT signal (triggered by Ctrl+C) +trap 'cleanup' INT + +# Cleanup function +cleanup() { + echo "Caught Ctrl+C, cleaning up..." + # Cleanup commands, suppressing their output + pgrep pt_main_thread | xargs kill -9 > /dev/null 2>&1 + pkill -f python3 > /dev/null 2>&1 + echo "Cleanup complete. Exiting." + exit 0 +} + export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') # install quart first -- required for disagg prefill proxy serve diff --git a/vllm/distributed/kv_transfer/kv_connector/base.py b/vllm/distributed/kv_transfer/kv_connector/base.py index c8cc3d0c69be..acf6c1a25806 100644 --- a/vllm/distributed/kv_transfer/kv_connector/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/base.py @@ -14,7 +14,7 @@ from vllm.sequence import IntermediateTensors if TYPE_CHECKING: - from vllm.config import KVTransferConfig + from vllm.config import VllmConfig from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata @@ -32,7 +32,7 @@ def __init__( self, rank: int, local_rank: int, - config: "KVTransferConfig", + config: "VllmConfig", ): raise NotImplementedError diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index ac8bf788b39c..99fc36c1fcdf 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -1,12 +1,16 @@ from .base import KVConnectorBase +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from vllm.config import VllmConfig class KVConnectorFactory: @staticmethod def create_connector(rank: int, local_rank: int, - config) -> KVConnectorBase: - if config.kv_connector == 'PyNcclConnector': + config: "VllmConfig") -> KVConnectorBase: + if config.kv_transfer_config.kv_connector == 'PyNcclConnector': from .pynccl_connector.connector import PyNcclConnector return PyNcclConnector(rank, local_rank, config) else: diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/connector.py b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/connector.py index fe45adba0605..9b9c7f950472 100644 --- a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/connector.py @@ -12,7 +12,7 @@ import torch from vllm import _custom_ops as ops -from vllm.config import KVTransferConfig +from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase from vllm.distributed.kv_transfer.kv_connector.pynccl_connector.buffer import ( LookupBuffer) @@ -33,10 +33,10 @@ def __init__( self, rank: int, local_rank: int, - config: KVTransferConfig, + config: VllmConfig, ): - self.config = config + self.config = config.kv_transfer_config self.lookup_buffer_size = self.config.kv_buffer_size @@ -48,22 +48,24 @@ def __init__( # In disaggregated prefill, the prefill vLLM only uses send pipe # and the decode vLLM only uses recv pipe - if config.is_kv_producer: + if self.config.is_kv_producer: self.producer_data_pipe = PyNcclPipe( local_rank=local_rank, - config=config, + config=self.config, port_offset=port_offset_base, ) self.producer_signal_pipe = PyNcclPipe( local_rank=local_rank, - config=config, + config=self.config, port_offset=port_offset_base + 1, device="cpu", ) - self.producer_buffer = LookupBuffer(self.producer_signal_pipe, - self.producer_data_pipe, - config.kv_buffer_size) + self.producer_buffer = LookupBuffer( + self.producer_signal_pipe, + self.producer_data_pipe, + self.config.kv_buffer_size + ) else: @@ -71,19 +73,19 @@ def __init__( # its recv pipe to the send pipe of KV producder self.consumer_data_pipe = PyNcclPipe( local_rank=local_rank, - config=config, + config=self.config, port_offset=port_offset_base, ) self.consumer_signal_pipe = PyNcclPipe( local_rank=local_rank, - config=config, + config=self.config, port_offset=port_offset_base + 1, device="cpu", ) self.consumer_buffer = LookupBuffer( self.consumer_signal_pipe, self.consumer_data_pipe, - config.kv_buffer_size, + self.config.kv_buffer_size, ) def select(self, input_tokens: Optional[torch.Tensor], diff --git a/vllm/distributed/kv_transfer/kv_transfer_agent.py b/vllm/distributed/kv_transfer/kv_transfer_agent.py index 98ca06138ebd..36d4a7fda424 100644 --- a/vllm/distributed/kv_transfer/kv_transfer_agent.py +++ b/vllm/distributed/kv_transfer/kv_transfer_agent.py @@ -22,6 +22,7 @@ if TYPE_CHECKING: from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + from vllm.config import VllmConfig import torch @@ -46,12 +47,17 @@ def __init__( self, rank: int, local_rank: int, - config, + config: "VllmConfig", ): self.config = config - assert self.config.is_kv_transfer_instance, "KV cache transfer "\ - "agent should only be used when kv_connector is set." + + if config.kv_transfer_config is None: + raise ValueError("KVTransferConfig is not set in the VllmConfig," + " cannot initialize KVConnector.") + + assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\ + "TransferAgent should only be used when kv_connector is set." self.connector = KVConnectorFactory.create_connector( rank, local_rank, config) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 0a3bbc7aec17..4e4d3f10e084 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -42,7 +42,7 @@ from vllm.utils import direct_register_custom_op, supports_custom_op if TYPE_CHECKING: - from vllm.config import KVTransferConfig + from vllm.config import VllmConfig @dataclass @@ -1103,17 +1103,19 @@ def initialize_model_parallel( group_name="pp") -def ensure_kv_transfer_initialized(config: "KVTransferConfig") -> None: +def ensure_kv_transfer_initialized(config: "VllmConfig") -> None: """ Initialize KV cache transfer parallel group. """ global _KV_TRANSFER - if config is None: + if config.kv_transfer_config is None: return - if all([config.need_kv_parallel_group, _KV_TRANSFER is None]): + if all([ + config.kv_transfer_config.need_kv_parallel_group, + _KV_TRANSFER is None]): _KV_TRANSFER = kv_transfer.KVTransferAgent( rank=get_world_group().rank, local_rank=get_world_group().local_rank, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 05c25f9a5b91..9d163c8fe699 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -144,8 +144,7 @@ def init_device(self) -> None: raise RuntimeError( f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. - init_worker_distributed_environment(self.parallel_config, - self.kv_transfer_config, self.rank, + init_worker_distributed_environment(self.vllm_config, self.rank, self.distributed_init_method, self.local_rank) # Set random seed. @@ -458,13 +457,13 @@ def get_cache_block_size_bytes(self) -> int: def init_worker_distributed_environment( - parallel_config: ParallelConfig, - kv_transfer_config: KVTransferConfig, + config: VllmConfig, rank: int, distributed_init_method: Optional[str] = None, local_rank: int = -1, ) -> None: """Initialize the distributed environment.""" + parallel_config = config.parallel_config set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) init_distributed_environment(parallel_config.world_size, rank, @@ -472,7 +471,7 @@ def init_worker_distributed_environment( ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) - ensure_kv_transfer_initialized(kv_transfer_config) + ensure_kv_transfer_initialized(config) def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): From dda64b92373a497347bf10f85a8b279904e1693e Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Thu, 28 Nov 2024 07:51:49 +0000 Subject: [PATCH 16/33] add -X POST flag to curl -- thanks @coolkp Signed-off-by: KuntaiDu --- examples/kv_transfer/disagg_prefill_example.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/kv_transfer/disagg_prefill_example.sh b/examples/kv_transfer/disagg_prefill_example.sh index 676f474bbc48..0264e57bad3a 100644 --- a/examples/kv_transfer/disagg_prefill_example.sh +++ b/examples/kv_transfer/disagg_prefill_example.sh @@ -76,7 +76,7 @@ python3 ../../benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py & sleep 1 # serve two example requests -output1=$(curl -s http://localhost:8000/v1/completions \ +output1=$(curl -X POST -s http://localhost:8000/v1/completions \ -H "Content-Type: application/json" \ -d '{ "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", @@ -85,7 +85,7 @@ output1=$(curl -s http://localhost:8000/v1/completions \ "temperature": 0 }') -output2=$(curl -s http://localhost:8000/v1/completions \ +output2=$(curl -X POST -s http://localhost:8000/v1/completions \ -H "Content-Type: application/json" \ -d '{ "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", @@ -106,5 +106,5 @@ echo "" echo "Output of first request: $output1" echo "Output of second request: $output2" -echo "Successfully finished 2 test requests!" +echo "🎉🎉 Successfully finished 2 test requests! 🎉🎉" echo "" From 25aab25391e07fa64bc8b8cf9741851387a8004b Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Thu, 28 Nov 2024 08:01:04 +0000 Subject: [PATCH 17/33] lowering the max model len, to make sure the example is runnable on low-end GPU Signed-off-by: KuntaiDu --- examples/kv_transfer/disagg_prefill_example.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/kv_transfer/disagg_prefill_example.sh b/examples/kv_transfer/disagg_prefill_example.sh index 0264e57bad3a..0d9f9f111982 100644 --- a/examples/kv_transfer/disagg_prefill_example.sh +++ b/examples/kv_transfer/disagg_prefill_example.sh @@ -43,7 +43,7 @@ CUDA_VISIBLE_DEVICES=0 python3 \ -m vllm.entrypoints.openai.api_server \ --model meta-llama/Meta-Llama-3.1-8B-Instruct \ --port 8100 \ - --max-model-len 10000 \ + --max-model-len 100 \ --gpu-memory-utilization 0.8 \ --kv-connector PyNcclConnector \ --kv-role kv_producer \ @@ -55,7 +55,7 @@ CUDA_VISIBLE_DEVICES=1 python3 \ -m vllm.entrypoints.openai.api_server \ --model meta-llama/Meta-Llama-3.1-8B-Instruct \ --port 8200 \ - --max-model-len 10000 \ + --max-model-len 100 \ --gpu-memory-utilization 0.8 \ --kv-connector PyNcclConnector \ --kv-role kv_consumer \ From 87eaea2ebea92d218e93960bbb01403dd6fc859f Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Thu, 28 Nov 2024 08:03:46 +0000 Subject: [PATCH 18/33] make format checker happy Signed-off-by: KuntaiDu --- vllm/distributed/kv_transfer/kv_connector/base.py | 2 +- vllm/distributed/kv_transfer/kv_connector/factory.py | 5 +++-- .../kv_connector/pynccl_connector/connector.py | 8 +++----- vllm/distributed/parallel_state.py | 7 ++++--- vllm/distributed/utils.py | 11 +++++------ vllm/worker/worker.py | 2 +- 6 files changed, 17 insertions(+), 18 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/base.py b/vllm/distributed/kv_transfer/kv_connector/base.py index acf6c1a25806..6089e3babac3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/base.py @@ -7,7 +7,7 @@ """ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Tuple, Union import torch diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 99fc36c1fcdf..2d495191855c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -1,6 +1,7 @@ -from .base import KVConnectorBase from typing import TYPE_CHECKING +from .base import KVConnectorBase + if TYPE_CHECKING: from vllm.config import VllmConfig @@ -9,7 +10,7 @@ class KVConnectorFactory: @staticmethod def create_connector(rank: int, local_rank: int, - config: "VllmConfig") -> KVConnectorBase: + config: "VllmConfig") -> KVConnectorBase: if config.kv_transfer_config.kv_connector == 'PyNcclConnector': from .pynccl_connector.connector import PyNcclConnector return PyNcclConnector(rank, local_rank, config) diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/connector.py b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/connector.py index 9b9c7f950472..69fec5fe711d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/connector.py @@ -61,11 +61,9 @@ def __init__( port_offset=port_offset_base + 1, device="cpu", ) - self.producer_buffer = LookupBuffer( - self.producer_signal_pipe, - self.producer_data_pipe, - self.config.kv_buffer_size - ) + self.producer_buffer = LookupBuffer(self.producer_signal_pipe, + self.producer_data_pipe, + self.config.kv_buffer_size) else: diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 4e4d3f10e084..7f15d8119bef 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -42,7 +42,7 @@ from vllm.utils import direct_register_custom_op, supports_custom_op if TYPE_CHECKING: - from vllm.config import VllmConfig + from vllm.config import VllmConfig @dataclass @@ -1114,8 +1114,9 @@ def ensure_kv_transfer_initialized(config: "VllmConfig") -> None: return if all([ - config.kv_transfer_config.need_kv_parallel_group, - _KV_TRANSFER is None]): + config.kv_transfer_config.need_kv_parallel_group, + _KV_TRANSFER is None + ]): _KV_TRANSFER = kv_transfer.KVTransferAgent( rank=get_world_group().rank, local_rank=get_world_group().local_rank, diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index c66c549138b3..6c63481b2f46 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -138,7 +138,7 @@ def send(self, tensor: torch.Tensor, dst: int): cache communication.""" self.expire_data() key = f"send_to/{dst}/{self.send_dst_counter[dst]}" - self.store.set(key, tensor.numpy().tobytes()) + self.store.set(key, tensor.numpy().tobytes()) self.send_dst_counter[dst] += 1 self.entries.append((key, time.time())) @@ -166,11 +166,10 @@ def recv(self, tensor: torch.Tensor, src: int): This function implements recv using CPU communication, which is needed when the KV cache buffer is placed on CPU in distributed KV cache communication.""" - received_tensor = torch.frombuffer( - self.store.get( - f"send_to/{self.rank}/{self.recv_src_counter[src]}"), - dtype=tensor.dtype - ).reshape(tensor.shape) + received_tensor = torch.frombuffer(self.store.get( + f"send_to/{self.rank}/{self.recv_src_counter[src]}"), + dtype=tensor.dtype).reshape( + tensor.shape) self.recv_src_counter[src] += 1 tensor[...] = received_tensor diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 9d163c8fe699..04e3d1c14452 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -8,7 +8,7 @@ import torch.distributed import vllm.envs as envs -from vllm.config import KVTransferConfig, ParallelConfig, VllmConfig +from vllm.config import VllmConfig from vllm.distributed import (ensure_kv_transfer_initialized, ensure_model_parallel_initialized, init_distributed_environment, From 6b0c370f6a7dbd0d036e4067fe673c0d26e3ad7f Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Fri, 29 Nov 2024 05:49:52 +0000 Subject: [PATCH 19/33] Merge kv transfer config into a single CLI, and people can add value by using JSON string. Signed-off-by: KuntaiDu --- .../disagg_overhead_benchmark.sh | 14 +--- .../disagg_performance_benchmark.sh | 16 ++-- .../kv_transfer/disagg_prefill_example.sh | 12 +-- vllm/config.py | 71 +++++++++++----- .../pynccl_connector/connector.py | 3 + vllm/engine/arg_utils.py | 84 ++----------------- 6 files changed, 75 insertions(+), 125 deletions(-) diff --git a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh index e7d30001e850..2924ea4a49f5 100644 --- a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh @@ -58,11 +58,8 @@ benchmark() { --port 8100 \ --max-model-len 10000 \ --gpu-memory-utilization 0.6 \ - --kv-connector PyNcclConnector \ - --kv-role kv_producer \ - --kv-rank 0 \ - --kv-parallel-size 2 \ - --kv-buffer-size 1e10 & + --kv-transfer-config \ + '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' & CUDA_VISIBLE_DEVICES=1 python3 \ @@ -71,11 +68,8 @@ benchmark() { --port 8200 \ --max-model-len 10000 \ --gpu-memory-utilization 0.6 \ - --kv-connector PyNcclConnector \ - --kv-role kv_consumer \ - --kv-rank 1 \ - --kv-parallel-size 2 \ - --kv-buffer-size 1e10 & + --kv-transfer-config \ + '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' & wait_for_server 8100 wait_for_server 8200 diff --git a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh index 2bd7aa14de5f..d8d9e976dce7 100644 --- a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh @@ -67,22 +67,18 @@ launch_disagg_prefill() { --port 8100 \ --max-model-len 10000 \ --gpu-memory-utilization 0.6 \ - --kv-connector PyNcclConnector \ - --kv-role kv_producer \ - --kv-rank 0 \ - --kv-parallel-size 2 \ - --kv-buffer-size 5e9 & + --kv-transfer-config \ + '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' & + CUDA_VISIBLE_DEVICES=1 python3 \ -m vllm.entrypoints.openai.api_server \ --model $model \ --port 8200 \ --max-model-len 10000 \ --gpu-memory-utilization 0.6 \ - --kv-connector PyNcclConnector \ - --kv-role kv_consumer \ - --kv-rank 1 \ - --kv-parallel-size 2 \ - --kv-buffer-size 5e9 & + --kv-transfer-config \ + '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' & + wait_for_server 8100 wait_for_server 8200 python3 disagg_prefill_proxy_server.py & diff --git a/examples/kv_transfer/disagg_prefill_example.sh b/examples/kv_transfer/disagg_prefill_example.sh index 0d9f9f111982..0fd7aa42b01f 100644 --- a/examples/kv_transfer/disagg_prefill_example.sh +++ b/examples/kv_transfer/disagg_prefill_example.sh @@ -45,10 +45,8 @@ CUDA_VISIBLE_DEVICES=0 python3 \ --port 8100 \ --max-model-len 100 \ --gpu-memory-utilization 0.8 \ - --kv-connector PyNcclConnector \ - --kv-role kv_producer \ - --kv-rank 0 \ - --kv-parallel-size 2 & + --kv-transfer-config \ + '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' & # decoding instance, which is the KV consumer CUDA_VISIBLE_DEVICES=1 python3 \ @@ -57,10 +55,8 @@ CUDA_VISIBLE_DEVICES=1 python3 \ --port 8200 \ --max-model-len 100 \ --gpu-memory-utilization 0.8 \ - --kv-connector PyNcclConnector \ - --kv-role kv_consumer \ - --kv-rank 1 \ - --kv-parallel-size 2 & + --kv-transfer-config \ + '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' & # wait until prefill and decode instances are ready wait_for_server 8100 diff --git a/vllm/config.py b/vllm/config.py index 611008672e05..3accc72fe9e5 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2055,19 +2055,66 @@ def __post_init__(self): @dataclass -class KVTransferConfig: +class KVTransferConfig(BaseModel): """Configuration for distributed KV cache transfer.""" - # NOTE: these default values should align with EngineArgs + # The KV connector for vLLM to transmit KV caches between vLLM instances. kv_connector: Optional[str] = None - kv_buffer_device: Optional[str] = None + + # The device used by kv connector to buffer the KV cache. Can be cpu or + # cuda. Recommended value: cuda. + kv_buffer_device: Optional[str] = "cuda" + + # The buffer size for TorchDistributedConnector. Measured in number of + # bytes. Recommended value: 1e9 (about 1GB). kv_buffer_size: float = 1e9 + + # Whether this vLLM instance produces, consumes KV cache, or both. Choices + # are 'kv_producer', 'kv_consumer', and 'both'. kv_role: Optional[str] = None + + # The rank of this vLLM instance in the KV cache transfer. Typical value: + # 0 for prefill instance, 1 for decode instance. + # Currently only 1P1D is supported. kv_rank: Optional[int] = None + + # The number of parallel instances for KV cache transfer. For + # PyNcclConnector, this should be 2. kv_parallel_size: int = 1 + + # The KV connector ip, used to build distributed connection kv_ip: str = "127.0.0.1" + + # The KV connector port, used to build distributed connection kv_port: int = 14579 + @classmethod + def from_cli(cls, cli_value: str) -> "KVTransferConfig": + """Parse the CLI value for the compilation config.""" + return KVTransferConfig.model_validate_json(cli_value) + + def model_post_init(self, __context: Any) -> None: + if all([ + self.kv_connector is not None, + self.kv_connector != "PyNcclConnector" + ]): + raise ValueError(f"Unsupported kv_connector: {self.kv_connector}. " + f"Supported connectors are " + f"`PyNcclConnector`.") + + if self.kv_role is not None and self.kv_role not in [ + "kv_producer", "kv_consumer", "kv_both" + ]: + raise ValueError( + f"Unsupported kv_role: {self.kv_role}. " + f"Supported roles are `kv_producer`, `kv_consumer`, " + f"and `kv_both`") + + if self.kv_connector is not None and self.kv_role is None: + raise ValueError("Please specify kv_disagg_role when kv_connector " + "is set, supported roles are `kv_producer`, " + "`kv_consumer`, and `kv_both`") + @property def is_kv_transfer_instance(self) -> bool: return self.kv_connector is not None and \ @@ -2089,24 +2136,6 @@ def is_kv_consumer(self) -> bool: return self.kv_connector is not None and \ self.kv_role in ["kv_consumer", "kv_both"] - def __post_init__(self): - - if self.kv_connector not in [None, "PyNcclConnector"]: - raise ValueError(f"Unsupported kv_connector: {self.kv_connector}. " - f"Supported connectors are " - f"`PyNcclConnector`.") - - if self.kv_role not in [None, "kv_producer", "kv_consumer", "kv_both"]: - raise ValueError( - f"Unsupported kv_role: {self.kv_role}. " - f"Supported roles are `kv_producer`, `kv_consumer`, " - f"and `kv_both`") - - if self.kv_connector is not None and self.kv_role is None: - raise ValueError("Please specify kv_disagg_role when kv_connector " - "is set, supported roles are `kv_producer`, " - "`kv_consumer`, and `kv_both`") - class CompilationLevel: # constants for the levels of the compilation process diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/connector.py b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/connector.py index 69fec5fe711d..be22077d410b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/connector.py @@ -38,6 +38,9 @@ def __init__( self.config = config.kv_transfer_config + logger.info("Initializing PyNcclConfig under kv_transfer_config %s", + self.config) + self.lookup_buffer_size = self.config.kv_buffer_size self.producer_buffer: Optional[LookupBuffer] = None diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 4e51b8fbdf9f..5b12fc38f5ce 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -194,15 +194,7 @@ class EngineArgs: compilation_config: Optional[CompilationConfig] = None worker_cls: str = "auto" - # P/D disaggregation coonfiguration - kv_connector: Optional[str] = None - kv_buffer_size: Optional[float] = 1e9 - kv_buffer_device: Optional[str] = "cuda" - kv_role: Optional[str] = None - kv_rank: Optional[str] = None - kv_parallel_size: int = 1 - kv_ip: str = "127.0.0.1" - kv_port: int = 14579 + kv_transfer_config: Optional[KVTransferConfig] = None def __post_init__(self): if not self.tokenizer: @@ -899,67 +891,18 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'compilers, using -O without space is also ' 'supported. -O3 is equivalent to -O 3.') + parser.add_argument('--kv-transfer-config', + type=KVTransferConfig.from_cli, + default=None, + help='The configurations for distributed KV cache ' + 'transfer. Should be a JSON string.') + parser.add_argument( '--worker-cls', type=str, default="auto", help='The worker class to use for distributed execution.') - parser.add_argument( - '--kv-parallel-size', - type=int, - default=EngineArgs.kv_parallel_size, - help="The number of parallel instances for KV cache transfer. " - "For PyNcclConnector, this should be >1.") - - parser.add_argument( - '--kv-connector', - type=str, - default=None, - choices=["PyNcclConnector"], - help="The KV connector for vLLM to transmit KV caches between vLLM" - " instances.") - - parser.add_argument( - '--kv-buffer-size', - type=float, - default=EngineArgs.kv_buffer_size, - help="The buffer size for TorchDistributedConnector. Measured in " - "number of bytes. Recommended value: 1e9 (about 1GB).") - - parser.add_argument( - '--kv-buffer-device', - type=str, - default=EngineArgs.kv_buffer_device, - choices=["cpu", "cuda"], - help="The device used by kv connector to buffer the KV cache. Can " - "be cpu or cuda. Recommended value: cuda.") - - parser.add_argument( - '--kv-role', - type=str, - default=None, - choices=["kv_producer", "kv_consumer", "both"], - help="Whether this vLLM instance produces, consumes KV cache, or " - "both. Choices are 'kv_producer', 'kv_consumer', and 'both'.") - - parser.add_argument( - '--kv-rank', - type=int, - default=None, - help="The rank of this vLLM instance in the KV cache transfer." - " Typical value: 0 for prefill instance, 1 for decode instance.") - - parser.add_argument('--kv-ip', - type=str, - default=EngineArgs.kv_ip, - help="The IP address of the KV cache producer.") - - parser.add_argument('--kv-port', - type=int, - default=EngineArgs.kv_port, - help="The port of the KV cache producer.") - return parser @classmethod @@ -1075,17 +1018,6 @@ def create_engine_config(self) -> VllmConfig: distributed_executor_backend=self.distributed_executor_backend, worker_cls=self.worker_cls, ) - kv_transfer_config = KVTransferConfig( - kv_parallel_size=self.kv_parallel_size, - kv_connector=self.kv_connector, - kv_buffer_size=self.kv_buffer_size, - kv_buffer_device=self.kv_buffer_device, - kv_role=self.kv_role, - kv_rank=self.kv_rank, - kv_ip=self.kv_ip, - kv_port=self.kv_port, - ) - max_model_len = model_config.max_model_len use_long_context = max_model_len > 32768 if self.enable_chunked_prefill is None: @@ -1251,7 +1183,7 @@ def create_engine_config(self) -> VllmConfig: observability_config=observability_config, prompt_adapter_config=prompt_adapter_config, compilation_config=self.compilation_config, - kv_transfer_config=kv_transfer_config, + kv_transfer_config=self.kv_transfer_config, ) From 11bd1158c12f7a5dc146798267222b8f2601b8af Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Fri, 29 Nov 2024 05:53:04 +0000 Subject: [PATCH 20/33] clean up diff Signed-off-by: KuntaiDu --- vllm/engine/arg_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 5b12fc38f5ce..fb315c3afa1e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1018,6 +1018,7 @@ def create_engine_config(self) -> VllmConfig: distributed_executor_backend=self.distributed_executor_backend, worker_cls=self.worker_cls, ) + max_model_len = model_config.max_model_len use_long_context = max_model_len > 32768 if self.enable_chunked_prefill is None: From bef81719f31ed385764e6c2542ad61a31f4cdd4e Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Fri, 29 Nov 2024 06:14:10 +0000 Subject: [PATCH 21/33] Syntax adjustments for disagg prefill example. Signed-off-by: KuntaiDu --- ...g_prefill_example.sh => disagg_prefill.sh} | 24 ++++++++----------- 1 file changed, 10 insertions(+), 14 deletions(-) rename examples/{kv_transfer/disagg_prefill_example.sh => disagg_prefill.sh} (80%) diff --git a/examples/kv_transfer/disagg_prefill_example.sh b/examples/disagg_prefill.sh similarity index 80% rename from examples/kv_transfer/disagg_prefill_example.sh rename to examples/disagg_prefill.sh index 0fd7aa42b01f..f9bfc294cd23 100644 --- a/examples/kv_transfer/disagg_prefill_example.sh +++ b/examples/disagg_prefill.sh @@ -9,9 +9,8 @@ trap 'cleanup' INT # Cleanup function cleanup() { echo "Caught Ctrl+C, cleaning up..." - # Cleanup commands, suppressing their output - pgrep pt_main_thread | xargs kill -9 > /dev/null 2>&1 - pkill -f python3 > /dev/null 2>&1 + # Cleanup commands + pgrep python | xargs kill -9 echo "Cleanup complete. Exiting." exit 0 } @@ -39,9 +38,7 @@ wait_for_server() { # You can also adjust --kv-ip and --kv-port for distributed inference. # prefilling instance, which is the KV producer -CUDA_VISIBLE_DEVICES=0 python3 \ - -m vllm.entrypoints.openai.api_server \ - --model meta-llama/Meta-Llama-3.1-8B-Instruct \ +CUDA_VISIBLE_DEVICES=0 vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \ --port 8100 \ --max-model-len 100 \ --gpu-memory-utilization 0.8 \ @@ -49,9 +46,7 @@ CUDA_VISIBLE_DEVICES=0 python3 \ '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' & # decoding instance, which is the KV consumer -CUDA_VISIBLE_DEVICES=1 python3 \ - -m vllm.entrypoints.openai.api_server \ - --model meta-llama/Meta-Llama-3.1-8B-Instruct \ +CUDA_VISIBLE_DEVICES=1 vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \ --port 8200 \ --max-model-len 100 \ --gpu-memory-utilization 0.8 \ @@ -68,7 +63,7 @@ wait_for_server 8200 # to 1 # - after the prefill vLLM finishes prefill, send the request to decode vLLM # instance -python3 ../../benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py & +python3 ../benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py & sleep 1 # serve two example requests @@ -91,11 +86,12 @@ output2=$(curl -X POST -s http://localhost:8000/v1/completions \ }') -# Cleanup commands, suppressing their output -pgrep pt_main_thread | xargs kill -9 > /dev/null 2>&1 -pkill -f python3 > /dev/null 2>&1 +# Cleanup commands +pgrep python | xargs kill -9 -sleep 4 +echo "" + +sleep 1 # Print the outputs of the curl requests echo "" From f826c44404190e142b000bf1284116f5ff9377f5 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Fri, 29 Nov 2024 06:24:30 +0000 Subject: [PATCH 22/33] Adjust test---we now pass '--kv-transfer-config' instead of multiple CLI args. Signed-off-by: KuntaiDu --- tests/kv_transfer/disagg_test.py | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/tests/kv_transfer/disagg_test.py b/tests/kv_transfer/disagg_test.py index d86be2eaa5d0..d240a58d459c 100644 --- a/tests/kv_transfer/disagg_test.py +++ b/tests/kv_transfer/disagg_test.py @@ -35,14 +35,9 @@ def setup_servers(): "0.5", "--max-model-len", "1000", - "--kv-connector", - "PyNcclConnector", - "--kv-role", - "kv_producer", - "--kv-rank", - "0", - "--kv-parallel-size", - "2", + "--kv-transfer-config", + '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer",'\ + '"kv_rank":0,"kv_parallel_size":2}', ] prefill_env = os.environ.copy() prefill_env["CUDA_VISIBLE_DEVICES"] = "0,1" @@ -63,14 +58,9 @@ def setup_servers(): "0.5", "--max-model-len", "1000", - "--kv-connector", - "PyNcclConnector", - "--kv-role", - "kv_consumer", - "--kv-rank", - "1", - "--kv-parallel-size", - "2", + "--kv-transfer-config", + '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer",'\ + '"kv_rank":1,"kv_parallel_size":2}', ] decode_env = os.environ.copy() decode_env["CUDA_VISIBLE_DEVICES"] = "2,3" From 2cd7ee5108989b6179ecbc5f30c3cc04070b5d1c Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 1 Dec 2024 18:11:40 +0000 Subject: [PATCH 23/33] Remove send and recv from StatelessProcessGroup, add comments to remind people that the disaggregated prefill proxy server is temporary usage and will be deprecated soon. Signed-off-by: KuntaiDu --- ...gg_prefill.sh => disaggregated_prefill.sh} | 2 ++ vllm/config.py | 4 ++-- .../kv_connector/pynccl_connector/pipe.py | 5 ++--- vllm/distributed/utils.py | 21 ------------------- 4 files changed, 6 insertions(+), 26 deletions(-) rename examples/{disagg_prefill.sh => disaggregated_prefill.sh} (94%) diff --git a/examples/disagg_prefill.sh b/examples/disaggregated_prefill.sh similarity index 94% rename from examples/disagg_prefill.sh rename to examples/disaggregated_prefill.sh index f9bfc294cd23..fdb99bab1b7f 100644 --- a/examples/disagg_prefill.sh +++ b/examples/disaggregated_prefill.sh @@ -63,6 +63,8 @@ wait_for_server 8200 # to 1 # - after the prefill vLLM finishes prefill, send the request to decode vLLM # instance +# NOTE: the usage of this API is subject to change --- in the future we will +# introduce "vllm connect" to connect between prefill and decode instances python3 ../benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py & sleep 1 diff --git a/vllm/config.py b/vllm/config.py index 3accc72fe9e5..fde06e518a69 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2061,8 +2061,8 @@ class KVTransferConfig(BaseModel): # The KV connector for vLLM to transmit KV caches between vLLM instances. kv_connector: Optional[str] = None - # The device used by kv connector to buffer the KV cache. Can be cpu or - # cuda. Recommended value: cuda. + # The device used by kv connector to buffer the KV cache. + # Currently only support 'cuda'. kv_buffer_device: Optional[str] = "cuda" # The buffer size for TorchDistributedConnector. Measured in number of diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pipe.py b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pipe.py index c53696c7c328..31bd323e5546 100644 --- a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pipe.py +++ b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pipe.py @@ -90,9 +90,8 @@ def _get_device_send_recv_impl( comm.disabled = False send, recv = comm.send, comm.recv # type: ignore else: - # use cpu communication - send = group.send - recv = group.recv + raise NotImplementedError("cpu device and non-nvidia hardawres are"\ + " not supported yet.") return send, recv diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 6c63481b2f46..1ffae2e80b99 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -131,16 +131,6 @@ def send_obj(self, obj: Any, dst: int): self.send_dst_counter[dst] += 1 self.entries.append((key, time.time())) - def send(self, tensor: torch.Tensor, dst: int): - """Send out a tensor to destination rank. - This function implements send using CPU communication, which is - needed when the KV cache buffer is placed on CPU in distributed KV - cache communication.""" - self.expire_data() - key = f"send_to/{dst}/{self.send_dst_counter[dst]}" - self.store.set(key, tensor.numpy().tobytes()) - self.send_dst_counter[dst] += 1 - self.entries.append((key, time.time())) def expire_data(self): """Expire data that is older than `data_expiration_seconds` seconds.""" @@ -161,17 +151,6 @@ def recv_obj(self, src: int) -> Any: self.recv_src_counter[src] += 1 return obj - def recv(self, tensor: torch.Tensor, src: int): - """Receive a tensor from source rank. - This function implements recv using CPU communication, which is - needed when the KV cache buffer is placed on CPU in distributed KV - cache communication.""" - received_tensor = torch.frombuffer(self.store.get( - f"send_to/{self.rank}/{self.recv_src_counter[src]}"), - dtype=tensor.dtype).reshape( - tensor.shape) - self.recv_src_counter[src] += 1 - tensor[...] = received_tensor def broadcast_obj(self, obj: Optional[Any], src: int) -> Any: """Broadcast an object from a source rank to all other ranks. From 6b7901e939426b7424bb419647f723d3e148475c Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 1 Dec 2024 18:12:23 +0000 Subject: [PATCH 24/33] Make format checker happy Signed-off-by: KuntaiDu --- vllm/config.py | 2 +- vllm/distributed/utils.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index fde06e518a69..47a47b1655b5 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2061,7 +2061,7 @@ class KVTransferConfig(BaseModel): # The KV connector for vLLM to transmit KV caches between vLLM instances. kv_connector: Optional[str] = None - # The device used by kv connector to buffer the KV cache. + # The device used by kv connector to buffer the KV cache. # Currently only support 'cuda'. kv_buffer_device: Optional[str] = "cuda" diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 1ffae2e80b99..dcfcb848cbe0 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -131,7 +131,6 @@ def send_obj(self, obj: Any, dst: int): self.send_dst_counter[dst] += 1 self.entries.append((key, time.time())) - def expire_data(self): """Expire data that is older than `data_expiration_seconds` seconds.""" while self.entries: @@ -151,7 +150,6 @@ def recv_obj(self, src: int) -> Any: self.recv_src_counter[src] += 1 return obj - def broadcast_obj(self, obj: Optional[Any], src: int) -> Any: """Broadcast an object from a source rank to all other ranks. It does not clean up after all ranks have received the object. From 0cc51a02a7263684d91b80ae9377b665ad38aa69 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 1 Dec 2024 18:59:09 +0000 Subject: [PATCH 25/33] bug fix on PyNcclPipe: it needs to support CPU as a device in order to transmit control-plane messages. Signed-off-by: KuntaiDu --- .../kv_connector/pynccl_connector/pipe.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pipe.py b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pipe.py index 31bd323e5546..eadf50d4f304 100644 --- a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pipe.py +++ b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pipe.py @@ -90,12 +90,21 @@ def _get_device_send_recv_impl( comm.disabled = False send, recv = comm.send, comm.recv # type: ignore else: - raise NotImplementedError("cpu device and non-nvidia hardawres are"\ - " not supported yet.") + # This send / recv implementation here is NOT intended to transfer + # KV caches (and should NOT be repurposed to transfer KV caches). + # Currently it is only used to transmit control-plane messages + # for PyNcclBuffer. + send = group.send_obj + + def my_recv(x, src): + x[...] = group.recv_obj(src) + + recv = my_recv return send, recv def _select_device(self, device: str): + logger.info("Selecting device: %s", device) if device == "cuda": return torch.device(f"cuda:{self.local_rank}") else: From 6af15f01c80c3969b58799e34761dec08cfdab53 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 1 Dec 2024 19:04:30 +0000 Subject: [PATCH 26/33] Add a warning to notify people that the usage of disaggregated prefill is experimental. Signed-off-by: KuntaiDu --- examples/disaggregated_prefill.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/disaggregated_prefill.sh b/examples/disaggregated_prefill.sh index fdb99bab1b7f..6a4b233a1a4a 100644 --- a/examples/disaggregated_prefill.sh +++ b/examples/disaggregated_prefill.sh @@ -3,6 +3,9 @@ # We will launch 2 vllm instances (1 for prefill and 1 for decode), # and then transfer the KV cache between them. +echo "🚧🚧 Warning: The usage of disaggregated prefill is experimental and subject to change 🚧🚧" +sleep 1 + # Trap the SIGINT signal (triggered by Ctrl+C) trap 'cleanup' INT From 33973e9dcf3bf75dfe511a8b8658afd32a606d6c Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 1 Dec 2024 19:15:29 +0000 Subject: [PATCH 27/33] Reduce the # of GPU to 2, in order to save CI cost. Signed-off-by: KuntaiDu --- .buildkite/test-pipeline.yaml | 24 ++++++++++++------------ tests/kv_transfer/disagg_test.py | 8 ++------ 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 02ba0e56d2c5..24833f743ebd 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -401,6 +401,18 @@ steps: - pytest -v -s distributed/test_comm_ops.py - pytest -v -s distributed/test_shm_broadcast.py +- label: Disaggregated Prefill Test # 2min + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + source_file_dependencies: + - vllm/distributed/parallel_state.py + - vllm/distributed/kv_transfer + - vllm/worker/worker_base.py + - vllm/worker/worker.py + - vllm/worker/model_runner.py + commands: + - pytest -v -s kv_transfer/disagg_test.py + - label: 2 Node Tests (4 GPUs in total) # 16min working_dir: "/vllm-workspace/tests" num_gpus: 2 @@ -475,18 +487,6 @@ steps: - pytest -v -s distributed/test_pp_cudagraph.py - pytest -v -s distributed/test_pipeline_parallel.py -- label: Disaggregated Prefill Test # 4min - working_dir: "/vllm-workspace/tests" - num_gpus: 4 - source_file_dependencies: - - vllm/distributed/parallel_state.py - - vllm/distributed/kv_transfer - - vllm/worker/worker_base.py - - vllm/worker/worker.py - - vllm/worker/model_runner.py - commands: - - pytest -v -s kv_transfer/disagg_test.py - - label: LoRA TP Test (Distributed) num_gpus: 4 soft_fail: true diff --git a/tests/kv_transfer/disagg_test.py b/tests/kv_transfer/disagg_test.py index d240a58d459c..adc6150edece 100644 --- a/tests/kv_transfer/disagg_test.py +++ b/tests/kv_transfer/disagg_test.py @@ -25,8 +25,6 @@ def setup_servers(): sys.executable, "-m", "vllm.entrypoints.openai.api_server", - "-tp", - "2", "--model", "meta-llama/Meta-Llama-3.1-8B-Instruct", "--port", @@ -40,7 +38,7 @@ def setup_servers(): '"kv_rank":0,"kv_parallel_size":2}', ] prefill_env = os.environ.copy() - prefill_env["CUDA_VISIBLE_DEVICES"] = "0,1" + prefill_env["CUDA_VISIBLE_DEVICES"] = "0" prefill_proc = Popen(prefill_cmd, env=prefill_env) # Start decode instance @@ -48,8 +46,6 @@ def setup_servers(): sys.executable, "-m", "vllm.entrypoints.openai.api_server", - "-tp", - "2", "--model", "meta-llama/Meta-Llama-3.1-8B-Instruct", "--port", @@ -63,7 +59,7 @@ def setup_servers(): '"kv_rank":1,"kv_parallel_size":2}', ] decode_env = os.environ.copy() - decode_env["CUDA_VISIBLE_DEVICES"] = "2,3" + decode_env["CUDA_VISIBLE_DEVICES"] = "1" decode_proc = Popen(decode_cmd, env=decode_env) # Wait for servers to be ready From a14180f8b46b7b8a880012fe02b52cc6a82a426c Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 1 Dec 2024 21:40:28 +0000 Subject: [PATCH 28/33] Refactor --- create three abstractions: kv_pipe, kv_lookup_buffer and kv_connector. Signed-off-by: KuntaiDu --- tests/kv_transfer/test_lookup_buffer.py | 9 ++- tests/kv_transfer/test_send_recv.py | 4 +- vllm/config.py | 1 - vllm/distributed/kv_transfer/README.md | 4 ++ .../kv_transfer/kv_connector/factory.py | 4 +- .../connector.py => simple_connector.py} | 28 ++++---- .../__init__.py | 0 .../simple_buffer.py} | 7 +- .../kv_transfer/kv_pipe/__init__.py | 0 vllm/distributed/kv_transfer/kv_pipe/base.py | 65 +++++++++++++++++++ .../pipe.py => kv_pipe/pynccl_pipe.py} | 3 +- .../kv_transfer/kv_transfer_agent.py | 24 ++----- 12 files changed, 98 insertions(+), 51 deletions(-) create mode 100644 vllm/distributed/kv_transfer/README.md rename vllm/distributed/kv_transfer/kv_connector/{pynccl_connector/connector.py => simple_connector.py} (90%) rename vllm/distributed/kv_transfer/{kv_connector/pynccl_connector => kv_lookup_buffer}/__init__.py (100%) rename vllm/distributed/kv_transfer/{kv_connector/pynccl_connector/buffer.py => kv_lookup_buffer/simple_buffer.py} (97%) create mode 100644 vllm/distributed/kv_transfer/kv_pipe/__init__.py create mode 100644 vllm/distributed/kv_transfer/kv_pipe/base.py rename vllm/distributed/kv_transfer/{kv_connector/pynccl_connector/pipe.py => kv_pipe/pynccl_pipe.py} (99%) diff --git a/tests/kv_transfer/test_lookup_buffer.py b/tests/kv_transfer/test_lookup_buffer.py index 7e158b166dc4..96b0e5871333 100644 --- a/tests/kv_transfer/test_lookup_buffer.py +++ b/tests/kv_transfer/test_lookup_buffer.py @@ -5,10 +5,9 @@ from tqdm import tqdm from vllm.config import KVTransferConfig -from vllm.distributed.kv_transfer.kv_connector.pynccl_connector.buffer import ( - LookupBuffer) -from vllm.distributed.kv_transfer.kv_connector.pynccl_connector.pipe import ( - PyNcclPipe) +from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( + SimpleBuffer) +from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe # TODO: the test depends on a lot of fields in the current implementation. # We should have standard interface instead direct field access @@ -149,7 +148,7 @@ def stress_test(my_rank, buf, device): port_offset=1, ) - buffer = LookupBuffer(cpu_pipe, data_pipe, 170000) + buffer = SimpleBuffer(cpu_pipe, data_pipe, 170000) test_run(my_rank, buffer, data_pipe.device) diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py index c0447bed4912..65973bf10a4d 100644 --- a/tests/kv_transfer/test_send_recv.py +++ b/tests/kv_transfer/test_send_recv.py @@ -6,7 +6,7 @@ from tqdm import tqdm from vllm.config import KVTransferConfig -from vllm.distributed.kv_transfer.kv_connector.pynccl_connector import pipe +from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe def test_run(my_rank, pipe): @@ -143,7 +143,7 @@ def latency_test(my_rank, pipe, nelement, ntensor): kv_port=12345, ) - pipe = pipe.PyNcclPipe( + pipe = PyNcclPipe( local_rank=my_rank, config=config, ) diff --git a/vllm/config.py b/vllm/config.py index 47a47b1655b5..34db165c727e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2054,7 +2054,6 @@ def __post_init__(self): f"installed. Original error:\n{otel_import_error_traceback}") -@dataclass class KVTransferConfig(BaseModel): """Configuration for distributed KV cache transfer.""" diff --git a/vllm/distributed/kv_transfer/README.md b/vllm/distributed/kv_transfer/README.md new file mode 100644 index 000000000000..ec9c3b2ad8cd --- /dev/null +++ b/vllm/distributed/kv_transfer/README.md @@ -0,0 +1,4 @@ + +# vLLM KV cache transfer + +This folder implements a series of functionalities \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 2d495191855c..015f892cec93 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -12,8 +12,8 @@ class KVConnectorFactory: def create_connector(rank: int, local_rank: int, config: "VllmConfig") -> KVConnectorBase: if config.kv_transfer_config.kv_connector == 'PyNcclConnector': - from .pynccl_connector.connector import PyNcclConnector - return PyNcclConnector(rank, local_rank, config) + from .simple_connector import SimpleConnector + return SimpleConnector(rank, local_rank, config) else: raise ValueError(f"Unsupported connector type: " f"{config.kv_connector}") diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py similarity index 90% rename from vllm/distributed/kv_transfer/kv_connector/pynccl_connector/connector.py rename to vllm/distributed/kv_transfer/kv_connector/simple_connector.py index be22077d410b..007f2d0fede4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py @@ -1,11 +1,4 @@ """ - This file implements a simple torch distributed connector by 3 classes: - - `TorchDistributedPipe`: a tensor transmission pipe between vllm instances, - using `torch.distributed` - - `TorchDistributedBuffer`: a buffer to store tensors, implemented on top - of `TorchDistributedPipe` - - `TorchDistributedConnector`: a torch distributed connector between P/D - instance, implemented on top of `TorchDistributedBuffer` """ from typing import TYPE_CHECKING, List, Optional, Tuple, Union @@ -14,10 +7,9 @@ from vllm import _custom_ops as ops from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase -from vllm.distributed.kv_transfer.kv_connector.pynccl_connector.buffer import ( - LookupBuffer) -from vllm.distributed.kv_transfer.kv_connector.pynccl_connector.pipe import ( - PyNcclPipe) +from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( + SimpleBuffer) +from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe from vllm.logger import init_logger from vllm.sequence import IntermediateTensors @@ -27,7 +19,7 @@ logger = init_logger(__name__) -class PyNcclConnector(KVConnectorBase): +class SimpleConnector(KVConnectorBase): def __init__( self, @@ -43,8 +35,8 @@ def __init__( self.lookup_buffer_size = self.config.kv_buffer_size - self.producer_buffer: Optional[LookupBuffer] = None - self.consumer_buffer: Optional[LookupBuffer] = None + self.producer_buffer: Optional[SimpleBuffer] = None + self.consumer_buffer: Optional[SimpleBuffer] = None # 2 pipes for every rank in the world port_offset_base = 2 * rank @@ -64,7 +56,7 @@ def __init__( port_offset=port_offset_base + 1, device="cpu", ) - self.producer_buffer = LookupBuffer(self.producer_signal_pipe, + self.producer_buffer = SimpleBuffer(self.producer_signal_pipe, self.producer_data_pipe, self.config.kv_buffer_size) @@ -83,7 +75,7 @@ def __init__( port_offset=port_offset_base + 1, device="cpu", ) - self.consumer_buffer = LookupBuffer( + self.consumer_buffer = SimpleBuffer( self.consumer_signal_pipe, self.consumer_data_pipe, self.config.kv_buffer_size, @@ -239,7 +231,9 @@ def recv_kv_caches_and_hidden_states( if not bypass_model_exec: # Some of the KV cache is not retrieved - # so we need to adjust model_input and redo the forwarding. + # Here we will fall back to normal model forwarding + # But optionally you can adjust model_input so that you only do + # prefilling on those tokens that are missing KV caches. logger.debug( "[rank%d]: Failed to receive all KVs and hidden " "states, redo model forwarding.", torch.distributed.get_rank()) diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/__init__.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/__init__.py similarity index 100% rename from vllm/distributed/kv_transfer/kv_connector/pynccl_connector/__init__.py rename to vllm/distributed/kv_transfer/kv_lookup_buffer/__init__.py diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py similarity index 97% rename from vllm/distributed/kv_transfer/kv_connector/pynccl_connector/buffer.py rename to vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py index d9febc733c5d..359e25b7b4bb 100644 --- a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py @@ -16,16 +16,15 @@ import torch -from vllm.distributed.kv_transfer.kv_connector.pynccl_connector.pipe import ( - PyNcclPipe) +from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase from vllm.logger import init_logger logger = init_logger(__name__) -class LookupBuffer: +class SimpleBuffer: - def __init__(self, signal_pipe: PyNcclPipe, data_pipe: PyNcclPipe, + def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, buffer_size_thresh: float): """ signal_pipe: on CPU diff --git a/vllm/distributed/kv_transfer/kv_pipe/__init__.py b/vllm/distributed/kv_transfer/kv_pipe/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/distributed/kv_transfer/kv_pipe/base.py b/vllm/distributed/kv_transfer/kv_pipe/base.py new file mode 100644 index 000000000000..4b0cb44cc5b8 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_pipe/base.py @@ -0,0 +1,65 @@ +""" +This file defines an interface `KVPipeBase` +that provides an abstraction for sending and receiving tensors, or None, via +distributed communications. + +All classes instantiated from this interface are assumed to be a FIFO pipe. + +If your distributed communication platform already supports key-value lookup, +you can bypass this interface and directly start from `kv_lookup_buffer`. +""" + +from abc import ABC, abstractmethod +from typing import Optional + +import torch + + +class KVPipeBase(ABC): + """ + This class provides an interface for sending and receiving tensors, or + None, by distributed communications. + """ + + @abstractmethod + def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + """Send a tensor, or None, via the pipe. + + Need to support sending None -- important for error handling. + + TODO: add a `key` argument so that we can use traditional + key-value database as the distributed communication mechanism behind + the pipe. + + Args: + tensor (Optional[torch.Tensor]): The tensor to be sent. Can be None. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def recv_tensor(self) -> Optional[torch.Tensor]: + """Receive a tensor (can be None) from the pipeline. + + Returns: + Optional[torch.Tensor]: The tensor received from the pipeline. Can + be None. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def close(self) -> None: + """Close the pipeline and release resources. + + This method is responsible for closing the communication pipeline + and releasing any resources associated with it. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pipe.py b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py similarity index 99% rename from vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pipe.py rename to vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py index eadf50d4f304..98222fa67e49 100644 --- a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py @@ -20,6 +20,7 @@ from vllm.config import KVTransferConfig from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator +from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger @@ -36,7 +37,7 @@ def __init__(self, message): Metadata = Dict[str, Optional[torch.Tensor]] -class PyNcclPipe: +class PyNcclPipe(KVPipeBase): METADATA_LENGTH = 16 MAX_TENSOR_DIMENSIONS = 14 diff --git a/vllm/distributed/kv_transfer/kv_transfer_agent.py b/vllm/distributed/kv_transfer/kv_transfer_agent.py index 36d4a7fda424..9ce97851dc84 100644 --- a/vllm/distributed/kv_transfer/kv_transfer_agent.py +++ b/vllm/distributed/kv_transfer/kv_transfer_agent.py @@ -1,22 +1,8 @@ -"""vLLM distributed KV cache transfer API. -These APIs are used in `vllm/worker/model_runner.py`. - -Currently supporting TP. The TP between prefill and decode instance needs to be -the same. - -Workflow (disaggregated prefill) -- In prefill instance - - After prefill, vLLM `insert` its KV caches into a lookup buffer. - - The prefill instance will also open up a thread that listens to - `drop_select` request. -- In decode instance - - vLLM first runs `drop_select` to send input tokens and a mask on input - tokens (we call it roi, region of interest) to prefill instance - - The prefill instance then respond to `drop_select` request by - - Finding a match in current lookup buffer. - - Clone and send the matched item out - - Delete the matched item in the lookup buffer to free up GPU memory. - - The decode vLLM then store the KV cache into paged memory. +"""A centralized entrypoint to perform distributed KV cache transfer. + +This implementation is a shim wrapper on two APIs exposed by `kv_connector`: +1. `send_kv_caches_and_hidden_states` +2. `recv_kv_caches_and_hidden_states """ from typing import TYPE_CHECKING, List, Tuple, Union From 4fe6fcc2543db5f61bc950cdd7bbf12074ebbac3 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 1 Dec 2024 21:44:21 +0000 Subject: [PATCH 29/33] Add KVLookupBufferBase abstraction Signed-off-by: KuntaiDu --- examples/disaggregated_prefill.sh | 2 + .../kv_transfer/kv_lookup_buffer/base.py | 108 ++++++++++++++++++ .../kv_lookup_buffer/simple_buffer.py | 4 +- 3 files changed, 113 insertions(+), 1 deletion(-) create mode 100644 vllm/distributed/kv_transfer/kv_lookup_buffer/base.py diff --git a/examples/disaggregated_prefill.sh b/examples/disaggregated_prefill.sh index 6a4b233a1a4a..87155273a81d 100644 --- a/examples/disaggregated_prefill.sh +++ b/examples/disaggregated_prefill.sh @@ -14,6 +14,7 @@ cleanup() { echo "Caught Ctrl+C, cleaning up..." # Cleanup commands pgrep python | xargs kill -9 + pkill -f python echo "Cleanup complete. Exiting." exit 0 } @@ -93,6 +94,7 @@ output2=$(curl -X POST -s http://localhost:8000/v1/completions \ # Cleanup commands pgrep python | xargs kill -9 +pkill -f python echo "" diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py new file mode 100644 index 000000000000..bad119a1aa92 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py @@ -0,0 +1,108 @@ +""" +This file contains a new class `KVLookupBufferBase` that allows developers to +think of KV cache operations as inserting new KV cache entries (`insert`) +into the lookup buffer and querying existing KV caches (`drop_select`) +from the lookup buffer. + +All distributed communications are abstracted behind this class. +""" + +from abc import ABC, abstractmethod +from typing import List, Optional + +import torch + + +class KVLookupBufferBase(ABC): + """ + Abstract base class for a lookup buffer. + + This class provides an abstraction for a key-value (KV) cache lookup buffer. + + The key of the lookup buffer: + - input_tokens: token IDs of the request + - roi: a binary mask on top of input_tokens. + - Purpose of roi: Since KV cache may only be available for a subset of + tokens in the input (for example, when vLLM is connected to an external + KV cache service), roi specifies the subset of tokens that the KV cache + is associated with. + - NOTE: roi can be further extended to describe which part of KV the + current process is holding (each process may only hold a part of KV + due to TP and PP). This is not implemented for now. + + The value of the lookup buffer: + - key: the key tensor in the KV cache + - value: the value tensor in the KV cache + - hidden: the final hidden state generated by model forwarding. This allows + vLLM to bypass further model forwarding by transmitting the hidden state. + """ + + @abstractmethod + def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor) -> None: + """Insert into the lookup buffer. + + The functionality is similar to the following python statement + ``` + buffer[input_tokens, roi] = [key, value, hidden] + ``` + + FIXME: in the future, we should only have two arguments, key and value, + where key is a tensor dict and value is a tensor dict. + + FIXME: we should transmit both sampler outputs and the hidden states. + + Args: + input_tokens (torch.Tensor): token IDs. + roi (torch.Tensor): A binary mask on top of the input tokens + key (torch.Tensor): The key tensor in the KV cache. + value (torch.Tensor): The value tensor in the KV cache. + hidden (torch.Tensor): The final hidden state tensor generated + during model forwarding to bypass model + forwarding. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def drop_select( + self, input_tokens: Optional[torch.Tensor], + roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: + """Select and *drop* KV cache entries from the lookup buffer. + + The functionality is similar to the following python statements + ``` + ret = buffer.pop(input_tokens, roi) + return ret + ``` + + If `input_tokens` and `roi` is `None`, it means selecting any of the + KV caches in the buffer, return, and remove it from the buffer, useful + when offloading KV cache to KV cache storage service. + + Args: + input_tokens (torch.Tensor): token IDs. + roi (torch.Tensor): A binary mask on top of the input tokens + + Returns: + List[Optional[torch.Tensor]]: A list of tensors. Can be None. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def close(self) -> None: + """Close the buffer and release resources. + + This method is responsible for cleaning up resources related to the + lookup buffer when it is no longer needed. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py index 359e25b7b4bb..81d9880d2f18 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py @@ -16,13 +16,15 @@ import torch +from vllm.distributed.kv_transfer.kv_lookup_buffer.base import ( + KVLookupBufferBase) from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase from vllm.logger import init_logger logger = init_logger(__name__) -class SimpleBuffer: +class SimpleBuffer(KVLookupBufferBase): def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, buffer_size_thresh: float): From b721b5c182d475f5d5581cbd647338670b41a175 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 1 Dec 2024 22:14:50 +0000 Subject: [PATCH 30/33] Add README and adjust docstrings Signed-off-by: KuntaiDu --- vllm/distributed/kv_transfer/README.md | 30 ++++++++++++++++-- .../kv_transfer/disagg_prefill_workflow.jpg | Bin 0 -> 142656 bytes .../kv_connector/simple_connector.py | 6 ++++ .../kv_lookup_buffer/simple_buffer.py | 3 +- 4 files changed, 35 insertions(+), 4 deletions(-) create mode 100644 vllm/distributed/kv_transfer/disagg_prefill_workflow.jpg diff --git a/vllm/distributed/kv_transfer/README.md b/vllm/distributed/kv_transfer/README.md index ec9c3b2ad8cd..76911bee4644 100644 --- a/vllm/distributed/kv_transfer/README.md +++ b/vllm/distributed/kv_transfer/README.md @@ -1,4 +1,30 @@ -# vLLM KV cache transfer +# Distributed KV cache transfer + +This folder implements distributed KV cache transfer across vLLM instances. +Currently the main usecase is for disaggregated prefilling. + +## Abstractions + +The KV cache transfer contains three layer of abstractions: + +- KV pipe: a FIFO pipe for torch.tensor transmission. Key APIs: `send_tensor` and `recv_tensor`. +- KV lookup buffer: a lookup buffer for KV caches. Key: the tokens, value: the KV caches (and/or hidden states). Key APIs: `insert` and `drop_select` (similar to SQL semantics). +- KV connector: a connector that connects the KV pipe and KV lookup buffer to vLLM. Key APIs: `send_kv_caches_and_hidden_states` and `recv_kv_caches_and_hidden_states`. + +Why we need KV lookup buffer: FIFO pipe itself is not enough as prefill vLLM worker may process requests in a different order compared to decode vLLM worker. Say the QPS is really high, prefill worker may handle requests in order A -> B -> C, but the decode worker may process request C first. This is not the case that can be naturally handled by FIFO pipe, so we provide KV lookup buffer to help translate a FIFO pipe to a lookup buffer. + +NOTE: KV pipe layer is bypassible: you can skip this layer if your distributed +communication service already supports key-value-based lookup (like redis or +RDMA database). + +NOTE: If you want to not only transfer KV caches, but adjust the model exectuion flow of vLLM as well (for example, allow vLLM to receive KV caches on some tokens and do prefill on the remaining tokens), you can bypass both KV pipe layer and KV lookup buffer layer, and directly implement on KV connector layer. Bear in mind that as vLLM's model input is constantly changing, this implementation will likely be broken when vLLM has new updates. + +## Disaggregated prefilling + +The example usage is in [this file](../../../examples/disaggregated_prefill.sh). + +Here is the diagram of how we run disaggretgated prefilling. + +![Disaggregated prefill workflow](./disagg_prefill_workflow.jpg) -This folder implements a series of functionalities \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/disagg_prefill_workflow.jpg b/vllm/distributed/kv_transfer/disagg_prefill_workflow.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a25ec5ef52491a0e3faf596669e6cf0e7c7ae175 GIT binary patch literal 142656 zcmeFZ2Ut_v)-D`+C(=7XKv5|{1nERTnuv%F604naBs!odHl%P*G7*Q~lOz z^3lQMdH^*W4g1;4nzS4y59ma_IOW2UbLqvjsyn&vjAF&F*gpwpVC3fE<>NnhUgE;V zOY#a=6_u1#w6E*v>ggNYFuiMLZee-P%E9sBBPVAUS8pF*KYw^YV8qkNsOV=gu_>u9 z($X_tzRJwYFDQIhR9y1@Lrra6eM4hYb60myZ(skX&tJyICnl$+XTG7(nB|q#??2Yo zH*kAD_YV$#;g61glZygC^;fa}F4^zoVk65%Nli^fP4}Bz6qNqtLd8Z+bM`VVyQT@< z11}CyxiET8t>oP5P6n|ncd%UcPevKJ#pTiGaKDN67s>wn1PlL{B>THy|Bwp>V5XuV z4;~d800JQVIy8pjdvhX?hv%IakF{9>D^Ppw-&9FPPD_g?F`0^_=iPRt*BI&(YHM#( zy5+k`a{{uQXAN}#=TTS3M*o$HFhdkLa;(uf3{qa&EmrenluI(%9tgcT?n%4HxEe+b3BuFNLJJh5|B+h^GMX$uZ|Y zwkSqNZGOCZj~w1R1uW+3{eA12kkEgf@b~`u-?o>yxCR$WJcLgH^(8C71MQChUrr0< z@2)vNiS57EO#WlzHl3c1clM)_QvhV(6c7?A{Etn5kI#Rd@YnA8Kc<(AyxvRZoiv4b z96MfI_^+Y6!4$cHa#MDCFxdH z*W%%Mt5d+b?*S*paS1^J`MHcJa0;ltqI(J$N;l|DUcyc@Z+LgP*A5%-PH%x>^A)jR1Y2<$MaV+5E?n!F<%2Wu;DPZPRK60H|VWVyp`oW)=Io#P!;E?w?IUDuSJe2W5 z-oIQfz}oldTFA|OR8x*zo&x>iIWkWq{}u^qKmHC=nI@?}I0bki(=@_2$kyx~-hVb{ zEnpRq>`XpI+>^kIyG9IeY(TXQajig1 zC2$?d9U=%0-yaOby3||kyeh3Ia&EZE+t1mUz-PFvdTV^KN?~8OkF>s-_-1nzA_zxf zklTi(aBZDwl5l%Bglb)7LS-S}`@L)ZFaL}-U9=#Z9)*V8_K$Pb@swBMuX^+(G0r3O zPg);hYmdL@zJos@ba)Mv-kZ5Ks<34U@BhMDz&Y@E06bVL7p`bp1+v?AoAvY$bjUQ< zHEt~o$!oigFF`vO;a_1}4n44gy}zQ=D9&ejmu^(g%o+Zev+6npNEg3(SwHopx9SNk z9TgX}&yHpk>EID_3Sd%z0@M%F&`3_07u>@dyOr@_&Pu=LfkTt#S^E^-?bhX8fF567 zr}vgA0TeX@OTt(?%V%A5yc^f`)vf4VtM70BeY?E0!*omfzrUA{gx{W{E#&(y&ar z9PbsDB74h_X1IFNP>>>Z+W;y_492a11Y7B_{$0YtZH^Hbzg!nz-*aUIMCGC5 zZ?s(!q)Kd5547mGg9yYMr+@pw<_4UHP%rJ2g;wW-3HgUGu}k6u+H~%eEG=5Z8oDi-*t_`~bTA z@8 zdw%)1*?R_rIi5VY}>KPO;JDQ_HLqyJ38-q z(-5XeLZZmPDxOe^dOiAK@5{AcTzeAR;AOmqFp&*!i-uCcGiB$t!8;le*K{oA9<#&nCD7WHjs zko4bdT$_UnyL+aqr9!37PJ}-j7Lks32PA%PC?kT4LI#m_JfS}G%mzhp$vzdqY>=UW zgny#s?43uUJBWlqKF7@zEmq?@9l@rx+Y*9-V`c6G&&%|RxqeDS`(`EagiJs<5MCFu zp+KVgsBXjqwmH!h-1IijEzZ~OX7pL1D>u695J%V(5?QqX~V)QMcYg?Bxw!o(V(f*!O0O11vBI!k@AKN;l)}rrX2n|93 z$)#QJu6DEd-pon5VFaM(XMs)wjYmt1Qom9^*Sc5xEB^Y+jZ{!2WTz%%W$rs!6-Y^^ zfWlFW(fU&WZkLCUe{^S!E)oW48?aGGe0rFBv5rNfNpJ+a4uq-F5Y^G7u^g$2E$>t~mOC(nvGcYC3>+L4+VZyOC{dI7GT(kwL|} zIV$yDS)23I0rs{Iyg+peU&oM&g>7y2HKlbs6N1Dk;Bys`{J+}Th-^4IZ}yX<;X0&q zA+v$NNcV#iZ@p0otWodaoQm{MEA_F$Ifu_qc3M|Y0a@-$6!A+FvnDy#ALN2pB@Z}r z`Klu&)Ri`&0^rd7#czs-#dwE&Jk75BDS#G!OYhlhdsA*hDsUy%rtQdbVV@YJ$#bMuJXjvZ&~i>~!E{&#qF^Sv%dL_Oex_-EX2 zI$Se<)Iptd^Ul`&PJNMM`!*1M8UqAwfP?Wyh^G zdv2d=8B#^pc?w{hQS5ZGr^b6!9$`r+^Ph;Y9BOyH2`i2Uqe=-l4aX zN83jEYiG2s#-{|bs#=V*U%1E18k!8{B8@NZJwURwwUK>@QA4QfAgS5-SJKatAm)dh zCi)L_IB!$@u2A1WR>a{-T>ok?*ZwJ@dFMmg2L16^u&xebkf6SR?8?P|#-*a`OFU__ zJY;d%&k5aCRj=)R`}{%DUGevqLuhG7EQM!r-u>)|%lHQv5AJ?@1)k9@@sHLP&wFY( z`DWYc+=D$``@11LkRD(Zi5+({z104AC&}r(!u!wP>2G4i3#snA&$g4F@+)CbF?f*i zN3tF52u)n%RB%vg$L&E~&9OtQB2`+WvzHuGxL)vV#>*{vrS1=JEuR9;Z(~bIz0zW* z0R7x|Be5EC_;JNZ4Gp3)-m_W7RvR+_BScRmeAcZa$OA7>sz--xza(_??b^MIYf1}r z#I!l!$pa`cOHe0U=BhSieNx_BUqen65xP?Oe(rAlAee#s!Tj)nZpC9ty!Y)`8!i^M8$t)Pw~|8FHrN?DVVjJ%QgI{4}i4 zreUL)prCGzkyxy2!RP10^={eU(ipfe@zaud`FoSxneO*t77w4ehr~2*KeB67DzyAY z+DJcU^p+unw^b%hwaMW1usPZ%AL>B72t9oIaa7a0{z_Yy5YUJ@$ysWj0DJQO!==72 zHw)4S1f5b7bw!-2h_v`z3`iJp2S1y)sd(AA2o8lt7>{8RY2tgCwC+X7@;%WXvR2q0 z5RhBBC^o@}pFMO)0Y)6RZ=0xrgG*6(CsXUE1fOm7uj@1~D!2+do$PtsJq?yiD$iq8 z5fFSOb_&8pygnWx0bAtob^g&3-4p1S_TE7F#m?ZWbWHq_mb`#hsAj+VMww(z0q6IG ziHB(~f@1NNJvrPVBKWLBBp!?oiksK&wv)r{RI~|A97x8QhaalKiJkIa5@|n+-@Rol z79f_bS4&hrc^)D}l)$;?V89HTrEAA87}uXTOjZc1>v@VPUifb7{J-ow-fG8B(9|Of6^`9CLotc9iU~Y-*0k^X?Qj)K!O5H4ODUG`w*% zXXvBmgQ+E*L1<2EFmL7vl7tsGc(TuqX@J(T8zfN9Z5A(2X~$TLy>rh68?}mgazz>* zqGx1+{h5u#ioV=nr%=MQiXDRSg&5;_B;{kG5vFBk_aLT@sgd7#N#DZpiQ6?UfRO4z zCHN%!6!376V;R(K%u2k8k0{$jE%t!fyE4D39JH0Iw>IC-0CM`}TDwx)=2NrHFjLq7 zJE5_eIh_#v`TBB90@tW-X|V1N>WfPT+~8r8C9&&rxS65sFdvn&AlVUm8)CgZUsr#; z9P1Y11+VN5VHiPBU2Rz!1Y=qd?lla>y_nJj;9YiCp7*o!=ggttfmE3=za_dKVj`qS zIP`9g4Yn_-#MWBj&HU;MMGk8N0Wlr1zOqkdO}ygq!tfH;4wq3K&ZytR8C)zm#CSI3 zG%d+jLEGsRAVv*6=xl!@tHOUF?>Q)T5o^_cq~gJZoOwlFzze4*O#c6g5^yRk9#PGo z0;aFrn`*~x0ts=%Qa#&Kz;pG*6K{n#~a{8NzO@Uqe5vP@07*Ni2%J84uwkuY{dA$u7GK_-f$@^fpfy55!|!P4Vs5 zU%DFvDz~rWmj$%WFutYBF7Q9rh?NoR((6!dkAZR_pcrHo+L7J+F9yVqodG&R9T(bL zZ_`UW*A#x5V^v{|(F=WmPeZ>Obyp~dX+^FYSpCp!B}jLoylX-jP(fiwqn;Y(CkZ5h zVGmplhNf%rjdf=mMf-%BGCJl(vY=gXM>ggRklEqrcGdEDxa0;jE)W@`kQ8Niew--m zmfG_xUC!Flz+Y99^N!ND?>uJqS81hWl`qlp4H{vm(Us9{Y?9Q(jfi%>L!>Im@P%AtA0v^T~1}m9JsB69k#^xVkLTofg@Iv z3T_4>mi&9#g4z#zo(O%V5L?woeY_K^)@L#zRWR-&Bsb*N=Jv7dcH4Mz)9s6$`=2U# zaK8Q9bPn(F;r%zs%DfB$LJdU5SKAvi2e)${L?&X`g&vwMD=9zDenZz{at$Q=jnuCg zXuP$I22!stoVPaShZ~#nlTs_nbjmSp>^0ka!x{!GwUu_nef9Qm-(eJ;$yDtCloRjW zfx4Q(`N>yEci6P4F&!S|U*4*E%|b%33BZhnS32dzJ;UVyiDJ_}i>UElB>N+nF<#FX z%slISxbY#?>qkL#3wxF)4`aZPH>wut0*tiE==A_t^`h6&wK%`%_g1h>q2BHnG45|) z$ljx05xaQbYs=FsVx!i?EBhWngDBpNe~A6O-+m@bwVi#F$KUGRM{j%UFWDOy-GHk# z0DX$^MB2-}0sdhj!bIQp;#Bq09#L1Yj;X9T z@m2=p#nO@|IFbK^;{6l)Y5jMvZG@)z{|)Ke+eJapZl{3Wq^R!$+UWvFPM5&Qch&jE zUS|`%+CoF6t!0h8-2B|l zAuVpe_T3AOeR8Zym0Wm#?CfDY{<8Nn6Nqsg!dfRSV2KmXE!VzUiG)iqd0kN{rJuOa zwVW6*e7>zE?@h164{!pMIph-2^CU~48&%g3S*8|(DlGqz>foz1-r6XWt|KsSxRMu- zn>Z|0F)t_F32ytMy2y%+maN?oDOf5`@SFa2F;UI3bf=DegKSy(5kVZZSTDeHr5$^yia&n{{+Wp zA88VM8IbnnH9HLghx1ZIQu}#)TI!WcG{#-ghP8t@)2Gz#$$tLT>i$^^R!7^n;04}JNMw*Nrwxz}pNsYHPY7}OqloYAv=2ez5>OvN~AU%;)pJS{3N zKw%CY=5hI5a+)c%8|pn!J9UUINQZ|H1NuNlr^=B_q=@BN>42?lTkl?#s2bPD-(5I0 zm1z`Lbp;aj$j<#&IL{#rcRr2)!snpv3Vn;-wZCV%*{#s}^>D+fF~EMVmqL(>ZdK{? zCe+YGrw37uV_M4S@rZ#|XGFJhc)we>@=x-CR(+T$Kis?*nTra&K|zr%+e97K0+AJ* z9#>(RDh$q>v4L@Wxg6I%7n?_ED!vgkW&8QQjdqnzNohUyqYf2Wt)Bl`n4isn)C}$` zEG{?&*zzpj8Tq}}^OxVyt07Z4H4SlXQm)6W?Mo~-`VFS@t3F~$MGE_o+(uVd>Fo=| z*m2eN`AKa7l?7`>Rl*K94)pcR=&kiKgfjN4A6C^_OmBS5{Bl>8!`+*&**PiJV4qu= z-jOCq6j(?!0_^k_A=OkfDnIKc$5)}iv@55Nj0?8ws@%L~Z zJy6aNPQ-O=WKrZp958af8c4B6ME(?*4k2?XrMMdx^LpgEjUaVTTXl;x$40}R>^(>1V7M6_Q+YJiI*{zRPfY0 z!FEiLGj@h9GO4~Y*CQ_Qem=t_?Dg{uEkIfNK%LXZDJ*2Tol?zsNyq`SI-KY$q-x_K z0M&+%uYVbpbZ>b0My=+{rIrs8Uk%KdJ#W#{g(85s*(KJLY9{ z8poh|!O8lLs3UJGCpw5CU}>526hN^7spT%g1Pg@FMQw&n>reZ$O#K)f6!v4ERzLH>+jMsS3xMtXr+V zC!5`(;bkcrmL+`DOU&leoyczsrvS}yVo)hjccPrAvkV3aHP59}q)ajS>gsRbpC6iF zJq29m6DKhK3cSz=0wb>Ay$RP4V%Wik?gaKgfPLlMeMCO z&0_7+!%#sK-jHx%e!cyJMkzL^Z`$vJaq(yPU_cRvXpfzVxLjxl;>(xvh9ch2m4$0X zp4C5xpDv#is(}mE+qkobJDM?8)OE%@^H7^ym=7a$3hL61OKb@m>S8bFl?v5*w;k6S zg+P4G2V}1>zXtYI416f!Fm8WlsS)F~}Q$Ki}sFU+lHx#>4Uk zO7S(lS2YC3#R$_@g>#Bt`s3Zvi*O%v;%0 znXWm96^~3BvR7q;_+6k6Ko+CWvkg7NsD|Dz*0-?WFa#BYvPs1z`{EqMj`Za7oF3EM zu~UF+OD~Hz?n+0!^vM8})!+gsPG8bYD}T*cPNO-m&EOR96nXge6p+`11`^L@k^6w` zQPGp*PnxFy9q%p>-jLk?!bJyxZz8fKrY0;tL$FDt3DCZU+FA0cHAvhB1PwYVhOAY< zwNC*Ik4^#06Tx;x_kYA3)N$<@6Ui`49o-4nFg+42wMY&CI*8OixpDq)F%DD&DfY)x2sYh=q^>u);l~Qku{ME@v+(wStV53Rw zgMSSv|3R4l>l>b2Ym7=^Ylg%wFu!SU0Y?dG60g_(ZG=8@?JAg>zV1vh_&>;))`rIbwV(fvr}P;mF!;?yEBDGlaW0UMx3* z;q>4zrGhMXB+TA%q0%zom&&bM=y$}a71%2RXU&lTiP z6v<2hLViYeD}R<7|4rl43@yg5bBDFA3&`Yu?=Pnm&nr6xH~^3E8T8vL;nrCEaFHLP`-)~S~yfQqWjPg0~Ustnnzs6?j(52#$MbJf6#1T z`J`I6wI(o)I(~&xJtX$Xevup@u1o9DB8Ngiz{lU+t(cI&K{<&kKVK>JPh4Q4D5MU0 z!Udu@1)#5x10h1i-;=z-Zmd5h?ph4AW1o>Gk^6TMbYxxo+D`OK!e2+awjW>8pdvr~ zkDw$~J^7&M`e^TsOOmULf2jMzwevI8yJxkV4!F4X$RNV+ZcJ7+5UP?Cn-qEqz`n{j z1tg9R3Zu7k=Z<4zkNp2wN!He+Z42@&q1S=L-y9(8pY!)5&%=?OU>zL$k@prYLQ)H1T!*}ZGr`nIONFj9F%hYoO-Ql zj)@6WhTH+auZ5u5PTrpa){5~!@>KK+iQg6+-4R9!4Hp+zW+4Ai3Hp8SCOZED{#Vcc zHZFhZdF1Gw-*^12&wsJ3E3?~Zi8i?T{$9zXNC>x-PotUxce+BB5Ul%ya`i;H|B*YT z%1+|-5PCJXaCPI+#hGqW$hnhC`;TJYRm4u1<^0FuiTtbHzRj!Sa{8O@g8q8!10 z_5Vpv{1d1O{hQrlcD9Zn@hsPN`}}=}^y3jRKXP#S3eNfuq$S87h)r*Slg7ViU)!1e zk;vx!2OII9=vx1+I?Aa3AT7BDLr9v3iwB>P&x`*Kz?!a~0``FQWS)~DF_O$>vivLD z+PD~(!-J5b#T8Evr>}x;3qJKgBq>OYnWF+(HF=*=`U#|%6e)X~~@g`25IsG*| zaK`P^yz>4d-eEH;7MtFjP~t;^BWx;L0pBS!6o$p{>GVgFcuSwxTv(Oy71balhl{SM zc*=p+VFChfGGCZfRpI5*BASA?wZhmHgt+c{(p6O)TRMRQAIX)tv7tjaPZAO5SZhGf zNFvY;2sh%vuPk%i-FMETjpEvFIWh{L2#pX^yIq(YOG26X)JVUZ+nHIxGjFeJuDw^H zTEAmr7w^DZs>55qfY#wOd-l&L^nYvnqFHSH-87B76f&v=jmi1Y-i?g4Ff3kNV_Vk{ zNof6~u+5N=K$P?Il7f>>v5%Q6Qmx*&%|p}s$I&Pab)szaD{&o5qQSfCS5m_wn{Pun zl9q%*V)Cj=bh#9DO5{C=LipvS$|O53<1=sx>u8G)L!U8#g2N9EFTclBjaKc8ttZ>A zB8+j?hfmrXa1ID#yn1o}QckSjvQMP0{D~Lc`Gvi+o%8VwbYg9uj0DiuFu~MiQCRvw zn#p3^5sU2kNXNFB&D9F|7}j{QMQ&nQ!B^(GOB8qp`gNwq`BxLSB=x;DyR7CxsfKuI&tMH2C&j6mNr@^^N*O zn3q3aifEtn$j)s|G3KDy?`pvWozF|q2w~RA{F)>id?aS@>en?AsE%4;b!9}X8YO4W zC1v0+!tA7QLwf5@oKw>%R0BVbDe3_}wd(I|B${1K3Vw%=!mbO8HGQb^;r{iZ*qQm} z!?ONm(7fRh0O`0U}ZMrwL zPd-wmW#{TVbg{6o@{fF6fzb1HVcF+4`CMWP-f^2Ui|2do4iTkg4p3u#*V>N;ZcCzv zWBGr@)izAlO7fAB*U$MbUgyi12Hsz_&>9*=pc&fp;e9N=z zSlb1=eok*)3r|NWU1Ft}yZ5!Sq43z@5jE{Xe5GOZol##&Bj~M{$wHIy1M;^tJ1Jfy zKZlpSvBvmE)JmThvM|9tihlUcNV1b_1r;yJnIgvJV?H1CP$Y`#x#rLXs&{#%)*F8Q znSs#R`Pt$7xT67I!S|!`@FDEdvb*w-mnA=4FPIiyX_;X#im=YzKotZv+`H$>W+*%R zE3w!9mg;gLz*oq*GD$-O-#5HWx9y@B?u0M*MFZ)a0~x=Rq8HEfKbl6Re&rbezYd6N zO?X|k*U>RxF%fbWmk}Kzw4R=XF=n$|8<9QJ>s4k~Agy|VZ}ro{{E)i$-DMc(GE9hu zOACC)#h)K^IAQhXMW*8OX$Z=|GivrTAU$=E&b@=;Tumqi*(SzY3iG%*+!)K~2{hdZ zEIW_s8U;JJ$mI8JU22ymdML&0yqE5$y5ajEjFp zsr&Spj9kp?8n#Bq3lhneNM4PLGaF>0QcjZZA{{>OOYp639T(Dg?#4x_^gAC`L!4(` zxC`)gE*+b$I+#2@43zv#4l&tc4~n|lqt4*-`&1rloSP&n7Y>bX1coz}mN?cw%eX(- zSkvMpQ6f2= zYq+aLCF>~Nx(Dx}>xtENv-jo-K3HEK;_6DR#Ujt>-F8iXg@bf9Pd@``jT7(XF2eNh zS>LLkc#wi9J{Q&TP5k}}>r3Tz-Q1FKlY6LN&-QbeT(*(ruJtPFvCmnia-TfE&alKX zbZuDd=X(+tE<0>FL6M-{-8>o95u%j_cY69vM7 zavr0GpwgCss4i{(_!>}m)ylr$boU}~KiKrJ6Cdq_%juu;w#G}|OiKzSiuts+=y&!X zCUgn&TXb9k0ssLW<+~`@oz%m(_@584T4OH2P_;`UQ>`B2Qd;w0^@D8Hs*V|WS!Z@= z)pzJc>%uRH#GDNmw_J`@U@U%n|58{7IrUUHZ80i;aQ9PRL3tiXhs)YDsCa+2zdGSj zo2AWc-wI6t<(tryYn_i-^($|8q#HiVPf=R9YassV`XfL+?JYV4V@ezE$!iT^1PDKm z&U&dK!l!_Ye=>K=H*Ko&e8ipl7dJYuU$T5)rnLm{dz~mt6KvsIyaugV>_XPURO)lR z^PI1ujpl!5Wi`=mi8_@woio!5p>g8#mfYuCFtOt)3K=^tuPDiFk0bGJhA?2Kn*7qO zLK=dDs2AH1&!#ZcM@!6}6p!PRH*Jjjz;Pi8WMa#PPXvjj%s*L7le_h zje%LRtxVEtCztvT-E+){x=cpO~p4q!wkbi#w8qfcz9GgDge=L*r9(L@znrS;bv+may z)THPswm1Ezv(0pUxbU3#A(PNN7rGu&n_m46-*9b==z$%u#L8q;&Xp%-mDxc%pCovm z0yZ$O#P}j#6W?A|OWsWWwurEa(8mwHLugFYHa%Xns7n87IC-;)%reC~HZaxNUu(%J zK^xx;z2!)U)bWI}kfdgbx29Wd-QKx;IG>Crkh3X!CsFquJ~o=Xy6#HH_%zfq+u%@X z*EO}RgVW+GIq7Y;$Cv{T?jrFvjO@QLajWp% zymE2xs4?ZaTyY(KHyd+MZuW?6cD=0QB}Uow`;G25?{(?KV}G=o60W&k#@F1#?uMyR zyhmJbA7wmWqF%Vg?N>i0JHK-D!3mKYa3NnU%X(qZSct$=oR{ck{^O&Kb~10%fMa&6F=rR+>vwh!jmYLT`I_C~&()1E#eXayo^2SdI2cgZtR2)kkSd&TbNiZpc&_oD zi~vNG{n~Ijd~qupZ4IF&QEa*gUg*L?%JJvW_JVna<)`u9FQr1IxSKTdxme1DvudvJTMc1%!XBb+sf7*9 z>jr%stE1ZM59T>n zRuSkP(<$P#ldevouw8ri%cUD%UB1mT5adEeP)^%I*j<7c%&hbZldArF;B^xo(8k~ubx}zi#PA}SevG6`j4L;pc9<=naDVmJ2DOw^N9hsA4R1f`eY-4IIZ*=?S z6$a0vb-m=NZze{Iy)QYANc6UI!6fnCBaJR7oxItP(_MVv?XKosVqHopW+`UE`Q+lK zQ-ED>0BCOz`$Q5&q9RI7kl2U{5uIcoCn;|GabdcuWLBKFVBDkri_`@z9+K&{b1W^o zv~uvAZw-fL&~IesrQO+y?>O;nG3#Or<1vTy*)J+?q1i5^;g%qeqwf;_1B>xz0GS5i zE@Jb$wFFzMFPJaKD1lkYmngpN3Ndy)smHRPzGyNRu+yuZfvxP$45N_|;!|ysu3$6^ zMIDPk|Mc_q47b9LM^y_^~_M3FlXPV7X--NO%9u;dyhonHr5D(Uy8RL(Fxn z`?BjWYB}u(j-X3>0#SP+M{Vn~t;iylIh&x18{2QEJZ)IuWF?{TBY+Gt-rtdNF zMY3PU0^;k@t05lV-OSZi7*?!rlxRQECW27K4dz6-AC%HRR{0A>l}X2Z@8Ii))DU1@qd%>wos1nUUHJa!uAxv5wIe%$==k9j((!$PJKwJ=MSU z3+o3BTVR{0z31yVmHE@Eb@J)@`{^51839AT5$iW~ zX|Df>k$e*}YgJus7IX!ia$QI0wbIEH!>3R7vX_d87d)N`E8+Z~DnMd1)W+&QUg>Hs z3+yFRJaSK-4SVxySbzC$Awr9qwMe8FKU7X_D^!;7eu^FLTBMl^S+u)ao6*HjonD2v zn>LN*8C$w2G;7QTi@fuaJ*DRr-B)t7UlJ{S=O>1)e=#mwVv@vWhq&83c2I;piGMP0 z$>eqkbq1glwV_G)e7Wg0*dbzo@fd1=SC+{jT!%+sye7x=eT`U!cwLZ>&gNZiR($n+ z_lN4`2ItMGCL|Yv0;iFH6S#}3>>B2^#%mF*&(@Vom{cid25;q)>?rB(qzZ3c$>eUGta`P8A_*tXFSZ=JYdp!T`;rbZuH~_^4BZv&CadElU}H!8QYl%E?8FrcaI44)=)3UoJ{mk_}4c9W0eQk4oqax(@F z0bcGYM$Utxw9c>N`}6_q1&)mExCrzqAd<*LkRWQ2!y&M4D@b(6EPCu0^xIW7tD~)I z;yYt$U5w&Yh9O0e)jqkYIn0mS@wCVmdKmfV;Uw? z$bJxIZyxs(#ki#_KQHOGUPH}|#(*Oy{Fpo=ojl&$=Au7yCRfA8z9eLXCu2eKBtC@E z$QP+*@vFHlT3}Ik@uT_W8(E#CL*?O0h5gz^hKH$#-X4*QsQ2eJ*vXPP5G=NSxbId6 zXASQ)_zVx$_}A|oe1f+eIyK_-I>FHk6i~K)H*loKK4EHtCHY|!LDuuG-glO)2H!ca z#Dg1Fn{%$@vfA!4%(RZIsoa{zh+)UC*I}nh0uQZjq7u^{Coc@8+^ptFdF<}r-9EIX zFLNuPQVCbo_2!00r1>!@4D#s2qAUC2*jGg>NQ}tDepmp@#N)E_DM;0?&tD_HNvNxM zK_3urU|ExJ)}1OpB7N7GQRe0DQ$H^)Klzb=aqEkx=3(Wl46=+kx`rNjy>U9bOd!-loW1o5~xwF-|n{@PIw87!r-Y=2t9b(6J^5=M|!Pj3N!% zUBUZfcG(sX()bv<$W~>HVPAXusIT-{k8~6opJHvk(855W%PV;QiB_1cM~87}28jB= zLug?R|8BSU2|oXJ^zEQ>EgHi}-#FMSK6-Ph-j{B78D(X?GQ}se->#G$s0Wgf-gx8e zy4@pvy{=sXM_tm10820odJeX@_dF?b@lXN=I=;$(Jk{(_y)lYEW1_>Q+xfd}nZNMH z{?&c=t{^`Rc#sX1ZQ`cY2Qp70Ien4zarS7O%8DFrp~E!%;gtY?d~Lur(EFU)!$>{3UUJ#MtOQI?}Y(#PeJe#>p;@A}?a^bJHx7JPj z_D%abM={o=g~wb~B5mq_f-IeXZE?kbYNVVXh5%2H6E)67-^NKQ^e497^OAVhs^cq5 zY0=WCdZ_rW_pJV!K#0mTnS~XEi;Xo9beo^MlwC(oHbLsOO@a&d9+$lhidkW_e(YTM za)~*cV#(|aJu|(S2v+AdNP~bJMa8Zbe$B=)`HlPy>2@2oG6K=po02WPYq=25HR*xl8wu8j z_ITq87!)<#4&1-!5qFe`88tcQ)>$Je!>#>X1FNvM>F_vRYtRkD9S*r*{Z!cUri1&! z%YNRn0D(Dw4d-bR>l{KKo6zhEtyL$LqUUeuJ4X9(8wz!Mm_MVGHJGni6|>D*wlw+zR$V{ErqAD;7Q(h*CU-lBXDqr@ zRwHW_6^E6_r3+fct!Mc>9Ah442yST((Y@t5u!u&0^5A;iRcrh*X|Tezs(d}6gfcCq zG;7sRlg-qlNmY+15?3wOgD7eON8S?VLw!C;ncA3|y%#4tqpNIhJTdX4Hm!yA$0umr z`E+?86Z`@U$)f<|#pnI(vGT#KM+L#9=GDdq)vi{{+-3P3J~TrYF&LP!?d2fRb7I&H zVNnAMg0NMOatL;_<^jT2-p*dvd*%I+R`Yq}GjMQ!e3)HEUdRyH`BqbXN6ulKflKJc zP7lQ{m8E0MT||}R(-GAu-D?tnm0t>zGk8VqafD$pJZjdkF4|^7OuyoQCO%+XGOm@4=MMCI^(xb0U99`r+|*Q zVbb~Xp!Z){wGTIw-AG*cTNv|i)WTQBR?~F})o*`(_!je{ciUbw z1x_~;miVF#h_^t40eHVM}YEv0cM`qvFCxB?qA@= z(QVCRRE;2J)ss62L)Y{`_40zxQW@U*>#I+%c~8XS_I~wxB5S9jF)aq|P9Fs{KQ`3k z0&VMFDE_*3Q4vzK#5-xO8t+PwWr@rKnjiH7( z-W?iel@sbqfU?;_jq`)EXO6I6XY-oOUp$Nu5mFM+(~KoB-NQ!3VT3I7CJvb&lS^#{{52!;a4!9zhhr>za)IDFC;aCK?Xj9QsB760AZq%!n)M7BG%(uhHux%p!RPmJQ2;hFaU&lSOT8axGVQlUC=GhY`~cmTe3O z%!}>uoHJ*f&Md^0QDDn<$QT&)N5yG8iVF+Zy)zd=jY( zwME#9Bww^2b_M0u+RS7X99H`)p4X^tFxzgNs0#1ZxZQ+qws_hKT&^U$;^^~nn0Y&C z4inNqP0Of97VEAPMFpmbq-Un1fqje?eJzxyPWt(1Kp_hD4#(+azHY1st@zcO4au^_1T8_NhW zAg?VMBP}y~Z(L|fx3W>4N^t^%@7 z1c9kJSvfb3&0D7pPclhdYt0FDCrP35Ox6P~+|G{@KU0a7e0tD-` VyoS(94Q7bHvy3QPzj(r^spo3ajpLn~R&DgWWN)+dLQ?s9eREqH zJW4W7V9nh5kttCso?dM8T4r0y>`tsZ1Q+ zzC?t7Ns`~$+}pgvF*czup<1vHqC=5+vybGFUxGY3+_zPB^g)pt?WPC$6Y_@<-^*PjUS-L z0Ny$C6>=Mhb)zA_{Q=%XQqb6U{{ebSy$mqfbu4~>;tPI&LLZSo!@ixH1{SDRF>=R< zum7X+KgZ+0c07jTba-2BS;B=aACI?HWTbS#^Z)qW6r+ggM<>V_Hmv*gIMph{F$pR-}vVsFq2G@3{i3(`TWqTEPNiuFiv#66R#d` zRCROd4)NO42luUwIr%**uHkZhPd-Xaj49AaQhck%Vkq{rUHs2_i2kNz?__*?Zs*Dr zz+7$w6w90d&9DDtLdLwOD=$36l_)G*i)A0#%Ir7;2Or&#&H~ub7O3hep;VfKnnrNv z5Fp$o+Ji7TvP6jCU1n=-3*ZsPBhr4@JWIf^dwo36aW{l(;xs?R+C(*crB*jJ(Z>7W zLW9T(fL~~a56BrC-LGYI@S?&uv9k@L@5T#JygcCK87Sc0WW@$tJ+5wHhO>8ZWrl9= ztxq*;+)C&h#=QIH*En8+;$-e)x&16^heE)IbrwE_vkwG9O=u0FHcW;nrg+t}yt#T_ zd?5Q z_GwN@uHFa%%9cURHI-5bU%()QsfB|WaimI+3mdvUvhj^rEm!4amg5s*#&?ZW8ip8q z7~L8z2klF*TZ+7_BoRRnXW#E5>XZ^;W@9Fu?q(f1mg2cH(VM}YChMa9o!SN&&cTgT zhOZ$vt>oIaC&6tHUfU6k7uE8QC8d*1x4#`tX}LTujeIKfzPtXZC@q{i(~IMZE{n(ttr`uHl5$Pw;rd)H zsx6)7_B#yUg{I-n$1VoMmbC1d&*JmW3tI0{XFMUI*cKaaT&h^>Azj2$=7LFZHlBk} zl6w3ifLS){;OM?q;M8}42l5`RJjWhA=7ja_sj>N`1o0xr>oxWHr7JVG0%VTinpyP0 z+pc;Ess1}J^dDYbaTF;_F0@#`ZkPA=(Z~H-y&S^TYTpJ!1M->iI8DBbKo!E>d1;hg zr~di_56Zxh-M*b??~kTx_GePpEcT;`h)il6v1oEJs5SnnpBJ>;a_e|;{c&df4r$iT z4pt9HeEf9|d_^D$H#=bTBpN>)lzELYpXmPiVuoK=ZPNXM&XdZPj-Q|28*rFa1Zc3R z5M{>9E^DEYg+WhYBh%^LlCD}?$1Guq4M3nx%>NR_u)ehg9x*Ef@CI`+leSg2cZrJlET2a02 z=Gn+FoPf>oywtW@#)3kTXf)t}eVPX!@-`{`00n76iGWw-0jD8{x`L7OgJ7~q?C0JH z9P;nqLlG@!fw=YURL2}qI_L)|;tkV3^dS@gruwUZgK{PylX&EMxQKaga;B8@0DpE> z?pBKo{)jmg*zOY>AsxL@(d46q3*YfN%1%7yg^9qglWOf^jsK8&b2LTAC5(@cQZ4P>H5) zo@qo-dw?u}(m_cN6~H|#%Ez2gf;9Kr1zrxG3n36`+)quBH`7K&7K!){Zo_8}K^Ty?*cq|}HV z2KTB)%|5H(EVhOa1QQlMT~|0}H>esZ$vH}9GZZU8vI@T9k~X;n?n0fMhQES)OR_y0 zbec+mRvI5(RsJe)dE)irX?3OIfL*ilA+)%-1lnb@%vG_!vP1g0xvRKO%NJb*lZ2R{ zREcNxY97|S#1t%q-HF}TIDElH$Ygl)>hh$SwT4&)uFOb_pOe+2%Ub%2A~@e4(|P?< zOt{@^|7^WFO~dC2eP`50=bZ;&&~ehugicuXQYfIb^~_P3tpRVIJ%@4;M> zd8y4MC+cYI9kH(~6&a`~=U?LJleBNVPpB|yh`n%i^P-{B2YwEo)LUB5Yt^2HwnCyO zeY#Pg74u?i2RC@g&U>p_2R1p1+2Lg>R)KO*to&XX%9>a`pGm*2zjD|U>1lPAN>~tS*+AS(H$F+y_uVa;*dZBi(@YDl9 z#6Zq>E^Ijjr-w3YdS?>I8jdKKP-QIceyTXw(EMy)M6?U)t(OL8;XV_tNr#`VMp z-s+Ou_l)AX>q}o$EUCzy{#XW5bf@-++Cx_7cQe{!X*#ZLDN1~ijPT;l6#P_n=7H(` z2-%D!Z`D(Fn)jU#UHcs~ydEAg^;8d7_Ou`S?smZ#o5V}j7|_T9hol1&hA|9KvH+O)tG{s35nqqIZVD#g(H+(@HRmM}>#9W_Q5hp3ji$WJ$8fGr5&AB$o zIeb!IWOOGo9*OOJlHGD8YD$Of%r{0)o!Gx}hhpWVzgadsL31Q4T+GWtbZk=D!tX`w zw|&9qX5u=I!dcIOoYr5uIk0R)aXWp-OJpPW!dYa`*W{y5q@u6?ZZT|Ncl^`8{Vzj1 ze~^ZJ*Xj?0ZU95FxtZ%bw~-F0)m7qM$(^6L0kyh&r53P2U&hQH?*52ErSF7`qmjpf z_^QHveb=QPiTN}+*BR#BMh>BfM_}-A;bO(+rYa<~$+G~)6(keF z9)~*XkT3@n8rru+XKqO37)SHT*y*39ql=RItZhJ)n_mhA#1t~*2pO45-%@KM7i*IB zViwdM>9s+UG7N%Gs?XMo_vb2v?Ldh_!7agZ8_8IC&wjxrCv;Uyp%9}$t@=~B@A-{3 zXc%^COuht)@?qONat5l<>J1g>;L$n#jCM{HQ*M3Jh@M!u7^Z=a=6t%pQ_&cJZ->y# zIhlXj4n-K^bVK*APgSL6ps-4hCdQ{5{gDL_^xVoj&`ka!rXOzBx-rZEzrf%i@d}E7Ip?DfBc3fLl3NTd=^eyG%836V6#ZoV? zvL<0!rqS+k@^H`BTK+um1pN^UNN{M)kh!+>y;NGp&E90yfYnLO z_Zv>lM290mqAT_w`W(SKdqJ0#2tdk=W2>8k52C6K3uFSN`kr#?mzTE+WxTc(Y&0c< zSK>(anF~)c4#>>3Rgb<2hJJ)q`lnxfdh_H|W5BZFR<~4_Xc%7UYE_}k$)7sH@?=v!oDCaPV zPHifKA0;6<(Jg{G2l<`ssj4gQR0ky|UeKS5J>BTz+8g6z%UhUg(iTmxLHtxox*h1Z zM4(3Jb>&X0?wqeVqkgy0XsS;3?t>)$$;hbh`TIUVHqhb-$Zw8MEJz*;vQ7l_1V9EC zj)Of<8Q!VZtY&1}(XKMROCR#^l~e-iRV-+^9BqV00trVfhU$?rbWr*z{g$>X*BX0j zQqI^}ent1`ybn}Aotu2aL`mJ_?gYd3xBGKrQSCDdt=Dgl#jn)JzBB6*pN<%M>E0RN zrqcfz89JC|gRsqAxwVF9Ej_?mnj`mJ&@juZ`b`!WB*yP;e>l*YkbQa=dA~!@@M+m%vkt3J#k7>UjR81g?xkzel#~b-J{c5oa<$& z_*Wnsp@FGObEle0YA z6VsHmyzftJ-EsOVZ&`9J$|d4@GFS(_KIGO7e^cVN_N=InxBOoEPR%`d$H(Oy zpHvi`#MO^TM{op}Mpyp?6x$zNbuF)IV4Gx=rMvhs&3wvkUPGM{6*rO`b8VreeN@^B z3H7#G`9XXnADnyW{0hwp3VhYuHL7tR+_1%maQ0txmT>+>)U1UoF)$5(fP#LHUc zJ-S~r$m-|D$M87D-KP-rwK8E$Ugj-60c%0JG_bWHKtM$x9y}H>QyyAKzquLT6->t_ zJ9)|yq`534R483V383XeM2Yl2KtF}H%@Kt|e}KZX$sFi;$zx^+4RSsULKbY!i-^k* z2hd$6cy@9)I0%E+2Sf?ZYLKL0t*}D@WhS74_!E?*+AD|QG2e+)0&%l+e^kO?SLje>&ZwZDa zpf{5Okg-$J?ld4VO|A-{IG9K&=HK7efa@#ujT7d7)XpX#|2!XcWpyFBsQ%sTgnmLNBK zk#=(-02(0_MV4`#+=_N7nKdH<%2ZREy$@qJKA17RpAsoCT`cFk26|1QN8%opXyGDB z9A{zihp*1zmoNqCBBQ(o;+TPf3J8ZD_fogi>%-Io$si=n6jnVS*Oy%w{S!RXI@z4= z!o;Gho?_*P&r+jriGFz1skUtbAFt|~xUDWBrc^fjVq+^PS|D0K%f_8Mcs=I$l5nHs zE6e!+5Q!GAOS&2meuztB)0XsFiK$QO{qBv+osSul7(Xv?41A1r;qLmCZmc zdHID1p_>#6O3Z6twFN2)c8{FP-2PKcC~3_koRe69~a?-r((p1{PCQjzC%Rp zn^|~AC*<8|9PW{wzxs|}q`%`?2`-0#zQ;nwc0EitsVM_FVnUkfJ1|r&fj92$$V) zs%O)@;nX4E_ICAI{-q3R`BlRYc{8CNy~;9hp^hIQurjblYsp*#&eHOWV>yGc4YqZ< z=Bl_d{lMq1ZE#9*a-|Z<*mX$cTums-u*+C#?a(K%s*&4Yu>Cum=Dm3Mpvf!2%ck4T zQXkr^TPfCIEXrr)I$hZi3Zu#T7H6Y z7cj(|2$OxOyg99lS<_RnDVpgl&3+TZfrfcu+@lURpQ(ZvN)7y-p!v+HngMBE!74m8yA`#=;z`{O#@w$uBY17^mj+Gq zh1BT+P}YW?&=x5R0Ps%}U8ZPQT7!~J_Ix)faE5WGd!;<A< ziF0cuMG~ix5Ok&(SqW}~o`fgwd*Sz%n%|DYll2HP*|4>SkhhNoG8VL6RbKK)`c6Ge zb>@I#K6Sn2B!Y`z(zdBA$!HBP>*(Y=RT=R}tZG%dHJL7^?WLCB#$`^sHlC2>sdNBZ zX%C1BF%2TY2qj{%TnLF(`kfYI^im4z%xL=3HG7MvX*Jq$uWjo2UstQ@>VT5>1xF>( z%FJu#Ix~CS#jXJC??|HpQx~f_WbxCyMt}J?TwIKZWLPf-*sUDrK=2PV33>3|QC&Jm zRGrN0{6}u0u`n>fZNtH(DV-LOy-X92k*dKvp@N8`6Z_i}gl%C;pnm`-R zTAkobu~wx=d`cMiDXcX11_2vTAOPNA9zbq90*3Zw^|#F)5<;6yAbSBo|FbP1V)nvv zTN;EGj_NMABnFNdp$y1Vs5_Plxi;rjvd8he3I3$Z z*=0)|!rkCcM0R3h_i6HgGB;s4ru4S!8z69OUbASwvn^!MT<-Il0n*HKvrj0qNn>mU z%}5eLf=N#Zp`8>3B_9^(=`En;-%T&SGE(21SmbRo@xSSNms z&w22w&eu;qXZfK*7ftx-GiDs+`LgQ^@xV}Ho zJ@y${8{9F6U>Rx>I-p0?T`vJc}lS4S&mN!j7~8`n@qc zEoT65jhvm=F=~PF;sOT0RL&XXFouu1DJ=R~Dp3h;+vPr2$-YV9L&(OC%*HLGAvZF^ zGtM;WaCaVB<{Il(b+d54yb+~#sqa$v)z~W_p*D+j`vFgVOQ9ZcR4TYtH^jF;8llGT z9|?p>-UFF{>O#+%8ys6(ofO23XM-GbQFA;^oOt?n$puL+r!%Wbc1?|Nl{q&biE9xY zJs-FHX>iYB5^E5^g?ogYeDFO3B1Hlbk&=gv#p_;;1l|pMOwygcC+Z4YG{m%A4LlM{ z-ALf?8X!3k`?(zXt~KdQ3Wd&z+hZe*zg`FjQ@dX_>p!=Q;a`dP=5Q~uKJT7L`kCEQ z^o9_AThG=52Y8DtB+cRJt)zo4uFcNeDa;n`!EL|O0r{oQ=#86raTZ}fY%o?yKl}&i zOY!uhvXMLVZNNfM0V1wlGkt+dkT~YJ)$7ikB$;Y+Cu{wR-iH>p)LF z?KneX0OpDlP`#h4`2pgBV#CPP$N`xFA;|Mpy{zq>VStQy~E-k_8NMd;rR8<08wyIXUb7S8^kKa4Wz{zo(j z);d9L2%R?e2tD60AS0S4fBG;w9;7?Wbx8d$UlU^472kvFX-#3<#lpHSHF;%q1kN%M zBO?RXpds8NlkeSLQ&#)M&xlTtPXMimc!|wrALi#yJ0v5)%0lLO>zTCc4`>BAzHa8U zn&IYm>aKBic~&oF^0b^-f$&tFvT3>EpnPFyZj(zcTr5C)=E^((RgdDWrxbn7BJ}r6|?3EJjbV*@IFnYZY3kIcxAtPBkumR`JC=mZR>B1BLG?w%T{oj}RZP!*dVN(fZRqIe)9Ot38GFprQa}MEV7MO~M-o zJ)m(S!8oVkVWQL1^D*OceJy}WBaB zu(W9xw$8T1Pc^yG6k0NT&4!njT3WIJM4o4woh+dwU<&HBIeL z%NHt)4O`nDpUVCC>|3hnrxqsQlf?8w_|D+6Zb%RL41fXQlM7Lu^%c^v@rB&$^3{>v zDtiXN%xwOw7{aAc{spyU{dzGm|A+t1B&K~NjsFejD!x>jP;~kM_s-aaA z8kfIg2j9%_X4%%LA6#3nGhq`!lz-iF(sA-w=OS+Jx8qy#@sd5}KAZgW>-3Sh{2sn^ zoIl$~Y_-+Za{}XSd=D6_mQgPmUW@yOTk9VEgm4;ErX||zw2W4PPdg>CA4EC{?X$fc zYd_g0wky2ys=rY28RJu-8zZ^|k2Y@%(|y9XS;3+Ec*}AP_lCt}x|$ZtsG=xKJjFw! z<+s;lcI-O=lcpOm{$<7WTqDr~pTyxs`__1dkzD^Ir(O5q>7Il+8l@a`EJ=e9q&g~b z|2Uj}rNlc?kFCS#I`hH3E1L$YEJ6-22+<$ zb|gj#vliQaUifmU*F3WdrC>ZLE7|GsRMpnE9_bOAUQSIrB@UjT+b6RU?^b{qGH+-G z7*Nb-sO$k/(Q$Jff*IbJuFlQEzVZ>WIJ)I7Ob73=o-)SB=lAiDH?c;OjygkDbJ zIc6Heh@0(6iXuB&E5e|x#FvF_R1ja;*iPF-P~2+9OVCRbZhyk{!9E;yA|5|J{h@G; z>!;LqD~Y*F7qmF(;?k$Wt*`X}aamtJb>bFanGAoCBcxn*om|#z!K-0H$k&LD)UIbq zO7#|LFT=2C&q5CIkw)L>5>tDHJGg0}WXVZj4UJNi%XOO6X`fm3&CM|xO{O2zv{PB> zSStbJtPRUnj5CD4?1{7M>B@E()Fky{rAFmXXM;aPKY`+Ppl7VtGT-H)-mV#NdVy)z zOlEmf#&?gxr)&;It_$z$TU(3B#wUAZr49A~s_|-1xL*Fc3;lmt3;i!_ZTycaKV^bR z&K0l;=>99zg7AqkEUGi*emQ~V*?r!tCi-EIrzq*-3q+QWre0;j-umt^X%v)=Etqtq z(=04M+r!WhGrGJOX-sRKoC;{$+S(ZMiDssvqD(i1`EwCIc-wBS<Hw$i~LD`&?W168}FGFY|+$oT(st%2B4^~=Bxh8zWCc`nv8%ivloV& zF5>074?m-PE`}7^=4r`P51Y<#N8JnY8;k0gca5?ZC*I5_-O!CZjLr{V8sr}p`{r`J zS7u07@xC*?%pS8@OFG;GV;Fctt!WGW+=5C#+I#0x^ffhM;AG~v!wrtDmI=i^|sqi zpl*xkWPlPvG8xC-cI?;geULdZHWn*tJ{_;36&IAZ!SFmTP=jB7Q!N+mN3k0k21qVu zoWN@n*CZ#obqR7&C8~4QaD7)7*sTWDqui>Hi+zQlL^Fz+9XvH0>fpi$NEpKaGK+os zV4$?_w&)Rx7MgB zA0+JHN)3rlK#RIuLISz<(1aXuYn+%kf9z^_oKHRj4;;E(S2y6Jg4tb^1=nXAk2s{m4G1&sJmD(rTub;Xz z(K9fh^)pNTzTQXM;e&pw*eO)AuD@A$}R?u9`%hNTneW$iG;zppi^?* z@!lYc)j;(@r!8@Ajd%Dhlx?m<&A`O!=;e#m)lD~PSv243)0Bh2?M+Vm`p2*PTcOmm z{bi1nozT-HjVE4`% z7(w*VHyX;S)8u>^&smG1q|Ed^Ol?Z*%%tBdwb)MrL{!c>-FW-xP-Zi@w<$+=sV_av zub7i# z^>x(l`)>A5a2U|e07bxc0h$A5idueHDD5ZSq_j6H#>u)2Lsd8O>gxP#I5J;}6gu1R zow|m029!ROeCKd;ApsYON*l)NwbxC;HY9SS{NMEpSlSJDW^+ zy;3uKu8D}mHlC_RfhC$9+|IW5bElLh=5eRW#F>CHK(-{KW_x zGqoprZ~dB=;K;)0!{^_YVxLCxzI#ho)cEXskUP`m%QWRgGR&`1|mK{|M^hr4w4$sc$sU5oiZoV-y9-|E+|3o=sV@Yll5Ct z1w+RD+hpntOy7K0F$O)$ju3Z(-KBSK2+~pzO zowG4}S5Nax0O3BRpH#m9#oUG3*6q`e%)y&9`73iwVoh33^Zp(>uhwL*kEE})**`rc z$rD?({B@UIL__|ZEpIdTf_*!TV+P*e4s;W5;zi~n+ud>wQQGH*lLW6Q^A03Yo0e7`tx9aA+N_5>6fQblMy9@_dKk-_ zDVwzdhjf5xW`XvQoP?tr+G*JZS~y0-jVl6M?=-}oSSa4$cz}B4$yZ&9bB#Wjy?z{{ z2z{FUnI^Kyxn`Cl+$izZY@bk^N9$4XPG3J2jqyVB&3zx@y;6jD1!35{&%AQoTBK>T zG~mOfC+Cu`vC7%0b4%5zykGtF#QMx}nxY1F4y9A5_0(~WQOqlHo7urzRQ_u2-mq{~ zDc?sbRlBL(47%6$g2|%qIC~LK2-ImBrS;l*&FJ&q zk;|w0cPk0;nz##j#2X8eV7Tl4GyHxs{&P>4aR$TK#ujeW^OHd@0YC-ck5XhmlQJCJ zUmZ&tKR)T()9o6pEHKjKfXxgGplB?wA%ujoB)Lzlr3IJe+VP+GqEK-0l)OzVpWCVx z_*peF+_$C%))l}5}p-(dko^YmbAiw4JgU6=o#g7LY-%iWa2LU+ms_Cor{=_ zuLL^N9C(hEcyE)lQRv6!Q>&KHX9(7N_;oClB;oJqH+KfU1Ete0oSw6sz4N8&;q>it z0q%$kMOo);8td-s3O4KF@&orEVRJN17T9Qk)`pD2`=k3yUwbv`-esmd6JSzewD^Kx zK;c7m@DCfprR`NB)iIhDA{`s;oT~BHA!$^o_R{X1I!|R0{8m07-o!_d(3v6DOjyb$ zKuvOsiDkcgpGYJ>%z;G)wti-`3)dm$6aJt~}%9}asTrv3?5ztISN6AQrDW4<;&zn|F9eZ>@9RjT+@20$s zS)(1gcS40}P)=%f+@F5iZk!=AD%r5($P_^xEEsn_Q3iQM#+{XL0>}Thr}-%|CV(HV zyCB!o&%K%nWSOyoE%^N42Po|!X>qo}z0I%n9h_VCP6+0|M7?%EBZMO?qW zVJ_41`Z4LETv+MW{nXJVQ_kD87u)(S+@dhh;QSuxeut*j3WD{@tH$yQbxOhWTQM9i zb`jdoi`Lji=J^Ef+L0lULbI>vy`%l~ zWJTMD*Rg5+H+SZNGNJq~lougS%qL7rf*%D)d?86~zX% z=b;meW8+-bX|CC;H8pHu1Cx`T<((xPix-}O#5xmN*R*67pd5hRUYf!shbjvbxpfEk z*3)z@STP@H3_p%8yYCT3Ge^+S-?WJO+u} zBbLwWC}WwZ-M$#LnKe`Oz0_B8{F%x9b2v(2l0-S%lt!lQxnN1UPR#cLll&|Zx5n&V z#01m#1S%DyiG{dzSTKeW0l3-SMxkm$OHC9FSBo;mq4T=cNG|&g3jESGHgh9wUO`ER z(<}PU#9rejNj`pc(yh+h2E`S(me;glJHc(Z_7}0h)W?Q^vN@_x@*wTIQExs#?}jA7 zw&iw;U;ZnpFP4Av?|;ij94$$~EZWrq@`T=G&(=|ZD^5WK*w-hE0loD+eC=^~66Vm# zUk&p6Ca`)mYx%zi)1r;gBF71s#nAvp0IZmjja+4rS^M)2*S}`&{*RYWar4jkL3g=+ z=@wN8C~|$LnL>_TB&z~#U&VmKhOjPSAQ+(BMn@h9IGNlLAQ-4QE62{|UHr(~YPFVs zOP6xbn1;ubKgc>!4TQLe1xNFV5WzTlXb6PwrG@3kS=~_bd6A7FuVILn;N#e7TTYrl z|0}uVIAf-w-kYKfAvEVu=34>hNY{ZV3r05!cLK=qdCghY4BGbaw3itl!AAEco>O_t z?64K$p*W$9%&WTi-lt?qgrs2|QNgrss>C6@pUsEr{$1X<5MPF~$!OIi!O7d*cMr2y z%td?19OMxUkb{TS#PJmhg*54NPC-LRTI}PJSmf4Aj}Yh<8|#Kx?wwQTj>fV%FNx;e zD+ea7-d~sy{{q|QS?+YDG6$rz_!_ZJQ{koJGETP)8ap#Q2&TeCrH*DLR@VWCz5TaD^Dl1j-#uf=LI@YU zTw5`90|6CYB7Nsd(vWt;g04zznN{V6=!eoKEdpl0!0R2+HESAZs>aZBmWb;d1t3>f z>QIWu5g#*?`LpZ!-p1nH(oKhEUS=(NLAjw0^tz!9)#yOd#XL-mvdCKY!xXF^W!+oo z&g9pg)1A#)_h~7I#1vo?sQtTjMB5S=;wNHgu~1Zt9?Ye@nfY z-iB?TDq^#Rg8Z?C_Bh@1DHi`VXDoGOA7V&gD-_BsSY!5auv5e-ZgS|aT31*bnGY6_ zV#2z90`j!WVxn@^z`|q&x?L2K)^AmMZ4(Kw!?n!MsnfRy(NR#Rbn4xCNH>a!{AE(D zDuL*}z=B_n*|agZHBkw!oSXNOk{uB$lVu8y_RqX4vkBHEC`bC|IF3G9SESDFUX{FF1UQSaJjc!sf~)O0w!r? zM$Fu&92FR>HuVCtlKYddBoE0HB<45b3m<-{t)on?`S2Z9J9h4v7mGgYz<1@#CJ;Ru zRF6uh<<2mEeVL{C+VArA=CDcaof1dU>?D0R>6NUA@z{SceZOSm{_L3&V7}uqTEx;^ z2k+}uwWoXX9lZ}r7@e`Y>GY3P+?h<2)Vx{rwLbz=Chxa7mA}3QogGU3zPuo`$yf}> zb)5xz)w8=mB1lgqx3;}Hzkfr6)!`2Luk6fkd>-$9$~SfUmRn#`NP-`W@?{~C^l}dv zOrjeEPd3Tt>ax}28Vf$#M$BH}qGyB7C04`@ODGWz+%I%TguOLVKT*(1IS9yfxgBSC zLMQrLOS4ypw2Pg5U*sKFj%~}=k4SPSC*2%>5L^6Ocuyc9uUyWe<=Y9^AIy{@om0(fv6O9H%(ON%B+* z+d-Q$1z}MNU-_t^YSKC0ZzILhe&>)54e3IoY8G>sH%5vpV@mB8L^)3%)7n9E{Y%xH zD_GeLo?buzXgT^I9P*f3Hw>shKg|a$QZE$y1OX_PsCa>J5Bcw(4lN5fTEb+)dl)I zy+pRp|D51aA6kHt_#A$KNbW6zS*ViXEuJ`K0K6L9M!=Ax^E;75MF2iS{0hO>Da!$R z(N#p$4-ic%@&||qT~(01E~5}Cpc{$&!{us_Bz|TTeREL;@PutknLvU~bW=q5@Y5aa zKkwj|?_fe%lH_IA39VFbfz_sZKPSfR5O;CGwJ)SUTz)V}=5xsS=>=brM6RF_iI#6xdWX!u)Y5 zWdCu2CGID(_&>~l@^85$$iIB?f3leW)F5b#I{GmEH@XF&Vf+~!`ZPlA`1OVt>a63~ z_1PjXe*?;=-cvbJJ$@ot3DYM*U>h38L&1MqnqN)(|It>?qrza-8l9M6WuEH?K=Rps z6R9ysg`Z3-^hy)rdr-zK(=M-AL7DHn@|O7gNE05{6b;!&hN{^^#OK1Av>#VU)|=j} z;G19_0qn3gW^r8b2dE+x*f<8qi`Oh%i3gF#ZKOi~`}uA=KM^OP6X1n|?ixy*T$jv1 zlboR9HVsKBy!1l4U(T3r3a78#Q<|C~m>x-rsEY*$DYMq%=K`zfq2Nb&&)TZn?9Mh# zG46x%TOYNQt49Y+Z{4}G`7*HvC@rz@Q{c3kh9e)MZ+-7tHOco?!?)O0?uzUVK&;$h({juNh4thd?tYbevcO0H7 zM~>~VF8&W)JIkLpvbzEFW%z4j;!I$tj`goSc84uq`o?b1^=G{R{y}BeKRUqx>%#1B zI7@S*UuzqlqYJ;Gk^hyP@uz{uoxa-{-(7Rbsd=#u4C z`0D>39x)LYAaI<$L82s#-68rf9CP714!xVXz)bKAGfLAnxodoO*=NW_HUVq3vm8Bo zB0Yb&b(8Z1Tb$&^0HL_G0rjHEgrM7Ma;UQq1X#a|Ss$o^UZx5$+X)!HzxV#|;WNdv zKm|0bB!D5HZ)ute>n-5B050`;UyV9-hoUE(UFZi@{djlT9!YQbZi}J;NW3i{? zr*7dqj&D*p5)b&W?;G8nMksCz;3tUhHfW(8tVJ;;X4*sxxLv$JG{f@-YzfF`k-2u~U60u$Pav@xK)lfdcx53?L?Dh)i)_s#>Q5pG zy`a5b(f2@V8?8(Ygc|qpLSz>2Op-h5zuOUe#eaaDpn-A*VoyjgyjZ+cdloU~W5PE+ zJD)>vR27)jBDXuTbak=`eDI*6A;u1>n)8L`69yo2f0w1jY}^|F0;3v{!=x#)(!vkW z6Qsy@Co&z7VH9kWJC3>zB`0;k45pAhO|UKSR~6;~wZG*iDeml&p<*RLib>C2=5 zGi1|TZ`uDAW%<9e+W+eN|8o_~)mRKT{=ua0Ka|h@t$`a&{MGWw)TQM9^;`%yg5Q1` z!Kd4m-7%8^YX5z6K~{3#^$bZZ1<|J@klDZJQ`<|r_4OwJ8rZ9y2b%#i2SaKhn82@$QPTHALcIj-y;tc%)JlTRm_j2 zUCDmF<*r*`lB-LXR4J8sr>FmPKu61okvt1Dk7Ix!F}dLJ>=>W$=)-h5pUzvxmz8S9 zTyw~Mm{Xnt^tSk^4n45HvK%yx()z8MSJ9K0_kAg^g-F7%5X`r8Z~(1xQV$8RTHAwv zMy{yRgkbPSGyq!AYKGOaxZLU=`Xq~cS)2EhO#IsbE`ZMP9E&Eg0)WF#U>VV|W6{T_ z_<*q-(8Q)%z%c-2D?3AqkeY?$nVcG=A*{*&kOyBM>NAa z;4_Ba>r|8v1gt1?f$tQ#?*;N=8NcaqNDHGQ`-yE`d_L*IVIXQ|7Vf>{M5}ig#EtqI zaZq&Wn%Y+22Jr|7*+?dM7~**V{xPbXew!5zx15@l3{~_BA8%qVDd2B5n#Ej+p?l_6 z?U&r)I@PF&GOxmic0(CvNj!w87i!@Bl0YlL?+p)Uzad`79awO@r+77%KPnW}1b&0~ z34i=KbiswM-*WIVyK_=)1?}mhsBiq045yz`*V_k2weT{ss@E!gz7!9SH%)&K-D>A0 z&j3CbG;O#(Un+s?`2i{~9hzA|7b4V&4BY;+;cn+Mjq#IbceR|CQ{GtTg3mEcSNfOE zL-`&-&m#Eq;F%p(6q2uPc+U0zr5~81Ub}o4Io4<&{JmPF<=qu}(`0 z-_e+Rx?!aNNJs-{uT9yc<(u8bMP>7Dh5gFuY`q%&16v`AFMhrfbNP-j^#DyaO${?B zV_~{o)wjYoxiqPtm?U@744?`1=E@kz+NoQX?DIka)xZ*QbV)u?YFhfa)MWE>sp(zZ z@aY@Htc0-*!-ZpAtVxWb{*gQ}V@n%N)X3sO2cfsP#h+%qA3yUT!<^>3^7CiyZ~Dip zu^+BAUM+r-6(OSj&uaGHyvJW0ntx9<`{!u=KkLEzYp%g7_tzYR&SOD`|Ha;$$3y+@ z{o^B*Eu@HSQ&B|7z8eyfkjhRfStn#43?o9;389FwW+(f;WKGs=Bl`?xXNHVnmcH+E z?)%*5oX_WT?%(;H`}=tOzTf+I|5129xaN9q*Y$e6p3m3wY1$3p-bO9uZ?$1;P9%)M z^e{oGD!PS0i{Q2I6Rv8OHy@9@F`Q>D32B|MXJilpc|b8c-+1ey_T6G2Yz}wP{8#Zw zSk3Dz-g#Lh6%o}JXHY~Aw+N){%{z1PKULX{9)A1s|9lqab^KSp0zS1~4f}y=%=g^( zO{A=~tNMC^%=%_wKiPvvv$LBULk;Wvwm-y8()TEdNJx(vm_*G;`@^v zr>nk7S6i+Ho&XFZ84CX=&-hJU?f>A7NW4ZqK4+nS?CBZ#}}apB45}FoU!PX&5fp@E=)H*;$yuPnaa{f~BR7GRIlHH`C-tMA@)b_(UZ`Gqp~C51ZZ;1nM_be) zeQKqge9oOQ^mLbe;V{D($3;IS`NUY~hDMJU*$%(g!B0CX!!vg;V9&i=U0h~PWT?Y3 zdgJAEN!N|wq-n0FwV<^G=s@f7)mkx<2|RF;ds1#3+cdV7(*By8Rk+DOs{pw4 zn3;C=a-wQ`V{LI1!VBAC9ZZwNCINq(ij;Z8RI8*+x9X zKzLwoUED(`%7=)~748p7p80*}vYdUkX-bC&^gvTR8Y1l;%DFK3H2We z^9;n(8lm@@W&)k}3|>bE*fCagr5tyO($aZc{WzF*#Izn@zgjMlz_E$2U!d;P`B26B zucq1wXw8}oHeubEil)PuZ+6!@Ro!2Vm`ajZC4usd<TX zO9^fEYCYLD;Vr=mNJg&nh|3n9-@cQEe9KjrK2t1zLjV)PkNZ|TBsFyppC;fOXjtl)9ks{&<*#_ef9E#ee>MOV{$tO#&_ANV z-Cq^Y1Gcx<9e=@$|95=-PjKgRcfUM{3AwHXS0!Bt$2BjU3*sL#7FW5tq`V)+LFKVb zCfyYKsk}@}8MARW?FOF{vs`r8n#h91Ma#7L8VpRWZDg+Qma(+OFsiK+BKYS+MY@hO})&z%{?JR)Dnl zh3J6Ehz4Y#VZ~jI{vw<{;KJ@F`a}w`(WHkC#Ouu9OZUk?+f^ixy)15YsjJ^aH=B{nzI{Elql|hR_8E;&v zcNtYMruF1KRneo{Yf<;_N+N&sI(tISHal>3qQa@d_&Y;vUV&2BeLXMn_zEqur6bvy zl_)g);tDa4#7~y1Y|>Auo1`;+UpH^-Enp?S^E}}KS@bjFnH&m`9eP1!TC2Q1u6AKV z((<971M|6ewAauc0grFI$XGU=A;{fzsx=`bR|uY)#FrRpadl)1cNFjigBW)!?lr#& z;A5Xm5dv=g%kjVAyC(fL2KqA-_m603|MIcFVY_;l{!O)Y?L{e&!qVa`sZ6FopO-$y zjdWHF@w4$L3bI&(*!7wR=n555XaFcHFUVw_F#IhZ`Ougogy<_F1B-2iv^D_pazQ5? z|K(X`M+qMbZvyo21KN&$I0r(g8ia-YKE&&SKL#$Wk#G^fvjlSD|XE(AI{4arQCP06X>RJ89aapFfn`3+~oNGX<7bF z_WXY~K>G@4(FgoeHOgbuu26@lQj~O+$AxfBBIl#jwy^#4%)_s!+^$mzZ6*z4Bpq?- zw@3WJUbI_`A~`~wURTN+_UJyfAv0KsD|y52m?na?_Q0EQ8NKCQw@pDh=(Q2lyXUhG z9|F+XrV#uLK^mviQ|bZcu zogSX0w15mj-#=z8+CBYrQmQBLeAF9oG~>P)c*(SLm`z)7v^fA1tIMkCR5GLY3zXSV z)~(QQ?_eJ{6b1Cmz1-^r|CDk96TzInKpZ+`p$wo)D=YEVzl-A+XYo{e=B_h>^Zr#3I~=Y|N?d`qj(^mjSb-hU){dEiHf^VhesdNv)Cp znUm{RtBIAaos(-7=f%m^nbSRr5^DL!XoxB0@_58nDGu`K*1|I)GG}2x!7}ii_P8}N zbC)(aDEFR6)Lh&EVR{fR6j_Ztx$5gb)KWOTdZeGdlR+RDu^Sq^G_#83!4jHdlb@9mX4x}6fs6`L zC04pcDdjaWl){xNC}|p1uFaOTq){b%gZgFe?-^Ivh~8rN-8_e}lg$Xy^I5%YvI=zP z7@XT-*V~47_U7Z3Ndc;Y%uiQi);X#7-U1aSBuxym6U+?f!>OfUBp96Bow)ha!70qm znhA)=xyGuBzH5GhXk93jIkrS@a-`SIjq#eh=9#dYcdv7(FM>{jIK=!C3q|5~mChLb zX}KPF{6Ac+-|@QsFX3vPghV#;Wx__g%!v`2W8NE*31a)ox3&dV%D*f|?dn4bJVZJy z!!&dujGsqJvhTWuAMI{TLAhskuo}P@itwinp?=;5l-$0g%uwVu1|B2_l3Yd_tAc(%@n`?l+ zTRu>{Z3S{AfO6j1f^aZAx~wEY*_ZzlfjnC{lQkXSd(5|0h&?TQ%*Yf4?dKx~ zxWuHy<9zSP(67Inpa~uM?0u@feJ8cEv6q@MXmfvr4zLdX?R5YME`NVPZEhN=gO!1J z_3tkOVAc6|7c`I!Aja(guA+Z;5s<&To=HCe#Bi_ecXtW$o9k)%72xUpGxM{C^bAk; zV#?cf1uF4Z8VJbK29U4K%DG*x`#d4q>SkvJ9VR|1ZalmkNYbQ^g3q-rqu`QXgSZ_N z0taoe%rkmE>wYY?uMx!xn=`#FR<>8+!uoj)YqhA4tebnb7ZN1GxEUUQ-aDg9w>+ff z?-m-w2v;NeUDCxMkHZie=_k(xLhgj0N9@#C_u^%D)5zSEOt@iWvp(sLvrpWd3tSRI z3Xm~;d}R|=c-IMnc^7oimy$=lvNHYSB1IHY{9sc@#bDtd5@~E`iJb&7p#PYUeT`p` zk{NVy!{0nraWHTBW8e7%IhNRv7UN)nS(6pAv{0ACks>0pjmQCP`e$_}ihNj|&Teyy zVs?o3Hm_`xC?C|WoZM=}UV>b^pF`9mX%X+7hf9)l4U5=Dhk?fK5Da)~fa@cfQzp}LM{a504={2Nj3ZgF0{?Ih>FZEH$=+*I9JEZ>sE ztq_8%lgK<+VF`&X;MLjCZ8$l9G!;jmK-U?=FC7zrR>=`2->M#)+6QX1xpx0Uto^$M z`~Q-~`nO>B-~RkB5Nj_gEaOC%z>(nU{8)se+u8M?V?q&(pzhy zFvBuO`}d4Ybks0^*4n#kZDVP6sb#RsWVK7!$x2udN2j`m=DcgsVV~-Vd3oxc;G?3G z((h5upyVj$Tb063AARxXP-I^R0qv`jLH$%VVx8Z~Cbk9~B5w0lq9HyaVqx=uFr?&a zCopJ3u&jMLstukqX z9ErSq>?n-jq3kM2)ui{1-C)+f z-G5N~jQ{4%F4I&83NrtdV$SdW_Wws^sA6ZOqz_12*)7VroXPXDHCJ&#ww~zq0Rzy{ z$ca?Zl{TKB1;jBGR@gIp{C;vxq3hJOW(d!^RK-vJTaOU5RxeR?%c)~ZFrCGSu5z|t zpkc)yV`n-z4+e&Kz&LJavCWr(4qEV#wwktH%ay^D%{~=({s<#{STifJ;6+#f^ijB`7@MKtvQa4k5?tkY(w%m;ihP zlaIi716YfI(PlomL@~v4BT94`L3C!F=A_1U8URG-bud*shsqsF0J6-S=VNWS#)w~_ z_yWRxstIp@!+tl%F68%v{mqOP{%u0z0at&4L|N=+kFZgP;NOo26vZ+rS%}qHnH{ot z7N8O3tor5`C@Df}JJ4MVj{JiG5uJB~M%sS5a{E#_fkw++S^h7GFP-mv)#|fbb$&k_ z1XBQUv^kPz2BIXU0>mLV(hl)(B*hShwE_Lt@>pB z3&i4c>6i!gx~>skfseId=uXrO&thG-LmZfetUa>Q1b;HLzHBj^NWOS^T93;>{|UbVAbExVQ6c=d$KVyeK*z^{#t8F|0G_)I{FJw6I=;W#`@6{VKQ;|O2lF1> z*(MlFOV-5_vq>h!r*5EcS5KSG+T4cMpB5%9C?C^#{-Nbbt0JYSETU_8x@obuQ-Cxl zeKUNim%HSeRO$Cb<_yGEW(t1Bm}^h>qP(4_4q&~YDDkO0j6g^&*(xA*nkUsD&p3P3 z9*Wm6wE8i)Vd)Lp172j5Kqm{bTZr9S=DYzY8nQFKHf0e^6t+(*+c*N0Ca$1t-BfnJ z^g6$A^q24v%A^%cJ%R#qXIpl-rKRu&wOCAxiQUAK=7q{}v1x|O8VB#BJ8kRf1V&q3 zhoTo+gA8FA# z>p76x-c;wWFnw(<{r36$?d2t|ANSl{4wD1#=a2z4&j!4wDG}=A_ndldw3IsG?K(MB zQ4;XvStR|mfe+lI&h&tH$YszkI+Vi4 zaNC$`h{b%H)0f_FCXjS?x6sWoVkI`TR9f~3>M>_l(E7R5d@_4g*Y32HrZ4@NP)r|F zgmnj3r?}2}x=_PQABuC=t#nT;%fdk^{MJNddK=h?_%0qFsJz~glfH~}Qc zukXLm(Q>K7kSJW&CF)R78{oKg9-Le!i{aX%n`u-=+eN6YgR4v4SZ5zY zq>nNWv_nt)x`Dc)T2gF!i@DB-Cm`=+O_G9xak(DOV)?6Q%(aiSlDVNX1+#P~?(axa z7W-6AEyi2^0-Xe;&ylonAIr~Vp`z-8NlkgBz*(^q1A4FJ*&lydeo_DQ>5?M0^$u}D zJ_<}AjBa_QNw8WupLR0(5axy*mL{npyt*nyq@Eb?3J*X0HiI-GTH+wsSk#uD9Uk2M zk*K3%@-Z2Gw*uF%667y)r=m&w`&-m?o#bI0>+Np_rj-_FZ5t*LE{0*;%n`4)%W7uA z09z-0r5~&nw=O66NP#K}X<_{L6Xs5yG!#&4!hp|Oz{WZnbzWUN^Yr6{*rk>_PKHWD z5v@~UDHS3by)AHM5+{)pF!0{Gi1##}sO}Hc-qUM(Yo7Lo1*j_VDxXkspgn&edyv=C z>{E-^I0a>w6uudIlHXn@>7 z-IEQ>dU!lVv`T3NQ1uu1XCdH!>i7QIT=)NpYV|)gM}YJAufy3QqG<)sw!YWbCtV1+ zEFoBV-B^oOFu@-re9U66Ed~%m=Is0h`bN^=u8Sk6HFL--(5+8?oYm+L<%LhbgqiIb(Taf=C(!$Z2a zj4cPw9hL4)uYIi8_-@hhKzR6(y6R+W+>oD2s_1AqKiiroo!k%<=syP*n(faPFq=+p zoruzkQ2sPKUY?ezf)mJwzOv*f(HxgQaZCTr2KNsbZhuC*PSKuXi(hrj4pt#yIX4e* zXl%8siuqzJk9krL^LTQHay)etzZH)-$??O&ce8WJFw|v*3`Fo{#3z$#rXAEPZd)!F z8-~VB53}jGzK|ETOuvqs+{Vu?*plqNO(@*uPfYNw4=tEj`~HVt%g|o()9?)$!*!|r5arO%D%^X0 zi+~*Zg$%(|2KLuV%raKpf3r_5?BmqYlX`Bb_A?+LxbWwvdl22W)Q#9ELMV-v%>Bg(_(R=UnS$&n z&41?kS2F`1IcESAo$?J3ZDYn`1e6j@~ZASFPZLC5wr>Q?P)p zy>I5dmAl4Yx*|aDpJps4KHLYW1p>K1>}o>;R5Yhvq5^TuOT>L@7wP(BYrn))&@|Vll}$C9cXjD$74JP`03(s#{GI7I zRs=BHs@2H=}lsE&aBKa);jzFosp zq;?di+C>|{BI=%V!n(X1)Um>H_L@<<@iD>Tu%Z1PZ3F3PWC)^?H6)0Rn7Y^H51)>! zD;wZe*r{#iWC-EAJPcCr0z z>Gxsm-LawK^t_}OhrtS;H$jT+r4r$7mGZ&VGvph!c;E1DO5R*+I^}(BxrmGbW0;f8 z_t{Z7DFgMsRXUC>FQY!5`Aqoyj%_2n)20&kGSXHt&2D%O_5ATyE`B;JY_EB6caRn^=3Qc&2HQkzo@3Zh0K4 z$|e-z$YTj6-db#^9p8l2On4n0_hrjclzbmj0vOIqv0jLI|CK02y1r_2yk?`tW2}q+ z5_&%%&E0X&8T~znjrMh`S-ZkkeXTOcuiQ2}<(gZ7Gp9-uMq(3>SqNsZ;ntBGZ$F+A0d z;F+;@J%0557Tu!g79(-sP8|eycR1qQTK_ar-Z<6J&hLpahz; z4*d6b#iu{7b^q$I|JI%B9HtK@uKiS%+A$gze^U9eI&!7HA-zL~-I+bl1hUq5jr1_Kw5T$FJt6j$=GvnM^?Pd7V)tIi<|0 z5o>T{pW{$?xp~k{+I&{9cXgkpXG8F=4#qh&saSdRN-G~ z2^f4#T=>Kd?m{xI1CCf%(pXBf*V(wr#2n6_Qaac2B&@1ooJhjMxaYx&mlP7fXl@aNEu zfuW|c!u3tkdCKd+u6nEEKcNI%eb ziWYKR2pwme&dm-MT{{XaOSE;fvvuJ}Pj7~@%i%8ET2jBx@!&a#FANS{vXsW{P#Ld@ zr^8zySfP*(!PxeIjjKU+{JEOn0ie&DkOVIWHp=6;Qc7KC8#DRA$d<*1H4$3eTly9Y zZcgoP^YuGTkbJ#B7%+MBy-)Vi)3OqO7mhRxifxU09k-oQwkChoFK^xC>)iCLwnwG* z`gy};Ehp1v4`6A$bF6viA+d=Pf$#5VV?k7^07ZZ~xfLVo7|CEk;0x=npC!xY=0?4( zDzpD(cqqvb)@FsX*|Ns<>~$(>Op>)cvmTl`5#Dw9jmf+!`vuy(;#*glI^^-)lv}l_ z8^RNg!rR^?24Ds`sb^w676Wrk9Jk4@ssjuUH$ z%@vOahL7#eHEokB<_ z@yTxz*GgBjk+d+ZI~t{#4+#A7*7^nk9liYMvSy6#Jw z3QCbUWe}4=(l9@y7*voYTj1)HR#JC?=m_Q zJ={hkhQP=X=kT6b!J^AyZs?AqX9Lo&T(qw+A{y<5>Tsb6k~SK~-E{>aQNqPQf%iLA z7=Cjs2ym>x(?##w8`+|~D%Wq4wtF8L&V_WQ%tzz1NpyL29C6MF~zl-OVc=vdb za1JM;9oQ)O{&o!^7T!lkt&D%^HTo}q{ckQ3iU)6TVq-N4rNlP*PhggL zT_@Z|ryB=4ZXarYVZJZ@>zez={hl>d$BOY*uo!52H15&l~sUPtIsL4L_qJ8eoXYT;ra(Li#?Et--K6Yoo zPtlb|HMp!qRl-TMCI0>QFF=NRXLR$YwH(@TYzx?(Ya$nRcqI9hL3d5LJlN$SO}SNI(S||%Wwx!oIIf;{9P14QK4|QS8utSE2BoA`J}{zZ z{sPW-YC%te5+P20LM@FDxww0IH#bAQ zd$u*>lvv{x`g z;nT@_@YJwjZJkYYoqJcYr{~h(1HCJXkE%Wzq0|}3gYFAVKNonM!PT6-Ow4pW{x%V^ zoh7~Q*OWel625Jh5ZGvYdc3-z;XLXJvkZ{y?1y=FKLoy65?+pwHE?NNYS(c4aZLqQ zPC+htlSsFG$?;9skoZbVcu120*X!pXAr)R2coSe-&}*SL$4YfdU_9n^Nk3I@Tl;oZ zI$|%6Z>4`Qx2Gyaw?(|fqSLt>xK#S z&8WPGDeq60w}5MNZ98K?vR8pw-$h$88G zF0}=ltx!e&{0mpTNWn*!nxR&dA{EIG#H#xv)#f(X?FucU9y(Z4fh3*$U>w2+WwAboI z)Kn9tviaii6m;!0U9_wr{nWsS6|Bi{l-otp+VYf%y2oebkG%f8F^7=O1~*wZK!PtU zMI*@-KybWkYZK0WhUdbIk156Mr5?0}D)qixq`W0d_zyd)fLS91Ec0HMCtoy^RB@G# zU4Tnvco_JVWX~mN%9CC2DIL_~M0*v4pZDd1@bhBd^-5)rf4KF-?t%C5807o2L+%;U z-Op2OYK2~Gd5ft~BD(YeJ>Bk)(LMTkD*lph697X~sfOdOw7in%LFl1G&4O+KKIjlK z^5fO@Pm3Oeh-EE_hXx4Kuq^5fX+MX4OQxyb}LU~>bj^h zm?ce7$o!&XIuj+Y)#^ILYxjhjlYA!fd6^#8w~NAx0f)_4qf@_^PQ$hy)mFY7boW$= zU|=_jGe@fU9}6(maVavNx3-$Gocp9YSjeOz`T0Cn@aaj=@>75GS$r`jE)j3rDaWmr zRn(VlGqcXN;KUu)RFFJRkPVEBnpI>RPEyr?SbUwIrjNiZ^t^5ME7Q-OBSN|}&P>>b zZi(`@?Ap>D|Ze&T!B&PM$HA1+nU=DZWji#;( zT@^TIRsZQw+uv3ZTLEXnD`QzhgT$OrOhFeBc2N20(wKurNTk<WjliO3@BOayT~{FDhh&>q<-tqJ&4yjC~QOR`(`wZCEhBXNt-?Bh(x1Fv3%LA!yU zw|vy1$JJC|YW)}{@}2M_YrPKWncEIkF*0wR&Z)3JsTUcLV&M9;(F2RMCG_JSA_MJr za$rSnr{=80N`l0qbr72=@voS^4;k*Tr+11#UM|Y|9dz%!tV03y4qHg%v@xNMr9~~P zK(1)daw_^(e12;RUMYj*4=EQpluKVGoFk3T#u1NT98{=XW7W3l8%yh)n zAeyWV#2K}#umFrsn2ino#8|JR^^*`LuRU?`JlmI1T)RvSn-CS#fi`Zcqy=0XAMr`~ ze6uXpjeh+#Z0B};+;CT;{UF3OUp1XC!rOY9mfZ?Pu=H9mgZ_ zDB-e@8$dr0i|7((hxOSdhfa!NN9sC9$eQLSUoOg4ddD6wK4YkKm`_{hhI7g$pf78U z=Oiq27)G)b;&nbA_vB!?WJ&<^zW=jDqug_(7LliC?qZhav-+g-e zRpuiQ9f)Si6cyPWXq^Y0olb{1!WRGwa}6SZ_^PVh2a1PV!+)Y-kyYsyk>sCvL+aT& z2w6Yy>I#rUsq#>dWc+p!!$7l3YpnY(kQ3xP8L`L{XftqFPUR29=}`S)`9Q!lgQ8wR z68%{}fe+WPD6)_%f|z=Ae>joqFHiaDxVH>c17ZOJBR8ZPwM&KIXxFI@g9zMDa0p~? ztQ$f!m<42c0qfY2AbJ4%H)ID$fyw~h+=U*^fTQ-lw(aETpoGXj-AJonDL|z(1^){a z3^v_U12|q0nPfI#^b7!7@90v(y^Tga$N;qL{6!J$gEiv(@eMhK4AX7m@!mo|@s| z>cYX-9Fp!YkOWX82dS0+fP+bj)G?b%#AY34h9cD7h=_I>%2LeOyW2A2YTs19ys4e| z=rFTvA>hX^o{W*;H_qwP004uV;Q97n9= zM%FHf56?n3C#G{a+n z+|cd)XjR5}Rk`7A2nvQRJ$do+Pgsp+p9UaPgSg4AcyOm7S2r$YDWNx) z*a(?hoxj>MsND5Bj^pbvgPLRmi&P*Hq90LG+9wd=zK+ax%#Qkb>PB-!x6P^j7E3|~ zP9n}_bas_nE!NN0Yf$|1nJ#QYfEEAtQk5iQk(Akc5Zyhe;F>*yeAk6q)2$OOi>mSm0Jx$Cp4TaW6&u~V zqx2y+;=0qS&Z&Y5@U7a6T!5Z*s_Vg;cS&F{E1Ny`rr9a^orZzd6!?Hjk&gT=-+0?kDvg% zB$#Xh6tyRQf%?W^mej6(8;W>qHj*f2Lr#|uYumh214QN<>%z4A_X!4nqL&cme}SIZ z11n$db#xIr>Gzta|E0NS0Tr6xNtuc;`%={p5c>n;M+kE0R@+9xFVIE9%*JHl?B+TG zirCn8td)j{?kV~rAv33@`Lr~&agBOQo7E=$P5dhJMU0iKA3MWBt3k?6lEIhnCQ(lklFq1P9nxOe4R zK+?6dsvV5`Q3t}E2@JdmzK;RL{pCMZ^LZyjaVzRLbc~7)HdbjcbCRqHY!6OW=a=2( z-!$2FoD+FVM^qp7UvtcU&ASj1#J37zS(;cHc{V-h?Q$c~WK(8W?}ydEOs;!N+=wnI z2jp|?BUJ#-fOG5Czdd#zEwh$lR{g5;WMTeXzk^qNVtU5YJGo_um&+coa`o(_q!K#t zBI+fT8z^I?QRuX*>*G4eP1ZcceP#A<8=qayi~ly<-@X#8(5;IB(6|h4TwABnqL|#L zMO@|d@Y8F9U2*ewx(A|vNOLZKCU1Wp=4vmijJg(#k#z$aRVG{6g*u~4C1*N}KJhjC zpd1t7&T)M7uWKn^+E06Ch)A4+pTgzGHVXh_v#gKITbZv)b*iDB*CmJzNZrgcZXWSH zTm9M-t%A-p0wmOqiEQ0fRVWw4(`^>xX1X%2Hlw#?pwR&{Yj@E2sdf2tlLtcvx~amm z0ttSKV0n0)C|dZ#XuyjB7#$`*yp?n9fS|@Tx-|N%yqy0C>i8}u2wGIkaqpH^NBNg} zT0YU?9U!^{&e9^e2%SLm(Qv!`sWI(#92Mhrm?xy1F$GEX6N;@@{e*4FhywKFc*|pe zl*OTU{se*!%KfAdQ)_mAQ*Kwav8e0I*P^#yl0Zc`)@Q~j9vGb_F>;YK%(Ov>^~<)- z){$)-z@K|yXU8)V&Lz^q>b&Zh|CCO8i8VaqJpSNiZ+NptZfB9>m#9KFIZytH>5@2> z?~jzO0~Rdbn`uZ2L{Us@8%teYXOY$V#~YKl+4a<#!G)yrJ;vvBM3lUH0yoB1^S%Jy zlvoI3E2r)UGK?6DKPGmbteU^U#Gmm||Dx2wAk)VL+=u6{W>0LZi@j(~C4ju3M!4M5i3Pq_I2k-Cm@b<)re_`!RgoCVEkVAC|CU(>7LqZ9h7**w;6A z-SnJYipHgn6w($#)!4?>xcYF^Ksj9%(M#9d99@6RFWwUD zOF_KVea@g<88C^8R^wlND2G!;zm$!ZYa9_TO&+o|PQBHCGZWM_5y}$!+S8hkeKUBO zY>O-J00Jy$n5E5V;@z(TW|ImJbxe`=a#FW)l|J^)10KFh)i&RBFlCYMd2?m4o&kkU zb%dj&+B`(Ttd&*nh}>!$Y{;M+vOB!Zdn)u{la}~{19hOX|8Lea0wSMF%8FJyPn7`b zI2wH2V(c`Sg^HvV@>QT{f4cS9&<|B+Gh%Uv@G|C^1Ki5;Lx9%RMTe|mPaAvFj<%Za z@dZZj@^BPyi4l4DIMjV+ood#XsM~yrDLC-CxjS_%bh`CNU47L0ydoyvK3t0>;Ub_> z^;%<_zMQ36@ygc;PA}cRN*evGtg?61><6dc^8esTq-Th}r<4@PW`GUl40QGAu#+(L z3zX1D5=TrC+KvR#U-uX7A-sVsvuXpvMHb6QmWYBr9;sW&qm`BKPv1C0uT}%x&oLaT z7tEtZPqghxP3d)eWx_fA{+>)N0lnChXZLhTK$YDT#prD0kyxp{n`qEyMY6j3Y|%cU znYK5B33m`C7R2T_FG+lXEiSz=(?y@^{WCE+(YR63o#X3WkNSs7G_@@}y*4}m32eb6 z(E!g@xi!87Vr8bsLL=`r;+J$&c& z?qZE&!Vfn=29D#8UYSUu6_4|KTBEk?@bz#eqH+oJTvfkIxs^I@O9n2b=NhC=9%;Ti0rF79N9+0$7h)|xMA1oy?& z^f>s2=r{6Xyk_S-?L@PD?mF6_vHVq8Yxf6p%busnexN>O9WB~~-e9Fv)wJ*Ttlfxo za;4HG+08&$3_J;RaB!MI4gO@7=z3%PrQutE7{vC`BB-lUgG=0W`Ee81=iX8aq9;DF z0|Fq!P15O1bhnp zsj7xr<4^WV%d(G)mfcc0O@ASTUL}M+U&J${IqAi1iJOC|heeP>-JW;9Ku@cFfjm6X zqcwXE5W9)7($qc&>hen204qKQN~S}5_mEA95P(M-JI_`Rr~!E?yZGKpLy14q?3%(X zxrgiUy}Kp*&B*4q-%95JioR7fsK1f}XFN=n0Ebr~ibnPWUYu;p@9phpC(bCn7`Z)AZw^3ODTP?Lf~*7*WR<5t2CNS^Cp`-SHVyh#Ci1D2t_=23 zSv@{0D7VYoRDF-#$KMCN;y=qZzw{{4cJB^ohT!gizfmySYJ8JY55R|9nnQV1}< zD(c+^$}Wux15#{w>Izw73CT<};7qZ^vn1tehhB3le-iQ2o}&}dSq`9}3s~Iz$4*iy zR+J1T$6B+7w=Jfls_MeZ&#W(ClETRAp6O4vSBy0G^&+hiw{5Jbs{s({ftt1%+jc2u z)*RwFW*cn&t4a>91$wiVk-kN@w=cL)?F&**#wrs6h=U|UgG9LHcyGsW#4}$+2P-D6 zP1>mSVV0W&sH1Szm`*@|KFFEG*%@?pgv|G$H-D8|4wY@DeLcC+yU4ODXG8>a0&1#) zKbw)pt^%oUIxgJam^t*fQ04*H!PI124@&N{!Ha%V^hOrHFZDzJ zoc5*Zb=zk}Na5qIQ}Ql!G4BZ+XMAdck2<$r2flMvAyS9hrTUv)F6`2)O6o!VPV6@h zyU6&St`il!bKYa`&AI#ONo|U{FjM=7(T{H>{v$ze}#0n;NWg<3RCmB z*wc?a+!Itc0v1buW^fS4%K8@EOM0<)zi(13MhD-tobrd_gEEJona-b5WXz3tb}1W%oL5O6XT z!uPl9U?aLzW#Q)7a&~L$&kn8euRew=zq=I21$w4a-DAcU!emQy&TofG=2Hj3{H3SH ztm$P9*>2c3zV5sB;<4TArl0)<4?_#~`Bc|)YHj9hIeWBM?>IbHvE8_FtMY=^qjhb~ ztDj|o<(2+{#HMjp*$5|PSLfnA0`q<}8X)1A+BQ4+FD)=Chm?H@bnV98Bv_FlQv z!nA}AP99!LBhK_tk2ew%ao@e(RE8+aa!Ocr&MKI6KS|zodVkZ+M!qYkIA*T7`w;;deMB>;zSB;TLVjtdr2E&XD&rcn2 zJ_EJb`VjN+aNa!T&Mm1Gx{J>hmJ~Ov$(}k=$#Giop0QsvICjxneT!fgXfY1`U|c2u z+Hzy5+dn8#tgA>*H8-y3mLbHuX<~RNTY4nP`Wc zm$G*p3Nn3i+Cd<8YM)M8f0i|OYhP5aIPyaH`uTk1iKRI8frho$tQ=2YXC^$SvpAF| zgb+_)SYOyMy21l+8+Z8G@rExaTr*p%l0vOBpYw~pUl}y%6H+yFw=r5h__{I8oV3*4 z$TRaaDf^vLapq-n!-`KQ3w7rQw5FDIU#mPTi^#e9xm*G8gBN}Xzf9CDhgIo@Pg4&J zt`~Ja*NsVecRGC9v)u#Q2l(b@FD$Y~1YIXe+_Sfops_YAUR&+D7^vjsl*RO(g-gmI z>yf`#8Myl)eLXIMeD=VmWjsN$jF1evf9y36EBDW!dE5v6&xHai zNMiETmTW+P`~*xBsIMI80F@I-pqx@=Ec6i>)LIX<}$^u&Dv{*a#l(r!m)T5 zNd==(W@C$EvMpDBT*q0`v9&Z8Se!q(?RQmJwe_xsVmWjYvdk6;m~{vOL+iseNXoOl zxP@CpMLf5YK-E|8Xy^h1v`YDUoGj;=eyFs^b4~3p29hzSVn*MjCTz+C_;%S$-`)VW z{A%A&HEt&pAK%)eNaKm^%TEMroH*v$)mg6@GqJ_d+@EhDv~rF;X`Ue4;i7vj-s|({ zES2-mZHOeY&Rc%FTs`uwgQt-`*nR85L3yKasO!N=AgTHfBU;)-V%?%1A_~kzHpPgR zW{Wx8$rw;Rb?Ex}24unp%IeJlC9vTsyj||1Kl6$C3Zeq;qPooHbIVOY_YJGdyq(`v{c5g@{!g`RMgB^G z-b?;5F#2bSsq)LjcO*0NCE|@?&M0_5uK7a=@T0mnbJmIPPhWT?^Af3Evz!91Y3nwO zD*g^XN8~J-o4LN!^-wiQCYs~Rck9QqPm_Y4s&G7!ZeItf)s(9dQ6vqx98m(g0xT!f z@)41Sl49FsZ)R^d2uOA`#2?aa7p^@#&$?NM&&Ot$v~j8_+nAewh6G~TB78axYJ1P$ z_gWhCR#;iSC-vOl_AB)~wFg3{c6N?ui|3PxyhW8le2ulOH)Td{uBo{%`m~>_$Pf1oY z0`K38p{7U&ku-_E4v{bYhQ0<~-{Yu1FHGm2Ai6#+OpA4!)@jGX0i~18c+6*6Yg5a( z&K?Y;TDG{hvFZKWZ_&D<9$g904W<7q*UZ9#4>|V@kA9~70;$X5*~Tmh&8CtP*5@R|?CajxUp2~k&LWljh#KBt^K#*o z3^LS)^=CQVs>JgngVeH*JPovc;U2x$K*QId;pwIZX;E`EM`>SFM)flUNT%lL^464e zms7iQJP0214p(s=uY|3-*KW3RaZ-kb1woe`xr5q(hyg@2k!3R5?oJ5K;kvdZRtE_UV+!OcBDxiLHtcu2O=(v`|U9RhQk5@NlSJ_x^^lFKL#0wps-=kAd#~%cwmAWsX z7B$WK{el47;DIA@(zBztj3;;0-_f6LD|?eKGV+(#4U(28WeVt_L_H*?CLhzHC`n0W z2nhM@9^^UYTDCcUB?h5m|H&XwfqeEC(;bW`c*>A3C0DEi(-u=jB6!{pk7V%R+4T>RpaW+-v9(9IG_V&9^KUU~14 z99c!|(>^i3Noaj_HWcyc)w)eqGEjG<6We&@`6(X8fe#_xRZl555-GOwaC)RXTldoF zuxZ^HBi&le`YhPsAjKUkeHI?AoyAcdX)u90{hc#ljhp@vpEeiFeMRIcExhbB(!PaS zhuZZ@DX~F zAg9^L+qsd?`L6fm>l4&4g3$y^0h+Z>*zuE_qaVloX0^M4I#w0oT|{Z*PB!-bH7_}{ z{g^vak{lgzIoPCG`jcN;xap0F%`%r}9!E^oDfcj?=D(!UCnBF4a^JltKtgawFN>Q&ZP zE&oeFV!RHFJJZq{8VInmhqK^7QSXn;k!lA;I^Re!i10k&_Ct~-xBuI>B^{yoZ}PC->If2U3DQx*CIT2&Aw;whP-xu+3v4>;Tq>K#~jJ zHUooiMUAQV&ellevXXA!rm+9|3_!1Q-~2sOm-wzOfG;q{&-WOCiOV6=dVaq_Y?hNZ zJl%Pas=`e6=FJbK82Ifn_KvS8c&P_s3xLKU+Z^%40C?QD%`1jTuC8uV{HFZJw; zCSP)EUA|R+r<{gk&AYc7mPho35k%KQPO@VHHTgJ)Z(pbx+M~bpnE(Bb{?)uR5k$^I z_qLUc?Lwe+;HBx0&u(i0WNftBIY7{1x&3SzZ!k7fYYM@qSXnD9r!Uv6Xzj#UTo*~5 z9o=v3RcWaJwz1qXO>UC$VA-~m^u;=YZH0rvVqCm@PEz-ar;1T3B2Q&rtwIgWWELP% z$KVdDP+MkN26xPu0#CL6>$C@A+e=UCt#_z28Xx*t6tk>6BdRSZbu&v~&`>5%v@?bb zX-f+jw&*H$IT(>bc~i~STy6e^rKw=TSUeSC+ShmpsZ?E?uY^X2@vH}bo?dIvXlk@t z=QS6xnr)xX&A8U4N7cxGPHwpAM;szX3O*3!H=kUwX6cyLSO(``ReaSs5&ZuB`#zib zc7D$rJQtPUz7i_urlUDCLU76?BpVU@0loUOeW0*Cb$TMbQg33@Il}5X)mWG4hA<$E z35}SCQ_{)O@`-d^XXiz8^4)BCENuLau#kpVhXq%EN|jK7CElG4671LFFACpHYQ#HZ z=erqg)HFEh4WAGtjW+y+aiKP4`=YFSH)AVs!lo1w5eG0PeHdQOuc7&l+SM7S8rvT zSXT2|bg0{cxEMuo&2u%gu4kPc+Pb2ou&Pe4EFLh5JJo+Qff3=)9cn zOqo@NG|6%rvp^=httR3Q-7YVgmTv{ML_7(TH1I9#!1c9I_kHSfAxs0p8HLjuz6`M- zF1ZSu%s@XtRrdm)ywTz|UulDCPv)MBFC1G8@|1U;`ex8^_mg6pT!bubz%g#TX)aF1 z%vewlHxMBQo4ZIf^8IM!T;Qt7(7Zb}?tYR*uL6*SX*+OU>t>SwH9X_WTBc!7osbC@P{)?dAyrVPSA(TtkZ@U_pNW z#_{%?hpPjDd|-PG7-%t70+g%hv0h`iY!621dKo50VjhsEM6>mn6D!^n$+{p}exQULPpox4@ zQR^MIRRx}YwLlk>u9(fGE^+q9{2uwnRWfp=X27t;$oM+nJQLC}`xy`-au_{BGvuz; z-r9zKrJItPS%|O<(%0>IHP)DIxqoLFEk}@R7nG+VhI9psqN!MX9b=9M#!=unS-fmOEDeItsSR(<`y^R9zP3c!F$l`eO!CPru>GdVL@ec^zeDJA5eTd?>3n~xto=AnU8cC2*uA% zAb%}}T|;87&tt410KLw^vqY+AsOrD_UY$Rb;$mc-Ufns;uz7fi&g#h+ZrVD^pW2lz|Y;Nu0E$zufdqAd&>7;9apwK$P1fVE#2B1Eifnc5~%pa?O0UPi=C89CF?ptR4 zf6WI*^ykfJl_1*;A9LQ995{esSlXj_ij=x`Gpw z4+zw(tF|q8#se7Mb7LiM%?s`-jjFbJI1M> z0WXP{#J2=k+dM5vNwf*)37D{WhhHxkgvHUut2~n|{e7MCYoOXR6RWf7TcYb&JxJk#2XHRgHPoTe{s)%= zn8*_5BkyW%?}~)J(n0{eCx*aH;4c6IWkBN3{00H!#-V?+HWqT`ePo8lK;>l=fJ?wg zk;{=Q>lkSY6%-IuPC@?bg2r>hwtm5t$=7DxNPK$OPsr{*9!L{FK%0rU3I3NA!S({X z^=4)6d_ z4`6XQ)cpRjSq_M18V`T~G8^@*g?Rc;v(M$;);A?6t4+-1jOmhr`8^< zdZ#SA&C#_O7@Htt6KJm12?RULqIXqoY!gJ!Hj*&-VcAC*#BdRYA67hr9+NAu?rh*ka(x&mZ6vzzb0Dg z>;R;=rJ-$5&UMq=HYDD7`)s$*1p+Eu67_nI*&F8d$bRd#O96#O0-;uY`!0oj_7~tJ z`*H7|b~jMIasAVM<-(P>k{@2WvIr;YR~%4_HoFl+u-oxK-Ai7{RsmlSSCxGVerwZ5 z>*H#AhU)y+LJ{}!uj5w zdxG;peC-$jPZ{7?Z9awfQ}UT|shO};wH+u5@~r;!uAV%Vb)6yW1`d#xv_h*ADiL5p zZH&4@fn;(Ff8sEwk?P_Tta}EhF`Xo%XTZ`^K`gK}VkA17jz4~0#%@%d zwwhx)WqmPCob?Iz3m~l90>=VsZgmK@FGQ(4(!4T>Ao)Gt;Hx}MhHCY0R zA+74DE!cXzP1>ug(MxP!v)_=f|KQ`9w)Mq3L1P7hDDPQd7Z_iSR>MJ9-Po=k*Sw`o z;eRxM(2B6qXWzPts)mFlzm|Jp)Mu4NMNBOt#vjBIQx~QuaiNUYt45Y9Vjr3KyN)Hi zweJfZO4!Tq520H3GwW+%1*TVJ;`&?Su|w*{IYc?5A^(%vEB@LeT1%tbg5$d5vL?b> zI@hh!9+{_%aJ1l|2NC1`tl`Zn1aP8#)Cm+X*_{I`1m-#wYM~Vit$dzd%a3vV#g(w& z1N!vF{Fo)~h`lpRkvN}^poVZ6CHuM19N^c)Ix1YeHJYkR%0AG%O*%|5Y3uM`bXG?w#6*yLLwW+_%SVTy+{$nV z+;m7UgxcD?U`qsMWm9@>u zef1G)F}v?yez4EB=3S^pO-^9UVnto6jiR#HM#Y{`D|AOc+41r{jB`dc1q+;yLTr{I z+tdO5i+iwkNb>Urb*b;>)tUXCUscq#7aXSf14YVS+)5^ryrLN_4n zv)2pFxLaqg|B_e(47E!YC+x&XJ=g3GV6#)|l9VrRzqDK=a4cPLyeP6KZ~}gyLz0`{ zWX&@WUIkZGrKfBjU7SE%fkD*wT_!bjlsBRG0j{)!-UrH03OGxP-)Ho2$F!NJjLw zv&r5w*d|6D!1a4(efiP|IM^viZTE2pBrXH*zRlO`&!UCrT3vL$G%3J}fy3a)yjFX?QAPIyNf zvC8kPyb)S#Z1b4Mlu6@8KVqdAxN{Kaev`xGXhl`utr@U>UDi-I4(*`MXfr$>w`Pbc7)?+exN z?#tBc;u4E3Y6lbesdAWSs1q}6X^%xuis1CbY~)!(G{pl${S`(6j%H-@UZWCQe@ z$#v@lS<>B+s|;7wo9?K%gOgp)f(uH6St3vLNumUGw_|8$HlwWsKXXPEwMMAb&q2C; zV~YmoSieM9y^(4w2|No!b)BBd-jQCZ_`elT9hv<#vs)6`rr)yo6L^R!qJe(<@Xz8k zD^KEKJh_cjc*^phrrG|rS_`dzP)#I5TA@)e&1feDB>(!vl{`31uiK!1;~lN=NsR1< z5kE+`ALNUyD%iC0$DeeXBALT@R~c}S_>8I$Y`J8-fhapr^D`^c#IocB*J7!I2Evh! z5I+)CS7s&~8}#CAd11Q-79@j`Ha@SfM&nVp&C*V`XiT3JBc)Puo*;_0W; z-#Y{UJMTf@uIhMSOjD0Dsvl&TJUYJ(?Q%)m>5C^zugU{S<^gE zFo>96<@#Eko|mG0@}+yeeJzN2eY4AFEQdF4-3KSe{>xHg5pa`FUp|2r4?e{Gx(#}0V2f(j4q+I|W>r+t5Fz$UXJa7X+%sE?1>p5yRKBr$kz z2Dbis%&>8ph01$TEMtv|``-KC}+JC_MII1N$ zM)#(RmAd<4vP5TOuYJ?mN4*A)HGp6?5;ibL7Y`t@5f`xDWxtf}Ppxj|yPli9_4($H zf%7T7xgNU8D$3jSbYCtRw9KrkAR->XuU#&@#mATLNU~5IxypGSKHV2-K2Tyzk^vhux!Tb-{64(S5ZQN)N!?GFa>h^A9kHn%!z&Wg`k{hZNC z&^deZJworsjd|`Z*f5Ee(5OwUFValO2k3U#Ze0?*Ug&-3#; zbFA(?SgIV<`39y76EQluM$}%xpb`)=SpA|s*mPU28?R={r5V-gs+Z`B*D{ecH-F4a z_{d(N2oSljo-t6GFd5BiJA?)n5)=4kO)NNzmD0ZwwT9Dr@j~ZJy4* zMIfv*n|Y&T&mja0X0=*9{^(*x2tq!QaR_aUYP0#K3rPUgi-RZO|;cNYU%<= zCE1k)k5?#!>KFoyg#oTxi}l(stK~meebl*z45Z)k9sLG*n@K6j*bqp*?TGPW@{v1xqDQvI-1SU7<#>xOU>e~o^x6dfZ)PSdHF?9waKi8k-!Yh(SAx#7Vb7K5Jl zqw?{+X;!d!WZmjCb?qm$M`)x$C6VBa2|=<*y$e?OKLTMXjOCe zpE|!(k+^LxOE_I?^F@Nf(X?ak>LJ&zZ-nXP$zH(~Zox%UXA=8%+e8){`87#E^Aq2R z71RQK1ODThrG7oH9g@RQVG&uAke;P9sLq4;)Je@9*A#)(96#v?OR{7$GHfmNeEfVW z^twINV^!6rEINO3wdPuBb>h1#o-|hu9TLxr7aoA9Y)!?@UM4$y%otv3_)l@yKgHx_ zm+Eu>=KzF?B{vlMh(J7*9z@f`uLX-}$iIpo?gJM)6-cp4f=T z>vx{c0E&>1$byv2pcgJRzZQ24Egk|Wvfw3Pl!{_LK+TvTXDNWXr9n067u23*I8}tY zC+TSEwIA$ld4KJ9$R>&@J*%ooLCX*z*xJZ9HA^V(usm+ z2*IZ>aCHjuaOUk8B8MUt_5`!PlO2IsFo0*1L1{Dk;|yd9&4aJ(#1}5CW~44f%B&P* zK&MeT`TOy)m_2nf=+QS$ zG!P~uJfm`7o%1A*rG*6T12lZQ<-&*4jEGmb+6_dp-0J!;EYN7yyI_aJ2;s-ZzI>WT z>rKJwB3GdgfZtBj1&P1MI=Rwmz&k}U8N>xDGs1_h>~QKwF?Gj#s9S>o2^#WzPPBD6 zOqgh7{K)&uv}H}efV!P#2>4qy;ub!$l_K>U^kO?w9a4(?c}qVctbSnr7I~R8P#RRz zc)~;i5zcl^v_lxRb~%58Zj@UY8j-!Hdus214o*S}J%x8Amy_VTi9_Y~*@Z>0B?%N! z6x@2+S2PjRZ`Q%Zm-_5ju`!4*f?oJ^p2xGWPf_G>WW^=|#YhSywqSNs0kLGe;UOF7 z8f5HC_#jo*&z?IVNx1?P!eI1TblK)6(yuq(}y9 z>=TaxZm}o!wd@0ic}a5D!4HR^KzQ8IiXF;{A?7X^b%R-Tugl8d8~e}r$i-5FT4{NN z0=wnzyF%~}wkMA_KPEpUh!$bu->O=zt1>S>q4i?8!dFQ5>K!if-95UoQL|_w@?e2@ z8v9VYl30(?4SWe2RXAU%A=jH}4^9j=il$tKD|Nk51;_hIK4z~%Is_>PkMqz^NXOups#it7bASC2J2{qmhc?*j70Oe|2KWAae9!5Gm4_W&lE>3hB* z?aeEB-*~2MP4nvyns0Nfg(?YpM^LSq;OF9lfjK7A`rTQwb%My}np`>fnurA1#3E(p zP~|lxX|_dwK?_u|pJ%+SAb`*@`;*Mqhc>=N(R?IBux<2?L_lxFJpoug2k0uV# zWa-irB_Rl`&zi+8iR|qZP-b^V$<)0uqgO?3&R-xhFMfkch4{~^tz@N2=GO^MXu;>h zfu|sLA0qIRDAqm4Fct=eXB)@qRsGs_id!2d3e9}wC!EV~U))^T%DU02`Sqro6(G&w z0Gy0>>aNOq)h`a7GPaKHV0P#Z9*+)EN>Yb7VrC6I^ONSLy0|;`6K3BmMQMQ4 zojE3kXMkqVr+*9s_`65qZ|^=Zg2Y|VxM4W;*LuoN10IDIHMb){{k!<*KYSZ~*b^MM zmm9tTKNkQ={zpco=}WGVL^Xpgfa%)F(HK{S?dagwflA(;UK7C&zh7*8dg5*uTCY*O z`JUAh7!?&?ur}ql?q4uJ-Q(}B>i1IYUT??G2v(e}G%lHOiim3q29jy%*9uvvf*JF> z^aEe7K6yqlFmL&jWOG`>D-UkqshXTZHf85*L_=3Z6ws)%;E2!3?dXz+?w6!iw3PCb z^ZwE%gH`zdb~*rD_Y2v~B&89qgCKGdA{1ZLGl`bNb|blo7Dn_5D_k{Sf4Y0uUH(jjW|>OZ3EwbumwJU62VHB zIEkOXO^Aw{cK4^hrSD-S4ZDWbg_85^bQVJOCi$IVjSW}qXCJ=Ww1GaQ$PfdukZ@1+2f2D` zQ|7U0JHzmVK*cYejd@YGA^xpBg1GU5w;G@=alVkgkO8@kLpGzO3Z0#IisJ$gaXAMl z0bgJ00|s48xR1ZJLX?cg{jd6V~`fw5uE&x=rzaGle*&2&X-4Z4 z$cK?tZ_ygA?pVgo@%*M2SpB+|i8~Ad?fx3k55o2IG@c%)MywrA^?W-q=rpA~iMULf zU~C%D;#o8LfI8Je{^C~*VbRCFL7&D+44jTtuR~}tzlJ(zuZRnnPzuEK zlk0E?j9FY&3+p-zpf?w}O7wKYw&_`Lzibe$I>O?Ue;oJg#@>}P zpK0~P%0XtS#1Y&v+_|$MI(jdg8s6A-TpA}~*r1&mW9HS{Is8S~d*GW9)Uq}oac?M0 z8o^HBijgbjZxM>Hh2$`WXL%HR5MF<6fGQr+4OCj6RWJ{ ziqBuii}+Yur3vZt=-BD-?@nlVxh%)8acbDc?X>VCf(1DywyWHQTjwMRi3i!`a{^pO zf_E>ZLdj1`;#OeP&m;%vI5eLFcMV|2(`_|hbWzMmtN3J@TXh_Y7MB)&Y%p*VE=rbB zaUOOV-vX)t_kYf{Sqnqe9&y3w$Tb#a`J-0>)cHVPB~~8A>1ym2ER_o@b{&2x+)$b7 z7+db;_00>W-E2mFNfC0HXGA^=N&6AP=1Lu2lE=gblYIriQ6%v3meKn#vpEe za%q@Q7!=c2j`AS6+ z&!ge0X(Lrw+r_l+*9BwfzV}x#B?aZa_4KM(uQJO-2opDRN!FtZWN`txhxqLn`73M+ z%v@H_FV;U(RQU3=$*rinWWgJ%_kd${qg4uNFI1@M~#`Rr2&>1Y^djs=yWeT z%R`eywN?=?wdZ4(X9}5a^`PS1T(wlIMbXRTC@lnOU?5iSXX<7(JQ!u zHnP%z1s(=lIe12FbPo!YwOY00=Y4JYemvhxriS}snymT$*Ip3+J3Q;mVK(8&k+O>v zFD@b)X^mHxikXed7pu)a(34fn4pJB2nRlNbsKD=F#9a*MB7ABGv;r)r1t;NC7o)1KzxgBvI68LHJ>s=bBPb` z5uUf9tU2T9$(n2HgNjjK-;VGm_Ao;@& zsK@uM?vV-ZlEuvGa6zjDjgU@owP!aVYvl^&*JQ;HbcAYeagLOgbvC>{jug6C7JfaE z%WY9`{i&4*N)8K+4`U;y5`4RBkZkI%rEp}2tzZkKe^}mzVfzAm!{uEmr?g4EM|z+O z>UP<84l-MwHvmFXTb~ z3oSs(&Qmb5A2Pt@s>yuKniyg>o@OuEX~y=nyw`>!`|E4ASACCn!!y=3`6)lxiMeeV z;5dXaE}I*YxUXzKy3$-R?i3;6^rEEscEE4YT01DU}m1C2#%;I>(A zEd2D(hq;tql^)-LW7<<+YSW+9Bs9A8c2g}r#%pqnX+nss?E!*k;#)1bVmzzr&UFuT z`K&)uQ?W}B7pc5arE7kf=iWCf8hmi0d;e$Zw?s0*N`Ev=y$gTR6?S!$bYY=X*dYFX zgX)DktEM~4M)mCa^inCGU-o#?B(V2#IkcxU!l_2fo`h-e#&=E&SpG%xDBYg8aU^flf#qW((LkEZ+P`pb)R!&PKltx)f%0M zkfaD)?b1|mAb`A7Wtu*bC4wGZFn87qkI==?cNPip=X#Gms@n#nrIod@>SBbKiepbU zxiBlwxBU6G86R!vOTUNz@_O*OL~mJ$Nwr)>?A4V!eczp=y9Z)tv_q0(BFbOXn%7n0TAi!=D3-CZ#YdY?}0oEPc z6S#@nSS_65WNH)hJw6+W z*VwriUxe(wJ5#yL60LILq+~OE}*Uwb3d@v(WmdEx&cVSP& z^#1At#6@8IFv(%OI#j;NM?jDE}iy|$WvZJdSarAO)adMBNPSX22>`F;&O;dX?;_QqtoHOxy!GT({R zdY^}xpP-eyX6`$%+o#M|k2ED~3CWTEtkIRk^QMPqY|2TP+EBDwUSk!E1G&wJG}C77ORj*8Ng^NoV*JTdDWcIfG-g&z)yG*|_R z*7$rcP(0uZ2DqXBAdk!P(pu8CzKav2|%-z3iWbbRRy14S~Q;F{#jDXuFsZfNh|4xT;x54^EjxcT?bu-dl}k;azRYk(nXqYkq+0LyuDC*zR{qgLO$Net zn`bOVg1{Q7uHuesIYa!ia-#hkr2q0jO}~-xM<7}&{Yeat7`#%1yVTeeY|UbZfp$ut zheu+}y2z}p$K5&D%u*_AxS8pweTs*Zp3qzhyZF}2+jif0qQ3_$CgrqLHl=r<&biHj z)9(@v<2S=D{Cu8nI=J8Ogz@mmdq`f>gUMXTWzsuM2eM=9MY)god&V|IcaUw5&~g|u zUDuP7y?$OSdCPvM&APO6==0MJC7kD(J1aS;s|2M;L@1%JTadnPJWOzsr0L1JMp#R4 zAdFpTs7p1bKT#L7jIukam)=RtU1NMgcOeLGqyr3@%g0Iz?phLV2W(?ps==xg(E2F< zlr8J?791V)(Z<|IH$-p>2O?+V;SRC-vfVjUKISiIRm6%=r-+FdX@0ocK}0;(`58jC zS+1&iP&+>8lO!p?=lZIQ-wU}SraZ-KDyIz%=~2NIFM~=-7aN{hS)<*CTg&*h7f*fT z4&C5SIP9|K$Yd7oN{Es;m8#}xwJ zB1wX0ycCbA2A_T%`ra(B)~%(2v23^G01n>+Ry;tmWC18X*-=xDGfrxgdtIX*RO`){ zgqahp*1#}&#$l50!CJ`DUrfER4XQ8II)9&) z2NZ`ZWP$3aKC18o>zL7`8pXhQVV)eDp+tGnVOwl#@~^$$e@OrTVeAc)XM|lNOR;7T)0XO$+j_noo}n5gU}P<2aTbyjiF_*YV_3x z>?XA^t450`lcDHeH`-`?=xKb!F=7n97Zm&*28PCOn`xVAQ>DI+px2-mE-a8zV%qCR z^a0r$bOQOfC5978+blAt7lP5ac{A%YA-|10LpT`OvikV_EtTGy%+G{SIony)xIB*} zJuxna8zMA9VYlmZ}y^#)-9o7}TsLmmAk zt17jrs6;NeUTYIhdB^X}ZwCE9AIvPbYqSVo*kp^f5FvfA`YI z(DiZ?mtFl5vVu9WIr~J42(r5QBDRz_NRi1iCEw?JsGdSQWq?yPpa_7UlG7^GELXWS zFA->zYT1gM;%E~}5G?6Ue zboOV8V?NP7}Osq5MNO)wYA-^S>R;<0+M+rrN=}M3dpVw%NBf3ghSt?pJ8Ot`SEPy=uDUaLV$ ziHqhj;)6p<+%ez6e21hyAn!R%V(=uG{=nTMup|bOi=L$KwoFOpnZB?(`d*rr^5}(x z$AHR!!n@EV`wa(#_PzO<^!9`V+>Nmq5-ty5CuC2nEiA?;qNlV=8EZe&PgWc~g-!0W zQvCwZu2m?GS@T$bvJ^vUsT8T3R(Yv=e~M+~eQx2WwJfT^>TDjDW^weFUH@4@75j>Y zUG5-#@l@|Be*VZAQ#6XeF(5iZ3c_V4AkU&x{4LO!V5CNsV?JMN!-LXi55cLU+B+&aq z-sT>lEm4seZ{{IM;#b-QXhXT$$6uPe`r2qMLj~Xac-rU1 z@ChA)`&fD=YW1wXr=$P!wVm7-XK^mdDonv~WoF-A2yzhEZiUT5>30C7|M|Jx-=L3M zfEF@7*M!{AELe9>0DY;jSW!ht^E%{i_V@jN5~(G&cqssPwDHc4ah{N zvPZKt-k9H6XN@8-X5Gjo zuJj=J)M2dKXF zX1DpOld!rU(84(xB6eQ{X?;{$f57e-J+-AqGQu=@lW*u7P@6Y$)lT{6p_t@!PbIkD zS1jmlxrl*x%XCRu*e)P4@wDU8r%5h&&)Dj9wLS=l_|~5u&A>d`Hlm+}#pL~5yFTTZ zAXl`xOf~9#_}&MoLpxxy1>v@=EbErmfxH^%7k#JEcL$Fo*=S`Go^eQoCxG3S{FB)b zHdSLNyJ0jNrmO6q~?C9%?z|JY3eL&)3erobnk{T zU6x$NtzKF02aVgvF$8 z&v;di<7Wi_+61YUQB#9vzo9xK@w?8Klo8;@ufp5B$qNEwHPt7E2;Z?i?lF=AFf&9i z@C;?m{igE4594UM7rsdalGzze5?G~;^#yQT#xP9vAg;&4935F`KBEUyk8M4($-{@I zrQKNY9tIL=TB6tu05Uu6&e#!b(wvxQwbEBzbD!48^pW1xo`lMChII0;ZR5+q@1dvn zvU#jXY2%m085NrhEq;H)^behVPvrQPsqFH4D8Uq(=K>8jVb>P4E3%Z133n z19&|%^L9IW#5dc6;}nvx2^?e zg?~}=2||64r0@}aCVGOi6LLq=Dd!~beDynN{1jk+m9f_CyqJfk$);_n|A_$UBDy(f zzWq+xj{eHm0<9obi{7JN(X%IBEV>}7uQ>kuOQEM$ir^1A&k`%!j}_}Z`MNxdCU1rC z`#AWN#FSC>c=1-dR@{D}ozfl851588x873)X^1x$eYW@&>^`1bID5;%9>#9q?eB#V zj#J=qastNnty{)?BWx#C0peShF|;w3uexGruZ|asbrMz<;Lr(hcurj}WJYzuPm-EI zyJJg@fcjdr^CiH5aLe-m%H-m^<7WU?OVQoD+7jAGnyuyFH|xlzTZ7co3tbe!-!R_>smnkMbPtOpeyBl1|nMilNYNhXjG{lbc-c|2n zO8tMGf68jVqOG@*=H$5*cM^=A2CgVyZrI>&5Tieo&TZ%G;|VIKiqQ|?=JJMH4OeiQ zpXYEwDJdCsdhK^~4Q$N&KccySQtDW#H7p~!-80TuCl6tPpW1j0P80AL^%mH#Sh!{< z+v=wgE#ACuyYcu_S(0|i22e08{$~ZzKll;KeM%9V9Lo_ZR8IGz81cnBAyHI$> zKa6jVgYxvR-sLggC5iFlGB$sGOnrL!W1QW;$}DYN3DdYwsd%+^f=$1>R5e!Pj>2Vo-uS@Y%%$<~C&DT3}* z@qx@0F~1yNoCeGdS#MDdb`fgk4pH3WmmTiK^1E@ETyF&d`1Yg2gWKe+{p9DtwzW72 z*t)BTWJ*XV4l13weM`rg8{AM4laQFgqP&px7R0`%08e!b(X(EUTQdDCS?ZU|<3PMm zMFfK0)5Sns@d;l`j!kPhwh7e&<1Jf`R6QAyKi>kPy8MaJdvt3B8}EOEL@V*0l+J#> z-ymIklnF|l{HkSfJd7Do%f8J;t}-dCRsGfF{-0T=`s4V0p)pVPRnW`^5|j5Os*<06Tw1~| z+$m|9dpYWM2G4m_Jw95~2ZIsH!gJYBH5l|1MJ%l6(03kjC2AdO)_$|w0>K#Q-S}W= zbf>%9>1!$-^h&*kiU^>O1CEs`>NY6<)lZ(xZn-D-M^l@Iq+hocS|}?!rqLcn(F+R; z&o#OE=RO!+er;`}4P?C^Qs!X$bCcb3O-)y-{#P0yAiY)|1#HXt7B^3T;?>CrRQnmt zXAwHYWKC06MtI1%Z;3rh=a@?98F(an_ljP>Hd$m)s#vAS6KbC0w>tuxg$Pnb?zt`p zF1E1Zf~nk}fleB)A7{0rx!Ncog4yPD$&C`?m+q5xFzCs$MTz{CO6J43w}z&$rz%O) z$`@2?*VI^}AVA?Xruh>5b03y#lPwg6T1!K%xKL;+M1CF1tnqXsheZ#cgi|K^YS$OF z5;@pzeoi=T{FE*~$iIRmY$HHbn9=Rqu~Vk&$JOz1PR=0<^QrFgHqTQuNg~G2O^jPX zOm-%pFP91b&A0dO*=#UxxjS)}^RE3jDfEArc2h4w|GHQ7M+$dZ3V62cGAZ2xU?A~- zger1x!Ock@((Qa;0QUZWI1K&?os(*~OMzT%ne$vSyZqAvQX{IxT+gbT{EJPW&nJDQ%*BK2nCd+VkjWa za-06Y9Nqu21bjcM(*W!v`T>Qz5obgeFl_-cwz1m3S)%`CW!-B-us;jXL;d%kL}wiH zQzKKOERJAAe6*bi+4HvqeA(=#^M9<||8=3XqdY`0#D)W#jsG9)y?0nsYq|#-1jGVF z1f&xc6cG^V(h?gWBA|lOi3);r0jY+BBE3XVKv4+2NQp?Vp%($^AVnZSr9*-c0wnRS zJv004;@)#-&N{iyz2E!(O6xQ~YJL&8HK->GaX+!de?MRrdWY;t zu9j$p;UnvT=0g845QQnaW+>9QCM^mo*Ov8)oIe_GX#EVIe+Tp(ge&!P777|GnK*^T zsEO=KBlJjHp@j--?8wpa^0EwzwRPzNs=9;dO(##y@g~cM?;Cc5%>Jp}{&*nJ3iih3 zv6V+g5vXP!1n{+Hya5VZ#3Z1a1dzIa9Qx}Kj?ZJcfJF8T1?VSe-2zDRcm=1w(#C%r z+nH2l`Cs)tzgZkFO9;BRRh|5$KG~MH%iQ8>v8D`t@Q_6NC|QJ3>^-3%J^re7zzkDc zjf`c^#02HPM&7Ez1K%Bhcx?K($uxv;$faNjMX8zRp0$|{CCcm-qUMrCz_x7KgxP-Z zSV<5Gyz*d5emI+mZNRddBW{!Ipk`AOjKYyQ&!}EJS_fR8dZx`~&v-)%ch!-dZrvRVyAvRRZ0HSi{nE!>FhK5_$6nIX#4V6$N!Yq<$5p6BZi4jdb(g=HGgCs~U) zdiwfRUpNdt{|3aJZ+=VSYWCOIh%f-n7jV}^4%N}+0u&fls7t2rY`%kH6Q=2P6WfD~ zra-i0AyoW8S^1IHavt&q^UjN!)Z2*L<&<}!9Sf%Rd6{_$hx zKk{X{?toio0g9eQHaWm8_89`)UFM0~2mF|aih{IN zD>eJd*$KdTRz;^|a54l-F17gko)Z(?Q-$?N_lE+WK`=X*|FMv?OcSaPF;iW7!>?TB zvun&6!}iH@4cpLr_Gp~+d{?eqW(*VgtoazKc=D$!zq-&zUIiG!6w582PF zs9|#_N>B~A53|ha13}W33;8Z~|cSQKmkHX-^J&HhE*YxG}&DFp0-<{^e1a zRJdUEJR}^uZy4u<-D{uJk^-UFWjUaDsMl)!&`qBDQnh|1CpsCpQR94Y-wiMrXqPA> z!4vj}7LgJaBXrTl{upUBc)srA557+(d5CQi^Eg&-ol8W$XGh}kEBwtD2SvUB>RQ$n z;eJg=85MdZ%}@uV3@r4YIwAK^k+f7`ItmUYIG}>24!xqegsO?0tsJ(!G<>pdH*XRm zeRQtQ4+1E|`Vm={xAD`xhsFAgziI9Kq_^{T{(cqTN6LQH)ym#m23$#S ztIH2YTtFpZ__tk5ng4(1TsVFcj^{e%p`W zH}lwd7=Ct|4L5(A5)Mo>$vqGvX)R=WLjrO6C$h`eNN9tfkp?0Ukna$t0F(H|iWYsU z=<0*t@85*XxL7UmMW5om)$FpvU4I;f7jxJ8&F<5C-$HKn!vWE^-yG`i_J73Bf2XMa z4yASXvNv{yr8Pev73S*~OOYd-OPJ6i-Fl)2(W;eUJ_}#h>=cwS0fDU9TclvW*{w1C zEUkvl4k*ht^J7908Z-LBi#sQ?7|lM7S$JRkk`d4A22NYu_ZhH|iTdd(u+t)dg&l^= z64?vi2dLi*4Hhb31oADWD3P~@ZoJYb(gKE7&E&X69u0_1o_7ynsTA25sAr{scXg4y z({2dVJk`eX?G%rXt7EsU;*(+nH79AXwqQeFZHW`UkTcu zOKCsWIe$_5{b!#0C)|AhnUDSXwFsbK`KN^yYyxg3%fDi|`4polB~lXv-l0X5BUE>< z+oQN~4X^mb%Zqx_Qo3@|sqo9k%buQ5nfp&pN4Gx(sXctaxWNAxZc(Pq zzN{C)!kGmBBN34+K##Q-2tjLU*R>0bFMbiNtB7vDWoHPPJI_2vMVg>$r{5RdmRhaJ zE*#Df((F5(Ao_9$B(n)jPE8hRs6GchK)VfS-<_PLVicfQh2qAk9LfPB>3Ksc;+ar-j zm$Z6Ds3v*TEh9I<`v%JE1m339<-P6^o7Ec&cjHDLzh2M#2q+$f!b(`0Fi%K8$!gzd zfV_iFCZO;b+*(;xBR2XW(LHKs4@;SEFqe`2RTYt==@F-YDz5)W4eTHNUm}Jy*NA1H zUa#=3HJhRFv-4MN7PUp_id90C#_8Rb5){#1KO&fLG7Eq1`AZlc zOk%?3&btfOTvV23*MoYTsO!DDF&zXef8skDwL%x9ni9WJWht?zS~HboSq42~U%Ql4 z&bMG@#FS@6KZ7q7DG}P6AO{il#Mg9Tl4VjN+38!eEVX1@XWZC&F-Bc$Hr71;Y=1&= zf0AZ~83qux6h`*q1k?EL4vr4c*W<3ggE8fGEoN0d)kT8U9e9FZ%-j z{nFh3kD5mR=-8=0UT_EfQ6=AC{!CWHsr08!|1a6z-_2Qn|K2|%3KX{Voc?dv7dQ?z zlRLuphTf@?9QR?Q(7n*^diaoao9U`dC2Lgd>rEkvFW9j;ve6)tld4@!ksv9xV-Bea z720#w%>{2tvpJ-9X^y6~cZ=wsdR_%$>*&^=VZGzL)BgK~#I>FbyPY?7>r)>mEOj1q zDb(~HsJV2Mm;DJkzOG%K;oAwo6rsIY>1nAP)tFKxuN&PlB{4HINB4Pbh#{_GLZ>&* zQLh4lruu7rE>N<#6-wVU8~4y6}BJZGgPp1EF$0o4KkPF(o|m`84Rv)ahZ_sX3mbw)5f`;tFnzR;0F| zCX(ALRxl+xG_7X!deR-P2k1FTTLPl@DBHDCGff1COw7f`p(U@+95I-sMH-IE`HJjO z5qu9GQhtUfL)*)38ChU0N>|uW_>ZW0ouw1T3A>LQ!;&3|dU+)1H+L;xC zabUC zEJ)=8r(MQ(w(Dh({nPvD2abJ{mpDXMB1t-QcsosWm4z2mrNosDX!E^yEx%aC9KlT3 zD4gg6LeNIkPxl`?4*}Pj4|(#|+s153ARk|}bbDO#;>XXLNyAck4>+GDL@g$>MKwRNmF<%-LkeVh6HFQZ|{%N zBkUtg)~1GBF*vqzW4VzNBcV#?O{VhX)00I4ma zK5p-n_c?pU>$-{^7W$ca4jqzP=u+){f@{WGhqOen_8Z`_5ABUfSzF|GfBG59t2V{` zPVc?}v=4Z>liPU{Q)Le8@OS&(`wJ|L7?LkoP0wS)9O>u0RWXgao7 zGB)qa6I#adZv9T!6CPPLS5Ap4=0w+H7uqByQP$Jpqp-tN#)85RcA z2g#uI)Lopf=8KqW;4&~{=1`pS68p<09No}#fy(=)lLopu1ZYu}~x za=lyOEP||`XOqHv7M~8lu8?2b3i|fFH{xUR`dZIJ%Pw0J22V&gA%%%rVRQY`WlmC( z88$EJy*x?pGTvLUFvn_phziPzFH2mSz4f7LmkZfjvY^5f37C^nu&us>Tq9Tv=ftx; z+fe~8t7v`A9^F-90k>9^2Jag0-p{ztcl!JmSAdSS{}t}-C6dF;92#{W_|COj0P#EB zkxd~F+lt8B71dA6N`o8Eee6cGFfli?&c+4}T9L;oWd)NI+Zj}pnnLc>1I|Sv;Oop( zUU~mx_tPu66cJ?GJ>{91odI%!-n^vWm$JY)yY?b+iEm5f2qHKJW-Wj@NYxpvKj{F# z-(DH+9?2KZk#nlF61c*3%x145$A`1iUgBt%`bQQw)m^lThwcj8s|i#d^4)vA0a@mn zC})>?=H?qY>tIpv1#Nf|@;FJ1DhebUVlndAXviU8>J^A|AnUVr-usYnn-ng(o_SQ5 z$0echIa}+C1QvNcL^)FtGU4Vs*_rsG!Y{c=-$5(tBnLAJH7w!VP>2NP{3;j+VT(w9{smXujC0A>k6Iq2`Do+*1fklCk@e^;S(sQ;>SfBChK znunPfWw3_o`uGjmJS%@$zkC8Z622ZM4HV-xXJ8w*hTVYEGisYhcDcieY(AMbV{AL_hlxa_qJ#n@;QMyL`F2$pBmgOn=%%~4dQ$a52|M*p;sDV zVN`blCX62!D&J+c_RenHGG^s-LxTi0zm}ntd%L(Qyi@I7{`eQB{s zR!ndTGui-Hs%IFEj<%@y1viup3CMh0iJowT%gk(7szV8VfWx)1RNA8DIZ9|ILfwLE zu~+sx$YJsN!tHe4$ge(nT!N2;zgF9C8<2;I(&V}DdTuy`gA%FFgFNwWfZmm}D17bE z{p9+=!z-%JdzbgiKICHc-9YgwSUAjt`XeM}PzSs*Wu*Z}`2q%){UQ+>2gS0m{>!U7j+|V%o31cq& zm^3@{A)YjW?*W=vWjqc#wULE#22WxYG2XA6RX94vKXwQi#DQUp{`&F zB{1StQQ%J*4fr%I8_Kx|;iYr;W)gzOVh=(;zv(Ha4^c7HVKqJ}eUKv|XMJgpt$&XJN#J<&TXroGtr3`;f;U?;b z*Lv(;e1PHr`}5$VQ^ z7Cp9RNm7hYr>eliko$%<#T9fF9dvddxJX)>^WA-}8`xDT!zvfxulpF&T(jb0KpgV2 zJVvY)J~WgW+hzCAqTRjtLQ{fwTTcta)2vY3-NF8g=raw7oLy4{8|iBqNwpEhUC%#p zDUX6~CLOdZ-Oe(!6EWd?noZzBPhaD8W?_{XsFdbF$WXQc=~`G58sd#@OmdFSc~IT@ z8k;i)zw~v?G^<;-Q?llB*I}-wkyRlB2ZH>O3+Y*dy66|%?gXPo6dxssJV+EnOUlDb zz4fmJiq;H4NB3-B-Y=_G?s{yut6$%A73QGtFH7&k90XV+%V$Qq%m#}z({i+rjiXOz zhZ_Sk}Jn>%Th2YQ%f6a=%A9+zm|k#0uwQY@RBOoa-HJM$=EJzOCu z$wP%@Csfh!_`Um&#_g3*r5}*Ocft2lP%~cxvX}#Vb+sLx5X(>V_!}b37{&beAGQ7l zM;XN;5VRyU_1#NQ|8a8On}v2wAJL#)8G^+Z@*L_VP<(j!sNKNiJ*IjkUuY9>OnXUe z2k?$7+5~mRrYR_KBp$#Q9eA*n{dM<7NwxO9+NXPJ`jtwGPg@^23!3ucX!XltC99Zo zn9#CMCF~nf&=$SkcphX+Y=y^m$H@s7xLcacx{O8QQ7tZqniq~5^v>mEbPx4>u5Z@%HnzMz&mkKx(5l}&Zas$+ENHMB_eS|hvt+M zyN2)Dv%CX8tRI)W(fM)y!(zvKSa@LI`3~jAEi9Pc3ENjIQM9o0;n|FzBH#%*r}N6; zMUZpEW#!v~n@mPUgRXA3?-rA$qhMD09QM=idIPBthc_3bqp1>d_D8x0I2_kiOVm`u z<{3c+=lO1=usKkbhU(=PcgW!x-ib4|Pd4&T-M^6UV8rIkk;W7vbDC%Jvm1!Xnl+w9 zVbh7LE~&21d>c%Tp_q{lPmoiHVy)79DGv0*uhbQ}0tqAH{T62}1dhG*78h@1nu&SH zEY_y)4^<`VuP}I!tnd|pss?{^Uen!SHCP_+Q1P)#D%-(^LB(cASDCDs_;<(=>PcP) zQv`_8bk1HLZEV-61mlD%U_xfq&=f_MKgFihvjS35jhlsO?07)@bo|X!; zrkSM7&Fy|1C_2ysdP zX^CtSsxd6Nu_PTPE^qoI3Bkj-ye8^A``$#EAWx6{kq2l>nVVH^#s*U zX96}ir#0S|l@-_TUcE9Jl@_~o$~}rXDsZo|v#(Lr4gEy3$ur&+E7~)V!<4RODDSY^ ziMztZg%m4`@{x>z%$Ij}@3rZDW^VyKbnl!Kt`)?FTe1N+0yeZ-5ep8@B(g+P6^gX- z?T(lhel{rhDk-SinjG*!OC&nzXb2XW{wlf!(KtJgl%Nk-LQ08>ImA`{Q2rCcEz;N} zp7E@*k1}mt$3mvgb=pk6uSA{)MeWcOY#oH6GmrG|IUb@A&hUBVI*7)sqj5lS2nf6R ze`&p7lTj{aN6=)yS1skAvor8Q4!WZ&qL#TN^GwGEa*JCB>cGzFS_AUXdhc=$NvWr8 zdf7(?vv%K;URkgx(HNteu~a8=6Ky=;hy~yz1|EC#FstuEfb`_qFK7*EyyhC^Rz3Ul z77#BAwroc`u8m|)uge0ssL>jf3Q@25-wDC-Gm_)Kvc2tG2jub+z|3i|JrZxTG)*Gl zHiu!!3Y9OjevKCSPwwOI8teh7bYO~>tJrc5BH^8+l5Ww$ad9E9)36DV@hhtgty(z< zZd(2dwFkV{@_X9tg+2?;5!u@q* zAAppdea^5c#l2pKX@-oCk7(LbdS@bAMHp;}B_MUWk>D_vVgz6i&@^szD1cDX0aDub zT*ie(iLa0!pLkP<3j(-t2hE~aNC)v*lU7GX$GJNU=8+~0G@w@A)k~qi`wsHK#rsSL z0UBrrDb0kf2s$5kT|f0co>)?j$!q>1Fmd}}WWIZxXrv^A$Wz~l^_pZBa^>!=^0gz> z(^8cyZQ*_k(_5SCm=!;~e$38jD>j&XgA&xpAKKR8f}pY0a_3l>+21Ztg(&LeS;uFX zU3eCI{Ha^RL-RvmPl{Ge6RNCzg&QyaQN&`}j()NXy-S+BP_7BlywbR34&LaYh8O@%whDLy0WnhzAj?|V)QQ!<0EU9eQ#l-$?I zWPRHyryUNVUQ)>}B#t(Wdsg7pp*~WC`r$G;C&gS++&jOQ+`LYc?0g+;0ZXxJ^OBF( z^K3wSwC9;mAYylKBFu-7M*+lf$}uASX)|RKyxnQho9Lhlu<59;IfwX`7TLCifNY+##ehRh98LHLDdN4Zf z5l_!>u=UBia_?P5(siucET8evdW*HKkeoIFe0TxMIY^%_wD(NXP1r=Q@RUeu*TrzV zwdbibBff)nZ!9JxiM64}wr=1mb#w%D;Wtx!KA$t5H9lrzwG290j6Ccxf+`2Ms$v8` zP3BW18offMFB6xX_d2BeM=k}caRz#_e{7^59qh$QUoYMmNhIisnb!_UZqTldkk|A* zV%GPaU?(p(#2-HOMH>{GQkk=ua&%xr+40_xhkzNMOOBbtmDQ7m4&nNj&+D0>Q?bl3 z-3WPFS;~W1<44#q--Aa^y%xIIS442_dqEZ??zd5Uo%E{ppi&ELLZB(>y5u9=p8549 zgCh1p5BF-`s#(ZUt4T!Z_6x{`eiO^n!_uYtY{1g~WoB2=;*pUxNq%eD0{urb)|v&= zr$kB5TD{?6^-zQk#r7n+!227X{4B>?>v~e)`Y-}0hPPR6r12>Ws- z<(v>|@#6%%AKLU|cu)tax-%=Bdc+~RSUpKH@sOF>cz1DC+_|mXE<~FCi5Hmo3u>cE zsR&QfNHi(3dFmh~bf!o3QaS0eyJD>a~zWNX~- zK5=Go|MGr?E(*ByJBYhX*ImbjgQ9eyocJ||?}eqEX|t|E#EY)fcoPtUSQV9G)@PId z(y~canw@&Nk9v#H%Smz9tMVyxVJqFAcey{_IO_chFRQm>*hJ2?+vRmwJ33FlOPQbP zsw$go+sTx#i;l|xPe{_0CNYP8qC&>tQR9r{mm>-enZ<*W)&pnGIE#MxIBV=2u7BzeJC2d@ou9Lg4XGdmiiqGP{##M=%wFeDpX;O3uy$i`li3S999UiNO z70gCnD!$95*x#|cr#%?N48YpiCvU;|=zWmGNI|oBbc}{BVju{s zJgO-P(>%@MePr5rs1{1j+10;3gzT{{?s5`rd4EwZr2CcO>DRMj+RA!? z!&Ey$;p0ZAEOF5r-Zn1jYTo1Gq+D5ERq-r1NvvU3?T`=6A%+N9IWg-JIcge=4Wxpt zBpy_^j3bM^-q^FT)w*VOzd3enf85Q;+9*)XbK4|ypoA%;8+SQGF;3G^gZbr)8IVex zLz~(KQcf%MC?%lKTMM5UQlNIh9lxGeusGiH&2-GHyVOds0q~%ELZxd|H~~7;r|2WqZ3hygK0!D>H1!mvlMh z@T=>BcRrW98zB3pj#iM2mdl9WjLQb^^sivcoc0o=pdT5yt4m3-{Qe*i`(Dt3$+vp5 z;4)|!NvXvN+{C{Xvf-phxPql#dN@+;E2{A%?ljNQ*e&FBeT6}zURbFXHDP*em)^Vu z+qts2q3|Gi0$@UM-{JmNv9gk@M3pO>3r%G%19RK&I0xAAZ^mgm-xxc;_JR;sAC65a zTT!o9K~CVnh!ZSpfKK#>Pqh{W)$2ItXP=A1QeF+wiQw~_!FFzg_cbgqcJtMazGu)@ zIW6wonH&>oWQ_x`_}j1Onr5h-Etfgd;a~MHYcCc2UuEjIFp7V%%2X?bWyUB7Aa_o* z(xv-1w!z_uTOQ)(Wj)+<4*D9JM>nu!HwSI#TB z=C&vme+T(0^RS5E)cM@Weu&{+BjgAE-e?7NF+yUq_|ZYA6valiobJ-aUSjvUF1r0% z(1#J&qm&F~f&RSZFUTiX)jC#lu@V%XpGr*vjxJJ1#!>8)E^Vk@1>WL{W9{Hq!Lz5& zPraO-~8y@l{M9TGH}xNB5NtcgrEXmMRF9)7mM z8#Bo^&sywcD`}=`u@}s{lU^&<0wAswx&TdK{-T;9@6*%mbdKtQyCs8O=HqJvlZS2f zPkNn4v@t~Hi7{w7?oVmHjAKMdWOUONM+Vz{;nB4g1fx&6i~JI8#b)W*9qHlRAK#>& z>pHCc6D-7|pDKl4+k=;2+=!s*-#{;*MJN`p1RE0$%$_!n=ZsSbY1!4(XRh8#Cs$6h9fwMu7FVQdxZVy(rSuo9$QtrkYS;@^V;Yl zTs?dGT%0xo$aRRQ3h~_$1J^ih8pg1ewxzVP28y;5N;^Q`?2YhNjsyR;SH}tAroJZ#(avH>sx6gD=4+~oCJ^I2%5p9 zXBP&$eVICig9u#*UE)rZ96K`gne{FrTZ3s+8Gc)AL83>v`Fd?hP4&Bqs_J((kL|L< zO=fI(?>#|3f6AhN)iW#GWJyuZ#QwVuUVw(b){2j_6O|c{ z4jaE~l0~#Bu`3L>`e51rLhCG(dq|8P27i|r~XP8iTK_W{k4r?lh%0KQx|O^>{~b9_l-a1C3q z<}B$nS6d?KR3sWE*JOM*;;?L$pzCs7QZ5)TN&|KSMkHhgFCyDE=&i4kA14?GH2%zJ11z9 zD6!4l4UXiWej*QJSag}Kx{Gcaz2JytZn`qRbZnyZ3A6v^-L3PYPZ>nMequ2$rl^vF z@FM}lUV`M560toD%N9KPWNygd2ESjfrE}Wz!=y}aF}v+67x(UUSXxMnA@Chbl272{ z%5Og&a@%0&koWPowJ_pV6P>FzVNHX$nwVvw)7)!oE2yU_6Ww6T_2tN?2FkjApc24y z17M~;pPXN#6zkXOczwRbB<-60dQE>1A_ zMj|Dyjl1d#pvTn8*Ah4vs(cO3nHw=S1d&BQjRyqlB7{JP+E7q>8{k! zyAPgw?m=W7nTjR%D$eGCHsNJjO@3wsSi>0N#LSA~SkcJ(z!k}nD=X<&G7xW>!)^qO z*Ikunq}cbi$Srv$?R~j7UnD`Pf2^%UFwVV_b?#ZOrJY;NM2D9IaFbJ~IEg%=kIX_o zx#lbo&8}$ocE`R+aqT@`??Y4qVw*RF+(h~0LJ)H9z3g{>LEL_8eV}P8O~e@M~;?9a4j~Z8e#ChJr)WhSU8U}&2%M`T>bA_q3|rNH`u&AO$5P4Af}xa&&qo$GsiikBpm|S=W{YxKsC5G3L%ECgEZm zw3FD3O~&jM2Nf@|z$C88rX}XKDYyD=1yaRl{G=OToMpW%-svb_-8_??BUx|sa}PER zv^u`euVp>V6_DaX4 z%lWgaQ(+E~I;ay>@Zbbm(vR8Ci#;IQY#ukpVMi`rurvu>mN*=SD0yAfZN(yn_H}9x z*?3jQI$yGo@XT(n4EUoZ_t|A39vhi}1+9|_>jzU9b~WwGuTxEE7zOb&c(k6%a`%~+ z4Uu5`M)p;VjEHa*_eQYQG_VpKme&E&0_eF{?$+x2r-3n za&mI~-5!*WyCc+_T_z4>o&s8cKe}PsRFFrNOyrj8eF-Sd_CeU_QkMiMI&Ia^VtN02 z1FBg%v75mpF}kYdla8+8D?1TAl;qtbR1Qj`hr(PqUGPB#e^g6pBH2|yHPMQyFM5Kn zceiEq^#x#Z7^S#pC0C0@Aw|92!1#)R+FSBBuy2|5zZJDcu13{iUM_AJC8C;Z5i)?2 zsTpC5egs%wkhp#r;f9-A1dY|t<@Q?w-Ewvkv7N^45;;vV22%Ex8$XV0wgL?ARo`gt zFVn;|2rH`pHU(%I1V`bhfJVme2mTd4TGoXkh6{GW$C*GS5thcYH@l~aFWx@{bv-_MMOHT4r5-B`Tb1Y|Ilz*VGk}##C zqp_%lE;Nfr@0_0I!4ZoS0a=HPrO3CBl_od&CAx6!63Q*7w8#U|JYT=1h4+(W_m`O}-kdCP_wB9vB z5mDqbmB8}g!yrMcZTgUO3mP&(-SK-G+10-RbvrV`9@UJVI@inBhOeY)+1k3~j5j+t zFnNyjT;IhMCcyB{eTUCd+PUazy3^fEzD9^{3`1cN~CTnm_gSk2FH`4^%_O zO@lw6jE}}6JEzwjV0!^A=WG4AjS&7#(|?*k@N1$$S`k1U6kelOC1C)x?o+=Cnj8G{ z=|AF@>!(&H0Pno?k3!@Bq1XK@D?j~(A(Q_1=>u7Xs1ANW7NhR-cTk_M#T(O6Y5*1x z^O@48G4-+Zz=+pewq>S6a9fAJg9ej_{j}g`P>nLU&+SOyU6VTBL3&=dHiiBnyURc2 zOq~7Kcij1}`W^f%k^7hJFP(0?U$?*q&BFgzcq{yku+6OcA6;ehH@|3){ki1>cNapB zl8o;Ss%IXqg%f%-cfJpOS|V}#)Hn^qHDR1@m#5lCON}~!cIh|@pt#PDK#!tkE<^BI zR7DI>>)TO4he0SDEEL1X_NeaA_q!1D8*mYT0BT1Cf=Pxk1t@w196|*{e*gCJ9BPL$ z+-fOeTa_F_sUk@LGNI9rXW5W8B(DYw8{U;Ocgu>SUDPC$&!>2PzCL(Cr!AV|kOmZc1}T{DAkAVDBcOYC z6li&?m%+YX_)lCLiO(pC6D%B1rE`EBf)S5{i5YYOeFBQc-`)bI?hB%xA`g8By=0MC z_H)%hL<7={`6LA&moZxcdw=JIyrk>0OeYz8pN7FbSUW)PPk{_cKjPhFlQtZ1hPLO= ztMn9C;0A5mJazbb%>~bj3UV|u9f+Eh4xYU2;64mQ`r_aGYJSFS{>e(|8k;>pH-bb= z|6V%{NGpG@uLeTC`EeFo$NpXq`p<9mpO5A5Y`Oke=J~=OO?ToOJU3ha8?+OCBir)t zC)2?49zwi+%dCfUW-9sIH%Dot$b{{7?4l^ydA^96IyI3x8BnJ%dg{T<^+xHeaM?GP!0#X@1`=m8(IJJOqBGlLdnNg$ zRc(1$#i>?3Ue<0E{sH+f@{_k)s!px?iD>vZ=aXDo85@y;qw+2xDl40-Id4uxPocde z6-QMbiN1As3AJI?>~Dazs2Js?>~B|)jf|Pql{uCEoUMgnB=g=4^J#zN@dqeQc@v=2 zH(B;6N-6OOn~0{tR$qtJQ_$nRH|3_e2Mp|5$0Gdz*l^}qY%-m1ycwXcHeN!Wpd~2; ztxT~DtQH}P(JD9GMej}w2c9nS17Xu8=6V4P(fZRh;qRa^*dS`-o&5%Y2bx91_5xxe?3$Ge+^bM1@iSj6A-=5R0>Ei{tVSOsQtxH z^H-8~&wq0&zq%X@>xnA>j2*ad7;(#)^T3QMVAQwyt^E%MXKY}CddtvmATY1%rS*RY z6~J?`L4cRKbd;qO5wkptUy;+MXb=E@ol8=U7B`y~M^0Ivym;k?NXTrqE8JW2;iCSv z(Lj1Losq=StUuP?JOv>^6BW?~6Fm_qAkJSGMu-XC0S^17;nX7X;$nW0nl-p_EtcEaW(f z!wlwtcWmiFFvo&Q+#>)Gl6rgSMxETPin{u1T!Gq5NYe`C z?YFLN=L+pAvv{%xzJvV5EmwR6A?>+T1ENtA^k6><4<$XZ++REF3~#kmnINxRZ95rO zdX?{b=?%WOvO!;19?XJvc$(pi8gCG=Q34Uy)_cgi=F+g0#^9#s&{15?ppct+R)G-Ac7Rx*l*Ys>dvZamsk|0S-Jz>J7CHAH8mIn#6NG_?_Z%;5Ery&1L_w59eBa zV9RsFS(9?gr&*Nnr5Q7H@D%$E}A#JB4Rc~4qgGAaKodwOWgxt2M36k0V0G|3< z-;Pa0+3-?riI%%JDYk{ciHHrpg4frUlJ4W4kNcD*X~vp#I5k!XI(`23Eb3l;dO@-7 z`}{<$;z#uYz#MO$od$;@VBbN{L%p}?o%C0L=IJw|K96s6K#_U-iWeHSj9I|#m?-$3 z1OGavUHQM5#`R}E?0;zde{oE^?TIBwi8aro-`FSk+ABAeeQCPPyTjIEn?}u#7>3zF zo>YrQ3#XATHJ0DZm9)aZOSKwAAiiKzDKqjGEoI~KCwCIOd0Ck-Zd9pe zRchZ}=Qau_8t9L$I0K9~O52=yrHI6l#7v9s58h1yUjns7!G$0{T@6GDgvuq+IG!o1 z7(m=2goX>bU0j#~SKn#yP!oO)=Yxr zb3GtGIiK$!NCJFKX+1xFJ$`hBkF<7U=(q)pY%}E4RBq)2_(5I|5%6KLq4Jj~wP45F zjYEzM(%C5@m+l#%$T<;~Gcz&XF(=w&RZ~Lzp_%FR=26Ho{tT5;9LvV3vGc4>n*xc~ z=0rJ8F~z7v@n_@qqa2(}1>p!E!1j2zdY;Z1f!sg$vhg!RyU>Tn{Ggrsh;69Ln9giM zmD{@|T$ydSsp8Fd8W|MdUj2@95DCqP>?{^lOmj7f z3mmf$_pUqVI<|l_b1>csjVfh3@zSNNN3kQ1A-HwYbtc9C^Jq6v$7T9J(en8s+zcpnAZKVV1YBM~S|_aaY;ST`i9|axnXD5H0Nu+yeVWa+zi36yuPG z%eWlgGL*m5j|6FYD6H!vS3 zW^RY`pTIUj_6pTpt!s;9THPloG$0{Jhi*;)jn{otHaz+E0F@6a3A0HacCLa8kI}6;VzC~Ff$!SKOU84 z{Y;M--DlJ=&l;dgT8M>TYt1x1FC!FRYejqU9Cfw)4c8r3usvw|KH#11U>Jg029He; zK4iHZuLbm3e)Ws?BvUh4xaazwQK%oYj zEU~zFiV^UwtT88Pg+n3qUSD(+RgNd9R!6BNT6!f(QY8?cRb}$o@c1JJGdE@=OBc?r zlMkD}G>4w}Mpvh^pvJU;DSAT@o3_-@jXan9(GLJnbAL|0pmBy>E=K$(VcRboxNsz7 z1bWN?xxcK|mLweQ#P;b+{^bVIo580Vt0A@GnNv;GT0pk&E8)go4EQ_f%dsi-@1S?6 zY=Bs24LFM-wuwNA>88-^^sd;B3q3*CGKP^dEs}z3@G}4s?g3%%U!@%X5t;r+)qJd% zrlk6)XNt>93Kn**c5bn44qU2PsEq2{TgKgiYMLR6(b!WUK<1qZq27WJE7kUY2VKem zvTULq3xJ3{FXZ=c`7(!j=-EERO>2I|iX8hksJZbcEg5*xfa`gIpEJy#x7Ytj{Ed5W z&5wphIeiv7z`Op-o}s^~2K-YsGy{_ng~@JiKkbH8mK!EK$TNMVo1Z)i?MAvr+#^|A_Xx}yoKszTopLr_1-H7c0JmB3I5O*>Gii^ z#2e_)(Q8W?6jg#&nCEL-U1BChgi+KdSjC#%XO5I>MxH0J=e?z%$c~hnT#?vt>C&)5 z9sxJKHHG$UejTcHTs6 zThreIB=M6U3U{wl%9{5GJ*qoC(EX^PeM%&zc5^P`F|rBkCsLkI8YN6+ResJR?Kv+C z0?`*wECAG%#UrsIyWRPl=iYcFeSMX4@qW#al+!(4=C3brXzHy8_>$PPhGekIM>-e9;z0b-&O32}*lP(NP=%hp`VG^eDhDO(>KTZwk>X4W z^=Rklr*7;QZG+AW^W~!&^>Mc9O5KMjQ%eT14DUbAd#%0pWW9U9ecAPv8U29Hj6|zJ zDusm*8gbfh0O#5p*{|{i&I7F5_1rAm@0tt7w-Gp_)B#{u~u3S~avL7$-C2 z)2OfaqFrIq4pt0nu;31;plINEc5$T5EDhJeoSlYg8gg^x(RCJ%Gz2@v74JQwE`SHx z&!u6*LBYlSPN@t5qJpQ{-L$C=@5;+U-ZO_x@y-MzmtAvj3C{VSHvNlw{GMO*M{oVA zqW-6HoZ#oUAFEd>uph@5f2nB2c|iE6nj~Dy1Z=|n0??l4%UR@6Oc5S3(kpYaeIK*a zpN*y+zi6~?sg&&n%=?>V4@q@)HEl+ka*?WFg%2vCUMF1^7Znz$sM;go^d_62ZQIVF z>E%zEbFs%mPdFq9gK0JO!uGzOCVeemdEtdL-wSkQyyJOsEyloSMyPiEG1@er#MzPr zZ^lg=1e|;}Yd$_}03foZ<5ES{1Dv)lZr`*deI6bNI3BgooUZwXj?h>Aar0-?s1DvdCR%M_VBgfFcWDoMF1&Do2u+k;}aTgjJt zA?!E4E#q@z&xIzRS0;5AXX^})2`JJg+TnGa;;LIE(3g}8r1$y}NTE^ceMK~oaIX1F zC20qiglf9yxpdEMR>U3(yl(@{BW(9Qj59>{&FIO?fC7g!{&>^ zw~O6va`t{zbqe;Jl`bv6n%Aby*W5?dz#?pD@kmjGbWNbDbOe%bU{(I$r!x}=$lZ-6EpEO8hN)GrES!cH)>n<>lgrfC2)LCC*#P$d=%Anv*_qyNkC2D+!Q1Ve zM@~=qXdkl^c`(~%5`Mi#Pi2$vwh$1*(HHb~CwM)P2ZzB}QFxAiH4{}6yXq*F0xvAv z$iNaSLB_v@)wRv-!cjq27@wM02IZhOEgk{G?^a2ds>50p9~xU+#V0+`C>{90>R*Qk^659)(f#v~@9SdN!UpJnA{{MfnS$^^1{+~){wx~cP7cCowm!cGt z46X*XtE^^W7y137%Zpe?s_(~_H8%^MZPM#QUTiIX7CZQw23!!2`uKI1<%8crqU@+$ zzOr80IbBOnm7CELtrG_65qhS{fO7-70lOng2l{4>=5U~8_S?snJsJQxuoUx6t|s17 z&6dvFK4U~Xi~{o9nIdVR)AQTMSet?#gI3d%p%35Mn|{FL=yiUTH1cw7Y1I~a$abgXp6MR(|oXTrLZ z^K3j(CHYOuY?4bb!Ya*sR@YCfxwTKfhQi(Bt*h!dyd48M=1s^8dr$ zdq*|drt70YL<9t+cY=bV(nNX<7Mh5NNRbv5kR~9~YXAk577+!d1u243A~iG{b_V1j1_HXvuf3V0}#w70(-sidR>%Pi$saK0(cO864 z-5Mbg4OhxbGJ^5qs_&PRfObU87()95-0I(T?UBWrO}TNuA3DD==KfE@NB?wB{L{Zb z2-bXh=W>1zo1N!cs2dg9T{nEHspGm8+;lH))~^?~nuB_0E+ z%#K~EV)&^DE#5eSu?=bS00B@cUC5QVF6Ej5l1++tk3wgvkobf*wvGisKD)8HTq&(= zuUtDc@6hHisD!UMZ>TEIcH|8G@&~DTE7(k~{js6bn^khL2<}FGIet5p7C!b7%|{n| z=(HfEYA_EIzU8ZkF-Rq}UE^I?PFRCUQrjVej?iI#{*5FJpf)ecMcY8CEZ68i+*yKE zo;ql}uB{J5lsT|5Niw`LI?vw`2-qW+!`x-uAlq_J!jF9Ck7*tXHIX3C(Is=+ztMuV z{0`4pZvqnSFf;KrFq5K$uh-sPnEzt%F_C4a6X(`YvP2wOYEedpl9}6sIysL{4yT3 zid>}I%65++2CRJ}uo$CU-jPdj!>KA3XBnDb6MKc|*6fbNQ~0Sz3AE{o&=UloSBoRT znko`LnLo@OGI!Vd>uunVrw$Pg5*&;O`mu`N%Xq@O^2ggFY$RBVR8PN#pU9vw8xvT3 zO4H)Ys$(ukDSkS+KC|5uZnb|5{`Hmza04ABn9mlwOdTP=LcSFXO4a75AKwjdmd>kz zsPFqdSi@63>S3-NDllhh;ujr^39U!epfMlhjfF@7%fj zIc-b`AfZMe_!SZ$NQiBvuTz8_)m0U~1G>MEUH4fmTeVOhh@^KLL}mGqN{Hqi_y*c3 zL7p=#3hIQrNcp_^?|rL;3cq;kp4C9B49qjbW&G9!#ne+k)OoD zM(?M1*zaj5hNrib+@}QtpAugUX_^OL?0_v~yq(M3*>_0aPX}szbT`mOGha%sRrCLd1$JoWMMF(e+pw>$=z9Pz z9dT`2Dlolx8)4W7RfV-=ACy2AE6I|;*9fo@XMsC2T8m-@>vJZU&^W*S1htPK4iGk< zr^lOi<@f^GTf-scNxf}Ywk#x8k9O~J(YVBN=`ub2TXr!ApIS%u5L$iGM6LCgtQ%cp(o>p5)l}h<;ndpiDECbz) z-8wWTA^5ZiHBI3=`HS)l!XmP$UCeZ3sY*yMa6S1Q@V$x-sykQk#GzisPZ0_mzoEHOMj$$>3Fv8rA9EP^USjKl+L}^143lZtBHiiTbYM$g`2li`N$}ou^j0R(WDK`9LUDbvWPuDewDoE)g$bG-%?go^q^p zDR&x_SLrr;0z5veT4*Nb1I{Y+)vOA72d}9@LB(AR=wvTktz*Lryy(rA6Wyt#4j-s` zAa;F(Qmhz zuvXe3PgBDhUxc8xOsGL0%pDF`!$VrvE?WQ&Rf7^A2WE*P)m7G5emv{L%Od6c&!lRg zuXn8X3tQ4t3rhmMOwhx=7KHU;&5P^4h4o2uShUG!C-{8$1X6gU7c5XM<;6zBg6Vhc zZaRJ(@bt}sniw_2jPYRwc!pR+=}uf%n5KWvvt12Dj=go~`Ib+#v$?9GugE4M92_Ie zv0YMZ^Mcn7VqvR&*3B50be#BT`E8nDDz~Qy18DanJ#T)33WePE_0XOGgX0;jE05es zL;C`~+PK9aEFOI83XM0+44~|N<^nnxB#>OO&0+cBugiL?%0|?nOdF}P(}#sK*T7c& z1=C^2;N$v(g8N%vPXCAXzF(7#K?BHf?|wn~UC}I0t2GwR7d8nGe`B0|2xCfdXi;KVCGof9|cvInEpwELw_n^GlGxVqKZ17x~O?6T@aZDxu6zW{&P4D|gy6eP$xeXz>I;a|HwfR5U*$rle>KI%FpE7l~bpO8*jTi-dX@(2)d?D znqcBNhTwucuyYfDRnm@;xmq7(zV4}EOGuhTyFtvv`}I}qqG(K7557B+tmxvMnjI9(j-OaAmB zmCg>@Cf`=tB@w>PAjTQyv!Mev={k3*F5s4EiT5o}!0^W^?^92qvu9Dkn}RHEAVncF z`_)z$VB=kkvrWwvaWEyKp4L&YhZihn87Nix>TbIz0cEUE9iyrzR!3ZXvf%@6hXd=8}c6tKS?$-zU5Kkxl`;?lj+Rplk`w3aTwSM zBh-a6WW!|xZ(e?mSqzS|P0|Z)uXUIvtXL>wMjK`<;@V4S+=150Vo~1;#WOkCEf~LB zx)z3VkOT71$)rVUD!Cm${!Zrf#Bp|e{~L_YreV$3AkOXQ6rCoOQ}^yWl7A z*{gC55?O)*Zi|=XL+h&GO9XJ+2KTc#z*NgI&_Zv%PWba8Ti{Z5F1oP%`}Ug5v&w-r z5B8GdA42>^cY-6djp3vp#OviJ>=;ST^`cx0 z@2lSyTY^-S?!0|4Yi!d#d?Rwp;gpqBFhOlyYDN7F{k%Alh14`G0iYFRV*vNkdG`Fd z5|qPGBhOdhe3TY>Z9*{JLz2`rq&ei_gAtuM6UgPd?#=$@0wx~7iD1x8a!fqmmnJ|Gzaw@U3!m#)yXiQ*qy zfntVSwycAfc&;NZ_TrJ$G7wSvgB8HI3n*$*KHO{ z-ppCnp%#Cv7+CWi4b3oe@wmG(-^#DM$ary;ao*w|1Y7-A2Lil;%@q8xv8(y!2$j2m zW*p^3;rE3d$i$%wfCvY8AZ7ymCF0btTjhTelRxxp=)rq;=;jd&uw4g1wz(|bQmX+F z#vLHe?WO>jWGC8XxwP(6+!|x&y2$iK_GyRGUW2TrXZ$ z|EF{_|GMSXY1q-l0By+{q6Ahq0WDUIayL-YM_u;emNs-5zUIknDCwd)H(bqaR~lix z0JN5|;k^B!XM=F@a%`^t*}{#tUo05L^e9%paXNfeR80BfLxTrNKHq#5%^Gazj7<6z z4 zHm=X}j$ONIA|KRqT7H7V4M+b;eBA%?JPN~ zkVFN`fnlj{LT!RRK>_vV*rKzFa)XbKSibGNM<)obTu&oHT4}6x{n}@j)h!x|T}M1T z97>noChtfonQty+UpSv4{QOIufWQXSK1l!NcsO?I)bOdzZP-}v?5pwfr%yM0jQS>& zdF;(8p^NyUC2X!27K%Th!75=+Ylf;0S!V6LsrV=m+c2+Y(DRhALu10H z>p}L@h#_EX*JVryA=M8n3t_N~t?G0W4Qcv>ms}C@J;#=PdcxG%vd*%v#uF>eYm;{{ z*w%KRPH`#jeqSj&%>+MC-g6CKdbCgaJ~j?SrrbD)kf`sQAg7R`$ls8!VL|g?&XOA_ zE#}vm+nsvU2D^UPOgiGIEbs8j?FM4nGuzHIJIl7+D~7e}YVs9e(+>6hE$GdaKh8!4 z?O}UqOuZEKu?3n;^`|Kx1Zp{81Pb49_6*w}ffj~$LJn;g|jIi>@)X5elDqKvnmC5-QSHR$49_oKpaSsvT) zOJ#lkuOtos=LC%ZOg{R*Jr@ilb^ODbZt*0?^AVW_nRz|E>40H{tpj)b@B9A0FV6dufbTDQ z|GzTsdM!BkUz#-}_=kNm^fqMDs~|LkatH{Tq({&?JbVBZAeZHv*D!|QEJbbc#pN@( zXx|OZZc-fiTOpN|Y>5nv@h~!~8g?kJ4HNA(%hZ!A>p!`~EIVWSg7Hh^mSs2_)s2Aa zkYgLSUCb8m^UJy^r7AhNrdZsTIo5S@!Tm$|*)2x;fL)Vw^tCYCme+D>t%L<$H{ z0Qg{sgj6~}Q!jo9^dbeNQ}l=Ad^392615nH5swH))dkIVix-Dm6yTNzSp<@C#PNlvKhR2Q@6YL_jvuHQsM*I$x8`gYY6I|@xZ|LpE zbC7umm$Z<|@0qPAIT)>jri8h>DdN_|&uqT{AWVBbpOu99M8(w{LKyfque2{ zV5ff!n=$?*w)YLZThLS7N3_Z3Rw$6H&;pF6wyOq1Es;w!(TZehe#VVjtH+pKOUWPF zEoeJ3WEM`gnXbNt1qAkCf)s{{;MJ|~q^8W32w|c?{$1gwEGxu4Gv#xnN z>%|?8#T+j?N%m%MyRCVuaP;<)asal9_r_A633Io^a^+dJO@;Fl>TGO*ZDDpkk5Cm2 z;{++JtIrDb2zjLybG+)xu%<%k=nB?G{-R%DPo;jX-_2VD{f9w3U)^C<`oITfT~fBr zE+o>oruy?WW!OAZI#bg_&wWbU>-CXiZ0dLREz_yT2;jsjXi~Nc@qXcq1|*gyG_u`b zo?=`*e4_ET+-v}=s#S!!jLcHs6;D*vsPfxGX z;ea!zoQ|q2xJ8aVR7e>dtiOd1Qpv3uUbfWCdZAt0zhk83Wdg|1%5doN+KUn@%Kc2@ z$tEXW#U2KmTDULBjUI7@;4&2{l6U%C>ucUS9XTS;X@8a5K@-FVS_L8V+_UqwGAA?y zb+>fSBQr2FxzJ4)nmkY)W+?!7jq=cT;vY=*KIBoY_TQS|oQKrG)JUM%*xv=Ryt+n=b<${;wkD|M6L0{{p4G5bat34tdaq*jftGn1S+Y-rG6$-(Psts``I&H$#oG|asCV&R47MwL z!I;=a!VJcYv zfCH8YBkKYb3{gNZ+G_wz5afuz?a}~W7(awNut05QWeK z=gw3c5tBi_I&@EZ$_ygi(d>6!=p`E zYYVApjoQ77E4F-8mF#4ojfdi)G!En6+_V8OX{#Ty7*7S-vzD_M@{V=h(}bW~zNdOJ z0&1tp4ls^Z+*l}+D-qFV0b|aW0aL~M#%vDCWir&NZTeN7?sX{Rg*`fa7Qq0S;r=Aw ztV*+%1}-%taEh-_aY_1H@kK@~^Qk(9TQN=(HVs@a3?(H-=z;BRAs^PqM+QoTCuT6> zG6dJS4c|Wr1Oo&sOWkqGutkjAR zd}2c9CrCQNVP0bXQ7|!y{qNR(5Y`&B1m4!A0Zk~5dzu>jPO)6RIE7oU1lhmVXaAPCJ&Kn~Fa$H|viEsXE0 zmdVO$k4YTj3>}e6e4FvkdnA1mPWI9gr#;1VV5ld7xe*)~p8n0Mr0E|Yux*ISQX1Sr z03r4`vdD}DVNkFny+!$L0F3lT?09Cc!c|L``;>le$=29r^=mg7q(jGU=b1)4Quy3Fp_mu*;e&4N&S`V8*?#9q!N1Y!C098VmEN> zFY!G8TrK^-h8MrcLH`xW`u{q;BG0!_12P;p{RGW2qyYpjU_!R9jh*Is*}8fm>XPp@TVO7-v`QWF5p41O_%x2-B> zjP3)Y$q<^5Q5$006TuB5@b&^HH89kR=_g7a2o<)2sSE(FH5UhWH`z;@(!TAl{WceE z|M&k6!BsUH0xl3sME##2I$)H&ckYA_jr?;RWHv_#8}oJs*WYloGd zT447)N-s7Ls%TahA)ggwDzmLJUrL8K{Kn|}oA*b;D0kh13liI3ZLE$FFRb^>D#y78 zm6T*j=sM3L+6RGV=;u;as)~!vY&C!Q-TT5=scs2q`>mpW*$@AWKUV-G zkmS^9+Nbfo;p$&X$ahEn*Cp^b&f}kdN+v^SnwGTs!mZMCf3bLS_g1+2o3FyJ&-_1l zOg|6($M91qFc0{L(dao$F2Ce$7Y6-BAbNd^29d^0dCgn?u|2v&Fx5QBVIe^vd>(FY=GX0)Mu@FdeC$gtY9-(2yS{^&Q?> zHSMaXTv&Kunj`(ydlclpPz3WhIOnc_lxsD4$`#1;>^!+ z`L6b$=LPEN?LuF!Y5+VX^Quo*r>zHMlvq?QVC`B`T2`Z3#LqZsSRGOokD@D9dIOkl z+@SmOPmtv^^m#Wy_BEjyOe|m%TCs#(hX=hk4xVpZ%;u#t*U@+wGI#%z#RpUG>iuVm z8+U$p!Duf8@W9L3r{Z0P>B45hWY-rZ54ooPFla>;A2zq$XWl>g_w?^FDh5y?G6VbH z?-m&UD;L-Or-^od7HR$}-2A^-F8)i+^Z$bd#vEEERO|69rd5uMto~!Z1UJ0BSCaTe z7SG`)rmZnEgGWm($g$AAiz*+y_Vo~NfeD84c?10~4Z%6ujGt!v3CbVs;>3w7v>jPQ zn|AfUMn&A0pwG+#E2f+FtFjp%U#}zH>|jk$ue>nc;3@Q&XF(&tp=Q-NVWgL)4!Bi8 zW+ht>pv?|5r$0H(irnuv);2_ztx5EkevPmC;xuI>cFbT;;gvA+eBOqBs-f!ir}Qnw z_j&ab*n>it3qWxy(H3>J7bV(f;L7pt&48TMxGTxv!S)XP>kHrM%~Ypz$FnbVuh{$K zhuZhk1VU}eK2A83#9b(Nl^mB_>y@GHEazLO#@eXcX~%6_x&&W|-dACZjgYT($R~EL z_Rd0%Q6w(A`27f3!73wo9$G^nu zKnfqylsS58y}Iv%E7af_gSw*U#TQXMUBxXo1{d~XYtD&KLI>8KG6B?UJ+Na5=HcSO<^*6_o z?pXi}cqS_T;y`*^O{vy_lQC=P#Ivs<*D|gf-V2h<5w>7nL~9;ohO3h|+SFAZ%QeDg zWrP?4%vo*YHgvorb!%RQqp5uMjA5+rR0{Mgxz2o!rGGP3HZ zA5(7jjqj^3Bs1=wdxGAPN(Ir4q}y*OdU+aagi;Um`*zjXpF&w^QdAMLLl9Bu8o3oo zhPPuG@|Nq%UX>ip@7uVmE9cGq@Gz_3hmCUqTmZM6NvsrsWyC|`l)TuepkbPb+w`bH z;DBxQQ~iK9Vp5t$PuF=K+L>lNKiXwoN@VE-3!*Pjtf;34HKk_l?hRE<+Sm?y z0Zi$NoamS_EU%^Wb9G*Gp^J3zhYFQMhnet72M>zMVz|J-!Z0wAWgw@Rq}-@Cel~A& z>A2=9lQ1V!Q`f>LTaQ4Bjr~T=@aER&8~ffg2gc4}ZS8|hniy4qNKZ8xY4moQg&o(; zaksXwI*~M%?lR|ns>-zE@QKixmiebMjtL{iM4z^71#$-EvK19fKs;M-6e7WnlLy-UKT9 zcriG9I4=Ms8WI7~%TYjgfUknJc2H#iz%K>QPUBbu06+J|gS7+vUf(B7BTRh@QdK}L zG#MK4SzspzPDf)FUxeNTa6l~#c@0b;Vpw2(5@di&^CSdD#O^~vV1YDWFRu-74}9$` z>`a(lfc8Uxl8{KYp*@rA0FXg|NuLCu7VJU<)B&eS0)(s&ke5V|dAq1x_<|mZMx-v% z`qMUJ|Fz$6#EgI%$^f zB`Jf!%tHf3v?{vrxCNy439{w&#O#%}dfke%SZkZbmq4JSUkS?wmJ1h?+8>X38qQa9 zKCz$4FObir<54*5DrlQo>T}@%63u8`dplf~d`?+hOu4~hfJQ0vr;zp!va5cWr_Eww z=VUU&vhTn4UN*fBDFuZb@NAh+Q@I030aYFt!pGqSX>K7IQQDlmc8n@NxD_MfuD(lS zzD>9={Y^O?`P9!`_p>|Xm>a)C!kQ4yKD-tX{_H66Y(52;Oz^-%OUb?)^weuZZ(i0> zz|p5PEEBlu(jI4z(aeP*gd5;B1JiBz*cSf&+u9TaSx?*?%f@v0`XqTt2YlNUDuG-p zC$itrS9e}T5ta2^HEUzua$DT3qnByW&|~xYyt{D~Qw{B;bmjAcE9`5CVqm3(_y=pR zQx*jkva_C1&m00S@d2{XU_Y7Yl(iMq?I#c; zFZ_Vt^C5Spu(KpTGRJ{9y%rxPEagtX;A&o0+=;=AEWEx_+cUhmpZ!=W8Q})n?}*#G zu6IzaBo!p6s?Fjm+sO>~dEanazL(>?+QL?W*HPLSU(XK{QL4-U_Zemx-;BCl(MQ(_ z^5Y{^9qr;=v^9#XPcer|8-)>9TJ~W=StX%@70*W(Rj$^FQp0}3cYF`t%5KU*T*vwx zB`~9!C)}BY=MKwAIiPHb>7f18xIJ%(lU%%t;mKui)lX2O#57P5yj`lXn30u-*FFs1 zJbV8jPHCkZuDk+_<&p9%Rfh*lbLlgI@i4vCBK?v(VCD({M?gfVjWpl;L+Bf@QK{=T zpET`59mv(8bSF^c3dafmojp(|&Jq9_qWr1KgF!=$xoHgkzn?(X1vigi^W65^?4dp1UcVErS zgOJx+cf(S_=FN4cUSdl6qT`9@0dZ%u zDa1MTxvs-(sgXri=7-BpIa8IjpchSwmlN2L)ts>AtKt0JLWYHb&a=GLqS}Vino=6l zi3dY6s7zvy)birpq$#B&2`jvPTm%Th90m$401hNjP2(1%Qu{CNt+o9bjYNS zQSDLhmt3pHOW}?K^W^)`xJr$(O*zH6YC7-b1!Q?MvbZRR{S{CNy7eaP=K|ya49SXU zC3pJsH&UIG$F#tmIS>AgfvJ2z6-1<5vu(^aS?%Mg@N5Y~c323AK3R3>>_@-LjnF=O zP+Se;UYTb9je5KxxX3%jPt?S4`eJHaEdrv}&lqIu8dpT99MMS}NE}^*T3PmjQ47=% zD$oN}qyRwLq6hTyNaSm(x5h^bfXhgAsYk`x3Oftm1P}wB9snUbPSyGlxL`r>E6Yy_ zOuJeWw@zvjkBK=eUmnjW0~)*Ukzsd&W{LsapVt-ixy`fiJNLGRK_7i zQ{Au=3(>!lfBy9N7nJSSV2=N4yV-v*s70)!`q9P*k+8W)02p>q2I}2HrsA((AyJ>M zr-M1R{en!|*|9dw(=4Hv`C>iO%tsoHBu>sP(!c|A=-RlG4<(+9wcR02=!IAoTXtou zqKF%-3ltr4?`(L+%s3oOmTI>kJ-%Ugs@u)_ZqzCGTs#@CVRSqXBCp1eWi4NlmYApG>|itD}m~6x@*d8qP%x zB!>`vTJ}-#no>oV75!baEYwbHdalJ4pTGTb>$@_Wf$9TMHcg{upuNo?c3({Ctaol- zW5l&5l9qhIOvcX6zDVX5;vrp*Z8=|#%ogc=Zv}CfuiEC*z5rw(Ty(EvJ;DwEZEa~I zf;9yc)j+QHbhczCA3R3DO@2|Cmy3_WZ8_zmD_-~6;rJOV(8Uw6TJ8FzND6V@l-o3;4{f#h#I2GuYA>C>Q>xXu80I`lnfnrTy4r+2cQ$|{Zs zZO1MoHF^^|XT_@PiC#~fB1bGMo?aGx?|gnz_=kLpGP9!`So{3ZlNoO%GV1Fm{rIZ% z8%&`>p-e={P9)Z0<`ews8tvrL6k{r1ZIni_ohVCPwb~1d6Ynf_o}Wy7U(FNGAaO7& zRYHKQ*}b=+b?i>_gqjsz4ww}Y`Zln>CPSB`kug}y`&qT;`=+{QhfbF`J?PfQ&G-t} zhEbj?J6c<_k_lWm0N2CXBpZ3QqH{_WdEU`B9s*?%$>d*orOEX3U32~n04n(#=xTWoo+iGQAn&PsNpej@} zUiw~=Z(GBlJYCWwoykngJ`4yp7agh&EM9=gPf)+h()G-{EQJ9N6EX|jMdts2KZDG7 z|FJHc#7^j~gj%0Bv#$#lg{IGFT zZ}Z|($GOwwJ<8QdX{0MfX}l!|$4lwp{n8ov;*$6=zYXl~(tT*ktjZd+NIGU<1~7Zy z{wmAk6>2U~7@$dgo{*I3z4KwE6e!LT`Jt>-$$=x2UkB``dQ&?1SWqH+KYo7!Yq-fl zg`d#F7b}pXifQ-faRb#%w~vI}zsIJJ3QnM^9h4#S{WI{45!>xPC?|M~{f+ntm9xp{ zU9-o@2v5L;AYE$S35d76F==4jxu)dqm=l`Y+bkI3sZANAy|TkNPABp8G6Mah+Xh#a z_v2R;gq^H93HXMT&u5W~8w0?%km&6G!MaOFK-3r)pg38 zyU4`{%?qkG5CKW5s-!jZYB*VPCx5sMTiA(6On-n}I`Q*)qEFbB$Ge^E42y2L#G&MnZT=+@1W$oGNz%qZPmBiYWPUuw|F!|K(>W*8v zOlS0t-4(HBipf)C$4`(?h10@PoKD!;0K6C563w-Kq^kQaM_abvoO>lo%~MKn1#j8f zeO9vpE*f~BNj|G$g?HI|C29xn4BSsKnPU)=NodEl~I zu)FMIyC%*qXil%O-xnYR;o*a{)jWdPbrzD>6Z<1iK0KK~i&RbMiYRbM zmdbTez!8lKu5m_ut?XZPinMf-VtWZQa#75|rTG>sA7p$w_BuL5_xV zhguGbWVlJ&b0@3#jL3e0C_v35T*1t7nDqp+BKrBNgF3klP9Fz~cBp;5Jp5%%yq92V zCuqI)v%tn3}ljgT0!%esOg$B?(87AUVVHMxj zt>%QL1+1G6;WZJbJ_dX54+YUK(1!WCGr(gjt@Yu1+Gb&RRj17n)xD_PluqwG|KE$l z6h~s=T1IF^1~eANvRxv};Z+@y@wkEU$HBO){&iLD?F@*=q~Q&L@S{um_YJM8_0&{n z<`Kzjox|u#pfJ3KBCxNeg)r*}5sh#)ufJZ{99=ZK{w0j@HAs9S(ilDiUAqgztYNA; zlJ7{Qyeuc^=0|dgQ48WSlL94C_Uu#TqEFKLv9u=|7*Jt0y|vsdmsYe^Bi#2(twI(%Si` zT!@)2XZLO|Za2pWzO`_v!RyVv#+7V!8|m5XfHrHD!bVT#DH}Y-&^WZ=Oh3&ozc7tF zjZOs?0S46GcOYXM#+y`3wDqx$3kE9ZtJ)SIAtAy$8#ct8t;OOAc#8aqN1WU`Ra;n{Fwk%k~s#I_<&MRu05H;SS&Vz2pqU|%7 zXM!JA51Q%-qynkxB}0b=brVjo-O_&$U$vmdFYIH@(f0g3M{Z8f!#wW&KMxoG+h&+@ zf9Fs7A6b?Hxkw!0XQ&1P7!NzPfop`r`YdaQUy54h*|L+53- zMBPWj&xZF^P*vc~^%C008;R^~2=4J`GMeHJ0it)j`CUVoKDTtl2Jc1MXfV->s6wAk zi48QKrG$iPG!A>+?O#P|o+!HA>=z^!C;9qo;zy0+Dh@Rva8q+_iNHK#vNy#wkiAn| z3>_uMK#px0SgcW>=$9KUEwjFo#owLyLE=i2$LNWz=XFX;k2*rLH^wk_0lzz(Wj~%V z!SY!lW_e;GftNZ7WGzb$<)(&--_K2tPR%XfI;pRIZdUg&XTGm^woaKsvwCz{3dKLv z z$-fuFg-#|+et9dh{Pwy!YJEz5&+XaEtPvXppbvWjpq*>9vw_ZSqH^kHE}{(OsiuW1 z^@$z5YeIorKkSZvVV9MD9A%Y~6xFyv6Kv>6o8qV?N9I$724!jd<1M=|EK0ELhMO6- z%*a;%-i;SIXHOVKH9hV)T9U{JC9-$G#>{5k1%fT_@e?2L?#SBuTkA4YX?ZNj$Im zu)4}j8q2KH)in{EWRf1L zKE=KN?qvCr!d=jx-G&l^ZY z7>wJG8P#rob6&fZ=PE_oOY3s>?S0WGRJshV;7bgJ<+^J#5tZv^dO6&E0%HwELX`@u z_pjW)vgNU$tCslqgjjY%(vc(j<`ktsl{T~(xf%Zxlt4Sv=cx&dP?#=_FYW8a7In@~ zjgRT@T%M4lRyuDE)i{f1665xzIn{-4n!O4se24TmsoeU$H$b8%OqJqH#t`}gJD+5C z^)WUpoysWE#^JZu3K`u&58b;P>4B(T2nXaYLV&D7QO92+igaL#+|>QhZ@F|FUufVo z#TPnwjWYs7&+n~4J^M0u3U{9QTy=HB!XAbwIlduI$HQFMc>qsFt`f3r?+j2vcXGMB z0HVNm1zl0)m3Ia7y*0p@A_=DEhM{*X`w;XG))hCR&~(-HQAX}XiY!M9VjRy9Yf(Dp zFE@{h9~N9!5mn1C%F;cv*`d6A>)t>^z(1|by!D$?gBn$OKgGB{rnA)_rE;Ou%ZoMF zP@@0iE#0F`_e?-^J;wJ!r2)Guc%nAj<92)YK^u$jvZvFJda*V>FL`%XQ2CsM`Kf@j zt#pr}03W5<2-rwBF2iw%$XUYNR;bhn{xQY*My!RglHb|gp+?k;BbgrsYcxTbV-7pjshV(xySGNdoQ20KUiG$b{t$|1n>>9(L{7La(u1(#yfSCN<#$ z!=xl=X!0moaN>?5q0Ez!e+SyjO}C@Pvsq00f)U)f@@^M!MNrH~6hQSQ5IkH`rpD}%F=&r%iQDt}`LK{UDi?lO|n5@{04xjOKZlY!_B z_awG{o`_D7y29`JgT_A05BqZw{ybc6sfN2qnUNJhXqXt$yzXJ!RV=g2VDKGO)43W}5cbC-d0gBygiQ9LhJg?6s*K@=o) z+GX$roVEAMqsc(c5KFE1;r(DH-}`rqLT?b9p50(!K^(W%j|k>^qcOthM4=1swKxQk z!|6rLD?C;-Iqj3(@Ol9mJXwB_NMm#@FYyQLmm1H-6+Z0A5b`0%fsFM@U&-tRPKZl{D=pSwDlsh$ zT2m&iUQDOgF+_FTtF7As2`quQ|=x(76{!3$;ci5#xgLmG`gxXQ=ESt7Ka8RmTzetz#NbuTVB&heW zzK|dOT`=u;=VF;EXpI(L+mU?vU2!J6%ns96hmqGK0R!;Tb7Xzer$)}8>Eqf0K*1!- zwYtyqwN7)6+JB6rvouT}XeR1Zny0aGn^)wBl-8Q3Idj%zr3=$`1A>VVXBhV6!rg|0 zLfP5w_2;((Z?+$~5NOx@QGO-+)s7Z$?US9W8ALOS+Yn&--|_!Eu^RZ zAmI*iyi^`@&*b=&6tJBGNfY<+336jQuWFjKU_(W{o+19o_lr5~sTf?OVtm^(KY*OT zuVgAnkkxedPMmOxExd{ZQ@NnHE0Hk;AINwc%X|H&^Je4IZZoNMgY+H-?Z@rc06rDh z*~zB>o#Zl#U28gO-OteWy-Y~)WF0=i(SZFLdy~aC1eGh6rhCYWpw&WkrBc78rlE)dLVy&vCB$RrLi?|hpD>+&&OV~&!|Sc zi4J^EJ;itj+>)D$y`^>L#8n0EXwcjK_(b5tq2fX>A2gHzE`9ac<(LqAJJ)iJ&!Zhs zsV&5ArM(%j-KUWNpfiQI=#MvOlC8aRg0kbgmc`s?~@`)3u ztnk%mi2BxYo*CF*y*j{a`K$MqD;i#hc!+lPEB-xgOy$#hTEu-B0bT&@RAWlMrPjclu+gDjmU*}%&bFXgU9YcULJNvpX z+`lkB|Md93FMw9*kNO#O#UF|-70%O`chp@%_4@`*bwih?xGkU3Mll@K$QTSA)hr}u zslIA=U}?luwz%MGxroH;r#6N6njUUTF)L9iHeCeRDU$!{g$KvO=FW@2I0)bD1+1UhNRL zDVe=(Sd-~TlC@Ma*@$>qm;8;e zp-W7QopbKrF$f*JEss_{T--5=e^gO{@V4`Fl9{Ee_2-MAU`3vG~DaDy7T4(QC^yx#_- zL5$R@2J?$VM}(x^o&Vl zILCyW0-(el-dBL;B>6qdO1X2Y=3My+OUdUs^pQ(;ipyu(W#3PHG4vbML8U;>cfo{rYwJP4S1g*9+F?dOna2H}=h4opEispXagP_unan zJXKwv{EWXn{iyxs?Qc21o$nUkRkz;!sKzCYKb`_Nc7=Y?igQc*QG){@;05-^BCT@%}J> zsJr)0oXCZ>UnW=FUYF){Ugt@H`EH@J4GfYtVe6Y}*ca4Te)v-*-}dE1$yzl_xzMfc z*3Oo9j!(Wb=UULJrIPM@_iWplWtKTl{x{=BU=>gkQT%&_>EAeu>%UijyXj|dRB!dm zYQrDt2esS}-}RKt=X*IJ)2eV@^yK_Jj(8(?edStjp?ax1YhIlVnX_cqbkWlyO_;}T zhF>y8-W(2Wq5yY?&jqa@;Hu}gTX_Mv#gY>^o^X9n{V$HySKWbg;1l$K+X~}80#__4 ztq;FGF92y{45Kq}5y}MH`+ti~?SP6De&_#svj0Q~)-ra!W2p{MYa6W1LJeAU_WJtc zT|n1{{3~6Gxc1pMPc4%gR@JP=v!iYOwEeh_;|pQbg)O22(jthB+$-)WdDNoL;t`34EO&td@MoZmwr7f zdj4ya{i)2i>34ogm%p6|Jgri{rT#lmJ@9Dm#pnX-*JUhzwWWTF?V8{7i7;EFHmv_s z`_}ozPjc^XbN|b*PY~#g-@t4A^m@_w`*tqsUv=?cz_-osYI?zw84JV@*#ni<1J5JB zh%Rvbx=A3&)pFNr&rjT5cT@fg(+}VrdHwtU3_o_6qw#+|>#|$z{?qx^`JJE4-rrXK z*AUO~pW*#qkPpQ}@_>id^~G-y+qy8SChOWB^@Ts!J@>Z<%|E=4zprd^_$;-%LjF2y z%8p1dg;qUY=8*bZu4~1^m9C#PUa1H9Ms5^!TEURO#~`vLFjftDOCl?8`9Fv6^>1AA9Mv;9Ca$8>I zOkA^ZlT|x#i+17!$2S?wPv-y7n74b`kLllhV*Nd?=$6KsJw763?Od5`be5+f-i-)@((&z0a*G%Wsm~}?aqs?h|toxl4dCl5ip1-;M z+v-QN`$yONN6&Xp|M>d8*w^YprlXu^=eNH(uXO5aaby3Z`bU;ep3ZErn~2n$BK2M^ zCZy&sN$jDx=i20FKtwdtEB+~a|F-(ua29jBD!p@)P`lBKE;0j8RjScXQ-9`xPX`6(th9mxpGDIwe?oO1**uBi~k9pwfH zC;xu^E2%vB_v@e6_DC}QKio~;|NJKZ{UY$d9EJ;@C;xu^^CH-pYk&m?lAJw~E0{+e zhzQ%!5Tr#KyjHeX$#l}gDgQoJ|Goqqkwau7U_#}%Hy3zm@8Qq(j~#l6%mIzFAejW5 z4@!ZpMviA-GyeUuf9(E8c~<$K-~7K{0H>xHGQgSS1vn>Mt4EfSN6N5_MCN6-c5p84 z2ZvEQhDQ*NaF7~xAlA}wGz96F28HgPn4h!R`ucnG`mF8nTEOiKsD_vctRRppfzjG* zw7P}1N=DlQ1X@(rVtS;HOnV~We)IP7Yrs8T$nCV1YfGxVC2LM3a{M?}lp$CK9Q!pU rdkm7o6Gw>YK8*l1j6c=_hhTuiFq4-7cb|7o+y5)J7I>Zx`~RB&kn;59 literal 0 HcmV?d00001 diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py index 007f2d0fede4..5870070a54c7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py @@ -1,4 +1,10 @@ """ +Simple KV Cache Connector for Distributed Machine Learning Inference + +The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache +producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe. + +But the logic can be extended to support other pipe and lookup buffer. """ from typing import TYPE_CHECKING, List, Optional, Tuple, Union diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py index 81d9880d2f18..fe8d8d7375f3 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py @@ -1,6 +1,5 @@ """ - Implements a distributed key-value (KV) cache transfer mechanism for vLLM - instances with buffer management. + Implements a distributed key-value (KV) cache transfer mechanism. Key Features: - Distributed KV cache transmission using PyNccl pipes. From 8d6916a5573f12804f8647c45aa6f988f8942be5 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 1 Dec 2024 22:18:53 +0000 Subject: [PATCH 31/33] Merge disaggregated prefill test into existing 2GPU test. Signed-off-by: KuntaiDu --- .buildkite/test-pipeline.yaml | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 24833f743ebd..fbe6c494e189 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -401,18 +401,6 @@ steps: - pytest -v -s distributed/test_comm_ops.py - pytest -v -s distributed/test_shm_broadcast.py -- label: Disaggregated Prefill Test # 2min - working_dir: "/vllm-workspace/tests" - num_gpus: 2 - source_file_dependencies: - - vllm/distributed/parallel_state.py - - vllm/distributed/kv_transfer - - vllm/worker/worker_base.py - - vllm/worker/worker.py - - vllm/worker/model_runner.py - commands: - - pytest -v -s kv_transfer/disagg_test.py - - label: 2 Node Tests (4 GPUs in total) # 16min working_dir: "/vllm-workspace/tests" num_gpus: 2 @@ -442,6 +430,9 @@ steps: - vllm/model_executor/models/ - tests/distributed/ - vllm/compilation + - vllm/worker/worker_base.py + - vllm/worker/worker.py + - vllm/worker/model_runner.py commands: - pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_wrapper.py @@ -455,6 +446,7 @@ steps: - pip install -e ./plugins/vllm_add_dummy_model - pytest -v -s distributed/test_distributed_oot.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py + - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/disagg_test.py - label: Multi-step Tests (4 GPUs) # 36min working_dir: "/vllm-workspace/tests" From a05676dc5bb0e02b3b24cf69f3f7b5cb672c0724 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 1 Dec 2024 22:23:21 +0000 Subject: [PATCH 32/33] Rename config to vllm_config for consistency. Signed-off-by: KuntaiDu --- vllm/distributed/parallel_state.py | 8 ++++---- vllm/worker/worker.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 7f15d8119bef..ed6929a94778 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1103,24 +1103,24 @@ def initialize_model_parallel( group_name="pp") -def ensure_kv_transfer_initialized(config: "VllmConfig") -> None: +def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: """ Initialize KV cache transfer parallel group. """ global _KV_TRANSFER - if config.kv_transfer_config is None: + if vllm_config.kv_transfer_config is None: return if all([ - config.kv_transfer_config.need_kv_parallel_group, + vllm_config.kv_transfer_config.need_kv_parallel_group, _KV_TRANSFER is None ]): _KV_TRANSFER = kv_transfer.KVTransferAgent( rank=get_world_group().rank, local_rank=get_world_group().local_rank, - config=config) + config=vllm_config) def ensure_model_parallel_initialized( diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 04e3d1c14452..719ca7547b20 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -457,13 +457,13 @@ def get_cache_block_size_bytes(self) -> int: def init_worker_distributed_environment( - config: VllmConfig, + vllm_config: VllmConfig, rank: int, distributed_init_method: Optional[str] = None, local_rank: int = -1, ) -> None: """Initialize the distributed environment.""" - parallel_config = config.parallel_config + parallel_config = vllm_config.parallel_config set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) init_distributed_environment(parallel_config.world_size, rank, @@ -471,7 +471,7 @@ def init_worker_distributed_environment( ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) - ensure_kv_transfer_initialized(config) + ensure_kv_transfer_initialized(vllm_config) def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): From c3aa7bb14a6ac472c0678a21f0b6efa5d378d757 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 1 Dec 2024 22:26:48 +0000 Subject: [PATCH 33/33] fix typo in README Signed-off-by: KuntaiDu --- vllm/distributed/kv_transfer/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/README.md b/vllm/distributed/kv_transfer/README.md index 76911bee4644..dab2d10c4c9d 100644 --- a/vllm/distributed/kv_transfer/README.md +++ b/vllm/distributed/kv_transfer/README.md @@ -18,7 +18,7 @@ NOTE: KV pipe layer is bypassible: you can skip this layer if your distributed communication service already supports key-value-based lookup (like redis or RDMA database). -NOTE: If you want to not only transfer KV caches, but adjust the model exectuion flow of vLLM as well (for example, allow vLLM to receive KV caches on some tokens and do prefill on the remaining tokens), you can bypass both KV pipe layer and KV lookup buffer layer, and directly implement on KV connector layer. Bear in mind that as vLLM's model input is constantly changing, this implementation will likely be broken when vLLM has new updates. +NOTE: If you want to not only transfer KV caches, but adjust the model execution flow of vLLM as well (for example, allow vLLM to receive KV caches on some tokens and do prefill on the remaining tokens), you can bypass both KV pipe layer and KV lookup buffer layer, and directly implement on KV connector layer. Bear in mind that as vLLM's model input is constantly changing, this implementation will likely be broken when vLLM has new updates. ## Disaggregated prefilling