1. 程式人生 > >functools下的partial模塊應用

functools下的partial模塊應用

tle cep 工作 庫函數 trace echo super() amp 合並

問題

你有一個被其他python代碼使用的callable對象,可能是一個回調函數或者是一個處理器, 但是它的參數太多了,導致調用時出錯。

解決方案

如果需要減少某個函數的參數個數,你可以使用 functools.partial()partial() 函數允許你給一個或多個參數設置固定的值,減少接下來被調用時的參數個數。 為了演示清楚,假設你有下面這樣的函數:

def spam(a, b, c, d):
    print(a, b, c, d)

現在我們使用 partial() 函數來固定某些參數值:

>>> from
functools import partial >>> s1 = partial(spam, 1) # a = 1 >>> s1(2, 3, 4) 1 2 3 4 >>> s1(4, 5, 6) 1 4 5 6 >>> s2 = partial(spam, d=42) # d = 42 >>> s2(1, 2, 3) 1 2 3 42 >>> s2(4, 5, 5) 4 5 5 42 >>> s3 = partial(spam, 1, 2, d=42
) # a = 1, b = 2, d = 42 >>> s3(3) 1 2 3 42 >>> s3(4) 1 2 4 42 >>> s3(5) 1 2 5 42 >>>

可以看出 partial() 固定某些參數並返回一個新的callable對象。這個新的callable接受未賦值的參數, 然後跟之前已經賦值過的參數合並起來,最後將所有參數傳遞給原始函數。

討論

本節要解決的問題是讓原本不兼容的代碼可以一起工作。下面我會列舉一系列的例子。

第一個例子是,假設你有一個點的列表來表示(x,y)坐標元組。 你可以使用下面的函數來計算兩點之間的距離:

points = [ (1, 2), (3, 4), (5, 6), (7, 8) ]

import math
def distance(p1, p2):
    x1, y1 = p1
    x2, y2 = p2
    return math.hypot(x2 - x1, y2 - y1)

說明一下這裏的math.hypot默認是以坐標原點為幾點計算坐標到原點的直線距離

import math
print(math.hypot(6,8))
>>>10

現在假設你想以某個點為基點,根據點和基點之間的距離來排序所有的這些點。 列表的 sort() 方法接受一個關鍵字參數來自定義排序邏輯, 但是它只能接受一個單個參數的函數(distance()很明顯是不符合條件的)。 現在我們可以通過使用 partial() 來解決這個問題:

>>> pt = (4, 3)
>>> points.sort(key=partial(distance,pt))
>>> points
[(3, 4), (1, 2), (5, 6), (7, 8)]
>>>

更進一步,partial() 通常被用來微調其他庫函數所使用的回調函數的參數。 例如,下面是一段代碼,使用 multiprocessing 來異步計算一個結果值, 然後這個值被傳遞給一個接受一個result值和一個可選logging參數的回調函數:

def output_result(result, log=None):
    if log is not None:
        log.debug(Got: %r, result)

# A sample function
def add(x, y):
    return x + y

if __name__ == __main__:
    import logging
    from multiprocessing import Pool
    from functools import partial

    logging.basicConfig(level=logging.DEBUG)
    log = logging.getLogger(test)

    p = Pool()
    p.apply_async(add, (3, 4), callback=partial(output_result, log=log))
    p.close()
    p.join()

當給 apply_async() 提供回調函數時,通過使用 partial() 傳遞額外的 logging 參數。 而 multiprocessing 對這些一無所知——它僅僅只是使用單個值來調用回調函數。

作為一個類似的例子,考慮下編寫網絡服務器的問題,socketserver 模塊讓它變得很容易。 下面是個簡單的echo服務器:

from socketserver import StreamRequestHandler, TCPServer

class EchoHandler(StreamRequestHandler):
    def handle(self):
        for line in self.rfile:
            self.wfile.write(bGOT: + line)

serv = TCPServer((‘‘, 15000), EchoHandler)
serv.serve_forever()

不過,假設你想給EchoHandler增加一個可以接受其他配置選項的 __init__ 方法。比如:

class EchoHandler(StreamRequestHandler):
    # ack is added keyword-only argument. *args, **kwargs are
    # any normal parameters supplied (which are passed on)
    def __init__(self, *args, ack, **kwargs):
        self.ack = ack
        super().__init__(*args, **kwargs)

    def handle(self):
        for line in self.rfile:
            self.wfile.write(self.ack + line)

這麽修改後,我們就不需要顯式地在TCPServer類中添加前綴了。 但是你再次運行程序後會報類似下面的錯誤:

Exception happened during processing of request from (127.0.0.1, 59834)
Traceback (most recent call last):
...
TypeError: __init__() missing 1 required keyword-only argument: ack

初看起來好像很難修正這個錯誤,除了修改 socketserver 模塊源代碼或者使用某些奇怪的方法之外。 但是,如果使用 partial() 就能很輕松的解決——給它傳遞 ack 參數的值來初始化即可,如下:

from functools import partial
serv = TCPServer((‘‘, 15000), partial(EchoHandler, ack=bRECEIVED:))
serv.serve_forever()

在這個例子中,__init__() 方法中的ack參數聲明方式看上去很有趣,其實就是聲明ack為一個強制關鍵字參數。 關於強制關鍵字參數問題我們在7.2小節我們已經討論過了,讀者可以再去回顧一下。

很多時候 partial() 能實現的效果,lambda表達式也能實現。比如,之前的幾個例子可以使用下面這樣的表達式:

points.sort(key=lambda p: distance(pt, p))
p.apply_async(add, (3, 4), callback=lambda result: output_result(result,log))
serv = TCPServer((‘‘, 15000),
        lambda *args, **kwargs: EchoHandler(*args, ack=bRECEIVED:, **kwargs))

這樣寫也能實現同樣的效果,不過相比而已會顯得比較臃腫,對於閱讀代碼的人來講也更加難懂。 這時候使用 partial() 可以更加直觀的表達你的意圖(給某些參數預先賦值)。

functools下的partial模塊應用