@@ -203,8 +203,12 @@ def _append_end_assert(pattern):
203203 else :
204204 return pattern if pattern .endswith (rb"\Z" ) else pattern + rb"\Z"
205205
206+ def _is_bytes_like (object ):
207+ return isinstance (object , (bytes , bytearray , memoryview ))
208+
206209class SRE_Pattern ():
207210 def __init__ (self , pattern , flags ):
211+ self .__binary = isinstance (pattern , bytes )
208212 self .pattern = pattern
209213 self .flags = flags
210214 flags_str = []
@@ -220,9 +224,18 @@ def __init__(self, pattern, flags):
220224 self .groupindex [group_name ] = self .__compiled_regexes [self .pattern ].groups [group_name ]
221225
222226
227+ def __check_input_type (self , input ):
228+ if not isinstance (input , str ) and not _is_bytes_like (input ):
229+ raise TypeError ("expected string or bytes-like object" )
230+ if not self .__binary and _is_bytes_like (input ):
231+ raise TypeError ("cannot use a string pattern on a bytes-like object" )
232+ if self .__binary and isinstance (input , str ):
233+ raise TypeError ("cannot use a bytes pattern on a string-like object" )
234+
235+
223236 def __tregex_compile (self , pattern ):
224237 if pattern not in self .__compiled_regexes :
225- tregex_engine = TREGEX_ENGINE_STR if isinstance ( pattern , str ) else TREGEX_ENGINE_BYTES
238+ tregex_engine = TREGEX_ENGINE_BYTES if self . __binary else TREGEX_ENGINE_STR
226239 try :
227240 self .__compiled_regexes [pattern ] = tregex_call_compile (tregex_engine , pattern , self .flags_str )
228241 except ValueError as e :
@@ -266,12 +279,15 @@ def _search(self, pattern, string, pos, endpos):
266279 return None
267280
268281 def search (self , string , pos = 0 , endpos = None ):
282+ self .__check_input_type (string )
269283 return self ._search (self .pattern , string , pos , default (endpos , - 1 ))
270284
271285 def match (self , string , pos = 0 , endpos = None ):
286+ self .__check_input_type (string )
272287 return self ._search (_prepend_begin_assert (self .pattern ), string , pos , default (endpos , - 1 ))
273288
274289 def fullmatch (self , string , pos = 0 , endpos = None ):
290+ self .__check_input_type (string )
275291 return self ._search (_append_end_assert (_prepend_begin_assert (self .pattern )), string , pos , default (endpos , - 1 ))
276292
277293 def __sanitize_out_type (self , elem ):
@@ -283,6 +299,7 @@ def __sanitize_out_type(self, elem):
283299 return str (elem )
284300
285301 def findall (self , string , pos = 0 , endpos = - 1 ):
302+ self .__check_input_type (string )
286303 if endpos > len (string ):
287304 endpos = len (string )
288305 elif endpos < 0 :
@@ -312,20 +329,20 @@ def group(match_result, group_nr, string):
312329 return string [group_start :group_end ]
313330
314331 n = len (repl )
315- result = ""
332+ result = b"" if self . __binary else ""
316333 start = 0
317- backslash = '\\ '
334+ backslash = b' \\ ' if self . __binary else '\\ '
318335 pos = repl .find (backslash , start )
319336 while pos != - 1 and start < n :
320337 if pos + 1 < n :
321338 if repl [pos + 1 ].isdigit () and match_result .groupCount > 0 :
322- group_nr = int (repl [pos + 1 ])
339+ group_nr = int (repl [pos + 1 ]. decode ( 'ascii' )) if self . __binary else int ( repl [ pos + 1 ] )
323340 group_str = group (match_result , group_nr , string )
324341 if group_str is None :
325342 raise ValueError ("invalid group reference %s at position %s" % (group_nr , pos ))
326343 result += repl [start :pos ] + group_str
327344 start = pos + 2
328- elif repl [pos + 1 ] == 'g' :
345+ elif repl [pos + 1 ] == ( b 'g' if self . __binary else 'g' ) :
329346 group_ref , group_ref_end , digits_only = self .__extract_groupname (repl , pos + 2 )
330347 if group_ref :
331348 group_str = group (match_result , int (group_ref ) if digits_only else pattern .groups [group_ref ], string )
@@ -345,26 +362,30 @@ def group(match_result, group_nr, string):
345362
346363
347364 def __extract_groupname (self , repl , pos ):
348- if repl [pos ] == '<' :
365+ if repl [pos ] == ( b '<' if self . __binary else '<' ) :
349366 digits_only = True
350367 n = len (repl )
351368 i = pos + 1
352- while i < n and repl [i ] != '>' :
369+ while i < n and repl [i ] != ( b '>' if self . __binary else '>' ) :
353370 digits_only = digits_only and repl [i ].isdigit ()
354371 i += 1
355372 if i < n :
356373 # found '>'
357- return repl [pos + 1 : i ], i , digits_only
374+ group_ref = repl [pos + 1 : i ]
375+ group_ref_str = group_ref .decode ('ascii' ) if self .__binary else group_ref
376+ return group_ref_str , i , digits_only
358377 return None , pos , False
359378
360379
361380 def sub (self , repl , string , count = 0 ):
381+ self .__check_input_type (string )
362382 n = 0
363383 pattern = self .__tregex_compile (self .pattern )
364384 result = []
365385 pos = 0
366- is_string_rep = isinstance (repl , str ) or isinstance (repl , bytes ) or isinstance ( repl , bytearray )
386+ is_string_rep = isinstance (repl , str ) or _is_bytes_like (repl )
367387 if is_string_rep :
388+ self .__check_input_type (repl )
368389 repl = _process_escape_sequences (repl )
369390 while (count == 0 or n < count ) and pos <= len (string ):
370391 match_result = tregex_call_exec (pattern .exec , string , pos )
@@ -386,7 +407,10 @@ def sub(self, repl, string, count=0):
386407 result .append (string [pos ])
387408 pos = pos + 1
388409 result .append (string [pos :])
389- return "" .join (result )
410+ if self .__binary :
411+ return b"" .join (result )
412+ else :
413+ return "" .join (result )
390414
391415 def split (self , string , maxsplit = 0 ):
392416 n = 0
0 commit comments