@@ -32,51 +32,14 @@ class ChainSerializationTest(TestCase):
3232 # We expect users to use the same LangChain version for serialize and de-serialize
3333
3434 def setUp (self ) -> None :
35- self .maxDiff = None
35+ # self.maxDiff = None
3636 return super ().setUp ()
3737
3838 PROMPT_TEMPLATE = "Tell me a joke about {subject}"
3939 COMPARTMENT_ID = "<ocid>"
4040 GEN_AI_KWARGS = {"service_endpoint" : "https://endpoint.oraclecloud.com" }
4141 ENDPOINT = "https://modeldeployment.customer-oci.com/ocid/predict"
4242
43- EXPECTED_LLM_CHAIN_WITH_OCI_MD = {
44- "lc" : 1 ,
45- "type" : "constructor" ,
46- "id" : ["langchain" , "chains" , "llm" , "LLMChain" ],
47- "kwargs" : {
48- "prompt" : {
49- "lc" : 1 ,
50- "type" : "constructor" ,
51- "kwargs" : {
52- "input_variables" : ["subject" ],
53- "template" : "Tell me a joke about {subject}" ,
54- "template_format" : "f-string" ,
55- "partial_variables" : {},
56- },
57- },
58- "llm" : {
59- "lc" : 1 ,
60- "type" : "constructor" ,
61- "id" : ["ads" , "llm" , "ModelDeploymentVLLM" ],
62- "kwargs" : {
63- "endpoint" : "https://modeldeployment.customer-oci.com/ocid/predict" ,
64- "model" : "my_model" ,
65- },
66- },
67- },
68- }
69-
70- EXPECTED_GEN_AI_LLM = {
71- "lc" : 1 ,
72- "type" : "constructor" ,
73- "id" : ["ads" , "llm" , "GenerativeAI" ],
74- "kwargs" : {
75- "compartment_id" : "<ocid>" ,
76- "client_kwargs" : {"service_endpoint" : "https://endpoint.oraclecloud.com" },
77- },
78- }
79-
8043 EXPECTED_GEN_AI_EMBEDDINGS = {
8144 "lc" : 1 ,
8245 "type" : "constructor" ,
@@ -170,10 +133,6 @@ def test_llm_chain_serialization_with_oci(self):
170133 template = PromptTemplate .from_template (self .PROMPT_TEMPLATE )
171134 llm_chain = LLMChain (prompt = template , llm = llm )
172135 serialized = dump (llm_chain )
173- # Do not check the ID field.
174- expected = deepcopy (self .EXPECTED_LLM_CHAIN_WITH_OCI_MD )
175- expected ["kwargs" ]["prompt" ]["id" ] = serialized ["kwargs" ]["prompt" ]["id" ]
176- self .assertEqual (serialized , expected )
177136 llm_chain = load (serialized )
178137 self .assertIsInstance (llm_chain , LLMChain )
179138 self .assertIsInstance (llm_chain .prompt , PromptTemplate )
@@ -193,10 +152,10 @@ def test_oci_gen_ai_serialization(self):
193152 except ImportError as ex :
194153 raise SkipTest ("OCI SDK does not support Generative AI." ) from ex
195154 serialized = dump (llm )
196- self .assertEqual (serialized , self .EXPECTED_GEN_AI_LLM )
197155 llm = load (serialized )
198156 self .assertIsInstance (llm , GenerativeAI )
199157 self .assertEqual (llm .compartment_id , self .COMPARTMENT_ID )
158+ self .assertEqual (llm .client_kwargs , self .GEN_AI_KWARGS )
200159
201160 def test_gen_ai_embeddings_serialization (self ):
202161 """Tests serialization of OCI Gen AI embeddings."""
0 commit comments