1. 程式人生 > >Akka-CQRS(16)- gRPC用JWT進行許可權管理

Akka-CQRS(16)- gRPC用JWT進行許可權管理

   前面談過gRPC的SSL/TLS安全機制,發現設定過程比較複雜:比如證書籤名:需要服務端、客戶端兩頭都設定等。想想實際上用JWT會更加便捷,而且更安全和功能強大,因為除JWT的加密簽名之外還可以把私密的使用者資訊放在JWT里加密後在服務端和客戶端之間傳遞。當然,最基本的是通過對JWT的驗證機制可以控制客戶端對某些功能的使用許可權。

通過JWT實現gRPC的函式呼叫許可權管理原理其實很簡單:客戶端首先從服務端通過身份驗證獲取JWT,然後在呼叫服務函式時把這個JWT同時傳給服務端進行許可權驗證。客戶端提交身份驗證請求返回JWT可以用一個獨立的服務函式實現,如下面.proto檔案裡的GetAuthToken:

message PBPOSCredential {
    string userid = 1;
    string password = 2;
}
message PBPOSToken {
    string jwt = 1;
}

service SendCommand {
    rpc SingleResponse(PBPOSCommand) returns (PBPOSResponse) {};
    rpc GetTxnItems(PBPOSCommand) returns (stream PBTxnItem) {};
    rpc GetAuthToken(PBPOSCredential) returns (PBPOSToken) {};

}

比較棘手的是如何把JWT從客戶端傳送至服務端,因為gRPC基本上騎劫了Request和Response。其中一個方法是通過Interceptor來擷取Request的header即metadata。客戶端將JWT寫入metadata,服務端從metadata讀取JWT。

我們先看看客戶端的Interceptor設定和使用:

  class AuthClientInterceptor(jwt: String) extends ClientInterceptor {
    def interceptCall[ReqT, RespT](methodDescriptor: MethodDescriptor[ReqT, RespT], callOptions: CallOptions, channel: io.grpc.Channel): ClientCall[ReqT, RespT] =
      new ForwardingClientCall.SimpleForwardingClientCall[ReqT, RespT](channel.newCall(methodDescriptor, callOptions)) {
        override def start(responseListener: ClientCall.Listener[RespT], headers: Metadata): Unit = {
          headers.put(Key.of("jwt", Metadata.ASCII_STRING_MARSHALLER), jwt)
          super.start(responseListener, headers)
        }
      }
  }

...


    val unsafeChannel = NettyChannelBuilder
      .forAddress("192.168.0.189",50051)
      .negotiationType(NegotiationType.PLAINTEXT)
      .build()

   val securedChannel = ClientInterceptors.intercept(unsafeChannel, new AuthClientInterceptor(jwt))

    val securedClient = SendCommandGrpc.blockingStub(securedChannel)

    val resp = securedClient.singleResponse(PBPOSCommand())

身份驗證請求即JWT獲取是不需要Interceptor的,所以要用沒有Interceptor的unsafeChannel: 

    //build connection channel
    val unsafeChannel = NettyChannelBuilder
      .forAddress("192.168.0.189",50051)
      .negotiationType(NegotiationType.PLAINTEXT)
      .build()


    val authClient = SendCommandGrpc.blockingStub(unsafeChannel)
    val jwt = authClient.getAuthToken(PBPOSCredential(userid="johnny",password="p4ssw0rd")).jwt
    println(s"got jwt: $jwt")
 

JWT的構建和使用已經在前面的幾篇博文裡討論過了: 

package com.datatech.auth

import pdi.jwt._
import org.json4s.native.Json
import org.json4s._
import org.json4s.jackson.JsonMethods._
import pdi.jwt.algorithms._
import scala.util._

