Skip to content

Commit

Permalink
fix: memory leak in the broker (#10567)
Browse files Browse the repository at this point in the history
This PR fixes a memory leak in the prover broker by cleaning up jobs
after their result is saved by the orchestrator.

The orchestrator then does clean up on its own after the epoch is
finished.
  • Loading branch information
alexghr authored Dec 11, 2024
1 parent 09e95a1 commit ecc037f
Show file tree
Hide file tree
Showing 11 changed files with 483 additions and 174 deletions.
10 changes: 8 additions & 2 deletions yarn-project/circuit-types/src/interfaces/prover-broker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,16 @@ export interface ProvingJobProducer {
enqueueProvingJob(job: ProvingJob): Promise<void>;

/**
* Cancels a proving job and clears all of its
* Cancels a proving job.
* @param id - The ID of the job to cancel
*/
removeAndCancelProvingJob(id: ProvingJobId): Promise<void>;
cancelProvingJob(id: ProvingJobId): Promise<void>;

/**
* Cleans up after a job has completed. Throws if the job is in-progress
* @param id - The ID of the job to cancel
*/
cleanUpProvingJobState(id: ProvingJobId): Promise<void>;

/**
* Returns the current status fof the proving job
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class MockProvingJobSource implements ProvingJobSource {
id: 'a-job-id',
type: ProvingRequestType.PRIVATE_BASE_ROLLUP,
inputsUri: 'inputs-uri' as ProofUri,
epochNumber: 1,
});
}
heartbeat(jobId: string): Promise<void> {
Expand Down
2 changes: 1 addition & 1 deletion yarn-project/circuit-types/src/interfaces/proving-job.ts
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ export type ProvingJobId = z.infer<typeof ProvingJobId>;
export const ProvingJob = z.object({
id: ProvingJobId,
type: z.nativeEnum(ProvingRequestType),
blockNumber: z.number().optional(),
epochNumber: z.number(),
inputsUri: ProofUri,
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ export class MemoryProvingQueue implements ServerCircuitProver, ProvingJobSource
id: job.id,
type: job.type,
inputsUri: job.inputsUri,
epochNumber: job.epochNumber,
};
} catch (err) {
if (err instanceof TimeoutError) {
Expand Down Expand Up @@ -244,7 +245,7 @@ export class MemoryProvingQueue implements ServerCircuitProver, ProvingJobSource
reject,
attempts: 1,
heartbeat: 0,
epochNumber,
epochNumber: epochNumber ?? 0,
};

if (signal) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ describe('CachingBrokerFacade', () => {
broker = mock<ProvingJobProducer>({
enqueueProvingJob: jest.fn<any>(),
getProvingJobStatus: jest.fn<any>(),
removeAndCancelProvingJob: jest.fn<any>(),
cancelProvingJob: jest.fn<any>(),
cleanUpProvingJobState: jest.fn<any>(),
waitForJobToSettle: jest.fn<any>(),
});
cache = new InMemoryProverCache();
Expand Down Expand Up @@ -101,4 +102,55 @@ describe('CachingBrokerFacade', () => {
await expect(facade.getBaseParityProof(inputs)).resolves.toEqual(result);
expect(broker.enqueueProvingJob).toHaveBeenCalledTimes(1); // job was only ever enqueued once
});

it('clears broker state after a job resolves', async () => {
const { promise, resolve } = promiseWithResolvers<any>();
broker.enqueueProvingJob.mockResolvedValue(Promise.resolve());
broker.waitForJobToSettle.mockResolvedValue(promise);

const inputs = makeBaseParityInputs();
void facade.getBaseParityProof(inputs);
await jest.advanceTimersToNextTimerAsync();

const job = broker.enqueueProvingJob.mock.calls[0][0];
const result = makePublicInputsAndRecursiveProof(
makeParityPublicInputs(),
makeRecursiveProof(RECURSIVE_PROOF_LENGTH),
VerificationKeyData.makeFakeHonk(),
);
const outputUri = await proofStore.saveProofOutput(job.id, ProvingRequestType.BASE_PARITY, result);
resolve({
status: 'fulfilled',
value: outputUri,
});

await jest.advanceTimersToNextTimerAsync();
expect(broker.cleanUpProvingJobState).toHaveBeenCalled();
});

it('clears broker state after a job is canceled', async () => {
const { promise, resolve } = promiseWithResolvers<any>();
const catchSpy = jest.fn();
broker.enqueueProvingJob.mockResolvedValue(Promise.resolve());
broker.waitForJobToSettle.mockResolvedValue(promise);

const inputs = makeBaseParityInputs();
const controller = new AbortController();
void facade.getBaseParityProof(inputs, controller.signal).catch(catchSpy);
await jest.advanceTimersToNextTimerAsync();

expect(broker.cancelProvingJob).not.toHaveBeenCalled();
controller.abort();
await jest.advanceTimersToNextTimerAsync();
expect(broker.cancelProvingJob).toHaveBeenCalled();

resolve({
status: 'rejected',
reason: 'Aborted',
});

await jest.advanceTimersToNextTimerAsync();
expect(broker.cleanUpProvingJobState).toHaveBeenCalled();
expect(catchSpy).toHaveBeenCalledWith(new Error('Aborted'));
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ export class CachingBrokerFacade implements ServerCircuitProver {
id: ProvingJobId,
type: T,
inputs: ProvingJobInputsMap[T],
epochNumber = 0,
signal?: AbortSignal,
): Promise<ProvingJobResultsMap[T]> {
// first try the cache
Expand Down Expand Up @@ -95,6 +96,7 @@ export class CachingBrokerFacade implements ServerCircuitProver {
id,
type,
inputsUri,
epochNumber,
});
await this.cache.setProvingJobStatus(id, { status: 'in-queue' });
} catch (err) {
Expand All @@ -107,7 +109,7 @@ export class CachingBrokerFacade implements ServerCircuitProver {
// notify broker of cancelled job
const abortFn = async () => {
signal?.removeEventListener('abort', abortFn);
await this.broker.removeAndCancelProvingJob(id);
await this.broker.cancelProvingJob(id);
};

signal?.addEventListener('abort', abortFn);
Expand Down Expand Up @@ -147,160 +149,174 @@ export class CachingBrokerFacade implements ServerCircuitProver {
}
} finally {
signal?.removeEventListener('abort', abortFn);
// we've saved the result in our cache. We can tell the broker to clear its state
await this.broker.cleanUpProvingJobState(id);
}
}

getAvmProof(
inputs: AvmCircuitInputs,
signal?: AbortSignal,
_blockNumber?: number,
epochNumber?: number,
): Promise<ProofAndVerificationKey<typeof AVM_PROOF_LENGTH_IN_FIELDS>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.PUBLIC_VM, inputs),
ProvingRequestType.PUBLIC_VM,
inputs,
epochNumber,
signal,
);
}

getBaseParityProof(
inputs: BaseParityInputs,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<PublicInputsAndRecursiveProof<ParityPublicInputs, typeof RECURSIVE_PROOF_LENGTH>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.BASE_PARITY, inputs),
ProvingRequestType.BASE_PARITY,
inputs,
epochNumber,
signal,
);
}

getBlockMergeRollupProof(
input: BlockMergeRollupInputs,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<PublicInputsAndRecursiveProof<BlockRootOrBlockMergePublicInputs, typeof RECURSIVE_PROOF_LENGTH>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.BLOCK_MERGE_ROLLUP, input),
ProvingRequestType.BLOCK_MERGE_ROLLUP,
input,
epochNumber,
signal,
);
}

getBlockRootRollupProof(
input: BlockRootRollupInputs,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<PublicInputsAndRecursiveProof<BlockRootOrBlockMergePublicInputs, typeof RECURSIVE_PROOF_LENGTH>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.BLOCK_ROOT_ROLLUP, input),
ProvingRequestType.BLOCK_ROOT_ROLLUP,
input,
epochNumber,
signal,
);
}

getEmptyBlockRootRollupProof(
input: EmptyBlockRootRollupInputs,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<PublicInputsAndRecursiveProof<BlockRootOrBlockMergePublicInputs>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.EMPTY_BLOCK_ROOT_ROLLUP, input),
ProvingRequestType.EMPTY_BLOCK_ROOT_ROLLUP,
input,
epochNumber,
signal,
);
}

getEmptyPrivateKernelProof(
inputs: PrivateKernelEmptyInputData,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<PublicInputsAndRecursiveProof<KernelCircuitPublicInputs, typeof RECURSIVE_PROOF_LENGTH>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.PRIVATE_KERNEL_EMPTY, inputs),
ProvingRequestType.PRIVATE_KERNEL_EMPTY,
inputs,
epochNumber,
signal,
);
}

getMergeRollupProof(
input: MergeRollupInputs,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<PublicInputsAndRecursiveProof<BaseOrMergeRollupPublicInputs, typeof RECURSIVE_PROOF_LENGTH>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.MERGE_ROLLUP, input),
ProvingRequestType.MERGE_ROLLUP,
input,
epochNumber,
signal,
);
}
getPrivateBaseRollupProof(
baseRollupInput: PrivateBaseRollupInputs,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<PublicInputsAndRecursiveProof<BaseOrMergeRollupPublicInputs, typeof RECURSIVE_PROOF_LENGTH>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.PRIVATE_BASE_ROLLUP, baseRollupInput),
ProvingRequestType.PRIVATE_BASE_ROLLUP,
baseRollupInput,
epochNumber,
signal,
);
}

getPublicBaseRollupProof(
inputs: PublicBaseRollupInputs,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<PublicInputsAndRecursiveProof<BaseOrMergeRollupPublicInputs, typeof RECURSIVE_PROOF_LENGTH>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.PUBLIC_BASE_ROLLUP, inputs),
ProvingRequestType.PUBLIC_BASE_ROLLUP,
inputs,
epochNumber,
signal,
);
}

getRootParityProof(
inputs: RootParityInputs,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<PublicInputsAndRecursiveProof<ParityPublicInputs, typeof NESTED_RECURSIVE_PROOF_LENGTH>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.ROOT_PARITY, inputs),
ProvingRequestType.ROOT_PARITY,
inputs,
epochNumber,
signal,
);
}

getRootRollupProof(
input: RootRollupInputs,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<PublicInputsAndRecursiveProof<RootRollupPublicInputs, typeof RECURSIVE_PROOF_LENGTH>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.ROOT_ROLLUP, input),
ProvingRequestType.ROOT_ROLLUP,
input,
epochNumber,
signal,
);
}

getTubeProof(
tubeInput: TubeInputs,
signal?: AbortSignal,
_epochNumber?: number,
epochNumber?: number,
): Promise<ProofAndVerificationKey<typeof TUBE_PROOF_LENGTH>> {
return this.enqueueAndWaitForJob(
this.generateId(ProvingRequestType.TUBE_PROOF, tubeInput),
ProvingRequestType.TUBE_PROOF,
tubeInput,
epochNumber,
signal,
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ describe('ProvingAgent', () => {
const inputs: ProvingJobInputs = { type: ProvingRequestType.BASE_PARITY, inputs: makeBaseParityInputs() };
const job: ProvingJob = {
id: randomBytes(8).toString('hex') as ProvingJobId,
blockNumber: 1,
epochNumber: 1,
type: ProvingRequestType.BASE_PARITY,
inputsUri: randomBytes(8).toString('hex') as ProofUri,
};
Expand Down
Loading

0 comments on commit ecc037f

Please sign in to comment.