Skip to content

Commit

Permalink
Refactor implementation to unify tools usage for REST & realtime
Browse files Browse the repository at this point in the history
  • Loading branch information
thekid committed Nov 1, 2024
1 parent dabe869 commit dc08c4c
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 70 deletions.
4 changes: 2 additions & 2 deletions composer.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
"xp-framework/core": "^12.0 | ^11.0 | ^10.0",
"xp-framework/logging": "^11.2",
"xp-framework/reflection": "^3.0 | ^2.0",
"xp-forge/marshalling": "^2.0 | ^1.0",
"xp-forge/rest-client": "^5.6",
"xp-forge/marshalling": "^2.3",
"xp-forge/rest-client": "^5.7",
"xp-forge/websockets": "^4.0",
"php" : ">=7.4.0"
},
Expand Down
15 changes: 1 addition & 14 deletions src/main/php/com/openai/Tools.class.php
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,7 @@ class Tools {
*/
public function __construct(...$selected) {
foreach ($selected as $select) {
if ($select instanceof Functions) {
foreach ($select->schema() as $name => $function) {
$this->selection[]= ['type' => 'function', 'function' => [
'name' => $name,
'description' => $function['description'],
'parameters' => $function['input'],
]];
}
} else {
$this->selection[]= is_string($select) ? ['type' => $select] : $select;
}
$this->selection[]= is_string($select) ? ['type' => $select] : $select;
}
}

/** @return var */
public function __serialize() { return $this->selection; }
}
21 changes: 19 additions & 2 deletions src/main/php/com/openai/realtime/RealtimeApi.class.php
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
<?php namespace com\openai\realtime;

use com\openai\Tools;
use com\openai\tools\Functions;
use lang\IllegalStateException;
use text\json\Json;
use util\URI;
use util\data\Marshalling;
use util\log\Traceable;
use util\URI;
use websocket\WebSocket;

/**
Expand All @@ -22,7 +24,22 @@ class RealtimeApi implements Traceable {
/** @param string|util.URI|websocket.WebSocket $endpoint */
public function __construct($endpoint) {
$this->ws= $endpoint instanceof WebSocket ? $endpoint : new WebSocket((string)$endpoint);
$this->marshalling= new Marshalling();
$this->marshalling= (new Marshalling())->mapping(Tools::class, function($tools) {
foreach ($tools->selection as $select) {
if ($select instanceof Functions) {
foreach ($select->schema() as $name => $function) {
yield [
'type' => 'function',
'name' => $name,
'description' => $function['description'],
'parameters' => $function['input'],
];
}
} else {
yield $select;
}
}
});
}

