Skip to content

Automated Test: span-flusher-multiprocess #333

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,17 @@ for org in organizations:

# RIGHT: Use prefetch_related
organizations.prefetch_related('projects')

# WRONG: Use hasattr() for unions
x: str | None = "hello"
if hasattr(x, "replace"):
x = x.replace("e", "a")

# RIGHT: Use isinstance()
x: str | None = "hello"
if isinstance(x, str):
x = x.replace("e", "a")

```

### Frontend
Expand Down
10 changes: 9 additions & 1 deletion src/sentry/consumers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,15 @@ def ingest_transactions_options() -> list[click.Option]:
"topic": Topic.INGEST_SPANS,
"dlq_topic": Topic.INGEST_SPANS_DLQ,
"strategy_factory": "sentry.spans.consumers.process.factory.ProcessSpansStrategyFactory",
"click_options": multiprocessing_options(default_max_batch_size=100),
"click_options": [
*multiprocessing_options(default_max_batch_size=100),
click.Option(
["--flusher-processes", "flusher_processes"],
default=1,
type=int,
help="Maximum number of processes for the span flusher. Defaults to 1.",
),
],
Comment on lines +430 to +438

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

No validation that --flusher-processes is ≥ 1.

A user could pass --flusher-processes 0 or a negative value, which would cause unexpected behavior downstream (e.g., min(0, num_shards) → 0 processes, leading to empty dicts and no flushing). Consider adding a callback or clamping in the factory.

🛡️ Proposed fix: add a callback validator
+def _validate_positive_int(ctx, param, value):
+    if value < 1:
+        raise click.BadParameter("must be >= 1")
+    return value
+
 "process-spans": {
     "topic": Topic.INGEST_SPANS,
     "dlq_topic": Topic.INGEST_SPANS_DLQ,
     "strategy_factory": "sentry.spans.consumers.process.factory.ProcessSpansStrategyFactory",
     "click_options": [
         *multiprocessing_options(default_max_batch_size=100),
         click.Option(
             ["--flusher-processes", "flusher_processes"],
             default=1,
             type=int,
+            callback=_validate_positive_int,
             help="Maximum number of processes for the span flusher. Defaults to 1.",
         ),
     ],
 },
🤖 Prompt for AI Agents
In `@src/sentry/consumers/__init__.py` around lines 430 - 438, The click option
"--flusher-processes" (click.Option with dest "flusher_processes") lacks
validation allowing 0 or negative values; update the option to enforce a minimum
of 1 by adding a callback validator on that click.Option (or clamp the value in
the span flusher factory where flusher_processes is consumed) so any input < 1
is rejected or coerced to 1, and ensure error text clearly states it must be >=
1.

},
"process-segments": {
"topic": Topic.BUFFERED_SEGMENTS,
Expand Down
3 changes: 3 additions & 0 deletions src/sentry/spans/consumers/process/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
num_processes: int,
input_block_size: int | None,
output_block_size: int | None,
flusher_processes: int | None = None,
produce_to_pipe: Callable[[KafkaPayload], None] | None = None,
):
super().__init__()
Expand All @@ -48,6 +49,7 @@ def __init__(
self.input_block_size = input_block_size
self.output_block_size = output_block_size
self.num_processes = num_processes
self.flusher_processes = flusher_processes
self.produce_to_pipe = produce_to_pipe

if self.num_processes != 1:
Expand All @@ -69,6 +71,7 @@ def create_with_partitions(
flusher = self._flusher = SpanFlusher(
buffer,
next_step=committer,
max_processes=self.flusher_processes,
produce_to_pipe=self.produce_to_pipe,
)

Expand Down
174 changes: 127 additions & 47 deletions src/sentry/spans/consumers/process/flusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from sentry import options
from sentry.conf.types.kafka_definition import Topic
from sentry.processing.backpressure.memory import ServiceMemory
from sentry.spans.buffer import SpansBuffer
from sentry.utils import metrics
from sentry.utils.arroyo import run_with_initialized_sentry
Expand All @@ -27,7 +28,8 @@

class SpanFlusher(ProcessingStrategy[FilteredPayload | int]):
"""
A background thread that polls Redis for new segments to flush and to produce to Kafka.
A background multiprocessing manager that polls Redis for new segments to flush and to produce to Kafka.
Creates one process per shard for parallel processing.
This is a processing step to be embedded into the consumer that writes to
Redis. It takes and fowards integer messages that represent recently
Expand All @@ -42,27 +44,53 @@ def __init__(
self,
buffer: SpansBuffer,
next_step: ProcessingStrategy[FilteredPayload | int],
max_processes: int | None = None,
produce_to_pipe: Callable[[KafkaPayload], None] | None = None,
):
self.buffer = buffer
self.next_step = next_step
self.max_processes = max_processes or len(buffer.assigned_shards)

self.mp_context = mp_context = multiprocessing.get_context("spawn")
self.stopped = mp_context.Value("i", 0)
self.redis_was_full = False
self.current_drift = mp_context.Value("i", 0)
self.backpressure_since = mp_context.Value("i", 0)
self.healthy_since = mp_context.Value("i", 0)
self.process_restarts = 0
self.produce_to_pipe = produce_to_pipe

self._create_process()

def _create_process(self):
# Determine which shards get their own processes vs shared processes
self.num_processes = min(self.max_processes, len(buffer.assigned_shards))
self.process_to_shards_map: dict[int, list[int]] = {
i: [] for i in range(self.num_processes)
}
for i, shard in enumerate(buffer.assigned_shards):
process_index = i % self.num_processes
self.process_to_shards_map[process_index].append(shard)

self.processes: dict[int, multiprocessing.context.SpawnProcess | threading.Thread] = {}
self.process_healthy_since = {
process_index: mp_context.Value("i", int(time.time()))
for process_index in range(self.num_processes)
}
self.process_backpressure_since = {
process_index: mp_context.Value("i", 0) for process_index in range(self.num_processes)
}
self.process_restarts = {process_index: 0 for process_index in range(self.num_processes)}
self.buffers: dict[int, SpansBuffer] = {}

self._create_processes()

def _create_processes(self):
# Create processes based on shard mapping
for process_index, shards in self.process_to_shards_map.items():
self._create_process_for_shards(process_index, shards)

def _create_process_for_shards(self, process_index: int, shards: list[int]):
# Optimistically reset healthy_since to avoid a race between the
# starting process and the next flush cycle. Keep back pressure across
# the restart, however.
self.healthy_since.value = int(time.time())
self.process_healthy_since[process_index].value = int(time.time())

# Create a buffer for these specific shards
shard_buffer = SpansBuffer(shards)

make_process: Callable[..., multiprocessing.context.SpawnProcess | threading.Thread]
if self.produce_to_pipe is None:
Expand All @@ -72,37 +100,50 @@ def _create_process(self):
# pickled separately. at the same time, pickling
# synchronization primitives like multiprocessing.Value can
# only be done by the Process
self.buffer,
shard_buffer,
)
make_process = self.mp_context.Process
else:
target = partial(SpanFlusher.main, self.buffer)
target = partial(SpanFlusher.main, shard_buffer)
make_process = threading.Thread

self.process = make_process(
process = make_process(
target=target,
args=(
shards,
self.stopped,
self.current_drift,
self.backpressure_since,
self.healthy_since,
self.process_backpressure_since[process_index],
self.process_healthy_since[process_index],
self.produce_to_pipe,
),
daemon=True,
)

self.process.start()
process.start()
self.processes[process_index] = process
self.buffers[process_index] = shard_buffer

def _create_process_for_shard(self, shard: int):
# Find which process this shard belongs to and restart that process
for process_index, shards in self.process_to_shards_map.items():
if shard in shards:
self._create_process_for_shards(process_index, shards)
break

@staticmethod
def main(
buffer: SpansBuffer,
shards: list[int],
stopped,
current_drift,
backpressure_since,
healthy_since,
produce_to_pipe: Callable[[KafkaPayload], None] | None,
) -> None:
shard_tag = ",".join(map(str, shards))
sentry_sdk.set_tag("sentry_spans_buffer_component", "flusher")
sentry_sdk.set_tag("sentry_spans_buffer_shards", shard_tag)

try:
producer_futures = []
Expand Down Expand Up @@ -134,23 +175,28 @@ def produce(payload: KafkaPayload) -> None:
else:
backpressure_since.value = 0

# Update healthy_since for all shards handled by this process
healthy_since.value = system_now

if not flushed_segments:
time.sleep(1)
continue

with metrics.timer("spans.buffer.flusher.produce"):
for _, flushed_segment in flushed_segments.items():
with metrics.timer("spans.buffer.flusher.produce", tags={"shard": shard_tag}):
for flushed_segment in flushed_segments.values():
if not flushed_segment.spans:
continue

spans = [span.payload for span in flushed_segment.spans]
kafka_payload = KafkaPayload(None, orjson.dumps({"spans": spans}), [])
metrics.timing("spans.buffer.segment_size_bytes", len(kafka_payload.value))
metrics.timing(
"spans.buffer.segment_size_bytes",
len(kafka_payload.value),
tags={"shard": shard_tag},
)
produce(kafka_payload)

with metrics.timer("spans.buffer.flusher.wait_produce"):
with metrics.timer("spans.buffer.flusher.wait_produce", tags={"shards": shard_tag}):
Comment on lines +185 to +199

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Inconsistent metric tag key: "shard" vs "shards".

Lines 185 and 195 use tags={"shard": shard_tag} (singular), while line 199 uses tags={"shards": shard_tag} (plural). This will create split metrics in your dashboards.

🔧 Proposed fix
-                with metrics.timer("spans.buffer.flusher.wait_produce", tags={"shards": shard_tag}):
+                with metrics.timer("spans.buffer.flusher.wait_produce", tags={"shard": shard_tag}):
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
with metrics.timer("spans.buffer.flusher.produce", tags={"shard": shard_tag}):
for flushed_segment in flushed_segments.values():
if not flushed_segment.spans:
continue
spans = [span.payload for span in flushed_segment.spans]
kafka_payload = KafkaPayload(None, orjson.dumps({"spans": spans}), [])
metrics.timing("spans.buffer.segment_size_bytes", len(kafka_payload.value))
metrics.timing(
"spans.buffer.segment_size_bytes",
len(kafka_payload.value),
tags={"shard": shard_tag},
)
produce(kafka_payload)
with metrics.timer("spans.buffer.flusher.wait_produce"):
with metrics.timer("spans.buffer.flusher.wait_produce", tags={"shards": shard_tag}):
with metrics.timer("spans.buffer.flusher.produce", tags={"shard": shard_tag}):
for flushed_segment in flushed_segments.values():
if not flushed_segment.spans:
continue
spans = [span.payload for span in flushed_segment.spans]
kafka_payload = KafkaPayload(None, orjson.dumps({"spans": spans}), [])
metrics.timing(
"spans.buffer.segment_size_bytes",
len(kafka_payload.value),
tags={"shard": shard_tag},
)
produce(kafka_payload)
with metrics.timer("spans.buffer.flusher.wait_produce", tags={"shard": shard_tag}):
🤖 Prompt for AI Agents
In `@src/sentry/spans/consumers/process/flusher.py` around lines 185 - 199, The
metric tag key is inconsistent: change the tags on the metrics.timer call that
currently uses tags={"shards": shard_tag} to use the same key as the others
(tags={"shard": shard_tag}) so
metrics.timer("spans.buffer.flusher.wait_produce", ...) aligns with the earlier
metrics.timer and metrics.timing calls that use shard_tag; update the call near
the produce loop that references wait_produce to use "shard" (symbols to locate:
metrics.timer, metrics.timing, produce, flushed_segments, shard_tag,
KafkaPayload).

for future in producer_futures:
future.result()

Expand All @@ -169,46 +215,71 @@ def produce(payload: KafkaPayload) -> None:
def poll(self) -> None:
self.next_step.poll()

def _ensure_process_alive(self) -> None:
def _ensure_processes_alive(self) -> None:
max_unhealthy_seconds = options.get("spans.buffer.flusher.max-unhealthy-seconds")
if not self.process.is_alive():
exitcode = getattr(self.process, "exitcode", "unknown")
cause = f"no_process_{exitcode}"
elif int(time.time()) - self.healthy_since.value > max_unhealthy_seconds:
cause = "hang"
else:
return # healthy

metrics.incr("spans.buffer.flusher_unhealthy", tags={"cause": cause})
if self.process_restarts > MAX_PROCESS_RESTARTS:
raise RuntimeError(f"flusher process crashed repeatedly ({cause}), restarting consumer")
for process_index, process in self.processes.items():
if not process:
continue

shards = self.process_to_shards_map[process_index]

cause = None
if not process.is_alive():
exitcode = getattr(process, "exitcode", "unknown")
cause = f"no_process_{exitcode}"
elif (
int(time.time()) - self.process_healthy_since[process_index].value
> max_unhealthy_seconds
):
# Check if any shard handled by this process is unhealthy
cause = "hang"

if cause is None:
continue # healthy

# Report unhealthy for all shards handled by this process
for shard in shards:
metrics.incr(
"spans.buffer.flusher_unhealthy", tags={"cause": cause, "shard": shard}
)

try:
self.process.kill()
except ValueError:
pass # Process already closed, ignore
if self.process_restarts[process_index] > MAX_PROCESS_RESTARTS:
raise RuntimeError(
f"flusher process for shards {shards} crashed repeatedly ({cause}), restarting consumer"
)
self.process_restarts[process_index] += 1

self.process_restarts += 1
self._create_process()
try:
if isinstance(process, multiprocessing.Process):
process.kill()
except (ValueError, AttributeError):
pass # Process already closed, ignore

self._create_process_for_shards(process_index, shards)

def submit(self, message: Message[FilteredPayload | int]) -> None:
# Note that submit is not actually a hot path. Their message payloads
# are mapped from *batches* of spans, and there are a handful of spans
# per second at most. If anything, self.poll() might even be called
# more often than submit()

self._ensure_process_alive()
self._ensure_processes_alive()

self.buffer.record_stored_segments()
for buffer in self.buffers.values():
buffer.record_stored_segments()

# We pause insertion into Redis if the flusher is not making progress
# fast enough. We could backlog into Redis, but we assume, despite best
# efforts, it is still always going to be less durable than Kafka.
# Minimizing our Redis memory usage also makes COGS easier to reason
# about.
if self.backpressure_since.value > 0:
backpressure_secs = options.get("spans.buffer.flusher.backpressure-seconds")
if int(time.time()) - self.backpressure_since.value > backpressure_secs:
backpressure_secs = options.get("spans.buffer.flusher.backpressure-seconds")
for backpressure_since in self.process_backpressure_since.values():
if (
backpressure_since.value > 0
and int(time.time()) - backpressure_since.value > backpressure_secs
):
metrics.incr("spans.buffer.flusher.backpressure")
raise MessageRejected()

Expand All @@ -225,7 +296,9 @@ def submit(self, message: Message[FilteredPayload | int]) -> None:
# wait until the situation is improved manually.
max_memory_percentage = options.get("spans.buffer.max-memory-percentage")
if max_memory_percentage < 1.0:
memory_infos = list(self.buffer.get_memory_info())
memory_infos: list[ServiceMemory] = []
for buffer in self.buffers.values():
memory_infos.extend(buffer.get_memory_info())
used = sum(x.used for x in memory_infos)
available = sum(x.available for x in memory_infos)
if available > 0 and used / available > max_memory_percentage:
Expand Down Expand Up @@ -253,15 +326,22 @@ def close(self) -> None:
self.next_step.close()

def join(self, timeout: float | None = None):
# set stopped flag first so we can "flush" the background thread while
# set stopped flag first so we can "flush" the background threads while
# next_step is also shutting down. we can do two things at once!
self.stopped.value = True
deadline = time.time() + timeout if timeout else None

self.next_step.join(timeout)

while self.process.is_alive() and (deadline is None or deadline > time.time()):
time.sleep(0.1)
# Wait for all processes to finish
for process_index, process in self.processes.items():
if deadline is not None:
remaining_time = deadline - time.time()
if remaining_time <= 0:
break

while process.is_alive() and (deadline is None or deadline > time.time()):
time.sleep(0.1)

if isinstance(self.process, multiprocessing.Process):
self.process.terminate()
if isinstance(process, multiprocessing.Process):
process.terminate()
Comment on lines +336 to +347

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

break on deadline expiry skips terminate() for remaining processes, risking zombie processes.

If the deadline is exceeded while waiting for the first process, the break on line 341 exits the loop entirely, leaving remaining processes un-terminated. Since stopped.value = True is a cooperative signal, a stuck process could linger.

🐛 Proposed fix: terminate all processes regardless of deadline
     def join(self, timeout: float | None = None):
         # set stopped flag first so we can "flush" the background threads while
         # next_step is also shutting down. we can do two things at once!
         self.stopped.value = True
         deadline = time.time() + timeout if timeout else None

         self.next_step.join(timeout)

         # Wait for all processes to finish
-        for process_index, process in self.processes.items():
+        for process in self.processes.values():
             if deadline is not None:
                 remaining_time = deadline - time.time()
                 if remaining_time <= 0:
                     break

             while process.is_alive() and (deadline is None or deadline > time.time()):
                 time.sleep(0.1)

+        # Terminate all remaining processes regardless of deadline
+        for process in self.processes.values():
             if isinstance(process, multiprocessing.Process):
                 process.terminate()
🧰 Tools
🪛 Ruff (0.14.14)

[warning] 337-337: Loop control variable process_index not used within loop body

Rename unused process_index to _process_index

(B007)

🤖 Prompt for AI Agents
In `@src/sentry/spans/consumers/process/flusher.py` around lines 336 - 347, The
loop over self.processes currently breaks out when deadline expires, which skips
calling process.terminate() for remaining entries; change the logic so that when
remaining_time <= 0 you do not break the for-loop but instead skip the waiting
while-loop and continue to the termination step for each remaining process.
Concretely, in the for process_index, process in self.processes.items() loop
remove the break and implement a branch (e.g., if remaining_time <= 0: skip
waiting) so the subsequent isinstance(process, multiprocessing.Process):
process.terminate() still runs for every process even after the deadline has
passed.

Loading