object AuthBase {
  type UserInfo = Map[String, Any]
  case class AuthBase(
                       algorithm: JwtAlgorithm = JwtAlgorithm.HMD5,
                       secret: String = "OpenSesame",
                       getUserInfo: (String,String) => Option[UserInfo] = null) {
    ctx =>

    def withAlgorithm(algo: JwtAlgorithm): AuthBase = ctx.copy(algorithm = algo)

    def withSecretKey(key: String): AuthBase = ctx.copy(secret = key)

    def withUserFunc(f: (String, String) => Option[UserInfo]): AuthBase = ctx.copy(getUserInfo = f)

    def authenticateToken(token: String): Option[String] =
      algorithm match {
        case algo: JwtAsymmetricAlgorithm =>
          Jwt.isValid(token, secret, Seq((algorithm.asInstanceOf[JwtAsymmetricAlgorithm]))) match {
            case true => Some(token)
            case _ => None
          }
        case _ =>
          Jwt.isValid(token, secret, Seq((algorithm.asInstanceOf[JwtHmacAlgorithm]))) match {
            case true => Some(token)
            case _ => None
          }
      }

    def getUserInfo(token: String): Option[UserInfo] = {
      algorithm match {
        case algo: JwtAsymmetricAlgorithm =>
          Jwt.decodeRawAll(token, secret, Seq(algorithm.asInstanceOf[JwtAsymmetricAlgorithm])) match {
            case Success(parts) => Some(((parse(parts._2).asInstanceOf[JObject]) \ "userinfo").values.asInstanceOf[UserInfo])
            case Failure(err) => None
          }
        case _ =>
          Jwt.decodeRawAll(token, secret, Seq(algorithm.asInstanceOf[JwtHmacAlgorithm])) match {
            case Success(parts) => Some(((parse(parts._2).asInstanceOf[JObject]) \ "userinfo").values.asInstanceOf[UserInfo])
            case Failure(err) => None
          }
      }
    }

    def issueJwt(userinfo: UserInfo): String = {
      val claims = JwtClaim() + Json(DefaultFormats).write(("userinfo", userinfo))
      Jwt.encode(claims, secret, algorithm)
    }
  }

}

服務端Interceptor的構建和設定如下: 

abstract class FutureListener[Q](implicit ec: ExecutionContext) extends Listener[Q] {

  protected val delegate: Future[Listener[Q]]

  private val eventually = delegate.foreach _

  override def onComplete(): Unit = eventually { _.onComplete() }
  override def onCancel(): Unit = eventually { _.onCancel() }
  override def onMessage(message: Q): Unit = eventually { _ onMessage message }
  override def onHalfClose(): Unit = eventually { _.onHalfClose() }
  override def onReady(): Unit = eventually { _.onReady() }

}

object Keys {
  val AUTH_META_KEY: Metadata.Key[String] = of("jwt", Metadata.ASCII_STRING_MARSHALLER)
  val AUTH_CTX_KEY: Context.Key[String] = key("jwt")
}

class AuthorizationInterceptor(implicit ec: ExecutionContext) extends ServerInterceptor {
  override def interceptCall[Q, R](
                                    call: ServerCall[Q, R],
                                    headers: Metadata,
                                    next: ServerCallHandler[Q, R]
                                  ): Listener[Q] = {

    val prevCtx = Context.current
    val jwt = headers.get(Keys.AUTH_META_KEY)

    println(s"!!!!!!!!!!! $jwt !!!!!!!!!!")

    new FutureListener[Q] {
      protected val delegate = Future {
        val nextCtx = prevCtx withValue (Keys.AUTH_CTX_KEY, jwt)
        Contexts.interceptCall(nextCtx, call, headers, next)
      }
    }
  }
}

trait gRPCServer {

  def runServer(service: ServerServiceDefinition)(implicit actorSys: ActorSystem): Unit = {
    import actorSys.dispatcher
    val server = NettyServerBuilder
      .forPort(50051)
      .addService(ServerInterceptors.intercept(service,
        new AuthorizationInterceptor))
      .build
      .start
    // make sure our server is stopped when jvm is shut down
    Runtime.getRuntime.addShutdownHook(new Thread() {
      override def run(): Unit = {
        server.shutdown()
        server.awaitTermination()
      }
    })
  }

}

