diff --git a/README.md b/README.md index 944bc3b..9448a80 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ _This library is not developed or endorsed by Google._ - [Tokens counting](#tokens-counting) - [Listing models](#listing-models) - [Advanced Usages](#advanced-usages) + - [Using Beta version](#using-beta-version) - [Safety Settings and Generation Configuration](#safety-settings-and-generation-configuration) - [Using your own HTTP client](#using-your-own-http-client) - [Using your own HTTP client for streaming responses](#using-your-own-http-client-for-streaming-responses) @@ -52,10 +53,11 @@ you need to allow `php-http/discovery` composer plugin or install a PSR-18 compa ```php use GeminiAPI\Client; +use GeminiAPI\Resources\ModelName; use GeminiAPI\Resources\Parts\TextPart; $client = new Client('GEMINI_API_KEY'); -$response = $client->geminiPro()->generateContent( +$response = $client->generativeModel(ModelName::GEMINI_PRO)->generateContent( new TextPart('PHP in less than 100 chars'), ); @@ -71,11 +73,12 @@ print $response->text(); ```php use GeminiAPI\Client; use GeminiAPI\Enums\MimeType; +use GeminiAPI\Resources\ModelName; use GeminiAPI\Resources\Parts\ImagePart; use GeminiAPI\Resources\Parts\TextPart; $client = new Client('GEMINI_API_KEY'); -$response = $client->geminiProVision()->generateContent( +$response = $client->generativeModel(ModelName::GEMINI_PRO)->generateContent( new TextPart('Explain what is in the image'), new ImagePart( MimeType::IMAGE_JPEG, @@ -94,10 +97,11 @@ print $response->text(); ```php use GeminiAPI\Client; +use GeminiAPI\Resources\ModelName; use GeminiAPI\Resources\Parts\TextPart; $client = new Client('GEMINI_API_KEY'); -$chat = $client->geminiPro()->startChat(); +$chat = $client->generativeModel(ModelName::GEMINI_PRO)->startChat(); $response = $chat->sendMessage(new TextPart('Hello World in PHP')); print $response->text(); @@ -132,6 +136,7 @@ This code will print "Hello World!" to the standard output. use GeminiAPI\Client; use GeminiAPI\Enums\Role; use GeminiAPI\Resources\Content; +use GeminiAPI\Resources\ModelName; use GeminiAPI\Resources\Parts\TextPart; $history = [ @@ -149,7 +154,7 @@ $history = [ ]; $client = new Client('GEMINI_API_KEY'); -$chat = $client->geminiPro() +$chat = $client->generativeModel(ModelName::GEMINI_PRO) ->startChat() ->withHistory($history); @@ -179,6 +184,7 @@ Long responses may be broken into separate responses, and you can start receivin ```php use GeminiAPI\Client; +use GeminiAPI\Resources\ModelName; use GeminiAPI\Resources\Parts\TextPart; use GeminiAPI\Responses\GenerateContentResponse; @@ -191,7 +197,7 @@ $callback = function (GenerateContentResponse $response): void { }; $client = new Client('GEMINI_API_KEY'); -$client->geminiPro()->generateContentStream( +$client->generativeModel(ModelName::GEMINI_PRO)->generateContentStream( $callback, [new TextPart('PHP in less than 100 chars')], ); @@ -209,6 +215,7 @@ $client->geminiPro()->generateContentStream( use GeminiAPI\Client; use GeminiAPI\Enums\Role; use GeminiAPI\Resources\Content; +use GeminiAPI\Resources\ModelName; use GeminiAPI\Resources\Parts\TextPart; use GeminiAPI\Responses\GenerateContentResponse; @@ -235,7 +242,7 @@ $callback = function (GenerateContentResponse $response): void { }; $client = new Client('GEMINI_API_KEY'); -$chat = $client->geminiPro() +$chat = $client->generativeModel(ModelName::GEMINI_PRO) ->startChat() ->withHistory($history); @@ -261,11 +268,11 @@ This code will print "Hello World!" to the standard output. ```php use GeminiAPI\Client; -use GeminiAPI\Enums\ModelName; +use GeminiAPI\Resources\ModelName; use GeminiAPI\Resources\Parts\TextPart; $client = new Client('GEMINI_API_KEY'); -$response = $client->embeddingModel(ModelName::Embedding) +$response = $client->embeddingModel(ModelName::EMBEDDING_001) ->embedContent( new TextPart('PHP in less than 100 chars'), ); @@ -282,10 +289,11 @@ print_r($response->embedding->values); ```php use GeminiAPI\Client; +use GeminiAPI\Resources\ModelName; use GeminiAPI\Resources\Parts\TextPart; $client = new Client('GEMINI_API_KEY'); -$response = $client->geminiPro()->countTokens( +$response = $client->generativeModel(ModelName::GEMINI_PRO)->countTokens( new TextPart('PHP in less than 100 chars'), ); @@ -322,6 +330,23 @@ print_r($response->models); ### Advanced Usages +#### Using Beta version + +```php +use GeminiAPI\Client; +use GeminiAPI\Resources\ModelName; +use GeminiAPI\Resources\Parts\TextPart; + +$client = (new Client('GEMINI_API_KEY')) + ->withV1BetaVersion(); +$response = $client->generativeModel(ModelName::GEMINI_PRO)->countTokens( + new TextPart('PHP in less than 100 chars'), +); + +print $response->totalTokens; +// 10 +``` + #### Safety Settings and Generation Configuration ```php @@ -329,6 +354,7 @@ use GeminiAPI\Client; use GeminiAPI\Enums\HarmCategory; use GeminiAPI\Enums\HarmBlockThreshold; use GeminiAPI\GenerationConfig; +use GeminiAPI\Resources\ModelName; use GeminiAPI\Resources\Parts\TextPart; use GeminiAPI\SafetySetting; @@ -345,7 +371,7 @@ $generationConfig = (new GenerationConfig()) ->withStopSequences(['STOP']); $client = new Client('GEMINI_API_KEY'); -$response = $client->geminiPro() +$response = $client->generativeModel(ModelName::GEMINI_PRO) ->withAddedSafetySetting($safetySetting) ->withGenerationConfig($generationConfig) ->generateContent( @@ -357,6 +383,7 @@ $response = $client->geminiPro() ```php use GeminiAPI\Client as GeminiClient; +use GeminiAPI\Resources\ModelName; use GeminiAPI\Resources\Parts\TextPart; use GuzzleHttp\Client as GuzzleClient; @@ -365,7 +392,7 @@ $guzzle = new GuzzleClient([ ]); $client = new GeminiClient('GEMINI_API_KEY', $guzzle); -$response = $client->geminiPro()->generateContent( +$response = $client->generativeModel(ModelName::GEMINI_PRO)->generateContent( new TextPart('PHP in less than 100 chars') ); ``` @@ -388,6 +415,7 @@ You can also pass the headers you want to be used in the requests. ```php use GeminiAPI\Client; +use GeminiAPI\Resources\ModelName; use GeminiAPI\Resources\Parts\TextPart; use GeminiAPI\Responses\GenerateContentResponse; @@ -402,7 +430,7 @@ $client = new Client('GEMINI_API_KEY'); $client->withRequestHeaders([ 'User-Agent' => 'My Gemini-backed app' ]) - ->geminiPro() + ->generativeModel(ModelName::GEMINI_PRO) ->generateContentStream( $callback, [new TextPart('PHP in less than 100 chars')], diff --git a/src/ChatSession.php b/src/ChatSession.php index 013b9c7..34f3ae6 100644 --- a/src/ChatSession.php +++ b/src/ChatSession.php @@ -37,7 +37,7 @@ public function sendMessage(PartInterface ...$parts): GenerateContentResponse ->withGenerationConfig($config) ->generateContentWithContents($this->history); - if(!empty($response->candidates)) { + if (!empty($response->candidates)) { $parts = $response->candidates[0]->content->parts; $this->history[] = new Content($parts, Role::Model); } @@ -58,7 +58,7 @@ public function sendMessageStream( $parts = []; $partsCollectorCallback = function (GenerateContentResponse $response) use ($callback, &$parts) { - if(!empty($response->candidates)) { + if (!empty($response->candidates)) { array_push($parts, ...$response->parts()); } diff --git a/src/Client.php b/src/Client.php index c4a3544..9db1d12 100644 --- a/src/Client.php +++ b/src/Client.php @@ -40,6 +40,7 @@ class Client implements GeminiClientInterface { private string $baseUrl = 'https://generativelanguage.googleapis.com'; + private string $version = GeminiClientInterface::API_VERSION_V1; /** * @var array @@ -87,7 +88,7 @@ public function geminiProFlash1_5(): GenerativeModel } - public function generativeModel(ModelName $modelName): GenerativeModel + public function generativeModel(ModelName|string $modelName): GenerativeModel { return new GenerativeModel( $this, @@ -95,7 +96,7 @@ public function generativeModel(ModelName $modelName): GenerativeModel ); } - public function embeddingModel(ModelName $modelName): EmbeddingModel + public function embeddingModel(ModelName|string $modelName): EmbeddingModel { return new EmbeddingModel( $this, @@ -163,7 +164,7 @@ public function generateContentStream( } } - curl_setopt($ch, CURLOPT_URL, "{$this->baseUrl}/v1/{$request->getOperation()}"); + curl_setopt($ch, CURLOPT_URL, $this->getRequestUrl($request)); curl_setopt($ch, CURLOPT_POST, true); curl_setopt($ch, CURLOPT_POSTFIELDS, json_encode($request)); curl_setopt($ch, CURLOPT_HTTPHEADER, $headerLines); @@ -214,6 +215,19 @@ public function withBaseUrl(string $baseUrl): self return $clone; } + public function withV1BetaVersion(): self + { + return $this->withVersion(GeminiClientInterface::API_VERSION_V1_BETA); + } + + public function withVersion(string $version): self + { + $clone = clone $this; + $clone->version = $version; + + return $clone; + } + /** * @param array $headers * @return self @@ -241,6 +255,16 @@ private function getRequestHeaders(): array ]; } + private function getRequestUrl(RequestInterface $request): string + { + return sprintf( + '%s/%s/%s', + $this->baseUrl, + $this->version, + $request->getOperation(), + ); + } + /** * @throws ClientExceptionInterface */ @@ -250,9 +274,11 @@ private function doRequest(RequestInterface $request): string throw new RuntimeException('Missing client or factory for Gemini API operation'); } - $uri = "{$this->baseUrl}/v1/{$request->getOperation()}"; $httpRequest = $this->requestFactory - ->createRequest($request->getHttpMethod(), $uri); + ->createRequest( + $request->getHttpMethod(), + $this->getRequestUrl($request), + ); foreach ($this->getRequestHeaders() as $name => $value) { $httpRequest = $httpRequest->withAddedHeader($name, $value); diff --git a/src/ClientInterface.php b/src/ClientInterface.php index 2d5cbec..edc2094 100644 --- a/src/ClientInterface.php +++ b/src/ClientInterface.php @@ -21,12 +21,14 @@ interface ClientInterface { public const API_KEY_HEADER_NAME = 'x-goog-api-key'; + public const API_VERSION_V1 = 'v1'; + public const API_VERSION_V1_BETA = 'v1beta'; public function countTokens(CountTokensRequest $request): CountTokensResponse; public function generateContent(GenerateContentRequest $request): GenerateContentResponse; public function embedContent(EmbedContentRequest $request): EmbedContentResponse; - public function generativeModel(ModelName $modelName): GenerativeModel; - public function embeddingModel(ModelName $modelName): EmbeddingModel; + public function generativeModel(ModelName|string $modelName): GenerativeModel; + public function embeddingModel(ModelName|string $modelName): EmbeddingModel; public function listModels(): ListModelsResponse; public function withBaseUrl(string $baseUrl): self; diff --git a/src/EmbeddingModel.php b/src/EmbeddingModel.php index 384977b..dc35952 100644 --- a/src/EmbeddingModel.php +++ b/src/EmbeddingModel.php @@ -19,7 +19,7 @@ class EmbeddingModel public function __construct( private readonly Client $client, - public readonly ModelName $modelName, + public readonly ModelName|string $modelName, ) { } diff --git a/src/Enums/ModelName.php b/src/Enums/ModelName.php index 5103ae2..df2dd00 100644 --- a/src/Enums/ModelName.php +++ b/src/Enums/ModelName.php @@ -4,6 +4,9 @@ namespace GeminiAPI\Enums; +/** + * @deprecated Use constants from GeminiAPI\Resources\ModelName instead + */ enum ModelName: string { case Default = 'models/text-bison-001'; diff --git a/src/GenerativeModel.php b/src/GenerativeModel.php index 0108f35..46f9fc1 100644 --- a/src/GenerativeModel.php +++ b/src/GenerativeModel.php @@ -4,7 +4,6 @@ namespace GeminiAPI; -use BadMethodCallException; use CurlHandle; use GeminiAPI\Enums\ModelName; use GeminiAPI\Enums\Role; @@ -29,7 +28,7 @@ class GenerativeModel public function __construct( private readonly Client $client, - public readonly ModelName $modelName, + public readonly ModelName|string $modelName, ) { } diff --git a/src/Requests/CountTokensRequest.php b/src/Requests/CountTokensRequest.php index 911fecf..d33549a 100644 --- a/src/Requests/CountTokensRequest.php +++ b/src/Requests/CountTokensRequest.php @@ -5,8 +5,9 @@ namespace GeminiAPI\Requests; use GeminiAPI\Enums\ModelName; -use GeminiAPI\Traits\ArrayTypeValidator; use GeminiAPI\Resources\Content; +use GeminiAPI\Traits\ArrayTypeValidator; +use GeminiAPI\Traits\ModelNameToString; use JsonSerializable; use function json_encode; @@ -14,13 +15,14 @@ class CountTokensRequest implements JsonSerializable, RequestInterface { use ArrayTypeValidator; + use ModelNameToString; /** - * @param ModelName $modelName + * @param ModelName|string $modelName * @param Content[] $contents */ public function __construct( - public readonly ModelName $modelName, + public readonly ModelName|string $modelName, public readonly array $contents, ) { $this->ensureArrayOfType($this->contents, Content::class); @@ -28,7 +30,7 @@ public function __construct( public function getOperation(): string { - return "{$this->modelName->value}:countTokens"; + return "{$this->modelNameToString($this->modelName)}:countTokens"; } public function getHttpMethod(): string @@ -50,7 +52,7 @@ public function getHttpPayload(): string public function jsonSerialize(): array { return [ - 'model' => $this->modelName->value, + 'model' => $this->modelNameToString($this->modelName), 'contents' => $this->contents, ]; } diff --git a/src/Requests/EmbedContentRequest.php b/src/Requests/EmbedContentRequest.php index 597c6ca..e55334d 100644 --- a/src/Requests/EmbedContentRequest.php +++ b/src/Requests/EmbedContentRequest.php @@ -8,22 +8,21 @@ use GeminiAPI\Enums\ModelName; use GeminiAPI\Enums\TaskType; use GeminiAPI\Resources\Content; +use GeminiAPI\Traits\ModelNameToString; use JsonSerializable; use function json_encode; class EmbedContentRequest implements JsonSerializable, RequestInterface { + use ModelNameToString; + public function __construct( - public readonly ModelName $modelName, + public readonly ModelName|string $modelName, public readonly Content $content, public readonly ?TaskType $taskType = null, public readonly ?string $title = null, ) { - if (isset($this->taskType) && $this->modelName !== ModelName::Embedding) { - throw new BadMethodCallException('TaskType can only be set when ModelName is Embedding'); - } - if (isset($this->title) && $this->taskType !== TaskType::RETRIEVAL_DOCUMENT) { throw new BadMethodCallException('Title is only applicable when TaskType is RETRIEVAL_DOCUMENT'); } @@ -31,7 +30,7 @@ public function __construct( public function getOperation(): string { - return "{$this->modelName->value}:embedContent"; + return "{$this->modelNameToString($this->modelName)}:embedContent"; } public function getHttpMethod(): string diff --git a/src/Requests/GenerateContentRequest.php b/src/Requests/GenerateContentRequest.php index a421e3e..d3027b9 100644 --- a/src/Requests/GenerateContentRequest.php +++ b/src/Requests/GenerateContentRequest.php @@ -6,9 +6,10 @@ use GeminiAPI\Enums\ModelName; use GeminiAPI\GenerationConfig; +use GeminiAPI\Resources\Content; use GeminiAPI\SafetySetting; use GeminiAPI\Traits\ArrayTypeValidator; -use GeminiAPI\Resources\Content; +use GeminiAPI\Traits\ModelNameToString; use JsonSerializable; use function json_encode; @@ -16,15 +17,16 @@ class GenerateContentRequest implements JsonSerializable, RequestInterface { use ArrayTypeValidator; + use ModelNameToString; /** - * @param ModelName $modelName + * @param ModelName|string $modelName * @param Content[] $contents * @param SafetySetting[] $safetySettings * @param GenerationConfig|null $generationConfig */ public function __construct( - public readonly ModelName $modelName, + public readonly ModelName|string $modelName, public readonly array $contents, public readonly array $safetySettings = [], public readonly ?GenerationConfig $generationConfig = null, @@ -35,7 +37,7 @@ public function __construct( public function getOperation(): string { - return "{$this->modelName->value}:generateContent"; + return "{$this->modelNameToString($this->modelName)}:generateContent"; } public function getHttpMethod(): string @@ -59,7 +61,7 @@ public function getHttpPayload(): string public function jsonSerialize(): array { $arr = [ - 'model' => $this->modelName->value, + 'model' => $this->modelNameToString($this->modelName), 'contents' => $this->contents, ]; diff --git a/src/Requests/GenerateContentStreamRequest.php b/src/Requests/GenerateContentStreamRequest.php index 74aac7b..cf19c3a 100644 --- a/src/Requests/GenerateContentStreamRequest.php +++ b/src/Requests/GenerateContentStreamRequest.php @@ -6,9 +6,10 @@ use GeminiAPI\Enums\ModelName; use GeminiAPI\GenerationConfig; +use GeminiAPI\Resources\Content; use GeminiAPI\SafetySetting; use GeminiAPI\Traits\ArrayTypeValidator; -use GeminiAPI\Resources\Content; +use GeminiAPI\Traits\ModelNameToString; use JsonSerializable; use function json_encode; @@ -16,15 +17,16 @@ class GenerateContentStreamRequest implements JsonSerializable, RequestInterface { use ArrayTypeValidator; + use ModelNameToString; /** - * @param ModelName $modelName + * @param ModelName|string $modelName * @param Content[] $contents * @param SafetySetting[] $safetySettings * @param GenerationConfig|null $generationConfig */ public function __construct( - public readonly ModelName $modelName, + public readonly ModelName|string $modelName, public readonly array $contents, public readonly array $safetySettings = [], public readonly ?GenerationConfig $generationConfig = null, @@ -35,7 +37,7 @@ public function __construct( public function getOperation(): string { - return "{$this->modelName->value}:streamGenerateContent"; + return "{$this->modelNameToString($this->modelName)}:streamGenerateContent"; } public function getHttpMethod(): string @@ -59,7 +61,7 @@ public function getHttpPayload(): string public function jsonSerialize(): array { $arr = [ - 'model' => $this->modelName->value, + 'model' => $this->modelNameToString($this->modelName), 'contents' => $this->contents, ]; diff --git a/src/Resources/ModelName.php b/src/Resources/ModelName.php new file mode 100644 index 0000000..84b0bf5 --- /dev/null +++ b/src/Resources/ModelName.php @@ -0,0 +1,38 @@ +value; + } +} diff --git a/tests/Unit/ClientTest.php b/tests/Unit/ClientTest.php index 1e13609..0e8f25e 100644 --- a/tests/Unit/ClientTest.php +++ b/tests/Unit/ClientTest.php @@ -6,12 +6,13 @@ use GeminiAPI\Client; use GeminiAPI\ClientInterface as GeminiAPIClientInterface; -use GeminiAPI\Enums\ModelName; +use GeminiAPI\Enums\ModelName as ModelNameEnum; use GeminiAPI\GenerativeModel; use GeminiAPI\Requests\CountTokensRequest; use GeminiAPI\Requests\EmbedContentRequest; use GeminiAPI\Requests\GenerateContentRequest; use GeminiAPI\Resources\Content; +use GeminiAPI\Resources\ModelName; use GuzzleHttp\Psr7\Request; use GuzzleHttp\Psr7\Response; use GuzzleHttp\Psr7\Utils; @@ -22,7 +23,7 @@ class ClientTest extends TestCase { - public function testConstructor() + public function testConstructor(): void { $client = new Client( 'test-api-key', @@ -31,7 +32,7 @@ public function testConstructor() self::assertInstanceOf(Client::class, $client); } - public function testWithBaseUrl() + public function testWithBaseUrl(): void { $client = new Client( 'test-api-key', @@ -41,26 +42,26 @@ public function testWithBaseUrl() self::assertInstanceOf(Client::class, $client); } - public function testGeminiPro() + public function testGeminiPro(): void { $client = new Client( 'test-api-key', $this->createMock(HttpClientInterface::class), ); - $model = $client->geminiPro(); + $model = $client->generativeModel(ModelName::GEMINI_PRO); self::assertInstanceOf(GenerativeModel::class, $model); - self::assertEquals(ModelName::GeminiPro, $model->modelName); + self::assertEquals(ModelName::GEMINI_PRO, $model->modelName); } - public function testGeminiProVision() + public function testGeminiProWithEnum(): void { $client = new Client( 'test-api-key', $this->createMock(HttpClientInterface::class), ); - $model = $client->geminiProVision(); + $model = $client->generativeModel(ModelNameEnum::GeminiPro); self::assertInstanceOf(GenerativeModel::class, $model); - self::assertEquals(ModelName::GeminiProVision, $model->modelName); + self::assertEquals(ModelNameEnum::GeminiPro, $model->modelName); } public function testGenerativeModel() @@ -69,12 +70,12 @@ public function testGenerativeModel() 'test-api-key', $this->createMock(HttpClientInterface::class), ); - $model = $client->generativeModel(ModelName::Embedding); + $model = $client->generativeModel(ModelName::EMBEDDING_001); self::assertInstanceOf(GenerativeModel::class, $model); - self::assertEquals(ModelName::Embedding, $model->modelName); + self::assertEquals(ModelName::EMBEDDING_001, $model->modelName); } - public function testGenerateContent() + public function testGenerateContent(): void { $httpRequest = new Request( 'POST', @@ -144,14 +145,14 @@ public function testGenerateContent() $streamFactory, ); $request = new GenerateContentRequest( - ModelName::GeminiPro, + ModelName::GEMINI_PRO, [Content::text('this is a text')], ); $response = $client->generateContent($request); self::assertEquals('This is the Gemini Pro response', $response->text()); } - public function testEmbedContent() + public function testEmbedContent(): void { $httpRequest = new Request( 'POST', @@ -199,14 +200,14 @@ public function testEmbedContent() $streamFactory, ); $request = new EmbedContentRequest( - ModelName::Embedding, + ModelName::EMBEDDING_001, Content::text('this is a text'), ); $response = $client->embedContent($request); self::assertEquals([0.041395925, -0.017692696], $response->embedding->values); } - public function testCountTokens() + public function testCountTokens(): void { $httpRequest = new Request( 'POST', @@ -249,14 +250,14 @@ public function testCountTokens() $streamFactory, ); $request = new CountTokensRequest( - ModelName::GeminiPro, + ModelName::GEMINI_PRO, [Content::text('this is a text')], ); $response = $client->countTokens($request); self::assertEquals(10, $response->totalTokens); } - public function testListModels() + public function testListModels(): void { $httpRequest = new Request( 'GET', diff --git a/tests/Unit/Requests/CountTokensRequestTest.php b/tests/Unit/Requests/CountTokensRequestTest.php index 9517c63..56763e8 100644 --- a/tests/Unit/Requests/CountTokensRequestTest.php +++ b/tests/Unit/Requests/CountTokensRequestTest.php @@ -4,26 +4,26 @@ namespace GeminiAPI\Tests\Unit\Requests; -use GeminiAPI\Enums\ModelName; use GeminiAPI\Enums\Role; use GeminiAPI\Requests\CountTokensRequest; use GeminiAPI\Resources\Content; +use GeminiAPI\Resources\ModelName; use GeminiAPI\Resources\Parts\TextPart; use InvalidArgumentException; use PHPUnit\Framework\TestCase; class CountTokensRequestTest extends TestCase { - public function testConstructorWithNoContents() + public function testConstructorWithNoContents(): void { - $request = new CountTokensRequest(ModelName::Default, []); + $request = new CountTokensRequest(ModelName::GEMINI_PRO, []); self::assertInstanceOf(CountTokensRequest::class, $request); } - public function testConstructorWithContents() + public function testConstructorWithContents(): void { $request = new CountTokensRequest( - ModelName::Default, + ModelName::GEMINI_PRO, [ new Content([], Role::User), new Content([], Role::Model), @@ -33,12 +33,13 @@ public function testConstructorWithContents() self::assertInstanceOf(CountTokensRequest::class, $request); } - public function testConstructorWithInvalidContents() + public function testConstructorWithInvalidContents(): void { $this->expectException(InvalidArgumentException::class); new CountTokensRequest( - ModelName::Default, + ModelName::GEMINI_PRO, + // @phpstan-ignore-next-line [ new Content([], Role::User), new TextPart('This is a text'), @@ -46,42 +47,42 @@ public function testConstructorWithInvalidContents() ); } - public function testGetOperation() + public function testGetOperation(): void { - $request = new CountTokensRequest(ModelName::Default, []); - self::assertEquals('models/text-bison-001:countTokens', $request->getOperation()); + $request = new CountTokensRequest(ModelName::GEMINI_PRO, []); + self::assertEquals('models/gemini-pro:countTokens', $request->getOperation()); } - public function testGetHttpMethod() + public function testGetHttpMethod(): void { - $request = new CountTokensRequest(ModelName::Default, []); + $request = new CountTokensRequest(ModelName::GEMINI_PRO, []); self::assertEquals('POST', $request->getHttpMethod()); } - public function testGetHttpPayload() + public function testGetHttpPayload(): void { $request = new CountTokensRequest( - ModelName::Default, + ModelName::GEMINI_PRO, [ new Content([new TextPart('This is a text')], Role::User), ], ); - $expected = '{"model":"models\/text-bison-001","contents":[{"parts":[{"text":"This is a text"}],"role":"user"}]}'; + $expected = '{"model":"models\/gemini-pro","contents":[{"parts":[{"text":"This is a text"}],"role":"user"}]}'; self::assertEquals($expected, $request->getHttpPayload()); } - public function testJsonSerialize() + public function testJsonSerialize(): void { $request = new CountTokensRequest( - ModelName::Default, + ModelName::GEMINI_PRO, [ new Content([new TextPart('This is a text')], Role::User), ], ); $expected = [ - 'model' => 'models/text-bison-001', + 'model' => 'models/gemini-pro', 'contents' => [ new Content([new TextPart('This is a text')], Role::User), ], @@ -89,10 +90,10 @@ public function testJsonSerialize() self::assertEquals($expected, $request->jsonSerialize()); } - public function test__toString() + public function test__toString(): void { $request = new CountTokensRequest( - ModelName::Default, + ModelName::GEMINI_PRO, [ new Content( [new TextPart('This is a text')], @@ -101,7 +102,7 @@ public function test__toString() ], ); - $expected = '{"model":"models\/text-bison-001","contents":[{"parts":[{"text":"This is a text"}],"role":"user"}]}'; + $expected = '{"model":"models\/gemini-pro","contents":[{"parts":[{"text":"This is a text"}],"role":"user"}]}'; self::assertEquals($expected, (string) $request); } } diff --git a/tests/Unit/Requests/EmbedContentRequestTest.php b/tests/Unit/Requests/EmbedContentRequestTest.php index e1a0fdb..e6cea4e 100644 --- a/tests/Unit/Requests/EmbedContentRequestTest.php +++ b/tests/Unit/Requests/EmbedContentRequestTest.php @@ -5,37 +5,37 @@ namespace GeminiAPI\Tests\Unit\Requests; use BadMethodCallException; -use GeminiAPI\Enums\ModelName; use GeminiAPI\Enums\TaskType; use GeminiAPI\Requests\EmbedContentRequest; use GeminiAPI\Resources\Content; +use GeminiAPI\Resources\ModelName; use PHPUnit\Framework\TestCase; class EmbedContentRequestTest extends TestCase { - public function testConstructor() + public function testConstructor(): void { $request = new EmbedContentRequest( - ModelName::Embedding, + ModelName::EMBEDDING_001, Content::text('this is a test'), ); self::assertInstanceOf(EmbedContentRequest::class, $request); } - public function testConstructorWithTaskType() + public function testConstructorWithTaskType(): void { $request = new EmbedContentRequest( - ModelName::Embedding, + ModelName::EMBEDDING_001, Content::text('this is a test'), TaskType::RETRIEVAL_DOCUMENT, ); self::assertInstanceOf(EmbedContentRequest::class, $request); } - public function testConstructorWithTitle() + public function testConstructorWithTitle(): void { $request = new EmbedContentRequest( - ModelName::Embedding, + ModelName::EMBEDDING_001, Content::text('this is a test'), TaskType::RETRIEVAL_DOCUMENT, 'this is a title', @@ -43,62 +43,50 @@ public function testConstructorWithTitle() self::assertInstanceOf(EmbedContentRequest::class, $request); } - public function testConstructorWithTaskTypeAndNonEmbeddingModel() - { - $this->expectException(BadMethodCallException::class); - $this->expectExceptionMessage('TaskType can only be set when ModelName is Embedding'); - - new EmbedContentRequest( - ModelName::GeminiPro, - Content::text('this is a test'), - TaskType::RETRIEVAL_DOCUMENT, - ); - } - - public function testConstructorWithTitleAndWrongTaskType() + public function testConstructorWithTitleAndWrongTaskType(): void { $this->expectException(BadMethodCallException::class); $this->expectExceptionMessage('Title is only applicable when TaskType is RETRIEVAL_DOCUMENT'); new EmbedContentRequest( - ModelName::Embedding, + ModelName::EMBEDDING_001, Content::text('this is a test'), TaskType::RETRIEVAL_QUERY, 'this is a title', ); } - public function testGetHttpPayload() + public function testGetHttpPayload(): void { $request = new EmbedContentRequest( - ModelName::Embedding, + ModelName::EMBEDDING_001, Content::text('this is a test'), ); self::assertEquals('{"content":{"parts":[{"text":"this is a test"}],"role":"user"}}', $request->getHttpPayload()); } - public function testGetHttpMethod() + public function testGetHttpMethod(): void { $request = new EmbedContentRequest( - ModelName::Embedding, + ModelName::EMBEDDING_001, Content::text('this is a test'), ); self::assertEquals('POST', $request->getHttpMethod()); } - public function testGetOperation() + public function testGetOperation(): void { $request = new EmbedContentRequest( - ModelName::Embedding, + ModelName::EMBEDDING_001, Content::text('this is a test'), ); self::assertEquals('models/embedding-001:embedContent', $request->getOperation()); } - public function testJsonSerialize() + public function testJsonSerialize(): void { $request = new EmbedContentRequest( - ModelName::Embedding, + ModelName::EMBEDDING_001, $content = Content::text('this is a test'), TaskType::RETRIEVAL_DOCUMENT, 'this is a title', @@ -111,10 +99,10 @@ public function testJsonSerialize() self::assertEquals($expected, $request->jsonSerialize()); } - public function test__toString() + public function test__toString(): void { $request = new EmbedContentRequest( - ModelName::Embedding, + ModelName::EMBEDDING_001, Content::text('this is a test'), ); self::assertEquals('{"content":{"parts":[{"text":"this is a test"}],"role":"user"}}', (string) $request); diff --git a/tests/Unit/Requests/GenerateContentRequestTest.php b/tests/Unit/Requests/GenerateContentRequestTest.php index 62afb95..6c3aead 100644 --- a/tests/Unit/Requests/GenerateContentRequestTest.php +++ b/tests/Unit/Requests/GenerateContentRequestTest.php @@ -7,11 +7,11 @@ use GeminiAPI\Enums\HarmBlockThreshold; use GeminiAPI\Enums\HarmCategory; use GeminiAPI\Enums\HarmProbability; -use GeminiAPI\Enums\ModelName; use GeminiAPI\Enums\Role; use GeminiAPI\GenerationConfig; use GeminiAPI\Requests\GenerateContentRequest; use GeminiAPI\Resources\Content; +use GeminiAPI\Resources\ModelName; use GeminiAPI\Resources\Parts\TextPart; use GeminiAPI\Resources\SafetyRating; use GeminiAPI\SafetySetting; @@ -20,10 +20,10 @@ class GenerateContentRequestTest extends TestCase { - public function testConstructorWithNoContents() + public function testConstructorWithNoContents(): void { $request = new GenerateContentRequest( - ModelName::Default, + ModelName::GEMINI_PRO, [], [], null, @@ -31,10 +31,10 @@ public function testConstructorWithNoContents() self::assertInstanceOf(GenerateContentRequest::class, $request); } - public function testConstructorWithContents() + public function testConstructorWithContents(): void { $request = new GenerateContentRequest( - ModelName::Default, + ModelName::GEMINI_PRO, [ new Content([], Role::User), new Content([], Role::Model), @@ -45,12 +45,12 @@ public function testConstructorWithContents() self::assertInstanceOf(GenerateContentRequest::class, $request); } - public function testConstructorWithInvalidContents() + public function testConstructorWithInvalidContents(): void { $this->expectException(InvalidArgumentException::class); new GenerateContentRequest( - ModelName::Default, + ModelName::GEMINI_PRO, [ new Content([], Role::User), new TextPart('This is a text'), @@ -60,10 +60,10 @@ public function testConstructorWithInvalidContents() ); } - public function testConstructorWithSafetySettings() + public function testConstructorWithSafetySettings(): void { $request = new GenerateContentRequest( - ModelName::Default, + ModelName::GEMINI_PRO, [], [ new SafetySetting( @@ -80,12 +80,12 @@ public function testConstructorWithSafetySettings() self::assertInstanceOf(GenerateContentRequest::class, $request); } - public function testConstructorWithInvalidSafetySettings() + public function testConstructorWithInvalidSafetySettings(): void { $this->expectException(InvalidArgumentException::class); new GenerateContentRequest( - ModelName::Default, + ModelName::GEMINI_PRO, [], [ new SafetySetting( @@ -102,10 +102,10 @@ public function testConstructorWithInvalidSafetySettings() ); } - public function testConstructorWithGenerationConfig() + public function testConstructorWithGenerationConfig(): void { $request = new GenerateContentRequest( - ModelName::Default, + ModelName::GEMINI_PRO, [], [], new GenerationConfig(), @@ -113,41 +113,41 @@ public function testConstructorWithGenerationConfig() self::assertInstanceOf(GenerateContentRequest::class, $request); } - public function testGetOperation() + public function testGetOperation(): void { - $request = new GenerateContentRequest(ModelName::Default, []); - self::assertEquals('models/text-bison-001:generateContent', $request->getOperation()); + $request = new GenerateContentRequest(ModelName::GEMINI_PRO, []); + self::assertEquals('models/gemini-pro:generateContent', $request->getOperation()); } - public function testGetHttpMethod() + public function testGetHttpMethod(): void { - $request = new GenerateContentRequest(ModelName::Default, []); + $request = new GenerateContentRequest(ModelName::GEMINI_PRO, []); self::assertEquals('POST', $request->getHttpMethod()); } - public function testGetHttpPayload() + public function testGetHttpPayload(): void { $request = new GenerateContentRequest( - ModelName::Default, + ModelName::GEMINI_PRO, [ new Content([new TextPart('This is a text')], Role::User), ], ); - $expected = '{"model":"models\/text-bison-001","contents":[{"parts":[{"text":"This is a text"}],"role":"user"}]}'; + $expected = '{"model":"models\/gemini-pro","contents":[{"parts":[{"text":"This is a text"}],"role":"user"}]}'; self::assertEquals($expected, $request->getHttpPayload()); } - public function testJsonSerialize() + public function testJsonSerialize(): void { $request = new GenerateContentRequest( - ModelName::Default, + ModelName::GEMINI_PRO, [ new Content([new TextPart('This is a text')], Role::User), ], ); $expected = [ - 'model' => 'models/text-bison-001', + 'model' => 'models/gemini-pro', 'contents' => [ new Content([new TextPart('This is a text')], Role::User), ], @@ -155,10 +155,10 @@ public function testJsonSerialize() self::assertEquals($expected, $request->jsonSerialize()); } - public function test__toString() + public function test__toString(): void { $request = new GenerateContentRequest( - ModelName::Default, + ModelName::GEMINI_PRO, [ new Content( [new TextPart('This is a text')], @@ -167,7 +167,7 @@ public function test__toString() ], ); - $expected = '{"model":"models\/text-bison-001","contents":[{"parts":[{"text":"This is a text"}],"role":"user"}]}'; + $expected = '{"model":"models\/gemini-pro","contents":[{"parts":[{"text":"This is a text"}],"role":"user"}]}'; self::assertEquals($expected, (string) $request); } }