Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix Can't pickle local object 'check_correctness.<locals>.unsafe_execute #30

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 57 additions & 48 deletions human_eval/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,42 +9,31 @@
import signal
import tempfile


def check_correctness(problem: Dict, completion: str, timeout: float,
completion_id: Optional[int] = None) -> Dict:
"""
Evaluates the functional correctness of a completion by running the test
suite provided in the problem.

:param completion_id: an optional completion ID so we can match
the results later even if execution finishes asynchronously.
"""

def unsafe_execute():

with create_tempdir():

# These system calls are needed when cleaning up tempdir.
import os
import shutil
rmtree = shutil.rmtree
rmdir = os.rmdir
chdir = os.chdir

# Disable functionalities that can make destructive changes to the test.
reliability_guard()

# Construct the check program and run it.
check_program = (
problem["prompt"] + completion + "\n" +
problem["test"] + "\n" +
f"check({problem['entry_point']})"
)

try:
exec_globals = {}
with swallow_io():
with time_limit(timeout):
def unsafe_execute(problem, completion, result, timeout):

with create_tempdir():

# These system calls are needed when cleaning up tempdir.
import os
import shutil
rmtree = shutil.rmtree
rmdir = os.rmdir
chdir = os.chdir

# Disable functionalities that can make destructive changes to the test.
reliability_guard()

# Construct the check program and run it.
check_program = (
problem["prompt"] + completion + "\n" +
problem["test"] + "\n" +
f"check({problem['entry_point']})"
)

try:
exec_globals = {}
with swallow_io():
with time_limit(timeout):
# WARNING
# This program exists to execute untrusted model-generated code. Although
# it is highly unlikely that model-generated code will do something overtly
Expand All @@ -55,22 +44,42 @@ def unsafe_execute():
# information on how OpenAI sandboxes its code, see the accompanying paper.
# Once you have read this disclaimer and taken appropriate precautions,
# uncomment the following line and proceed at your own risk:
# exec(check_program, exec_globals)
result.append("passed")
except TimeoutException:
result.append("timed out")
except BaseException as e:
result.append(f"failed: {e}")

# Needed for cleaning up.
shutil.rmtree = rmtree
os.rmdir = rmdir
os.chdir = chdir
exec(check_program, exec_globals)
result.append("passed")
except TimeoutException:
result.append("timed out")
except BaseException as e:
result.append(f"failed: {e}")

# Needed for cleaning up.
shutil.rmtree = rmtree
os.rmdir = rmdir
os.chdir = chdir


def check_correctness(problem: Dict, completion: str, timeout: float,
completion_id: Optional[int] = None) -> Dict:
"""
Evaluates the functional correctness of a completion by running the test
suite provided in the problem.

:param completion_id: an optional completion ID so we can match
the results later even if execution finishes asynchronously.
"""

manager = multiprocessing.Manager()
result = manager.list()

p = multiprocessing.Process(target=unsafe_execute)
# p = multiprocessing.Process(target=unsafe_execute)
p = multiprocessing.Process(
target=unsafe_execute,
args=(
problem,
completion,
result,
timeout
),
)
p.start()
p.join(timeout=timeout + 1)
if p.is_alive():
Expand Down