注意:客戶端上傳的request-header只能在構建server時接觸到,在具體服務函式裡是無法呼叫request-header的,但gRPC又一個結構Context可以在兩個地方都能呼叫。所以,我們可以在構建server時把JWT從header搬到Context裡。不過,千萬注意這個Context的讀寫必須在同一個執行緒裡。在服務端的Interceptor裡我們把JWT從metadata裡讀出然後寫入Context。在需要許可權管理的服務函式裡再從Context裡讀取JWT進行驗證: 

   override def singleResponse(request: PBPOSCommand): Future[PBPOSResponse] = {
      val jwt = AUTH_CTX_KEY.get
      println(s"***********$jwt**************")
      val optUserInfo = authenticator.getUserInfo(jwt)
      val shopid = optUserInfo match {
        case Some(m) => m("shopid")
        case None => "invalid token!"
      }
      FastFuture.successful(PBPOSResponse(msg=s"shopid:$shopid"))
    }

JWT的構建也是一個服務函式: 

   val authenticator = new AuthBase()
      .withAlgorithm(JwtAlgorithm.HS256)
      .withSecretKey("OpenSesame")
      .withUserFunc(getValidUser)

    override def getAuthToken(request: PBPOSCredential): Future[PBPOSToken] = {
      getValidUser(request.userid, request.password) match {
        case Some(userinfo) => FastFuture.successful(PBPOSToken(authenticator.issueJwt(userinfo)))
        case None => FastFuture.successful(PBPOSToken("Invalid Token!"))
      }
    }

還需要一個模擬的身份驗證服務函式: 

package com.datatech.auth

object MockUserAuthService {
  type UserInfo = Map[String,Any]
  case class User(username: String, password: String, userInfo: UserInfo)
  val validUsers = Seq(User("johnny", "p4ssw0rd",Map("shopid" -> "1101", "userid" -> "101"))
    ,User("tiger", "secret", Map("shopid" -> "1101" , "userid" -> "102")))

  def getValidUser(userid: String, pswd: String): Option[UserInfo] =
    validUsers.find(user => user.username == userid && user.password == pswd) match {
          case Some(user) => Some(user.userInfo)
          case _ => None
    }
}

下面是本次示範的原始碼:

project/plugins.sbt

addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.9")
addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.9.2")
addSbtPlugin("com.typesafe.sbt" % "sbt-native-packager" % "1.3.15")
addSbtPlugin("com.thesamet" % "sbt-protoc" % "0.99.21")
addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.9.2")
libraryDependencies += "com.thesamet.scalapb" %% "compilerplugin" % "0.9.0-M6"

build.sbt

name := "grpc-jwt"

version := "0.1"

version := "0.1"

scalaVersion := "2.12.8"

scalacOptions += "-Ypartial-unification"

val akkaversion = "2.5.23"

libraryDependencies := Seq(
  "com.typesafe.akka" %% "akka-cluster-metrics" % akkaversion,
  "com.typesafe.akka" %% "akka-cluster-sharding" % akkaversion,
  "com.typesafe.akka" %% "akka-persistence" % akkaversion,
  "com.lightbend.akka" %% "akka-stream-alpakka-cassandra" % "1.0.1",
  "org.mongodb.scala" %% "mongo-scala-driver" % "2.6.0",
  "com.lightbend.akka" %% "akka-stream-alpakka-mongodb" % "1.0.1",
  "com.typesafe.akka" %% "akka-persistence-query" % akkaversion,
  "com.typesafe.akka" %% "akka-persistence-cassandra" % "0.97",
  "com.datastax.cassandra" % "cassandra-driver-core" % "3.6.0",
  "com.datastax.cassandra" % "cassandra-driver-extras" % "3.6.0",
  "ch.qos.logback"  %  "logback-classic"   % "1.2.3",
  "io.monix" %% "monix" % "3.0.0-RC2",
  "org.typelevel" %% "cats-core" % "2.0.0-M1",
  "io.grpc" % "grpc-netty" % scalapb.compiler.Version.grpcJavaVersion,
  "io.netty" % "netty-tcnative-boringssl-static" % "2.0.22.Final",
  "com.thesamet.scalapb" %% "scalapb-runtime" % scalapb.compiler.Version.scalapbVersion % "protobuf",
  "com.thesamet.scalapb" %% "scalapb-runtime-grpc" % scalapb.compiler.Version.scalapbVersion,
  
  "com.pauldijou" %% "jwt-core" % "3.0.1",
  "de.heikoseeberger" %% "akka-http-json4s" % "1.22.0",
  "org.json4s" %% "json4s-native" % "3.6.1",
  "com.typesafe.akka" %% "akka-http-spray-json" % "10.1.8",
  "org.json4s" %% "json4s-jackson" % "3.6.7",
  "org.json4s" %% "json4s-ext" % "3.6.7"

)