/** @param ?util.log.LogCategory $cat */
Expand Down
23 changes: 3 additions & 20 deletions src/main/php/com/openai/rest/AzureAIEndpoint.class.php
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
*
* @test com.openai.unittest.AzureAIEndpointTest
*/
class AzureAIEndpoint extends ApiEndpoint {
private $endpoint, $rateLimit;
class AzureAIEndpoint extends RestEndpoint {
public $version;

/**
Expand All @@ -22,29 +21,13 @@ class AzureAIEndpoint extends ApiEndpoint {
*/
public function __construct($arg, $version= null) {
if ($arg instanceof Endpoint) {
$this->endpoint= $arg;
$this->version= $version;
parent::__construct($arg);
} else {
$uri= $arg instanceof URI ? $arg : new URI($arg);
$this->version= $version ?? $uri->param('api-version');
$this->endpoint= (new Endpoint($uri))->with(['Authorization' => null, 'API-Key' => $uri->user()]);
parent::__construct((new Endpoint($uri))->with(['Authorization' => null, 'API-Key' => $uri->user()]));
}
$this->rateLimit= new RateLimit();
}

/** Returns rate limit */
public function rateLimit(): RateLimit { return $this->rateLimit; }

/** @return [:var] */
public function headers() { return $this->endpoint->headers(); }

/**
* Provides a log category for tracing requests
*
* @param ?util.log.LogCategory $cat
*/
public function setTrace($cat) {
$this->endpoint->setTrace($cat);
}

/** Returns an API */
Expand Down
21 changes: 2 additions & 19 deletions src/main/php/com/openai/rest/OpenAIEndpoint.class.php
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
* @see https://platform.openai.com/docs/api-reference/authentication
* @test com.openai.unittest.OpenAIEndpointTest
*/
class OpenAIEndpoint extends ApiEndpoint {
private $endpoint, $rateLimit;
class OpenAIEndpoint extends RestEndpoint {

/**
* Creates a new OpenAI endpoint
Expand All @@ -19,8 +18,7 @@ class OpenAIEndpoint extends ApiEndpoint {
* @param ?string $project
*/
public function __construct($arg, $organization= null, $project= null) {
$this->endpoint= $arg instanceof Endpoint ? $arg : new Endpoint($arg);
$this->rateLimit= new RateLimit();
parent::__construct($arg instanceof Endpoint ? $arg : new Endpoint($arg));

// Pass optional organization and project IDs
$headers= [];
Expand All @@ -29,21 +27,6 @@ public function __construct($arg, $organization= null, $project= null) {
$headers && $this->endpoint->with($headers);
}

/** Returns rate limit */
public function rateLimit(): RateLimit { return $this->rateLimit; }

/** @return [:var] */
public function headers() { return $this->endpoint->headers(); }

/**
* Provides a log category for tracing requests
*
* @param ?util.log.LogCategory $cat
*/
public function setTrace($cat) {
$this->endpoint->setTrace($cat);
}

/** Returns an API */
public function api(string $path, array $segments= []): Api {
return new Api($this->endpoint->resource(ltrim($path, '/'), $segments), $this->rateLimit);
Expand Down
45 changes: 45 additions & 0 deletions src/main/php/com/openai/rest/RestEndpoint.class.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
<?php namespace com\openai\rest;

use com\openai\Tools;
use com\openai\tools\Functions;

/** Base class for OpenAI and AzureAI implementations */
abstract class RestEndpoint extends ApiEndpoint {
protected $endpoint, $rateLimit;

/** @param webservices.rest.Endpoint */
public function __construct($endpoint) {
$this->endpoint= $endpoint;
$this->endpoint->marshalling->mapping(Tools::class, function($tools) {
foreach ($tools->selection as $select) {
if ($select instanceof Functions) {
foreach ($select->schema() as $name => $function) {
yield ['type' => 'function', 'function' => [
'name' => $name,
'description' => $function['description'],
'parameters' => $function['input'],
]];
}
} else {
yield $select;
}
}
});
$this->rateLimit= new RateLimit();
}

/** Returns rate limit */
public function rateLimit(): RateLimit { return $this->rateLimit; }

/** @return [:var] */
public function headers() { return $this->endpoint->headers(); }

/**
* Provides a log category for tracing requests
*
* @param ?util.log.LogCategory $cat
*/
public function setTrace($cat) {
$this->endpoint->setTrace($cat);
}
}
72 changes: 59 additions & 13 deletions src/test/php/com/openai/unittest/ToolsTest.class.php
Original file line number Diff line number Diff line change
@@ -1,40 +1,86 @@
<?php namespace com\openai\unittest;

use com\openai\Tools;
use com\openai\realtime\RealtimeApi;
use com\openai\rest\OpenAIEndpoint;
use com\openai\tools\Functions;
use test\{Assert, Test};
use test\{Assert, Test, Values};
use webservices\rest\TestEndpoint;

class ToolsTest {

/** Returns a testing API endpoint */
private function testingEndpoint(): TestEndpoint {
return new TestEndpoint([
'POST /echo' => function($call) {
return $call->respond(200, 'OK', ['Content-Type' => 'application/json'], $call->content());
}
]);
}

/** Returns functions with a "Hello World!" registration */
private function functions(): Functions {
return (new Functions())->register('greet', new class() {
public function world($name= 'World') { return "Hello {$name}!"; }
});
}

#[Test]
public function can_create() {
new Tools();
}

#[Test]
public function code_interpreter() {
Assert::equals(
[['type' => 'code_interpreter']],
(new Tools('code_interpreter'))->selection
);
#[Test, Values([['code_interpreter'], [['type' => 'code_interpreter']]])]
public function code_interpreter($tool) {
Assert::equals([['type' => 'code_interpreter']], (new Tools($tool))->selection);
}

#[Test]
public function with_custom_functions() {
$functions= (new Functions())->register('greet', new class() {
public function world() { return 'Hello World!'; }
});
$functions= $this->functions();
Assert::equals([$functions], (new Tools($functions))->selection);
}

#[Test]
public function serialized_for_rest_api() {
$functions= $this->functions();
$endpoint= new OpenAIEndpoint($this->testingEndpoint());
$result= $endpoint->api('/echo')->invoke(['tools' => new Tools($functions)]);

Assert::equals(
[[
['tools' => [[
'type' => 'function',
'function' => [
'name' => 'greet_world',
'description' => 'World',
'parameters' => $functions->schema()->current()['input'],
],
]],
(new Tools($functions))->selection
]]],
$result
);
}

#[Test]
public function serialized_for_realtime_api() {
$functions= $this->functions();
$api= new RealtimeApi(new TestingSocket([
'{"type": "session.created"}',
'{"type": "session.update", "session": {
"tools": [{
"type": "function",
"name": "greet_world",
"description": "World",
"parameters": {
"type": "object",
"properties": {
"name": {"type": "string", "description": "Name"}
},
"required": []
}
}]
}}',
]));
$api->connect();
$api->send(['type' => 'session.update', 'session' => ['tools' => new Tools($functions)]]);
}
}

0 comments on commit dc08c4c

Please sign in to comment.