functools下的partial模塊應用
問題
你有一個被其他python代碼使用的callable對象,可能是一個回調函數或者是一個處理器, 但是它的參數太多了,導致調用時出錯。
解決方案
如果需要減少某個函數的參數個數,你可以使用 functools.partial()
。 partial()
函數允許你給一個或多個參數設置固定的值,減少接下來被調用時的參數個數。 為了演示清楚,假設你有下面這樣的函數:
def spam(a, b, c, d):
print(a, b, c, d)
現在我們使用 partial()
函數來固定某些參數值:
>>> fromfunctools import partial >>> s1 = partial(spam, 1) # a = 1 >>> s1(2, 3, 4) 1 2 3 4 >>> s1(4, 5, 6) 1 4 5 6 >>> s2 = partial(spam, d=42) # d = 42 >>> s2(1, 2, 3) 1 2 3 42 >>> s2(4, 5, 5) 4 5 5 42 >>> s3 = partial(spam, 1, 2, d=42) # a = 1, b = 2, d = 42 >>> s3(3) 1 2 3 42 >>> s3(4) 1 2 4 42 >>> s3(5) 1 2 5 42 >>>
可以看出 partial()
固定某些參數並返回一個新的callable對象。這個新的callable接受未賦值的參數, 然後跟之前已經賦值過的參數合並起來,最後將所有參數傳遞給原始函數。
討論
本節要解決的問題是讓原本不兼容的代碼可以一起工作。下面我會列舉一系列的例子。
第一個例子是,假設你有一個點的列表來表示(x,y)坐標元組。 你可以使用下面的函數來計算兩點之間的距離:
points = [ (1, 2), (3, 4), (5, 6), (7, 8) ] import math def distance(p1, p2): x1, y1 = p1 x2, y2 = p2 return math.hypot(x2 - x1, y2 - y1)
說明一下這裏的math.hypot默認是以坐標原點為幾點計算坐標到原點的直線距離
import math print(math.hypot(6,8)) >>>10
現在假設你想以某個點為基點,根據點和基點之間的距離來排序所有的這些點。 列表的 sort()
方法接受一個關鍵字參數來自定義排序邏輯, 但是它只能接受一個單個參數的函數(distance()很明顯是不符合條件的)。 現在我們可以通過使用 partial()
來解決這個問題:
>>> pt = (4, 3) >>> points.sort(key=partial(distance,pt)) >>> points [(3, 4), (1, 2), (5, 6), (7, 8)] >>>
更進一步,partial()
通常被用來微調其他庫函數所使用的回調函數的參數。 例如,下面是一段代碼,使用 multiprocessing
來異步計算一個結果值, 然後這個值被傳遞給一個接受一個result值和一個可選logging參數的回調函數:
def output_result(result, log=None): if log is not None: log.debug(‘Got: %r‘, result) # A sample function def add(x, y): return x + y if __name__ == ‘__main__‘: import logging from multiprocessing import Pool from functools import partial logging.basicConfig(level=logging.DEBUG) log = logging.getLogger(‘test‘) p = Pool() p.apply_async(add, (3, 4), callback=partial(output_result, log=log)) p.close() p.join()
當給 apply_async()
提供回調函數時,通過使用 partial()
傳遞額外的 logging
參數。 而 multiprocessing
對這些一無所知——它僅僅只是使用單個值來調用回調函數。
作為一個類似的例子,考慮下編寫網絡服務器的問題,socketserver
模塊讓它變得很容易。 下面是個簡單的echo服務器:
from socketserver import StreamRequestHandler, TCPServer class EchoHandler(StreamRequestHandler): def handle(self): for line in self.rfile: self.wfile.write(b‘GOT:‘ + line) serv = TCPServer((‘‘, 15000), EchoHandler) serv.serve_forever()
不過,假設你想給EchoHandler增加一個可以接受其他配置選項的 __init__
方法。比如:
class EchoHandler(StreamRequestHandler): # ack is added keyword-only argument. *args, **kwargs are # any normal parameters supplied (which are passed on) def __init__(self, *args, ack, **kwargs): self.ack = ack super().__init__(*args, **kwargs) def handle(self): for line in self.rfile: self.wfile.write(self.ack + line)
這麽修改後,我們就不需要顯式地在TCPServer類中添加前綴了。 但是你再次運行程序後會報類似下面的錯誤:
Exception happened during processing of request from (‘127.0.0.1‘, 59834) Traceback (most recent call last): ... TypeError: __init__() missing 1 required keyword-only argument: ‘ack‘
初看起來好像很難修正這個錯誤,除了修改 socketserver
模塊源代碼或者使用某些奇怪的方法之外。 但是,如果使用 partial()
就能很輕松的解決——給它傳遞 ack
參數的值來初始化即可,如下:
from functools import partial serv = TCPServer((‘‘, 15000), partial(EchoHandler, ack=b‘RECEIVED:‘)) serv.serve_forever()
在這個例子中,__init__()
方法中的ack參數聲明方式看上去很有趣,其實就是聲明ack為一個強制關鍵字參數。 關於強制關鍵字參數問題我們在7.2小節我們已經討論過了,讀者可以再去回顧一下。
很多時候 partial()
能實現的效果,lambda表達式也能實現。比如,之前的幾個例子可以使用下面這樣的表達式:
points.sort(key=lambda p: distance(pt, p)) p.apply_async(add, (3, 4), callback=lambda result: output_result(result,log)) serv = TCPServer((‘‘, 15000), lambda *args, **kwargs: EchoHandler(*args, ack=b‘RECEIVED:‘, **kwargs))
這樣寫也能實現同樣的效果,不過相比而已會顯得比較臃腫,對於閱讀代碼的人來講也更加難懂。 這時候使用 partial()
可以更加直觀的表達你的意圖(給某些參數預先賦值)。
functools下的partial模塊應用