// (optional) If you need scalapb/scalapb.proto or anything from
// google/protobuf/*.proto
//libraryDependencies += "com.thesamet.scalapb" %% "scalapb-runtime" % scalapb.compiler.Version.scalapbVersion % "protobuf"


PB.targets in Compile := Seq(
  scalapb.gen() -> (sourceManaged in Compile).value
)

enablePlugins(JavaAppPackaging)

main/protobuf/posmessages.proto

syntax = "proto3";

import "google/protobuf/wrappers.proto";
import "google/protobuf/any.proto";
import "scalapb/scalapb.proto";

option (scalapb.options) = {
  // use a custom Scala package name
  // package_name: "io.ontherocks.introgrpc.demo"

  // don't append file name to package
  flat_package: true

  // generate one Scala file for all messages (services still get their own file)
  single_file: true

  // add imports to generated file
  // useful when extending traits or using custom types
  // import: "io.ontherocks.hellogrpc.RockingMessage"

  // code to put at the top of generated file
  // works only with `single_file: true`
  //preamble: "sealed trait SomeSealedTrait"
};

package com.datatech.pos.messages;

message PBVchState {      //單據狀態
    string opr  = 1;    //收款員
    int64  jseq = 2;    //begin journal sequence for read-side replay
    int32  num  = 3;    //當前單號
    int32  seq  = 4;    //當前序號
    bool   void = 5;    //取消模式
    bool   refd = 6;    //退款模式
    bool   susp = 7;    //掛單
    bool   canc = 8;    //廢單
    bool   due  = 9;    //當前餘額
    string su   = 10;   //主管編號
    string mbr  = 11;   //會員號
    int32  mode = 12;   //當前操作流程:0=logOff, 1=LogOn, 2=Payment
}

message PBTxnItem {       //交易記錄
    string txndate    = 1;   //交易日期
    string txntime    = 2;   //錄入時間
    string opr        = 3;   //操作員
    int32  num        = 4;   //銷售單號
    int32  seq        = 5;   //交易序號
    int32  txntype    = 6;   //交易型別
    int32  salestype  = 7;   //銷售型別
    int32  qty        = 8;   //交易數量
    int32  price      = 9;   //單價(分)
    int32  amount     = 10;  //碼洋(分)
    int32  disc       = 11;  //折扣率 (%)
    int32  dscamt     = 12;  //折扣額:負值  net實洋 = amount + dscamt
    string member     = 13;  //會員卡號
    string code       = 14;  //編號(商品、卡號...)
    string acct       = 15;  //賬號
    string dpt        = 16;  //部類
}

message PBPOSResponse {
    int32  sts                  = 1;
    string msg                  = 2;
    PBVchState voucher          = 3;
    repeated PBTxnItem txnitems   = 4;

}

message PBPOSCommand {
    string commandname = 1;
    string delimitedparams = 2;
}

message PBPOSCredential {
    string userid = 1;
    string password = 2;
}
message PBPOSToken {
    string jwt = 1;
}

service SendCommand {
    rpc SingleResponse(PBPOSCommand) returns (PBPOSResponse) {};
    rpc GetTxnItems(PBPOSCommand) returns (stream PBTxnItem) {};
    rpc GetAuthToken(PBPOSCredential) returns (PBPOSToken) {};

}

gRPCServer.scala

package com.datatech.grpc.server

import io.grpc.ServerServiceDefinition
import io.grpc.netty.NettyServerBuilder
import io.grpc.ServerInterceptors
import scala.concurrent._
import io.grpc.Context
import io.grpc.Contexts
import io.grpc.ServerCall
import io.grpc.ServerCallHandler
import io.grpc.ServerInterceptor
import io.grpc.Metadata
import io.grpc.Metadata.Key.of
import io.grpc.Context.key
import io.grpc.ServerCall.Listener
import akka.actor._


abstract class FutureListener[Q](implicit ec: ExecutionContext) extends Listener[Q] {

  protected val delegate: Future[Listener[Q]]

  private val eventually = delegate.foreach _

  override def onComplete(): Unit = eventually { _.onComplete() }
  override def onCancel(): Unit = eventually { _.onCancel() }
  override def onMessage(message: Q): Unit = eventually { _ onMessage message }
  override def onHalfClose(): Unit = eventually { _.onHalfClose() }
  override def onReady(): Unit = eventually { _.onReady() }

}

object Keys {
  val AUTH_META_KEY: Metadata.Key[String] = of("jwt", Metadata.ASCII_STRING_MARSHALLER)
  val AUTH_CTX_KEY: Context.Key[String] = key("jwt")
}

class AuthorizationInterceptor(implicit ec: ExecutionContext) extends ServerInterceptor {
  override def interceptCall[Q, R](
                                    call: ServerCall[Q, R],
                                    headers: Metadata,
                                    next: ServerCallHandler[Q, R]
                                  ): Listener[Q] = {

    val prevCtx = Context.current
    val jwt = headers.get(Keys.AUTH_META_KEY)

    println(s"!!!!!!!!!!! $jwt !!!!!!!!!!")

    new FutureListener[Q] {
      protected val delegate = Future {
        val nextCtx = prevCtx withValue (Keys.AUTH_CTX_KEY, jwt)
        Contexts.interceptCall(nextCtx, call, headers, next)
      }
    }
  }
}

trait gRPCServer {

  def runServer(service: ServerServiceDefinition)(implicit actorSys: ActorSystem): Unit = {
    import actorSys.dispatcher
    val server = NettyServerBuilder
      .forPort(50051)
      .addService(ServerInterceptors.intercept(service,
        new AuthorizationInterceptor))
      .build
      .start
    // make sure our server is stopped when jvm is shut down
    Runtime.getRuntime.addShutdownHook(new Thread() {
      override def run(): Unit = {
        server.shutdown()
        server.awaitTermination()
      }
    })
  }

}

POSServices.scala

package com.datatech.pos.service
import com.datatech.grpc.server.Keys._
import akka.http.scaladsl.util.FastFuture
import com.datatech.pos.messages._
import com.datatech.grpc.server._
import com.datatech.auth.MockUserAuthService._

import scala.concurrent.Future
import com.datatech.auth.AuthBase._
import pdi.jwt._
import akka.actor._
import io.grpc.stub.StreamObserver


object POSServices extends gRPCServer {
  type UserInfo = Map[String, Any]

  class POSServices extends SendCommandGrpc.SendCommand {

    val authenticator = new AuthBase()
      .withAlgorithm(JwtAlgorithm.HS256)
      .withSecretKey("OpenSesame")
      .withUserFunc(getValidUser)

    override def getTxnItems(request: PBPOSCommand, responseObserver: StreamObserver[PBTxnItem]): Unit = ???

    override def singleResponse(request: PBPOSCommand): Future[PBPOSResponse] = {
      val jwt = AUTH_CTX_KEY.get
      println(s"***********$jwt**************")
      val optUserInfo = authenticator.getUserInfo(jwt)
      val shopid = optUserInfo match {
        case Some(m) => m("shopid")
        case None => "invalid token!"
      }
      FastFuture.successful(PBPOSResponse(msg=s"shopid:$shopid"))
    }

    override def getAuthToken(request: PBPOSCredential): Future[PBPOSToken] = {
      getValidUser(request.userid, request.password) match {
        case Some(userinfo) => FastFuture.successful(PBPOSToken(authenticator.issueJwt(userinfo)))
        case None => FastFuture.successful(PBPOSToken("Invalid Token!"))
      }
    }
  }

  def main(args: Array[String]) = {
    implicit val system = ActorSystem("grpc-system")
    val svc = SendCommandGrpc.bindService(new POSServices, system.dispatcher)
    runServer(svc)
  }
}

AuthBase.scala

package com.datatech.auth

import pdi.jwt._
import org.json4s.native.Json
import org.json4s._
import org.json4s.jackson.JsonMethods._
import pdi.jwt.algorithms._
import scala.util._

object AuthBase {
  type UserInfo = Map[String, Any]
  case class AuthBase(
                       algorithm: JwtAlgorithm = JwtAlgorithm.HMD5,
                       secret: String = "OpenSesame",
                       getUserInfo: (String,String) => Option[UserInfo] = null) {
    ctx =>

    def withAlgorithm(algo: JwtAlgorithm): AuthBase = ctx.copy(algorithm = algo)

    def withSecretKey(key: String): AuthBase = ctx.copy(secret = key)

    def withUserFunc(f: (String, String) => Option[UserInfo]): AuthBase = ctx.copy(getUserInfo = f)

    def authenticateToken(token: String): Option[String] =
      algorithm match {
        case algo: JwtAsymmetricAlgorithm =>
          Jwt.isValid(token, secret, Seq((algorithm.asInstanceOf[JwtAsymmetricAlgorithm]))) match {
            case true => Some(token)
            case _ => None
          }
        case _ =>
          Jwt.isValid(token, secret, Seq((algorithm.asInstanceOf[JwtHmacAlgorithm]))) match {
            case true => Some(token)
            case _ => None
          }
      }

    def getUserInfo(token: String): Option[UserInfo] = {
      algorithm match {
        case algo: JwtAsymmetricAlgorithm =>
          Jwt.decodeRawAll(token, secret, Seq(algorithm.asInstanceOf[JwtAsymmetricAlgorithm])) match {
            case Success(parts) => Some(((parse(parts._2).asInstanceOf[JObject]) \ "userinfo").values.asInstanceOf[UserInfo])
            case Failure(err) => None
          }
        case _ =>
          Jwt.decodeRawAll(token, secret, Seq(algorithm.asInstanceOf[JwtHmacAlgorithm])) match {
            case Success(parts) => Some(((parse(parts._2).asInstanceOf[JObject]) \ "userinfo").values.asInstanceOf[UserInfo])
            case Failure(err) => None
          }
      }
    }

    def issueJwt(userinfo: UserInfo): String = {
      val claims = JwtClaim() + Json(DefaultFormats).write(("userinfo", userinfo))
      Jwt.encode(claims, secret, algorithm)
    }
  }

}

POSClient.scala

package com.datatech.pos.client

import com.datatech.pos.messages.{PBPOSCommand, PBPOSCredential, SendCommandGrpc}
import io.grpc.stub.StreamObserver
import io.grpc.netty.{ NegotiationType, NettyChannelBuilder}
import io.grpc.CallOptions
import io.grpc.ClientCall
import io.grpc.ClientInterceptor
import io.grpc.ForwardingClientCall
import io.grpc.Metadata
import io.grpc.Metadata.Key
import io.grpc.MethodDescriptor
import io.grpc.ClientInterceptors

object POSClient {
  class AuthClientInterceptor(jwt: String) extends ClientInterceptor {
    def interceptCall[ReqT, RespT](methodDescriptor: MethodDescriptor[ReqT, RespT], callOptions: CallOptions, channel: io.grpc.Channel): ClientCall[ReqT, RespT] =
      new ForwardingClientCall.SimpleForwardingClientCall[ReqT, RespT](channel.newCall(methodDescriptor, callOptions)) {
        override def start(responseListener: ClientCall.Listener[RespT], headers: Metadata): Unit = {
          headers.put(Key.of("jwt", Metadata.ASCII_STRING_MARSHALLER), jwt)
          super.start(responseListener, headers)
        }
      }
  }

  def main(args: Array[String]): Unit = {

    //build connection channel
    val unsafeChannel = NettyChannelBuilder
      .forAddress("192.168.0.189",50051)
      .negotiationType(NegotiationType.PLAINTEXT)
      .build()


    val authClient = SendCommandGrpc.blockingStub(unsafeChannel)
    val jwt = authClient.getAuthToken(PBPOSCredential(userid="johnny",password="p4ssw0rd")).jwt
    println(s"got jwt: $jwt")


    val securedChannel = ClientInterceptors.intercept(unsafeChannel, new AuthClientInterceptor(jwt))

    val securedClient = SendCommandGrpc.blockingStub(securedChannel)

    val resp = securedClient.singleResponse(PBPOSCommand())

    println(s"secured response: $resp")

    // wait for async execution
    scala.io.StdIn.readLine()
  }


}

&n